janus-llm 1.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. janus/__init__.py +9 -1
  2. janus/__main__.py +4 -0
  3. janus/_tests/test_cli.py +128 -0
  4. janus/_tests/test_translate.py +49 -7
  5. janus/cli.py +530 -46
  6. janus/converter.py +50 -19
  7. janus/embedding/_tests/test_collections.py +2 -8
  8. janus/embedding/_tests/test_database.py +32 -0
  9. janus/embedding/_tests/test_vectorize.py +9 -4
  10. janus/embedding/collections.py +49 -6
  11. janus/embedding/embedding_models_info.py +130 -0
  12. janus/embedding/vectorize.py +53 -62
  13. janus/language/_tests/__init__.py +0 -0
  14. janus/language/_tests/test_combine.py +62 -0
  15. janus/language/_tests/test_splitter.py +16 -0
  16. janus/language/binary/_tests/test_binary.py +16 -1
  17. janus/language/binary/binary.py +10 -3
  18. janus/language/block.py +31 -30
  19. janus/language/combine.py +26 -34
  20. janus/language/mumps/_tests/test_mumps.py +2 -2
  21. janus/language/mumps/mumps.py +93 -9
  22. janus/language/naive/__init__.py +4 -0
  23. janus/language/naive/basic_splitter.py +14 -0
  24. janus/language/naive/chunk_splitter.py +26 -0
  25. janus/language/naive/registry.py +13 -0
  26. janus/language/naive/simple_ast.py +18 -0
  27. janus/language/naive/tag_splitter.py +61 -0
  28. janus/language/splitter.py +168 -74
  29. janus/language/treesitter/_tests/test_treesitter.py +19 -14
  30. janus/language/treesitter/treesitter.py +37 -13
  31. janus/llm/model_callbacks.py +177 -0
  32. janus/llm/models_info.py +165 -72
  33. janus/metrics/__init__.py +8 -0
  34. janus/metrics/_tests/__init__.py +0 -0
  35. janus/metrics/_tests/reference.py +2 -0
  36. janus/metrics/_tests/target.py +2 -0
  37. janus/metrics/_tests/test_bleu.py +56 -0
  38. janus/metrics/_tests/test_chrf.py +67 -0
  39. janus/metrics/_tests/test_file_pairing.py +59 -0
  40. janus/metrics/_tests/test_llm.py +91 -0
  41. janus/metrics/_tests/test_reading.py +28 -0
  42. janus/metrics/_tests/test_rouge_score.py +65 -0
  43. janus/metrics/_tests/test_similarity_score.py +23 -0
  44. janus/metrics/_tests/test_treesitter_metrics.py +110 -0
  45. janus/metrics/bleu.py +66 -0
  46. janus/metrics/chrf.py +55 -0
  47. janus/metrics/cli.py +7 -0
  48. janus/metrics/complexity_metrics.py +208 -0
  49. janus/metrics/file_pairing.py +113 -0
  50. janus/metrics/llm_metrics.py +202 -0
  51. janus/metrics/metric.py +466 -0
  52. janus/metrics/reading.py +70 -0
  53. janus/metrics/rouge_score.py +96 -0
  54. janus/metrics/similarity.py +53 -0
  55. janus/metrics/splitting.py +38 -0
  56. janus/parsers/_tests/__init__.py +0 -0
  57. janus/parsers/_tests/test_code_parser.py +32 -0
  58. janus/parsers/code_parser.py +24 -253
  59. janus/parsers/doc_parser.py +169 -0
  60. janus/parsers/eval_parser.py +80 -0
  61. janus/parsers/reqs_parser.py +72 -0
  62. janus/prompts/prompt.py +103 -30
  63. janus/translate.py +636 -111
  64. janus/utils/_tests/__init__.py +0 -0
  65. janus/utils/_tests/test_logger.py +67 -0
  66. janus/utils/_tests/test_progress.py +20 -0
  67. janus/utils/enums.py +56 -3
  68. janus/utils/progress.py +56 -0
  69. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/METADATA +27 -11
  70. janus_llm-2.0.1.dist-info/RECORD +94 -0
  71. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/WHEEL +1 -1
  72. janus_llm-1.0.0.dist-info/RECORD +0 -48
  73. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/LICENSE +0 -0
  74. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/entry_points.txt +0 -0
janus/cli.py CHANGED
@@ -1,34 +1,50 @@
1
1
  import json
