janus-llm 3.2.0__py3-none-any.whl → 3.3.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (74) hide show
  1. janus/__init__.py +3 -3
  2. janus/_tests/test_cli.py +3 -3
  3. janus/cli.py +1 -1
  4. janus/converter/__init__.py +6 -6
  5. janus/converter/_tests/test_translate.py +6 -233
  6. janus/converter/converter.py +49 -7
  7. janus/converter/diagram.py +68 -55
  8. janus/embedding/_tests/test_collections.py +2 -2
  9. janus/embedding/_tests/test_database.py +1 -1
  10. janus/embedding/_tests/test_vectorize.py +3 -3
  11. janus/embedding/collections.py +2 -2
  12. janus/embedding/database.py +1 -1
  13. janus/embedding/embedding_models_info.py +1 -1
  14. janus/embedding/vectorize.py +5 -5
  15. janus/language/_tests/test_combine.py +1 -1
  16. janus/language/_tests/test_splitter.py +1 -1
  17. janus/language/alc/_tests/test_alc.py +3 -3
  18. janus/language/alc/alc.py +5 -5
  19. janus/language/binary/_tests/test_binary.py +2 -2
  20. janus/language/binary/binary.py +5 -5
  21. janus/language/block.py +2 -2
  22. janus/language/combine.py +3 -3
  23. janus/language/file.py +2 -2
  24. janus/language/mumps/_tests/test_mumps.py +3 -3
  25. janus/language/mumps/mumps.py +5 -5
  26. janus/language/mumps/patterns.py +1 -1
  27. janus/language/naive/__init__.py +4 -4
  28. janus/language/naive/basic_splitter.py +4 -4
  29. janus/language/naive/chunk_splitter.py +4 -4
  30. janus/language/naive/registry.py +1 -1
  31. janus/language/naive/simple_ast.py +5 -5
  32. janus/language/naive/tag_splitter.py +4 -4
  33. janus/language/node.py +1 -1
  34. janus/language/splitter.py +4 -4
  35. janus/language/treesitter/_tests/test_treesitter.py +3 -3
  36. janus/language/treesitter/treesitter.py +4 -4
  37. janus/llm/__init__.py +1 -1
  38. janus/llm/model_callbacks.py +1 -1
  39. janus/llm/models_info.py +5 -3
  40. janus/metrics/_tests/test_bleu.py +1 -1
  41. janus/metrics/_tests/test_chrf.py +1 -1
  42. janus/metrics/_tests/test_file_pairing.py +1 -1
  43. janus/metrics/_tests/test_llm.py +2 -2
  44. janus/metrics/_tests/test_reading.py +1 -1
  45. janus/metrics/_tests/test_rouge_score.py +1 -1
  46. janus/metrics/_tests/test_similarity_score.py +1 -1
  47. janus/metrics/_tests/test_treesitter_metrics.py +2 -2
  48. janus/metrics/bleu.py +1 -1
  49. janus/metrics/chrf.py +1 -1
  50. janus/metrics/complexity_metrics.py +4 -4
  51. janus/metrics/file_pairing.py +5 -5
  52. janus/metrics/llm_metrics.py +1 -1
  53. janus/metrics/metric.py +7 -7
  54. janus/metrics/reading.py +1 -1
  55. janus/metrics/rouge_score.py +1 -1
  56. janus/metrics/similarity.py +2 -2
  57. janus/parsers/_tests/test_code_parser.py +1 -1
  58. janus/parsers/code_parser.py +2 -2
  59. janus/parsers/doc_parser.py +3 -3
  60. janus/parsers/eval_parser.py +2 -2
  61. janus/parsers/refiner_parser.py +49 -0
  62. janus/parsers/reqs_parser.py +3 -3
  63. janus/parsers/uml.py +1 -2
  64. janus/prompts/prompt.py +2 -2
  65. janus/refiners/refiner.py +63 -0
  66. janus/utils/_tests/test_logger.py +1 -1
  67. janus/utils/_tests/test_progress.py +1 -1
  68. janus/utils/progress.py +1 -1
  69. {janus_llm-3.2.0.dist-info → janus_llm-3.3.0.dist-info}/METADATA +1 -1
  70. janus_llm-3.3.0.dist-info/RECORD +107 -0
  71. janus_llm-3.2.0.dist-info/RECORD +0 -105
  72. {janus_llm-3.2.0.dist-info → janus_llm-3.3.0.dist-info}/LICENSE +0 -0
  73. {janus_llm-3.2.0.dist-info → janus_llm-3.3.0.dist-info}/WHEEL +0 -0
  74. {janus_llm-3.2.0.dist-info → janus_llm-3.3.0.dist-info}/entry_points.txt +0 -0
