janus-llm 3.5.3__tar.gz → 4.1.0__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {janus_llm-3.5.3 → janus_llm-4.1.0}/PKG-INFO +1 -1
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/__init__.py +1 -1
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/cli.py +91 -48
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/converter/_tests/test_translate.py +2 -2
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/converter/converter.py +111 -142
- janus_llm-4.1.0/janus/converter/diagram.py +51 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/converter/translate.py +1 -1
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/alc/_tests/test_alc.py +1 -1
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/alc/alc.py +15 -10
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/binary/_tests/test_binary.py +1 -1
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/binary/binary.py +2 -2
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/mumps/_tests/test_mumps.py +1 -1
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/mumps/mumps.py +2 -3
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/splitter.py +2 -2
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/treesitter/_tests/test_treesitter.py +1 -1
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/treesitter/treesitter.py +2 -2
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/llm/model_callbacks.py +22 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/llm/models_info.py +142 -81
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/metrics/metric.py +15 -14
- janus_llm-4.1.0/janus/parsers/uml.py +88 -0
- janus_llm-4.1.0/janus/refiners/refiner.py +115 -0
- janus_llm-4.1.0/janus/retrievers/retriever.py +42 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/pyproject.toml +1 -1
- janus_llm-3.5.3/janus/converter/diagram.py +0 -139
- janus_llm-3.5.3/janus/parsers/refiner_parser.py +0 -46
- janus_llm-3.5.3/janus/parsers/uml.py +0 -51
- janus_llm-3.5.3/janus/refiners/refiner.py +0 -73
- {janus_llm-3.5.3 → janus_llm-4.1.0}/LICENSE +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/README.md +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/__main__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/_tests/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/_tests/conftest.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/_tests/test_cli.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/converter/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/converter/_tests/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/converter/aggregator.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/converter/document.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/converter/evaluate.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/converter/requirements.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/embedding/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/embedding/_tests/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/embedding/_tests/test_collections.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/embedding/_tests/test_database.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/embedding/_tests/test_vectorize.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/embedding/collections.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/embedding/database.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/embedding/embedding_models_info.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/embedding/vectorize.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/_tests/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/_tests/test_combine.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/_tests/test_splitter.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/alc/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/alc/_tests/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/binary/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/binary/_tests/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/binary/reveng/decompile_script.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/block.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/combine.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/file.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/mumps/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/mumps/_tests/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/mumps/patterns.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/naive/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/naive/basic_splitter.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/naive/chunk_splitter.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/naive/registry.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/naive/simple_ast.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/naive/tag_splitter.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/node.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/treesitter/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/language/treesitter/_tests/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/llm/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/metrics/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/metrics/_tests/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/metrics/_tests/reference.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/metrics/_tests/target.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/metrics/_tests/test_bleu.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/metrics/_tests/test_chrf.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/metrics/_tests/test_file_pairing.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/metrics/_tests/test_llm.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/metrics/_tests/test_reading.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/metrics/_tests/test_rouge_score.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/metrics/_tests/test_similarity_score.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/metrics/_tests/test_treesitter_metrics.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/metrics/bleu.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/metrics/chrf.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/metrics/cli.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/metrics/complexity_metrics.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/metrics/file_pairing.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/metrics/llm_metrics.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/metrics/reading.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/metrics/rouge_score.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/metrics/similarity.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/metrics/splitting.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/parsers/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/parsers/_tests/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/parsers/_tests/test_code_parser.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/parsers/code_parser.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/parsers/doc_parser.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/parsers/eval_parser.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/parsers/parser.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/parsers/reqs_parser.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/prompts/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/prompts/prompt.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/utils/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/utils/_tests/__init__.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/utils/_tests/test_logger.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/utils/_tests/test_progress.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/utils/enums.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/utils/logger.py +0 -0
- {janus_llm-3.5.3 → janus_llm-4.1.0}/janus/utils/progress.py +0 -0
@@ -5,7 +5,7 @@ from langchain_core._api.deprecation import LangChainDeprecationWarning
|
|
5
5
|
from janus.converter.translate import Translator
|
6
6
|
from janus.metrics import * # noqa: F403
|
7
7
|
|
8
|
-
__version__ = "
|
8
|
+
__version__ = "4.1.0"
|
9
9
|
|
10
10
|
# Ignoring a deprecation warning from langchain_core that I can't seem to hunt down
|
11
11
|
warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
|
@@ -39,10 +39,12 @@ from janus.llm.models_info import (
|
|
39
39
|
MODEL_TYPE_CONSTRUCTORS,
|
40
40
|
MODEL_TYPES,
|
41
41
|
TOKEN_LIMITS,
|
42
|
+
azure_models,
|
42
43
|
bedrock_models,
|
43
44
|
openai_models,
|
44
45
|
)
|
45
46
|
from janus.metrics.cli import evaluate
|
47
|
+
from janus.refiners.refiner import REFINERS
|
46
48
|
from janus.utils.enums import LANGUAGES
|
47
49
|
from janus.utils.logger import create_logger
|
48
50
|
|
@@ -242,6 +244,24 @@ def translate(
|
|
242
244
|
click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
|
243
245
|
),
|
244
246
|
] = "file",
|
247
|
+
refiner_type: Annotated[
|
248
|
+
str,
|
249
|
+
typer.Option(
|
250
|
+
"-r",
|
251
|
+
"--refiner",
|
252
|
+
help="Name of custom refiner to use",
|
253
|
+
click_type=click.Choice(list(REFINERS.keys())),
|
254
|
+
),
|
255
|
+
] = "none",
|
256
|
+
retriever_type: Annotated[
|
257
|
+
str,
|
258
|
+
typer.Option(
|
259
|
+
"-R",
|
260
|
+
"--retriever",
|
261
|
+
help="Name of custom retriever to use",
|
262
|
+
click_type=click.Choice(["active_usings"]),
|
263
|
+
),
|
264
|
+
] = None,
|
245
265
|
max_tokens: Annotated[
|
246
266
|
int,
|
247
267
|
typer.Option(
|
@@ -251,13 +271,6 @@ def translate(
|
|
251
271
|
"If unspecificed, model's default max will be used.",
|
252
272
|
),
|
253
273
|
] = None,
|
254
|
-
skip_refiner: Annotated[
|
255
|
-
bool,
|
256
|
-
typer.Option(
|
257
|
-
"--skip-refiner",
|
258
|
-
help="Whether to skip the refiner for generating output",
|
259
|
-
),
|
260
|
-
] = True,
|
261
274
|
):
|
262
275
|
try:
|
263
276
|
target_language, target_version = target_lang.split("-")
|
@@ -283,8 +296,8 @@ def translate(
|
|
283
296
|
db_path=db_loc,
|
284
297
|
db_config=collections_config,
|
285
298
|
splitter_type=splitter_type,
|
286
|
-
|
287
|
-
|
299
|
+
refiner_type=refiner_type,
|
300
|
+
retriever_type=retriever_type,
|
288
301
|
)
|
289
302
|
translator.translate(input_dir, output_dir, overwrite, collection)
|
290
303
|
|
@@ -342,14 +355,6 @@ def document(
|
|
342
355
|
help="Whether to overwrite existing files in the output directory",
|
343
356
|
),
|
344
357
|
] = False,
|
345
|
-
skip_context: Annotated[
|
346
|
-
bool,
|
347
|
-
typer.Option(
|
348
|
-
"--skip-context",
|
349
|
-
help="Prompts will include any context information associated with source"
|
350
|
-
" code blocks, unless this option is specified",
|
351
|
-
),
|
352
|
-
] = False,
|
353
358
|
doc_mode: Annotated[
|
354
359
|
str,
|
355
360
|
typer.Option(
|
@@ -397,6 +402,24 @@ def document(
|
|
397
402
|
click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
|
398
403
|
),
|
399
404
|
] = "file",
|
405
|
+
refiner_type: Annotated[
|
406
|
+
str,
|
407
|
+
typer.Option(
|
408
|
+
"-r",
|
409
|
+
"--refiner",
|
410
|
+
help="Name of custom refiner to use",
|
411
|
+
click_type=click.Choice(list(REFINERS.keys())),
|
412
|
+
),
|
413
|
+
] = "none",
|
414
|
+
retriever_type: Annotated[
|
415
|
+
str,
|
416
|
+
typer.Option(
|
417
|
+
"-R",
|
418
|
+
"--retriever",
|
419
|
+
help="Name of custom retriever to use",
|
420
|
+
click_type=click.Choice(["active_usings"]),
|
421
|
+
),
|
422
|
+
] = None,
|
400
423
|
max_tokens: Annotated[
|
401
424
|
int,
|
402
425
|
typer.Option(
|
@@ -406,13 +429,6 @@ def document(
|
|
406
429
|
"If unspecificed, model's default max will be used.",
|
407
430
|
),
|
408
431
|
] = None,
|
409
|
-
skip_refiner: Annotated[
|
410
|
-
bool,
|
411
|
-
typer.Option(
|
412
|
-
"--skip-refiner",
|
413
|
-
help="Whether to skip the refiner for generating output",
|
414
|
-
),
|
415
|
-
] = True,
|
416
432
|
):
|
417
433
|
model_arguments = dict(temperature=temperature)
|
418
434
|
collections_config = get_collections_config()
|
@@ -425,8 +441,8 @@ def document(
|
|
425
441
|
db_path=db_loc,
|
426
442
|
db_config=collections_config,
|
427
443
|
splitter_type=splitter_type,
|
428
|
-
|
429
|
-
|
444
|
+
refiner_type=refiner_type,
|
445
|
+
retriever_type=retriever_type,
|
430
446
|
)
|
431
447
|
if doc_mode == "madlibs":
|
432
448
|
documenter = MadLibsDocumenter(
|
@@ -615,14 +631,6 @@ def diagram(
|
|
615
631
|
help="Whether to overwrite existing files in the output directory",
|
616
632
|
),
|
617
633
|
] = False,
|
618
|
-
skip_context: Annotated[
|
619
|
-
bool,
|
620
|
-
typer.Option(
|
621
|
-
"--skip-context",
|
622
|
-
help="Prompts will include any context information associated with source"
|
623
|
-
" code blocks, unless this option is specified",
|
624
|
-
),
|
625
|
-
] = False,
|
626
634
|
temperature: Annotated[
|
627
635
|
float,
|
628
636
|
typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
|
@@ -659,13 +667,24 @@ def diagram(
|
|
659
667
|
click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
|
660
668
|
),
|
661
669
|
] = "file",
|
662
|
-
|
663
|
-
|
670
|
+
refiner_type: Annotated[
|
671
|
+
str,
|
664
672
|
typer.Option(
|
665
|
-
"
|
666
|
-
|
673
|
+
"-r",
|
674
|
+
"--refiner",
|
675
|
+
help="Name of custom refiner to use",
|
676
|
+
click_type=click.Choice(list(REFINERS.keys())),
|
667
677
|
),
|
668
|
-
] =
|
678
|
+
] = "none",
|
679
|
+
retriever_type: Annotated[
|
680
|
+
str,
|
681
|
+
typer.Option(
|
682
|
+
"-R",
|
683
|
+
"--retriever",
|
684
|
+
help="Name of custom retriever to use",
|
685
|
+
click_type=click.Choice(["active_usings"]),
|
686
|
+
),
|
687
|
+
] = None,
|
669
688
|
):
|
670
689
|
model_arguments = dict(temperature=temperature)
|
671
690
|
collections_config = get_collections_config()
|
@@ -676,11 +695,11 @@ def diagram(
|
|
676
695
|
max_prompts=max_prompts,
|
677
696
|
db_path=db_loc,
|
678
697
|
db_config=collections_config,
|
698
|
+
splitter_type=splitter_type,
|
699
|
+
refiner_type=refiner_type,
|
700
|
+
retriever_type=retriever_type,
|
679
701
|
diagram_type=diagram_type,
|
680
702
|
add_documentation=add_documentation,
|
681
|
-
splitter_type=splitter_type,
|
682
|
-
skip_refiner=skip_refiner,
|
683
|
-
skip_context=skip_context,
|
684
703
|
)
|
685
704
|
diagram_generator.translate(input_dir, output_dir, overwrite, collection)
|
686
705
|
|
@@ -934,7 +953,7 @@ def llm_add(
|
|
934
953
|
help="The type of the model",
|
935
954
|
click_type=click.Choice(sorted(list(MODEL_TYPE_CONSTRUCTORS.keys()))),
|
936
955
|
),
|
937
|
-
] = "
|
956
|
+
] = "Azure",
|
938
957
|
):
|
939
958
|
if not MODEL_CONFIG_DIR.exists():
|
940
959
|
MODEL_CONFIG_DIR.mkdir(parents=True)
|
@@ -978,6 +997,7 @@ def llm_add(
|
|
978
997
|
"model_cost": {"input": in_cost, "output": out_cost},
|
979
998
|
}
|
980
999
|
elif model_type == "OpenAI":
|
1000
|
+
print("DEPRECATED: Use 'Azure' instead. CTRL+C to exit.")
|
981
1001
|
model_id = typer.prompt(
|
982
1002
|
"Enter the model ID (list model IDs with `janus llm ls -a`)",
|
983
1003
|
default="gpt-4o",
|
@@ -999,6 +1019,28 @@ def llm_add(
|
|
999
1019
|
"token_limit": max_tokens,
|
1000
1020
|
"model_cost": model_cost,
|
1001
1021
|
}
|
1022
|
+
elif model_type == "Azure":
|
1023
|
+
model_id = typer.prompt(
|
1024
|
+
"Enter the model ID (list model IDs with `janus llm ls -a`)",
|
1025
|
+
default="gpt-4o",
|
1026
|
+
type=click.Choice(azure_models),
|
1027
|
+
show_choices=False,
|
1028
|
+
)
|
1029
|
+
params = dict(
|
1030
|
+
# Azure uses the "azure_deployment" key for what we're calling "long_model_id"
|
1031
|
+
azure_deployment=MODEL_ID_TO_LONG_ID[model_id],
|
1032
|
+
temperature=0.7,
|
1033
|
+
n=1,
|
1034
|
+
)
|
1035
|
+
max_tokens = TOKEN_LIMITS[MODEL_ID_TO_LONG_ID[model_id]]
|
1036
|
+
model_cost = COST_PER_1K_TOKENS[MODEL_ID_TO_LONG_ID[model_id]]
|
1037
|
+
cfg = {
|
1038
|
+
"model_type": model_type,
|
1039
|
+
"model_id": model_id,
|
1040
|
+
"model_args": params,
|
1041
|
+
"token_limit": max_tokens,
|
1042
|
+
"model_cost": model_cost,
|
1043
|
+
}
|
1002
1044
|
elif model_type == "BedrockChat" or model_type == "Bedrock":
|
1003
1045
|
model_id = typer.prompt(
|
1004
1046
|
"Enter the model ID (list model IDs with `janus llm ls -a`)",
|
@@ -1173,13 +1215,14 @@ def render(
|
|
1173
1215
|
for input_file in input_dir.rglob("*.json"):
|
1174
1216
|
with open(input_file, "r") as f:
|
1175
1217
|
data = json.load(f)
|
1176
|
-
|
1177
|
-
output_file = output_dir /
|
1178
|
-
output_file = output_file.with_suffix(".txt")
|
1218
|
+
|
1219
|
+
output_file = output_dir / input_file.relative_to(input_dir).with_suffix(".txt")
|
1179
1220
|
if not output_file.parent.exists():
|
1180
1221
|
output_file.parent.mkdir()
|
1181
|
-
|
1182
|
-
|
1222
|
+
|
1223
|
+
text = data["output"].replace("\\n", "\n").strip()
|
1224
|
+
output_file.write_text(text)
|
1225
|
+
|
1183
1226
|
jar_path = homedir / ".janus/lib/plantuml.jar"
|
1184
1227
|
subprocess.run(["java", "-jar", jar_path, output_file]) # nosec
|
1185
1228
|
output_file.unlink()
|
@@ -90,14 +90,14 @@ class TestDiagramGenerator(unittest.TestCase):
|
|
90
90
|
def setUp(self):
|
91
91
|
"""Set up the tests."""
|
92
92
|
self.diagram_generator = DiagramGenerator(
|
93
|
-
model="gpt-4o",
|
93
|
+
model="gpt-4o-mini",
|
94
94
|
source_language="fortran",
|
95
95
|
diagram_type="Activity",
|
96
96
|
)
|
97
97
|
|
98
98
|
def test_init(self):
|
99
99
|
"""Test __init__ method."""
|
100
|
-
self.assertEqual(self.diagram_generator._model_name, "gpt-4o")
|
100
|
+
self.assertEqual(self.diagram_generator._model_name, "gpt-4o-mini")
|
101
101
|
self.assertEqual(self.diagram_generator._source_language, "fortran")
|
102
102
|
self.assertEqual(self.diagram_generator._diagram_type, "Activity")
|
103
103
|
|
@@ -2,13 +2,11 @@ import functools
|
|
2
2
|
import json
|
3
3
|
import time
|
4
4
|
from pathlib import Path
|
5
|
-
from typing import Any
|
5
|
+
from typing import Any
|
6
6
|
|
7
|
-
from langchain.output_parsers import RetryWithErrorOutputParser
|
8
7
|
from langchain_core.exceptions import OutputParserException
|
9
|
-
from langchain_core.
|
10
|
-
from langchain_core.
|
11
|
-
from langchain_core.runnables import RunnableLambda, RunnableParallel
|
8
|
+
from langchain_core.prompts import ChatPromptTemplate
|
9
|
+
from langchain_core.runnables import Runnable, RunnableParallel, RunnablePassthrough
|
12
10
|
from openai import BadRequestError, RateLimitError
|
13
11
|
from pydantic import ValidationError
|
14
12
|
|
@@ -22,12 +20,18 @@ from janus.language.splitter import (
|
|
22
20
|
Splitter,
|
23
21
|
TokenLimitError,
|
24
22
|
)
|
25
|
-
from janus.llm import load_model
|
26
23
|
from janus.llm.model_callbacks import get_model_callback
|
27
|
-
from janus.llm.models_info import MODEL_PROMPT_ENGINES
|
24
|
+
from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel, load_model
|
28
25
|
from janus.parsers.parser import GenericParser, JanusParser
|
29
|
-
from janus.
|
30
|
-
|
26
|
+
from janus.refiners.refiner import (
|
27
|
+
FixParserExceptions,
|
28
|
+
HallucinationRefiner,
|
29
|
+
JanusRefiner,
|
30
|
+
ReflectionRefiner,
|
31
|
+
)
|
32
|
+
|
33
|
+
# from janus.refiners.refiner import BasicRefiner, Refiner
|
34
|
+
from janus.retrievers.retriever import ActiveUsingsRetriever, JanusRetriever
|
31
35
|
from janus.utils.enums import LANGUAGES
|
32
36
|
from janus.utils.logger import create_logger
|
33
37
|
|
@@ -74,9 +78,8 @@ class Converter:
|
|
74
78
|
protected_node_types: tuple[str, ...] = (),
|
75
79
|
prune_node_types: tuple[str, ...] = (),
|
76
80
|
splitter_type: str = "file",
|
77
|
-
refiner_type: str =
|
78
|
-
|
79
|
-
skip_context: bool = False,
|
81
|
+
refiner_type: str | None = None,
|
82
|
+
retriever_type: str | None = None,
|
80
83
|
) -> None:
|
81
84
|
"""Initialize a Converter instance.
|
82
85
|
|
@@ -96,9 +99,13 @@ class Converter:
|
|
96
99
|
prune_node_types: A set of node types which should be pruned.
|
97
100
|
splitter_type: The type of splitter to use. Valid values are `"file"`,
|
98
101
|
`"tag"`, `"chunk"`, `"ast-strict"`, and `"ast-flex"`.
|
99
|
-
refiner_type: The type of refiner to use. Valid values
|
100
|
-
|
101
|
-
|
102
|
+
refiner_type: The type of refiner to use. Valid values:
|
103
|
+
- "parser"
|
104
|
+
- "reflection"
|
105
|
+
- None
|
106
|
+
retriever_type: The type of retriever to use. Valid values:
|
107
|
+
- "active_usings"
|
108
|
+
- None
|
102
109
|
"""
|
103
110
|
self._changed_attrs: set = set()
|
104
111
|
|
@@ -107,7 +114,6 @@ class Converter:
|
|
107
114
|
self.override_token_limit: bool = max_tokens is not None
|
108
115
|
|
109
116
|
self._model_name: str
|
110
|
-
self._model_id: str
|
111
117
|
self._custom_model_arguments: dict[str, Any]
|
112
118
|
|
113
119
|
self._source_language: str
|
@@ -120,24 +126,26 @@ class Converter:
|
|
120
126
|
self._prune_node_types: tuple[str, ...] = ()
|
121
127
|
self._max_tokens: int | None = max_tokens
|
122
128
|
self._prompt_template_name: str
|
123
|
-
self._splitter_type: str
|
124
129
|
self._db_path: str | None
|
125
130
|
self._db_config: dict[str, Any] | None
|
126
131
|
|
127
|
-
self.
|
128
|
-
self._llm: BaseLanguageModel
|
132
|
+
self._llm: JanusModel
|
129
133
|
self._prompt: ChatPromptTemplate
|
130
134
|
|
131
135
|
self._parser: JanusParser = GenericParser()
|
132
136
|
self._combiner: Combiner = Combiner()
|
133
137
|
|
134
|
-
self.
|
135
|
-
self.
|
138
|
+
self._splitter_type: str
|
139
|
+
self._refiner_type: str | None
|
140
|
+
self._retriever_type: str | None
|
136
141
|
|
137
|
-
self.
|
142
|
+
self._splitter: Splitter
|
143
|
+
self._refiner: JanusRefiner
|
144
|
+
self._retriever: JanusRetriever
|
138
145
|
|
139
146
|
self.set_splitter(splitter_type=splitter_type)
|
140
147
|
self.set_refiner(refiner_type=refiner_type)
|
148
|
+
self.set_retriever(retriever_type=retriever_type)
|
141
149
|
self.set_model(model_name=model, **model_arguments)
|
142
150
|
self.set_prompt(prompt_template=prompt_template)
|
143
151
|
self.set_source_language(source_language)
|
@@ -146,8 +154,6 @@ class Converter:
|
|
146
154
|
self.set_db_path(db_path=db_path)
|
147
155
|
self.set_db_config(db_config=db_config)
|
148
156
|
|
149
|
-
self.skip_context = skip_context
|
150
|
-
|
151
157
|
# Child class must call this. Should we enforce somehow?
|
152
158
|
# self._load_parameters()
|
153
159
|
|
@@ -163,9 +169,11 @@ class Converter:
|
|
163
169
|
def _load_parameters(self) -> None:
|
164
170
|
self._load_model()
|
165
171
|
self._load_prompt()
|
172
|
+
self._load_retriever()
|
173
|
+
self._load_refiner()
|
166
174
|
self._load_splitter()
|
167
175
|
self._load_vectorizer()
|
168
|
-
self.
|
176
|
+
self._load_chain()
|
169
177
|
self._changed_attrs.clear()
|
170
178
|
|
171
179
|
def set_model(self, model_name: str, **custom_arguments: dict[str, Any]):
|
@@ -184,8 +192,6 @@ class Converter:
|
|
184
192
|
def set_prompt(self, prompt_template: str) -> None:
|
185
193
|
"""Validate and set the prompt template name.
|
186
194
|
|
187
|
-
The affected objects will not be updated until translate() is called.
|
188
|
-
|
189
195
|
Arguments:
|
190
196
|
prompt_template: name of prompt template directory
|
191
197
|
(see janus/prompts/templates) or path to a directory.
|
@@ -195,29 +201,34 @@ class Converter:
|
|
195
201
|
def set_splitter(self, splitter_type: str) -> None:
|
196
202
|
"""Validate and set the prompt template name.
|
197
203
|
|
198
|
-
The affected objects will not be updated until translate() is called.
|
199
|
-
|
200
204
|
Arguments:
|
201
205
|
prompt_template: name of prompt template directory
|
202
206
|
(see janus/prompts/templates) or path to a directory.
|
203
207
|
"""
|
204
|
-
|
208
|
+
if splitter_type not in CUSTOM_SPLITTERS:
|
209
|
+
raise ValueError(f'Splitter type "{splitter_type}" does not exist.')
|
205
210
|
|
206
|
-
|
207
|
-
"""Validate and set the refiner name
|
211
|
+
self._splitter_type = splitter_type
|
208
212
|
|
209
|
-
|
213
|
+
def set_refiner(self, refiner_type: str | None) -> None:
|
214
|
+
"""Validate and set the refiner type
|
210
215
|
|
211
216
|
Arguments:
|
212
|
-
refiner_type: the
|
217
|
+
refiner_type: the type of refiner to use
|
213
218
|
"""
|
214
219
|
self._refiner_type = refiner_type
|
215
220
|
|
221
|
+
def set_retriever(self, retriever_type: str | None) -> None:
|
222
|
+
"""Validate and set the retriever type
|
223
|
+
|
224
|
+
Arguments:
|
225
|
+
retriever_type: the type of retriever to use
|
226
|
+
"""
|
227
|
+
self._retriever_type = retriever_type
|
228
|
+
|
216
229
|
def set_source_language(self, source_language: str) -> None:
|
217
230
|
"""Validate and set the source language.
|
218
231
|
|
219
|
-
The affected objects will not be updated until _load_parameters() is called.
|
220
|
-
|
221
232
|
Arguments:
|
222
233
|
source_language: The source programming language.
|
223
234
|
"""
|
@@ -287,20 +298,6 @@ class Converter:
|
|
287
298
|
|
288
299
|
self._splitter = CUSTOM_SPLITTERS[self._splitter_type](**kwargs)
|
289
300
|
|
290
|
-
@run_if_changed("_refiner_type", "_model_name")
|
291
|
-
def _load_refiner(self) -> None:
|
292
|
-
"""Load the refiner according to this instance's attributes.
|
293
|
-
|
294
|
-
If the relevant fields have not been changed since the last time this method was
|
295
|
-
called, nothing happens.
|
296
|
-
"""
|
297
|
-
if self._refiner_type == "basic":
|
298
|
-
self._refiner = BasicRefiner(
|
299
|
-
"basic_refinement", self._model_id, self._source_language
|
300
|
-
)
|
301
|
-
else:
|
302
|
-
raise ValueError(f"Error: unknown refiner type {self._refiner_type}")
|
303
|
-
|
304
301
|
@run_if_changed("_model_name", "_custom_model_arguments")
|
305
302
|
def _load_model(self) -> None:
|
306
303
|
"""Load the model according to this instance's attributes.
|
@@ -314,9 +311,9 @@ class Converter:
|
|
314
311
|
# model_arguments.update(self._custom_model_arguments)
|
315
312
|
|
316
313
|
# Load the model
|
317
|
-
self._llm
|
318
|
-
|
319
|
-
|
314
|
+
self._llm = load_model(self._model_name)
|
315
|
+
token_limit = self._llm.token_limit
|
316
|
+
|
320
317
|
# Set the max_tokens to less than half the model's limit to allow for enough
|
321
318
|
# tokens at output
|
322
319
|
# Only modify max_tokens if it is not specified by user
|
@@ -335,7 +332,7 @@ class Converter:
|
|
335
332
|
If the relevant fields have not been changed since the last time this
|
336
333
|
method was called, nothing happens.
|
337
334
|
"""
|
338
|
-
prompt_engine = MODEL_PROMPT_ENGINES[self.
|
335
|
+
prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
|
339
336
|
source_language=self._source_language,
|
340
337
|
prompt_template=self._prompt_template_name,
|
341
338
|
)
|
@@ -354,6 +351,59 @@ class Converter:
|
|
354
351
|
self._db_path, self._db_config
|
355
352
|
)
|
356
353
|
|
354
|
+
@run_if_changed("_retriever_type")
|
355
|
+
def _load_retriever(self):
|
356
|
+
if self._retriever_type == "active_usings":
|
357
|
+
self._retriever = ActiveUsingsRetriever()
|
358
|
+
else:
|
359
|
+
self._retriever = JanusRetriever()
|
360
|
+
|
361
|
+
@run_if_changed("_refiner_type", "_model_name", "max_prompts", "_parser", "_llm")
|
362
|
+
def _load_refiner(self) -> None:
|
363
|
+
"""Load the refiner according to this instance's attributes.
|
364
|
+
|
365
|
+
If the relevant fields have not been changed since the last time this method was
|
366
|
+
called, nothing happens.
|
367
|
+
"""
|
368
|
+
if self._refiner_type == "parser":
|
369
|
+
self._refiner = FixParserExceptions(
|
370
|
+
llm=self._llm,
|
371
|
+
parser=self._parser,
|
372
|
+
max_retries=self.max_prompts,
|
373
|
+
)
|
374
|
+
elif self._refiner_type == "reflection":
|
375
|
+
self._refiner = ReflectionRefiner(
|
376
|
+
llm=self._llm,
|
377
|
+
parser=self._parser,
|
378
|
+
max_retries=self.max_prompts,
|
379
|
+
)
|
380
|
+
elif self._refiner_type == "hallucination":
|
381
|
+
self._refiner = HallucinationRefiner(
|
382
|
+
llm=self._llm,
|
383
|
+
parser=self._parser,
|
384
|
+
max_retries=self.max_prompts,
|
385
|
+
)
|
386
|
+
else:
|
387
|
+
self._refiner = JanusRefiner(parser=self._parser)
|
388
|
+
|
389
|
+
@run_if_changed("_parser", "_retriever", "_prompt", "_llm", "_refiner")
|
390
|
+
def _load_chain(self):
|
391
|
+
self.chain = (
|
392
|
+
self._input_runnable()
|
393
|
+
| self._prompt
|
394
|
+
| RunnableParallel(
|
395
|
+
completion=self._llm,
|
396
|
+
prompt_value=RunnablePassthrough(),
|
397
|
+
)
|
398
|
+
| self._refiner.parse_runnable
|
399
|
+
)
|
400
|
+
|
401
|
+
def _input_runnable(self) -> Runnable:
|
402
|
+
return RunnableParallel(
|
403
|
+
SOURCE_CODE=self._parser.parse_input,
|
404
|
+
context=self._retriever,
|
405
|
+
)
|
406
|
+
|
357
407
|
def translate(
|
358
408
|
self,
|
359
409
|
input_directory: str | Path,
|
@@ -598,110 +648,29 @@ class Converter:
|
|
598
648
|
return root
|
599
649
|
|
600
650
|
def _run_chain(self, block: TranslatedCodeBlock) -> str:
|
601
|
-
|
602
|
-
First, try to fix simple formatting errors by giving the model just
|
603
|
-
the output and the parsing error. After a number of attempts, try
|
604
|
-
giving the model the output, the parsing error, and the original
|
605
|
-
input. Again check/retry this output to solve for formatting errors.
|
606
|
-
If we still haven't succeeded after several attempts, the model may
|
607
|
-
be getting thrown off by a bad initial output; start from scratch
|
608
|
-
and try again.
|
609
|
-
|
610
|
-
The number of tries for each layer of this scheme is roughly equal
|
611
|
-
to the cube root of self.max_retries, so the total calls to the
|
612
|
-
LLM will be roughly as expected (up to sqrt(self.max_retries) over)
|
613
|
-
"""
|
614
|
-
input = self._parser.parse_input(block.original)
|
615
|
-
|
616
|
-
# Retries with just the output and the error
|
617
|
-
n1 = round(self.max_prompts ** (1 / 2))
|
618
|
-
|
619
|
-
# Retries with the input, output, and error
|
620
|
-
n2 = round(self.max_prompts // n1)
|
621
|
-
|
622
|
-
if not self.skip_context:
|
623
|
-
self._make_prompt_additions(block)
|
624
|
-
if not self.skip_refiner: # Make replacements in the prompt
|
625
|
-
refine_output = RefinerParser(
|
626
|
-
parser=self._parser,
|
627
|
-
initial_prompt=self._prompt.format(**{"SOURCE_CODE": input}),
|
628
|
-
refiner=self._refiner,
|
629
|
-
max_retries=n1,
|
630
|
-
llm=self._llm,
|
631
|
-
)
|
632
|
-
else:
|
633
|
-
refine_output = RetryWithErrorOutputParser.from_llm(
|
634
|
-
llm=self._llm,
|
635
|
-
parser=self._parser,
|
636
|
-
max_retries=n1,
|
637
|
-
)
|
638
|
-
|
639
|
-
completion_chain = self._prompt | self._llm
|
640
|
-
chain = RunnableParallel(
|
641
|
-
completion=completion_chain, prompt_value=self._prompt
|
642
|
-
) | RunnableLambda(lambda x: refine_output.parse_with_prompt(**x))
|
643
|
-
for _ in range(n2):
|
644
|
-
try:
|
645
|
-
return chain.invoke({"SOURCE_CODE": input})
|
646
|
-
except OutputParserException:
|
647
|
-
pass
|
648
|
-
|
649
|
-
raise OutputParserException(f"Failed to parse after {n1*n2} retries")
|
651
|
+
return self.chain.invoke(block.original)
|
650
652
|
|
651
653
|
def _get_output_obj(
|
652
654
|
self, block: TranslatedCodeBlock
|
653
|
-
) -> dict[str, int | float | str | dict[str, str]]:
|
655
|
+
) -> dict[str, int | float | str | dict[str, str] | dict[str, float]]:
|
654
656
|
output_str = self._parser.parse_combined_output(block.complete_text)
|
655
657
|
|
656
|
-
|
658
|
+
output_obj: str | dict[str, str]
|
657
659
|
try:
|
658
|
-
|
660
|
+
output_obj = json.loads(output_str)
|
659
661
|
except json.JSONDecodeError:
|
660
|
-
|
662
|
+
output_obj = output_str
|
661
663
|
|
662
664
|
return dict(
|
663
|
-
input=block.original.text,
|
665
|
+
input=block.original.text or "",
|
664
666
|
metadata=dict(
|
665
667
|
retries=block.total_retries,
|
666
668
|
cost=block.total_cost,
|
667
669
|
processing_time=block.processing_time,
|
668
670
|
),
|
669
|
-
output=
|
670
|
-
)
|
671
|
-
|
672
|
-
@staticmethod
|
673
|
-
def _get_prompt_additions(block) -> Optional[List[Tuple[str, str]]]:
|
674
|
-
"""Get a list of strings to append to the prompt.
|
675
|
-
|
676
|
-
Arguments:
|
677
|
-
block: The `TranslatedCodeBlock` to save to a file.
|
678
|
-
"""
|
679
|
-
return [(key, item) for key, item in block.context_tags.items()]
|
680
|
-
|
681
|
-
def _make_prompt_additions(self, block: CodeBlock):
|
682
|
-
# Prepare the additional context to prepend
|
683
|
-
additional_context = "".join(
|
684
|
-
[
|
685
|
-
f"{context_tag}: {context}\n"
|
686
|
-
for context_tag, context in self._get_prompt_additions(block)
|
687
|
-
]
|
671
|
+
output=output_obj,
|
688
672
|
)
|
689
673
|
|
690
|
-
if not hasattr(self._prompt, "messages"):
|
691
|
-
log.debug("Skipping additions to prompt, no messages found on prompt object!")
|
692
|
-
return
|
693
|
-
|
694
|
-
# Iterate through existing messages to find and update the system message
|
695
|
-
for i, message in enumerate(self._prompt.messages):
|
696
|
-
if isinstance(message, SystemMessagePromptTemplate):
|
697
|
-
# Prepend the additional context to the system message
|
698
|
-
updated_system_message = SystemMessagePromptTemplate.from_template(
|
699
|
-
additional_context + message.prompt.template
|
700
|
-
)
|
701
|
-
# Directly modify the message in the list
|
702
|
-
self._prompt.messages[i] = updated_system_message
|
703
|
-
break # Assuming there's only one system message to update
|
704
|
-
|
705
674
|
def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
|
706
675
|
"""Save a file to disk.
|
707
676
|
|