janus-llm 3.1.1__py3-none-any.whl → 3.2.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (72) hide show
  1. janus/__init__.py +3 -3
  2. janus/_tests/test_cli.py +3 -3
  3. janus/cli.py +65 -8
  4. janus/converter/__init__.py +6 -6
  5. janus/converter/_tests/test_translate.py +10 -238
  6. janus/converter/converter.py +6 -3
  7. janus/converter/translate.py +1 -1
  8. janus/embedding/_tests/test_collections.py +2 -2
  9. janus/embedding/_tests/test_database.py +1 -1
  10. janus/embedding/_tests/test_vectorize.py +3 -3
  11. janus/embedding/collections.py +2 -2
  12. janus/embedding/database.py +1 -1
  13. janus/embedding/embedding_models_info.py +1 -1
  14. janus/embedding/vectorize.py +5 -5
  15. janus/language/_tests/test_combine.py +1 -1
  16. janus/language/_tests/test_splitter.py +1 -1
  17. janus/language/alc/_tests/test_alc.py +6 -6
  18. janus/language/alc/alc.py +5 -5
  19. janus/language/binary/_tests/test_binary.py +4 -4
  20. janus/language/binary/binary.py +5 -5
  21. janus/language/block.py +2 -2
  22. janus/language/combine.py +3 -3
  23. janus/language/file.py +2 -2
  24. janus/language/mumps/_tests/test_mumps.py +5 -5
  25. janus/language/mumps/mumps.py +5 -5
  26. janus/language/mumps/patterns.py +1 -1
  27. janus/language/naive/__init__.py +4 -4
  28. janus/language/naive/basic_splitter.py +4 -4
  29. janus/language/naive/chunk_splitter.py +4 -4
  30. janus/language/naive/registry.py +1 -1
  31. janus/language/naive/simple_ast.py +5 -5
  32. janus/language/naive/tag_splitter.py +4 -4
  33. janus/language/node.py +1 -1
  34. janus/language/splitter.py +4 -4
  35. janus/language/treesitter/_tests/test_treesitter.py +5 -5
  36. janus/language/treesitter/treesitter.py +4 -4
  37. janus/llm/__init__.py +1 -1
  38. janus/llm/model_callbacks.py +1 -1
  39. janus/llm/models_info.py +45 -23
  40. janus/metrics/_tests/test_bleu.py +1 -1
  41. janus/metrics/_tests/test_chrf.py +1 -1
  42. janus/metrics/_tests/test_file_pairing.py +1 -1
  43. janus/metrics/_tests/test_llm.py +5 -5
  44. janus/metrics/_tests/test_reading.py +1 -1
  45. janus/metrics/_tests/test_rouge_score.py +1 -1
  46. janus/metrics/_tests/test_similarity_score.py +1 -1
  47. janus/metrics/_tests/test_treesitter_metrics.py +2 -2
  48. janus/metrics/bleu.py +1 -1
  49. janus/metrics/chrf.py +1 -1
  50. janus/metrics/complexity_metrics.py +4 -4
  51. janus/metrics/file_pairing.py +5 -5
  52. janus/metrics/llm_metrics.py +1 -1
  53. janus/metrics/metric.py +11 -11
  54. janus/metrics/reading.py +1 -1
  55. janus/metrics/rouge_score.py +1 -1
  56. janus/metrics/similarity.py +2 -2
  57. janus/parsers/_tests/test_code_parser.py +1 -1
  58. janus/parsers/code_parser.py +2 -2
  59. janus/parsers/doc_parser.py +3 -3
  60. janus/parsers/eval_parser.py +2 -2
  61. janus/parsers/reqs_parser.py +3 -3
  62. janus/parsers/uml.py +1 -2
  63. janus/prompts/prompt.py +2 -2
  64. janus/utils/_tests/test_logger.py +1 -1
  65. janus/utils/_tests/test_progress.py +1 -1
  66. janus/utils/progress.py +1 -1
  67. {janus_llm-3.1.1.dist-info → janus_llm-3.2.1.dist-info}/METADATA +1 -1
  68. janus_llm-3.2.1.dist-info/RECORD +105 -0
  69. janus_llm-3.1.1.dist-info/RECORD +0 -105
  70. {janus_llm-3.1.1.dist-info → janus_llm-3.2.1.dist-info}/LICENSE +0 -0
  71. {janus_llm-3.1.1.dist-info → janus_llm-3.2.1.dist-info}/WHEEL +0 -0
  72. {janus_llm-3.1.1.dist-info → janus_llm-3.2.1.dist-info}/entry_points.txt +0 -0
