janus-llm 3.2.0__py3-none-any.whl → 3.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. janus/__init__.py +3 -3
  2. janus/_tests/test_cli.py +3 -3
  3. janus/cli.py +1 -1
  4. janus/converter/__init__.py +6 -6
  5. janus/converter/_tests/test_translate.py +6 -233
  6. janus/converter/converter.py +49 -7
  7. janus/converter/diagram.py +68 -55
  8. janus/embedding/_tests/test_collections.py +2 -2
  9. janus/embedding/_tests/test_database.py +1 -1
  10. janus/embedding/_tests/test_vectorize.py +3 -3
  11. janus/embedding/collections.py +2 -2
  12. janus/embedding/database.py +1 -1
  13. janus/embedding/embedding_models_info.py +1 -1
  14. janus/embedding/vectorize.py +5 -5
  15. janus/language/_tests/test_combine.py +1 -1
  16. janus/language/_tests/test_splitter.py +1 -1
  17. janus/language/alc/_tests/test_alc.py +3 -3
  18. janus/language/alc/alc.py +5 -5
  19. janus/language/binary/_tests/test_binary.py +2 -2
  20. janus/language/binary/binary.py +5 -5
  21. janus/language/block.py +2 -2
  22. janus/language/combine.py +3 -3
  23. janus/language/file.py +2 -2
  24. janus/language/mumps/_tests/test_mumps.py +3 -3
  25. janus/language/mumps/mumps.py +5 -5
  26. janus/language/mumps/patterns.py +1 -1
  27. janus/language/naive/__init__.py +4 -4
  28. janus/language/naive/basic_splitter.py +4 -4
  29. janus/language/naive/chunk_splitter.py +4 -4
  30. janus/language/naive/registry.py +1 -1
  31. janus/language/naive/simple_ast.py +5 -5
  32. janus/language/naive/tag_splitter.py +4 -4
  33. janus/language/node.py +1 -1
  34. janus/language/splitter.py +4 -4
  35. janus/language/treesitter/_tests/test_treesitter.py +3 -3
  36. janus/language/treesitter/treesitter.py +4 -4
  37. janus/llm/__init__.py +1 -1
  38. janus/llm/model_callbacks.py +1 -1
  39. janus/llm/models_info.py +5 -3
  40. janus/metrics/_tests/test_bleu.py +1 -1
  41. janus/metrics/_tests/test_chrf.py +1 -1
  42. janus/metrics/_tests/test_file_pairing.py +1 -1
  43. janus/metrics/_tests/test_llm.py +2 -2
  44. janus/metrics/_tests/test_reading.py +1 -1
  45. janus/metrics/_tests/test_rouge_score.py +1 -1
  46. janus/metrics/_tests/test_similarity_score.py +1 -1
  47. janus/metrics/_tests/test_treesitter_metrics.py +2 -2
  48. janus/metrics/bleu.py +1 -1
  49. janus/metrics/chrf.py +1 -1
  50. janus/metrics/complexity_metrics.py +4 -4
  51. janus/metrics/file_pairing.py +5 -5
  52. janus/metrics/llm_metrics.py +1 -1
  53. janus/metrics/metric.py +7 -7
  54. janus/metrics/reading.py +1 -1
  55. janus/metrics/rouge_score.py +1 -1
  56. janus/metrics/similarity.py +2 -2
  57. janus/parsers/_tests/test_code_parser.py +1 -1
  58. janus/parsers/code_parser.py +2 -2
  59. janus/parsers/doc_parser.py +3 -3
  60. janus/parsers/eval_parser.py +2 -2
  61. janus/parsers/refiner_parser.py +49 -0
  62. janus/parsers/reqs_parser.py +3 -3
  63. janus/parsers/uml.py +1 -2
  64. janus/prompts/prompt.py +2 -2
  65. janus/refiners/refiner.py +63 -0
  66. janus/utils/_tests/test_logger.py +1 -1
  67. janus/utils/_tests/test_progress.py +1 -1
  68. janus/utils/progress.py +1 -1
  69. {janus_llm-3.2.0.dist-info → janus_llm-3.3.0.dist-info}/METADATA +1 -1
  70. janus_llm-3.3.0.dist-info/RECORD +107 -0
  71. janus_llm-3.2.0.dist-info/RECORD +0 -105
  72. {janus_llm-3.2.0.dist-info → janus_llm-3.3.0.dist-info}/LICENSE +0 -0
  73. {janus_llm-3.2.0.dist-info → janus_llm-3.3.0.dist-info}/WHEEL +0 -0
  74. {janus_llm-3.2.0.dist-info → janus_llm-3.3.0.dist-info}/entry_points.txt +0 -0
