janus-llm 1.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +9 -1
- janus/__main__.py +4 -0
- janus/_tests/test_cli.py +128 -0
- janus/_tests/test_translate.py +49 -7
- janus/cli.py +530 -46
- janus/converter.py +50 -19
- janus/embedding/_tests/test_collections.py +2 -8
- janus/embedding/_tests/test_database.py +32 -0
- janus/embedding/_tests/test_vectorize.py +9 -4
- janus/embedding/collections.py +49 -6
- janus/embedding/embedding_models_info.py +130 -0
- janus/embedding/vectorize.py +53 -62
- janus/language/_tests/__init__.py +0 -0
- janus/language/_tests/test_combine.py +62 -0
- janus/language/_tests/test_splitter.py +16 -0
- janus/language/binary/_tests/test_binary.py +16 -1
- janus/language/binary/binary.py +10 -3
- janus/language/block.py +31 -30
- janus/language/combine.py +26 -34
- janus/language/mumps/_tests/test_mumps.py +2 -2
- janus/language/mumps/mumps.py +93 -9
- janus/language/naive/__init__.py +4 -0
- janus/language/naive/basic_splitter.py +14 -0
- janus/language/naive/chunk_splitter.py +26 -0
- janus/language/naive/registry.py +13 -0
- janus/language/naive/simple_ast.py +18 -0
- janus/language/naive/tag_splitter.py +61 -0
- janus/language/splitter.py +168 -74
- janus/language/treesitter/_tests/test_treesitter.py +19 -14
- janus/language/treesitter/treesitter.py +37 -13
- janus/llm/model_callbacks.py +177 -0
- janus/llm/models_info.py +165 -72
- janus/metrics/__init__.py +8 -0
- janus/metrics/_tests/__init__.py +0 -0
- janus/metrics/_tests/reference.py +2 -0
- janus/metrics/_tests/target.py +2 -0
- janus/metrics/_tests/test_bleu.py +56 -0
- janus/metrics/_tests/test_chrf.py +67 -0
- janus/metrics/_tests/test_file_pairing.py +59 -0
- janus/metrics/_tests/test_llm.py +91 -0
- janus/metrics/_tests/test_reading.py +28 -0
- janus/metrics/_tests/test_rouge_score.py +65 -0
- janus/metrics/_tests/test_similarity_score.py +23 -0
- janus/metrics/_tests/test_treesitter_metrics.py +110 -0
- janus/metrics/bleu.py +66 -0
- janus/metrics/chrf.py +55 -0
- janus/metrics/cli.py +7 -0
- janus/metrics/complexity_metrics.py +208 -0
- janus/metrics/file_pairing.py +113 -0
- janus/metrics/llm_metrics.py +202 -0
- janus/metrics/metric.py +466 -0
- janus/metrics/reading.py +70 -0
- janus/metrics/rouge_score.py +96 -0
- janus/metrics/similarity.py +53 -0
- janus/metrics/splitting.py +38 -0
- janus/parsers/_tests/__init__.py +0 -0
- janus/parsers/_tests/test_code_parser.py +32 -0
- janus/parsers/code_parser.py +24 -253
- janus/parsers/doc_parser.py +169 -0
- janus/parsers/eval_parser.py +80 -0
- janus/parsers/reqs_parser.py +72 -0
- janus/prompts/prompt.py +103 -30
- janus/translate.py +636 -111
- janus/utils/_tests/__init__.py +0 -0
- janus/utils/_tests/test_logger.py +67 -0
- janus/utils/_tests/test_progress.py +20 -0
- janus/utils/enums.py +56 -3
- janus/utils/progress.py +56 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/METADATA +27 -11
- janus_llm-2.0.1.dist-info/RECORD +94 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/WHEEL +1 -1
- janus_llm-1.0.0.dist-info/RECORD +0 -48
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/LICENSE +0 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/entry_points.txt +0 -0
janus/cli.py
CHANGED
@@ -1,34 +1,50 @@
|
|
1
1
|
import json
|
2
|
+
import logging
|
2
3
|
import os
|
3
4
|
from pathlib import Path
|
4
5
|
from typing import Optional
|
5
6
|
|
6
7
|
import click
|
7
8
|
import typer
|
9
|
+
from pydantic import AnyHttpUrl
|
8
10
|
from rich import print
|
9
11
|
from rich.console import Console
|
10
12
|
from rich.prompt import Confirm
|
11
13
|
from typing_extensions import Annotated
|
12
14
|
|
15
|
+
from janus.language.naive.registry import CUSTOM_SPLITTERS
|
16
|
+
|
13
17
|
from .embedding.collections import Collections
|
14
18
|
from .embedding.database import ChromaEmbeddingDatabase
|
19
|
+
from .embedding.embedding_models_info import (
|
20
|
+
EMBEDDING_COST_PER_MODEL,
|
21
|
+
EMBEDDING_MODEL_CONFIG_DIR,
|
22
|
+
EMBEDDING_TOKEN_LIMITS,
|
23
|
+
EmbeddingModelType,
|
24
|
+
)
|
15
25
|
from .embedding.vectorize import ChromaDBVectorizer
|
16
26
|
from .language.binary import BinarySplitter
|
17
27
|
from .language.mumps import MumpsSplitter
|
18
28
|
from .language.treesitter import TreeSitterSplitter
|
19
|
-
from .llm.
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
29
|
+
from .llm.model_callbacks import COST_PER_1K_TOKENS
|
30
|
+
from .llm.models_info import MODEL_CONFIG_DIR, MODEL_TYPE_CONSTRUCTORS, TOKEN_LIMITS
|
31
|
+
from .metrics.cli import evaluate
|
32
|
+
from .translate import (
|
33
|
+
PARSER_TYPES,
|
34
|
+
DiagramGenerator,
|
35
|
+
Documenter,
|
36
|
+
MadLibsDocumenter,
|
37
|
+
MultiDocumenter,
|
38
|
+
RequirementsDocumenter,
|
39
|
+
Translator,
|
24
40
|
)
|
25
|
-
from .
|
26
|
-
from .translate import Translator
|
27
|
-
from .utils.enums import CUSTOM_SPLITTERS, LANGUAGES
|
41
|
+
from .utils.enums import LANGUAGES
|
28
42
|
from .utils.logger import create_logger
|
29
43
|
|
30
|
-
|
44
|
+
httpx_logger = logging.getLogger("httpx")
|
45
|
+
httpx_logger.setLevel(logging.WARNING)
|
31
46
|
|
47
|
+
log = create_logger(__name__)
|
32
48
|
homedir = Path.home().expanduser()
|
33
49
|
|
34
50
|
janus_dir = homedir / ".janus"
|
@@ -43,6 +59,17 @@ if not db_file.exists():
|
|
43
59
|
with open(db_file, "r") as f:
|
44
60
|
db_loc = f.read()
|
45
61
|
|
62
|
+
collections_config_file = Path(db_loc) / "collections.json"
|
63
|
+
|
64
|
+
|
65
|
+
def get_collections_config():
|
66
|
+
if collections_config_file.exists():
|
67
|
+
with open(collections_config_file, "r") as f:
|
68
|
+
config = json.load(f)
|
69
|
+
else:
|
70
|
+
config = {}
|
71
|
+
return config
|
72
|
+
|
46
73
|
|
47
74
|
app = typer.Typer(
|
48
75
|
help="Choose a command",
|
@@ -51,6 +78,7 @@ app = typer.Typer(
|
|
51
78
|
context_settings={"help_option_names": ["-h", "--help"]},
|
52
79
|
)
|
53
80
|
|
81
|
+
|
54
82
|
db = typer.Typer(
|
55
83
|
help="Database commands",
|
56
84
|
add_completion=False,
|
@@ -64,6 +92,43 @@ llm = typer.Typer(
|
|
64
92
|
context_settings={"help_option_names": ["-h", "--help"]},
|
65
93
|
)
|
66
94
|
|
95
|
+
embedding = typer.Typer(
|
96
|
+
help="Embedding model commands",
|
97
|
+
add_completion=False,
|
98
|
+
no_args_is_help=True,
|
99
|
+
context_settings={"help_option_names": ["-h", "--help"]},
|
100
|
+
)
|
101
|
+
|
102
|
+
|
103
|
+
def version_callback(value: bool) -> None:
|
104
|
+
if value:
|
105
|
+
from . import __version__ as version
|
106
|
+
|
107
|
+
print(f"Janus CLI [blue]v{version}[/blue]")
|
108
|
+
raise typer.Exit()
|
109
|
+
|
110
|
+
|
111
|
+
@app.callback()
|
112
|
+
def common(
|
113
|
+
ctx: typer.Context,
|
114
|
+
version: bool = typer.Option(
|
115
|
+
None,
|
116
|
+
"--version",
|
117
|
+
"-v",
|
118
|
+
callback=version_callback,
|
119
|
+
help="Print the version and exit.",
|
120
|
+
),
|
121
|
+
) -> None:
|
122
|
+
"""A function for getting the app version
|
123
|
+
|
124
|
+
This will call the version_callback function to print the version and exit.
|
125
|
+
|
126
|
+
Arguments:
|
127
|
+
ctx: The typer context
|
128
|
+
version: A boolean flag for the version
|
129
|
+
"""
|
130
|
+
pass
|
131
|
+
|
67
132
|
|
68
133
|
@app.command(
|
69
134
|
help="Translate code from one language to another using an LLM.",
|
@@ -73,41 +138,53 @@ def translate(
|
|
73
138
|
input_dir: Annotated[
|
74
139
|
Path,
|
75
140
|
typer.Option(
|
141
|
+
"--input",
|
142
|
+
"-i",
|
76
143
|
help="The directory containing the source code to be translated. "
|
77
|
-
"The files should all be in one flat directory."
|
144
|
+
"The files should all be in one flat directory.",
|
78
145
|
),
|
79
146
|
],
|
80
147
|
source_lang: Annotated[
|
81
148
|
str,
|
82
149
|
typer.Option(
|
150
|
+
"--source-language",
|
151
|
+
"-s",
|
83
152
|
help="The language of the source code.",
|
84
153
|
click_type=click.Choice(sorted(LANGUAGES)),
|
85
154
|
),
|
86
155
|
],
|
87
156
|
output_dir: Annotated[
|
88
157
|
Path,
|
89
|
-
typer.Option(
|
158
|
+
typer.Option(
|
159
|
+
"--output", "-o", help="The directory to store the translated code in."
|
160
|
+
),
|
90
161
|
],
|
91
162
|
target_lang: Annotated[
|
92
163
|
str,
|
93
164
|
typer.Option(
|
165
|
+
"--target-language",
|
166
|
+
"-t",
|
94
167
|
help="The desired output language to translate the source code to. The "
|
95
168
|
"format can follow a 'language-version' syntax. Use 'text' to get plaintext"
|
96
169
|
"results as returned by the LLM. Examples: `python-3.10`, `mumps`, `java-10`,"
|
97
|
-
"text."
|
170
|
+
"text.",
|
98
171
|
),
|
99
172
|
],
|
100
173
|
llm_name: Annotated[
|
101
174
|
str,
|
102
175
|
typer.Option(
|
176
|
+
"--llm",
|
177
|
+
"-L",
|
103
178
|
help="The custom name of the model set with 'janus llm add'.",
|
104
179
|
),
|
105
|
-
] = "gpt-3.5-turbo",
|
180
|
+
] = "gpt-3.5-turbo-0125",
|
106
181
|
max_prompts: Annotated[
|
107
182
|
int,
|
108
183
|
typer.Option(
|
184
|
+
"--max-prompts",
|
185
|
+
"-m",
|
109
186
|
help="The maximum number of times to prompt a model on one functional block "
|
110
|
-
"before exiting the application. This is to prevent wasting too much money."
|
187
|
+
"before exiting the application. This is to prevent wasting too much money.",
|
111
188
|
),
|
112
189
|
] = 10,
|
113
190
|
overwrite: Annotated[
|
@@ -119,18 +196,22 @@ def translate(
|
|
119
196
|
] = False,
|
120
197
|
temp: Annotated[
|
121
198
|
float,
|
122
|
-
typer.Option(help="Sampling temperature.", min=0, max=2),
|
199
|
+
typer.Option("--temperature", "-T", help="Sampling temperature.", min=0, max=2),
|
123
200
|
] = 0.7,
|
124
201
|
prompt_template: Annotated[
|
125
202
|
str,
|
126
203
|
typer.Option(
|
204
|
+
"--prompt-template",
|
205
|
+
"-p",
|
127
206
|
help="Name of the Janus prompt template directory or "
|
128
|
-
"path to a directory containing those template files."
|
207
|
+
"path to a directory containing those template files.",
|
129
208
|
),
|
130
209
|
] = "simple",
|
131
210
|
parser_type: Annotated[
|
132
211
|
str,
|
133
212
|
typer.Option(
|
213
|
+
"--parser",
|
214
|
+
"-P",
|
134
215
|
click_type=click.Choice(sorted(PARSER_TYPES)),
|
135
216
|
help="The type of parser to use.",
|
136
217
|
),
|
@@ -144,6 +225,24 @@ def translate(
|
|
144
225
|
"collection with the name provided.",
|
145
226
|
),
|
146
227
|
] = None,
|
228
|
+
custom_splitter: Annotated[
|
229
|
+
Optional[str],
|
230
|
+
typer.Option(
|
231
|
+
"-cs",
|
232
|
+
"--custom-splitter",
|
233
|
+
help="Name of custom splitter to use",
|
234
|
+
click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
|
235
|
+
),
|
236
|
+
] = None,
|
237
|
+
max_tokens: Annotated[
|
238
|
+
int,
|
239
|
+
typer.Option(
|
240
|
+
"--max-tokens",
|
241
|
+
"-M",
|
242
|
+
help="The maximum number of tokens the model will take in. "
|
243
|
+
"If unspecificed, model's default max will be used.",
|
244
|
+
),
|
245
|
+
] = None,
|
147
246
|
):
|
148
247
|
try:
|
149
248
|
target_language, target_version = target_lang.split("-")
|
@@ -156,12 +255,7 @@ def translate(
|
|
156
255
|
raise ValueError
|
157
256
|
|
158
257
|
model_arguments = dict(temperature=temp)
|
159
|
-
|
160
|
-
if collection is not None:
|
161
|
-
_check_collection(collection, input_dir)
|
162
|
-
db = ChromaEmbeddingDatabase(db_loc)
|
163
|
-
collections = Collections(db)
|
164
|
-
output_collection = collections.get_or_create(collection)
|
258
|
+
collections_config = get_collections_config()
|
165
259
|
translator = Translator(
|
166
260
|
model=llm_name,
|
167
261
|
model_arguments=model_arguments,
|
@@ -169,21 +263,269 @@ def translate(
|
|
169
263
|
target_language=target_language,
|
170
264
|
target_version=target_version,
|
171
265
|
max_prompts=max_prompts,
|
266
|
+
max_tokens=max_tokens,
|
172
267
|
prompt_template=prompt_template,
|
173
268
|
parser_type=parser_type,
|
269
|
+
db_path=db_loc,
|
270
|
+
db_config=collections_config,
|
271
|
+
custom_splitter=custom_splitter,
|
174
272
|
)
|
175
|
-
translator.translate(input_dir, output_dir, overwrite,
|
273
|
+
translator.translate(input_dir, output_dir, overwrite, collection)
|
274
|
+
|
275
|
+
|
276
|
+
@app.command(
|
277
|
+
help="Document input code using an LLM.",
|
278
|
+
no_args_is_help=True,
|
279
|
+
)
|
280
|
+
def document(
|
281
|
+
input_dir: Annotated[
|
282
|
+
Path,
|
283
|
+
typer.Option(
|
284
|
+
"--input",
|
285
|
+
"-i",
|
286
|
+
help="The directory containing the source code to be translated. "
|
287
|
+
"The files should all be in one flat directory.",
|
288
|
+
),
|
289
|
+
],
|
290
|
+
language: Annotated[
|
291
|
+
str,
|
292
|
+
typer.Option(
|
293
|
+
"--language",
|
294
|
+
"-l",
|
295
|
+
help="The language of the source code.",
|
296
|
+
click_type=click.Choice(sorted(LANGUAGES)),
|
297
|
+
),
|
298
|
+
],
|
299
|
+
output_dir: Annotated[
|
300
|
+
Path,
|
301
|
+
typer.Option(
|
302
|
+
"--output-dir", "-o", help="The directory to store the translated code in."
|
303
|
+
),
|
304
|
+
],
|
305
|
+
llm_name: Annotated[
|
306
|
+
str,
|
307
|
+
typer.Option(
|
308
|
+
"--llm",
|
309
|
+
"-L",
|
310
|
+
help="The custom name of the model set with 'janus llm add'.",
|
311
|
+
),
|
312
|
+
] = "gpt-3.5-turbo-0125",
|
313
|
+
max_prompts: Annotated[
|
314
|
+
int,
|
315
|
+
typer.Option(
|
316
|
+
"--max-prompts",
|
317
|
+
"-m",
|
318
|
+
help="The maximum number of times to prompt a model on one functional block "
|
319
|
+
"before exiting the application. This is to prevent wasting too much money.",
|
320
|
+
),
|
321
|
+
] = 10,
|
322
|
+
overwrite: Annotated[
|
323
|
+
bool,
|
324
|
+
typer.Option(
|
325
|
+
"--overwrite/--preserve",
|
326
|
+
help="Whether to overwrite existing files in the output directory",
|
327
|
+
),
|
328
|
+
] = False,
|
329
|
+
doc_mode: Annotated[
|
330
|
+
str,
|
331
|
+
typer.Option(
|
332
|
+
"--doc-mode",
|
333
|
+
"-d",
|
334
|
+
help="The documentation mode.",
|
335
|
+
click_type=click.Choice(["madlibs", "summary", "multidoc", "requirements"]),
|
336
|
+
),
|
337
|
+
] = "madlibs",
|
338
|
+
comments_per_request: Annotated[
|
339
|
+
int,
|
340
|
+
typer.Option(
|
341
|
+
"--comments-per-request",
|
342
|
+
"-rc",
|
343
|
+
help="The maximum number of comments to generate per request when using "
|
344
|
+
"MadLibs documentation mode.",
|
345
|
+
),
|
346
|
+
] = None,
|
347
|
+
drop_comments: Annotated[
|
348
|
+
bool,
|
349
|
+
typer.Option(
|
350
|
+
"--drop-comments/--keep-comments",
|
351
|
+
help="Whether to drop or keep comments in the code sent to the LLM",
|
352
|
+
),
|
353
|
+
] = False,
|
354
|
+
temperature: Annotated[
|
355
|
+
float,
|
356
|
+
typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
|
357
|
+
] = 0.7,
|
358
|
+
collection: Annotated[
|
359
|
+
str,
|
360
|
+
typer.Option(
|
361
|
+
"--collection",
|
362
|
+
"-c",
|
363
|
+
help="If set, will put the translated result into a Chroma DB "
|
364
|
+
"collection with the name provided.",
|
365
|
+
),
|
366
|
+
] = None,
|
367
|
+
custom_splitter: Annotated[
|
368
|
+
Optional[str],
|
369
|
+
typer.Option(
|
370
|
+
"-cs",
|
371
|
+
"--custom-splitter",
|
372
|
+
help="Name of custom splitter to use",
|
373
|
+
click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
|
374
|
+
),
|
375
|
+
] = None,
|
376
|
+
max_tokens: Annotated[
|
377
|
+
int,
|
378
|
+
typer.Option(
|
379
|
+
"--max-tokens",
|
380
|
+
"-M",
|
381
|
+
help="The maximum number of tokens the model will take in. "
|
382
|
+
"If unspecificed, model's default max will be used.",
|
383
|
+
),
|
384
|
+
] = None,
|
385
|
+
):
|
386
|
+
model_arguments = dict(temperature=temperature)
|
387
|
+
collections_config = get_collections_config()
|
388
|
+
kwargs = dict(
|
389
|
+
model=llm_name,
|
390
|
+
model_arguments=model_arguments,
|
391
|
+
source_language=language,
|
392
|
+
max_prompts=max_prompts,
|
393
|
+
max_tokens=max_tokens,
|
394
|
+
db_path=db_loc,
|
395
|
+
db_config=collections_config,
|
396
|
+
custom_splitter=custom_splitter,
|
397
|
+
)
|
398
|
+
if doc_mode == "madlibs":
|
399
|
+
documenter = MadLibsDocumenter(
|
400
|
+
comments_per_request=comments_per_request, **kwargs
|
401
|
+
)
|
402
|
+
elif doc_mode == "multidoc":
|
403
|
+
documenter = MultiDocumenter(drop_comments=drop_comments, **kwargs)
|
404
|
+
elif doc_mode == "requirements":
|
405
|
+
documenter = RequirementsDocumenter(drop_comments=drop_comments, **kwargs)
|
406
|
+
else:
|
407
|
+
documenter = Documenter(drop_comments=drop_comments, **kwargs)
|
408
|
+
|
409
|
+
documenter.translate(input_dir, output_dir, overwrite, collection)
|
410
|
+
|
411
|
+
|
412
|
+
@app.command(
|
413
|
+
help="Diagram input code using an LLM.",
|
414
|
+
no_args_is_help=True,
|
415
|
+
)
|
416
|
+
def diagram(
|
417
|
+
input_dir: Annotated[
|
418
|
+
Path,
|
419
|
+
typer.Option(
|
420
|
+
"--input",
|
421
|
+
"-i",
|
422
|
+
help="The directory containing the source code to be translated. "
|
423
|
+
"The files should all be in one flat directory.",
|
424
|
+
),
|
425
|
+
],
|
426
|
+
language: Annotated[
|
427
|
+
str,
|
428
|
+
typer.Option(
|
429
|
+
"--language",
|
430
|
+
"-l",
|
431
|
+
help="The language of the source code.",
|
432
|
+
click_type=click.Choice(sorted(LANGUAGES)),
|
433
|
+
),
|
434
|
+
],
|
435
|
+
output_dir: Annotated[
|
436
|
+
Path,
|
437
|
+
typer.Option(
|
438
|
+
"--output-dir", "-o", help="The directory to store the translated code in."
|
439
|
+
),
|
440
|
+
],
|
441
|
+
llm_name: Annotated[
|
442
|
+
str,
|
443
|
+
typer.Option(
|
444
|
+
"--llm",
|
445
|
+
"-L",
|
446
|
+
help="The custom name of the model set with 'janus llm add'.",
|
447
|
+
),
|
448
|
+
] = "gpt-3.5-turbo-0125",
|
449
|
+
max_prompts: Annotated[
|
450
|
+
int,
|
451
|
+
typer.Option(
|
452
|
+
"--max-prompts",
|
453
|
+
"-m",
|
454
|
+
help="The maximum number of times to prompt a model on one functional block "
|
455
|
+
"before exiting the application. This is to prevent wasting too much money.",
|
456
|
+
),
|
457
|
+
] = 10,
|
458
|
+
overwrite: Annotated[
|
459
|
+
bool,
|
460
|
+
typer.Option(
|
461
|
+
"--overwrite/--preserve",
|
462
|
+
help="Whether to overwrite existing files in the output directory",
|
463
|
+
),
|
464
|
+
] = False,
|
465
|
+
temperature: Annotated[
|
466
|
+
float,
|
467
|
+
typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
|
468
|
+
] = 0.7,
|
469
|
+
collection: Annotated[
|
470
|
+
str,
|
471
|
+
typer.Option(
|
472
|
+
"--collection",
|
473
|
+
"-c",
|
474
|
+
help="If set, will put the translated result into a Chroma DB "
|
475
|
+
"collection with the name provided.",
|
476
|
+
),
|
477
|
+
] = None,
|
478
|
+
diagram_type: Annotated[
|
479
|
+
str,
|
480
|
+
typer.Option(
|
481
|
+
"--diagram-type", "-dg", help="Diagram type to generate in PLANTUML"
|
482
|
+
),
|
483
|
+
] = "Activity",
|
484
|
+
add_documentation: Annotated[
|
485
|
+
bool,
|
486
|
+
typer.Option(
|
487
|
+
"--add-documentation/--no-documentation",
|
488
|
+
"-ad",
|
489
|
+
help="Whether to use documentation in generation",
|
490
|
+
),
|
491
|
+
] = False,
|
492
|
+
custom_splitter: Annotated[
|
493
|
+
Optional[str],
|
494
|
+
typer.Option(
|
495
|
+
"-cs",
|
496
|
+
"--custom-splitter",
|
497
|
+
help="Name of custom splitter to use",
|
498
|
+
click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
|
499
|
+
),
|
500
|
+
] = None,
|
501
|
+
):
|
502
|
+
model_arguments = dict(temperature=temperature)
|
503
|
+
collections_config = get_collections_config()
|
504
|
+
diagram_generator = DiagramGenerator(
|
505
|
+
model=llm_name,
|
506
|
+
model_arguments=model_arguments,
|
507
|
+
source_language=language,
|
508
|
+
max_prompts=max_prompts,
|
509
|
+
db_path=db_loc,
|
510
|
+
db_config=collections_config,
|
511
|
+
diagram_type=diagram_type,
|
512
|
+
add_documentation=add_documentation,
|
513
|
+
custom_splitter=custom_splitter,
|
514
|
+
)
|
515
|
+
diagram_generator.translate(input_dir, output_dir, overwrite, collection)
|
176
516
|
|
177
517
|
|
178
518
|
@db.command("init", help="Connect to or create a database.")
|
179
519
|
def db_init(
|
180
|
-
path: Annotated[
|
181
|
-
|
182
|
-
),
|
520
|
+
path: Annotated[
|
521
|
+
str, typer.Option("--path", "-p", help="The path to the database file.")
|
522
|
+
] = str(janus_dir / "chroma.db"),
|
183
523
|
url: Annotated[
|
184
524
|
str,
|
185
525
|
typer.Option(
|
186
|
-
|
526
|
+
"--url",
|
527
|
+
"-u",
|
528
|
+
help="The URL of the database if the database is running externally.",
|
187
529
|
),
|
188
530
|
] = "",
|
189
531
|
) -> None:
|
@@ -219,7 +561,7 @@ def db_ls(
|
|
219
561
|
] = None,
|
220
562
|
peek: Annotated[
|
221
563
|
Optional[int],
|
222
|
-
typer.Option(help="Peek at N entries for a specific collection."),
|
564
|
+
typer.Option("--peek", "-p", help="Peek at N entries for a specific collection."),
|
223
565
|
] = None,
|
224
566
|
) -> None:
|
225
567
|
"""List the current database's collections"""
|
@@ -256,17 +598,24 @@ def db_ls(
|
|
256
598
|
@db.command("add", help="Add a collection to the current database.")
|
257
599
|
def db_add(
|
258
600
|
collection_name: Annotated[str, typer.Argument(help="The name of the collection.")],
|
601
|
+
model_name: Annotated[str, typer.Argument(help="The name of the embedding model.")],
|
259
602
|
input_dir: Annotated[
|
260
603
|
str,
|
261
|
-
typer.Option(
|
604
|
+
typer.Option(
|
605
|
+
"--input",
|
606
|
+
"-i",
|
607
|
+
help="The directory containing the source code to be added.",
|
608
|
+
),
|
262
609
|
] = "./",
|
263
610
|
input_lang: Annotated[
|
264
|
-
str, typer.Option(help="The language of the source code.")
|
611
|
+
str, typer.Option("--language", "-l", help="The language of the source code.")
|
265
612
|
] = "python",
|
266
613
|
max_tokens: Annotated[
|
267
614
|
int,
|
268
615
|
typer.Option(
|
269
|
-
|
616
|
+
"--max-tokens",
|
617
|
+
"-m",
|
618
|
+
help="The maximum number of tokens for each chunk of input source code.",
|
270
619
|
),
|
271
620
|
] = 4096,
|
272
621
|
) -> None:
|
@@ -274,13 +623,16 @@ def db_add(
|
|
274
623
|
|
275
624
|
Arguments:
|
276
625
|
collection_name: The name of the collection to add
|
626
|
+
model_name: The name of the embedding model to use
|
277
627
|
input_dir: The directory containing the source code to be added
|
278
628
|
input_lang: The language of the source code
|
629
|
+
max_tokens: The maximum number of tokens for each chunk of input source code
|
279
630
|
"""
|
280
631
|
# TODO: import factory
|
281
632
|
console = Console()
|
282
633
|
|
283
634
|
added_to = _check_collection(collection_name, input_dir)
|
635
|
+
collections_config = get_collections_config()
|
284
636
|
|
285
637
|
with console.status(
|
286
638
|
f"Adding collection: [bold salmon]{collection_name}[/bold salmon]",
|
@@ -288,13 +640,13 @@ def db_add(
|
|
288
640
|
):
|
289
641
|
vectorizer_factory = ChromaDBVectorizer()
|
290
642
|
vectorizer = vectorizer_factory.create_vectorizer(
|
291
|
-
|
292
|
-
path=db_loc,
|
293
|
-
max_tokens=max_tokens,
|
643
|
+
path=db_loc, config=collections_config
|
294
644
|
)
|
645
|
+
vectorizer.get_or_create_collection(collection_name, model_name=model_name)
|
295
646
|
input_dir = Path(input_dir)
|
296
|
-
|
297
|
-
|
647
|
+
suffix = LANGUAGES[input_lang]["suffix"]
|
648
|
+
source_glob = f"**/*.{suffix}"
|
649
|
+
input_paths = [p for p in input_dir.rglob(source_glob)]
|
298
650
|
if input_lang in CUSTOM_SPLITTERS:
|
299
651
|
if input_lang == "mumps":
|
300
652
|
splitter = MumpsSplitter(
|
@@ -311,15 +663,35 @@ def db_add(
|
|
311
663
|
)
|
312
664
|
for input_path in input_paths:
|
313
665
|
input_block = splitter.split(input_path)
|
314
|
-
vectorizer.
|
666
|
+
vectorizer.add_nodes_recursively(
|
315
667
|
input_block,
|
316
668
|
collection_name,
|
317
669
|
input_path.name,
|
318
670
|
)
|
671
|
+
total_files = len([p for p in Path.glob(input_dir, "**/*") if not p.is_dir()])
|
319
672
|
if added_to:
|
320
|
-
print(
|
673
|
+
print(
|
674
|
+
f"\nAdded to [bold salmon1]{collection_name}[/bold salmon1]:\n"
|
675
|
+
f" Embedding Model: [green]{model_name}[/green]\n"
|
676
|
+
f" Input Directory: {input_dir.absolute()}\n"
|
677
|
+
f" {input_lang.capitalize()} [green]*.{suffix}[/green] Files: "
|
678
|
+
f"{len(input_paths)}\n"
|
679
|
+
" Other Files (skipped): "
|
680
|
+
f"{total_files - len(input_paths)}\n"
|
681
|
+
)
|
682
|
+
[p for p in Path.glob(input_dir, f"**/*.{suffix}") if not p.is_dir()]
|
321
683
|
else:
|
322
|
-
print(
|
684
|
+
print(
|
685
|
+
f"\nCreated [bold salmon1]{collection_name}[/bold salmon1]:\n"
|
686
|
+
f" Embedding Model: '{model_name}'\n"
|
687
|
+
f" Input Directory: {input_dir.absolute()}\n"
|
688
|
+
f" {input_lang.capitalize()} [green]*.{suffix}[/green] Files: "
|
689
|
+
f"{len(input_paths)}\n"
|
690
|
+
" Other Files (skipped): "
|
691
|
+
f"{total_files - len(input_paths)}\n"
|
692
|
+
)
|
693
|
+
with open(collections_config_file, "w") as f:
|
694
|
+
json.dump(vectorizer.config, f, indent=2)
|
323
695
|
|
324
696
|
|
325
697
|
@db.command(
|
@@ -327,17 +699,28 @@ def db_add(
|
|
327
699
|
help="Remove a collection from the database.",
|
328
700
|
)
|
329
701
|
def db_rm(
|
330
|
-
collection_name: Annotated[str, typer.Argument(help="The name of the collection.")]
|
702
|
+
collection_name: Annotated[str, typer.Argument(help="The name of the collection.")],
|
703
|
+
confirm: Annotated[
|
704
|
+
bool,
|
705
|
+
typer.Option(
|
706
|
+
"--yes",
|
707
|
+
"-y",
|
708
|
+
help="Confirm the removal of the collection.",
|
709
|
+
),
|
710
|
+
],
|
331
711
|
) -> None:
|
332
712
|
"""Remove a collection from the database
|
333
713
|
|
334
714
|
Arguments:
|
335
715
|
collection_name: The name of the collection to remove
|
336
716
|
"""
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
717
|
+
if not confirm:
|
718
|
+
delete = Confirm.ask(
|
719
|
+
f"\nAre you sure you want to [bold red]remove[/bold red] "
|
720
|
+
f"[bold salmon1]{collection_name}[/bold salmon1]?",
|
721
|
+
)
|
722
|
+
else:
|
723
|
+
delete = True
|
341
724
|
if not delete:
|
342
725
|
raise typer.Abort()
|
343
726
|
db = ChromaEmbeddingDatabase(db_loc)
|
@@ -425,16 +808,115 @@ def llm_add(
|
|
425
808
|
"model_cost": {"input": in_cost, "output": out_cost},
|
426
809
|
}
|
427
810
|
elif model_type == "OpenAI":
|
428
|
-
model_name = typer.prompt("Enter the model name", default="gpt-3.5-turbo")
|
811
|
+
model_name = typer.prompt("Enter the model name", default="gpt-3.5-turbo-0125")
|
429
812
|
params = dict(
|
430
813
|
model_name=model_name,
|
431
814
|
temperature=0.7,
|
432
815
|
n=1,
|
433
816
|
)
|
434
817
|
max_tokens = TOKEN_LIMITS[model_name]
|
435
|
-
model_cost =
|
818
|
+
model_cost = COST_PER_1K_TOKENS[model_name]
|
819
|
+
cfg = {
|
820
|
+
"model_type": model_type,
|
821
|
+
"model_args": params,
|
822
|
+
"token_limit": max_tokens,
|
823
|
+
"model_cost": model_cost,
|
824
|
+
}
|
825
|
+
else:
|
826
|
+
raise ValueError(f"Unknown model type {model_type}")
|
827
|
+
with open(model_cfg, "w") as f:
|
828
|
+
json.dump(cfg, f, indent=2)
|
829
|
+
print(f"Model config written to {model_cfg}")
|
830
|
+
|
831
|
+
|
832
|
+
@embedding.command("add", help="Add an embedding model config to janus")
|
833
|
+
def embedding_add(
|
834
|
+
model_name: Annotated[
|
835
|
+
str, typer.Argument(help="The user's custom name for the model")
|
836
|
+
],
|
837
|
+
model_type: Annotated[
|
838
|
+
str,
|
839
|
+
typer.Option(
|
840
|
+
"--type",
|
841
|
+
"-t",
|
842
|
+
help="The type of the model",
|
843
|
+
click_type=click.Choice(list(val.value for val in EmbeddingModelType)),
|
844
|
+
),
|
845
|
+
] = "OpenAI",
|
846
|
+
):
|
847
|
+
if not EMBEDDING_MODEL_CONFIG_DIR.exists():
|
848
|
+
EMBEDDING_MODEL_CONFIG_DIR.mkdir(parents=True)
|
849
|
+
model_cfg = EMBEDDING_MODEL_CONFIG_DIR / f"{model_name}.json"
|
850
|
+
if model_type in EmbeddingModelType.HuggingFaceInferenceAPI.values:
|
851
|
+
hf = typer.style("HuggingFaceInferenceAPI", fg="yellow")
|
852
|
+
url = typer.prompt(f"Enter the {hf} model's URL", type=str, value_proc=AnyHttpUrl)
|
853
|
+
api_model_name = typer.prompt("Enter the model's name", type=str, default="")
|
854
|
+
api_key = typer.prompt("Enter the API key", type=str, default="")
|
855
|
+
max_tokens = typer.prompt(
|
856
|
+
"Enter the model's maximum tokens", default=8191, type=int
|
857
|
+
)
|
858
|
+
in_cost = typer.prompt("Enter the cost per input token", default=0, type=float)
|
859
|
+
out_cost = typer.prompt("Enter the cost per output token", default=0, type=float)
|
860
|
+
params = dict(
|
861
|
+
model_name=api_model_name,
|
862
|
+
api_key=api_key,
|
863
|
+
)
|
864
|
+
cfg = {
|
865
|
+
"model_type": model_type,
|
866
|
+
"model_identifier": str(url),
|
867
|
+
"model_args": params,
|
868
|
+
"token_limit": max_tokens,
|
869
|
+
"model_cost": {"input": in_cost, "output": out_cost},
|
870
|
+
}
|
871
|
+
elif model_type in EmbeddingModelType.HuggingFaceLocal.values:
|
872
|
+
hf = typer.style("HuggingFace", fg="yellow")
|
873
|
+
model_id = typer.prompt(
|
874
|
+
f"Enter the {hf} model ID",
|
875
|
+
default="sentence-transformers/all-MiniLM-L6-v2",
|
876
|
+
type=str,
|
877
|
+
)
|
878
|
+
cache_folder = str(
|
879
|
+
Path(
|
880
|
+
typer.prompt(
|
881
|
+
"Enter the model's cache folder",
|
882
|
+
default=EMBEDDING_MODEL_CONFIG_DIR / "cache",
|
883
|
+
type=str,
|
884
|
+
)
|
885
|
+
)
|
886
|
+
)
|
887
|
+
max_tokens = typer.prompt(
|
888
|
+
"Enter the model's maximum tokens", default=8191, type=int
|
889
|
+
)
|
890
|
+
params = dict(
|
891
|
+
cache_folder=str(cache_folder),
|
892
|
+
)
|
893
|
+
cfg = {
|
894
|
+
"model_type": model_type,
|
895
|
+
"model_identifier": model_id,
|
896
|
+
"model_args": params,
|
897
|
+
"token_limit": max_tokens,
|
898
|
+
"model_cost": {"input": 0, "output": 0},
|
899
|
+
}
|
900
|
+
elif model_type in EmbeddingModelType.OpenAI.values:
|
901
|
+
available_models = list(EMBEDDING_COST_PER_MODEL.keys())
|
902
|
+
|
903
|
+
open_ai = typer.style("OpenAI", fg="green")
|
904
|
+
prompt = f"Enter the {open_ai} model name"
|
905
|
+
|
906
|
+
model_name = typer.prompt(
|
907
|
+
prompt,
|
908
|
+
default="text-embedding-3-small",
|
909
|
+
type=click.types.Choice(available_models),
|
910
|
+
show_choices=False,
|
911
|
+
)
|
912
|
+
params = dict(
|
913
|
+
model=model_name,
|
914
|
+
)
|
915
|
+
max_tokens = EMBEDDING_TOKEN_LIMITS[model_name]
|
916
|
+
model_cost = EMBEDDING_COST_PER_MODEL[model_name]
|
436
917
|
cfg = {
|
437
918
|
"model_type": model_type,
|
919
|
+
"model_identifier": model_name,
|
438
920
|
"model_args": params,
|
439
921
|
"token_limit": max_tokens,
|
440
922
|
"model_cost": model_cost,
|
@@ -448,6 +930,8 @@ def llm_add(
|
|
448
930
|
|
449
931
|
app.add_typer(db, name="db")
|
450
932
|
app.add_typer(llm, name="llm")
|
933
|
+
app.add_typer(evaluate, name="evaluate")
|
934
|
+
app.add_typer(embedding, name="embedding")
|
451
935
|
|
452
936
|
|
453
937
|
if __name__ == "__main__":
|