janus/__init__.py CHANGED
@@ -2,10 +2,10 @@ import warnings
2
2
 
3
3
  from langchain_core._api.deprecation import LangChainDeprecationWarning
4
4
 
5
- from .converter.translate import Translator
6
- from .metrics import * # noqa: F403
5
+ from janus.converter.translate import Translator
6
+ from janus.metrics import * # noqa: F403
7
7
 
8
- __version__ = "3.1.1"
8
+ __version__ = "3.2.1"
9
9
 
10
10
  # Ignoring a deprecation warning from langchain_core that I can't seem to hunt down
11
11
  warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
janus/_tests/test_cli.py CHANGED
@@ -4,9 +4,9 @@ from unittest.mock import ANY, patch
4
4
 
5
5
  from typer.testing import CliRunner
6
6
 
7
- from ..cli import app, translate
8
- from ..embedding.embedding_models_info import EMBEDDING_MODEL_CONFIG_DIR
9
- from ..llm.models_info import MODEL_CONFIG_DIR
7
+ from janus.cli import app, translate
8
+ from janus.embedding.embedding_models_info import EMBEDDING_MODEL_CONFIG_DIR
9
+ from janus.llm.models_info import MODEL_CONFIG_DIR
10
10
 
11
11
 
12
12
  class TestCli(unittest.TestCase):
janus/cli.py CHANGED
@@ -32,8 +32,12 @@ from janus.language.treesitter import TreeSitterSplitter
32
32
  from janus.llm.model_callbacks import COST_PER_1K_TOKENS
33
33
  from janus.llm.models_info import (
34
34
  MODEL_CONFIG_DIR,
35
+ MODEL_ID_TO_LONG_ID,
35
36
  MODEL_TYPE_CONSTRUCTORS,
37
+ MODEL_TYPES,
36
38
  TOKEN_LIMITS,
39
+ bedrock_models,
40
+ openai_models,
37
41
  )
38
42
  from janus.metrics.cli import evaluate
39
43
  from janus.utils.enums import LANGUAGES
