janus-llm 3.5.3__tar.gz → 4.0.0__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (112) hide show
  1. {janus_llm-3.5.3 → janus_llm-4.0.0}/PKG-INFO +1 -1
  2. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/__init__.py +1 -1
  3. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/cli.py +66 -47
  4. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/converter/converter.py +111 -142
  5. janus_llm-4.0.0/janus/converter/diagram.py +51 -0
  6. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/converter/translate.py +1 -1
  7. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/alc/_tests/test_alc.py +1 -1
  8. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/alc/alc.py +15 -10
  9. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/binary/_tests/test_binary.py +1 -1
  10. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/binary/binary.py +2 -2
  11. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/mumps/_tests/test_mumps.py +1 -1
  12. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/mumps/mumps.py +2 -3
  13. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/splitter.py +2 -2
  14. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/treesitter/_tests/test_treesitter.py +1 -1
  15. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/treesitter/treesitter.py +2 -2
  16. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/llm/model_callbacks.py +13 -0
  17. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/llm/models_info.py +111 -71
  18. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/metrics/metric.py +15 -14
  19. janus_llm-4.0.0/janus/parsers/uml.py +88 -0
  20. janus_llm-4.0.0/janus/refiners/refiner.py +115 -0
  21. janus_llm-4.0.0/janus/retrievers/retriever.py +42 -0
  22. {janus_llm-3.5.3 → janus_llm-4.0.0}/pyproject.toml +1 -1
  23. janus_llm-3.5.3/janus/converter/diagram.py +0 -139
  24. janus_llm-3.5.3/janus/parsers/refiner_parser.py +0 -46
  25. janus_llm-3.5.3/janus/parsers/uml.py +0 -51
  26. janus_llm-3.5.3/janus/refiners/refiner.py +0 -73
  27. {janus_llm-3.5.3 → janus_llm-4.0.0}/LICENSE +0 -0
  28. {janus_llm-3.5.3 → janus_llm-4.0.0}/README.md +0 -0
  29. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/__main__.py +0 -0
  30. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/_tests/__init__.py +0 -0
  31. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/_tests/conftest.py +0 -0
  32. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/_tests/test_cli.py +0 -0
  33. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/converter/__init__.py +0 -0
  34. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/converter/_tests/__init__.py +0 -0
  35. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/converter/_tests/test_translate.py +0 -0
  36. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/converter/aggregator.py +0 -0
  37. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/converter/document.py +0 -0
  38. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/converter/evaluate.py +0 -0
  39. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/converter/requirements.py +0 -0
  40. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/embedding/__init__.py +0 -0
  41. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/embedding/_tests/__init__.py +0 -0
  42. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/embedding/_tests/test_collections.py +0 -0
  43. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/embedding/_tests/test_database.py +0 -0
  44. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/embedding/_tests/test_vectorize.py +0 -0
  45. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/embedding/collections.py +0 -0
  46. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/embedding/database.py +0 -0
  47. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/embedding/embedding_models_info.py +0 -0
  48. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/embedding/vectorize.py +0 -0
  49. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/__init__.py +0 -0
  50. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/_tests/__init__.py +0 -0
  51. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/_tests/test_combine.py +0 -0
  52. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/_tests/test_splitter.py +0 -0
  53. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/alc/__init__.py +0 -0
  54. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/alc/_tests/__init__.py +0 -0
  55. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/binary/__init__.py +0 -0
  56. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/binary/_tests/__init__.py +0 -0
  57. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/binary/reveng/decompile_script.py +0 -0
  58. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/block.py +0 -0
  59. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/combine.py +0 -0
  60. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/file.py +0 -0
  61. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/mumps/__init__.py +0 -0
  62. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/mumps/_tests/__init__.py +0 -0
  63. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/mumps/patterns.py +0 -0
  64. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/naive/__init__.py +0 -0
  65. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/naive/basic_splitter.py +0 -0
  66. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/naive/chunk_splitter.py +0 -0
  67. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/naive/registry.py +0 -0
  68. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/naive/simple_ast.py +0 -0
  69. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/naive/tag_splitter.py +0 -0
  70. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/node.py +0 -0
  71. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/treesitter/__init__.py +0 -0
  72. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/language/treesitter/_tests/__init__.py +0 -0
  73. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/llm/__init__.py +0 -0
  74. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/metrics/__init__.py +0 -0
  75. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/metrics/_tests/__init__.py +0 -0
  76. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/metrics/_tests/reference.py +0 -0
  77. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/metrics/_tests/target.py +0 -0
  78. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/metrics/_tests/test_bleu.py +0 -0
  79. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/metrics/_tests/test_chrf.py +0 -0
  80. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/metrics/_tests/test_file_pairing.py +0 -0
  81. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/metrics/_tests/test_llm.py +0 -0
  82. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/metrics/_tests/test_reading.py +0 -0
  83. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/metrics/_tests/test_rouge_score.py +0 -0
  84. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/metrics/_tests/test_similarity_score.py +0 -0
  85. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/metrics/_tests/test_treesitter_metrics.py +0 -0
  86. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/metrics/bleu.py +0 -0
  87. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/metrics/chrf.py +0 -0
  88. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/metrics/cli.py +0 -0
  89. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/metrics/complexity_metrics.py +0 -0
  90. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/metrics/file_pairing.py +0 -0
  91. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/metrics/llm_metrics.py +0 -0
  92. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/metrics/reading.py +0 -0
  93. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/metrics/rouge_score.py +0 -0
  94. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/metrics/similarity.py +0 -0
  95. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/metrics/splitting.py +0 -0
  96. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/parsers/__init__.py +0 -0
  97. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/parsers/_tests/__init__.py +0 -0
  98. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/parsers/_tests/test_code_parser.py +0 -0
  99. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/parsers/code_parser.py +0 -0
  100. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/parsers/doc_parser.py +0 -0
  101. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/parsers/eval_parser.py +0 -0
  102. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/parsers/parser.py +0 -0
  103. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/parsers/reqs_parser.py +0 -0
  104. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/prompts/__init__.py +0 -0
  105. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/prompts/prompt.py +0 -0
  106. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/utils/__init__.py +0 -0
  107. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/utils/_tests/__init__.py +0 -0
  108. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/utils/_tests/test_logger.py +0 -0
  109. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/utils/_tests/test_progress.py +0 -0
  110. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/utils/enums.py +0 -0
  111. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/utils/logger.py +0 -0
  112. {janus_llm-3.5.3 → janus_llm-4.0.0}/janus/utils/progress.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: janus-llm
