janus-llm 4.2.0__py3-none-any.whl → 4.3.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (134) hide show
  1. janus/__init__.py +1 -1
  2. janus/__main__.py +1 -1
  3. janus/_tests/evaluator_tests/EvalReadMe.md +85 -0
  4. janus/_tests/evaluator_tests/incose_tests/incose_large_test.json +39 -0
  5. janus/_tests/evaluator_tests/incose_tests/incose_small_test.json +17 -0
  6. janus/_tests/evaluator_tests/inline_comment_tests/mumps_inline_comment_test.m +71 -0
  7. janus/_tests/test_cli.py +3 -2
  8. janus/cli/aggregate.py +135 -0
  9. janus/cli/cli.py +111 -0
  10. janus/cli/constants.py +43 -0
  11. janus/cli/database.py +289 -0
  12. janus/cli/diagram.py +178 -0
  13. janus/cli/document.py +174 -0
  14. janus/cli/embedding.py +122 -0
  15. janus/cli/llm.py +187 -0
  16. janus/cli/partition.py +125 -0
  17. janus/cli/self_eval.py +149 -0
  18. janus/cli/translate.py +183 -0
  19. janus/converter/__init__.py +1 -1
  20. janus/converter/_tests/test_translate.py +2 -0
  21. janus/converter/converter.py +129 -92
  22. janus/converter/document.py +21 -14
  23. janus/converter/evaluate.py +237 -4
  24. janus/converter/translate.py +3 -3
  25. janus/embedding/collections.py +1 -1
  26. janus/language/alc/_tests/alc.asm +3779 -0
  27. janus/language/alc/_tests/test_alc.py +1 -1
  28. janus/language/alc/alc.py +9 -4
  29. janus/language/binary/_tests/hello.bin +0 -0
  30. janus/language/block.py +47 -12
  31. janus/language/file.py +1 -1
  32. janus/language/mumps/_tests/mumps.m +235 -0
  33. janus/language/splitter.py +31 -23
  34. janus/language/treesitter/_tests/languages/fortran.f90 +416 -0
  35. janus/language/treesitter/_tests/languages/ibmhlasm.asm +16 -0
  36. janus/language/treesitter/_tests/languages/matlab.m +225 -0
  37. janus/language/treesitter/treesitter.py +9 -1
  38. janus/llm/models_info.py +26 -13
  39. janus/metrics/_tests/asm_test_file.asm +10 -0
  40. janus/metrics/_tests/mumps_test_file.m +6 -0
  41. janus/metrics/_tests/test_treesitter_metrics.py +1 -1
  42. janus/metrics/prompts/clarity.txt +8 -0
  43. janus/metrics/prompts/completeness.txt +16 -0
  44. janus/metrics/prompts/faithfulness.txt +10 -0
  45. janus/metrics/prompts/hallucination.txt +16 -0
  46. janus/metrics/prompts/quality.txt +8 -0
  47. janus/metrics/prompts/readability.txt +16 -0
  48. janus/metrics/prompts/usefulness.txt +16 -0
  49. janus/parsers/code_parser.py +4 -4
  50. janus/parsers/doc_parser.py +12 -9
  51. janus/parsers/eval_parsers/incose_parser.py +134 -0
  52. janus/parsers/eval_parsers/inline_comment_parser.py +112 -0
  53. janus/parsers/parser.py +7 -0
  54. janus/parsers/partition_parser.py +47 -13
  55. janus/parsers/reqs_parser.py +8 -5
  56. janus/parsers/uml.py +5 -4
  57. janus/prompts/prompt.py +2 -2
  58. janus/prompts/templates/README.md +30 -0
  59. janus/prompts/templates/basic_aggregation/human.txt +6 -0
  60. janus/prompts/templates/basic_aggregation/system.txt +1 -0
  61. janus/prompts/templates/basic_refinement/human.txt +14 -0
  62. janus/prompts/templates/basic_refinement/system.txt +1 -0
  63. janus/prompts/templates/diagram/human.txt +9 -0
  64. janus/prompts/templates/diagram/system.txt +1 -0
  65. janus/prompts/templates/diagram_with_documentation/human.txt +15 -0
  66. janus/prompts/templates/diagram_with_documentation/system.txt +1 -0
  67. janus/prompts/templates/document/human.txt +10 -0
  68. janus/prompts/templates/document/system.txt +1 -0
  69. janus/prompts/templates/document_cloze/human.txt +11 -0
  70. janus/prompts/templates/document_cloze/system.txt +1 -0
  71. janus/prompts/templates/document_cloze/variables.json +4 -0
  72. janus/prompts/templates/document_cloze/variables_asm.json +4 -0
  73. janus/prompts/templates/document_inline/human.txt +13 -0
  74. janus/prompts/templates/eval_prompts/incose/human.txt +32 -0
  75. janus/prompts/templates/eval_prompts/incose/system.txt +1 -0
  76. janus/prompts/templates/eval_prompts/incose/variables.json +3 -0
  77. janus/prompts/templates/eval_prompts/inline_comments/human.txt +49 -0
  78. janus/prompts/templates/eval_prompts/inline_comments/system.txt +1 -0
  79. janus/prompts/templates/eval_prompts/inline_comments/variables.json +3 -0
  80. janus/prompts/templates/micromanaged_mumps_v1.0/human.txt +23 -0
  81. janus/prompts/templates/micromanaged_mumps_v1.0/system.txt +3 -0
  82. janus/prompts/templates/micromanaged_mumps_v2.0/human.txt +28 -0
  83. janus/prompts/templates/micromanaged_mumps_v2.0/system.txt +3 -0
  84. janus/prompts/templates/micromanaged_mumps_v2.1/human.txt +29 -0
  85. janus/prompts/templates/micromanaged_mumps_v2.1/system.txt +3 -0
  86. janus/prompts/templates/multidocument/human.txt +15 -0
  87. janus/prompts/templates/multidocument/system.txt +1 -0
  88. janus/prompts/templates/partition/human.txt +22 -0
  89. janus/prompts/templates/partition/system.txt +1 -0
  90. janus/prompts/templates/partition/variables.json +4 -0
  91. janus/prompts/templates/pseudocode/human.txt +7 -0
  92. janus/prompts/templates/pseudocode/system.txt +7 -0
  93. janus/prompts/templates/refinement/fix_exceptions/human.txt +19 -0
  94. janus/prompts/templates/refinement/fix_exceptions/system.txt +1 -0
  95. janus/prompts/templates/refinement/format/code_format/human.txt +12 -0
  96. janus/prompts/templates/refinement/format/code_format/system.txt +1 -0
  97. janus/prompts/templates/refinement/format/requirements_format/human.txt +14 -0
  98. janus/prompts/templates/refinement/format/requirements_format/system.txt +1 -0
  99. janus/prompts/templates/refinement/hallucination/human.txt +13 -0
  100. janus/prompts/templates/refinement/hallucination/system.txt +1 -0
  101. janus/prompts/templates/refinement/reflection/human.txt +15 -0
  102. janus/prompts/templates/refinement/reflection/incose/human.txt +26 -0
  103. janus/prompts/templates/refinement/reflection/incose/system.txt +1 -0
  104. janus/prompts/templates/refinement/reflection/incose_deduplicate/human.txt +16 -0
  105. janus/prompts/templates/refinement/reflection/incose_deduplicate/system.txt +1 -0
  106. janus/prompts/templates/refinement/reflection/system.txt +1 -0
  107. janus/prompts/templates/refinement/revision/human.txt +16 -0
  108. janus/prompts/templates/refinement/revision/incose/human.txt +16 -0
  109. janus/prompts/templates/refinement/revision/incose/system.txt +1 -0
  110. janus/prompts/templates/refinement/revision/incose_deduplicate/human.txt +17 -0
  111. janus/prompts/templates/refinement/revision/incose_deduplicate/system.txt +1 -0
  112. janus/prompts/templates/refinement/revision/system.txt +1 -0
  113. janus/prompts/templates/refinement/uml/alc_fix_variables/human.txt +15 -0
  114. janus/prompts/templates/refinement/uml/alc_fix_variables/system.txt +2 -0
  115. janus/prompts/templates/refinement/uml/fix_connections/human.txt +15 -0
  116. janus/prompts/templates/refinement/uml/fix_connections/system.txt +2 -0
  117. janus/prompts/templates/requirements/human.txt +13 -0
  118. janus/prompts/templates/requirements/system.txt +2 -0
  119. janus/prompts/templates/retrieval/language_docs/human.txt +10 -0
  120. janus/prompts/templates/retrieval/language_docs/system.txt +1 -0
  121. janus/prompts/templates/simple/human.txt +16 -0
  122. janus/prompts/templates/simple/system.txt +3 -0
  123. janus/refiners/format.py +49 -0
  124. janus/refiners/refiner.py +143 -4
  125. janus/utils/enums.py +140 -111
  126. janus/utils/logger.py +2 -0
  127. {janus_llm-4.2.0.dist-info → janus_llm-4.3.5.dist-info}/METADATA +7 -7
  128. janus_llm-4.3.5.dist-info/RECORD +210 -0
  129. {janus_llm-4.2.0.dist-info → janus_llm-4.3.5.dist-info}/WHEEL +1 -1
  130. janus_llm-4.3.5.dist-info/entry_points.txt +3 -0
  131. janus/cli.py +0 -1343
  132. janus_llm-4.2.0.dist-info/RECORD +0 -113
  133. janus_llm-4.2.0.dist-info/entry_points.txt +0 -3
  134. {janus_llm-4.2.0.dist-info → janus_llm-4.3.5.dist-info}/LICENSE +0 -0