janus/__init__.py CHANGED
@@ -2,10 +2,10 @@ import warnings
2
2
 
3
3
  from langchain_core._api.deprecation import LangChainDeprecationWarning
4
4
 
5
- from .converter.translate import Translator
6
- from .metrics import * # noqa: F403
5
+ from janus.converter.translate import Translator
6
+ from janus.metrics import * # noqa: F403
7
7
 
8
- __version__ = "3.2.0"
8
+ __version__ = "3.3.0"
9
9
 
10
10
  # Ignoring a deprecation warning from langchain_core that I can't seem to hunt down
11
11
  warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
janus/_tests/test_cli.py CHANGED
@@ -4,9 +4,9 @@ from unittest.mock import ANY, patch
4
4
 
5
5
  from typer.testing import CliRunner
6
6
 
7
- from ..cli import app, translate
8
- from ..embedding.embedding_models_info import EMBEDDING_MODEL_CONFIG_DIR
9
- from ..llm.models_info import MODEL_CONFIG_DIR
7
+ from janus.cli import app, translate
8
+ from janus.embedding.embedding_models_info import EMBEDDING_MODEL_CONFIG_DIR
9
+ from janus.llm.models_info import MODEL_CONFIG_DIR
10
10
 
11
11
 
12
12
  class TestCli(unittest.TestCase):
