janus-llm 1.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (74) hide show
  1. janus/__init__.py +9 -1
  2. janus/__main__.py +4 -0
  3. janus/_tests/test_cli.py +128 -0
  4. janus/_tests/test_translate.py +49 -7
  5. janus/cli.py +530 -46
  6. janus/converter.py +50 -19
  7. janus/embedding/_tests/test_collections.py +2 -8
  8. janus/embedding/_tests/test_database.py +32 -0
  9. janus/embedding/_tests/test_vectorize.py +9 -4
  10. janus/embedding/collections.py +49 -6
  11. janus/embedding/embedding_models_info.py +130 -0
  12. janus/embedding/vectorize.py +53 -62
  13. janus/language/_tests/__init__.py +0 -0
  14. janus/language/_tests/test_combine.py +62 -0
  15. janus/language/_tests/test_splitter.py +16 -0
  16. janus/language/binary/_tests/test_binary.py +16 -1
  17. janus/language/binary/binary.py +10 -3
  18. janus/language/block.py +31 -30
  19. janus/language/combine.py +26 -34
  20. janus/language/mumps/_tests/test_mumps.py +2 -2
  21. janus/language/mumps/mumps.py +93 -9
  22. janus/language/naive/__init__.py +4 -0
  23. janus/language/naive/basic_splitter.py +14 -0
  24. janus/language/naive/chunk_splitter.py +26 -0
  25. janus/language/naive/registry.py +13 -0
  26. janus/language/naive/simple_ast.py +18 -0
  27. janus/language/naive/tag_splitter.py +61 -0
  28. janus/language/splitter.py +168 -74
  29. janus/language/treesitter/_tests/test_treesitter.py +19 -14
  30. janus/language/treesitter/treesitter.py +37 -13
  31. janus/llm/model_callbacks.py +177 -0
  32. janus/llm/models_info.py +165 -72
  33. janus/metrics/__init__.py +8 -0
  34. janus/metrics/_tests/__init__.py +0 -0
  35. janus/metrics/_tests/reference.py +2 -0
  36. janus/metrics/_tests/target.py +2 -0
  37. janus/metrics/_tests/test_bleu.py +56 -0
  38. janus/metrics/_tests/test_chrf.py +67 -0
  39. janus/metrics/_tests/test_file_pairing.py +59 -0
  40. janus/metrics/_tests/test_llm.py +91 -0
  41. janus/metrics/_tests/test_reading.py +28 -0
  42. janus/metrics/_tests/test_rouge_score.py +65 -0
  43. janus/metrics/_tests/test_similarity_score.py +23 -0
  44. janus/metrics/_tests/test_treesitter_metrics.py +110 -0
  45. janus/metrics/bleu.py +66 -0
  46. janus/metrics/chrf.py +55 -0
  47. janus/metrics/cli.py +7 -0
  48. janus/metrics/complexity_metrics.py +208 -0
  49. janus/metrics/file_pairing.py +113 -0
  50. janus/metrics/llm_metrics.py +202 -0
  51. janus/metrics/metric.py +466 -0
  52. janus/metrics/reading.py +70 -0
  53. janus/metrics/rouge_score.py +96 -0
  54. janus/metrics/similarity.py +53 -0
  55. janus/metrics/splitting.py +38 -0
  56. janus/parsers/_tests/__init__.py +0 -0
  57. janus/parsers/_tests/test_code_parser.py +32 -0
  58. janus/parsers/code_parser.py +24 -253
  59. janus/parsers/doc_parser.py +169 -0
  60. janus/parsers/eval_parser.py +80 -0
  61. janus/parsers/reqs_parser.py +72 -0
  62. janus/prompts/prompt.py +103 -30
  63. janus/translate.py +636 -111
  64. janus/utils/_tests/__init__.py +0 -0
  65. janus/utils/_tests/test_logger.py +67 -0
  66. janus/utils/_tests/test_progress.py +20 -0
  67. janus/utils/enums.py +56 -3
  68. janus/utils/progress.py +56 -0
  69. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/METADATA +27 -11
  70. janus_llm-2.0.1.dist-info/RECORD +94 -0
  71. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/WHEEL +1 -1
  72. janus_llm-1.0.0.dist-info/RECORD +0 -48
  73. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/LICENSE +0 -0
  74. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/entry_points.txt +0 -0
janus/cli.py CHANGED
@@ -1,34 +1,50 @@
1
1
  import json
