janus-llm 3.2.0__tar.gz → 3.3.0__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (111) hide show
  1. {janus_llm-3.2.0 → janus_llm-3.3.0}/PKG-INFO +1 -1
  2. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/__init__.py +3 -3
  3. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/_tests/test_cli.py +3 -3
  4. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/cli.py +1 -1
  5. janus_llm-3.3.0/janus/converter/__init__.py +6 -0
  6. janus_llm-3.3.0/janus/converter/_tests/test_translate.py +156 -0
  7. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/converter/converter.py +49 -7
  8. janus_llm-3.3.0/janus/converter/diagram.py +139 -0
  9. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/embedding/_tests/test_collections.py +2 -2
  10. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/embedding/_tests/test_database.py +1 -1
  11. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/embedding/_tests/test_vectorize.py +3 -3
  12. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/embedding/collections.py +2 -2
  13. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/embedding/database.py +1 -1
  14. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/embedding/embedding_models_info.py +1 -1
  15. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/embedding/vectorize.py +5 -5
  16. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/_tests/test_combine.py +1 -1
  17. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/_tests/test_splitter.py +1 -1
  18. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/alc/_tests/test_alc.py +3 -3
  19. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/alc/alc.py +5 -5
  20. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/binary/_tests/test_binary.py +2 -2
  21. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/binary/binary.py +5 -5
  22. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/block.py +2 -2
  23. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/combine.py +3 -3
  24. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/file.py +2 -2
  25. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/mumps/_tests/test_mumps.py +3 -3
  26. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/mumps/mumps.py +5 -5
  27. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/mumps/patterns.py +1 -1
  28. janus_llm-3.3.0/janus/language/naive/__init__.py +4 -0
  29. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/naive/basic_splitter.py +4 -4
  30. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/naive/chunk_splitter.py +4 -4
  31. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/naive/registry.py +1 -1
  32. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/naive/simple_ast.py +5 -5
  33. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/naive/tag_splitter.py +4 -4
  34. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/node.py +1 -1
  35. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/splitter.py +4 -4
  36. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/treesitter/_tests/test_treesitter.py +3 -3
  37. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/treesitter/treesitter.py +4 -4
  38. janus_llm-3.3.0/janus/llm/__init__.py +1 -0
  39. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/llm/model_callbacks.py +1 -1
  40. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/llm/models_info.py +5 -3
  41. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/metrics/_tests/test_bleu.py +1 -1
  42. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/metrics/_tests/test_chrf.py +1 -1
  43. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/metrics/_tests/test_file_pairing.py +1 -1
  44. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/metrics/_tests/test_llm.py +2 -2
  45. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/metrics/_tests/test_reading.py +1 -1
  46. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/metrics/_tests/test_rouge_score.py +1 -1
  47. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/metrics/_tests/test_similarity_score.py +1 -1
  48. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/metrics/_tests/test_treesitter_metrics.py +2 -2
  49. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/metrics/bleu.py +1 -1
  50. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/metrics/chrf.py +1 -1
  51. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/metrics/complexity_metrics.py +4 -4
  52. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/metrics/file_pairing.py +5 -5
  53. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/metrics/llm_metrics.py +1 -1
  54. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/metrics/metric.py +7 -7
  55. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/metrics/reading.py +1 -1
  56. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/metrics/rouge_score.py +1 -1
  57. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/metrics/similarity.py +2 -2
  58. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/parsers/_tests/test_code_parser.py +1 -1
  59. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/parsers/code_parser.py +2 -2
  60. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/parsers/doc_parser.py +3 -3
  61. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/parsers/eval_parser.py +2 -2
  62. janus_llm-3.3.0/janus/parsers/refiner_parser.py +49 -0
  63. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/parsers/reqs_parser.py +3 -3
  64. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/parsers/uml.py +1 -2
  65. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/prompts/prompt.py +2 -2
  66. janus_llm-3.3.0/janus/refiners/refiner.py +63 -0
  67. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/utils/_tests/test_logger.py +1 -1
  68. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/utils/_tests/test_progress.py +1 -1
  69. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/utils/progress.py +1 -1
  70. {janus_llm-3.2.0 → janus_llm-3.3.0}/pyproject.toml +1 -1
  71. janus_llm-3.2.0/janus/converter/__init__.py +0 -6
  72. janus_llm-3.2.0/janus/converter/_tests/test_translate.py +0 -383
  73. janus_llm-3.2.0/janus/converter/diagram.py +0 -126
  74. janus_llm-3.2.0/janus/language/naive/__init__.py +0 -4
  75. janus_llm-3.2.0/janus/llm/__init__.py +0 -1
  76. {janus_llm-3.2.0 → janus_llm-3.3.0}/LICENSE +0 -0
  77. {janus_llm-3.2.0 → janus_llm-3.3.0}/README.md +0 -0
  78. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/__main__.py +0 -0
  79. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/_tests/__init__.py +0 -0
  80. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/_tests/conftest.py +0 -0
  81. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/converter/_tests/__init__.py +0 -0
  82. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/converter/document.py +0 -0
  83. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/converter/evaluate.py +0 -0
  84. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/converter/requirements.py +0 -0
  85. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/converter/translate.py +0 -0
  86. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/embedding/__init__.py +0 -0
  87. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/embedding/_tests/__init__.py +0 -0
  88. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/__init__.py +0 -0
  89. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/_tests/__init__.py +0 -0
  90. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/alc/__init__.py +0 -0
  91. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/alc/_tests/__init__.py +0 -0
  92. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/binary/__init__.py +0 -0
  93. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/binary/_tests/__init__.py +0 -0
  94. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/binary/reveng/decompile_script.py +0 -0
  95. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/mumps/__init__.py +0 -0
  96. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/mumps/_tests/__init__.py +0 -0
  97. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/treesitter/__init__.py +0 -0
  98. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/language/treesitter/_tests/__init__.py +0 -0
  99. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/metrics/__init__.py +0 -0
  100. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/metrics/_tests/__init__.py +0 -0
  101. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/metrics/_tests/reference.py +0 -0
  102. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/metrics/_tests/target.py +0 -0
  103. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/metrics/cli.py +0 -0
  104. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/metrics/splitting.py +0 -0
  105. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/parsers/__init__.py +0 -0
  106. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/parsers/_tests/__init__.py +0 -0
  107. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/prompts/__init__.py +0 -0
  108. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/utils/__init__.py +0 -0
  109. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/utils/_tests/__init__.py +0 -0
  110. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/utils/enums.py +0 -0
  111. {janus_llm-3.2.0 → janus_llm-3.3.0}/janus/utils/logger.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: janus-llm