janus/cli.py DELETED
@@ -1,1343 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- import subprocess # nosec
5
- from pathlib import Path
6
- from typing import List, Optional
7
-
8
- import click
9
- import typer
10
- from pydantic import AnyHttpUrl
11
- from rich import print
12
- from rich.console import Console
13
- from rich.prompt import Confirm
14
- from typing_extensions import Annotated
15
-
16
- import janus.refiners.refiner
17
- import janus.refiners.uml
18
- from janus.converter.aggregator import Aggregator
19
- from janus.converter.converter import Converter
20
- from janus.converter.diagram import DiagramGenerator
21
- from janus.converter.document import Documenter, MadLibsDocumenter, MultiDocumenter
22
- from janus.converter.partition import Partitioner
23
- from janus.converter.requirements import RequirementsDocumenter
24
- from janus.converter.translate import Translator
25
- from janus.embedding.collections import Collections
26
- from janus.embedding.database import ChromaEmbeddingDatabase
27
- from janus.embedding.embedding_models_info import (
28
- EMBEDDING_COST_PER_MODEL,
29
- EMBEDDING_MODEL_CONFIG_DIR,
30
- EMBEDDING_TOKEN_LIMITS,
31
- EmbeddingModelType,
32
- )
33
- from janus.embedding.vectorize import ChromaDBVectorizer
34
- from janus.language.binary import BinarySplitter
35
- from janus.language.mumps import MumpsSplitter
36
- from janus.language.naive.registry import CUSTOM_SPLITTERS
37
- from janus.language.treesitter import TreeSitterSplitter
38
- from janus.llm.model_callbacks import COST_PER_1K_TOKENS
39
- from janus.llm.models_info import (
40
- MODEL_CONFIG_DIR,
41
- MODEL_ID_TO_LONG_ID,
42
- MODEL_TYPE_CONSTRUCTORS,
43
- MODEL_TYPES,
44
- TOKEN_LIMITS,
45
- azure_models,
46
- bedrock_models,
47
- openai_models,
48
- )
49
- from janus.metrics.cli import evaluate
50
- from janus.utils.enums import LANGUAGES
51
- from janus.utils.logger import create_logger
52
-
53
- httpx_logger = logging.getLogger("httpx")
54
- httpx_logger.setLevel(logging.WARNING)
55
-
56
- log = create_logger(__name__)
57
- homedir = Path.home().expanduser()
58
-
59
- janus_dir = homedir / ".janus"
60
- if not janus_dir.exists():
61
- janus_dir.mkdir(parents=True)
62
-
63
- db_file = janus_dir / ".db"
64
- if not db_file.exists():
65
- with open(db_file, "w") as f:
66
- f.write(str(janus_dir / "chroma.db"))
67
-
68
- with open(db_file, "r") as f:
69
- db_loc = f.read()
70
-
71
- collections_config_file = Path(db_loc) / "collections.json"
72
-
73
-
74
- def get_subclasses(cls):
75
- return set(cls.__subclasses__()).union(
76
- set(s for c in cls.__subclasses__() for s in get_subclasses(c))
77
- )
78
-
79
-
80
- REFINER_TYPES = get_subclasses(janus.refiners.refiner.JanusRefiner).union(
81
- {janus.refiners.refiner.JanusRefiner}
82
- )
83
- REFINERS = {r.__name__: r for r in REFINER_TYPES}
84
-
85
-
86
- def get_collections_config():
87
- if collections_config_file.exists():
88
- with open(collections_config_file, "r") as f:
89
- config = json.load(f)
90
- else:
91
- config = {}
92
- return config
93
-
94
-
95
- app = typer.Typer(
96
- help=(
97
- "[bold][dark_orange]Janus[/dark_orange] is a CLI for translating, "
98
- "documenting, and diagramming code using large language models.[/bold]"
99
- ),
100
- add_completion=False,
101
- no_args_is_help=True,
102
- context_settings={"help_option_names": ["-h", "--help"]},
103
- rich_markup_mode="rich",
104
- )
105
-
106
-
107
- db = typer.Typer(
108
- help="Database commands",
109
- add_completion=False,
110
- no_args_is_help=True,
111
- context_settings={"help_option_names": ["-h", "--help"]},
112
- )
113
- llm = typer.Typer(
114
- help="LLM commands",
115
- add_completion=False,
116
- no_args_is_help=True,
117
- context_settings={"help_option_names": ["-h", "--help"]},
118
- )
119
-
120
- embedding = typer.Typer(
121
- help="Embedding model commands",
122
- add_completion=False,
123
- no_args_is_help=True,
124
- context_settings={"help_option_names": ["-h", "--help"]},
125
- )
126
-
127
-
128
- def version_callback(value: bool) -> None:
129
- if value:
130
- from janus import __version__ as version
131
-
132
- print(f"Janus CLI [blue]v{version}[/blue]")
133
- raise typer.Exit()
134
-
135
-
136
- @app.callback()
137
- def common(
138
- ctx: typer.Context,
139
- version: bool = typer.Option(
140
- None,
141
- "--version",
142
- "-v",
143
- callback=version_callback,
144
- help="Print the version and exit.",
145
- ),
146
- ) -> None:
147
- """A function for getting the app version
148
-
149
- This will call the version_callback function to print the version and exit.
150
-
151
- Arguments:
152
- ctx: The typer context
153
- version: A boolean flag for the version
154
- """
155
- pass
156
-
157
-
158
- @app.command(
159
- help="Translate code from one language to another using an LLM.",
160
- no_args_is_help=True,
161
- )
162
- def translate(
163
- input_dir: Annotated[
164
- Path,
165
- typer.Option(
166
- "--input",
167
- "-i",
168
- help="The directory containing the source code to be translated. "
169
- "The files should all be in one flat directory.",
170
- ),
171
- ],
172
- source_lang: Annotated[
173
- str,
174
- typer.Option(
175
- "--source-language",
176
- "-s",
177
- help="The language of the source code.",
178
- click_type=click.Choice(sorted(LANGUAGES)),
179
- ),
180
- ],
181
- output_dir: Annotated[
182
- Path,
183
- typer.Option(
184
- "--output", "-o", help="The directory to store the translated code in."
185
- ),
186
- ],
187
- target_lang: Annotated[
188
- str,
189
- typer.Option(
190
- "--target-language",
191
- "-t",
192
- help="The desired output language to translate the source code to. The "
193
- "format can follow a 'language-version' syntax. Use 'text' to get plaintext"
194
- "results as returned by the LLM. Examples: `python-3.10`, `mumps`, `java-10`,"
195
- "text.",
196
- ),
197
- ],
198
- llm_name: Annotated[
199
- str,
200
- typer.Option(
201
- "--llm",
202
- "-L",
203
- help="The custom name of the model set with 'janus llm add'.",
204
- ),
205
- ],
206
- max_prompts: Annotated[
207
- int,
208
- typer.Option(
209
- "--max-prompts",
210
- "-m",
211
- help="The maximum number of times to prompt a model on one functional block "
212
- "before exiting the application. This is to prevent wasting too much money.",
213
- ),
214
- ] = 10,
215
- overwrite: Annotated[
216
- bool,
217
- typer.Option(
218
- "--overwrite/--preserve",
219
- help="Whether to overwrite existing files in the output directory",
220
- ),
221
- ] = False,
222
- skip_context: Annotated[
223
- bool,
224
- typer.Option(
225
- "--skip-context",
226
- help="Prompts will include any context information associated with source"
227
- " code blocks, unless this option is specified",
228
- ),
229
- ] = False,
230
- temp: Annotated[
231
- float,
232
- typer.Option("--temperature", "-T", help="Sampling temperature.", min=0, max=2),
233
- ] = 0.7,
234
- prompt_template: Annotated[
235
- str,
236
- typer.Option(
237
- "--prompt-template",
238
- "-p",
239
- help="Name of the Janus prompt template directory or "
240
- "path to a directory containing those template files.",
241
- ),
242
- ] = "simple",
243
- collection: Annotated[
244
- str,
245
- typer.Option(
246
- "--collection",
247
- "-c",
248
- help="If set, will put the translated result into a Chroma DB "
249
- "collection with the name provided.",
250
- ),
251
- ] = None,
252
- splitter_type: Annotated[
253
- str,
254
- typer.Option(
255
- "-S",
256
- "--splitter",
257
- help="Name of custom splitter to use",
258
- click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
259
- ),
260
- ] = "file",
261
- refiner_types: Annotated[
262
- list[str],
263
- typer.Option(
264
- "-r",
265
- "--refiner",
266
- help="List of refiner types to use. Add -r for each refiner to use in\
267
- refinement chain",
268
- click_type=click.Choice(list(REFINERS.keys())),
269
- ),
270
- ] = ["JanusRefiner"],
271
- retriever_type: Annotated[
272
- str,
273
- typer.Option(
274
- "-R",
275
- "--retriever",
276
- help="Name of custom retriever to use",
277
- click_type=click.Choice(["active_usings", "language_docs"]),
278
- ),
279
- ] = None,
280
- max_tokens: Annotated[
281
- int,
282
- typer.Option(
283
- "--max-tokens",
284
- "-M",
285
- help="The maximum number of tokens the model will take in. "
286
- "If unspecificed, model's default max will be used.",
287
- ),
288
- ] = None,
289
- ):
290
- refiner_types = [REFINERS[r] for r in refiner_types]
291
- try:
292
- target_language, target_version = target_lang.split("-")
293
- except ValueError:
294
- target_language = target_lang
295
- target_version = None
296
- # make sure not overwriting input
297
- if source_lang.lower() == target_language.lower() and input_dir == output_dir:
298
- log.error("Output files would overwrite input! Aborting...")
299
- raise ValueError
300
-
301
- model_arguments = dict(temperature=temp)
302
- collections_config = get_collections_config()
303
- translator = Translator(
304
- model=llm_name,
305
- model_arguments=model_arguments,
306
- source_language=source_lang,
307
- target_language=target_language,
308
- target_version=target_version,
309
- max_prompts=max_prompts,
310
- max_tokens=max_tokens,
311
- prompt_template=prompt_template,
312
- db_path=db_loc,
313
- db_config=collections_config,
314
- splitter_type=splitter_type,
315
- refiner_types=refiner_types,
316
- retriever_type=retriever_type,
317
- )
318
- translator.translate(input_dir, output_dir, overwrite, collection)
319
-
320
-
321
- @app.command(
322
- help="Document input code using an LLM.",
323
- no_args_is_help=True,
324
- )
325
- def document(
326
- input_dir: Annotated[
327
- Path,
328
- typer.Option(
329
- "--input",
330
- "-i",
331
- help="The directory containing the source code to be translated. "
332
- "The files should all be in one flat directory.",
333
- ),
334
- ],
335
- language: Annotated[
336
- str,
337
- typer.Option(
338
- "--language",
339
- "-l",
340
- help="The language of the source code.",
341
- click_type=click.Choice(sorted(LANGUAGES)),
342
- ),
343
- ],
344
- output_dir: Annotated[
345
- Path,
346
- typer.Option(
347
- "--output-dir", "-o", help="The directory to store the translated code in."
348
- ),
349
- ],
350
- llm_name: Annotated[
351
- str,
352
- typer.Option(
353
- "--llm",
354
- "-L",
355
- help="The custom name of the model set with 'janus llm add'.",
356
- ),
357
- ],
358
- max_prompts: Annotated[
359
- int,
360
- typer.Option(
361
- "--max-prompts",
362
- "-m",
363
- help="The maximum number of times to prompt a model on one functional block "
364
- "before exiting the application. This is to prevent wasting too much money.",
365
- ),
366
- ] = 10,
367
- overwrite: Annotated[
368
- bool,
369
- typer.Option(
370
- "--overwrite/--preserve",
371
- help="Whether to overwrite existing files in the output directory",
372
- ),
373
- ] = False,
374
- doc_mode: Annotated[
375
- str,
376
- typer.Option(
377
- "--doc-mode",
378
- "-d",
379
- help="The documentation mode.",
380
- click_type=click.Choice(["madlibs", "summary", "multidoc", "requirements"]),
381
- ),
382
- ] = "madlibs",
383
- comments_per_request: Annotated[
384
- int,
385
- typer.Option(
386
- "--comments-per-request",
387
- "-rc",
388
- help="The maximum number of comments to generate per request when using "
389
- "MadLibs documentation mode.",
390
- ),
391
- ] = None,
392
- drop_comments: Annotated[
393
- bool,
394
- typer.Option(
395
- "--drop-comments/--keep-comments",
396
- help="Whether to drop or keep comments in the code sent to the LLM",
397
- ),
398
- ] = False,
399
- temperature: Annotated[
400
- float,
401
- typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
402
- ] = 0.7,
403
- collection: Annotated[
404
- str,
405
- typer.Option(
406
- "--collection",
407
- "-c",
408
- help="If set, will put the translated result into a Chroma DB "
409
- "collection with the name provided.",
410
- ),
411
- ] = None,
412
- splitter_type: Annotated[
413
- str,
414
- typer.Option(
415
- "-S",
416
- "--splitter",
417
- help="Name of custom splitter to use",
418
- click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
419
- ),
420
- ] = "file",
421
- refiner_types: Annotated[
422
- list[str],
423
- typer.Option(
424
- "-r",
425
- "--refiner",
426
- help="List of refiner types to use. Add -r for each refiner to use in\
427
- refinement chain",
428
- click_type=click.Choice(list(REFINERS.keys())),
429
- ),
430
- ] = ["JanusRefiner"],
431
- retriever_type: Annotated[
432
- str,
433
- typer.Option(
434
- "-R",
435
- "--retriever",
436
- help="Name of custom retriever to use",
437
- click_type=click.Choice(["active_usings", "language_docs"]),
438
- ),
439
- ] = None,
440
- max_tokens: Annotated[
441
- int,
442
- typer.Option(
443
- "--max-tokens",
444
- "-M",
445
- help="The maximum number of tokens the model will take in. "
446
- "If unspecificed, model's default max will be used.",
447
- ),
448
- ] = None,
449
- ):
450
- refiner_types = [REFINERS[r] for r in refiner_types]
451
- model_arguments = dict(temperature=temperature)
452
- collections_config = get_collections_config()
453
- kwargs = dict(
454
- model=llm_name,
455
- model_arguments=model_arguments,
456
- source_language=language,
457
- max_prompts=max_prompts,
458
- max_tokens=max_tokens,
459
- db_path=db_loc,
460
- db_config=collections_config,
461
- splitter_type=splitter_type,
462
- refiner_types=refiner_types,
463
- retriever_type=retriever_type,
464
- )
465
- if doc_mode == "madlibs":
466
- documenter = MadLibsDocumenter(
467
- comments_per_request=comments_per_request, **kwargs
468
- )
469
- elif doc_mode == "multidoc":
470
- documenter = MultiDocumenter(drop_comments=drop_comments, **kwargs)
471
- elif doc_mode == "requirements":
472
- documenter = RequirementsDocumenter(drop_comments=drop_comments, **kwargs)
473
- else:
474
- documenter = Documenter(drop_comments=drop_comments, **kwargs)
475
-
476
- documenter.translate(input_dir, output_dir, overwrite, collection)
477
-
478
-
479
- @app.command()
480
- def aggregate(
481
- input_dir: Annotated[
482
- Path,
483
- typer.Option(
484
- "--input",
485
- "-i",
486
- help="The directory containing the source code to be translated. "
487
- "The files should all be in one flat directory.",
488
- ),
489
- ],
490
- language: Annotated[
491
- str,
492
- typer.Option(
493
- "--language",
494
- "-l",
495
- help="The language of the source code.",
496
- click_type=click.Choice(sorted(LANGUAGES)),
497
- ),
498
- ],
499
- output_dir: Annotated[
500
- Path,
501
- typer.Option(
502
- "--output-dir", "-o", help="The directory to store the translated code in."
503
- ),
504
- ],
505
- llm_name: Annotated[
506
- str,
507
- typer.Option(
508
- "--llm",
509
- "-L",
510
- help="The custom name of the model set with 'janus llm add'.",
511
- ),
512
- ],
513
- max_prompts: Annotated[
514
- int,
515
- typer.Option(
516
- "--max-prompts",
517
- "-m",
518
- help="The maximum number of times to prompt a model on one functional block "
519
- "before exiting the application. This is to prevent wasting too much money.",
520
- ),
521
- ] = 10,
522
- overwrite: Annotated[
523
- bool,
524
- typer.Option(
525
- "--overwrite/--preserve",
526
- help="Whether to overwrite existing files in the output directory",
527
- ),
528
- ] = False,
529
- temperature: Annotated[
530
- float,
531
- typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
532
- ] = 0.7,
533
- collection: Annotated[
534
- str,
535
- typer.Option(
536
- "--collection",
537
- "-c",
538
- help="If set, will put the translated result into a Chroma DB "
539
- "collection with the name provided.",
540
- ),
541
- ] = None,
542
- splitter_type: Annotated[
543
- str,
544
- typer.Option(
545
- "-S",
546
- "--splitter",
547
- help="Name of custom splitter to use",
548
- click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
549
- ),
550
- ] = "file",
551
- intermediate_converters: Annotated[
552
- List[str],
553
- typer.Option(
554
- "-C",
555
- "--converter",
556
- help="Name of an intermediate converter to use",
557
- click_type=click.Choice([c.__name__ for c in get_subclasses(Converter)]),
558
- ),
559
- ] = ["Documenter"],
560
- ):
561
- converter_subclasses = get_subclasses(Converter)
562
- converter_subclasses_map = {c.__name__: c for c in converter_subclasses}
563
- model_arguments = dict(temperature=temperature)
564
- collections_config = get_collections_config()
565
- converters = []
566
- for ic in intermediate_converters:
567
- converters.append(
568
- converter_subclasses_map[ic](
569
- model=llm_name,
570
- model_arguments=model_arguments,
571
- source_language=language,
572
- max_prompts=max_prompts,
573
- db_path=db_loc,
574
- db_config=collections_config,
575
- splitter_type=splitter_type,
576
- )
577
- )
578
-
579
- aggregator = Aggregator(
580
- intermediate_converters=converters,
581
- model=llm_name,
582
- model_arguments=model_arguments,
583
- source_language=language,
584
- max_prompts=max_prompts,
585
- db_path=db_loc,
586
- db_config=collections_config,
587
- splitter_type=splitter_type,
588
- prompt_template="basic_aggregation",
589
- )
590
- aggregator.translate(input_dir, output_dir, overwrite, collection)
591
-
592
-
593
- @app.command(
594
- help="Partition input code using an LLM.",
595
- no_args_is_help=True,
596
- )
597
- def partition(
598
- input_dir: Annotated[
599
- Path,
600
- typer.Option(
601
- "--input",
602
- "-i",
603
- help="The directory containing the source code to be partitioned. ",
604
- ),
605
- ],
606
- language: Annotated[
607
- str,
608
- typer.Option(
609
- "--language",
610
- "-l",
611
- help="The language of the source code.",
612
- click_type=click.Choice(sorted(LANGUAGES)),
613
- ),
614
- ],
615
- output_dir: Annotated[
616
- Path,
617
- typer.Option(
618
- "--output-dir", "-o", help="The directory to store the partitioned code in."
619
- ),
620
- ],
621
- llm_name: Annotated[
622
- str,
623
- typer.Option(
624
- "--llm",
625
- "-L",
626
- help="The custom name of the model set with 'janus llm add'.",
627
- ),
628
- ] = "gpt-4o",
629
- max_prompts: Annotated[
630
- int,
631
- typer.Option(
632
- "--max-prompts",
633
- "-m",
634
- help="The maximum number of times to prompt a model on one functional block "
635
- "before exiting the application. This is to prevent wasting too much money.",
636
- ),
637
- ] = 10,
638
- overwrite: Annotated[
639
- bool,
640
- typer.Option(
641
- "--overwrite/--preserve",
642
- help="Whether to overwrite existing files in the output directory",
643
- ),
644
- ] = False,
645
- temperature: Annotated[
646
- float,
647
- typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
648
- ] = 0.7,
649
- splitter_type: Annotated[
650
- str,
651
- typer.Option(
652
- "-S",
653
- "--splitter",
654
- help="Name of custom splitter to use",
655
- click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
656
- ),
657
- ] = "file",
658
- max_tokens: Annotated[
659
- int,
660
- typer.Option(
661
- "--max-tokens",
662
- "-M",
663
- help="The maximum number of tokens the model will take in. "
664
- "If unspecificed, model's default max will be used.",
665
- ),
666
- ] = None,
667
- partition_token_limit: Annotated[
668
- int,
669
- typer.Option(
670
- "--partition-tokens",
671
- "-pt",
672
- help="The limit on the number of tokens per partition.",
673
- ),
674
- ] = 8192,
675
- ):
676
- model_arguments = dict(temperature=temperature)
677
- kwargs = dict(
678
- model=llm_name,
679
- model_arguments=model_arguments,
680
- source_language=language,
681
- max_prompts=max_prompts,
682
- max_tokens=max_tokens,
683
- splitter_type=splitter_type,
684
- partition_token_limit=partition_token_limit,
685
- )
686
- partitioner = Partitioner(**kwargs)
687
- partitioner.translate(input_dir, output_dir, overwrite)
688
-
689
-
690
- @app.command(
691
- help="Diagram input code using an LLM.",
692
- no_args_is_help=True,
693
- )
694
- def diagram(
695
- input_dir: Annotated[
696
- Path,
697
- typer.Option(
698
- "--input",
699
- "-i",
700
- help="The directory containing the source code to be translated. "
701
- "The files should all be in one flat directory.",
702
- ),
703
- ],
704
- language: Annotated[
705
- str,
706
- typer.Option(
707
- "--language",
708
- "-l",
709
- help="The language of the source code.",
710
- click_type=click.Choice(sorted(LANGUAGES)),
711
- ),
712
- ],
713
- output_dir: Annotated[
714
- Path,
715
- typer.Option(
716
- "--output-dir", "-o", help="The directory to store the translated code in."
717
- ),
718
- ],
719
- llm_name: Annotated[
720
- str,
721
- typer.Option(
722
- "--llm",
723
- "-L",
724
- help="The custom name of the model set with 'janus llm add'.",
725
- ),
726
- ],
727
- max_prompts: Annotated[
728
- int,
729
- typer.Option(
730
- "--max-prompts",
731
- "-m",
732
- help="The maximum number of times to prompt a model on one functional block "
733
- "before exiting the application. This is to prevent wasting too much money.",
734
- ),
735
- ] = 10,
736
- overwrite: Annotated[
737
- bool,
738
- typer.Option(
739
- "--overwrite/--preserve",
740
- help="Whether to overwrite existing files in the output directory",
741
- ),
742
- ] = False,
743
- temperature: Annotated[
744
- float,
745
- typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
746
- ] = 0.7,
747
- collection: Annotated[
748
- str,
749
- typer.Option(
750
- "--collection",
751
- "-c",
752
- help="If set, will put the translated result into a Chroma DB "
753
- "collection with the name provided.",
754
- ),
755
- ] = None,
756
- diagram_type: Annotated[
757
- str,
758
- typer.Option(
759
- "--diagram-type", "-dg", help="Diagram type to generate in PLANTUML"
760
- ),
761
- ] = "Activity",
762
- add_documentation: Annotated[
763
- bool,
764
- typer.Option(
765
- "--add-documentation/--no-documentation",
766
- "-ad",
767
- help="Whether to use documentation in generation",
768
- ),
769
- ] = False,
770
- splitter_type: Annotated[
771
- str,
772
- typer.Option(
773
- "-S",
774
- "--splitter",
775
- help="Name of custom splitter to use",
776
- click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
777
- ),
778
- ] = "file",
779
- refiner_types: Annotated[
780
- list[str],
781
- typer.Option(
782
- "-r",
783
- "--refiner",
784
- help="List of refiner types to use. Add -r for each refiner to use in\
785
- refinement chain",
786
- click_type=click.Choice(list(REFINERS.keys())),
787
- ),
788
- ] = ["JanusRefiner"],
789
- retriever_type: Annotated[
790
- str,
791
- typer.Option(
792
- "-R",
793
- "--retriever",
794
- help="Name of custom retriever to use",
795
- click_type=click.Choice(["active_usings", "language_docs"]),
796
- ),
797
- ] = None,
798
- ):
799
- refiner_types = [REFINERS[r] for r in refiner_types]
800
- model_arguments = dict(temperature=temperature)
801
- collections_config = get_collections_config()
802
- diagram_generator = DiagramGenerator(
803
- model=llm_name,
804
- model_arguments=model_arguments,
805
- source_language=language,
806
- max_prompts=max_prompts,
807
- db_path=db_loc,
808
- db_config=collections_config,
809
- splitter_type=splitter_type,
810
- refiner_types=refiner_types,
811
- retriever_type=retriever_type,
812
- diagram_type=diagram_type,
813
- add_documentation=add_documentation,
814
- )
815
- diagram_generator.translate(input_dir, output_dir, overwrite, collection)
816
-
817
-
818
- @db.command("init", help="Connect to or create a database.")
819
- def db_init(
820
- path: Annotated[
821
- str, typer.Option("--path", "-p", help="The path to the database file.")
822
- ] = str(janus_dir / "chroma.db"),
823
- url: Annotated[
824
- str,
825
- typer.Option(
826
- "--url",
827
- "-u",
828
- help="The URL of the database if the database is running externally.",
829
- ),
830
- ] = "",
831
- ) -> None:
832
- global db_loc
833
- if url != "":
834
- print(f"Pointing to Chroma DB at {url}")
835
- with open(db_file, "w") as f:
836
- f.write(url)
837
- db_loc = url
838
- else:
839
- path = os.path.abspath(path)
840
- print(f"Setting up Chroma DB at {path}")
841
- with open(db_file, "w") as f:
842
- f.write(path)
843
- db_loc = path
844
- global embedding_db
845
- embedding_db = ChromaEmbeddingDatabase(db_loc)
846
-
847
-
848
- @db.command("status", help="Print current database location.")
849
- def db_status():
850
- print(f"Chroma DB currently pointing to {db_loc}")
851
-
852
-
853
- @db.command(
854
- "ls",
855
- help="List the current database's collections. Or supply a collection name to list "
856
- "information about its contents.",
857
- )
858
- def db_ls(
859
- collection_name: Annotated[
860
- Optional[str], typer.Argument(help="The name of the collection.")
861
- ] = None,
862
- peek: Annotated[
863
- Optional[int],
864
- typer.Option("--peek", "-p", help="Peek at N entries for a specific collection."),
865
- ] = None,
866
- ) -> None:
867
- """List the current database's collections"""
868
- if peek is not None and collection_name is None:
869
- print(
870
- "\n[bold red]Cannot peek at all collections. Please specify a "
871
- "collection by name.[/bold red]"
872
- )
873
- return
874
- db = ChromaEmbeddingDatabase(db_loc)
875
- collections = Collections(db)
876
- collection_list = collections.get(collection_name)
877
- for collection in collection_list:
878
- print(
879
- f"\n[bold underline]Collection[/bold underline]: "
880
- f"[bold salmon1]{collection.name}[/bold salmon1]"
881
- )
882
- print(f" ID: {collection.id}")
883
- print(f" Metadata: {collection.metadata}")
884
- print(f" Tenant: [green]{collection.tenant}[/green]")
885
- print(f" Database: [green]{collection.database}[/green]")
886
- print(f" Length: {collection.count()}")
887
- if peek:
888
- entry = collection.peek(peek)
889
- entry["embeddings"] = entry["embeddings"][0][:2] + ["..."]
890
- if peek == 1:
891
- print(" [bold]Peeking at first entry[/bold]:")
892
- else:
893
- print(f" [bold]Peeking at first {peek} entries[/bold]:")
894
- print(entry)
895
- print()
896
-
897
-
898
- @db.command("add", help="Add a collection to the current database.")
899
- def db_add(
900
- collection_name: Annotated[str, typer.Argument(help="The name of the collection.")],
901
- model_name: Annotated[str, typer.Argument(help="The name of the embedding model.")],
902
- input_dir: Annotated[
903
- str,
904
- typer.Option(
905
- "--input",
906
- "-i",
907
- help="The directory containing the source code to be added.",
908
- ),
909
- ] = "./",
910
- input_lang: Annotated[
911
- str, typer.Option("--language", "-l", help="The language of the source code.")
912
- ] = "python",
913
- max_tokens: Annotated[
914
- int,
915
- typer.Option(
916
- "--max-tokens",
917
- "-m",
918
- help="The maximum number of tokens for each chunk of input source code.",
919
- ),
920
- ] = 4096,
921
- ) -> None:
922
- """Add a collection to the database
923
-
924
- Arguments:
925
- collection_name: The name of the collection to add
926
- model_name: The name of the embedding model to use
927
- input_dir: The directory containing the source code to be added
928
- input_lang: The language of the source code
929
- max_tokens: The maximum number of tokens for each chunk of input source code
930
- """
931
- # TODO: import factory
932
- console = Console()
933
-
934
- added_to = _check_collection(collection_name, input_dir)
935
- collections_config = get_collections_config()
936
-
937
- with console.status(
938
- f"Adding collection: [bold salmon]{collection_name}[/bold salmon]",
939
- spinner="arrow3",
940
- ):
941
- vectorizer_factory = ChromaDBVectorizer()
942
- vectorizer = vectorizer_factory.create_vectorizer(
943
- path=db_loc, config=collections_config
944
- )
945
- vectorizer.get_or_create_collection(collection_name, model_name=model_name)
946
- input_dir = Path(input_dir)
947
- suffix = LANGUAGES[input_lang]["suffix"]
948
- source_glob = f"**/*.{suffix}"
949
- input_paths = [p for p in input_dir.rglob(source_glob)]
950
- if input_lang in CUSTOM_SPLITTERS:
951
- if input_lang == "mumps":
952
- splitter = MumpsSplitter(
953
- max_tokens=max_tokens,
954
- )
955
- elif input_lang == "binary":
956
- splitter = BinarySplitter(
957
- max_tokens=max_tokens,
958
- )
959
- else:
960
- splitter = TreeSitterSplitter(
961
- language=input_lang,
962
- max_tokens=max_tokens,
963
- )
964
- for input_path in input_paths:
965
- input_block = splitter.split(input_path)
966
- vectorizer.add_nodes_recursively(
967
- input_block,
968
- collection_name,
969
- input_path.name,
970
- )
971
- total_files = len([p for p in Path.glob(input_dir, "**/*") if not p.is_dir()])
972
- if added_to:
973
- print(
974
- f"\nAdded to [bold salmon1]{collection_name}[/bold salmon1]:\n"
975
- f" Embedding Model: [green]{model_name}[/green]\n"
976
- f" Input Directory: {input_dir.absolute()}\n"
977
- f" {input_lang.capitalize()} [green]*.{suffix}[/green] Files: "
978
- f"{len(input_paths)}\n"
979
- " Other Files (skipped): "
980
- f"{total_files - len(input_paths)}\n"
981
- )
982
- [p for p in Path.glob(input_dir, f"**/*.{suffix}") if not p.is_dir()]
983
- else:
984
- print(
985
- f"\nCreated [bold salmon1]{collection_name}[/bold salmon1]:\n"
986
- f" Embedding Model: '{model_name}'\n"
987
- f" Input Directory: {input_dir.absolute()}\n"
988
- f" {input_lang.capitalize()} [green]*.{suffix}[/green] Files: "
989
- f"{len(input_paths)}\n"
990
- " Other Files (skipped): "
991
- f"{total_files - len(input_paths)}\n"
992
- )
993
- with open(collections_config_file, "w") as f:
994
- json.dump(vectorizer.config, f, indent=2)
995
-
996
-
997
- @db.command(
998
- "rm",
999
- help="Remove a collection from the database.",
1000
- )
1001
- def db_rm(
1002
- collection_name: Annotated[str, typer.Argument(help="The name of the collection.")],
1003
- confirm: Annotated[
1004
- bool,
1005
- typer.Option(
1006
- "--yes",
1007
- "-y",
1008
- help="Confirm the removal of the collection.",
1009
- ),
1010
- ],
1011
- ) -> None:
1012
- """Remove a collection from the database
1013
-
1014
- Arguments:
1015
- collection_name: The name of the collection to remove
1016
- """
1017
- if not confirm:
1018
- delete = Confirm.ask(
1019
- f"\nAre you sure you want to [bold red]remove[/bold red] "
1020
- f"[bold salmon1]{collection_name}[/bold salmon1]?",
1021
- )
1022
- else:
1023
- delete = True
1024
- if not delete:
1025
- raise typer.Abort()
1026
- db = ChromaEmbeddingDatabase(db_loc)
1027
- collections = Collections(db)
1028
- collections.delete(collection_name)
1029
- print(
1030
- f"[bold red]Removed[/bold red] collection "
1031
- f"[bold salmon1]{collection_name}[/bold salmon1]"
1032
- )
1033
-
1034
-
1035
- def _check_collection(collection_name: str, input_dir: str | Path) -> bool:
1036
- db = ChromaEmbeddingDatabase(db_loc)
1037
- collections = Collections(db)
1038
- added_to = False
1039
- try:
1040
- collections.get(collection_name)
1041
- # confirm_add = Confirm.ask(
1042
- # f"\nCollection [bold salmon1]{collection_name}[/bold salmon1] exists. Are "
1043
- # "you sure you want to update it with the contents of"
1044
- # f"[bold green]{input_dir}[/bold green]?"
1045
- # )
1046
- added_to = True
1047
- # if not confirm_add:
1048
- # raise typer.Abort()
1049
- except ValueError:
1050
- pass
1051
- return added_to
1052
-
1053
-
1054
- @llm.command("add", help="Add a model config to janus")
1055
- def llm_add(
1056
- model_name: Annotated[
1057
- str, typer.Argument(help="The user's custom name of the model")
1058
- ],
1059
- model_type: Annotated[
1060
- str,
1061
- typer.Option(
1062
- "--type",
1063
- "-t",
1064
- help="The type of the model",
1065
- click_type=click.Choice(sorted(list(MODEL_TYPE_CONSTRUCTORS.keys()))),
1066
- ),
1067
- ] = "Azure",
1068
- ):
1069
- if not MODEL_CONFIG_DIR.exists():
1070
- MODEL_CONFIG_DIR.mkdir(parents=True)
1071
- model_cfg = MODEL_CONFIG_DIR / f"{model_name}.json"
1072
- if model_type == "HuggingFace":
1073
- url = typer.prompt("Enter the model's URL")
1074
- max_tokens = typer.prompt(
1075
- "Enter the model's maximum tokens", default=4096, type=int
1076
- )
1077
- in_cost = typer.prompt("Enter the cost per input token", default=0, type=float)
1078
- out_cost = typer.prompt("Enter the cost per output token", default=0, type=float)
1079
- params = dict(
1080
- inference_server_url=url,
1081
- max_new_tokens=max_tokens,
1082
- top_k=10,
1083
- top_p=0.95,
1084
- typical_p=0.95,
1085
- temperature=0.01,
1086
- repetition_penalty=1.03,
1087
- timeout=240,
1088
- )
1089
- cfg = {
1090
- "model_type": model_type,
1091
- "model_args": params,
1092
- "token_limit": max_tokens,
1093
- "model_cost": {"input": in_cost, "output": out_cost},
1094
- }
1095
- elif model_type == "HuggingFaceLocal":
1096
- model_id = typer.prompt("Enter the model ID")
1097
- task = typer.prompt("Enter the task")
1098
- max_tokens = typer.prompt(
1099
- "Enter the model's maximum tokens", default=4096, type=int
1100
- )
1101
- in_cost = 0
1102
- out_cost = 0
1103
- params = {"model_id": model_id, "task": task}
1104
- cfg = {
1105
- "model_type": model_type,
1106
- "model_args": params,
1107
- "token_limit": max_tokens,
1108
- "model_cost": {"input": in_cost, "output": out_cost},
1109
- }
1110
- elif model_type == "OpenAI":
1111
- print("DEPRECATED: Use 'Azure' instead. CTRL+C to exit.")
1112
- model_id = typer.prompt(
1113
- "Enter the model ID (list model IDs with `janus llm ls -a`)",
1114
- default="gpt-4o",
1115
- type=click.Choice(openai_models),
1116
- show_choices=False,
1117
- )
1118
- params = dict(
1119
- # OpenAI uses the "model_name" key for what we're calling "long_model_id"
1120
- model_name=MODEL_ID_TO_LONG_ID[model_id],
1121
- temperature=0.7,
1122
- n=1,
1123
- )
1124
- max_tokens = TOKEN_LIMITS[MODEL_ID_TO_LONG_ID[model_id]]
1125
- model_cost = COST_PER_1K_TOKENS[MODEL_ID_TO_LONG_ID[model_id]]
1126
- cfg = {
1127
- "model_type": model_type,
1128
- "model_id": model_id,
1129
- "model_args": params,
1130
- "token_limit": max_tokens,
1131
- "model_cost": model_cost,
1132
- }
1133
- elif model_type == "Azure":
1134
- model_id = typer.prompt(
1135
- "Enter the model ID (list model IDs with `janus llm ls -a`)",
1136
- default="gpt-4o",
1137
- type=click.Choice(azure_models),
1138
- show_choices=False,
1139
- )
1140
- params = dict(
1141
- # Azure uses the "azure_deployment" key for what we're calling "long_model_id"
1142
- azure_deployment=MODEL_ID_TO_LONG_ID[model_id],
1143
- temperature=0.7,
1144
- n=1,
1145
- )
1146
- max_tokens = TOKEN_LIMITS[MODEL_ID_TO_LONG_ID[model_id]]
1147
- model_cost = COST_PER_1K_TOKENS[MODEL_ID_TO_LONG_ID[model_id]]
1148
- cfg = {
1149
- "model_type": model_type,
1150
- "model_id": model_id,
1151
- "model_args": params,
1152
- "token_limit": max_tokens,
1153
- "model_cost": model_cost,
1154
- }
1155
- elif model_type == "BedrockChat" or model_type == "Bedrock":
1156
- model_id = typer.prompt(
1157
- "Enter the model ID (list model IDs with `janus llm ls -a`)",
1158
- default="bedrock-claude-sonnet",
1159
- type=click.Choice(bedrock_models),
1160
- show_choices=False,
1161
- )
1162
- params = dict(
1163
- # Bedrock uses the "model_id" key for what we're calling "long_model_id"
1164
- model_id=MODEL_ID_TO_LONG_ID[model_id],
1165
- model_kwargs={"temperature": 0.7},
1166
- )
1167
- max_tokens = TOKEN_LIMITS[MODEL_ID_TO_LONG_ID[model_id]]
1168
- model_cost = COST_PER_1K_TOKENS[MODEL_ID_TO_LONG_ID[model_id]]
1169
- cfg = {
1170
- "model_type": model_type,
1171
- "model_id": model_id,
1172
- "model_args": params,
1173
- "token_limit": max_tokens,
1174
- "model_cost": model_cost,
1175
- }
1176
- else:
1177
- raise ValueError(f"Unknown model type {model_type}")
1178
- with open(model_cfg, "w") as f:
1179
- json.dump(cfg, f, indent=2)
1180
- print(f"Model config written to {model_cfg}")
1181
-
1182
-
1183
- @llm.command("ls", help="List all of the user-configured models")
1184
- def llm_ls(
1185
- all: Annotated[
1186
- bool,
1187
- typer.Option(
1188
- "--all",
1189
- "-a",
1190
- is_flag=True,
1191
- help="List all models, including the default model IDs.",
1192
- click_type=click.Choice(sorted(list(MODEL_TYPE_CONSTRUCTORS.keys()))),
1193
- ),
1194
- ] = False,
1195
- ):
1196
- print("\n[green]User-configured models[/green]:")
1197
- for model_cfg in MODEL_CONFIG_DIR.glob("*.json"):
1198
- with open(model_cfg, "r") as f:
1199
- cfg = json.load(f)
1200
- print(f"\t[blue]{model_cfg.stem}[/blue]: [purple]{cfg['model_type']}[/purple]")
1201
-
1202
- if all:
1203
- print("\n[green]Available model IDs[/green]:")
1204
- for model_id, model_type in MODEL_TYPES.items():
1205
- print(f"\t[blue]{model_id}[/blue]: [purple]{model_type}[/purple]")
1206
-
1207
-
1208
- @embedding.command("add", help="Add an embedding model config to janus")
1209
- def embedding_add(
1210
- model_name: Annotated[
1211
- str, typer.Argument(help="The user's custom name for the model")
1212
- ],
1213
- model_type: Annotated[
1214
- str,
1215
- typer.Option(
1216
- "--type",
1217
- "-t",
1218
- help="The type of the model",
1219
- click_type=click.Choice(list(val.value for val in EmbeddingModelType)),
1220
- ),
1221
- ] = "OpenAI",
1222
- ):
1223
- if not EMBEDDING_MODEL_CONFIG_DIR.exists():
1224
- EMBEDDING_MODEL_CONFIG_DIR.mkdir(parents=True)
1225
- model_cfg = EMBEDDING_MODEL_CONFIG_DIR / f"{model_name}.json"
1226
- if model_type in EmbeddingModelType.HuggingFaceInferenceAPI.values:
1227
- hf = typer.style("HuggingFaceInferenceAPI", fg="yellow")
1228
- url = typer.prompt(f"Enter the {hf} model's URL", type=str, value_proc=AnyHttpUrl)
1229
- api_model_name = typer.prompt("Enter the model's name", type=str, default="")
1230
- api_key = typer.prompt("Enter the API key", type=str, default="")
1231
- max_tokens = typer.prompt(
1232
- "Enter the model's maximum tokens", default=8191, type=int
1233
- )
1234
- in_cost = typer.prompt("Enter the cost per input token", default=0, type=float)
1235
- out_cost = typer.prompt("Enter the cost per output token", default=0, type=float)
1236
- params = dict(
1237
- model_name=api_model_name,
1238
- api_key=api_key,
1239
- )
1240
- cfg = {
1241
- "model_type": model_type,
1242
- "model_identifier": str(url),
1243
- "model_args": params,
1244
- "token_limit": max_tokens,
1245
- "model_cost": {"input": in_cost, "output": out_cost},
1246
- }
1247
- elif model_type in EmbeddingModelType.HuggingFaceLocal.values:
1248
- hf = typer.style("HuggingFace", fg="yellow")
1249
- model_id = typer.prompt(
1250
- f"Enter the {hf} model ID",
1251
- default="sentence-transformers/all-MiniLM-L6-v2",
1252
- type=str,
1253
- )
1254
- cache_folder = str(
1255
- Path(
1256
- typer.prompt(
1257
- "Enter the model's cache folder",
1258
- default=EMBEDDING_MODEL_CONFIG_DIR / "cache",
1259
- type=str,
1260
- )
1261
- )
1262
- )
1263
- max_tokens = typer.prompt(
1264
- "Enter the model's maximum tokens", default=8191, type=int
1265
- )
1266
- params = dict(
1267
- cache_folder=str(cache_folder),
1268
- )
1269
- cfg = {
1270
- "model_type": model_type,
1271
- "model_identifier": model_id,
1272
- "model_args": params,
1273
- "token_limit": max_tokens,
1274
- "model_cost": {"input": 0, "output": 0},
1275
- }
1276
- elif model_type in EmbeddingModelType.OpenAI.values:
1277
- available_models = list(EMBEDDING_COST_PER_MODEL.keys())
1278
-
1279
- open_ai = typer.style("OpenAI", fg="green")
1280
- prompt = f"Enter the {open_ai} model name"
1281
-
1282
- model_name = typer.prompt(
1283
- prompt,
1284
- default="text-embedding-3-small",
1285
- type=click.types.Choice(available_models),
1286
- show_choices=False,
1287
- )
1288
- params = dict(
1289
- model=model_name,
1290
- )
1291
- max_tokens = EMBEDDING_TOKEN_LIMITS[model_name]
1292
- model_cost = EMBEDDING_COST_PER_MODEL[model_name]
1293
- cfg = {
1294
- "model_type": model_type,
1295
- "model_identifier": model_name,
1296
- "model_args": params,
1297
- "token_limit": max_tokens,
1298
- "model_cost": model_cost,
1299
- }
1300
- else:
1301
- raise ValueError(f"Unknown model type {model_type}")
1302
- with open(model_cfg, "w") as f:
1303
- json.dump(cfg, f, indent=2)
1304
- print(f"Model config written to {model_cfg}")
1305
-
1306
-
1307
- app.add_typer(db, name="db")
1308
- app.add_typer(llm, name="llm")
1309
- app.add_typer(evaluate, name="evaluate")
1310
- app.add_typer(embedding, name="embedding")
1311
-
1312
-
1313
- @app.command()
1314
- def render(
1315
- input_dir: Annotated[
1316
- str,
1317
- typer.Option(
1318
- "--input",
1319
- "-i",
1320
- ),
1321
- ],
1322
- output_dir: Annotated[str, typer.Option("--output", "-o")],
1323
- ):
1324
- input_dir = Path(input_dir)
1325
- output_dir = Path(output_dir)
1326
- for input_file in input_dir.rglob("*.json"):
1327
- with open(input_file, "r") as f:
1328
- data = json.load(f)
1329
-
1330
- output_file = output_dir / input_file.relative_to(input_dir).with_suffix(".txt")
1331
- if not output_file.parent.exists():
1332
- output_file.parent.mkdir()
1333
-
1334
- text = data["output"].replace("\\n", "\n").strip()
1335
- output_file.write_text(text)
1336
-
1337
- jar_path = homedir / ".janus/lib/plantuml.jar"
1338
- subprocess.run(["java", "-jar", jar_path, output_file]) # nosec
1339
- output_file.unlink()
1340
-
1341
-
1342
- if __name__ == "__main__":
1343
- app()