2
+ import logging
2
3
  import os
3
4
  from pathlib import Path
4
5
  from typing import Optional
5
6
 
6
7
  import click
7
8
  import typer
9
+ from pydantic import AnyHttpUrl
8
10
  from rich import print
9
11
  from rich.console import Console
10
12
  from rich.prompt import Confirm
11
13
  from typing_extensions import Annotated
12
14
 
15
+ from janus.language.naive.registry import CUSTOM_SPLITTERS
16
+
13
17
  from .embedding.collections import Collections
14
18
  from .embedding.database import ChromaEmbeddingDatabase
19
+ from .embedding.embedding_models_info import (
20
+ EMBEDDING_COST_PER_MODEL,
21
+ EMBEDDING_MODEL_CONFIG_DIR,
22
+ EMBEDDING_TOKEN_LIMITS,
23
+ EmbeddingModelType,
24
+ )
15
25
  from .embedding.vectorize import ChromaDBVectorizer
16
26
  from .language.binary import BinarySplitter
17
27
  from .language.mumps import MumpsSplitter
18
28
  from .language.treesitter import TreeSitterSplitter
19
- from .llm.models_info import (
20
- COST_PER_MODEL,
21
- MODEL_CONFIG_DIR,
22
- MODEL_TYPE_CONSTRUCTORS,
23
- TOKEN_LIMITS,
29
+ from .llm.model_callbacks import COST_PER_1K_TOKENS
30
+ from .llm.models_info import MODEL_CONFIG_DIR, MODEL_TYPE_CONSTRUCTORS, TOKEN_LIMITS
31
+ from .metrics.cli import evaluate
32
+ from .translate import (
33
+ PARSER_TYPES,
34
+ DiagramGenerator,
35
+ Documenter,
36
+ MadLibsDocumenter,
37
+ MultiDocumenter,
38
+ RequirementsDocumenter,
39
+ Translator,
24
40
  )
25
- from .parsers.code_parser import PARSER_TYPES
26
- from .translate import Translator
27
- from .utils.enums import CUSTOM_SPLITTERS, LANGUAGES
41
+ from .utils.enums import LANGUAGES
28
42
  from .utils.logger import create_logger
29
43
 
30
- log = create_logger(__name__)
44
+ httpx_logger = logging.getLogger("httpx")
45
+ httpx_logger.setLevel(logging.WARNING)
31
46
 
47
+ log = create_logger(__name__)
32
48
  homedir = Path.home().expanduser()
33
49
 
34
50
  janus_dir = homedir / ".janus"
@@ -43,6 +59,17 @@ if not db_file.exists():
43
59
  with open(db_file, "r") as f:
44
60
  db_loc = f.read()
45
61
 
62
+ collections_config_file = Path(db_loc) / "collections.json"
63
+
64
+
65
+ def get_collections_config():
66
+ if collections_config_file.exists():
67
+ with open(collections_config_file, "r") as f:
68
+ config = json.load(f)
69
+ else:
70
+ config = {}
71
+ return config
72
+
46
73
 
47
74
  app = typer.Typer(
48
75
  help="Choose a command",
@@ -51,6 +78,7 @@ app = typer.Typer(
51
78
  context_settings={"help_option_names": ["-h", "--help"]},
52
79
  )
53
80
 
81
+
54
82
  db = typer.Typer(
55
83
  help="Database commands",
56
84
  add_completion=False,
@@ -64,6 +92,43 @@ llm = typer.Typer(
64
92
  context_settings={"help_option_names": ["-h", "--help"]},
65
93
  )
66
94
 
95
+ embedding = typer.Typer(
96
+ help="Embedding model commands",
97
+ add_completion=False,
98
+ no_args_is_help=True,
99
+ context_settings={"help_option_names": ["-h", "--help"]},
100
+ )
101
+
102
+
103
+ def version_callback(value: bool) -> None:
104
+ if value:
105
+ from . import __version__ as version
106
+
107
+ print(f"Janus CLI [blue]v{version}[/blue]")
108
+ raise typer.Exit()
109
+
110
+
111
+ @app.callback()
112
+ def common(
113
+ ctx: typer.Context,
114
+ version: bool = typer.Option(
115
+ None,
116
+ "--version",
117
+ "-v",
118
+ callback=version_callback,
119
+ help="Print the version and exit.",
120
+ ),
121
+ ) -> None:
122
+ """A function for getting the app version
123
+
124
+ This will call the version_callback function to print the version and exit.
125
+
126
+ Arguments:
127
+ ctx: The typer context
128
+ version: A boolean flag for the version
129
+ """
130
+ pass
131
+
67
132
 
68
133
  @app.command(
69
134
  help="Translate code from one language to another using an LLM.",
@@ -73,41 +138,53 @@ def translate(
73
138
  input_dir: Annotated[
74
139
  Path,
75
140
  typer.Option(
141
+ "--input",
142
+ "-i",
76
143
  help="The directory containing the source code to be translated. "
77
- "The files should all be in one flat directory."
144
+ "The files should all be in one flat directory.",
78
145
  ),
79
146
  ],
80
147
  source_lang: Annotated[
81
148
  str,
82
149
  typer.Option(
150
+ "--source-language",
151
+ "-s",
83
152
  help="The language of the source code.",
84
153
  click_type=click.Choice(sorted(LANGUAGES)),
85
154
  ),
86
155
  ],
87
156
  output_dir: Annotated[
88
157
  Path,
89
- typer.Option(help="The directory to store the translated code in."),
158
+ typer.Option(
159
+ "--output", "-o", help="The directory to store the translated code in."
160
+ ),
90
161
  ],
91
162
  target_lang: Annotated[
92
163
  str,
93
164
  typer.Option(
165
+ "--target-language",
166
+ "-t",
94
167
  help="The desired output language to translate the source code to. The "
95
168
  "format can follow a 'language-version' syntax. Use 'text' to get plaintext"
96
169
  "results as returned by the LLM. Examples: `python-3.10`, `mumps`, `java-10`,"
97
- "text."
170
+ "text.",
98
171
  ),
99
172
  ],
100
173
  llm_name: Annotated[
101
174
  str,
102
175
  typer.Option(
176
+ "--llm",
177
+ "-L",
103
178
  help="The custom name of the model set with 'janus llm add'.",
104
179
  ),
105
- ] = "gpt-3.5-turbo",
180
+ ] = "gpt-3.5-turbo-0125",
106
181
  max_prompts: Annotated[
107
182
  int,
108
183
  typer.Option(
184
+ "--max-prompts",
185
+ "-m",
109
186
  help="The maximum number of times to prompt a model on one functional block "
110
- "before exiting the application. This is to prevent wasting too much money."
187
+ "before exiting the application. This is to prevent wasting too much money.",
111
188
  ),
112
189
  ] = 10,
113
190
  overwrite: Annotated[
@@ -119,18 +196,22 @@ def translate(
119
196
  ] = False,
120
197
  temp: Annotated[
121
198
  float,
122
- typer.Option(help="Sampling temperature.", min=0, max=2),
199
+ typer.Option("--temperature", "-T", help="Sampling temperature.", min=0, max=2),
123
200
  ] = 0.7,
124
201
  prompt_template: Annotated[
125
202
  str,
126
203
  typer.Option(
204
+ "--prompt-template",
205
+ "-p",
127
206
  help="Name of the Janus prompt template directory or "
128
- "path to a directory containing those template files."
207
+ "path to a directory containing those template files.",
129
208
  ),
130
209
  ] = "simple",
131
210
  parser_type: Annotated[
132
211
  str,
133
212
  typer.Option(
213
+ "--parser",
214
+ "-P",
134
215
  click_type=click.Choice(sorted(PARSER_TYPES)),
135
216
  help="The type of parser to use.",
136
217
  ),
@@ -144,6 +225,24 @@ def translate(
144
225
  "collection with the name provided.",
145
226
  ),
146
227
  ] = None,
228
+ custom_splitter: Annotated[
229
+ Optional[str],
230
+ typer.Option(
231
+ "-cs",
232
+ "--custom-splitter",
233
+ help="Name of custom splitter to use",
234
+ click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
235
+ ),
236
+ ] = None,
237
+ max_tokens: Annotated[
238
+ int,
239
+ typer.Option(
240
+ "--max-tokens",
241
+ "-M",
242
+ help="The maximum number of tokens the model will take in. "
243
+ "If unspecificed, model's default max will be used.",
244
+ ),
245
+ ] = None,
147
246
  ):
148
247
  try:
149
248
  target_language, target_version = target_lang.split("-")
@@ -156,12 +255,7 @@ def translate(
156
255
  raise ValueError
157
256
 
158
257
  model_arguments = dict(temperature=temp)
159
- output_collection = None
160
- if collection is not None:
161
- _check_collection(collection, input_dir)
162
- db = ChromaEmbeddingDatabase(db_loc)
163
- collections = Collections(db)
164
- output_collection = collections.get_or_create(collection)
258
+ collections_config = get_collections_config()
165
259
  translator = Translator(
166
260
  model=llm_name,
167
261
  model_arguments=model_arguments,
@@ -169,21 +263,269 @@ def translate(
169
263
  target_language=target_language,
170
264
  target_version=target_version,
171
265
  max_prompts=max_prompts,
266
+ max_tokens=max_tokens,
172
267
  prompt_template=prompt_template,
173
268
  parser_type=parser_type,
269
+ db_path=db_loc,
270
+ db_config=collections_config,
271
+ custom_splitter=custom_splitter,
174
272
  )
175
- translator.translate(input_dir, output_dir, overwrite, output_collection)
273
+ translator.translate(input_dir, output_dir, overwrite, collection)
274
+
275
+
276
+ @app.command(
277
+ help="Document input code using an LLM.",
278
+ no_args_is_help=True,
279
+ )
280
+ def document(
281
+ input_dir: Annotated[
282
+ Path,
283
+ typer.Option(
284
+ "--input",
285
+ "-i",
286
+ help="The directory containing the source code to be translated. "
287
+ "The files should all be in one flat directory.",
288
+ ),
289
+ ],
290
+ language: Annotated[
291
+ str,
292
+ typer.Option(
293
+ "--language",
294
+ "-l",
295
+ help="The language of the source code.",
296
+ click_type=click.Choice(sorted(LANGUAGES)),
297
+ ),
298
+ ],
299
+ output_dir: Annotated[
300
+ Path,
301
+ typer.Option(
302
+ "--output-dir", "-o", help="The directory to store the translated code in."
303
+ ),
304
+ ],
305
+ llm_name: Annotated[
306
+ str,
307
+ typer.Option(
308
+ "--llm",
309
+ "-L",
310
+ help="The custom name of the model set with 'janus llm add'.",
311
+ ),
312
+ ] = "gpt-3.5-turbo-0125",
313
+ max_prompts: Annotated[
314
+ int,
315
+ typer.Option(
316
+ "--max-prompts",
317
+ "-m",
318
+ help="The maximum number of times to prompt a model on one functional block "
319
+ "before exiting the application. This is to prevent wasting too much money.",
320
+ ),
321
+ ] = 10,
322
+ overwrite: Annotated[
323
+ bool,
324
+ typer.Option(
325
+ "--overwrite/--preserve",
326
+ help="Whether to overwrite existing files in the output directory",
327
+ ),
328
+ ] = False,
329
+ doc_mode: Annotated[
330
+ str,
331
+ typer.Option(
332
+ "--doc-mode",
333
+ "-d",
334
+ help="The documentation mode.",
335
+ click_type=click.Choice(["madlibs", "summary", "multidoc", "requirements"]),
336
+ ),
337
+ ] = "madlibs",
338
+ comments_per_request: Annotated[
339
+ int,
340
+ typer.Option(
341
+ "--comments-per-request",
342
+ "-rc",
343
+ help="The maximum number of comments to generate per request when using "
344
+ "MadLibs documentation mode.",
345
+ ),
346
+ ] = None,
347
+ drop_comments: Annotated[
348
+ bool,
349
+ typer.Option(
350
+ "--drop-comments/--keep-comments",
351
+ help="Whether to drop or keep comments in the code sent to the LLM",
352
+ ),
353
+ ] = False,
354
+ temperature: Annotated[
355
+ float,
356
+ typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
357
+ ] = 0.7,
358
+ collection: Annotated[
359
+ str,
360
+ typer.Option(
361
+ "--collection",
362
+ "-c",
363
+ help="If set, will put the translated result into a Chroma DB "
364
+ "collection with the name provided.",
365
+ ),
366
+ ] = None,
367
+ custom_splitter: Annotated[
368
+ Optional[str],
369
+ typer.Option(
370
+ "-cs",
371
+ "--custom-splitter",
372
+ help="Name of custom splitter to use",
373
+ click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
374
+ ),
375
+ ] = None,
376
+ max_tokens: Annotated[
377
+ int,
378
+ typer.Option(
379
+ "--max-tokens",
380
+ "-M",
381
+ help="The maximum number of tokens the model will take in. "
382
+ "If unspecificed, model's default max will be used.",
383
+ ),
384
+ ] = None,
385
+ ):
386
+ model_arguments = dict(temperature=temperature)
387
+ collections_config = get_collections_config()
388
+ kwargs = dict(
389
+ model=llm_name,
390
+ model_arguments=model_arguments,
391
+ source_language=language,
392
+ max_prompts=max_prompts,
393
+ max_tokens=max_tokens,
394
+ db_path=db_loc,
395
+ db_config=collections_config,
396
+ custom_splitter=custom_splitter,
397
+ )
398
+ if doc_mode == "madlibs":
399
+ documenter = MadLibsDocumenter(
400
+ comments_per_request=comments_per_request, **kwargs
401
+ )
402
+ elif doc_mode == "multidoc":
403
+ documenter = MultiDocumenter(drop_comments=drop_comments, **kwargs)
404
+ elif doc_mode == "requirements":
405
+ documenter = RequirementsDocumenter(drop_comments=drop_comments, **kwargs)
406
+ else:
407
+ documenter = Documenter(drop_comments=drop_comments, **kwargs)
408
+
409
+ documenter.translate(input_dir, output_dir, overwrite, collection)
410
+
411
+
412
+ @app.command(
413
+ help="Diagram input code using an LLM.",
414
+ no_args_is_help=True,
415
+ )
416
+ def diagram(
417
+ input_dir: Annotated[
418
+ Path,
419
+ typer.Option(
420
+ "--input",
421
+ "-i",
422
+ help="The directory containing the source code to be translated. "
423
+ "The files should all be in one flat directory.",
424
+ ),
425
+ ],
426
+ language: Annotated[
427
+ str,
428
+ typer.Option(
429
+ "--language",
430
+ "-l",
431
+ help="The language of the source code.",
432
+ click_type=click.Choice(sorted(LANGUAGES)),
433
+ ),
434
+ ],
435
+ output_dir: Annotated[
436
+ Path,
437
+ typer.Option(
438
+ "--output-dir", "-o", help="The directory to store the translated code in."
439
+ ),
440
+ ],
441
+ llm_name: Annotated[
442
+ str,
443
+ typer.Option(
444
+ "--llm",
445
+ "-L",
446
+ help="The custom name of the model set with 'janus llm add'.",
447
+ ),
448
+ ] = "gpt-3.5-turbo-0125",
449
+ max_prompts: Annotated[
450
+ int,
451
+ typer.Option(
452
+ "--max-prompts",
453
+ "-m",
454
+ help="The maximum number of times to prompt a model on one functional block "
455
+ "before exiting the application. This is to prevent wasting too much money.",
456
+ ),
457
+ ] = 10,
458
+ overwrite: Annotated[
459
+ bool,
460
+ typer.Option(
461
+ "--overwrite/--preserve",
462
+ help="Whether to overwrite existing files in the output directory",
463
+ ),
464
+ ] = False,
465
+ temperature: Annotated[
466
+ float,
467
+ typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
468
+ ] = 0.7,
469
+ collection: Annotated[
470
+ str,
471
+ typer.Option(
472
+ "--collection",
473
+ "-c",
474
+ help="If set, will put the translated result into a Chroma DB "
475
+ "collection with the name provided.",
476
+ ),
477
+ ] = None,
478
+ diagram_type: Annotated[
479
+ str,
480
+ typer.Option(
481
+ "--diagram-type", "-dg", help="Diagram type to generate in PLANTUML"
482
+ ),
483
+ ] = "Activity",
484
+ add_documentation: Annotated[
485
+ bool,
486
+ typer.Option(
487
+ "--add-documentation/--no-documentation",
488
+ "-ad",
489
+ help="Whether to use documentation in generation",
490
+ ),
491
+ ] = False,
492
+ custom_splitter: Annotated[
493
+ Optional[str],
494
+ typer.Option(
495
+ "-cs",
496
+ "--custom-splitter",
497
+ help="Name of custom splitter to use",
498
+ click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
499
+ ),
500
+ ] = None,
501
+ ):
502
+ model_arguments = dict(temperature=temperature)
503
+ collections_config = get_collections_config()
504
+ diagram_generator = DiagramGenerator(
505
+ model=llm_name,
506
+ model_arguments=model_arguments,
507
+ source_language=language,
508
+ max_prompts=max_prompts,
509
+ db_path=db_loc,
510
+ db_config=collections_config,
511
+ diagram_type=diagram_type,
512
+ add_documentation=add_documentation,
513
+ custom_splitter=custom_splitter,
514
+ )
515
+ diagram_generator.translate(input_dir, output_dir, overwrite, collection)
176
516
 
177
517
 
178
518
  @db.command("init", help="Connect to or create a database.")
179
519
  def db_init(
180
- path: Annotated[str, typer.Option(help="The path to the database file.")] = str(
181
- janus_dir / "chroma.db"
182
- ),
520
+ path: Annotated[
521
+ str, typer.Option("--path", "-p", help="The path to the database file.")
522
+ ] = str(janus_dir / "chroma.db"),
183
523
  url: Annotated[
184
524
  str,
185
525
  typer.Option(
186
- help="The URL of the database if the database is running externally."
526
+ "--url",
527
+ "-u",
528
+ help="The URL of the database if the database is running externally.",
187
529
  ),
188
530
  ] = "",
189
531
  ) -> None:
@@ -219,7 +561,7 @@ def db_ls(
219
561
  ] = None,
220
562
  peek: Annotated[
221
563
  Optional[int],
222
- typer.Option(help="Peek at N entries for a specific collection."),
564
+ typer.Option("--peek", "-p", help="Peek at N entries for a specific collection."),
223
565
  ] = None,
224
566
  ) -> None:
225
567
  """List the current database's collections"""
@@ -256,17 +598,24 @@ def db_ls(
256
598
  @db.command("add", help="Add a collection to the current database.")
257
599
  def db_add(
258
600
  collection_name: Annotated[str, typer.Argument(help="The name of the collection.")],
601
+ model_name: Annotated[str, typer.Argument(help="The name of the embedding model.")],
259
602
  input_dir: Annotated[
260
603
  str,
261
- typer.Option(help="The directory containing the source code to be added."),
604
+ typer.Option(
605
+ "--input",
606
+ "-i",
607
+ help="The directory containing the source code to be added.",
608
+ ),
262
609
  ] = "./",
263
610
  input_lang: Annotated[
264
- str, typer.Option(help="The language of the source code.")
611
+ str, typer.Option("--language", "-l", help="The language of the source code.")
265
612
  ] = "python",
266
613
  max_tokens: Annotated[
267
614
  int,
268
615
  typer.Option(
269
- help="The maximum number of tokens for each chunk of input source code."
616
+ "--max-tokens",
617
+ "-m",
618
+ help="The maximum number of tokens for each chunk of input source code.",
270
619
  ),
271
620
  ] = 4096,
272
621
  ) -> None:
@@ -274,13 +623,16 @@ def db_add(
274
623
 
275
624
  Arguments:
276
625
  collection_name: The name of the collection to add
626
+ model_name: The name of the embedding model to use
277
627
  input_dir: The directory containing the source code to be added
278
628
  input_lang: The language of the source code
629
+ max_tokens: The maximum number of tokens for each chunk of input source code
279
630
  """
280
631
  # TODO: import factory
281
632
  console = Console()
282
633
 
283
634
  added_to = _check_collection(collection_name, input_dir)
635
+ collections_config = get_collections_config()
284
636
 
285
637
  with console.status(
286
638
  f"Adding collection: [bold salmon]{collection_name}[/bold salmon]",
@@ -288,13 +640,13 @@ def db_add(
288
640
  ):
289
641
  vectorizer_factory = ChromaDBVectorizer()
290
642
  vectorizer = vectorizer_factory.create_vectorizer(
291
- source_language=input_lang,
292
- path=db_loc,
293
- max_tokens=max_tokens,
643
+ path=db_loc, config=collections_config
294
644
  )
645
+ vectorizer.get_or_create_collection(collection_name, model_name=model_name)
295
646
  input_dir = Path(input_dir)
296
- source_glob = f"**/*.{LANGUAGES[input_lang]['suffix']}"
297
- input_paths = input_dir.rglob(source_glob)
647
+ suffix = LANGUAGES[input_lang]["suffix"]
648
+ source_glob = f"**/*.{suffix}"
649
+ input_paths = [p for p in input_dir.rglob(source_glob)]
298
650
  if input_lang in CUSTOM_SPLITTERS:
299
651
  if input_lang == "mumps":
300
652
  splitter = MumpsSplitter(
@@ -311,15 +663,35 @@ def db_add(
311
663
  )
312
664
  for input_path in input_paths:
313
665
  input_block = splitter.split(input_path)
314
- vectorizer._add_nodes_recursively(
666
+ vectorizer.add_nodes_recursively(
315
667
  input_block,
316
668
  collection_name,
317
669
  input_path.name,
318
670
  )
671
+ total_files = len([p for p in Path.glob(input_dir, "**/*") if not p.is_dir()])
319
672
  if added_to:
320
- print(f"Added to collection [bold salmon1]{collection_name}[/bold salmon1]")
673
+ print(
674
+ f"\nAdded to [bold salmon1]{collection_name}[/bold salmon1]:\n"
675
+ f" Embedding Model: [green]{model_name}[/green]\n"
676
+ f" Input Directory: {input_dir.absolute()}\n"
677
+ f" {input_lang.capitalize()} [green]*.{suffix}[/green] Files: "
678
+ f"{len(input_paths)}\n"
679
+ " Other Files (skipped): "
680
+ f"{total_files - len(input_paths)}\n"
681
+ )
682
+ [p for p in Path.glob(input_dir, f"**/*.{suffix}") if not p.is_dir()]
321
683
  else:
322
- print(f"Created collection [bold salmon1]{collection_name}[/bold salmon1]")
684
+ print(
685
+ f"\nCreated [bold salmon1]{collection_name}[/bold salmon1]:\n"
686
+ f" Embedding Model: '{model_name}'\n"
687
+ f" Input Directory: {input_dir.absolute()}\n"
688
+ f" {input_lang.capitalize()} [green]*.{suffix}[/green] Files: "
689
+ f"{len(input_paths)}\n"
690
+ " Other Files (skipped): "
691
+ f"{total_files - len(input_paths)}\n"
692
+ )
693
+ with open(collections_config_file, "w") as f:
694
+ json.dump(vectorizer.config, f, indent=2)
323
695
 
324
696
 
325
697
  @db.command(
@@ -327,17 +699,28 @@ def db_add(
327
699
  help="Remove a collection from the database.",
328
700
  )
329
701
  def db_rm(
330
- collection_name: Annotated[str, typer.Argument(help="The name of the collection.")]
702
+ collection_name: Annotated[str, typer.Argument(help="The name of the collection.")],
703
+ confirm: Annotated[
704
+ bool,
705
+ typer.Option(
706
+ "--yes",
707
+ "-y",
708
+ help="Confirm the removal of the collection.",
709
+ ),
710
+ ],
331
711
  ) -> None:
332
712
  """Remove a collection from the database
333
713
 
334
714
  Arguments:
335
715
  collection_name: The name of the collection to remove
336
716
  """
337
- delete = Confirm.ask(
338
- f"\nAre you sure you want to [bold red]remove[/bold red] "
339
- f"[bold salmon1]{collection_name}[/bold salmon1]?",
340
- )
717
+ if not confirm:
718
+ delete = Confirm.ask(
719
+ f"\nAre you sure you want to [bold red]remove[/bold red] "
720
+ f"[bold salmon1]{collection_name}[/bold salmon1]?",
721
+ )
722
+ else:
723
+ delete = True
341
724
  if not delete:
342
725
  raise typer.Abort()
343
726
  db = ChromaEmbeddingDatabase(db_loc)
@@ -425,16 +808,115 @@ def llm_add(
425
808
  "model_cost": {"input": in_cost, "output": out_cost},
426
809
  }
427
810
  elif model_type == "OpenAI":
428
- model_name = typer.prompt("Enter the model name", default="gpt-3.5-turbo")
811
+ model_name = typer.prompt("Enter the model name", default="gpt-3.5-turbo-0125")
429
812
  params = dict(
430
813
  model_name=model_name,
431
814
  temperature=0.7,
432
815
  n=1,
433
816
  )
434
817
  max_tokens = TOKEN_LIMITS[model_name]
435
- model_cost = COST_PER_MODEL[model_name]
818
+ model_cost = COST_PER_1K_TOKENS[model_name]
819
+ cfg = {
820
+ "model_type": model_type,
821
+ "model_args": params,
822
+ "token_limit": max_tokens,
823
+ "model_cost": model_cost,
824
+ }
825
+ else:
826
+ raise ValueError(f"Unknown model type {model_type}")
827
+ with open(model_cfg, "w") as f:
828
+ json.dump(cfg, f, indent=2)
829
+ print(f"Model config written to {model_cfg}")
830
+
831
+
832
+ @embedding.command("add", help="Add an embedding model config to janus")
833
+ def embedding_add(
834
+ model_name: Annotated[
835
+ str, typer.Argument(help="The user's custom name for the model")
836
+ ],
837
+ model_type: Annotated[
838
+ str,
839
+ typer.Option(
840
+ "--type",
841
+ "-t",
842
+ help="The type of the model",
843
+ click_type=click.Choice(list(val.value for val in EmbeddingModelType)),
844
+ ),
845
+ ] = "OpenAI",
846
+ ):
847
+ if not EMBEDDING_MODEL_CONFIG_DIR.exists():
848
+ EMBEDDING_MODEL_CONFIG_DIR.mkdir(parents=True)
849
+ model_cfg = EMBEDDING_MODEL_CONFIG_DIR / f"{model_name}.json"
850
+ if model_type in EmbeddingModelType.HuggingFaceInferenceAPI.values:
851
+ hf = typer.style("HuggingFaceInferenceAPI", fg="yellow")
852
+ url = typer.prompt(f"Enter the {hf} model's URL", type=str, value_proc=AnyHttpUrl)
853
+ api_model_name = typer.prompt("Enter the model's name", type=str, default="")
854
+ api_key = typer.prompt("Enter the API key", type=str, default="")
855
+ max_tokens = typer.prompt(
856
+ "Enter the model's maximum tokens", default=8191, type=int
857
+ )
858
+ in_cost = typer.prompt("Enter the cost per input token", default=0, type=float)
859
+ out_cost = typer.prompt("Enter the cost per output token", default=0, type=float)
860
+ params = dict(
861
+ model_name=api_model_name,
862
+ api_key=api_key,
863
+ )
864
+ cfg = {
865
+ "model_type": model_type,
866
+ "model_identifier": str(url),
867
+ "model_args": params,
868
+ "token_limit": max_tokens,
869
+ "model_cost": {"input": in_cost, "output": out_cost},
870
+ }
871
+ elif model_type in EmbeddingModelType.HuggingFaceLocal.values:
872
+ hf = typer.style("HuggingFace", fg="yellow")
873
+ model_id = typer.prompt(
874
+ f"Enter the {hf} model ID",
875
+ default="sentence-transformers/all-MiniLM-L6-v2",
876
+ type=str,
877
+ )
878
+ cache_folder = str(
879
+ Path(
880
+ typer.prompt(
881
+ "Enter the model's cache folder",
882
+ default=EMBEDDING_MODEL_CONFIG_DIR / "cache",
883
+ type=str,
884
+ )
885
+ )
886
+ )
887
+ max_tokens = typer.prompt(
888
+ "Enter the model's maximum tokens", default=8191, type=int
889
+ )
890
+ params = dict(
891
+ cache_folder=str(cache_folder),
892
+ )
893
+ cfg = {
894
+ "model_type": model_type,
895
+ "model_identifier": model_id,
896
+ "model_args": params,
897
+ "token_limit": max_tokens,
898
+ "model_cost": {"input": 0, "output": 0},
899
+ }
900
+ elif model_type in EmbeddingModelType.OpenAI.values:
901
+ available_models = list(EMBEDDING_COST_PER_MODEL.keys())
902
+
903
+ open_ai = typer.style("OpenAI", fg="green")
904
+ prompt = f"Enter the {open_ai} model name"
905
+
906
+ model_name = typer.prompt(
907
+ prompt,
908
+ default="text-embedding-3-small",
909
+ type=click.types.Choice(available_models),
910
+ show_choices=False,
911
+ )
912
+ params = dict(
913
+ model=model_name,
914
+ )
915
+ max_tokens = EMBEDDING_TOKEN_LIMITS[model_name]
916
+ model_cost = EMBEDDING_COST_PER_MODEL[model_name]
436
917
  cfg = {
437
918
  "model_type": model_type,
919
+ "model_identifier": model_name,
438
920
  "model_args": params,
439
921
  "token_limit": max_tokens,
440
922
  "model_cost": model_cost,
@@ -448,6 +930,8 @@ def llm_add(
448
930
 
449
931
  app.add_typer(db, name="db")
450
932
  app.add_typer(llm, name="llm")
933
+ app.add_typer(evaluate, name="evaluate")
934
+ app.add_typer(embedding, name="embedding")
451
935
 
452
936
 
453
937
  if __name__ == "__main__":