@@ -104,7 +108,7 @@ embedding = typer.Typer(
104
108
 
105
109
  def version_callback(value: bool) -> None:
106
110
  if value:
107
- from . import __version__ as version
111
+ from janus import __version__ as version
108
112
 
109
113
  print(f"Janus CLI [blue]v{version}[/blue]")
110
114
  raise typer.Exit()
@@ -179,7 +183,7 @@ def translate(
179
183
  "-L",
180
184
  help="The custom name of the model set with 'janus llm add'.",
181
185
  ),
182
- ] = "gpt-3.5-turbo-0125",
186
+ ] = "gpt-4o",
183
187
  max_prompts: Annotated[
184
188
  int,
185
189
  typer.Option(
@@ -301,7 +305,7 @@ def document(
301
305
  "-L",
302
306
  help="The custom name of the model set with 'janus llm add'.",
303
307
  ),
304
- ] = "gpt-3.5-turbo-0125",
308
+ ] = "gpt-4o",
305
309
  max_prompts: Annotated[
306
310
  int,
307
311
  typer.Option(
@@ -437,7 +441,7 @@ def diagram(
437
441
  "-L",
438
442
  help="The custom name of the model set with 'janus llm add'.",
439
443
  ),
440
- ] = "gpt-3.5-turbo-0125",
444
+ ] = "gpt-4o",
441
445
  max_prompts: Annotated[
442
446
  int,
443
447
  typer.Option(
@@ -800,16 +804,44 @@ def llm_add(
800
804
  "model_cost": {"input": in_cost, "output": out_cost},
801
805
  }
802
806
  elif model_type == "OpenAI":
803
- model_name = typer.prompt("Enter the model name", default="gpt-3.5-turbo-0125")
807
+ model_id = typer.prompt(
808
+ "Enter the model ID (list model IDs with `janus llm ls -a`)",
809
+ default="gpt-4o",
810
+ type=click.Choice(openai_models),
811
+ show_choices=False,
812
+ )
804
813
  params = dict(
805
- model_name=model_name,
814
+ # OpenAI uses the "model_name" key for what we're calling "long_model_id"
815
+ model_name=MODEL_ID_TO_LONG_ID[model_id],
806
816
  temperature=0.7,
807
817
  n=1,
808
818
  )
809
- max_tokens = TOKEN_LIMITS[model_name]
810
- model_cost = COST_PER_1K_TOKENS[model_name]
819
+ max_tokens = TOKEN_LIMITS[MODEL_ID_TO_LONG_ID[model_id]]
820
+ model_cost = COST_PER_1K_TOKENS[MODEL_ID_TO_LONG_ID[model_id]]
821
+ cfg = {
822
+ "model_type": model_type,
823
+ "model_id": model_id,
824
+ "model_args": params,
825
+ "token_limit": max_tokens,
826
+ "model_cost": model_cost,
827
+ }
828
+ elif model_type == "BedrockChat" or model_type == "Bedrock":
829
+ model_id = typer.prompt(
830
+ "Enter the model ID (list model IDs with `janus llm ls -a`)",
831
+ default="bedrock-claude-sonnet",
832
+ type=click.Choice(bedrock_models),
833
+ show_choices=False,
834
+ )
835
+ params = dict(
836
+ # Bedrock uses the "model_id" key for what we're calling "long_model_id"
837
+ model_id=MODEL_ID_TO_LONG_ID[model_id],
838
+ model_kwargs={"temperature": 0.7},
839
+ )
840
+ max_tokens = TOKEN_LIMITS[MODEL_ID_TO_LONG_ID[model_id]]
841
+ model_cost = COST_PER_1K_TOKENS[MODEL_ID_TO_LONG_ID[model_id]]
811
842
  cfg = {
812
843
  "model_type": model_type,
844
+ "model_id": model_id,
813
845
  "model_args": params,
814
846
  "token_limit": max_tokens,
815
847
  "model_cost": model_cost,
@@ -821,6 +853,31 @@ def llm_add(
821
853
  print(f"Model config written to {model_cfg}")
822
854
 
823
855
 
856
+ @llm.command("ls", help="List all of the user-configured models")
857
+ def llm_ls(
858
+ all: Annotated[
859
+ bool,
860
+ typer.Option(
861
+ "--all",
862
+ "-a",
863
+ is_flag=True,
864
+ help="List all models, including the default model IDs.",
865
+ click_type=click.Choice(sorted(list(MODEL_TYPE_CONSTRUCTORS.keys()))),
866
+ ),
867
+ ] = False,
868
+ ):
869
+ print("\n[green]User-configured models[/green]:")
870
+ for model_cfg in MODEL_CONFIG_DIR.glob("*.json"):
871
+ with open(model_cfg, "r") as f:
872
+ cfg = json.load(f)
873
+ print(f"\t[blue]{model_cfg.stem}[/blue]: [purple]{cfg['model_type']}[/purple]")
874
+
875
+ if all:
876
+ print("\n[green]Available model IDs[/green]:")
877
+ for model_id, model_type in MODEL_TYPES.items():
878
+ print(f"\t[blue]{model_id}[/blue]: [purple]{model_type}[/purple]")
879
+
880
+
824
881
  @embedding.command("add", help="Add an embedding model config to janus")
825
882
  def embedding_add(
826
883
  model_name: Annotated[
@@ -1,6 +1,6 @@
1
- from .converter import Converter
2
- from .diagram import DiagramGenerator
3
- from .document import Documenter, MadLibsDocumenter, MultiDocumenter
4
- from .evaluate import Evaluator
5
- from .requirements import RequirementsDocumenter
6
- from .translate import Translator
1
+ from janus.converter.converter import Converter
2
+ from janus.converter.diagram import DiagramGenerator
3
+ from janus.converter.document import Documenter, MadLibsDocumenter, MultiDocumenter
4
+ from janus.converter.evaluate import Evaluator
5
+ from janus.converter.requirements import RequirementsDocumenter
6
+ from janus.converter.translate import Translator
@@ -7,37 +7,11 @@ from langchain.schema import Document
7
7
  from langchain.schema.embeddings import Embeddings
8
8
  from langchain.schema.vectorstore import VST, VectorStore
9
9
 
10
+ from janus.converter.diagram import DiagramGenerator
11
+ from janus.converter.requirements import RequirementsDocumenter
12
+ from janus.converter.translate import Translator
10
13
  from janus.language.block import CodeBlock, TranslatedCodeBlock
11
14
 
12
- from ..diagram import DiagramGenerator
13
- from ..requirements import RequirementsDocumenter
14
- from ..translate import Translator
15
-
16
- # from langchain.vectorstores import Chroma
17
-
18
-
19
- # from ..utils.enums import EmbeddingType
20
-
21
-
22
- def print_query_results(query, n_results):
23
- # print(f"\n{query}")
24
- # count = 1
25
- # for t in n_results:
26
- # short_code = (
27
- # (t[0].page_content[0:50] + "..")
28
- # if (len(t[0].page_content) > 50)
29
- # else t[0].page_content
30
- # )
31
- # return_index = short_code.find("\n")
32
- # if -1 != return_index:
33
- # short_code = short_code[0:return_index] + ".."
34
- # print(
35
- # f"{count}. @ {t[0].metadata['start_line']}-{t[0].metadata['end_line']}"
36
- # f" -- {t[1]} -- {short_code}"
37
- # )
38
- # count += 1
39
- pass
40
-
41
15
 
42
16
  class MockCollection(VectorStore):
43
17
  """Vector store for testing"""
@@ -65,21 +39,13 @@ class MockCollection(VectorStore):
65
39
  raise NotImplementedError("from_texts() not implemented!")
66
40
 
67
41
 
68
- # class MockEmbeddingsFactory(EmbeddingsFactory):
69
- # """Embeddings for testing - uses MockCollection"""
70
- #
71
- # def get_embeddings(self) -> Embeddings:
72
- # return MockCollection()
73
- #
74
-
75
-
76
42
  class TestTranslator(unittest.TestCase):
77
43
  """Tests for the Translator class."""
78
44
 
79
45
  def setUp(self):
80
46
  """Set up the tests."""
81
47
  self.translator = Translator(
82
- model="gpt-3.5-turbo-0125",
48
+ model="gpt-4o",
83
49
  source_language="fortran",
84
50
  target_language="python",
85
51
  target_version="3.10",
@@ -88,7 +54,7 @@ class TestTranslator(unittest.TestCase):
88
54
  self.TEST_FILE_EMBEDDING_COUNT = 14
89
55
 
90
56
  self.req_translator = RequirementsDocumenter(
91
- model="gpt-3.5-turbo-0125",
57
+ model="gpt-4o",
92
58
  source_language="fortran",
93
59
  prompt_template="requirements",
94
60
  )
@@ -105,200 +71,6 @@ class TestTranslator(unittest.TestCase):
105
71
  # unit tests anyway
106
72
  self.assertTrue(python_file.exists())
107
73
 
108
- # def test_embeddings(self):
109
- # """Testing access to embeddings"""
110
- # vector_store = self.translator.embeddings(EmbeddingType.SOURCE)
111
- # self.assertIsInstance(vector_store, Chroma, "Unexpected vector store type!")
112
- # self.assertEqual(
113
- # 0, vector_store._collection.count(), "Non-empty initial vector store?"
114
- # )
115
- #
116
- # self.translator.set_model("llama")
117
- # self.translator._load_parameters()
118
- # vector_store = self.translator.embeddings(EmbeddingType.SOURCE)
119
- # self.assertIsInstance(vector_store, Chroma)
120
- # self.assertEqual(
121
- # 0, vector_store._collection.count(), "Non-empty initial vector store?"
122
- # )
123
-
124
- # def test_embed_split_source(self):
125
- # """Characterize _embed method"""
126
- # mock_embeddings = MockEmbeddingsFactory()
127
- # self.translator.set_embeddings(mock_embeddings)
128
- # self.translator._load_parameters()
129
- # input_block = self.translator.splitter.split(self.test_file)
130
- # self.assertIsNone(
131
- # input_block.text, "Root node of input text shouldn't contain text"
132
- # )
133
- # self.assertIsNone(input_block.embedding_id, "Precondition failed")
134
- #
135
- # result = self.translator._embed(
136
- # input_block, EmbeddingType.SOURCE, self.test_file.name
137
- # )
138
- #
139
- # self.assertFalse(result, "Nothing to embed, so should have no result")
140
- # self.assertIsNone(
141
- # input_block.embedding_id, "Embeddings should not have changed")
142
-
143
- # def test_embed_has_values_for_each_non_empty_node(self):
144
- # """Characterize our sample fortran file"""
145
- # mock_embeddings = MockEmbeddingsFactory()
146
- # self.translator.set_embeddings(mock_embeddings)
147
- # self.translator._load_parameters()
148
- # input_block = self.translator.splitter.split(self.test_file)
149
- # self.translator._embed_nodes_recursively(
150
- # input_block, EmbeddingType.SOURCE, self.test_file.name
151
- # )
152
- # has_text_count = 0
153
- # has_embeddings_count = 0
154
- # nodes = [input_block]
155
- # while nodes:
156
- # node = nodes.pop(0)
157
- # if node.text:
158
- # has_text_count += 1
159
- # if node.embedding_id:
160
- # has_embeddings_count += 1
161
- # nodes.extend(node.children)
162
- # self.assertEqual(
163
- # self.TEST_FILE_EMBEDDING_COUNT,
164
- # has_text_count,
165
- # "Parsing of test_file has changed!",
166
- # )
167
- # self.assertEqual(
168
- # self.TEST_FILE_EMBEDDING_COUNT,
169
- # has_embeddings_count,
170
- # "Not all non-empty nodes have embeddings!",
171
- # )
172
-
173
- # def test_embed_nodes_recursively(self):
174
- # mock_embeddings = MockEmbeddingsFactory()
175
- # self.translator.set_embeddings(mock_embeddings)
176
- # self.translator._load_parameters()
177
- # input_block = self.translator.splitter.split(self.test_file)
178
- # self.translator._embed_nodes_recursively(
179
- # input_block, EmbeddingType.SOURCE, self.test_file.name
180
- # )
181
- # nodes = [input_block]
182
- # while nodes:
183
- # node = nodes.pop(0)
184
- # self.assertEqual(node.text is not None, node.embedding_id is not None)
185
- # nodes.extend(node.children)
186
-
187
- # @pytest.mark.slow
188
- # def test_translate_file_adds_source_embeddings(self):
189
- # mock_embeddings = MockEmbeddingsFactory()
190
- # self.translator.set_embeddings(mock_embeddings)
191
- # self.translator._load_parameters()
192
- # vector_store = self.translator.embeddings(EmbeddingType.SOURCE)
193
- # self.assertEqual(0, vector_store._add_texts_calls, "precondition")
194
- #
195
- # self.translator.translate_file(self.test_file)
196
- #
197
- # self.assertEqual(
198
- # self.TEST_FILE_EMBEDDING_COUNT,
199
- # vector_store._add_texts_calls,
200
- # "Did not find expected source embeddings",
201
- # )
202
-
203
- # @pytest.mark.slow
204
- # def test_embeddings_usage(self):
205
- # """Noodling on use of embeddings
206
- # To see results have to uncomment print_query_results() above
207
- # """
208
- # input_block = self.translator.splitter.split(self.test_file)
209
- # self.translator._embed_nodes_recursively(
210
- # input_block, EmbeddingType.SOURCE, self.test_file.name
211
- # )
212
- # vector_store = self.translator.embeddings(EmbeddingType.SOURCE)
213
- #
214
- # # this symbol has the lowest relevance scores of any in this test, but
215
- # # still not very low; multiple embedded nodes contain it
216
- # QUERY_STRING = "IWX_BAND_START"
217
- # query = self.translator._embeddings._embeddings.embed_query(QUERY_STRING)
218
- # n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
219
- # embedding=query,
220
- # k=10,
221
- # where_document={"$contains": QUERY_STRING},
222
- # )
223
- # self.assertTrue(len(n_results) > 1, "Why was valid symbol not found?")
224
- # print_query_results(QUERY_STRING, n_results)
225
-
226
- # in the XYZZY test, the least dissimilar results were the start and finish lines
227
- # 0, and 415, which produced a similarity score of 0.47:
228
-
229
- # QUERY_STRING = "XYZZY"
230
- # query = self.translator._embeddings.embed_query(QUERY_STRING)
231
- # n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
232
- # embedding=query,
233
- # k=10,
234
- # # filter={"end_line": 15},
235
- # # filter={"$and": [{"end_line": 15}, {"tokens": {"$gte": 21}}]},
236
- # # where_document={"$contains": QUERY_STRING},
237
- # )
238
- # print_query_results(QUERY_STRING, n_results)
239
- # # self.assertTrue(len(n_results) == 0, "Invalid symbol was found?")
240
-
241
- # # only returns a single result because only 1 embedded node contains
242
- # # CSV_ICASEARR:
243
- # QUERY_STRING = "What is the use of CSV_ICASEARR?"
244
- # query = self.translator._embeddings._embeddings.embed_query(QUERY_STRING)
245
- # n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
246
- # embedding=query,
247
- # k=10,
248
- # # where_document={"$contains": QUERY_STRING},
249
- # where_document={"$contains": "CSV_ICASEARR"},
250
- # )
251
- # print_query_results(QUERY_STRING, n_results)
252
- # self.assertTrue(len(n_results) == 1, "Was splitting changed?")
253
- #
254
- # # trimmed out some characters from line 43, and still not very similar scoring
255
- # QUERY_STRING = "IYL_EDGEBUFFER EDGEBUFFER IGN_MASK CELLSIZE"
256
- # query = self.translator._embeddings._embeddings.embed_query(QUERY_STRING)
257
- # n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
258
- # embedding=query,
259
- # k=10,
260
- # # where_document={"$contains": QUERY_STRING},
261
- # )
262
- # print_query_results(QUERY_STRING, n_results)
263
- #
264
- # # random string (as bad as XYZZY), but searching for a specific line
265
- # QUERY_STRING = "ghost in the invisible moon"
266
- # query = self.translator._embeddings._embeddings.embed_query(QUERY_STRING)
267
- # n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
268
- # embedding=query,
269
- # k=10,
270
- # filter={"$and": [{"end_line": 90}, {"tokens": {"$gte": 21}}]},
271
- # )
272
- # print_query_results(QUERY_STRING, n_results)
273
- # self.assertTrue(len(n_results) == 1, "Was splitting changed?")
274
-
275
- # @pytest.mark.slow
276
- # def test_document_embeddings_added_by_translate(self):
277
- # vector_store = self.req_translator.embeddings(EmbeddingType.REQUIREMENT)
278
- # self.assertEqual(0, vector_store._add_texts_calls, "Precondition failed")
279
- # self.req_translator.translate(self.test_file.parent, self.test_file.parent,
280
- # True)
281
- # self.assertTrue(vector_store._add_texts_calls > 0, "Why no documentation?")
282
-
283
- # @pytest.mark.slow
284
- # def test_embed_requirements(self):
285
- # vector_store = self.req_translator.embeddings(EmbeddingType.REQUIREMENT)
286
- # translated = self.req_translator.translate_file(self.test_file)
287
- # self.assertEqual(
288
- # 0,
289
- # vector_store._add_texts_calls,
290
- # "Unexpected requirements added in translate_file",
291
- # )
292
- # result = self.req_translator._embed(
293
- # translated, EmbeddingType.REQUIREMENT, self.test_file.name
294
- # )
295
- # self.assertFalse(result, "No text in root node, so should generate no docs")
296
- # self.assertIsNotNone(translated.children[0].text, "Data changed?")
297
- # result = self.req_translator._embed(
298
- # translated.children[0], EmbeddingType.REQUIREMENT, self.test_file.name
299
- # )
300
- # self.assertTrue(result, "No docs generated for first child node?")
301
-
302
74
  def test_invalid_selections(self) -> None:
303
75
  """Tests that settings values for the translator will raise exceptions"""
304
76
  self.assertRaises(
@@ -317,14 +89,14 @@ class TestDiagramGenerator(unittest.TestCase):
317
89
  def setUp(self):
318
90
  """Set up the tests."""
319
91
  self.diagram_generator = DiagramGenerator(
320
- model="gpt-3.5-turbo-0125",
92
+ model="gpt-4o",
321
93
  source_language="fortran",
322
94
  diagram_type="Activity",
323
95
  )
324
96
 
325
97
  def test_init(self):
326
98
  """Test __init__ method."""
327
- self.assertEqual(self.diagram_generator._model_name, "gpt-3.5-turbo-0125")
99
+ self.assertEqual(self.diagram_generator._model_name, "gpt-4o")
328
100
  self.assertEqual(self.diagram_generator._source_language, "fortran")
329
101
  self.assertEqual(self.diagram_generator._diagram_type, "Activity")
330
102
 
@@ -370,8 +142,8 @@ def test_language_combinations(
370
142
  """Tests that translator target language settings are consistent
371
143
  with prompt template expectations.
372
144
  """
373
- translator = Translator(model="gpt-3.5-turbo-0125")
374
- translator.set_model("gpt-3.5-turbo-0125")
145
+ translator = Translator(model="gpt-4o")
146
+ translator.set_model("gpt-4o")
375
147
  translator.set_source_language(source_language)
376
148
  translator.set_target_language(expected_target_language, expected_target_version)
377
149
  translator.set_prompt(prompt_template)
@@ -379,5 +151,5 @@ def test_language_combinations(
379
151
  assert translator._target_language == expected_target_language # nosec
380
152
  assert translator._target_version == expected_target_version # nosec
381
153
  assert translator._splitter.language == source_language # nosec
382
- assert translator._splitter.model.model_name == "gpt-3.5-turbo-0125" # nosec
154
+ assert translator._splitter.model.model_name == "gpt-4o" # nosec
383
155
  assert translator._prompt_template_name == prompt_template # nosec
@@ -64,7 +64,7 @@ class Converter:
64
64
 
65
65
  def __init__(
66
66
  self,
67
- model: str = "gpt-3.5-turbo-0125",
67
+ model: str = "gpt-4o",
68
68
  model_arguments: dict[str, Any] = {},
69
69
  source_language: str = "fortran",
70
70
  max_prompts: int = 10,
@@ -92,6 +92,7 @@ class Converter:
92
92
  self.override_token_limit: bool = max_tokens is not None
93
93
 
94
94
  self._model_name: str
95
+ self._model_id: str
95
96
  self._custom_model_arguments: dict[str, Any]
96
97
 
97
98
  self._source_language: str
@@ -265,7 +266,9 @@ class Converter:
265
266
  # model_arguments.update(self._custom_model_arguments)
266
267
 
267
268
  # Load the model
268
- self._llm, token_limit, self.model_cost = load_model(self._model_name)
269
+ self._llm, self._model_id, token_limit, self.model_cost = load_model(
270
+ self._model_name
271
+ )
269
272
  # Set the max_tokens to less than half the model's limit to allow for enough
270
273
  # tokens at output
271
274
  # Only modify max_tokens if it is not specified by user
@@ -283,7 +286,7 @@ class Converter:
283
286
  If the relevant fields have not been changed since the last time this
284
287
  method was called, nothing happens.
285
288
  """
286
- prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
289
+ prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
287
290
  source_language=self._source_language,
288
291
  prompt_template=self._prompt_template_name,
289
292
  )
@@ -90,7 +90,7 @@ class Translator(Converter):
90
90
  f"({self._source_language} != {self._target_language})"
91
91
  )
92
92
 
93
- prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
93
+ prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
94
94
  source_language=self._source_language,
95
95
  target_language=self._target_language,
96
96
  target_version=self._target_version,
@@ -4,8 +4,8 @@ from unittest.mock import MagicMock
4
4
 
5
5
  import pytest
6
6
 
7
- from ...utils.enums import EmbeddingType
8
- from ..collections import Collections
7
+ from janus.embedding.collections import Collections
8
+ from janus.utils.enums import EmbeddingType
9
9
 
10
10
 
11
11
  class TestCollections(unittest.TestCase):
@@ -2,7 +2,7 @@ import unittest
2
2
  from pathlib import Path
3
3
  from unittest.mock import patch
4
4
 
5
- from ..database import ChromaEmbeddingDatabase, uri_to_path
5
+ from janus.embedding.database import ChromaEmbeddingDatabase, uri_to_path
6
6
 
7
7
 
8
8
  class TestDatabase(unittest.TestCase):
@@ -5,9 +5,9 @@ from unittest.mock import MagicMock
5
5
 
6
6
  from chromadb.api.client import Client
7
7
 
8
- from ...language.treesitter import TreeSitterSplitter
9
- from ...utils.enums import EmbeddingType
10
- from ..vectorize import Vectorizer, VectorizerFactory
8
+ from janus.embedding.vectorize import Vectorizer, VectorizerFactory
9
+ from janus.language.treesitter import TreeSitterSplitter
10
+ from janus.utils.enums import EmbeddingType
11
11
 
12
12
 
13
13
  class MockDBVectorizer(VectorizerFactory):
@@ -5,8 +5,8 @@ from typing import Dict, Optional, Sequence
5
5
  from chromadb import Client, Collection
6
6
  from langchain_community.vectorstores import Chroma
7
7
 
8
- from ..utils.enums import EmbeddingType
9
- from .embedding_models_info import load_embedding_model
8
+ from janus.embedding.embedding_models_info import load_embedding_model
9
+ from janus.utils.enums import EmbeddingType
10
10
 
11
11
  # See https://docs.trychroma.com/telemetry#in-chromas-backend-using-environment-variables
12
12
  os.environ["ANONYMIZED_TELEMETRY"] = "False"
@@ -5,7 +5,7 @@ from urllib.request import url2pathname
5
5
 
6
6
  import chromadb
7
7
 
8
- from ..utils.logger import create_logger
8
+ from janus.utils.logger import create_logger
9
9
 
10
10
  log = create_logger(__name__)
11
11
 
@@ -8,7 +8,7 @@ from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEm
8
8
  from langchain_core.embeddings import Embeddings
9
9
  from langchain_openai import OpenAIEmbeddings
10
10
 
11
- from ..utils.logger import create_logger
11
+ from janus.utils.logger import create_logger
12
12
 
13
13
  load_dotenv()
14
14
 
@@ -6,10 +6,10 @@ from typing import Any, Dict, Optional, Sequence
6
6
  from chromadb import Client, Collection
7
7
  from langchain_community.vectorstores import Chroma
8
8
 
9
- from ..language.block import CodeBlock, TranslatedCodeBlock
10
- from ..utils.enums import EmbeddingType
11
- from .collections import Collections
12
- from .database import ChromaEmbeddingDatabase
9
+ from janus.embedding.collections import Collections
10
+ from janus.embedding.database import ChromaEmbeddingDatabase
11
+ from janus.language.block import CodeBlock, TranslatedCodeBlock
12
+ from janus.utils.enums import EmbeddingType
13
13
 
14
14
 
15
15
  class Vectorizer(object):
@@ -59,7 +59,7 @@ class Vectorizer(object):
59
59
  self,
60
60
  code_block: CodeBlock,
61
61
  collection_name: EmbeddingType | str,
62
- filename: str # perhaps this should be a relative path from the source, but for
62
+ filename: str, # perhaps this should be a relative path from the source, but for
63
63
  # now we're all in 1 directory
64
64
  ) -> None:
65
65
  """Calculate `code_block` embedding, returning success & storing in `embedding_id`
@@ -1,6 +1,6 @@
1
1
  import unittest
2
2
 
3
- from ..combine import CodeBlock, Combiner, TranslatedCodeBlock
3
+ from janus.language.combine import CodeBlock, Combiner, TranslatedCodeBlock
4
4
 
5
5
 
6
6
  class TestCombiner(unittest.TestCase):
@@ -1,6 +1,6 @@
1
1
  import unittest
2
2
 
3
- from ..splitter import Splitter
3
+ from janus.language.splitter import Splitter
4
4
 
5
5
 
6
6
  class TestSplitter(unittest.TestCase):
@@ -1,9 +1,9 @@
1
1
  import unittest
2
2
  from pathlib import Path
3
3
 
4
- from ....llm import load_model
5
- from ...combine import Combiner
6
- from ..alc import AlcSplitter
4
+ from janus.language.alc import AlcSplitter
5
+ from janus.language.combine import Combiner
6
+ from janus.llm import load_model
7
7
 
8
8
 
9
9
  class TestAlcSplitter(unittest.TestCase):
@@ -11,8 +11,8 @@ class TestAlcSplitter(unittest.TestCase):
11
11
 
12
12
  def setUp(self):
13
13
  """Set up the tests."""
14
- model_name = "gpt-3.5-turbo-0125"
15
- llm, _, _ = load_model(model_name)
14
+ model_name = "gpt-4o"
15
+ llm, _, _, _ = load_model(model_name)
16
16
  self.splitter = AlcSplitter(model=llm)
17
17
  self.combiner = Combiner(language="ibmhlasm")
18
18
  self.test_file = Path("janus/language/alc/_tests/alc.asm")
@@ -20,7 +20,7 @@ class TestAlcSplitter(unittest.TestCase):
20
20
  def test_split(self):
21
21
  """Test the split method."""
22
22
  tree_root = self.splitter.split(self.test_file)
23
- self.assertEqual(tree_root.n_descendents, 34)
23
+ self.assertAlmostEqual(tree_root.n_descendents, 32, delta=5)
24
24
  self.assertLessEqual(tree_root.max_tokens, self.splitter.max_tokens)
25
25
  self.assertFalse(tree_root.complete)
26
26
  self.combiner.combine_children(tree_root)