3
- Version: 3.5.3
3
+ Version: 4.0.0
4
4
  Summary: A transcoding library using LLMs.
5
5
  Home-page: https://github.com/janus-llm/janus-llm
6
6
  License: Apache 2.0
@@ -5,7 +5,7 @@ from langchain_core._api.deprecation import LangChainDeprecationWarning
5
5
  from janus.converter.translate import Translator
6
6
  from janus.metrics import * # noqa: F403
7
7
 
8
- __version__ = "3.5.3"
8
+ __version__ = "4.0.0"
9
9
 
10
10
  # Ignoring a deprecation warning from langchain_core that I can't seem to hunt down
11
11
  warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
@@ -43,6 +43,7 @@ from janus.llm.models_info import (
43
43
  openai_models,
44
44
  )
45
45
  from janus.metrics.cli import evaluate
46
+ from janus.refiners.refiner import REFINERS
46
47
  from janus.utils.enums import LANGUAGES
47
48
  from janus.utils.logger import create_logger
48
49
 
@@ -242,6 +243,24 @@ def translate(
242
243
  click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
243
244
  ),
244
245
  ] = "file",
246
+ refiner_type: Annotated[
247
+ str,
248
+ typer.Option(
249
+ "-r",
250
+ "--refiner",
251
+ help="Name of custom refiner to use",
252
+ click_type=click.Choice(list(REFINERS.keys())),
253
+ ),
254
+ ] = "none",
255
+ retriever_type: Annotated[
256
+ str,
257
+ typer.Option(
258
+ "-R",
259
+ "--retriever",
260
+ help="Name of custom retriever to use",
261
+ click_type=click.Choice(["active_usings"]),
262
+ ),
263
+ ] = None,
245
264
  max_tokens: Annotated[
246
265
  int,
247
266
  typer.Option(
@@ -251,13 +270,6 @@ def translate(
251
270
  "If unspecificed, model's default max will be used.",
252
271
  ),
253
272
  ] = None,
254
- skip_refiner: Annotated[
255
- bool,
256
- typer.Option(
257
- "--skip-refiner",
258
- help="Whether to skip the refiner for generating output",
259
- ),
260
- ] = True,
261
273
  ):