3
- Version: 3.2.0
3
+ Version: 3.3.0
4
4
  Summary: A transcoding library using LLMs.
5
5
  Home-page: https://github.com/janus-llm/janus-llm
6
6
  License: Apache 2.0
@@ -2,10 +2,10 @@ import warnings
2
2
 
3
3
  from langchain_core._api.deprecation import LangChainDeprecationWarning
4
4
 
5
- from .converter.translate import Translator
6
- from .metrics import * # noqa: F403
5
+ from janus.converter.translate import Translator
6
+ from janus.metrics import * # noqa: F403
7
7
 
8
- __version__ = "3.2.0"
8
+ __version__ = "3.3.0"
9
9
 
10
10
  # Ignoring a deprecation warning from langchain_core that I can't seem to hunt down
11
11
  warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
@@ -4,9 +4,9 @@ from unittest.mock import ANY, patch
4
4
 
5
5
  from typer.testing import CliRunner
6
6
 
7
- from ..cli import app, translate
8
- from ..embedding.embedding_models_info import EMBEDDING_MODEL_CONFIG_DIR
9
- from ..llm.models_info import MODEL_CONFIG_DIR
7
+ from janus.cli import app, translate
8
+ from janus.embedding.embedding_models_info import EMBEDDING_MODEL_CONFIG_DIR
9
+ from janus.llm.models_info import MODEL_CONFIG_DIR
10
10
 
11
11
 
12
12
  class TestCli(unittest.TestCase):