janus/__init__.py CHANGED
@@ -2,10 +2,10 @@ import warnings
2
2
 
3
3
  from langchain_core._api.deprecation import LangChainDeprecationWarning
4
4
 
5
- from .converter.translate import Translator
6
- from .metrics import * # noqa: F403
5
+ from janus.converter.translate import Translator
6
+ from janus.metrics import * # noqa: F403
7
7
 
8
- __version__ = "3.2.0"
8
+ __version__ = "3.3.0"
9
9
 
10
10
  # Ignoring a deprecation warning from langchain_core that I can't seem to hunt down
11
11
  warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
janus/_tests/test_cli.py CHANGED
@@ -4,9 +4,9 @@ from unittest.mock import ANY, patch
4
4
 
5
5
  from typer.testing import CliRunner
6
6
 
7
- from ..cli import app, translate
8
- from ..embedding.embedding_models_info import EMBEDDING_MODEL_CONFIG_DIR
9
- from ..llm.models_info import MODEL_CONFIG_DIR
7
+ from janus.cli import app, translate
8
+ from janus.embedding.embedding_models_info import EMBEDDING_MODEL_CONFIG_DIR
9
+ from janus.llm.models_info import MODEL_CONFIG_DIR
10
10
 
11
11
 
12
12
  class TestCli(unittest.TestCase):