janus/cli.py CHANGED
@@ -108,7 +108,7 @@ embedding = typer.Typer(
108
108
 
109
109
  def version_callback(value: bool) -> None:
110
110
  if value:
111
- from . import __version__ as version
111
+ from janus import __version__ as version
112
112
 
113
113
  print(f"Janus CLI [blue]v{version}[/blue]")
114
114
  raise typer.Exit()
@@ -1,6 +1,6 @@
1
- from .converter import Converter
2
- from .diagram import DiagramGenerator
3
- from .document import Documenter, MadLibsDocumenter, MultiDocumenter
4
- from .evaluate import Evaluator
5
- from .requirements import RequirementsDocumenter
6
- from .translate import Translator
1
+ from janus.converter.converter import Converter
2
+ from janus.converter.diagram import DiagramGenerator
3
+ from janus.converter.document import Documenter, MadLibsDocumenter, MultiDocumenter
4
+ from janus.converter.evaluate import Evaluator
5
+ from janus.converter.requirements import RequirementsDocumenter
6
+ from janus.converter.translate import Translator
@@ -7,37 +7,11 @@ from langchain.schema import Document
7
7
  from langchain.schema.embeddings import Embeddings
8
8
  from langchain.schema.vectorstore import VST, VectorStore
9
9
 
10
+ from janus.converter.diagram import DiagramGenerator
11
+ from janus.converter.requirements import RequirementsDocumenter
12
+ from janus.converter.translate import Translator
10
13
  from janus.language.block import CodeBlock, TranslatedCodeBlock
11
14
 
12
- from ..diagram import DiagramGenerator
13
- from ..requirements import RequirementsDocumenter
14
- from ..translate import Translator
15
-
16
- # from langchain.vectorstores import Chroma
17
-
18
-
19
- # from ..utils.enums import EmbeddingType
20
-
21
-
22
- def print_query_results(query, n_results):
23
- # print(f"\n{query}")
24
- # count = 1
25
- # for t in n_results:
26
- # short_code = (
27
- # (t[0].page_content[0:50] + "..")
28
- # if (len(t[0].page_content) > 50)
29
- # else t[0].page_content
30
- # )
31
- # return_index = short_code.find("\n")
32
- # if -1 != return_index:
33
- # short_code = short_code[0:return_index] + ".."
34
- # print(
35
- # f"{count}. @ {t[0].metadata['start_line']}-{t[0].metadata['end_line']}"
36
- # f" -- {t[1]} -- {short_code}"
37
- # )
38
- # count += 1
39
- pass
40
-
41
15
 
42
16
  class MockCollection(VectorStore):
43
17
  """Vector store for testing"""
@@ -65,30 +39,23 @@ class MockCollection(VectorStore):
65
39
  raise NotImplementedError("from_texts() not implemented!")
66
40
 
67
41
 
68
- # class MockEmbeddingsFactory(EmbeddingsFactory):
69
- # """Embeddings for testing - uses MockCollection"""
70
- #
71
- # def get_embeddings(self) -> Embeddings:
72
- # return MockCollection()
73
- #
74
-
75
-
76
42
  class TestTranslator(unittest.TestCase):
77
43
  """Tests for the Translator class."""
78
44
 
79
45
  def setUp(self):
80
46
  """Set up the tests."""
81
47
  self.translator = Translator(
82
- model="gpt-4o",
48
+ model="gpt-4o-mini",
83
49
  source_language="fortran",
84
50
  target_language="python",
85
51
  target_version="3.10",
52
+ splitter_type="ast-flex",
86
53
  )
87
54
  self.test_file = Path("janus/language/treesitter/_tests/languages/fortran.f90")
88
55
  self.TEST_FILE_EMBEDDING_COUNT = 14
89
56
 
90
57
  self.req_translator = RequirementsDocumenter(
91
- model="gpt-4o",
58
+ model="gpt-4o-mini",
92
59
  source_language="fortran",
93
60
  prompt_template="requirements",
94
61
  )
@@ -105,200 +72,6 @@ class TestTranslator(unittest.TestCase):
105
72
  # unit tests anyway
106
73
  self.assertTrue(python_file.exists())
107
74
 
108
- # def test_embeddings(self):
109
- # """Testing access to embeddings"""
110
- # vector_store = self.translator.embeddings(EmbeddingType.SOURCE)
111
- # self.assertIsInstance(vector_store, Chroma, "Unexpected vector store type!")
112
- # self.assertEqual(
113
- # 0, vector_store._collection.count(), "Non-empty initial vector store?"
114
- # )
115
- #
116
- # self.translator.set_model("llama")
117
- # self.translator._load_parameters()
118
- # vector_store = self.translator.embeddings(EmbeddingType.SOURCE)
119
- # self.assertIsInstance(vector_store, Chroma)
120
- # self.assertEqual(
121
- # 0, vector_store._collection.count(), "Non-empty initial vector store?"
122
- # )
123
-
124
- # def test_embed_split_source(self):
125
- # """Characterize _embed method"""
126
- # mock_embeddings = MockEmbeddingsFactory()
127
- # self.translator.set_embeddings(mock_embeddings)
128
- # self.translator._load_parameters()
129
- # input_block = self.translator.splitter.split(self.test_file)
130
- # self.assertIsNone(
131
- # input_block.text, "Root node of input text shouldn't contain text"
132
- # )
133
- # self.assertIsNone(input_block.embedding_id, "Precondition failed")
134
- #
135
- # result = self.translator._embed(
136
- # input_block, EmbeddingType.SOURCE, self.test_file.name
137
- # )
138
- #
139
- # self.assertFalse(result, "Nothing to embed, so should have no result")
140
- # self.assertIsNone(
141
- # input_block.embedding_id, "Embeddings should not have changed")
142
-
143
- # def test_embed_has_values_for_each_non_empty_node(self):
144
- # """Characterize our sample fortran file"""
145
- # mock_embeddings = MockEmbeddingsFactory()
146
- # self.translator.set_embeddings(mock_embeddings)
147
- # self.translator._load_parameters()
148
- # input_block = self.translator.splitter.split(self.test_file)
149
- # self.translator._embed_nodes_recursively(
150
- # input_block, EmbeddingType.SOURCE, self.test_file.name
151
- # )
152
- # has_text_count = 0
153
- # has_embeddings_count = 0
154
- # nodes = [input_block]
155
- # while nodes:
156
- # node = nodes.pop(0)
157
- # if node.text:
158
- # has_text_count += 1
159
- # if node.embedding_id:
160
- # has_embeddings_count += 1
161
- # nodes.extend(node.children)
162
- # self.assertEqual(
163
- # self.TEST_FILE_EMBEDDING_COUNT,
164
- # has_text_count,
165
- # "Parsing of test_file has changed!",
166
- # )
167
- # self.assertEqual(
168
- # self.TEST_FILE_EMBEDDING_COUNT,
169
- # has_embeddings_count,
170
- # "Not all non-empty nodes have embeddings!",
171
- # )
172
-
173
- # def test_embed_nodes_recursively(self):
174
- # mock_embeddings = MockEmbeddingsFactory()
175
- # self.translator.set_embeddings(mock_embeddings)
176
- # self.translator._load_parameters()
177
- # input_block = self.translator.splitter.split(self.test_file)
178
- # self.translator._embed_nodes_recursively(
179
- # input_block, EmbeddingType.SOURCE, self.test_file.name
180
- # )
181
- # nodes = [input_block]
182
- # while nodes:
183
- # node = nodes.pop(0)
184
- # self.assertEqual(node.text is not None, node.embedding_id is not None)
185
- # nodes.extend(node.children)
186
-
187
- # @pytest.mark.slow
188
- # def test_translate_file_adds_source_embeddings(self):
189
- # mock_embeddings = MockEmbeddingsFactory()
190
- # self.translator.set_embeddings(mock_embeddings)
191
- # self.translator._load_parameters()
192
- # vector_store = self.translator.embeddings(EmbeddingType.SOURCE)
193
- # self.assertEqual(0, vector_store._add_texts_calls, "precondition")
194
- #
195
- # self.translator.translate_file(self.test_file)
196
- #
197
- # self.assertEqual(
198
- # self.TEST_FILE_EMBEDDING_COUNT,
199
- # vector_store._add_texts_calls,
200
- # "Did not find expected source embeddings",
201
- # )
202
-
203
- # @pytest.mark.slow
204
- # def test_embeddings_usage(self):
205
- # """Noodling on use of embeddings
206
- # To see results have to uncomment print_query_results() above
207
- # """
208
- # input_block = self.translator.splitter.split(self.test_file)
209
- # self.translator._embed_nodes_recursively(
210
- # input_block, EmbeddingType.SOURCE, self.test_file.name
211
- # )
212
- # vector_store = self.translator.embeddings(EmbeddingType.SOURCE)
213
- #
214
- # # this symbol has the lowest relevance scores of any in this test, but
215
- # # still not very low; multiple embedded nodes contain it
216
- # QUERY_STRING = "IWX_BAND_START"
217
- # query = self.translator._embeddings._embeddings.embed_query(QUERY_STRING)
218
- # n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
219
- # embedding=query,
220
- # k=10,
221
- # where_document={"$contains": QUERY_STRING},
222
- # )
223
- # self.assertTrue(len(n_results) > 1, "Why was valid symbol not found?")
224
- # print_query_results(QUERY_STRING, n_results)
225
-
226
- # in the XYZZY test, the least dissimilar results were the start and finish lines
227
- # 0, and 415, which produced a similarity score of 0.47:
228
-
229
- # QUERY_STRING = "XYZZY"
230
- # query = self.translator._embeddings.embed_query(QUERY_STRING)
231
- # n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
232
- # embedding=query,
233
- # k=10,
234
- # # filter={"end_line": 15},
235
- # # filter={"$and": [{"end_line": 15}, {"tokens": {"$gte": 21}}]},
236
- # # where_document={"$contains": QUERY_STRING},
237
- # )
238
- # print_query_results(QUERY_STRING, n_results)
239
- # # self.assertTrue(len(n_results) == 0, "Invalid symbol was found?")
240
-
241
- # # only returns a single result because only 1 embedded node contains
242
- # # CSV_ICASEARR:
243
- # QUERY_STRING = "What is the use of CSV_ICASEARR?"
244
- # query = self.translator._embeddings._embeddings.embed_query(QUERY_STRING)
245
- # n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
246
- # embedding=query,
247
- # k=10,
248
- # # where_document={"$contains": QUERY_STRING},
249
- # where_document={"$contains": "CSV_ICASEARR"},
250
- # )
251
- # print_query_results(QUERY_STRING, n_results)
252
- # self.assertTrue(len(n_results) == 1, "Was splitting changed?")
253
- #
254
- # # trimmed out some characters from line 43, and still not very similar scoring
255
- # QUERY_STRING = "IYL_EDGEBUFFER EDGEBUFFER IGN_MASK CELLSIZE"
256
- # query = self.translator._embeddings._embeddings.embed_query(QUERY_STRING)
257
- # n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
258
- # embedding=query,
259
- # k=10,
260
- # # where_document={"$contains": QUERY_STRING},
261
- # )
262
- # print_query_results(QUERY_STRING, n_results)
263
- #
264
- # # random string (as bad as XYZZY), but searching for a specific line
265
- # QUERY_STRING = "ghost in the invisible moon"
266
- # query = self.translator._embeddings._embeddings.embed_query(QUERY_STRING)
267
- # n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
268
- # embedding=query,
269
- # k=10,
270
- # filter={"$and": [{"end_line": 90}, {"tokens": {"$gte": 21}}]},
271
- # )
272
- # print_query_results(QUERY_STRING, n_results)
273
- # self.assertTrue(len(n_results) == 1, "Was splitting changed?")
274
-
275
- # @pytest.mark.slow
276
- # def test_document_embeddings_added_by_translate(self):
277
- # vector_store = self.req_translator.embeddings(EmbeddingType.REQUIREMENT)
278
- # self.assertEqual(0, vector_store._add_texts_calls, "Precondition failed")
279
- # self.req_translator.translate(self.test_file.parent, self.test_file.parent,
280
- # True)
281
- # self.assertTrue(vector_store._add_texts_calls > 0, "Why no documentation?")
282
-
283
- # @pytest.mark.slow
284
- # def test_embed_requirements(self):
285
- # vector_store = self.req_translator.embeddings(EmbeddingType.REQUIREMENT)
286
- # translated = self.req_translator.translate_file(self.test_file)
287
- # self.assertEqual(
288
- # 0,
289
- # vector_store._add_texts_calls,
290
- # "Unexpected requirements added in translate_file",
291
- # )
292
- # result = self.req_translator._embed(
293
- # translated, EmbeddingType.REQUIREMENT, self.test_file.name
294
- # )
295
- # self.assertFalse(result, "No text in root node, so should generate no docs")
296
- # self.assertIsNotNone(translated.children[0].text, "Data changed?")
297
- # result = self.req_translator._embed(
298
- # translated.children[0], EmbeddingType.REQUIREMENT, self.test_file.name
299
- # )
300
- # self.assertTrue(result, "No docs generated for first child node?")
301
-
302
75
  def test_invalid_selections(self) -> None:
303
76
  """Tests that settings values for the translator will raise exceptions"""
304
77
  self.assertRaises(
@@ -6,7 +6,6 @@ from pathlib import Path
6
6
  from typing import Any
7
7
 
8
8
  from langchain.output_parsers import RetryWithErrorOutputParser
9
- from langchain.output_parsers.fix import OutputFixingParser
10
9
  from langchain_core.exceptions import OutputParserException
11
10
  from langchain_core.language_models import BaseLanguageModel
12
11
  from langchain_core.output_parsers import BaseOutputParser
@@ -29,6 +28,8 @@ from janus.llm import load_model
29
28
  from janus.llm.model_callbacks import get_model_callback
30
29
  from janus.llm.models_info import MODEL_PROMPT_ENGINES
31
30
  from janus.parsers.code_parser import GenericParser
31
+ from janus.parsers.refiner_parser import RefinerParser
32
+ from janus.refiners.refiner import BasicRefiner, Refiner
32
33
  from janus.utils.enums import LANGUAGES
33
34
  from janus.utils.logger import create_logger
34
35
 
@@ -75,6 +76,7 @@ class Converter:
75
76
  protected_node_types: tuple[str, ...] = (),
76
77
  prune_node_types: tuple[str, ...] = (),
77
78
  splitter_type: str = "file",
79
+ refiner_type: str = "basic",
78
80
  ) -> None:
79
81
  """Initialize a Converter instance.
80
82
 
@@ -84,6 +86,17 @@ class Converter:
84
86
  values are `"code"`, `"text"`, `"eval"`, and `None` (default). If `None`,
85
87
  the `Converter` assumes you won't be parsing an output (i.e., adding to an
86
88
  embedding DB).
89
+ max_prompts: The maximum number of prompts to try before giving up.
90
+ max_tokens: The maximum number of tokens to use in the LLM. If `None`, the
91
+ converter will use half the model's token limit.
92
+ prompt_template: The name of the prompt template to use.
93
+ db_path: The path to the database to use for vectorization.
94
+ db_config: The configuration for the database.
95
+ protected_node_types: A set of node types that aren't to be merged.
96
+ prune_node_types: A set of node types which should be pruned.
97
+ splitter_type: The type of splitter to use. Valid values are `"file"`,
98
+ `"tag"`, `"chunk"`, `"ast-strict"`, and `"ast-flex"`.
99
+ refiner_type: The type of refiner to use. Valid values are `"basic"`.
87
100
  """
88
101
  self._changed_attrs: set = set()
89
102
 
@@ -116,7 +129,11 @@ class Converter:
116
129
  self._parser: BaseOutputParser = GenericParser()
117
130
  self._combiner: Combiner = Combiner()
118
131
 
132
+ self._refiner_type: str
133
+ self._refiner: Refiner
134
+
119
135
  self.set_splitter(splitter_type=splitter_type)
136
+ self.set_refiner(refiner_type=refiner_type)
120
137
  self.set_model(model_name=model, **model_arguments)
121
138
  self.set_prompt(prompt_template=prompt_template)
122
139
  self.set_source_language(source_language)
@@ -142,6 +159,7 @@ class Converter:
142
159
  self._load_prompt()
143
160
  self._load_splitter()
144
161
  self._load_vectorizer()
162
+ self._load_refiner()
145
163
  self._changed_attrs.clear()
146
164
 
147
165
  def set_model(self, model_name: str, **custom_arguments: dict[str, Any]):
@@ -179,6 +197,16 @@ class Converter:
179
197
  """
180
198
  self._splitter_type = splitter_type
181
199
 
200
+ def set_refiner(self, refiner_type: str) -> None:
201
+ """Validate and set the refiner name
202
+
203
+ The affected objects will not be updated until translate is called
204
+
205
+ Arguments:
206
+ refiner_type: the name of the refiner to use
207
+ """
208
+ self._refiner_type = refiner_type
209
+
182
210
  def set_source_language(self, source_language: str) -> None:
183
211
  """Validate and set the source language.
184
212
 
@@ -249,10 +277,24 @@ class Converter:
249
277
  )