2
+ import logging
2
3
  import os
3
4
  from pathlib import Path
4
5
  from typing import Optional
5
6
 
6
7
  import click
7
8
  import typer
9
+ from pydantic import AnyHttpUrl
8
10
  from rich import print
9
11
  from rich.console import Console
10
12
  from rich.prompt import Confirm
11
13
  from typing_extensions import Annotated
12
14
 
15
+ from janus.language.naive.registry import CUSTOM_SPLITTERS
16
+
13
17
  from .embedding.collections import Collections
14
18
  from .embedding.database import ChromaEmbeddingDatabase
19
+ from .embedding.embedding_models_info import (
20
+ EMBEDDING_COST_PER_MODEL,
21
+ EMBEDDING_MODEL_CONFIG_DIR,
22
+ EMBEDDING_TOKEN_LIMITS,
23
+ EmbeddingModelType,
24
+ )
15
25
  from .embedding.vectorize import ChromaDBVectorizer
16
26
  from .language.binary import BinarySplitter
17
27
  from .language.mumps import MumpsSplitter
18
28
  from .language.treesitter import TreeSitterSplitter
19
- from .llm.models_info import (
20
- COST_PER_MODEL,
21
- MODEL_CONFIG_DIR,
22
- MODEL_TYPE_CONSTRUCTORS,
23
- TOKEN_LIMITS,
29
+ from .llm.model_callbacks import COST_PER_1K_TOKENS
30
+ from .llm.models_info import MODEL_CONFIG_DIR, MODEL_TYPE_CONSTRUCTORS, TOKEN_LIMITS
31
+ from .metrics.cli import evaluate
32
+ from .translate import (
33
+ PARSER_TYPES,
34
+ DiagramGenerator,
35
+ Documenter,
36
+ MadLibsDocumenter,
37
+ MultiDocumenter,
38
+ RequirementsDocumenter,
39
+ Translator,
24
40
  )
25
- from .parsers.code_parser import PARSER_TYPES
26
- from .translate import Translator
27
- from .utils.enums import CUSTOM_SPLITTERS, LANGUAGES
41
+ from .utils.enums import LANGUAGES
28
42
  from .utils.logger import create_logger
29
43
 
30
- log = create_logger(__name__)
44
+ httpx_logger = logging.getLogger("httpx")
45
+ httpx_logger.setLevel(logging.WARNING)
31
46
 
47
+ log = create_logger(__name__)
32
48
  homedir = Path.home().expanduser()
33
49
 
34
50
  janus_dir = homedir / ".janus"
@@ -43,6 +59,17 @@ if not db_file.exists():
43
59
  with open(db_file, "r") as f:
44
60
  db_loc = f.read()
45
61
 
62
+ collections_config_file = Path(db_loc) / "collections.json"
63
+
64
+
65
+ def get_collections_config():
66
+ if collections_config_file.exists():
67
+ with open(collections_config_file, "r") as f:
68
+ config = json.load(f)
69
+ else:
70
+ config = {}
71
+ return config
72
+
46
73
 
47
74
  app = typer.Typer(
48
75
  help="Choose a command",
@@ -51,6 +78,7 @@ app = typer.Typer(
51
78
  context_settings={"help_option_names": ["-h", "--help"]},
52
79
  )
53
80
 
81
+
54
82
  db = typer.Typer(
55
83
  help="Database commands",
56
84
  add_completion=False,
@@ -64,6 +92,43 @@ llm = typer.Typer(
64
92
  context_settings={"help_option_names": ["-h", "--help"]},
65
93
  )
66
94
 
95
+ embedding = typer.Typer(
96
+ help="Embedding model commands",
97
+ add_completion=False,
98
+ no_args_is_help=True,
99
+ context_settings={"help_option_names": ["-h", "--help"]},
100
+ )
101
+
102
+
103
+ def version_callback(value: bool) -> None:
104
+ if value:
105
+ from . import __version__ as version
106
+
107
+ print(f"Janus CLI [blue]v{version}[/blue]")
108
+ raise typer.Exit()
109
+
110
+
111
+ @app.callback()
112
+ def common(
113
+ ctx: typer.Context,
114
+ version: bool = typer.Option(
115
+ None,
116
+ "--version",
117
+ "-v",
118
+ callback=version_callback,
119
+ help="Print the version and exit.",
120
+ ),
121
+ ) -> None:
122
+ """A function for getting the app version
123
+
124
+ This will call the version_callback function to print the version and exit.
125
+
126
+ Arguments:
127
+ ctx: The typer context
128
+ version: A boolean flag for the version
129
+ """
130
+ pass
131
+
67
132
 
68
133
  @app.command(
69
134
  help="Translate code from one language to another using an LLM.",
@@ -73,41 +138,53 @@ def translate(
73
138
  input_dir: Annotated[
74
139
  Path,
75
140
  typer.Option(
141
+ "--input",
142
+ "-i",
76
143
  help="The directory containing the source code to be translated. "
77
- "The files should all be in one flat directory."
144
+ "The files should all be in one flat directory.",
78
145
  ),
79
146
  ],
80
147
  source_lang: Annotated[
81
148
  str,
82
149
  typer.Option(
150
+ "--source-language",
151
+ "-s",
83
152
  help="The language of the source code.",
84
153
  click_type=click.Choice(sorted(LANGUAGES)),
85
154
  ),
86
155
  ],
87
156
  output_dir: Annotated[
88
157
  Path,
89
- typer.Option(help="The directory to store the translated code in."),
158
+ typer.Option(
159
+ "--output", "-o", help="The directory to store the translated code in."
160
+ ),
90
161
  ],
91
162
  target_lang: Annotated[
92
163
  str,
93
164
  typer.Option(
165
+ "--target-language",
166
+ "-t",
94
167
  help="The desired output language to translate the source code to. The "
95
168
  "format can follow a 'language-version' syntax. Use 'text' to get plaintext"
96
169
  "results as returned by the LLM. Examples: `python-3.10`, `mumps`, `java-10`,"
97
- "text."
170
+ "text.",
98
171
  ),
99
172
  ],
100
173
  llm_name: Annotated[
101
174
  str,
102
175
  typer.Option(
176
+ "--llm",
177
+ "-L",
103
178
  help="The custom name of the model set with 'janus llm add'.",
104
179
  ),
105
- ] = "gpt-3.5-turbo",
180
+ ] = "gpt-3.5-turbo-0125",
106
181
  max_prompts: Annotated[
107
182
  int,
108
183
  typer.Option(
184
+ "--max-prompts",
185
+ "-m",
109
186
  help="The maximum number of times to prompt a model on one functional block "
110
- "before exiting the application. This is to prevent wasting too much money."
187
+ "before exiting the application. This is to prevent wasting too much money.",
111
188
  ),
112
189
  ] = 10,
113
190
  overwrite: Annotated[
@@ -119,18 +196,22 @@ def translate(
119
196
  ] = False,
120
197
  temp: Annotated[
121
198
  float,
122
- typer.Option(help="Sampling temperature.", min=0, max=2),
199
+ typer.Option("--temperature", "-T", help="Sampling temperature.", min=0, max=2),
123
200
  ] = 0.7,
124
201
  prompt_template: Annotated[
125
202
  str,
126
203
  typer.Option(
204
+ "--prompt-template",
205
+ "-p",
127
206
  help="Name of the Janus prompt template directory or "
128
- "path to a directory containing those template files."
207
+ "path to a directory containing those template files.",
129
208
  ),
130
209
  ] = "simple",
131
210
  parser_type: Annotated[
132
211
  str,
133
212
  typer.Option(
213
+ "--parser",
214
+ "-P",
134
215
  click_type=click.Choice(sorted(PARSER_TYPES)),
135
216
  help="The type of parser to use.",
136
217
  ),
@@ -144,6 +225,24 @@ def translate(
144
225
  "collection with the name provided.",
145
226
  ),
146
227
  ] = None,
228
+ custom_splitter: Annotated[
229
+ Optional[str],
230
+ typer.Option(
231
+ "-cs",
232
+ "--custom-splitter",
233
+ help="Name of custom splitter to use",
234
+ click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
235
+ ),
236
+ ] = None,
237
+ max_tokens: Annotated[
238
+ int,
239
+ typer.Option(
240
+ "--max-tokens",
241
+ "-M",
242
+ help="The maximum number of tokens the model will take in. "
243
+ "If unspecificed, model's default max will be used.",
244
+ ),
245
+ ] = None,
147
246
  ):
148
247
  try:
149
248
  target_language, target_version = target_lang.split("-")
@@ -156,12 +255,7 @@ def translate(
156
255
  raise ValueError
157
256
 
158
257
  model_arguments = dict(temperature=temp)
159
- output_collection = None
160
- if collection is not None:
161
- _check_collection(collection, input_dir)
162
- db = ChromaEmbeddingDatabase(db_loc)
163
- collections = Collections(db)
164
- output_collection = collections.get_or_create(collection)
258
+ collections_config = get_collections_config()
165
259
  translator = Translator(
166
260
  model=llm_name,
167
261
  model_arguments=model_arguments,
@@ -169,21 +263,269 @@ def translate(
169
263
  target_language=target_language,
170
264
  target_version=target_version,
171
265
  max_prompts=max_prompts,
266
+ max_tokens=max_tokens,
172
267
  prompt_template=prompt_template,
173
268
  parser_type=parser_type,
269
+ db_path=db_loc,
270
+ db_config=collections_config,
271
+ custom_splitter=custom_splitter,
174
272
  )
175
- translator.translate(input_dir, output_dir, overwrite, output_collection)
273
+ translator.translate(input_dir, output_dir, overwrite, collection)
274
+
275
+
276
+ @app.command(
277
+ help="Document input code using an LLM.",
278
+ no_args_is_help=True,
279
+ )
280
+ def document(
281
+ input_dir: Annotated[
282
+ Path,
283
+ typer.Option(
284
+ "--input",
285
+ "-i",
286
+ help="The directory containing the source code to be translated. "
287
+ "The files should all be in one flat directory.",
288
+ ),
289
+ ],
290
+ language: Annotated[
291
+ str,
292
+ typer.Option(
293
+ "--language",
294
+ "-l",
295
+ help="The language of the source code.",
296
+ click_type=click.Choice(sorted(LANGUAGES)),
297
+ ),
298
+ ],
299
+ output_dir: Annotated[
300
+ Path,
301
+ typer.Option(
302
+ "--output-dir", "-o", help="The directory to store the translated code in."
303
+ ),
304
+ ],
305
+ llm_name: Annotated[
306
+ str,
307
+ typer.Option(
308
+ "--llm",
309
+ "-L",
310
+ help="The custom name of the model set with 'janus llm add'.",
311
+ ),
312
+ ] = "gpt-3.5-turbo-0125",
313
+ max_prompts: Annotated[
314
+ int,
315
+ typer.Option(
316
+ "--max-prompts",
317
+ "-m",
318
+ help="The maximum number of times to prompt a model on one functional block "
319
+ "before exiting the application. This is to prevent wasting too much money.",
320
+ ),
321
+ ] = 10,
322
+ overwrite: Annotated[
323
+ bool,
324
+ typer.Option(
325
+ "--overwrite/--preserve",
326
+ help="Whether to overwrite existing files in the output directory",
327
+ ),
328
+ ] = False,
329
+ doc_mode: Annotated[
330
+ str,
331
+ typer.Option(
332
+ "--doc-mode",
333
+ "-d",
334
+ help="The documentation mode.",
335
+ click_type=click.Choice(["madlibs", "summary", "multidoc", "requirements"]),
336
+ ),
337
+ ] = "madlibs",
338
+ comments_per_request: Annotated[
339
+ int,
340
+ typer.Option(
341
+ "--comments-per-request",
342
+ "-rc",
343
+ help="The maximum number of comments to generate per request when using "
344
+ "MadLibs documentation mode.",
345
+ ),
346
+ ] = None,
347
+ drop_comments: Annotated[
348
+ bool,
349
+ typer.Option(
350
+ "--drop-comments/--keep-comments",
351
+ help="Whether to drop or keep comments in the code sent to the LLM",
352
+ ),
353
+ ] = False,
354
+ temperature: Annotated[
355
+ float,
356
+ typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
357
+ ] = 0.7,
358
+ collection: Annotated[
359
+ str,
360
+ typer.Option(
361
+ "--collection",
362
+ "-c",
363
+ help="If set, will put the translated result into a Chroma DB "
364
+ "collection with the name provided.",
365
+ ),
366
+ ] = None,
367
+ custom_splitter: Annotated[
368
+ Optional[str],
369
+ typer.Option(
370
+ "-cs",
371
+ "--custom-splitter",
372
+ help="Name of custom splitter to use",
373
+ click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
374
+ ),
375
+ ] = None,
376
+ max_tokens: Annotated[
377
+ int,
378
+ typer.Option(
379
+ "--max-tokens",
380
+ "-M",
381
+ help="The maximum number of tokens the model will take in. "
382
+ "If unspecificed, model's default max will be used.",
383
+ ),
384
+ ] = None,
385
+ ):
386
+ model_arguments = dict(temperature=temperature)
387
+ collections_config = get_collections_config()
388
+ kwargs = dict(
389
+ model=llm_name,
390
+ model_arguments=model_arguments,
391
+ source_language=language,
392
+ max_prompts=max_prompts,
393
+ max_tokens=max_tokens,
394
+ db_path=db_loc,
395
+ db_config=collections_config,
396
+ custom_splitter=custom_splitter,
397
+ )
398
+ if doc_mode == "madlibs":
399
+ documenter = MadLibsDocumenter(
400
+ comments_per_request=comments_per_request, **kwargs
401
+ )
402
+ elif doc_mode == "multidoc":
403
+ documenter = MultiDocumenter(drop_comments=drop_comments, **kwargs)
404
+ elif doc_mode == "requirements":
405
+ documenter = RequirementsDocumenter(drop_comments=drop_comments, **kwargs)
406
+ else:
407
+ documenter = Documenter(drop_comments=drop_comments, **kwargs)
408
+
409
+ documenter.translate(input_dir, output_dir, overwrite, collection)
410
+
411
+
412
+ @app.command(
413
+ help="Diagram input code using an LLM.",
414
+ no_args_is_help=True,
415
+ )
416
+ def diagram(
417
+ input_dir: Annotated[
418
+ Path,
419
+ typer.Option(
420
+ "--input",
421
+ "-i",
422
+ help="The directory containing the source code to be translated. "
423
+ "The files should all be in one flat directory.",
424
+ ),
425
+ ],
426
+ language: Annotated[
427
+ str,
428
+ typer.Option(
429
+ "--language",
430
+ "-l",
431
+ help="The language of the source code.",
432
+ click_type=click.Choice(sorted(LANGUAGES)),
433
+ ),
434
+ ],
435
+ output_dir: Annotated[
436
+ Path,
437
+ typer.Option(
438
+ "--output-dir", "-o", help="The directory to store the translated code in."
439
+ ),
440
+ ],
441
+ llm_name: Annotated[
442
+ str,
443
+ typer.Option(
444
+ "--llm",
445
+ "-L",
446
+ help="The custom name of the model set with 'janus llm add'.",
447
+ ),
448
+ ] = "gpt-3.5-turbo-0125",
449
+ max_prompts: Annotated[
450
+ int,
451
+ typer.Option(
452
+ "--max-prompts",
453
+ "-m",
454
+ help="The maximum number of times to prompt a model on one functional block "
455
+ "before exiting the application. This is to prevent wasting too much money.",
456
+ ),
457
+ ] = 10,
458
+ overwrite: Annotated[
459
+ bool,
460
+ typer.Option(
461
+ "--overwrite/--preserve",
462
+ help="Whether to overwrite existing files in the output directory",
463
+ ),
464
+ ] = False,
465
+ temperature: Annotated[
466
+ float,
467
+ typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
468
+ ] = 0.7,
469
+ collection: Annotated[
470
+ str,
471
+ typer.Option(
472
+ "--collection",
473
+ "-c",
474
+ help="If set, will put the translated result into a Chroma DB "
475
+ "collection with the name provided.",
476
+ ),
477
+ ] = None,
478
+ diagram_type: Annotated[
479
+ str,
480
+ typer.Option(
481
+ "--diagram-type", "-dg", help="Diagram type to generate in PLANTUML"
482
+ ),
483
+ ] = "Activity",
484
+ add_documentation: Annotated[
485
+ bool,
486
+ typer.Option(
487
+ "--add-documentation/--no-documentation",
488
+ "-ad",
489
+ help="Whether to use documentation in generation",
490
+ ),
491
+ ] = False,
492
+ custom_splitter: Annotated[
493
+ Optional[str],
494
+ typer.Option(
495
+ "-cs",
496
+ "--custom-splitter",
497
+ help="Name of custom splitter to use",
498
+ click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
499
+ ),
500
+ ] = None,
501
+ ):
502
+ model_arguments = dict(temperature=temperature)
503
+ collections_config = get_collections_config()
504
+ diagram_generator = DiagramGenerator(
505
+ model=llm_name,
506
+ model_arguments=model_arguments,
507
+ source_language=language,
508
+ max_prompts=max_prompts,
509
+ db_path=db_loc,
510
+ db_config=collections_config,
511
+ diagram_type=diagram_type,
512
+ add_documentation=add_documentation,
513
+ custom_splitter=custom_splitter,
514
+ )
515
+ diagram_generator.translate(input_dir, output_dir, overwrite, collection)
176
516
 
177
517
 
178
518
  @db.command("init", help="Connect to or create a database.")
179
519
  def db_init(
180
- path: Annotated[str, typer.Option(help="The path to the database file.")] = str(
181
- janus_dir / "chroma.db"
182
- ),
520
+ path: Annotated[
521
+ str, typer.Option("--path", "-p", help="The path to the database file.")
522
+ ] = str(janus_dir / "chroma.db"),
183
523
  url: Annotated[
184
524
  str,
185
525
  typer.Option(
186
- help="The URL of the database if the database is running externally."
526
+ "--url",
527
+ "-u",
528
+ help="The URL of the database if the database is running externally.",
187
529
  ),
188
530
  ] = "",
189
531
  ) -> None:
@@ -219,7 +561,7 @@ def db_ls(
219
561
  ] = None,
220
562
  peek: Annotated[
221
563
  Optional[int],
222
- typer.Option(help="Peek at N entries for a specific collection."),
564
+ typer.Option("--peek", "-p", help="Peek at N entries for a specific collection."),
223
565
  ] = None,
224
566
  ) -> None:
225
567
  """List the current database's collections"""
@@ -256,17 +598,24 @@ def db_ls(
256
598
  @db.command("add", help="Add a collection to the current database.")
257
599
  def db_add(
258
600
  collection_name: Annotated[str, typer.Argument(help="The name of the collection.")],
601
+ model_name: Annotated[str, typer.Argument(help="The name of the embedding model.")],
259
602
  input_dir: Annotated[
260
603
  str,
261
- typer.Option(help="The directory containing the source code to be added."),
604
+ typer.Option(
605
+ "--input",
606
+ "-i",
607
+ help="The directory containing the source code to be added.",
608
+ ),
262
609
  ] = "./",
263
610
  input_lang: Annotated[
264
- str, typer.Option(help="The language of the source code.")
611
+ str, typer.Option("--language", "-l", help="The language of the source code.")
265
612
  ] = "python",
266
613
  max_tokens: Annotated[
267
614
  int,
268
615
  typer.Option(
269
- help="The maximum number of tokens for each chunk of input source code."
616
+ "--max-tokens",
617
+ "-m",
618
+ help="The maximum number of tokens for each chunk of input source code.",
270
619
  ),
271
620
  ] = 4096,
272
621
  ) -> None:
@@ -274,13 +623,16 @@ def db_add(
274
623
 
275
624
  Arguments:
276
625
  collection_name: The name of the collection to add
626
+ model_name: The name of the embedding model to use
277
627
  input_dir: The directory containing the source code to be added
278
628
  input_lang: The language of the source code
629
+ max_tokens: The maximum number of tokens for each chunk of input source code
279
630
  """
280
631
  # TODO: import factory
281
632
  console = Console()
282
633
 
283
634
  added_to = _check_collection(collection_name, input_dir)
635
+ collections_config = get_collections_config()
284
636
 
285
637
  with console.status(
286
638
  f"Adding collection: [bold salmon]{collection_name}[/bold salmon]",
@@ -288,13 +640,13 @@ def db_add(
288
640
  ):
289
641
  vectorizer_factory = ChromaDBVectorizer()
290
642
  vectorizer = vectorizer_factory.create_vectorizer(
291
- source_language=input_lang,
292
- path=db_loc,
293
- max_tokens=max_tokens,
643
+ path=db_loc, config=collections_config
294
644
  )
645
+ vectorizer.get_or_create_collection(collection_name, model_name=model_name)
295
646
  input_dir = Path(input_dir)
296
- source_glob = f"**/*.{LANGUAGES[input_lang]['suffix']}"
297
- input_paths = input_dir.rglob(source_glob)
647
+ suffix = LANGUAGES[input_lang]["suffix"]
648
+ source_glob = f"**/*.{suffix}"
649
+ input_paths = [p for p in input_dir.rglob(source_glob)]
298
650
  if input_lang in CUSTOM_SPLITTERS:
299
651
  if input_lang == "mumps":
300
652
  splitter = MumpsSplitter(
@@ -311,15 +663,35 @@ def db_add(
311
663
  )
312
664
  for input_path in input_paths:
313
665
  input_block = splitter.split(input_path)
314
- vectorizer._add_nodes_recursively(
666
+ vectorizer.add_nodes_recursively(
315
667
  input_block,
316
668
  collection_name,
317
669
  input_path.name,
318
670
  )
671
+ total_files = len([p for p in Path.glob(input_dir, "**/*") if not p.is_dir()])
319
672
  if added_to:
320
- print(f"Added to collection [bold salmon1]{collection_name}[/bold salmon1]")
673
+ print(
674
+ f"\nAdded to [bold salmon1]{collection_name}[/bold salmon1]:\n"
675
+ f" Embedding Model: [green]{model_name}[/green]\n"
676
+ f" Input Directory: {input_dir.absolute()}\n"
677
+ f" {input_lang.capitalize()} [green]*.{suffix}[/green] Files: "
678
+ f"{len(input_paths)}\n"
679
+ " Other Files (skipped): "
680
+ f"{total_files - len(input_paths)}\n"
681
+ )
682
+ [p for p in Path.glob(input_dir, f"**/*.{suffix}") if not p.is_dir()]
321
683
  else:
322
- print(f"Created collection [bold salmon1]{collection_name}[/bold salmon1]")
684
+ print(
685
+ f"\nCreated [bold salmon1]{collection_name}[/bold salmon1]:\n"
686
+ f" Embedding Model: '{model_name}'\n"
687
+ f" Input Directory: {input_dir.absolute()}\n"
688
+ f" {input_lang.capitalize()} [green]*.{suffix}[/green] Files: "
689
+ f"{len(input_paths)}\n"
690
+ " Other Files (skipped): "
691
+ f"{total_files - len(input_paths)}\n"
692
+ )
693
+ with open(collections_config_file, "w") as f:
694
+ json.dump(vectorizer.config, f, indent=2)
323
695
 
324
696
 
325
697
  @db.command(
@@ -327,17 +699,28 @@ def db_add(
327
699
  help="Remove a collection from the database.",
328
700
  )
329
701
  def db_rm(
330
- collection_name: Annotated[str, typer.Argument(help="The name of the collection.")]
702
+ collection_name: Annotated[str, typer.Argument(help="The name of the collection.")],
703
+ confirm: Annotated[
704
+ bool,
705
+ typer.Option(
706
+ "--yes",
707
+ "-y",
708
+ help="Confirm the removal of the collection.",
709
+ ),
710
+ ],
331
711
  ) -> None:
332
712
  """Remove a collection from the database
333
713
 
334
714
  Arguments:
335
715
  collection_name: The name of the collection to remove
336
716
  """
337
- delete = Confirm.ask(
338
- f"\nAre you sure you want to [bold red]remove[/bold red] "
339
- f"[bold salmon1]{collection_name}[/bold salmon1]?",
340
- )
717
+ if not confirm:
718
+ delete = Confirm.ask(
719
+ f"\nAre you sure you want to [bold red]remove[/bold red] "
720
+ f"[bold salmon1]{collection_name}[/bold salmon1]?",
721
+ )
722
+ else:
723
+ delete = True
341
724
  if not delete:
342
725
  raise typer.Abort()
343
726
  db = ChromaEmbeddingDatabase(db_loc)
@@ -425,16 +808,115 @@ def llm_add(
425
808
  "model_cost": {"input": in_cost, "output": out_cost},
426
809
  }
427
810
  elif model_type == "OpenAI":
428
- model_name = typer.prompt("Enter the model name", default="gpt-3.5-turbo")
811
+ model_name = typer.prompt("Enter the model name", default="gpt-3.5-turbo-0125")
429
812
  params = dict(
430
813
  model_name=model_name,
431
814
  temperature=0.7,
432
815
  n=1,
433
816
  )
434
817
  max_tokens = TOKEN_LIMITS[model_name]
435
- model_cost = COST_PER_MODEL[model_name]
818
+ model_cost = COST_PER_1K_TOKENS[model_name]
819
+ cfg = {
820
+ "model_type": model_type,
821
+ "model_args": params,
822
+ "token_limit": max_tokens,
823
+ "model_cost": model_cost,
824
+ }
825
+ else:
826
+ raise ValueError(f"Unknown model type {model_type}")
827
+ with open(model_cfg, "w") as f:
828
+ json.dump(cfg, f, indent=2)
829
+ print(f"Model config written to {model_cfg}")
830
+
831
+
832
+ @embedding.command("add", help="Add an embedding model config to janus")
833
+ def embedding_add(
834
+ model_name: Annotated[
835
+ str, typer.Argument(help="The user's custom name for the model")
836
+ ],
837
+ model_type: Annotated[
838
+ str,
839
+ typer.Option(
840
+ "--type",
841
+ "-t",
842
+ help="The type of the model",
843
+ click_type=click.Choice(list(val.value for val in EmbeddingModelType)),
844
+ ),
845
+ ] = "OpenAI",
846
+ ):
847
+ if not EMBEDDING_MODEL_CONFIG_DIR.exists():
848
+ EMBEDDING_MODEL_CONFIG_DIR.mkdir(parents=True)
849
+ model_cfg = EMBEDDING_MODEL_CONFIG_DIR / f"{model_name}.json"
850
+ if model_type in EmbeddingModelType.HuggingFaceInferenceAPI.values:
851
+ hf = typer.style("HuggingFaceInferenceAPI", fg="yellow")
852
+ url = typer.prompt(f"Enter the {hf} model's URL", type=str, value_proc=AnyHttpUrl)
853
+ api_model_name = typer.prompt("Enter the model's name", type=str, default="")
854
+ api_key = typer.prompt("Enter the API key", type=str, default="")
855
+ max_tokens = typer.prompt(
856
+ "Enter the model's maximum tokens", default=8191, type=int
857
+ )
858
+ in_cost = typer.prompt("Enter the cost per input token", default=0, type=float)
859
+ out_cost = typer.prompt("Enter the cost per output token", default=0, type=float)
860
+ params = dict(
861
+ model_name=api_model_name,
862
+ api_key=api_key,
863
+ )
864
+ cfg = {
865
+ "model_type": model_type,
866
+ "model_identifier": str(url),
867
+ "model_args": params,
868
+ "token_limit": max_tokens,
869
+ "model_cost": {"input": in_cost, "output": out_cost},
870
+ }
871
+ elif model_type in EmbeddingModelType.HuggingFaceLocal.values:
872
+ hf = typer.style("HuggingFace", fg="yellow")
873
+ model_id = typer.prompt(
874
+ f"Enter the {hf} model ID",
875
+ default="sentence-transformers/all-MiniLM-L6-v2",
876
+ type=str,
877
+ )
878
+ cache_folder = str(
879
+ Path(
880
+ typer.prompt(
881
+ "Enter the model's cache folder",
882
+ default=EMBEDDING_MODEL_CONFIG_DIR / "cache",
883
+ type=str,
884
+ )
885
+ )
886
+ )
887
+ max_tokens = typer.prompt(
888
+ "Enter the model's maximum tokens", default=8191, type=int
889
+ )
890
+ params = dict(
891
+ cache_folder=str(cache_folder),
892
+ )
893
+ cfg = {
894
+ "model_type": model_type,
895
+ "model_identifier": model_id,
896
+ "model_args": params,
897
+ "token_limit": max_tokens,
898
+ "model_cost": {"input": 0, "output": 0},
899
+ }
900
+ elif model_type in EmbeddingModelType.OpenAI.values:
901
+ available_models = list(EMBEDDING_COST_PER_MODEL.keys())
902
+
903
+ open_ai = typer.style("OpenAI", fg="green")
904
+ prompt = f"Enter the {open_ai} model name"
905
+
906
+ model_name = typer.prompt(
907
+ prompt,
908
+ default="text-embedding-3-small",
909
+ type=click.types.Choice(available_models),
910
+ show_choices=False,
911
+ )
912
+ params = dict(
913
+ model=model_name,
914
+ )
915
+ max_tokens = EMBEDDING_TOKEN_LIMITS[model_name]
916
+ model_cost = EMBEDDING_COST_PER_MODEL[model_name]
436
917
  cfg = {
437
918
  "model_type": model_type,
919
+ "model_identifier": model_name,
438
920
  "model_args": params,
439
921
  "token_limit": max_tokens,
440
922
  "model_cost": model_cost,
@@ -448,6 +930,8 @@ def llm_add(
448
930
 
449
931
  app.add_typer(db, name="db")
450
932
  app.add_typer(llm, name="llm")
933
+ app.add_typer(evaluate, name="evaluate")
934
+ app.add_typer(embedding, name="embedding")
451
935
 
452
936
 
453
937
  if __name__ == "__main__":