janus/cli.py CHANGED
@@ -108,7 +108,7 @@ embedding = typer.Typer(
108
108
 
109
109
  def version_callback(value: bool) -> None:
110
110
  if value:
111
- from . import __version__ as version
111
+ from janus import __version__ as version
112
112
 
113
113
  print(f"Janus CLI [blue]v{version}[/blue]")
114
114
  raise typer.Exit()
@@ -1,6 +1,6 @@
1
- from .converter import Converter
2
- from .diagram import DiagramGenerator
3
- from .document import Documenter, MadLibsDocumenter, MultiDocumenter
4
- from .evaluate import Evaluator
5
- from .requirements import RequirementsDocumenter
6
- from .translate import Translator
1
+ from janus.converter.converter import Converter
2
+ from janus.converter.diagram import DiagramGenerator
3
+ from janus.converter.document import Documenter, MadLibsDocumenter, MultiDocumenter
4
+ from janus.converter.evaluate import Evaluator
5
+ from janus.converter.requirements import RequirementsDocumenter
6
+ from janus.converter.translate import Translator
@@ -7,37 +7,11 @@ from langchain.schema import Document
7
7
  from langchain.schema.embeddings import Embeddings
8
8
  from langchain.schema.vectorstore import VST, VectorStore
9
9
 
10
+ from janus.converter.diagram import DiagramGenerator
11
+ from janus.converter.requirements import RequirementsDocumenter
12
+ from janus.converter.translate import Translator
10
13
  from janus.language.block import CodeBlock, TranslatedCodeBlock
11
14
 
12
- from ..diagram import DiagramGenerator
13
- from ..requirements import RequirementsDocumenter
14
- from ..translate import Translator
15
-
16
- # from langchain.vectorstores import Chroma
17
-
18
-
19
- # from ..utils.enums import EmbeddingType
20
-
21
-
22
- def print_query_results(query, n_results):
23
- # print(f"\n{query}")
24
- # count = 1
25
- # for t in n_results:
26
- # short_code = (
27
- # (t[0].page_content[0:50] + "..")
28
- # if (len(t[0].page_content) > 50)
29
- # else t[0].page_content
30
- # )
31
- # return_index = short_code.find("\n")
32
- # if -1 != return_index:
33
- # short_code = short_code[0:return_index] + ".."
34
- # print(
35
- # f"{count}. @ {t[0].metadata['start_line']}-{t[0].metadata['end_line']}"
36
- # f" -- {t[1]} -- {short_code}"
37
- # )
38
- # count += 1
39
- pass
40
-
41
15
 
42
16
  class MockCollection(VectorStore):
43
17
  """Vector store for testing"""
@@ -65,30 +39,23 @@ class MockCollection(VectorStore):
65
39
  raise NotImplementedError("from_texts() not implemented!")
66
40
 
67
41
 
68
- # class MockEmbeddingsFactory(EmbeddingsFactory):
69
- # """Embeddings for testing - uses MockCollection"""
70
- #
71
- # def get_embeddings(self) -> Embeddings:
72
- # return MockCollection()
73
- #
74
-
75
-
76
42
  class TestTranslator(unittest.TestCase):
77
43
  """Tests for the Translator class."""
78
44
 
79
45
  def setUp(self):
80
46
  """Set up the tests."""
81
47
  self.translator = Translator(
82
- model="gpt-4o",
48
+ model="gpt-4o-mini",
83
49
  source_language="fortran",
84
50
  target_language="python",
85
51
  target_version="3.10",
52
+ splitter_type="ast-flex",
86
53
  )
87
54
  self.test_file = Path("janus/language/treesitter/_tests/languages/fortran.f90")
88
55
  self.TEST_FILE_EMBEDDING_COUNT = 14
89
56
 
90
57
  self.req_translator = RequirementsDocumenter(
91
- model="gpt-4o",
58
+ model="gpt-4o-mini",
92
59
  source_language="fortran",
93
60
  prompt_template="requirements",
94
61
  )
@@ -105,200 +72,6 @@ class TestTranslator(unittest.TestCase):
105
72
  # unit tests anyway
106
73
  self.assertTrue(python_file.exists())
107
74
 
108
- # def test_embeddings(self):
109
- # """Testing access to embeddings"""
110
- # vector_store = self.translator.embeddings(EmbeddingType.SOURCE)
111
- # self.assertIsInstance(vector_store, Chroma, "Unexpected vector store type!")
112
- # self.assertEqual(
113
- # 0, vector_store._collection.count(), "Non-empty initial vector store?"
114
- # )
115
- #
116
- # self.translator.set_model("llama")
117
- # self.translator._load_parameters()
118
- # vector_store = self.translator.embeddings(EmbeddingType.SOURCE)
119
- # self.assertIsInstance(vector_store, Chroma)
120
- # self.assertEqual(
121
- # 0, vector_store._collection.count(), "Non-empty initial vector store?"
122
- # )
123
-
124
- # def test_embed_split_source(self):
125
- # """Characterize _embed method"""
126
- # mock_embeddings = MockEmbeddingsFactory()
127
- # self.translator.set_embeddings(mock_embeddings)
128
- # self.translator._load_parameters()
129
- # input_block = self.translator.splitter.split(self.test_file)
130
- # self.assertIsNone(
131
- # input_block.text, "Root node of input text shouldn't contain text"
132
- # )
133
- # self.assertIsNone(input_block.embedding_id, "Precondition failed")
134
- #
135
- # result = self.translator._embed(
136
- # input_block, EmbeddingType.SOURCE, self.test_file.name
137
- # )
138
- #
139
- # self.assertFalse(result, "Nothing to embed, so should have no result")
140
- # self.assertIsNone(
141
- # input_block.embedding_id, "Embeddings should not have changed")
142
-
143
- # def test_embed_has_values_for_each_non_empty_node(self):
144
- # """Characterize our sample fortran file"""
145
- # mock_embeddings = MockEmbeddingsFactory()
146
- # self.translator.set_embeddings(mock_embeddings)
147
- # self.translator._load_parameters()
148
- # input_block = self.translator.splitter.split(self.test_file)
149
- # self.translator._embed_nodes_recursively(
150
- # input_block, EmbeddingType.SOURCE, self.test_file.name
151
- # )
152
- # has_text_count = 0
153
- # has_embeddings_count = 0
154
- # nodes = [input_block]
155
- # while nodes:
156
- # node = nodes.pop(0)
157
- # if node.text:
158
- # has_text_count += 1
159
- # if node.embedding_id:
160
- # has_embeddings_count += 1
161
- # nodes.extend(node.children)
162
- # self.assertEqual(
163
- # self.TEST_FILE_EMBEDDING_COUNT,
164
- # has_text_count,
165
- # "Parsing of test_file has changed!",
166
- # )
167
- # self.assertEqual(
168
- # self.TEST_FILE_EMBEDDING_COUNT,
169
- # has_embeddings_count,
170
- # "Not all non-empty nodes have embeddings!",
171
- # )
172
-
173
- # def test_embed_nodes_recursively(self):
174
- # mock_embeddings = MockEmbeddingsFactory()
175
- # self.translator.set_embeddings(mock_embeddings)
176
- # self.translator._load_parameters()
177
- # input_block = self.translator.splitter.split(self.test_file)
178
- # self.translator._embed_nodes_recursively(
179
- # input_block, EmbeddingType.SOURCE, self.test_file.name
180
- # )
181
- # nodes = [input_block]
182
- # while nodes:
183
- # node = nodes.pop(0)
184
- # self.assertEqual(node.text is not None, node.embedding_id is not None)
185
- # nodes.extend(node.children)
186
-
187
- # @pytest.mark.slow
188
- # def test_translate_file_adds_source_embeddings(self):
189
- # mock_embeddings = MockEmbeddingsFactory()
190
- # self.translator.set_embeddings(mock_embeddings)
191
- # self.translator._load_parameters()
192
- # vector_store = self.translator.embeddings(EmbeddingType.SOURCE)
193
- # self.assertEqual(0, vector_store._add_texts_calls, "precondition")
194
- #
195
- # self.translator.translate_file(self.test_file)
196
- #
197
- # self.assertEqual(
198
- # self.TEST_FILE_EMBEDDING_COUNT,
199
- # vector_store._add_texts_calls,
200
- # "Did not find expected source embeddings",
201
- # )
202
-
203
- # @pytest.mark.slow
204
- # def test_embeddings_usage(self):
205
- # """Noodling on use of embeddings
206
- # To see results have to uncomment print_query_results() above
207
- # """
208
- # input_block = self.translator.splitter.split(self.test_file)
209
- # self.translator._embed_nodes_recursively(
210
- # input_block, EmbeddingType.SOURCE, self.test_file.name
211
- # )
212
- # vector_store = self.translator.embeddings(EmbeddingType.SOURCE)
213
- #
214
- # # this symbol has the lowest relevance scores of any in this test, but
215
- # # still not very low; multiple embedded nodes contain it
216
- # QUERY_STRING = "IWX_BAND_START"
217
- # query = self.translator._embeddings._embeddings.embed_query(QUERY_STRING)
218
- # n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
219
- # embedding=query,
220
- # k=10,
221
- # where_document={"$contains": QUERY_STRING},
222
- # )
223
- # self.assertTrue(len(n_results) > 1, "Why was valid symbol not found?")
224
- # print_query_results(QUERY_STRING, n_results)
225
-
226
- # in the XYZZY test, the least dissimilar results were the start and finish lines
227
- # 0, and 415, which produced a similarity score of 0.47:
228
-
229
- # QUERY_STRING = "XYZZY"
230
- # query = self.translator._embeddings.embed_query(QUERY_STRING)
231
- # n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
232
- # embedding=query,
233
- # k=10,
234
- # # filter={"end_line": 15},
235
- # # filter={"$and": [{"end_line": 15}, {"tokens": {"$gte": 21}}]},
236
- # # where_document={"$contains": QUERY_STRING},
237
- # )
238
- # print_query_results(QUERY_STRING, n_results)
239
- # # self.assertTrue(len(n_results) == 0, "Invalid symbol was found?")
240
-
241
- # # only returns a single result because only 1 embedded node contains
242
- # # CSV_ICASEARR:
243
- # QUERY_STRING = "What is the use of CSV_ICASEARR?"
244
- # query = self.translator._embeddings._embeddings.embed_query(QUERY_STRING)
245
- # n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
246
- # embedding=query,
247
- # k=10,
248
- # # where_document={"$contains": QUERY_STRING},
249
- # where_document={"$contains": "CSV_ICASEARR"},
250
- # )
251
- # print_query_results(QUERY_STRING, n_results)
252
- # self.assertTrue(len(n_results) == 1, "Was splitting changed?")
253
- #
254
- # # trimmed out some characters from line 43, and still not very similar scoring
255
- # QUERY_STRING = "IYL_EDGEBUFFER EDGEBUFFER IGN_MASK CELLSIZE"
256
- # query = self.translator._embeddings._embeddings.embed_query(QUERY_STRING)
257
- # n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
258
- # embedding=query,
259
- # k=10,
260
- # # where_document={"$contains": QUERY_STRING},
261
- # )
262
- # print_query_results(QUERY_STRING, n_results)
263
- #
264
- # # random string (as bad as XYZZY), but searching for a specific line
265
- # QUERY_STRING = "ghost in the invisible moon"
266
- # query = self.translator._embeddings._embeddings.embed_query(QUERY_STRING)
267
- # n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
268
- # embedding=query,
269
- # k=10,
270
- # filter={"$and": [{"end_line": 90}, {"tokens": {"$gte": 21}}]},
271
- # )
272
- # print_query_results(QUERY_STRING, n_results)
273
- # self.assertTrue(len(n_results) == 1, "Was splitting changed?")
274
-
275
- # @pytest.mark.slow
276
- # def test_document_embeddings_added_by_translate(self):
277
- # vector_store = self.req_translator.embeddings(EmbeddingType.REQUIREMENT)
278
- # self.assertEqual(0, vector_store._add_texts_calls, "Precondition failed")
279
- # self.req_translator.translate(self.test_file.parent, self.test_file.parent,
280
- # True)
281
- # self.assertTrue(vector_store._add_texts_calls > 0, "Why no documentation?")
282
-
283
- # @pytest.mark.slow
284
- # def test_embed_requirements(self):
285
- # vector_store = self.req_translator.embeddings(EmbeddingType.REQUIREMENT)
286
- # translated = self.req_translator.translate_file(self.test_file)
287
- # self.assertEqual(
288
- # 0,
289
- # vector_store._add_texts_calls,
290
- # "Unexpected requirements added in translate_file",
291
- # )
292
- # result = self.req_translator._embed(
293
- # translated, EmbeddingType.REQUIREMENT, self.test_file.name
294
- # )
295
- # self.assertFalse(result, "No text in root node, so should generate no docs")
296
- # self.assertIsNotNone(translated.children[0].text, "Data changed?")
297
- # result = self.req_translator._embed(
298
- # translated.children[0], EmbeddingType.REQUIREMENT, self.test_file.name
299
- # )
300
- # self.assertTrue(result, "No docs generated for first child node?")
301
-
302
75
  def test_invalid_selections(self) -> None:
303
76
  """Tests that settings values for the translator will raise exceptions"""
304
77
  self.assertRaises(
@@ -6,7 +6,6 @@ from pathlib import Path
6
6
  from typing import Any
7
7
 
8
8
  from langchain.output_parsers import RetryWithErrorOutputParser
9
- from langchain.output_parsers.fix import OutputFixingParser
10
9
  from langchain_core.exceptions import OutputParserException
11
10
  from langchain_core.language_models import BaseLanguageModel
12
11
  from langchain_core.output_parsers import BaseOutputParser
@@ -29,6 +28,8 @@ from janus.llm import load_model
29
28
  from janus.llm.model_callbacks import get_model_callback
30
29
  from janus.llm.models_info import MODEL_PROMPT_ENGINES
31
30
  from janus.parsers.code_parser import GenericParser
31
+ from janus.parsers.refiner_parser import RefinerParser
32
+ from janus.refiners.refiner import BasicRefiner, Refiner
32
33
  from janus.utils.enums import LANGUAGES
33
34
  from janus.utils.logger import create_logger
34
35
 
@@ -75,6 +76,7 @@ class Converter:
75
76
  protected_node_types: tuple[str, ...] = (),
76
77
  prune_node_types: tuple[str, ...] = (),
77
78
  splitter_type: str = "file",
79
+ refiner_type: str = "basic",
78
80
  ) -> None:
79
81
  """Initialize a Converter instance.
80
82
 
@@ -84,6 +86,17 @@ class Converter:
84
86
  values are `"code"`, `"text"`, `"eval"`, and `None` (default). If `None`,
85
87
  the `Converter` assumes you won't be parsing an output (i.e., adding to an
86
88
  embedding DB).
89
+ max_prompts: The maximum number of prompts to try before giving up.
90
+ max_tokens: The maximum number of tokens to use in the LLM. If `None`, the
91
+ converter will use half the model's token limit.
92
+ prompt_template: The name of the prompt template to use.
93
+ db_path: The path to the database to use for vectorization.
94
+ db_config: The configuration for the database.
95
+ protected_node_types: A set of node types that aren't to be merged.
96
+ prune_node_types: A set of node types which should be pruned.
97
+ splitter_type: The type of splitter to use. Valid values are `"file"`,
98
+ `"tag"`, `"chunk"`, `"ast-strict"`, and `"ast-flex"`.
99
+ refiner_type: The type of refiner to use. Valid values are `"basic"`.
87
100
  """
88
101
  self._changed_attrs: set = set()
89
102
 
@@ -116,7 +129,11 @@ class Converter:
116
129
  self._parser: BaseOutputParser = GenericParser()
117
130
  self._combiner: Combiner = Combiner()
118
131
 
132
+ self._refiner_type: str
133
+ self._refiner: Refiner
134
+
119
135
  self.set_splitter(splitter_type=splitter_type)
136
+ self.set_refiner(refiner_type=refiner_type)
120
137
  self.set_model(model_name=model, **model_arguments)
121
138
  self.set_prompt(prompt_template=prompt_template)
122
139
  self.set_source_language(source_language)
@@ -142,6 +159,7 @@ class Converter:
142
159
  self._load_prompt()
143
160
  self._load_splitter()
144
161
  self._load_vectorizer()
162
+ self._load_refiner()
145
163
  self._changed_attrs.clear()
146
164
 
147
165
  def set_model(self, model_name: str, **custom_arguments: dict[str, Any]):
@@ -179,6 +197,16 @@ class Converter:
179
197
  """
180
198
  self._splitter_type = splitter_type
181
199
 
200
+ def set_refiner(self, refiner_type: str) -> None:
201
+ """Validate and set the refiner name
202
+
203
+ The affected objects will not be updated until translate is called
204
+
205
+ Arguments:
206
+ refiner_type: the name of the refiner to use
207
+ """
208
+ self._refiner_type = refiner_type
209
+
182
210
  def set_source_language(self, source_language: str) -> None:
183
211
  """Validate and set the source language.
184
212
 
@@ -249,10 +277,24 @@ class Converter:
249
277
  )