262
274
  try:
263
275
  target_language, target_version = target_lang.split("-")
@@ -283,8 +295,8 @@ def translate(
283
295
  db_path=db_loc,
284
296
  db_config=collections_config,
285
297
  splitter_type=splitter_type,
286
- skip_context=skip_context,
287
- skip_refiner=skip_refiner,
298
+ refiner_type=refiner_type,
299
+ retriever_type=retriever_type,
288
300
  )
289
301
  translator.translate(input_dir, output_dir, overwrite, collection)
290
302
 
@@ -342,14 +354,6 @@ def document(
342
354
  help="Whether to overwrite existing files in the output directory",
343
355
  ),
344
356
  ] = False,
345
- skip_context: Annotated[
346
- bool,
347
- typer.Option(
348
- "--skip-context",
349
- help="Prompts will include any context information associated with source"
350
- " code blocks, unless this option is specified",
351
- ),
352
- ] = False,
353
357
  doc_mode: Annotated[
354
358
  str,
355
359
  typer.Option(
@@ -397,6 +401,24 @@ def document(
397
401
  click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
398
402
  ),
399
403
  ] = "file",
404
+ refiner_type: Annotated[
405
+ str,
406
+ typer.Option(
407
+ "-r",
408
+ "--refiner",
409
+ help="Name of custom refiner to use",
410
+ click_type=click.Choice(list(REFINERS.keys())),
411
+ ),
412
+ ] = "none",
413
+ retriever_type: Annotated[
414
+ str,
415
+ typer.Option(
416
+ "-R",
417
+ "--retriever",
418
+ help="Name of custom retriever to use",
419
+ click_type=click.Choice(["active_usings"]),
420
+ ),
421
+ ] = None,
400
422
  max_tokens: Annotated[
401
423
  int,
402
424
  typer.Option(
@@ -406,13 +428,6 @@ def document(
406
428
  "If unspecificed, model's default max will be used.",
407
429
  ),
408
430
  ] = None,
409
- skip_refiner: Annotated[
410
- bool,
411
- typer.Option(
412
- "--skip-refiner",
413
- help="Whether to skip the refiner for generating output",
414
- ),
415
- ] = True,
416
431
  ):
417
432
  model_arguments = dict(temperature=temperature)
418
433
  collections_config = get_collections_config()
@@ -425,8 +440,8 @@ def document(
425
440
  db_path=db_loc,
426
441
  db_config=collections_config,
427
442
  splitter_type=splitter_type,
428
- skip_refiner=skip_refiner,
429
- skip_context=skip_context,
443
+ refiner_type=refiner_type,
444
+ retriever_type=retriever_type,
430
445
  )
431
446
  if doc_mode == "madlibs":
432
447
  documenter = MadLibsDocumenter(
@@ -615,14 +630,6 @@ def diagram(
615
630
  help="Whether to overwrite existing files in the output directory",
616
631
  ),
617
632
  ] = False,
618
- skip_context: Annotated[
619
- bool,
620
- typer.Option(
621
- "--skip-context",
622
- help="Prompts will include any context information associated with source"
623
- " code blocks, unless this option is specified",
624
- ),
625
- ] = False,
626
633
  temperature: Annotated[
627
634
  float,
628
635
  typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
@@ -659,13 +666,24 @@ def diagram(
659
666
  click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
660
667
  ),
661
668
  ] = "file",
662
- skip_refiner: Annotated[
663
- bool,
669
+ refiner_type: Annotated[
670
+ str,
671
+ typer.Option(
672
+ "-r",
673
+ "--refiner",
674
+ help="Name of custom refiner to use",
675
+ click_type=click.Choice(list(REFINERS.keys())),
676
+ ),
677
+ ] = "none",
678
+ retriever_type: Annotated[
679
+ str,
664
680
  typer.Option(
665
- "--skip-refiner",
666
- help="Whether to skip the refiner for generating output",
681
+ "-R",
682
+ "--retriever",
683
+ help="Name of custom retriever to use",
684
+ click_type=click.Choice(["active_usings"]),
667
685
  ),
668
- ] = True,
686
+ ] = None,
669
687
  ):