@@ -108,7 +108,7 @@ embedding = typer.Typer(
108
108
 
109
109
  def version_callback(value: bool) -> None:
110
110
  if value:
111
- from . import __version__ as version
111
+ from janus import __version__ as version
112
112
 
113
113
  print(f"Janus CLI [blue]v{version}[/blue]")
114
114
  raise typer.Exit()
@@ -0,0 +1,6 @@
1
+ from janus.converter.converter import Converter
2
+ from janus.converter.diagram import DiagramGenerator
3
+ from janus.converter.document import Documenter, MadLibsDocumenter, MultiDocumenter
4
+ from janus.converter.evaluate import Evaluator
5
+ from janus.converter.requirements import RequirementsDocumenter
6
+ from janus.converter.translate import Translator
@@ -0,0 +1,156 @@
1
+ import unittest
2
+ from pathlib import Path
3
+ from typing import Any, Iterable, List, Optional, Type
4
+
5
+ import pytest
6
+ from langchain.schema import Document
7
+ from langchain.schema.embeddings import Embeddings
8
+ from langchain.schema.vectorstore import VST, VectorStore
9
+
10
+ from janus.converter.diagram import DiagramGenerator
11
+ from janus.converter.requirements import RequirementsDocumenter
12
+ from janus.converter.translate import Translator
13
+ from janus.language.block import CodeBlock, TranslatedCodeBlock
14
+
15
+
16
+ class MockCollection(VectorStore):
17
+ """Vector store for testing"""
18
+
19
+ def __init__(self):
20
+ self._add_texts_calls = 0
21
+
22
+ def add_texts(
23
+ self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, **kwargs: Any
24
+ ) -> List[str]:
25
+ self._add_texts_calls += 1
26
+ return ["id"]
27
+
28
+ def similarity_search(self, query: str, k: int = 4, **kwargs: Any) -> List[Document]:
29
+ raise NotImplementedError("similarity_search() not implemented!")
30
+
31
+ @classmethod
32
+ def from_texts(
33
+ cls: Type[VST],
34
+ texts: List[str],
35
+ embedding: Embeddings,
36
+ metadatas: Optional[List[dict]] = None,
37
+ **kwargs: Any,
38
+ ) -> VST:
39
+ raise NotImplementedError("from_texts() not implemented!")
40
+
41
+
42
+ class TestTranslator(unittest.TestCase):
43
+ """Tests for the Translator class."""
44
+
45
+ def setUp(self):
46
+ """Set up the tests."""
47
+ self.translator = Translator(
48
+ model="gpt-4o-mini",
49
+ source_language="fortran",
50
+ target_language="python",
51
+ target_version="3.10",
52
+ splitter_type="ast-flex",
53
+ )
54
+ self.test_file = Path("janus/language/treesitter/_tests/languages/fortran.f90")
55
+ self.TEST_FILE_EMBEDDING_COUNT = 14
56
+
57
+ self.req_translator = RequirementsDocumenter(
58
+ model="gpt-4o-mini",
59
+ source_language="fortran",
60
+ prompt_template="requirements",
61
+ )
62
+
63
+ @pytest.mark.translate
64
+ def test_translate(self):
65
+ """Test translate method."""
66
+ # Delete a file if it's already there
67
+ python_file = self.test_file.parent / "python" / f"{self.test_file.stem}.py"
68
+ python_file.unlink(missing_ok=True)
69
+ python_file.parent.rmdir() if python_file.parent.is_dir() else None
70
+ self.translator.translate(self.test_file.parent, self.test_file.parent / "python")
71
+ # Only check the top-most level functionality, since it should be handled by other
72
+ # unit tests anyway
73
+ self.assertTrue(python_file.exists())
74
+
75
+ def test_invalid_selections(self) -> None:
76
+ """Tests that settings values for the translator will raise exceptions"""
77
+ self.assertRaises(
78
+ ValueError, self.translator.set_target_language, "gobbledy", "goobledy"
79
+ )
80
+ self.assertRaises(
81
+ ValueError, self.translator.set_source_language, "scribbledy-doop"
82
+ )
83
+ self.translator.set_prompt("pish posh")
84
+ self.assertRaises(ValueError, self.translator._load_parameters)
85
+
86
+
87
+ class TestDiagramGenerator(unittest.TestCase):
88
+ """Tests for the DiagramGenerator class."""
89
+
90
+ def setUp(self):
91
+ """Set up the tests."""
92
+ self.diagram_generator = DiagramGenerator(
93
+ model="gpt-4o",
94
+ source_language="fortran",
95
+ diagram_type="Activity",
96
+ )
97
+
98
+ def test_init(self):
99
+ """Test __init__ method."""
100
+ self.assertEqual(self.diagram_generator._model_name, "gpt-4o")
101
+ self.assertEqual(self.diagram_generator._source_language, "fortran")
102
+ self.assertEqual(self.diagram_generator._diagram_type, "Activity")
103
+
104
+ def test_add_translation(self):
105
+ """Test _add_translation method."""
106
+ block = TranslatedCodeBlock(
107
+ original=CodeBlock(
108
+ id="test",
109
+ name="Test Block",
110
+ node_type="function",
111
+ language="python",
112
+ text="print('Hello, World!')",
113
+ start_point=(0, 0),
114
+ end_point=(1, 0),
115
+ start_byte=0,
116
+ end_byte=1,
117
+ tokens=5,
118
+ children=[],
119
+ ),
120
+ language="python",
121
+ )
122
+ self.diagram_generator._add_translation(block)
123
+ self.assertTrue(block.translated)
124
+ self.assertIsNotNone(block.text)
125
+ self.assertIsNotNone(block.tokens)
126
+
127
+
128
+ @pytest.mark.parametrize(
129
+ "source_language,prompt_template,expected_target_language,expected_target_version,",
130
+ [
131
+ ("python", "document_inline", "python", "3.10"),
132
+ ("fortran", "document", "text", None),
133
+ ("mumps", "requirements", "text", None),
134
+ ("python", "simple", "javascript", "es6"),
135
+ ],
136
+ )
137
+ def test_language_combinations(
138
+ source_language: str,
139
+ prompt_template: str,
140
+ expected_target_language: str,
141
+ expected_target_version: str,
142
+ ):
143
+ """Tests that translator target language settings are consistent
144
+ with prompt template expectations.
145
+ """
146
+ translator = Translator(model="gpt-4o")
147
+ translator.set_model("gpt-4o")
148
+ translator.set_source_language(source_language)
149
+ translator.set_target_language(expected_target_language, expected_target_version)
150
+ translator.set_prompt(prompt_template)
151
+ translator._load_parameters()
152
+ assert translator._target_language == expected_target_language # nosec
153
+ assert translator._target_version == expected_target_version # nosec
154
+ assert translator._splitter.language == source_language # nosec
155
+ assert translator._splitter.model.model_name == "gpt-4o" # nosec
156
+ assert translator._prompt_template_name == prompt_template # nosec
@@ -6,7 +6,6 @@ from pathlib import Path
6
6
  from typing import Any
7
7
 
8
8
  from langchain.output_parsers import RetryWithErrorOutputParser
9
- from langchain.output_parsers.fix import OutputFixingParser
10
9
  from langchain_core.exceptions import OutputParserException
11
10
  from langchain_core.language_models import BaseLanguageModel
12
11
  from langchain_core.output_parsers import BaseOutputParser
@@ -29,6 +28,8 @@ from janus.llm import load_model
29
28
  from janus.llm.model_callbacks import get_model_callback
30
29
  from janus.llm.models_info import MODEL_PROMPT_ENGINES
31
30
  from janus.parsers.code_parser import GenericParser
31
+ from janus.parsers.refiner_parser import RefinerParser
32
+ from janus.refiners.refiner import BasicRefiner, Refiner
32
33
  from janus.utils.enums import LANGUAGES
33
34
  from janus.utils.logger import create_logger
34
35
 
@@ -75,6 +76,7 @@ class Converter:
75
76
  protected_node_types: tuple[str, ...] = (),
76
77
  prune_node_types: tuple[str, ...] = (),
77
78
  splitter_type: str = "file",
79
+ refiner_type: str = "basic",
78
80
  ) -> None:
