janus-llm 3.2.0__py3-none-any.whl → 3.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +3 -3
- janus/_tests/test_cli.py +3 -3
- janus/cli.py +1 -1
- janus/converter/__init__.py +6 -6
- janus/converter/_tests/test_translate.py +6 -233
- janus/converter/converter.py +49 -7
- janus/converter/diagram.py +68 -55
- janus/embedding/_tests/test_collections.py +2 -2
- janus/embedding/_tests/test_database.py +1 -1
- janus/embedding/_tests/test_vectorize.py +3 -3
- janus/embedding/collections.py +2 -2
- janus/embedding/database.py +1 -1
- janus/embedding/embedding_models_info.py +1 -1
- janus/embedding/vectorize.py +5 -5
- janus/language/_tests/test_combine.py +1 -1
- janus/language/_tests/test_splitter.py +1 -1
- janus/language/alc/_tests/test_alc.py +3 -3
- janus/language/alc/alc.py +5 -5
- janus/language/binary/_tests/test_binary.py +2 -2
- janus/language/binary/binary.py +5 -5
- janus/language/block.py +2 -2
- janus/language/combine.py +3 -3
- janus/language/file.py +2 -2
- janus/language/mumps/_tests/test_mumps.py +3 -3
- janus/language/mumps/mumps.py +5 -5
- janus/language/mumps/patterns.py +1 -1
- janus/language/naive/__init__.py +4 -4
- janus/language/naive/basic_splitter.py +4 -4
- janus/language/naive/chunk_splitter.py +4 -4
- janus/language/naive/registry.py +1 -1
- janus/language/naive/simple_ast.py +5 -5
- janus/language/naive/tag_splitter.py +4 -4
- janus/language/node.py +1 -1
- janus/language/splitter.py +4 -4
- janus/language/treesitter/_tests/test_treesitter.py +3 -3
- janus/language/treesitter/treesitter.py +4 -4
- janus/llm/__init__.py +1 -1
- janus/llm/model_callbacks.py +1 -1
- janus/llm/models_info.py +5 -3
- janus/metrics/_tests/test_bleu.py +1 -1
- janus/metrics/_tests/test_chrf.py +1 -1
- janus/metrics/_tests/test_file_pairing.py +1 -1
- janus/metrics/_tests/test_llm.py +2 -2
- janus/metrics/_tests/test_reading.py +1 -1
- janus/metrics/_tests/test_rouge_score.py +1 -1
- janus/metrics/_tests/test_similarity_score.py +1 -1
- janus/metrics/_tests/test_treesitter_metrics.py +2 -2
- janus/metrics/bleu.py +1 -1
- janus/metrics/chrf.py +1 -1
- janus/metrics/complexity_metrics.py +4 -4
- janus/metrics/file_pairing.py +5 -5
- janus/metrics/llm_metrics.py +1 -1
- janus/metrics/metric.py +7 -7
- janus/metrics/reading.py +1 -1
- janus/metrics/rouge_score.py +1 -1
- janus/metrics/similarity.py +2 -2
- janus/parsers/_tests/test_code_parser.py +1 -1
- janus/parsers/code_parser.py +2 -2
- janus/parsers/doc_parser.py +3 -3
- janus/parsers/eval_parser.py +2 -2
- janus/parsers/refiner_parser.py +49 -0
- janus/parsers/reqs_parser.py +3 -3
- janus/parsers/uml.py +1 -2
- janus/prompts/prompt.py +2 -2
- janus/refiners/refiner.py +63 -0
- janus/utils/_tests/test_logger.py +1 -1
- janus/utils/_tests/test_progress.py +1 -1
- janus/utils/progress.py +1 -1
- {janus_llm-3.2.0.dist-info → janus_llm-3.3.0.dist-info}/METADATA +1 -1
- janus_llm-3.3.0.dist-info/RECORD +107 -0
- janus_llm-3.2.0.dist-info/RECORD +0 -105
- {janus_llm-3.2.0.dist-info → janus_llm-3.3.0.dist-info}/LICENSE +0 -0
- {janus_llm-3.2.0.dist-info → janus_llm-3.3.0.dist-info}/WHEEL +0 -0
- {janus_llm-3.2.0.dist-info → janus_llm-3.3.0.dist-info}/entry_points.txt +0 -0
janus/__init__.py
CHANGED
@@ -2,10 +2,10 @@ import warnings
|
|
2
2
|
|
3
3
|
from langchain_core._api.deprecation import LangChainDeprecationWarning
|
4
4
|
|
5
|
-
from .converter.translate import Translator
|
6
|
-
from .metrics import * # noqa: F403
|
5
|
+
from janus.converter.translate import Translator
|
6
|
+
from janus.metrics import * # noqa: F403
|
7
7
|
|
8
|
-
__version__ = "3.
|
8
|
+
__version__ = "3.3.0"
|
9
9
|
|
10
10
|
# Ignoring a deprecation warning from langchain_core that I can't seem to hunt down
|
11
11
|
warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
|
janus/_tests/test_cli.py
CHANGED
@@ -4,9 +4,9 @@ from unittest.mock import ANY, patch
|
|
4
4
|
|
5
5
|
from typer.testing import CliRunner
|
6
6
|
|
7
|
-
from
|
8
|
-
from
|
9
|
-
from
|
7
|
+
from janus.cli import app, translate
|
8
|
+
from janus.embedding.embedding_models_info import EMBEDDING_MODEL_CONFIG_DIR
|
9
|
+
from janus.llm.models_info import MODEL_CONFIG_DIR
|
10
10
|
|
11
11
|
|
12
12
|
class TestCli(unittest.TestCase):
|
janus/cli.py
CHANGED
janus/converter/__init__.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
|
-
from .converter import Converter
|
2
|
-
from .diagram import DiagramGenerator
|
3
|
-
from .document import Documenter, MadLibsDocumenter, MultiDocumenter
|
4
|
-
from .evaluate import Evaluator
|
5
|
-
from .requirements import RequirementsDocumenter
|
6
|
-
from .translate import Translator
|
1
|
+
from janus.converter.converter import Converter
|
2
|
+
from janus.converter.diagram import DiagramGenerator
|
3
|
+
from janus.converter.document import Documenter, MadLibsDocumenter, MultiDocumenter
|
4
|
+
from janus.converter.evaluate import Evaluator
|
5
|
+
from janus.converter.requirements import RequirementsDocumenter
|
6
|
+
from janus.converter.translate import Translator
|
@@ -7,37 +7,11 @@ from langchain.schema import Document
|
|
7
7
|
from langchain.schema.embeddings import Embeddings
|
8
8
|
from langchain.schema.vectorstore import VST, VectorStore
|
9
9
|
|
10
|
+
from janus.converter.diagram import DiagramGenerator
|
11
|
+
from janus.converter.requirements import RequirementsDocumenter
|
12
|
+
from janus.converter.translate import Translator
|
10
13
|
from janus.language.block import CodeBlock, TranslatedCodeBlock
|
11
14
|
|
12
|
-
from ..diagram import DiagramGenerator
|
13
|
-
from ..requirements import RequirementsDocumenter
|
14
|
-
from ..translate import Translator
|
15
|
-
|
16
|
-
# from langchain.vectorstores import Chroma
|
17
|
-
|
18
|
-
|
19
|
-
# from ..utils.enums import EmbeddingType
|
20
|
-
|
21
|
-
|
22
|
-
def print_query_results(query, n_results):
|
23
|
-
# print(f"\n{query}")
|
24
|
-
# count = 1
|
25
|
-
# for t in n_results:
|
26
|
-
# short_code = (
|
27
|
-
# (t[0].page_content[0:50] + "..")
|
28
|
-
# if (len(t[0].page_content) > 50)
|
29
|
-
# else t[0].page_content
|
30
|
-
# )
|
31
|
-
# return_index = short_code.find("\n")
|
32
|
-
# if -1 != return_index:
|
33
|
-
# short_code = short_code[0:return_index] + ".."
|
34
|
-
# print(
|
35
|
-
# f"{count}. @ {t[0].metadata['start_line']}-{t[0].metadata['end_line']}"
|
36
|
-
# f" -- {t[1]} -- {short_code}"
|
37
|
-
# )
|
38
|
-
# count += 1
|
39
|
-
pass
|
40
|
-
|
41
15
|
|
42
16
|
class MockCollection(VectorStore):
|
43
17
|
"""Vector store for testing"""
|
@@ -65,30 +39,23 @@ class MockCollection(VectorStore):
|
|
65
39
|
raise NotImplementedError("from_texts() not implemented!")
|
66
40
|
|
67
41
|
|
68
|
-
# class MockEmbeddingsFactory(EmbeddingsFactory):
|
69
|
-
# """Embeddings for testing - uses MockCollection"""
|
70
|
-
#
|
71
|
-
# def get_embeddings(self) -> Embeddings:
|
72
|
-
# return MockCollection()
|
73
|
-
#
|
74
|
-
|
75
|
-
|
76
42
|
class TestTranslator(unittest.TestCase):
|
77
43
|
"""Tests for the Translator class."""
|
78
44
|
|
79
45
|
def setUp(self):
|
80
46
|
"""Set up the tests."""
|
81
47
|
self.translator = Translator(
|
82
|
-
model="gpt-4o",
|
48
|
+
model="gpt-4o-mini",
|
83
49
|
source_language="fortran",
|
84
50
|
target_language="python",
|
85
51
|
target_version="3.10",
|
52
|
+
splitter_type="ast-flex",
|
86
53
|
)
|
87
54
|
self.test_file = Path("janus/language/treesitter/_tests/languages/fortran.f90")
|
88
55
|
self.TEST_FILE_EMBEDDING_COUNT = 14
|
89
56
|
|
90
57
|
self.req_translator = RequirementsDocumenter(
|
91
|
-
model="gpt-4o",
|
58
|
+
model="gpt-4o-mini",
|
92
59
|
source_language="fortran",
|
93
60
|
prompt_template="requirements",
|
94
61
|
)
|
@@ -105,200 +72,6 @@ class TestTranslator(unittest.TestCase):
|
|
105
72
|
# unit tests anyway
|
106
73
|
self.assertTrue(python_file.exists())
|
107
74
|
|
108
|
-
# def test_embeddings(self):
|
109
|
-
# """Testing access to embeddings"""
|
110
|
-
# vector_store = self.translator.embeddings(EmbeddingType.SOURCE)
|
111
|
-
# self.assertIsInstance(vector_store, Chroma, "Unexpected vector store type!")
|
112
|
-
# self.assertEqual(
|
113
|
-
# 0, vector_store._collection.count(), "Non-empty initial vector store?"
|
114
|
-
# )
|
115
|
-
#
|
116
|
-
# self.translator.set_model("llama")
|
117
|
-
# self.translator._load_parameters()
|
118
|
-
# vector_store = self.translator.embeddings(EmbeddingType.SOURCE)
|
119
|
-
# self.assertIsInstance(vector_store, Chroma)
|
120
|
-
# self.assertEqual(
|
121
|
-
# 0, vector_store._collection.count(), "Non-empty initial vector store?"
|
122
|
-
# )
|
123
|
-
|
124
|
-
# def test_embed_split_source(self):
|
125
|
-
# """Characterize _embed method"""
|
126
|
-
# mock_embeddings = MockEmbeddingsFactory()
|
127
|
-
# self.translator.set_embeddings(mock_embeddings)
|
128
|
-
# self.translator._load_parameters()
|
129
|
-
# input_block = self.translator.splitter.split(self.test_file)
|
130
|
-
# self.assertIsNone(
|
131
|
-
# input_block.text, "Root node of input text shouldn't contain text"
|
132
|
-
# )
|
133
|
-
# self.assertIsNone(input_block.embedding_id, "Precondition failed")
|
134
|
-
#
|
135
|
-
# result = self.translator._embed(
|
136
|
-
# input_block, EmbeddingType.SOURCE, self.test_file.name
|
137
|
-
# )
|
138
|
-
#
|
139
|
-
# self.assertFalse(result, "Nothing to embed, so should have no result")
|
140
|
-
# self.assertIsNone(
|
141
|
-
# input_block.embedding_id, "Embeddings should not have changed")
|
142
|
-
|
143
|
-
# def test_embed_has_values_for_each_non_empty_node(self):
|
144
|
-
# """Characterize our sample fortran file"""
|
145
|
-
# mock_embeddings = MockEmbeddingsFactory()
|
146
|
-
# self.translator.set_embeddings(mock_embeddings)
|
147
|
-
# self.translator._load_parameters()
|
148
|
-
# input_block = self.translator.splitter.split(self.test_file)
|
149
|
-
# self.translator._embed_nodes_recursively(
|
150
|
-
# input_block, EmbeddingType.SOURCE, self.test_file.name
|
151
|
-
# )
|
152
|
-
# has_text_count = 0
|
153
|
-
# has_embeddings_count = 0
|
154
|
-
# nodes = [input_block]
|
155
|
-
# while nodes:
|
156
|
-
# node = nodes.pop(0)
|
157
|
-
# if node.text:
|
158
|
-
# has_text_count += 1
|
159
|
-
# if node.embedding_id:
|
160
|
-
# has_embeddings_count += 1
|
161
|
-
# nodes.extend(node.children)
|
162
|
-
# self.assertEqual(
|
163
|
-
# self.TEST_FILE_EMBEDDING_COUNT,
|
164
|
-
# has_text_count,
|
165
|
-
# "Parsing of test_file has changed!",
|
166
|
-
# )
|
167
|
-
# self.assertEqual(
|
168
|
-
# self.TEST_FILE_EMBEDDING_COUNT,
|
169
|
-
# has_embeddings_count,
|
170
|
-
# "Not all non-empty nodes have embeddings!",
|
171
|
-
# )
|
172
|
-
|
173
|
-
# def test_embed_nodes_recursively(self):
|
174
|
-
# mock_embeddings = MockEmbeddingsFactory()
|
175
|
-
# self.translator.set_embeddings(mock_embeddings)
|
176
|
-
# self.translator._load_parameters()
|
177
|
-
# input_block = self.translator.splitter.split(self.test_file)
|
178
|
-
# self.translator._embed_nodes_recursively(
|
179
|
-
# input_block, EmbeddingType.SOURCE, self.test_file.name
|
180
|
-
# )
|
181
|
-
# nodes = [input_block]
|
182
|
-
# while nodes:
|
183
|
-
# node = nodes.pop(0)
|
184
|
-
# self.assertEqual(node.text is not None, node.embedding_id is not None)
|
185
|
-
# nodes.extend(node.children)
|
186
|
-
|
187
|
-
# @pytest.mark.slow
|
188
|
-
# def test_translate_file_adds_source_embeddings(self):
|
189
|
-
# mock_embeddings = MockEmbeddingsFactory()
|
190
|
-
# self.translator.set_embeddings(mock_embeddings)
|
191
|
-
# self.translator._load_parameters()
|
192
|
-
# vector_store = self.translator.embeddings(EmbeddingType.SOURCE)
|
193
|
-
# self.assertEqual(0, vector_store._add_texts_calls, "precondition")
|
194
|
-
#
|
195
|
-
# self.translator.translate_file(self.test_file)
|
196
|
-
#
|
197
|
-
# self.assertEqual(
|
198
|
-
# self.TEST_FILE_EMBEDDING_COUNT,
|
199
|
-
# vector_store._add_texts_calls,
|
200
|
-
# "Did not find expected source embeddings",
|
201
|
-
# )
|
202
|
-
|
203
|
-
# @pytest.mark.slow
|
204
|
-
# def test_embeddings_usage(self):
|
205
|
-
# """Noodling on use of embeddings
|
206
|
-
# To see results have to uncomment print_query_results() above
|
207
|
-
# """
|
208
|
-
# input_block = self.translator.splitter.split(self.test_file)
|
209
|
-
# self.translator._embed_nodes_recursively(
|
210
|
-
# input_block, EmbeddingType.SOURCE, self.test_file.name
|
211
|
-
# )
|
212
|
-
# vector_store = self.translator.embeddings(EmbeddingType.SOURCE)
|
213
|
-
#
|
214
|
-
# # this symbol has the lowest relevance scores of any in this test, but
|
215
|
-
# # still not very low; multiple embedded nodes contain it
|
216
|
-
# QUERY_STRING = "IWX_BAND_START"
|
217
|
-
# query = self.translator._embeddings._embeddings.embed_query(QUERY_STRING)
|
218
|
-
# n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
|
219
|
-
# embedding=query,
|
220
|
-
# k=10,
|
221
|
-
# where_document={"$contains": QUERY_STRING},
|
222
|
-
# )
|
223
|
-
# self.assertTrue(len(n_results) > 1, "Why was valid symbol not found?")
|
224
|
-
# print_query_results(QUERY_STRING, n_results)
|
225
|
-
|
226
|
-
# in the XYZZY test, the least dissimilar results were the start and finish lines
|
227
|
-
# 0, and 415, which produced a similarity score of 0.47:
|
228
|
-
|
229
|
-
# QUERY_STRING = "XYZZY"
|
230
|
-
# query = self.translator._embeddings.embed_query(QUERY_STRING)
|
231
|
-
# n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
|
232
|
-
# embedding=query,
|
233
|
-
# k=10,
|
234
|
-
# # filter={"end_line": 15},
|
235
|
-
# # filter={"$and": [{"end_line": 15}, {"tokens": {"$gte": 21}}]},
|
236
|
-
# # where_document={"$contains": QUERY_STRING},
|
237
|
-
# )
|
238
|
-
# print_query_results(QUERY_STRING, n_results)
|
239
|
-
# # self.assertTrue(len(n_results) == 0, "Invalid symbol was found?")
|
240
|
-
|
241
|
-
# # only returns a single result because only 1 embedded node contains
|
242
|
-
# # CSV_ICASEARR:
|
243
|
-
# QUERY_STRING = "What is the use of CSV_ICASEARR?"
|
244
|
-
# query = self.translator._embeddings._embeddings.embed_query(QUERY_STRING)
|
245
|
-
# n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
|
246
|
-
# embedding=query,
|
247
|
-
# k=10,
|
248
|
-
# # where_document={"$contains": QUERY_STRING},
|
249
|
-
# where_document={"$contains": "CSV_ICASEARR"},
|
250
|
-
# )
|
251
|
-
# print_query_results(QUERY_STRING, n_results)
|
252
|
-
# self.assertTrue(len(n_results) == 1, "Was splitting changed?")
|
253
|
-
#
|
254
|
-
# # trimmed out some characters from line 43, and still not very similar scoring
|
255
|
-
# QUERY_STRING = "IYL_EDGEBUFFER EDGEBUFFER IGN_MASK CELLSIZE"
|
256
|
-
# query = self.translator._embeddings._embeddings.embed_query(QUERY_STRING)
|
257
|
-
# n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
|
258
|
-
# embedding=query,
|
259
|
-
# k=10,
|
260
|
-
# # where_document={"$contains": QUERY_STRING},
|
261
|
-
# )
|
262
|
-
# print_query_results(QUERY_STRING, n_results)
|
263
|
-
#
|
264
|
-
# # random string (as bad as XYZZY), but searching for a specific line
|
265
|
-
# QUERY_STRING = "ghost in the invisible moon"
|
266
|
-
# query = self.translator._embeddings._embeddings.embed_query(QUERY_STRING)
|
267
|
-
# n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
|
268
|
-
# embedding=query,
|
269
|
-
# k=10,
|
270
|
-
# filter={"$and": [{"end_line": 90}, {"tokens": {"$gte": 21}}]},
|
271
|
-
# )
|
272
|
-
# print_query_results(QUERY_STRING, n_results)
|
273
|
-
# self.assertTrue(len(n_results) == 1, "Was splitting changed?")
|
274
|
-
|
275
|
-
# @pytest.mark.slow
|
276
|
-
# def test_document_embeddings_added_by_translate(self):
|
277
|
-
# vector_store = self.req_translator.embeddings(EmbeddingType.REQUIREMENT)
|
278
|
-
# self.assertEqual(0, vector_store._add_texts_calls, "Precondition failed")
|
279
|
-
# self.req_translator.translate(self.test_file.parent, self.test_file.parent,
|
280
|
-
# True)
|
281
|
-
# self.assertTrue(vector_store._add_texts_calls > 0, "Why no documentation?")
|
282
|
-
|
283
|
-
# @pytest.mark.slow
|
284
|
-
# def test_embed_requirements(self):
|
285
|
-
# vector_store = self.req_translator.embeddings(EmbeddingType.REQUIREMENT)
|
286
|
-
# translated = self.req_translator.translate_file(self.test_file)
|
287
|
-
# self.assertEqual(
|
288
|
-
# 0,
|
289
|
-
# vector_store._add_texts_calls,
|
290
|
-
# "Unexpected requirements added in translate_file",
|
291
|
-
# )
|
292
|
-
# result = self.req_translator._embed(
|
293
|
-
# translated, EmbeddingType.REQUIREMENT, self.test_file.name
|
294
|
-
# )
|
295
|
-
# self.assertFalse(result, "No text in root node, so should generate no docs")
|
296
|
-
# self.assertIsNotNone(translated.children[0].text, "Data changed?")
|
297
|
-
# result = self.req_translator._embed(
|
298
|
-
# translated.children[0], EmbeddingType.REQUIREMENT, self.test_file.name
|
299
|
-
# )
|
300
|
-
# self.assertTrue(result, "No docs generated for first child node?")
|
301
|
-
|
302
75
|
def test_invalid_selections(self) -> None:
|
303
76
|
"""Tests that settings values for the translator will raise exceptions"""
|
304
77
|
self.assertRaises(
|
janus/converter/converter.py
CHANGED
@@ -6,7 +6,6 @@ from pathlib import Path
|
|
6
6
|
from typing import Any
|
7
7
|
|
8
8
|
from langchain.output_parsers import RetryWithErrorOutputParser
|
9
|
-
from langchain.output_parsers.fix import OutputFixingParser
|
10
9
|
from langchain_core.exceptions import OutputParserException
|
11
10
|
from langchain_core.language_models import BaseLanguageModel
|
12
11
|
from langchain_core.output_parsers import BaseOutputParser
|
@@ -29,6 +28,8 @@ from janus.llm import load_model
|
|
29
28
|
from janus.llm.model_callbacks import get_model_callback
|
30
29
|
from janus.llm.models_info import MODEL_PROMPT_ENGINES
|
31
30
|
from janus.parsers.code_parser import GenericParser
|
31
|
+
from janus.parsers.refiner_parser import RefinerParser
|
32
|
+
from janus.refiners.refiner import BasicRefiner, Refiner
|
32
33
|
from janus.utils.enums import LANGUAGES
|
33
34
|
from janus.utils.logger import create_logger
|
34
35
|
|
@@ -75,6 +76,7 @@ class Converter:
|
|
75
76
|
protected_node_types: tuple[str, ...] = (),
|
76
77
|
prune_node_types: tuple[str, ...] = (),
|
77
78
|
splitter_type: str = "file",
|
79
|
+
refiner_type: str = "basic",
|
78
80
|
) -> None:
|
79
81
|
"""Initialize a Converter instance.
|
80
82
|
|
@@ -84,6 +86,17 @@ class Converter:
|
|
84
86
|
values are `"code"`, `"text"`, `"eval"`, and `None` (default). If `None`,
|
85
87
|
the `Converter` assumes you won't be parsing an output (i.e., adding to an
|
86
88
|
embedding DB).
|
89
|
+
max_prompts: The maximum number of prompts to try before giving up.
|
90
|
+
max_tokens: The maximum number of tokens to use in the LLM. If `None`, the
|
91
|
+
converter will use half the model's token limit.
|
92
|
+
prompt_template: The name of the prompt template to use.
|
93
|
+
db_path: The path to the database to use for vectorization.
|
94
|
+
db_config: The configuration for the database.
|
95
|
+
protected_node_types: A set of node types that aren't to be merged.
|
96
|
+
prune_node_types: A set of node types which should be pruned.
|
97
|
+
splitter_type: The type of splitter to use. Valid values are `"file"`,
|
98
|
+
`"tag"`, `"chunk"`, `"ast-strict"`, and `"ast-flex"`.
|
99
|
+
refiner_type: The type of refiner to use. Valid values are `"basic"`.
|
87
100
|
"""
|
88
101
|
self._changed_attrs: set = set()
|
89
102
|
|
@@ -116,7 +129,11 @@ class Converter:
|
|
116
129
|
self._parser: BaseOutputParser = GenericParser()
|
117
130
|
self._combiner: Combiner = Combiner()
|
118
131
|
|
132
|
+
self._refiner_type: str
|
133
|
+
self._refiner: Refiner
|
134
|
+
|
119
135
|
self.set_splitter(splitter_type=splitter_type)
|
136
|
+
self.set_refiner(refiner_type=refiner_type)
|
120
137
|
self.set_model(model_name=model, **model_arguments)
|
121
138
|
self.set_prompt(prompt_template=prompt_template)
|
122
139
|
self.set_source_language(source_language)
|
@@ -142,6 +159,7 @@ class Converter:
|
|
142
159
|
self._load_prompt()
|
143
160
|
self._load_splitter()
|
144
161
|
self._load_vectorizer()
|
162
|
+
self._load_refiner()
|
145
163
|
self._changed_attrs.clear()
|
146
164
|
|
147
165
|
def set_model(self, model_name: str, **custom_arguments: dict[str, Any]):
|
@@ -179,6 +197,16 @@ class Converter:
|
|
179
197
|
"""
|
180
198
|
self._splitter_type = splitter_type
|
181
199
|
|
200
|
+
def set_refiner(self, refiner_type: str) -> None:
|
201
|
+
"""Validate and set the refiner name
|
202
|
+
|
203
|
+
The affected objects will not be updated until translate is called
|
204
|
+
|
205
|
+
Arguments:
|
206
|
+
refiner_type: the name of the refiner to use
|
207
|
+
"""
|
208
|
+
self._refiner_type = refiner_type
|
209
|
+
|
182
210
|
def set_source_language(self, source_language: str) -> None:
|
183
211
|
"""Validate and set the source language.
|
184
212
|
|
@@ -249,10 +277,24 @@ class Converter:
|
|
249
277
|
)
|
250
278
|
|
251
279
|
if self._splitter_type == "tag":
|
252
|
-
kwargs["tag"] = "<ITMOD_ALC_SPLIT>"
|
280
|
+
kwargs["tag"] = "<ITMOD_ALC_SPLIT>" # Hardcoded for now
|
253
281
|
|
254
282
|
self._splitter = CUSTOM_SPLITTERS[self._splitter_type](**kwargs)
|
255
283
|
|
284
|
+
@run_if_changed("_refiner_type", "_model_name")
|
285
|
+
def _load_refiner(self) -> None:
|
286
|
+
"""Load the refiner according to this instance's attributes.
|
287
|
+
|
288
|
+
If the relevant fields have not been changed since the last time this method was
|
289
|
+
called, nothing happens.
|
290
|
+
"""
|
291
|
+
if self._refiner_type == "basic":
|
292
|
+
self._refiner = BasicRefiner(
|
293
|
+
"basic_refinement", self._model_name, self._source_language
|
294
|
+
)
|
295
|
+
else:
|
296
|
+
raise ValueError(f"Error: unknown refiner type {self._refiner_type}")
|
297
|
+
|
256
298
|
@run_if_changed("_model_name", "_custom_model_arguments")
|
257
299
|
def _load_model(self) -> None:
|
258
300
|
"""Load the model according to this instance's attributes.
|
@@ -561,22 +603,22 @@ class Converter:
|
|
561
603
|
# Retries with just the input
|
562
604
|
n3 = math.ceil(self.max_prompts / (n1 * n2))
|
563
605
|
|
564
|
-
|
565
|
-
llm=self._llm,
|
606
|
+
refine_output = RefinerParser(
|
566
607
|
parser=self._parser,
|
608
|
+
initial_prompt=self._prompt.format(**{"SOURCE_CODE": block.original.text}),
|
609
|
+
refiner=self._refiner,
|
567
610
|
max_retries=n1,
|
611
|
+
llm=self._llm,
|
568
612
|
)
|
569
613
|
retry = RetryWithErrorOutputParser.from_llm(
|
570
614
|
llm=self._llm,
|
571
|
-
parser=
|
615
|
+
parser=refine_output,
|
572
616
|
max_retries=n2,
|
573
617
|
)
|
574
|
-
|
575
618
|
completion_chain = self._prompt | self._llm
|
576
619
|
chain = RunnableParallel(
|
577
620
|
completion=completion_chain, prompt_value=self._prompt
|
578
621
|
) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
|
579
|
-
|
580
622
|
for _ in range(n3):
|
581
623
|
try:
|
582
624
|
return chain.invoke({"SOURCE_CODE": block.original.text})
|
janus/converter/diagram.py
CHANGED
@@ -1,10 +1,14 @@
|
|
1
|
-
import
|
2
|
-
|
1
|
+
import math
|
2
|
+
|
3
|
+
from langchain.output_parsers import RetryWithErrorOutputParser
|
4
|
+
from langchain_core.exceptions import OutputParserException
|
5
|
+
from langchain_core.runnables import RunnableLambda, RunnableParallel
|
3
6
|
|
4
7
|
from janus.converter.converter import run_if_changed
|
5
8
|
from janus.converter.document import Documenter
|
6
9
|
from janus.language.block import TranslatedCodeBlock
|
7
10
|
from janus.llm.models_info import MODEL_PROMPT_ENGINES
|
11
|
+
from janus.parsers.refiner_parser import RefinerParser
|
8
12
|
from janus.parsers.uml import UMLSyntaxParser
|
9
13
|
from janus.utils.logger import create_logger
|
10
14
|
|
@@ -47,65 +51,74 @@ class DiagramGenerator(Documenter):
|
|
47
51
|
self._diagram_prompt_template_name = "diagram"
|
48
52
|
self._load_diagram_prompt_engine()
|
49
53
|
|
50
|
-
def
|
51
|
-
"""Given an "empty" `TranslatedCodeBlock`, translate the code represented in
|
52
|
-
`block.original`, setting the relevant fields in the translated block. The
|
53
|
-
`TranslatedCodeBlock` is updated in-pace, nothing is returned. Note that this
|
54
|
-
translates *only* the code for this block, not its children.
|
55
|
-
|
56
|
-
Arguments:
|
57
|
-
block: An empty `TranslatedCodeBlock`
|
58
|
-
"""
|
59
|
-
if block.translated:
|
60
|
-
return
|
61
|
-
|
62
|
-
if block.original.text is None:
|
63
|
-
block.translated = True
|
64
|
-
return
|
65
|
-
|
66
|
-
if self._add_documentation:
|
67
|
-
documentation_block = deepcopy(block)
|
68
|
-
super()._add_translation(documentation_block)
|
69
|
-
if not documentation_block.translated:
|
70
|
-
message = "Error: unable to produce documentation for code block"
|
71
|
-
log.info(message)
|
72
|
-
raise ValueError(message)
|
73
|
-
documentation = json.loads(documentation_block.text)["docstring"]
|
74
|
-
|
75
|
-
if self._llm is None:
|
76
|
-
message = (
|
77
|
-
"Model not configured correctly, cannot translate. Try setting "
|
78
|
-
"the model"
|
79
|
-
)
|
80
|
-
log.error(message)
|
81
|
-
raise ValueError(message)
|
82
|
-
|
83
|
-
log.debug(f"[{block.name}] Translating...")
|
84
|
-
log.debug(f"[{block.name}] Input text:\n{block.original.text}")
|
85
|
-
|
54
|
+
def _run_chain(self, block: TranslatedCodeBlock) -> str:
|
86
55
|
self._parser.set_reference(block.original)
|
56
|
+
n1 = round(self.max_prompts ** (1 / 3))
|
87
57
|
|
88
|
-
|
58
|
+
# Retries with the input, output, and error
|
59
|
+
n2 = round((self.max_prompts // n1) ** (1 / 2))
|
60
|
+
|
61
|
+
# Retries with just the input
|
62
|
+
n3 = math.ceil(self.max_prompts / (n1 * n2))
|
89
63
|
|
90
64
|
if self._add_documentation:
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
65
|
+
documentation_text = super()._run_chain(block)
|
66
|
+
refine_output = RefinerParser(
|
67
|
+
parser=self._diagram_parser,
|
68
|
+
initial_prompt=self._diagram_prompt.format(
|
69
|
+
**{
|
70
|
+
"SOURCE_CODE": block.original.text,
|
71
|
+
"DOCUMENTATION": documentation_text,
|
72
|
+
"DIAGRAM_TYPE": self._diagram_type,
|
73
|
+
}
|
74
|
+
),
|
75
|
+
refiner=self._refiner,
|
76
|
+
max_retries=n1,
|
77
|
+
llm=self._llm,
|
97
78
|
)
|
98
79
|
else:
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
80
|
+
refine_output = RefinerParser(
|
81
|
+
parser=self._diagram_parser,
|
82
|
+
initial_prompt=self._diagram_prompt.format(
|
83
|
+
**{
|
84
|
+
"SOURCE_CODE": block.original.text,
|
85
|
+
"DIAGRAM_TYPE": self._diagram_type,
|
86
|
+
}
|
87
|
+
),
|
88
|
+
refiner=self._refiner,
|
89
|
+
max_retries=n1,
|
90
|
+
llm=self._llm,
|
104
91
|
)
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
92
|
+
retry = RetryWithErrorOutputParser.from_llm(
|
93
|
+
llm=self._llm,
|
94
|
+
parser=refine_output,
|
95
|
+
max_retries=n2,
|
96
|
+
)
|
97
|
+
completion_chain = self._prompt | self._llm
|
98
|
+
chain = RunnableParallel(
|
99
|
+
completion=completion_chain, prompt_value=self._diagram_prompt
|
100
|
+
) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
|
101
|
+
for _ in range(n3):
|
102
|
+
try:
|
103
|
+
if self._add_documentation:
|
104
|
+
return chain.invoke(
|
105
|
+
{
|
106
|
+
"SOURCE_CODE": block.original.text,
|
107
|
+
"DOCUMENTATION": documentation_text,
|
108
|
+
"DIAGRAM_TYPE": self._diagram_type,
|
109
|
+
}
|
110
|
+
)
|
111
|
+
else:
|
112
|
+
return chain.invoke(
|
113
|
+
{
|
114
|
+
"SOURCE_CODE": block.original.text,
|
115
|
+
"DIAGRAM_TYPE": self._diagram_type,
|
116
|
+
}
|
117
|
+
)
|
118
|
+
except OutputParserException:
|
119
|
+
pass
|
120
|
+
|
121
|
+
raise OutputParserException(f"Failed to parse after {n1*n2*n3} retries")
|
109
122
|
|
110
123
|
@run_if_changed(
|
111
124
|
"_diagram_prompt_template_name",
|
@@ -123,4 +136,4 @@ class DiagramGenerator(Documenter):
|
|
123
136
|
target_version=None,
|
124
137
|
prompt_template=self._diagram_prompt_template_name,
|
125
138
|
)
|
126
|
-
self.
|
139
|
+
self._diagram_prompt = self._diagram_prompt_engine.prompt
|
@@ -4,8 +4,8 @@ from unittest.mock import MagicMock
|
|
4
4
|
|
5
5
|
import pytest
|
6
6
|
|
7
|
-
from
|
8
|
-
from
|
7
|
+
from janus.embedding.collections import Collections
|
8
|
+
from janus.utils.enums import EmbeddingType
|
9
9
|
|
10
10
|
|
11
11
|
class TestCollections(unittest.TestCase):
|
@@ -5,9 +5,9 @@ from unittest.mock import MagicMock
|
|
5
5
|
|
6
6
|
from chromadb.api.client import Client
|
7
7
|
|
8
|
-
from
|
9
|
-
from
|
10
|
-
from
|
8
|
+
from janus.embedding.vectorize import Vectorizer, VectorizerFactory
|
9
|
+
from janus.language.treesitter import TreeSitterSplitter
|
10
|
+
from janus.utils.enums import EmbeddingType
|
11
11
|
|
12
12
|
|
13
13
|
class MockDBVectorizer(VectorizerFactory):
|
janus/embedding/collections.py
CHANGED
@@ -5,8 +5,8 @@ from typing import Dict, Optional, Sequence
|
|
5
5
|
from chromadb import Client, Collection
|
6
6
|
from langchain_community.vectorstores import Chroma
|
7
7
|
|
8
|
-
from
|
9
|
-
from .
|
8
|
+
from janus.embedding.embedding_models_info import load_embedding_model
|
9
|
+
from janus.utils.enums import EmbeddingType
|
10
10
|
|
11
11
|
# See https://docs.trychroma.com/telemetry#in-chromas-backend-using-environment-variables
|
12
12
|
os.environ["ANONYMIZED_TELEMETRY"] = "False"
|
janus/embedding/database.py
CHANGED
@@ -8,7 +8,7 @@ from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEm
|
|
8
8
|
from langchain_core.embeddings import Embeddings
|
9
9
|
from langchain_openai import OpenAIEmbeddings
|
10
10
|
|
11
|
-
from
|
11
|
+
from janus.utils.logger import create_logger
|
12
12
|
|
13
13
|
load_dotenv()
|
14
14
|
|