670
688
  model_arguments = dict(temperature=temperature)
671
689
  collections_config = get_collections_config()
@@ -676,11 +694,11 @@ def diagram(
676
694
  max_prompts=max_prompts,
677
695
  db_path=db_loc,
678
696
  db_config=collections_config,
697
+ splitter_type=splitter_type,
698
+ refiner_type=refiner_type,
699
+ retriever_type=retriever_type,
679
700
  diagram_type=diagram_type,
680
701
  add_documentation=add_documentation,
681
- splitter_type=splitter_type,
682
- skip_refiner=skip_refiner,
683
- skip_context=skip_context,
684
702
  )
685
703
  diagram_generator.translate(input_dir, output_dir, overwrite, collection)
686
704
 
@@ -1173,13 +1191,14 @@ def render(
1173
1191
  for input_file in input_dir.rglob("*.json"):
1174
1192
  with open(input_file, "r") as f:
1175
1193
  data = json.load(f)
1176
- input_tail = input_file.relative_to(input_dir)
1177
- output_file = output_dir / input_tail
1178
- output_file = output_file.with_suffix(".txt")
1194
+
1195
+ output_file = output_dir / input_file.relative_to(input_dir).with_suffix(".txt")
1179
1196
  if not output_file.parent.exists():
1180
1197
  output_file.parent.mkdir()
1181
- with open(output_file, "w") as f:
1182
- f.write(data["output"])
1198
+
1199
+ text = data["output"].replace("\\n", "\n").strip()
1200
+ output_file.write_text(text)
1201
+
1183
1202
  jar_path = homedir / ".janus/lib/plantuml.jar"
1184
1203
  subprocess.run(["java", "-jar", jar_path, output_file]) # nosec
1185
1204
  output_file.unlink()
@@ -2,13 +2,11 @@ import functools
2
2
  import json
3
3
  import time
4
4
  from pathlib import Path
5
- from typing import Any, List, Optional, Tuple
5
+ from typing import Any
6
6
 
7
- from langchain.output_parsers import RetryWithErrorOutputParser
8
7
  from langchain_core.exceptions import OutputParserException
9
- from langchain_core.language_models import BaseLanguageModel
10
- from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate
11
- from langchain_core.runnables import RunnableLambda, RunnableParallel
8
+ from langchain_core.prompts import ChatPromptTemplate
9
+ from langchain_core.runnables import Runnable, RunnableParallel, RunnablePassthrough
12
10
  from openai import BadRequestError, RateLimitError
13
11
  from pydantic import ValidationError
14
12
 
@@ -22,12 +20,18 @@ from janus.language.splitter import (
22
20
  Splitter,
23
21
  TokenLimitError,
24
22
  )
25
- from janus.llm import load_model
26
23
  from janus.llm.model_callbacks import get_model_callback
27
- from janus.llm.models_info import MODEL_PROMPT_ENGINES
24
+ from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel, load_model
28
25
  from janus.parsers.parser import GenericParser, JanusParser
29
- from janus.parsers.refiner_parser import RefinerParser
30
- from janus.refiners.refiner import BasicRefiner, Refiner
26
+ from janus.refiners.refiner import (
27
+ FixParserExceptions,
28
+ HallucinationRefiner,
29
+ JanusRefiner,
30
+ ReflectionRefiner,
31
+ )
32
+
33
+ # from janus.refiners.refiner import BasicRefiner, Refiner
34
+ from janus.retrievers.retriever import ActiveUsingsRetriever, JanusRetriever
31
35
  from janus.utils.enums import LANGUAGES
32
36
  from janus.utils.logger import create_logger
33
37
 
@@ -74,9 +78,8 @@ class Converter:
74
78
  protected_node_types: tuple[str, ...] = (),
75
79
  prune_node_types: tuple[str, ...] = (),
76
80
  splitter_type: str = "file",
77
- refiner_type: str = "basic",
78
- skip_refiner: bool = True,
79
- skip_context: bool = False,
81
+ refiner_type: str | None = None,
82
+ retriever_type: str | None = None,
80
83
  ) -> None:
81
84
  """Initialize a Converter instance.
82
85
 
@@ -96,9 +99,13 @@ class Converter:
96
99
  prune_node_types: A set of node types which should be pruned.
97
100
  splitter_type: The type of splitter to use. Valid values are `"file"`,
98
101
  `"tag"`, `"chunk"`, `"ast-strict"`, and `"ast-flex"`.
99
- refiner_type: The type of refiner to use. Valid values are `"basic"`.
100
- skip_refiner: Whether to skip the refiner.
101
- skip_context: Whether to skip adding context to the prompt.
102
+ refiner_type: The type of refiner to use. Valid values:
103
+ - "parser"
104
+ - "reflection"
105
+ - None
106
+ retriever_type: The type of retriever to use. Valid values:
107
+ - "active_usings"
108
+ - None
102
109
  """
103
110
  self._changed_attrs: set = set()
104
111
 
@@ -107,7 +114,6 @@ class Converter:
107
114
  self.override_token_limit: bool = max_tokens is not None
108
115
 
109
116
  self._model_name: str
110
- self._model_id: str
111
117
  self._custom_model_arguments: dict[str, Any]
112
118
 
113
119
  self._source_language: str
@@ -120,24 +126,26 @@ class Converter:
120
126
  self._prune_node_types: tuple[str, ...] = ()
121
127
  self._max_tokens: int | None = max_tokens
122
128
  self._prompt_template_name: str
123
- self._splitter_type: str
124
129
  self._db_path: str | None
125
130
  self._db_config: dict[str, Any] | None
126
131
 
127
- self._splitter: Splitter
128
- self._llm: BaseLanguageModel
132
+ self._llm: JanusModel
129
133
  self._prompt: ChatPromptTemplate
130
134
 
131
135
  self._parser: JanusParser = GenericParser()
132
136
  self._combiner: Combiner = Combiner()
133
137
 
134
- self._refiner_type: str
135
- self._refiner: Refiner
138
+ self._splitter_type: str
139
+ self._refiner_type: str | None
140
+ self._retriever_type: str | None
136
141
 
137
- self.skip_refiner = skip_refiner
142
+ self._splitter: Splitter
143
+ self._refiner: JanusRefiner
144
+ self._retriever: JanusRetriever
138
145
 
139
146
  self.set_splitter(splitter_type=splitter_type)
140
147
  self.set_refiner(refiner_type=refiner_type)
148
+ self.set_retriever(retriever_type=retriever_type)
141
149
  self.set_model(model_name=model, **model_arguments)
142
150
  self.set_prompt(prompt_template=prompt_template)
143
151
  self.set_source_language(source_language)
@@ -146,8 +154,6 @@ class Converter:
146
154
  self.set_db_path(db_path=db_path)
147
155
  self.set_db_config(db_config=db_config)
148
156
 
149
- self.skip_context = skip_context
150
-
151
157
  # Child class must call this. Should we enforce somehow?
152
158
  # self._load_parameters()
153
159
 
@@ -163,9 +169,11 @@ class Converter:
163
169
  def _load_parameters(self) -> None:
164
170
  self._load_model()
165
171
  self._load_prompt()
172
+ self._load_retriever()
173
+ self._load_refiner()
166
174
  self._load_splitter()
167
175
  self._load_vectorizer()
168
- self._load_refiner()
176
+ self._load_chain()
169
177
  self._changed_attrs.clear()
170
178
 
171
179
  def set_model(self, model_name: str, **custom_arguments: dict[str, Any]):
@@ -184,8 +192,6 @@ class Converter:
184
192
  def set_prompt(self, prompt_template: str) -> None:
185
193
  """Validate and set the prompt template name.
186
194
 
187
- The affected objects will not be updated until translate() is called.
188
-
189
195
  Arguments:
190
196
  prompt_template: name of prompt template directory
191
197
  (see janus/prompts/templates) or path to a directory.
@@ -195,29 +201,34 @@ class Converter:
195
201
  def set_splitter(self, splitter_type: str) -> None:
196
202
  """Validate and set the prompt template name.
197
203
 
198
- The affected objects will not be updated until translate() is called.
199
-
200
204
  Arguments:
201
205
  prompt_template: name of prompt template directory
202
206
  (see janus/prompts/templates) or path to a directory.