79
81
  """Initialize a Converter instance.
80
82
 
@@ -84,6 +86,17 @@ class Converter:
84
86
  values are `"code"`, `"text"`, `"eval"`, and `None` (default). If `None`,
85
87
  the `Converter` assumes you won't be parsing an output (i.e., adding to an
86
88
  embedding DB).
89
+ max_prompts: The maximum number of prompts to try before giving up.
90
+ max_tokens: The maximum number of tokens to use in the LLM. If `None`, the
91
+ converter will use half the model's token limit.
92
+ prompt_template: The name of the prompt template to use.
93
+ db_path: The path to the database to use for vectorization.
94
+ db_config: The configuration for the database.
95
+ protected_node_types: A set of node types that aren't to be merged.
96
+ prune_node_types: A set of node types which should be pruned.
97
+ splitter_type: The type of splitter to use. Valid values are `"file"`,
98
+ `"tag"`, `"chunk"`, `"ast-strict"`, and `"ast-flex"`.
99
+ refiner_type: The type of refiner to use. Valid values are `"basic"`.
87
100
  """
88
101
  self._changed_attrs: set = set()
89
102
 
@@ -116,7 +129,11 @@ class Converter:
116
129
  self._parser: BaseOutputParser = GenericParser()
117
130
  self._combiner: Combiner = Combiner()