250
278
 
251
279
  if self._splitter_type == "tag":
252
- kwargs["tag"] = "<ITMOD_ALC_SPLIT>"
280
+ kwargs["tag"] = "<ITMOD_ALC_SPLIT>" # Hardcoded for now
253
281
 
254
282
  self._splitter = CUSTOM_SPLITTERS[self._splitter_type](**kwargs)
255
283
 
284
+ @run_if_changed("_refiner_type", "_model_name")
285
+ def _load_refiner(self) -> None:
286
+ """Load the refiner according to this instance's attributes.
287
+
288
+ If the relevant fields have not been changed since the last time this method was
289
+ called, nothing happens.
290
+ """
291
+ if self._refiner_type == "basic":
292
+ self._refiner = BasicRefiner(
293
+ "basic_refinement", self._model_name, self._source_language
294
+ )
295
+ else:
296
+ raise ValueError(f"Error: unknown refiner type {self._refiner_type}")
297
+
256
298
  @run_if_changed("_model_name", "_custom_model_arguments")
257
299
  def _load_model(self) -> None:
258
300
  """Load the model according to this instance's attributes.
@@ -561,22 +603,22 @@ class Converter:
561
603
  # Retries with just the input
562
604
  n3 = math.ceil(self.max_prompts / (n1 * n2))
563
605
 
564
- fix_format = OutputFixingParser.from_llm(
565
- llm=self._llm,
606
+ refine_output = RefinerParser(
566
607
  parser=self._parser,
608
+ initial_prompt=self._prompt.format(**{"SOURCE_CODE": block.original.text}),
609
+ refiner=self._refiner,
567
610
  max_retries=n1,
611
+ llm=self._llm,
568
612
  )
569
613
  retry = RetryWithErrorOutputParser.from_llm(
570
614
  llm=self._llm,
571
- parser=fix_format,
615
+ parser=refine_output,
572
616
  max_retries=n2,
573
617
  )
574
-
575
618
  completion_chain = self._prompt | self._llm
576
619
  chain = RunnableParallel(
577
620
  completion=completion_chain, prompt_value=self._prompt
578
621
  ) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
579
-
580
622
  for _ in range(n3):
581
623
  try:
582
624
  return chain.invoke({"SOURCE_CODE": block.original.text})
@@ -1,10 +1,14 @@
1
- import json
2
- from copy import deepcopy
1
+ import math
2
+
3
+ from langchain.output_parsers import RetryWithErrorOutputParser
4
+ from langchain_core.exceptions import OutputParserException
5
+ from langchain_core.runnables import RunnableLambda, RunnableParallel
3
6
 
4
7
  from janus.converter.converter import run_if_changed
5
8
  from janus.converter.document import Documenter
6
9
  from janus.language.block import TranslatedCodeBlock
7
10
  from janus.llm.models_info import MODEL_PROMPT_ENGINES
11
+ from janus.parsers.refiner_parser import RefinerParser
8
12
  from janus.parsers.uml import UMLSyntaxParser
9
13
  from janus.utils.logger import create_logger
10
14
 
@@ -47,65 +51,74 @@ class DiagramGenerator(Documenter):
47
51
  self._diagram_prompt_template_name = "diagram"
48
52
  self._load_diagram_prompt_engine()
49
53
 
50
- def _add_translation(self, block: TranslatedCodeBlock) -> None:
51
- """Given an "empty" `TranslatedCodeBlock`, translate the code represented in
52
- `block.original`, setting the relevant fields in the translated block. The
53
- `TranslatedCodeBlock` is updated in-pace, nothing is returned. Note that this
54
- translates *only* the code for this block, not its children.
55
-
56
- Arguments:
57
- block: An empty `TranslatedCodeBlock`
58
- """
59
- if block.translated:
60
- return
61
-
62
- if block.original.text is None:
63
- block.translated = True
64
- return
65
-
66
- if self._add_documentation:
67
- documentation_block = deepcopy(block)
68
- super()._add_translation(documentation_block)
69
- if not documentation_block.translated:
70
- message = "Error: unable to produce documentation for code block"
71
- log.info(message)
72
- raise ValueError(message)
73
- documentation = json.loads(documentation_block.text)["docstring"]
74
-
75
- if self._llm is None:
76
- message = (
77
- "Model not configured correctly, cannot translate. Try setting "
78
- "the model"
79
- )
80
- log.error(message)
81
- raise ValueError(message)
82
-
83
- log.debug(f"[{block.name}] Translating...")
84
- log.debug(f"[{block.name}] Input text:\n{block.original.text}")
85
-
54
+ def _run_chain(self, block: TranslatedCodeBlock) -> str:
86
55
  self._parser.set_reference(block.original)
56
+ n1 = round(self.max_prompts ** (1 / 3))
87
57
 
88
- query_and_parse = self.diagram_prompt | self._llm | self._diagram_parser
58
+ # Retries with the input, output, and error
59
+ n2 = round((self.max_prompts // n1) ** (1 / 2))
60
+
61
+ # Retries with just the input
62
+ n3 = math.ceil(self.max_prompts / (n1 * n2))
89
63
 
90
64
  if self._add_documentation:
91
- block.text = query_and_parse.invoke(
92
- {
93
- "SOURCE_CODE": block.original.text,
94
- "DIAGRAM_TYPE": self._diagram_type,
95
- "DOCUMENTATION": documentation,
96
- }
65
+ documentation_text = super()._run_chain(block)
66
+ refine_output = RefinerParser(
67
+ parser=self._diagram_parser,
68
+ initial_prompt=self._diagram_prompt.format(
69
+ **{
70
+ "SOURCE_CODE": block.original.text,
71
+ "DOCUMENTATION": documentation_text,
72
+ "DIAGRAM_TYPE": self._diagram_type,
73
+ }
74
+ ),
75
+ refiner=self._refiner,
76
+ max_retries=n1,
77
+ llm=self._llm,
97
78
  )
98
79
  else:
99
- block.text = query_and_parse.invoke(
100
- {
101
- "SOURCE_CODE": block.original.text,
102
- "DIAGRAM_TYPE": self._diagram_type,
103
- }
80
+ refine_output = RefinerParser(
81
+ parser=self._diagram_parser,
82
+ initial_prompt=self._diagram_prompt.format(
83
+ **{
84
+ "SOURCE_CODE": block.original.text,
85
+ "DIAGRAM_TYPE": self._diagram_type,
86
+ }
87
+ ),
88
+ refiner=self._refiner,
89
+ max_retries=n1,
90
+ llm=self._llm,
104
91
  )
105
- block.tokens = self._llm.get_num_tokens(block.text)
106
- block.translated = True
107
-
108
- log.debug(f"[{block.name}] Output code:\n{block.text}")
92
+ retry = RetryWithErrorOutputParser.from_llm(
93
+ llm=self._llm,
94
+ parser=refine_output,
95
+ max_retries=n2,
96
+ )
97
+ completion_chain = self._prompt | self._llm
98
+ chain = RunnableParallel(
99
+ completion=completion_chain, prompt_value=self._diagram_prompt
100
+ ) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
101
+ for _ in range(n3):
102
+ try:
103
+ if self._add_documentation:
104
+ return chain.invoke(
105
+ {
106
+ "SOURCE_CODE": block.original.text,
107
+ "DOCUMENTATION": documentation_text,
108
+ "DIAGRAM_TYPE": self._diagram_type,
109
+ }
110
+ )
111
+ else:
112
+ return chain.invoke(
113
+ {
114
+ "SOURCE_CODE": block.original.text,
115
+ "DIAGRAM_TYPE": self._diagram_type,
116
+ }
117
+ )
118
+ except OutputParserException:
119
+ pass
120
+
121
+ raise OutputParserException(f"Failed to parse after {n1*n2*n3} retries")
109
122
 
110
123
  @run_if_changed(
111
124
  "_diagram_prompt_template_name",
@@ -123,4 +136,4 @@ class DiagramGenerator(Documenter):
123
136
  target_version=None,
124
137
  prompt_template=self._diagram_prompt_template_name,
125
138
  )
126
- self.diagram_prompt = self._diagram_prompt_engine.prompt
139
+ self._diagram_prompt = self._diagram_prompt_engine.prompt
@@ -4,8 +4,8 @@ from unittest.mock import MagicMock
4
4
 
5
5
  import pytest
6
6
 
7
- from ...utils.enums import EmbeddingType
8
- from ..collections import Collections
7
+ from janus.embedding.collections import Collections
8
+ from janus.utils.enums import EmbeddingType
9
9
 
10
10
 
11
11
  class TestCollections(unittest.TestCase):
@@ -2,7 +2,7 @@ import unittest
2
2
  from pathlib import Path
3
3
  from unittest.mock import patch
4
4
 
5
- from ..database import ChromaEmbeddingDatabase, uri_to_path
5
+ from janus.embedding.database import ChromaEmbeddingDatabase, uri_to_path
6
6
 
7
7
 
8
8
  class TestDatabase(unittest.TestCase):
@@ -5,9 +5,9 @@ from unittest.mock import MagicMock
5
5
 
6
6
  from chromadb.api.client import Client
7
7
 
8
- from ...language.treesitter import TreeSitterSplitter
9
- from ...utils.enums import EmbeddingType
10
- from ..vectorize import Vectorizer, VectorizerFactory
8
+ from janus.embedding.vectorize import Vectorizer, VectorizerFactory
9
+ from janus.language.treesitter import TreeSitterSplitter
10
+ from janus.utils.enums import EmbeddingType
11
11
 
12
12
 
13
13
  class MockDBVectorizer(VectorizerFactory):
@@ -5,8 +5,8 @@ from typing import Dict, Optional, Sequence
5
5
  from chromadb import Client, Collection
6
6
  from langchain_community.vectorstores import Chroma
7
7
 
8
- from ..utils.enums import EmbeddingType
9
- from .embedding_models_info import load_embedding_model
8
+ from janus.embedding.embedding_models_info import load_embedding_model
9
+ from janus.utils.enums import EmbeddingType
10
10
 
11
11
  # See https://docs.trychroma.com/telemetry#in-chromas-backend-using-environment-variables
12
12
  os.environ["ANONYMIZED_TELEMETRY"] = "False"
@@ -5,7 +5,7 @@ from urllib.request import url2pathname
5
5
 
6
6
  import chromadb
7
7
 
8
- from ..utils.logger import create_logger
8
+ from janus.utils.logger import create_logger
9
9
 
10
10
  log = create_logger(__name__)
11
11
 
@@ -8,7 +8,7 @@ from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEm
8
8
  from langchain_core.embeddings import Embeddings
9
9
  from langchain_openai import OpenAIEmbeddings
10
10
 
11
- from ..utils.logger import create_logger
11
+ from janus.utils.logger import create_logger
12
12
 
13
13
  load_dotenv()
14
14