203
207
  """
204
- self._splitter_type = splitter_type
208
+ if splitter_type not in CUSTOM_SPLITTERS:
209
+ raise ValueError(f'Splitter type "{splitter_type}" does not exist.')
205
210
 
206
- def set_refiner(self, refiner_type: str) -> None:
207
- """Validate and set the refiner name
211
+ self._splitter_type = splitter_type
208
212
 
209
- The affected objects will not be updated until translate is called
213
+ def set_refiner(self, refiner_type: str | None) -> None:
214
+ """Validate and set the refiner type
210
215
 
211
216
  Arguments:
212
- refiner_type: the name of the refiner to use
217
+ refiner_type: the type of refiner to use
213
218
  """
214
219
  self._refiner_type = refiner_type
215
220
 
221
+ def set_retriever(self, retriever_type: str | None) -> None:
222
+ """Validate and set the retriever type
223
+
224
+ Arguments:
225
+ retriever_type: the type of retriever to use
226
+ """
227
+ self._retriever_type = retriever_type
228
+
216
229
  def set_source_language(self, source_language: str) -> None:
217
230
  """Validate and set the source language.
218
231
 
219
- The affected objects will not be updated until _load_parameters() is called.
220
-
221
232
  Arguments:
222
233
  source_language: The source programming language.
223
234
  """
@@ -287,20 +298,6 @@ class Converter:
287
298
 
288
299
  self._splitter = CUSTOM_SPLITTERS[self._splitter_type](**kwargs)
289
300
 
290
- @run_if_changed("_refiner_type", "_model_name")
291
- def _load_refiner(self) -> None:
292
- """Load the refiner according to this instance's attributes.
293
-
294
- If the relevant fields have not been changed since the last time this method was
295
- called, nothing happens.
296
- """
297
- if self._refiner_type == "basic":
298
- self._refiner = BasicRefiner(
299
- "basic_refinement", self._model_id, self._source_language
300
- )
301
- else:
302
- raise ValueError(f"Error: unknown refiner type {self._refiner_type}")
303
-
304
301
  @run_if_changed("_model_name", "_custom_model_arguments")
305
302
  def _load_model(self) -> None:
306
303
  """Load the model according to this instance's attributes.
@@ -314,9 +311,9 @@ class Converter:
314
311
  # model_arguments.update(self._custom_model_arguments)
315
312
 
316
313
  # Load the model
317
- self._llm, self._model_id, token_limit, self.model_cost = load_model(
318
- self._model_name
319
- )
314
+ self._llm = load_model(self._model_name)
315
+ token_limit = self._llm.token_limit
316
+
320
317
  # Set the max_tokens to less than half the model's limit to allow for enough
321
318
  # tokens at output
322
319
  # Only modify max_tokens if it is not specified by user
@@ -335,7 +332,7 @@ class Converter:
335
332
  If the relevant fields have not been changed since the last time this
336
333
  method was called, nothing happens.
337
334
  """
338
- prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
335
+ prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
339
336
  source_language=self._source_language,
340
337
  prompt_template=self._prompt_template_name,
341
338
  )
@@ -354,6 +351,59 @@ class Converter:
354
351
  self._db_path, self._db_config
355
352
  )
356
353
 