118
131
 
132
+ self._refiner_type: str
133
+ self._refiner: Refiner
134
+
119
135
  self.set_splitter(splitter_type=splitter_type)
136
+ self.set_refiner(refiner_type=refiner_type)
120
137
  self.set_model(model_name=model, **model_arguments)
121
138
  self.set_prompt(prompt_template=prompt_template)
122
139
  self.set_source_language(source_language)
@@ -142,6 +159,7 @@ class Converter:
142
159
  self._load_prompt()
143
160
  self._load_splitter()
144
161
  self._load_vectorizer()
162
+ self._load_refiner()
145
163
  self._changed_attrs.clear()
146
164
 
147
165
  def set_model(self, model_name: str, **custom_arguments: dict[str, Any]):
@@ -179,6 +197,16 @@ class Converter:
179
197
  """
180
198
  self._splitter_type = splitter_type
181
199
 
200
+ def set_refiner(self, refiner_type: str) -> None:
201
+ """Validate and set the refiner name
202
+
203
+ The affected objects will not be updated until translate is called
204
+
205
+ Arguments:
206
+ refiner_type: the name of the refiner to use
207
+ """
208
+ self._refiner_type = refiner_type
209
+
182
210
  def set_source_language(self, source_language: str) -> None:
183
211
  """Validate and set the source language.
184
212
 
@@ -249,10 +277,24 @@ class Converter:
249
277
  )
250
278
 
251
279
  if self._splitter_type == "tag":
252
- kwargs["tag"] = "<ITMOD_ALC_SPLIT>"
280
+ kwargs["tag"] = "<ITMOD_ALC_SPLIT>" # Hardcoded for now
253
281
 
254
282
  self._splitter = CUSTOM_SPLITTERS[self._splitter_type](**kwargs)
255
283
 
284
+ @run_if_changed("_refiner_type", "_model_name")
285
+ def _load_refiner(self) -> None:
286
+ """Load the refiner according to this instance's attributes.
287
+
288
+ If the relevant fields have not been changed since the last time this method was
289
+ called, nothing happens.
290
+ """
291
+ if self._refiner_type == "basic":
292
+ self._refiner = BasicRefiner(
293
+ "basic_refinement", self._model_name, self._source_language
294
+ )
295
+ else:
296
+ raise ValueError(f"Error: unknown refiner type {self._refiner_type}")
297
+
256
298
  @run_if_changed("_model_name", "_custom_model_arguments")
257
299
  def _load_model(self) -> None:
258
300
  """Load the model according to this instance's attributes.
@@ -561,22 +603,22 @@ class Converter:
561
603
  # Retries with just the input
562
604
  n3 = math.ceil(self.max_prompts / (n1 * n2))
563
605
 
564
- fix_format = OutputFixingParser.from_llm(
565
- llm=self._llm,
606
+ refine_output = RefinerParser(
566
607
  parser=self._parser,
608
+ initial_prompt=self._prompt.format(**{"SOURCE_CODE": block.original.text}),
609
+ refiner=self._refiner,
567
610
  max_retries=n1,
611
+ llm=self._llm,
568
612
  )
569
613
  retry = RetryWithErrorOutputParser.from_llm(
570
614
  llm=self._llm,
571
- parser=fix_format,
615
+ parser=refine_output,
572
616
  max_retries=n2,
573
617
  )
574
-
575
618
  completion_chain = self._prompt | self._llm
576
619
  chain = RunnableParallel(
577
620
  completion=completion_chain, prompt_value=self._prompt
578
621
  ) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
579
-
580
622
  for _ in range(n3):
581
623
  try:
582
624
  return chain.invoke({"SOURCE_CODE": block.original.text})
@@ -0,0 +1,139 @@
1
+ import math
2
+
3
+ from langchain.output_parsers import RetryWithErrorOutputParser
4
+ from langchain_core.exceptions import OutputParserException
5
+ from langchain_core.runnables import RunnableLambda, RunnableParallel
6
+
7
+ from janus.converter.converter import run_if_changed
8
+ from janus.converter.document import Documenter
9
+ from janus.language.block import TranslatedCodeBlock
10
+ from janus.llm.models_info import MODEL_PROMPT_ENGINES
11
+ from janus.parsers.refiner_parser import RefinerParser
12
+ from janus.parsers.uml import UMLSyntaxParser
13
+ from janus.utils.logger import create_logger
14
+
15
+ log = create_logger(__name__)
16
+
17
+
18
+ class DiagramGenerator(Documenter):
19
+ """DiagramGenerator
20
+
21
+ A class that translates code from one programming language to a set of diagrams.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ diagram_type="Activity",
27
+ add_documentation=False,
28
+ **kwargs,
29
+ ) -> None:
30
+ """Initialize the DiagramGenerator class
31
+
32
+ Arguments:
33
+ model: The LLM to use for translation. If an OpenAI model, the
34
+ `OPENAI_API_KEY` environment variable must be set and the
35
+ `OPENAI_ORG_ID` environment variable should be set if needed.
36
+ model_arguments: Additional arguments to pass to the LLM constructor.
37
+ source_language: The source programming language.
38
+ max_prompts: The maximum number of prompts to try before giving up.
39
+ db_path: path to chroma database
40
+ db_config: database configuraiton
41
+ diagram_type: type of PLANTUML diagram to generate
42
+ """
43
+ super().__init__(**kwargs)
44
+ self._diagram_type = diagram_type
45
+ self._add_documentation = add_documentation
46
+ self._documenter = None
47
+ self._diagram_parser = UMLSyntaxParser(language="plantuml")
48
+ if add_documentation:
49
+ self._diagram_prompt_template_name = "diagram_with_documentation"
50
+ else:
51
+ self._diagram_prompt_template_name = "diagram"
52
+ self._load_diagram_prompt_engine()
53
+
54
+ def _run_chain(self, block: TranslatedCodeBlock) -> str:
55
+ self._parser.set_reference(block.original)
56
+ n1 = round(self.max_prompts ** (1 / 3))
57
+
58
+ # Retries with the input, output, and error
59
+ n2 = round((self.max_prompts // n1) ** (1 / 2))
60
+
61
+ # Retries with just the input
62
+ n3 = math.ceil(self.max_prompts / (n1 * n2))
63
+
64
+ if self._add_documentation:
65
+ documentation_text = super()._run_chain(block)
66
+ refine_output = RefinerParser(
67
+ parser=self._diagram_parser,
68
+ initial_prompt=self._diagram_prompt.format(
69
+ **{
70
+ "SOURCE_CODE": block.original.text,
71
+ "DOCUMENTATION": documentation_text,
72
+ "DIAGRAM_TYPE": self._diagram_type,
73
+ }
74
+ ),
75
+ refiner=self._refiner,
76
+ max_retries=n1,
77
+ llm=self._llm,
78
+ )
79
+ else:
80
+ refine_output = RefinerParser(
81
+ parser=self._diagram_parser,
82
+ initial_prompt=self._diagram_prompt.format(
83
+ **{
84
+ "SOURCE_CODE": block.original.text,
85
+ "DIAGRAM_TYPE": self._diagram_type,
86
+ }
87
+ ),
88
+ refiner=self._refiner,
89
+ max_retries=n1,
90
+ llm=self._llm,
91
+ )
92
+ retry = RetryWithErrorOutputParser.from_llm(
93
+ llm=self._llm,
94
+ parser=refine_output,
95
+ max_retries=n2,
96
+ )
97
+ completion_chain = self._prompt | self._llm
98
+ chain = RunnableParallel(
99
+ completion=completion_chain, prompt_value=self._diagram_prompt
100
+ ) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
101
+ for _ in range(n3):
102
+ try:
103
+ if self._add_documentation:
104
+ return chain.invoke(
105
+ {
106
+ "SOURCE_CODE": block.original.text,
107
+ "DOCUMENTATION": documentation_text,
108
+ "DIAGRAM_TYPE": self._diagram_type,
109
+ }
110
+ )
111
+ else:
112
+ return chain.invoke(
113
+ {
114
+ "SOURCE_CODE": block.original.text,
115
+ "DIAGRAM_TYPE": self._diagram_type,
116
+ }
117
+ )
118
+ except OutputParserException:
119
+ pass
120
+
121
+ raise OutputParserException(f"Failed to parse after {n1*n2*n3} retries")
122
+
123
+ @run_if_changed(
124
+ "_diagram_prompt_template_name",
125
+ "_source_language",
126
+ )
127
+ def _load_diagram_prompt_engine(self) -> None:
128
+ """Load the prompt engine according to this instance's attributes.
129
+
130
+ If the relevant fields have not been changed since the last time this method was
131
+ called, nothing happens.
132
+ """
133
+ self._diagram_prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
134
+ source_language=self._source_language,
135
+ target_language="text",
136
+ target_version=None,
137
+ prompt_template=self._diagram_prompt_template_name,
138
+ )
139
+ self._diagram_prompt = self._diagram_prompt_engine.prompt
@@ -4,8 +4,8 @@ from unittest.mock import MagicMock
4
4
 
5
5
  import pytest
6
6
 
7
- from ...utils.enums import EmbeddingType
8
- from ..collections import Collections
7
+ from janus.embedding.collections import Collections
8
+ from janus.utils.enums import EmbeddingType
9
9
 
10
10
 
11
11
  class TestCollections(unittest.TestCase):
@@ -2,7 +2,7 @@ import unittest
2
2
  from pathlib import Path
3
3
  from unittest.mock import patch
4
4
 
5
- from ..database import ChromaEmbeddingDatabase, uri_to_path
5
+ from janus.embedding.database import ChromaEmbeddingDatabase, uri_to_path
6
6
 
7
7
 
8
8
  class TestDatabase(unittest.TestCase):
@@ -5,9 +5,9 @@ from unittest.mock import MagicMock
5
5
 
6
6
  from chromadb.api.client import Client
7
7
 
8
- from ...language.treesitter import TreeSitterSplitter
9
- from ...utils.enums import EmbeddingType
10
- from ..vectorize import Vectorizer, VectorizerFactory
8
+ from janus.embedding.vectorize import Vectorizer, VectorizerFactory
9
+ from janus.language.treesitter import TreeSitterSplitter
10
+ from janus.utils.enums import EmbeddingType
11
11
 
12
12
 
13
13
  class MockDBVectorizer(VectorizerFactory):
@@ -5,8 +5,8 @@ from typing import Dict, Optional, Sequence
5
5
  from chromadb import Client, Collection
6
6
  from langchain_community.vectorstores import Chroma
7
7
 
8
- from ..utils.enums import EmbeddingType
9
- from .embedding_models_info import load_embedding_model
8
+ from janus.embedding.embedding_models_info import load_embedding_model
9
+ from janus.utils.enums import EmbeddingType
10
10
 
11
11
  # See https://docs.trychroma.com/telemetry#in-chromas-backend-using-environment-variables
12
12
  os.environ["ANONYMIZED_TELEMETRY"] = "False"
@@ -5,7 +5,7 @@ from urllib.request import url2pathname
5
5
 
6
6
  import chromadb
7
7
 
8
- from ..utils.logger import create_logger
8
+ from janus.utils.logger import create_logger
9
9
 
10
10
  log = create_logger(__name__)
11
11
 
@@ -8,7 +8,7 @@ from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEm
8
8
  from langchain_core.embeddings import Embeddings
9
9
  from langchain_openai import OpenAIEmbeddings
10
10
 
11
- from ..utils.logger import create_logger
11
+ from janus.utils.logger import create_logger
12
12
 
13
13
  load_dotenv()
14
14
 
@@ -6,10 +6,10 @@ from typing import Any, Dict, Optional, Sequence
6
6
  from chromadb import Client, Collection
7
7
  from langchain_community.vectorstores import Chroma
8
8
 
9
- from ..language.block import CodeBlock, TranslatedCodeBlock
10
- from ..utils.enums import EmbeddingType
11
- from .collections import Collections
12
- from .database import ChromaEmbeddingDatabase
9
+ from janus.embedding.collections import Collections
10
+ from janus.embedding.database import ChromaEmbeddingDatabase
11
+ from janus.language.block import CodeBlock, TranslatedCodeBlock
12
+ from janus.utils.enums import EmbeddingType
13
13
 
14
14
 
15
15
  class Vectorizer(object):
@@ -59,7 +59,7 @@ class Vectorizer(object):
59
59
  self,
60
60
  code_block: CodeBlock,
61
61
  collection_name: EmbeddingType | str,
62
- filename: str # perhaps this should be a relative path from the source, but for
62
+ filename: str, # perhaps this should be a relative path from the source, but for
63
63
  # now we're all in 1 directory
64
64
  ) -> None:
65
65
  """Calculate `code_block` embedding, returning success & storing in `embedding_id`
@@ -1,6 +1,6 @@
1
1
  import unittest
2
2
 
3
- from ..combine import CodeBlock, Combiner, TranslatedCodeBlock
3
+ from janus.language.combine import CodeBlock, Combiner, TranslatedCodeBlock
4
4
 
5
5
 
6
6
  class TestCombiner(unittest.TestCase):
@@ -1,6 +1,6 @@
1
1
  import unittest
2
2
 
3
- from ..splitter import Splitter
3
+ from janus.language.splitter import Splitter
4
4
 
5
5
 
6
6
  class TestSplitter(unittest.TestCase):
@@ -1,9 +1,9 @@
1
1
  import unittest
2
2
  from pathlib import Path
3
3
 
4
- from ....llm import load_model
5
- from ...combine import Combiner
6
- from ..alc import AlcSplitter
4
+ from janus.language.alc import AlcSplitter
5
+ from janus.language.combine import Combiner
6
+ from janus.llm import load_model
7
7
 
8
8
 
9
9
  class TestAlcSplitter(unittest.TestCase):
@@ -1,10 +1,10 @@
1
1
  from langchain.schema.language_model import BaseLanguageModel
2
2
 
3
- from ...utils.logger import create_logger
4
- from ..block import CodeBlock
5
- from ..combine import Combiner
6
- from ..node import NodeType
7
- from ..treesitter import TreeSitterSplitter
3
+ from janus.language.block import CodeBlock
4
+ from janus.language.combine import Combiner
5
+ from janus.language.node import NodeType
6
+ from janus.language.treesitter import TreeSitterSplitter
7
+ from janus.utils.logger import create_logger
8
8
 
9
9
  log = create_logger(__name__)
10
10
 
@@ -5,8 +5,8 @@ from unittest.mock import patch
5
5
 
6
6
  import pytest
7
7
 
8
- from ....llm import load_model
9
- from ..binary import BinarySplitter
8
+ from janus.language.binary import BinarySplitter
9
+ from janus.llm import load_model
10
10
 
11
11
 
12
12
  class TestBinarySplitter(unittest.TestCase):
@@ -7,11 +7,11 @@ from pathlib import Path
7
7
  import tree_sitter
8
8
  from langchain.schema.language_model import BaseLanguageModel
9
9
 
10
- from ...utils.enums import LANGUAGES
11
- from ...utils.logger import create_logger
12
- from ..block import CodeBlock
13
- from ..combine import Combiner
14
- from ..treesitter import TreeSitterSplitter
10
+ from janus.language.block import CodeBlock
11
+ from janus.language.combine import Combiner
12
+ from janus.language.treesitter import TreeSitterSplitter
13
+ from janus.utils.enums import LANGUAGES
14
+ from janus.utils.logger import create_logger
15
15
 
16
16
  log = create_logger(__name__)
17
17
 
@@ -1,8 +1,8 @@
1
1
  from functools import total_ordering
2
2
  from typing import ForwardRef, Hashable, Optional, Tuple
3
3
 
4
- from ..utils.logger import create_logger
5
- from .node import NodeType
4
+ from janus.language.node import NodeType
5
+ from janus.utils.logger import create_logger
6
6
 
7
7
  log = create_logger(__name__)
8
8