250
278
 
251
279
  if self._splitter_type == "tag":
252
- kwargs["tag"] = "<ITMOD_ALC_SPLIT>"
280
+ kwargs["tag"] = "<ITMOD_ALC_SPLIT>" # Hardcoded for now
253
281
 
254
282
  self._splitter = CUSTOM_SPLITTERS[self._splitter_type](**kwargs)
255
283
 
284
+ @run_if_changed("_refiner_type", "_model_name")
285
+ def _load_refiner(self) -> None:
286
+ """Load the refiner according to this instance's attributes.
287
+
288
+ If the relevant fields have not been changed since the last time this method was
289
+ called, nothing happens.
290
+ """
291
+ if self._refiner_type == "basic":
292
+ self._refiner = BasicRefiner(
293
+ "basic_refinement", self._model_name, self._source_language
294
+ )
295
+ else:
296
+ raise ValueError(f"Error: unknown refiner type {self._refiner_type}")
297
+
256
298
  @run_if_changed("_model_name", "_custom_model_arguments")
257
299
  def _load_model(self) -> None:
258
300
  """Load the model according to this instance's attributes.
@@ -561,22 +603,22 @@ class Converter:
561
603
  # Retries with just the input
562
604
  n3 = math.ceil(self.max_prompts / (n1 * n2))
563
605
 
564
- fix_format = OutputFixingParser.from_llm(
565
- llm=self._llm,
606
+ refine_output = RefinerParser(
566
607
  parser=self._parser,
608
+ initial_prompt=self._prompt.format(**{"SOURCE_CODE": block.original.text}),
609
+ refiner=self._refiner,
567
610
  max_retries=n1,
611
+ llm=self._llm,
568
612
  )
569
613
  retry = RetryWithErrorOutputParser.from_llm(
570
614
  llm=self._llm,
571
- parser=fix_format,
615
+ parser=refine_output,
572
616
  max_retries=n2,
573
617
  )
574
-
575
618
  completion_chain = self._prompt | self._llm
576
619
  chain = RunnableParallel(
577
620
  completion=completion_chain, prompt_value=self._prompt
578
621
  ) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
579
-
580
622
  for _ in range(n3):
581
623
  try:
582
624
  return chain.invoke({"SOURCE_CODE": block.original.text})
@@ -1,10 +1,14 @@
1
- import json
2
- from copy import deepcopy
1
+ import math
2
+
3
+ from langchain.output_parsers import RetryWithErrorOutputParser
4
+ from langchain_core.exceptions import OutputParserException
5
+ from langchain_core.runnables import RunnableLambda, RunnableParallel
3
6
 
4
7
  from janus.converter.converter import run_if_changed
5
8
  from janus.converter.document import Documenter
6
9
  from janus.language.block import TranslatedCodeBlock
7
10
  from janus.llm.models_info import MODEL_PROMPT_ENGINES
11
+ from janus.parsers.refiner_parser import RefinerParser
8
12
  from janus.parsers.uml import UMLSyntaxParser
9
13
  from janus.utils.logger import create_logger
10
14
 
@@ -47,65 +51,74 @@ class DiagramGenerator(Documenter):
47
51
  self._diagram_prompt_template_name = "diagram"
48
52
  self._load_diagram_prompt_engine()
49
53
 
50
- def _add_translation(self, block: TranslatedCodeBlock) -> None:
51
- """Given an "empty" `TranslatedCodeBlock`, translate the code represented in
52
- `block.original`, setting the relevant fields in the translated block. The
53
- `TranslatedCodeBlock` is updated in-pace, nothing is returned. Note that this
54
- translates *only* the code for this block, not its children.
55
-
56
- Arguments:
57
- block: An empty `TranslatedCodeBlock`
58
- """
59
- if block.translated:
60
- return
61
-
62
- if block.original.text is None:
63
- block.translated = True
64
- return
65
-
66
- if self._add_documentation:
67
- documentation_block = deepcopy(block)
68
- super()._add_translation(documentation_block)
69
- if not documentation_block.translated:
70
- message = "Error: unable to produce documentation for code block"
71
- log.info(message)
72
- raise ValueError(message)
73
- documentation = json.loads(documentation_block.text)["docstring"]
74
-
75
- if self._llm is None:
76
- message = (
77
- "Model not configured correctly, cannot translate. Try setting "
78
- "the model"
79
- )
80
- log.error(message)
81
- raise ValueError(message)
82
-
83
- log.debug(f"[{block.name}] Translating...")
84
- log.debug(f"[{block.name}] Input text:\n{block.original.text}")
85
-
54
+ def _run_chain(self, block: TranslatedCodeBlock) -> str:
86
55
  self._parser.set_reference(block.original)
56
+ n1 = round(self.max_prompts ** (1 / 3))
87
57
 
88
- query_and_parse = self.diagram_prompt | self._llm | self._diagram_parser
58
+ # Retries with the input, output, and error
59
+ n2 = round((self.max_prompts // n1) ** (1 / 2))
60
+
61
+ # Retries with just the input
62
+ n3 = math.ceil(self.max_prompts / (n1 * n2))
89
63
 
90
64
  if self._add_documentation:
91
- block.text = query_and_parse.invoke(
92
- {
93
- "SOURCE_CODE": block.original.text,
94
- "DIAGRAM_TYPE": self._diagram_type,
95
- "DOCUMENTATION": documentation,
96
- }
65
+ documentation_text = super()._run_chain(block)
66
+ refine_output = RefinerParser(
67
+ parser=self._diagram_parser,
68
+ initial_prompt=self._diagram_prompt.format(
69
+ **{
70
+ "SOURCE_CODE": block.original.text,
71
+ "DOCUMENTATION": documentation_text,
72
+ "DIAGRAM_TYPE": self._diagram_type,
73
+ }
74
+ ),
75
+ refiner=self._refiner,
76
+ max_retries=n1,
77
+ llm=self._llm,
97
78
  )
98
79
  else:
99
- block.text = query_and_parse.invoke(
100
- {
101
- "SOURCE_CODE": block.original.text,
102
- "DIAGRAM_TYPE": self._diagram_type,
103
- }
80
+ refine_output = RefinerParser(
81
+ parser=self._diagram_parser,
82
+ initial_prompt=self._diagram_prompt.format(
83
+ **{
84
+ "SOURCE_CODE": block.original.text,
85
+ "DIAGRAM_TYPE": self._diagram_type,
86
+ }
87
+ ),
88
+ refiner=self._refiner,
89
+ max_retries=n1,
90
+ llm=self._llm,
104
91
  )
105
- block.tokens = self._llm.get_num_tokens(block.text)
106
- block.translated = True
107
-
108
- log.debug(f"[{block.name}] Output code:\n{block.text}")
92
+ retry = RetryWithErrorOutputParser.from_llm(
93
+ llm=self._llm,
94
+ parser=refine_output,
95
+ max_retries=n2,
96
+ )
97
+ completion_chain = self._prompt | self._llm
98
+ chain = RunnableParallel(
99
+ completion=completion_chain, prompt_value=self._diagram_prompt
100
+ ) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
101
+ for _ in range(n3):
102
+ try:
103
+ if self._add_documentation:
104
+ return chain.invoke(
105
+ {
106
+ "SOURCE_CODE": block.original.text,
107
+ "DOCUMENTATION": documentation_text,
108
+ "DIAGRAM_TYPE": self._diagram_type,
109
+ }
110
+ )
111
+ else:
112
+ return chain.invoke(
113
+ {
114
+ "SOURCE_CODE": block.original.text,
115
+ "DIAGRAM_TYPE": self._diagram_type,
116
+ }
117
+ )
118
+ except OutputParserException:
119
+ pass
120
+
121
+ raise OutputParserException(f"Failed to parse after {n1*n2*n3} retries")
109
122
 
110
123
  @run_if_changed(
111
124
  "_diagram_prompt_template_name",
@@ -123,4 +136,4 @@ class DiagramGenerator(Documenter):
123
136
  target_version=None,
124
137
  prompt_template=self._diagram_prompt_template_name,
125
138
  )
126
- self.diagram_prompt = self._diagram_prompt_engine.prompt
139
+ self._diagram_prompt = self._diagram_prompt_engine.prompt
@@ -4,8 +4,8 @@ from unittest.mock import MagicMock
4
4
 
5
5
  import pytest
6
6
 
7
- from ...utils.enums import EmbeddingType
8
- from ..collections import Collections
7
+ from janus.embedding.collections import Collections
8
+ from janus.utils.enums import EmbeddingType
9
9
 
10
10
 
11
11
  class TestCollections(unittest.TestCase):
@@ -2,7 +2,7 @@ import unittest
2
2
  from pathlib import Path
3
3
  from unittest.mock import patch
4
4
 
5
- from ..database import ChromaEmbeddingDatabase, uri_to_path
5
+ from janus.embedding.database import ChromaEmbeddingDatabase, uri_to_path
6
6
 
7
7
 
8
8
  class TestDatabase(unittest.TestCase):
@@ -5,9 +5,9 @@ from unittest.mock import MagicMock
5
5
 
6
6
  from chromadb.api.client import Client
7
7
 
8
- from ...language.treesitter import TreeSitterSplitter
9
- from ...utils.enums import EmbeddingType
10
- from ..vectorize import Vectorizer, VectorizerFactory
8
+ from janus.embedding.vectorize import Vectorizer, VectorizerFactory
9
+ from janus.language.treesitter import TreeSitterSplitter
10
+ from janus.utils.enums import EmbeddingType
11
11
 
12
12
 
13
13
  class MockDBVectorizer(VectorizerFactory):
@@ -5,8 +5,8 @@ from typing import Dict, Optional, Sequence
5
5
  from chromadb import Client, Collection
6
6
  from langchain_community.vectorstores import Chroma
7
7
 
8
- from ..utils.enums import EmbeddingType
9
- from .embedding_models_info import load_embedding_model
8
+ from janus.embedding.embedding_models_info import load_embedding_model
9
+ from janus.utils.enums import EmbeddingType
10
10
 
11
11
  # See https://docs.trychroma.com/telemetry#in-chromas-backend-using-environment-variables
12
12
  os.environ["ANONYMIZED_TELEMETRY"] = "False"
@@ -5,7 +5,7 @@ from urllib.request import url2pathname
5
5
 
6
6
  import chromadb
7
7
 
8
- from ..utils.logger import create_logger
8
+ from janus.utils.logger import create_logger
9
9
 
10
10
  log = create_logger(__name__)
11
11
 
@@ -8,7 +8,7 @@ from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEm
8
8
  from langchain_core.embeddings import Embeddings
9
9
  from langchain_openai import OpenAIEmbeddings
10
10
 
11
- from ..utils.logger import create_logger
11
+ from janus.utils.logger import create_logger
12
12
 
13
13
  load_dotenv()
14
14