354
+ @run_if_changed("_retriever_type")
355
+ def _load_retriever(self):
356
+ if self._retriever_type == "active_usings":
357
+ self._retriever = ActiveUsingsRetriever()
358
+ else:
359
+ self._retriever = JanusRetriever()
360
+
361
+ @run_if_changed("_refiner_type", "_model_name", "max_prompts", "_parser", "_llm")
362
+ def _load_refiner(self) -> None:
363
+ """Load the refiner according to this instance's attributes.
364
+
365
+ If the relevant fields have not been changed since the last time this method was
366
+ called, nothing happens.
367
+ """
368
+ if self._refiner_type == "parser":
369
+ self._refiner = FixParserExceptions(
370
+ llm=self._llm,
371
+ parser=self._parser,
372
+ max_retries=self.max_prompts,
373
+ )
374
+ elif self._refiner_type == "reflection":
375
+ self._refiner = ReflectionRefiner(
376
+ llm=self._llm,
377
+ parser=self._parser,
378
+ max_retries=self.max_prompts,
379
+ )
380
+ elif self._refiner_type == "hallucination":
381
+ self._refiner = HallucinationRefiner(
382
+ llm=self._llm,
383
+ parser=self._parser,
384
+ max_retries=self.max_prompts,
385
+ )
386
+ else:
387
+ self._refiner = JanusRefiner(parser=self._parser)
388
+
389
+ @run_if_changed("_parser", "_retriever", "_prompt", "_llm", "_refiner")
390
+ def _load_chain(self):
391
+ self.chain = (
392
+ self._input_runnable()
393
+ | self._prompt
394
+ | RunnableParallel(
395
+ completion=self._llm,
396
+ prompt_value=RunnablePassthrough(),
397
+ )
398
+ | self._refiner.parse_runnable
399
+ )
400
+
401
+ def _input_runnable(self) -> Runnable:
402
+ return RunnableParallel(
403
+ SOURCE_CODE=self._parser.parse_input,
404
+ context=self._retriever,
405
+ )
406
+
357
407
  def translate(
358
408
  self,
359
409
  input_directory: str | Path,
@@ -598,110 +648,29 @@ class Converter:
598
648
  return root
599
649
 
600
650
  def _run_chain(self, block: TranslatedCodeBlock) -> str:
601
- """Run the model with three nested error fixing schemes.
602
- First, try to fix simple formatting errors by giving the model just
603
- the output and the parsing error. After a number of attempts, try
604
- giving the model the output, the parsing error, and the original
605
- input. Again check/retry this output to solve for formatting errors.
606
- If we still haven't succeeded after several attempts, the model may
607
- be getting thrown off by a bad initial output; start from scratch
608
- and try again.
609
-
610
- The number of tries for each layer of this scheme is roughly equal
611
- to the cube root of self.max_retries, so the total calls to the
612
- LLM will be roughly as expected (up to sqrt(self.max_retries) over)
613
- """
614
- input = self._parser.parse_input(block.original)
615
-
616
- # Retries with just the output and the error
617
- n1 = round(self.max_prompts ** (1 / 2))
618
-
619
- # Retries with the input, output, and error
620
- n2 = round(self.max_prompts // n1)
621
-
622
- if not self.skip_context:
623
- self._make_prompt_additions(block)
624
- if not self.skip_refiner: # Make replacements in the prompt
625
- refine_output = RefinerParser(
626
- parser=self._parser,
627
- initial_prompt=self._prompt.format(**{"SOURCE_CODE": input}),
628
- refiner=self._refiner,
629
- max_retries=n1,
630
- llm=self._llm,
631
- )
632
- else:
633
- refine_output = RetryWithErrorOutputParser.from_llm(
634
- llm=self._llm,
635
- parser=self._parser,
636
- max_retries=n1,
637
- )
638
-
639
- completion_chain = self._prompt | self._llm
640
- chain = RunnableParallel(
641
- completion=completion_chain, prompt_value=self._prompt
642
- ) | RunnableLambda(lambda x: refine_output.parse_with_prompt(**x))
643
- for _ in range(n2):
644
- try:
645
- return chain.invoke({"SOURCE_CODE": input})
646
- except OutputParserException:
647
- pass
648
-
649
- raise OutputParserException(f"Failed to parse after {n1*n2} retries")
651
+ return self.chain.invoke(block.original)
650
652
 
651
653
  def _get_output_obj(
652
654
  self, block: TranslatedCodeBlock
653
- ) -> dict[str, int | float | str | dict[str, str]]:
655
+ ) -> dict[str, int | float | str | dict[str, str] | dict[str, float]]:
654
656
  output_str = self._parser.parse_combined_output(block.complete_text)
655
657
 
656
- output: str | dict[str, str]
658
+ output_obj: str | dict[str, str]
657
659
  try:
658
- output = json.loads(output_str)
660
+ output_obj = json.loads(output_str)
659
661
  except json.JSONDecodeError:
660
- output = output_str
662
+ output_obj = output_str
661
663
 
662
664
  return dict(
663
- input=block.original.text,
665
+ input=block.original.text or "",
664
666
  metadata=dict(
665
667
  retries=block.total_retries,
666
668
  cost=block.total_cost,
667
669
  processing_time=block.processing_time,
668
670
  ),
669
- output=output,
670
- )
671
-
672
- @staticmethod
673
- def _get_prompt_additions(block) -> Optional[List[Tuple[str, str]]]:
674
- """Get a list of strings to append to the prompt.
675
-
676
- Arguments:
677
- block: The `TranslatedCodeBlock` to save to a file.
678
- """
679
- return [(key, item) for key, item in block.context_tags.items()]
680
-
681
- def _make_prompt_additions(self, block: CodeBlock):
682
- # Prepare the additional context to prepend
683
- additional_context = "".join(
684
- [
685
- f"{context_tag}: {context}\n"
686
- for context_tag, context in self._get_prompt_additions(block)
687
- ]
671
+ output=output_obj,
688
672
  )
689
673
 
690
- if not hasattr(self._prompt, "messages"):
691
- log.debug("Skipping additions to prompt, no messages found on prompt object!")
692
- return
693
-
694
- # Iterate through existing messages to find and update the system message
695
- for i, message in enumerate(self._prompt.messages):
696
- if isinstance(message, SystemMessagePromptTemplate):
697
- # Prepend the additional context to the system message
698
- updated_system_message = SystemMessagePromptTemplate.from_template(
699
- additional_context + message.prompt.template
700
- )
701
- # Directly modify the message in the list
702
- self._prompt.messages[i] = updated_system_message
703
- break # Assuming there's only one system message to update
704
-
705
674
  def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
706
675
  """Save a file to disk.
707
676
 
@@ -0,0 +1,51 @@
1
+ from langchain_core.runnables import Runnable, RunnableParallel
2
+
3
+ from janus.converter.document import Documenter
4
+ from janus.parsers.uml import UMLSyntaxParser
5
+ from janus.utils.logger import create_logger
6
+
7
+ log = create_logger(__name__)
8
+
9
+
10
+ class DiagramGenerator(Documenter):
11
+ """A Converter that translates code into a set of PLANTUML diagrams."""
12
+
13
+ def __init__(
14
+ self,
15
+ diagram_type="Activity",
16
+ add_documentation=False,
17
+ **kwargs,
18
+ ) -> None:
19
+ """Initialize the DiagramGenerator class
20
+
21
+ Arguments:
22
+ diagram_type: type of PLANTUML diagram to generate
23
+ add_documentation: Whether to add a documentation step prior to
24
+ diagram generation.
25
+ """
26
+ self._diagram_type = diagram_type
27
+ self._add_documentation = add_documentation
28
+ self._documenter = Documenter(**kwargs)
29
+
30
+ super().__init__(**kwargs)
31
+
32
+ self.set_prompt("diagram_with_documentation" if add_documentation else "diagram")
33
+ self._parser = UMLSyntaxParser(language="plantuml")
34
+
35
+ self._load_parameters()
36
+
37
+ def _load_prompt(self):
38
+ super()._load_prompt()
39
+ self._prompt = self._prompt.partial(DIAGRAM_TYPE=self._diagram_type)
40
+
41
+ def _input_runnable(self) -> Runnable:
42
+ if self._add_documentation:
43
+ return RunnableParallel(
44
+ SOURCE_CODE=self._parser.parse_input,
45
+ DOCUMENTATION=self._documenter.chain,
46
+ context=self._retriever,
47
+ )
48
+ return RunnableParallel(
49
+ SOURCE_CODE=self._parser.parse_input,
50
+ context=self._retriever,
51
+ )
@@ -90,7 +90,7 @@ class Translator(Converter):
90
90
  f"({self._source_language} != {self._target_language})"
91
91
  )
92
92
 
93
- prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
93
+ prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
94
94
  source_language=self._source_language,
95
95
  target_language=self._target_language,
96
96
  target_version=self._target_version,
@@ -12,7 +12,7 @@ class TestAlcSplitter(unittest.TestCase):
12
12
  def setUp(self):
13
13
  """Set up the tests."""
14
14
  model_name = "gpt-4o"
15
- llm, _, _, _ = load_model(model_name)
15
+ llm = load_model(model_name)
16
16
  self.splitter = AlcSplitter(model=llm)
17
17
  self.combiner = Combiner(language="ibmhlasm")
18
18
  self.test_file = Path("janus/language/alc/_tests/alc.asm")