janus-llm 3.1.1__py3-none-any.whl → 3.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. janus/__init__.py +3 -3
  2. janus/_tests/test_cli.py +3 -3
  3. janus/cli.py +65 -8
  4. janus/converter/__init__.py +6 -6
  5. janus/converter/_tests/test_translate.py +10 -238
  6. janus/converter/converter.py +6 -3
  7. janus/converter/translate.py +1 -1
  8. janus/embedding/_tests/test_collections.py +2 -2
  9. janus/embedding/_tests/test_database.py +1 -1
  10. janus/embedding/_tests/test_vectorize.py +3 -3
  11. janus/embedding/collections.py +2 -2
  12. janus/embedding/database.py +1 -1
  13. janus/embedding/embedding_models_info.py +1 -1
  14. janus/embedding/vectorize.py +5 -5
  15. janus/language/_tests/test_combine.py +1 -1
  16. janus/language/_tests/test_splitter.py +1 -1
  17. janus/language/alc/_tests/test_alc.py +6 -6
  18. janus/language/alc/alc.py +5 -5
  19. janus/language/binary/_tests/test_binary.py +4 -4
  20. janus/language/binary/binary.py +5 -5
  21. janus/language/block.py +2 -2
  22. janus/language/combine.py +3 -3
  23. janus/language/file.py +2 -2
  24. janus/language/mumps/_tests/test_mumps.py +5 -5
  25. janus/language/mumps/mumps.py +5 -5
  26. janus/language/mumps/patterns.py +1 -1
  27. janus/language/naive/__init__.py +4 -4
  28. janus/language/naive/basic_splitter.py +4 -4
  29. janus/language/naive/chunk_splitter.py +4 -4
  30. janus/language/naive/registry.py +1 -1
  31. janus/language/naive/simple_ast.py +5 -5
  32. janus/language/naive/tag_splitter.py +4 -4
  33. janus/language/node.py +1 -1
  34. janus/language/splitter.py +4 -4
  35. janus/language/treesitter/_tests/test_treesitter.py +5 -5
  36. janus/language/treesitter/treesitter.py +4 -4
  37. janus/llm/__init__.py +1 -1
  38. janus/llm/model_callbacks.py +1 -1
  39. janus/llm/models_info.py +45 -23
  40. janus/metrics/_tests/test_bleu.py +1 -1
  41. janus/metrics/_tests/test_chrf.py +1 -1
  42. janus/metrics/_tests/test_file_pairing.py +1 -1
  43. janus/metrics/_tests/test_llm.py +5 -5
  44. janus/metrics/_tests/test_reading.py +1 -1
  45. janus/metrics/_tests/test_rouge_score.py +1 -1
  46. janus/metrics/_tests/test_similarity_score.py +1 -1
  47. janus/metrics/_tests/test_treesitter_metrics.py +2 -2
  48. janus/metrics/bleu.py +1 -1
  49. janus/metrics/chrf.py +1 -1
  50. janus/metrics/complexity_metrics.py +4 -4
  51. janus/metrics/file_pairing.py +5 -5
  52. janus/metrics/llm_metrics.py +1 -1
  53. janus/metrics/metric.py +11 -11
  54. janus/metrics/reading.py +1 -1
  55. janus/metrics/rouge_score.py +1 -1
  56. janus/metrics/similarity.py +2 -2
  57. janus/parsers/_tests/test_code_parser.py +1 -1
  58. janus/parsers/code_parser.py +2 -2
  59. janus/parsers/doc_parser.py +3 -3
  60. janus/parsers/eval_parser.py +2 -2
  61. janus/parsers/reqs_parser.py +3 -3
  62. janus/parsers/uml.py +1 -2
  63. janus/prompts/prompt.py +2 -2
  64. janus/utils/_tests/test_logger.py +1 -1
  65. janus/utils/_tests/test_progress.py +1 -1
  66. janus/utils/progress.py +1 -1
  67. {janus_llm-3.1.1.dist-info → janus_llm-3.2.1.dist-info}/METADATA +1 -1
  68. janus_llm-3.2.1.dist-info/RECORD +105 -0
  69. janus_llm-3.1.1.dist-info/RECORD +0 -105
  70. {janus_llm-3.1.1.dist-info → janus_llm-3.2.1.dist-info}/LICENSE +0 -0
  71. {janus_llm-3.1.1.dist-info → janus_llm-3.2.1.dist-info}/WHEEL +0 -0
  72. {janus_llm-3.1.1.dist-info → janus_llm-3.2.1.dist-info}/entry_points.txt +0 -0
janus/__init__.py CHANGED
@@ -2,10 +2,10 @@ import warnings
2
2
 
3
3
  from langchain_core._api.deprecation import LangChainDeprecationWarning
4
4
 
5
- from .converter.translate import Translator
6
- from .metrics import * # noqa: F403
5
+ from janus.converter.translate import Translator
6
+ from janus.metrics import * # noqa: F403
7
7
 
8
- __version__ = "3.1.1"
8
+ __version__ = "3.2.1"
9
9
 
10
10
  # Ignoring a deprecation warning from langchain_core that I can't seem to hunt down
11
11
  warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
janus/_tests/test_cli.py CHANGED
@@ -4,9 +4,9 @@ from unittest.mock import ANY, patch
4
4
 
5
5
  from typer.testing import CliRunner
6
6
 
7
- from ..cli import app, translate
8
- from ..embedding.embedding_models_info import EMBEDDING_MODEL_CONFIG_DIR
9
- from ..llm.models_info import MODEL_CONFIG_DIR
7
+ from janus.cli import app, translate
8
+ from janus.embedding.embedding_models_info import EMBEDDING_MODEL_CONFIG_DIR
9
+ from janus.llm.models_info import MODEL_CONFIG_DIR
10
10
 
11
11
 
12
12
  class TestCli(unittest.TestCase):
janus/cli.py CHANGED
@@ -32,8 +32,12 @@ from janus.language.treesitter import TreeSitterSplitter
32
32
  from janus.llm.model_callbacks import COST_PER_1K_TOKENS
33
33
  from janus.llm.models_info import (
34
34
  MODEL_CONFIG_DIR,
35
+ MODEL_ID_TO_LONG_ID,
35
36
  MODEL_TYPE_CONSTRUCTORS,
37
+ MODEL_TYPES,
36
38
  TOKEN_LIMITS,
39
+ bedrock_models,
40
+ openai_models,
37
41
  )
38
42
  from janus.metrics.cli import evaluate
39
43
  from janus.utils.enums import LANGUAGES
@@ -104,7 +108,7 @@ embedding = typer.Typer(
104
108
 
105
109
  def version_callback(value: bool) -> None:
106
110
  if value:
107
- from . import __version__ as version
111
+ from janus import __version__ as version
108
112
 
109
113
  print(f"Janus CLI [blue]v{version}[/blue]")
110
114
  raise typer.Exit()
@@ -179,7 +183,7 @@ def translate(
179
183
  "-L",
180
184
  help="The custom name of the model set with 'janus llm add'.",
181
185
  ),
182
- ] = "gpt-3.5-turbo-0125",
186
+ ] = "gpt-4o",
183
187
  max_prompts: Annotated[
184
188
  int,
185
189
  typer.Option(
@@ -301,7 +305,7 @@ def document(
301
305
  "-L",
302
306
  help="The custom name of the model set with 'janus llm add'.",
303
307
  ),
304
- ] = "gpt-3.5-turbo-0125",
308
+ ] = "gpt-4o",
305
309
  max_prompts: Annotated[
306
310
  int,
307
311
  typer.Option(
@@ -437,7 +441,7 @@ def diagram(
437
441
  "-L",
438
442
  help="The custom name of the model set with 'janus llm add'.",
439
443
  ),
440
- ] = "gpt-3.5-turbo-0125",
444
+ ] = "gpt-4o",
441
445
  max_prompts: Annotated[
442
446
  int,
443
447
  typer.Option(
@@ -800,16 +804,44 @@ def llm_add(
800
804
  "model_cost": {"input": in_cost, "output": out_cost},
801
805
  }
802
806
  elif model_type == "OpenAI":
803
- model_name = typer.prompt("Enter the model name", default="gpt-3.5-turbo-0125")
807
+ model_id = typer.prompt(
808
+ "Enter the model ID (list model IDs with `janus llm ls -a`)",
809
+ default="gpt-4o",
810
+ type=click.Choice(openai_models),
811
+ show_choices=False,
812
+ )
804
813
  params = dict(
805
- model_name=model_name,
814
+ # OpenAI uses the "model_name" key for what we're calling "long_model_id"
815
+ model_name=MODEL_ID_TO_LONG_ID[model_id],
806
816
  temperature=0.7,
807
817
  n=1,
808
818
  )
809
- max_tokens = TOKEN_LIMITS[model_name]
810
- model_cost = COST_PER_1K_TOKENS[model_name]
819
+ max_tokens = TOKEN_LIMITS[MODEL_ID_TO_LONG_ID[model_id]]
820
+ model_cost = COST_PER_1K_TOKENS[MODEL_ID_TO_LONG_ID[model_id]]
821
+ cfg = {
822
+ "model_type": model_type,
823
+ "model_id": model_id,
824
+ "model_args": params,
825
+ "token_limit": max_tokens,
826
+ "model_cost": model_cost,
827
+ }
828
+ elif model_type == "BedrockChat" or model_type == "Bedrock":
829
+ model_id = typer.prompt(
830
+ "Enter the model ID (list model IDs with `janus llm ls -a`)",
831
+ default="bedrock-claude-sonnet",
832
+ type=click.Choice(bedrock_models),
833
+ show_choices=False,
834
+ )
835
+ params = dict(
836
+ # Bedrock uses the "model_id" key for what we're calling "long_model_id"
837
+ model_id=MODEL_ID_TO_LONG_ID[model_id],
838
+ model_kwargs={"temperature": 0.7},
839
+ )
840
+ max_tokens = TOKEN_LIMITS[MODEL_ID_TO_LONG_ID[model_id]]
841
+ model_cost = COST_PER_1K_TOKENS[MODEL_ID_TO_LONG_ID[model_id]]
811
842
  cfg = {
812
843
  "model_type": model_type,
844
+ "model_id": model_id,
813
845
  "model_args": params,
814
846
  "token_limit": max_tokens,
815
847
  "model_cost": model_cost,
@@ -821,6 +853,31 @@ def llm_add(
821
853
  print(f"Model config written to {model_cfg}")
822
854
 
823
855
 
856
+ @llm.command("ls", help="List all of the user-configured models")
857
+ def llm_ls(
858
+ all: Annotated[
859
+ bool,
860
+ typer.Option(
861
+ "--all",
862
+ "-a",
863
+ is_flag=True,
864
+ help="List all models, including the default model IDs.",
865
+ click_type=click.Choice(sorted(list(MODEL_TYPE_CONSTRUCTORS.keys()))),
866
+ ),
867
+ ] = False,
868
+ ):
869
+ print("\n[green]User-configured models[/green]:")
870
+ for model_cfg in MODEL_CONFIG_DIR.glob("*.json"):
871
+ with open(model_cfg, "r") as f:
872
+ cfg = json.load(f)
873
+ print(f"\t[blue]{model_cfg.stem}[/blue]: [purple]{cfg['model_type']}[/purple]")
874
+
875
+ if all:
876
+ print("\n[green]Available model IDs[/green]:")
877
+ for model_id, model_type in MODEL_TYPES.items():
878
+ print(f"\t[blue]{model_id}[/blue]: [purple]{model_type}[/purple]")
879
+
880
+
824
881
  @embedding.command("add", help="Add an embedding model config to janus")
825
882
  def embedding_add(
826
883
  model_name: Annotated[
@@ -1,6 +1,6 @@
1
- from .converter import Converter
2
- from .diagram import DiagramGenerator
3
- from .document import Documenter, MadLibsDocumenter, MultiDocumenter
4
- from .evaluate import Evaluator
5
- from .requirements import RequirementsDocumenter
6
- from .translate import Translator
1
+ from janus.converter.converter import Converter
2
+ from janus.converter.diagram import DiagramGenerator
3
+ from janus.converter.document import Documenter, MadLibsDocumenter, MultiDocumenter
4
+ from janus.converter.evaluate import Evaluator
5
+ from janus.converter.requirements import RequirementsDocumenter
6
+ from janus.converter.translate import Translator
@@ -7,37 +7,11 @@ from langchain.schema import Document
7
7
  from langchain.schema.embeddings import Embeddings
8
8
  from langchain.schema.vectorstore import VST, VectorStore
9
9
 
10
+ from janus.converter.diagram import DiagramGenerator
11
+ from janus.converter.requirements import RequirementsDocumenter
12
+ from janus.converter.translate import Translator
10
13
  from janus.language.block import CodeBlock, TranslatedCodeBlock
11
14
 
12
- from ..diagram import DiagramGenerator
13
- from ..requirements import RequirementsDocumenter
14
- from ..translate import Translator
15
-
16
- # from langchain.vectorstores import Chroma
17
-
18
-
19
- # from ..utils.enums import EmbeddingType
20
-
21
-
22
- def print_query_results(query, n_results):
23
- # print(f"\n{query}")
24
- # count = 1
25
- # for t in n_results:
26
- # short_code = (
27
- # (t[0].page_content[0:50] + "..")
28
- # if (len(t[0].page_content) > 50)
29
- # else t[0].page_content
30
- # )
31
- # return_index = short_code.find("\n")
32
- # if -1 != return_index:
33
- # short_code = short_code[0:return_index] + ".."
34
- # print(
35
- # f"{count}. @ {t[0].metadata['start_line']}-{t[0].metadata['end_line']}"
36
- # f" -- {t[1]} -- {short_code}"
37
- # )
38
- # count += 1
39
- pass
40
-
41
15
 
42
16
  class MockCollection(VectorStore):
43
17
  """Vector store for testing"""
@@ -65,21 +39,13 @@ class MockCollection(VectorStore):
65
39
  raise NotImplementedError("from_texts() not implemented!")
66
40
 
67
41
 
68
- # class MockEmbeddingsFactory(EmbeddingsFactory):
69
- # """Embeddings for testing - uses MockCollection"""
70
- #
71
- # def get_embeddings(self) -> Embeddings:
72
- # return MockCollection()
73
- #
74
-
75
-
76
42
  class TestTranslator(unittest.TestCase):
77
43
  """Tests for the Translator class."""
78
44
 
79
45
  def setUp(self):
80
46
  """Set up the tests."""
81
47
  self.translator = Translator(
82
- model="gpt-3.5-turbo-0125",
48
+ model="gpt-4o",
83
49
  source_language="fortran",
84
50
  target_language="python",
85
51
  target_version="3.10",
@@ -88,7 +54,7 @@ class TestTranslator(unittest.TestCase):
88
54
  self.TEST_FILE_EMBEDDING_COUNT = 14
89
55
 
90
56
  self.req_translator = RequirementsDocumenter(
91
- model="gpt-3.5-turbo-0125",
57
+ model="gpt-4o",
92
58
  source_language="fortran",
93
59
  prompt_template="requirements",
94
60
  )
@@ -105,200 +71,6 @@ class TestTranslator(unittest.TestCase):
105
71
  # unit tests anyway
106
72
  self.assertTrue(python_file.exists())
107
73
 
108
- # def test_embeddings(self):
109
- # """Testing access to embeddings"""
110
- # vector_store = self.translator.embeddings(EmbeddingType.SOURCE)
111
- # self.assertIsInstance(vector_store, Chroma, "Unexpected vector store type!")
112
- # self.assertEqual(
113
- # 0, vector_store._collection.count(), "Non-empty initial vector store?"
114
- # )
115
- #
116
- # self.translator.set_model("llama")
117
- # self.translator._load_parameters()
118
- # vector_store = self.translator.embeddings(EmbeddingType.SOURCE)
119
- # self.assertIsInstance(vector_store, Chroma)
120
- # self.assertEqual(
121
- # 0, vector_store._collection.count(), "Non-empty initial vector store?"
122
- # )
123
-
124
- # def test_embed_split_source(self):
125
- # """Characterize _embed method"""
126
- # mock_embeddings = MockEmbeddingsFactory()
127
- # self.translator.set_embeddings(mock_embeddings)
128
- # self.translator._load_parameters()
129
- # input_block = self.translator.splitter.split(self.test_file)
130
- # self.assertIsNone(
131
- # input_block.text, "Root node of input text shouldn't contain text"
132
- # )
133
- # self.assertIsNone(input_block.embedding_id, "Precondition failed")
134
- #
135
- # result = self.translator._embed(
136
- # input_block, EmbeddingType.SOURCE, self.test_file.name
137
- # )
138
- #
139
- # self.assertFalse(result, "Nothing to embed, so should have no result")
140
- # self.assertIsNone(
141
- # input_block.embedding_id, "Embeddings should not have changed")
142
-
143
- # def test_embed_has_values_for_each_non_empty_node(self):
144
- # """Characterize our sample fortran file"""
145
- # mock_embeddings = MockEmbeddingsFactory()
146
- # self.translator.set_embeddings(mock_embeddings)
147
- # self.translator._load_parameters()
148
- # input_block = self.translator.splitter.split(self.test_file)
149
- # self.translator._embed_nodes_recursively(
150
- # input_block, EmbeddingType.SOURCE, self.test_file.name
151
- # )
152
- # has_text_count = 0
153
- # has_embeddings_count = 0
154
- # nodes = [input_block]
155
- # while nodes:
156
- # node = nodes.pop(0)
157
- # if node.text:
158
- # has_text_count += 1
159
- # if node.embedding_id:
160
- # has_embeddings_count += 1
161
- # nodes.extend(node.children)
162
- # self.assertEqual(
163
- # self.TEST_FILE_EMBEDDING_COUNT,
164
- # has_text_count,
165
- # "Parsing of test_file has changed!",
166
- # )
167
- # self.assertEqual(
168
- # self.TEST_FILE_EMBEDDING_COUNT,
169
- # has_embeddings_count,
170
- # "Not all non-empty nodes have embeddings!",
171
- # )
172
-
173
- # def test_embed_nodes_recursively(self):
174
- # mock_embeddings = MockEmbeddingsFactory()
175
- # self.translator.set_embeddings(mock_embeddings)
176
- # self.translator._load_parameters()
177
- # input_block = self.translator.splitter.split(self.test_file)
178
- # self.translator._embed_nodes_recursively(
179
- # input_block, EmbeddingType.SOURCE, self.test_file.name
180
- # )
181
- # nodes = [input_block]
182
- # while nodes:
183
- # node = nodes.pop(0)
184
- # self.assertEqual(node.text is not None, node.embedding_id is not None)
185
- # nodes.extend(node.children)
186
-
187
- # @pytest.mark.slow
188
- # def test_translate_file_adds_source_embeddings(self):
189
- # mock_embeddings = MockEmbeddingsFactory()
190
- # self.translator.set_embeddings(mock_embeddings)
191
- # self.translator._load_parameters()
192
- # vector_store = self.translator.embeddings(EmbeddingType.SOURCE)
193
- # self.assertEqual(0, vector_store._add_texts_calls, "precondition")
194
- #
195
- # self.translator.translate_file(self.test_file)
196
- #
197
- # self.assertEqual(
198
- # self.TEST_FILE_EMBEDDING_COUNT,
199
- # vector_store._add_texts_calls,
200
- # "Did not find expected source embeddings",
201
- # )
202
-
203
- # @pytest.mark.slow
204
- # def test_embeddings_usage(self):
205
- # """Noodling on use of embeddings
206
- # To see results have to uncomment print_query_results() above
207
- # """
208
- # input_block = self.translator.splitter.split(self.test_file)
209
- # self.translator._embed_nodes_recursively(
210
- # input_block, EmbeddingType.SOURCE, self.test_file.name
211
- # )
212
- # vector_store = self.translator.embeddings(EmbeddingType.SOURCE)
213
- #
214
- # # this symbol has the lowest relevance scores of any in this test, but
215
- # # still not very low; multiple embedded nodes contain it
216
- # QUERY_STRING = "IWX_BAND_START"
217
- # query = self.translator._embeddings._embeddings.embed_query(QUERY_STRING)
218
- # n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
219
- # embedding=query,
220
- # k=10,
221
- # where_document={"$contains": QUERY_STRING},
222
- # )
223
- # self.assertTrue(len(n_results) > 1, "Why was valid symbol not found?")
224
- # print_query_results(QUERY_STRING, n_results)
225
-
226
- # in the XYZZY test, the least dissimilar results were the start and finish lines
227
- # 0, and 415, which produced a similarity score of 0.47:
228
-
229
- # QUERY_STRING = "XYZZY"
230
- # query = self.translator._embeddings.embed_query(QUERY_STRING)
231
- # n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
232
- # embedding=query,
233
- # k=10,
234
- # # filter={"end_line": 15},
235
- # # filter={"$and": [{"end_line": 15}, {"tokens": {"$gte": 21}}]},
236
- # # where_document={"$contains": QUERY_STRING},
237
- # )
238
- # print_query_results(QUERY_STRING, n_results)
239
- # # self.assertTrue(len(n_results) == 0, "Invalid symbol was found?")
240
-
241
- # # only returns a single result because only 1 embedded node contains
242
- # # CSV_ICASEARR:
243
- # QUERY_STRING = "What is the use of CSV_ICASEARR?"
244
- # query = self.translator._embeddings._embeddings.embed_query(QUERY_STRING)
245
- # n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
246
- # embedding=query,
247
- # k=10,
248
- # # where_document={"$contains": QUERY_STRING},
249
- # where_document={"$contains": "CSV_ICASEARR"},
250
- # )
251
- # print_query_results(QUERY_STRING, n_results)
252
- # self.assertTrue(len(n_results) == 1, "Was splitting changed?")
253
- #
254
- # # trimmed out some characters from line 43, and still not very similar scoring
255
- # QUERY_STRING = "IYL_EDGEBUFFER EDGEBUFFER IGN_MASK CELLSIZE"
256
- # query = self.translator._embeddings._embeddings.embed_query(QUERY_STRING)
257
- # n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
258
- # embedding=query,
259
- # k=10,
260
- # # where_document={"$contains": QUERY_STRING},
261
- # )
262
- # print_query_results(QUERY_STRING, n_results)
263
- #
264
- # # random string (as bad as XYZZY), but searching for a specific line
265
- # QUERY_STRING = "ghost in the invisible moon"
266
- # query = self.translator._embeddings._embeddings.embed_query(QUERY_STRING)
267
- # n_results = vector_store.similarity_search_by_vector_with_relevance_scores(
268
- # embedding=query,
269
- # k=10,
270
- # filter={"$and": [{"end_line": 90}, {"tokens": {"$gte": 21}}]},
271
- # )
272
- # print_query_results(QUERY_STRING, n_results)
273
- # self.assertTrue(len(n_results) == 1, "Was splitting changed?")
274
-
275
- # @pytest.mark.slow
276
- # def test_document_embeddings_added_by_translate(self):
277
- # vector_store = self.req_translator.embeddings(EmbeddingType.REQUIREMENT)
278
- # self.assertEqual(0, vector_store._add_texts_calls, "Precondition failed")
279
- # self.req_translator.translate(self.test_file.parent, self.test_file.parent,
280
- # True)
281
- # self.assertTrue(vector_store._add_texts_calls > 0, "Why no documentation?")
282
-
283
- # @pytest.mark.slow
284
- # def test_embed_requirements(self):
285
- # vector_store = self.req_translator.embeddings(EmbeddingType.REQUIREMENT)
286
- # translated = self.req_translator.translate_file(self.test_file)
287
- # self.assertEqual(
288
- # 0,
289
- # vector_store._add_texts_calls,
290
- # "Unexpected requirements added in translate_file",
291
- # )
292
- # result = self.req_translator._embed(
293
- # translated, EmbeddingType.REQUIREMENT, self.test_file.name
294
- # )
295
- # self.assertFalse(result, "No text in root node, so should generate no docs")
296
- # self.assertIsNotNone(translated.children[0].text, "Data changed?")
297
- # result = self.req_translator._embed(
298
- # translated.children[0], EmbeddingType.REQUIREMENT, self.test_file.name
299
- # )
300
- # self.assertTrue(result, "No docs generated for first child node?")
301
-
302
74
  def test_invalid_selections(self) -> None:
303
75
  """Tests that settings values for the translator will raise exceptions"""
304
76
  self.assertRaises(
@@ -317,14 +89,14 @@ class TestDiagramGenerator(unittest.TestCase):
317
89
  def setUp(self):
318
90
  """Set up the tests."""
319
91
  self.diagram_generator = DiagramGenerator(
320
- model="gpt-3.5-turbo-0125",
92
+ model="gpt-4o",
321
93
  source_language="fortran",
322
94
  diagram_type="Activity",
323
95
  )
324
96
 
325
97
  def test_init(self):
326
98
  """Test __init__ method."""
327
- self.assertEqual(self.diagram_generator._model_name, "gpt-3.5-turbo-0125")
99
+ self.assertEqual(self.diagram_generator._model_name, "gpt-4o")
328
100
  self.assertEqual(self.diagram_generator._source_language, "fortran")
329
101
  self.assertEqual(self.diagram_generator._diagram_type, "Activity")
330
102
 
@@ -370,8 +142,8 @@ def test_language_combinations(
370
142
  """Tests that translator target language settings are consistent
371
143
  with prompt template expectations.
372
144
  """
373
- translator = Translator(model="gpt-3.5-turbo-0125")
374
- translator.set_model("gpt-3.5-turbo-0125")
145
+ translator = Translator(model="gpt-4o")
146
+ translator.set_model("gpt-4o")
375
147
  translator.set_source_language(source_language)
376
148
  translator.set_target_language(expected_target_language, expected_target_version)
377
149
  translator.set_prompt(prompt_template)
@@ -379,5 +151,5 @@ def test_language_combinations(
379
151
  assert translator._target_language == expected_target_language # nosec
380
152
  assert translator._target_version == expected_target_version # nosec
381
153
  assert translator._splitter.language == source_language # nosec
382
- assert translator._splitter.model.model_name == "gpt-3.5-turbo-0125" # nosec
154
+ assert translator._splitter.model.model_name == "gpt-4o" # nosec
383
155
  assert translator._prompt_template_name == prompt_template # nosec
@@ -64,7 +64,7 @@ class Converter:
64
64
 
65
65
  def __init__(
66
66
  self,
67
- model: str = "gpt-3.5-turbo-0125",
67
+ model: str = "gpt-4o",
68
68
  model_arguments: dict[str, Any] = {},
69
69
  source_language: str = "fortran",
70
70
  max_prompts: int = 10,
@@ -92,6 +92,7 @@ class Converter:
92
92
  self.override_token_limit: bool = max_tokens is not None
93
93
 
94
94
  self._model_name: str
95
+ self._model_id: str
95
96
  self._custom_model_arguments: dict[str, Any]
96
97
 
97
98
  self._source_language: str
@@ -265,7 +266,9 @@ class Converter:
265
266
  # model_arguments.update(self._custom_model_arguments)
266
267
 
267
268
  # Load the model
268
- self._llm, token_limit, self.model_cost = load_model(self._model_name)
269
+ self._llm, self._model_id, token_limit, self.model_cost = load_model(
270
+ self._model_name
271
+ )
269
272
  # Set the max_tokens to less than half the model's limit to allow for enough
270
273
  # tokens at output
271
274
  # Only modify max_tokens if it is not specified by user
@@ -283,7 +286,7 @@ class Converter:
283
286
  If the relevant fields have not been changed since the last time this
284
287
  method was called, nothing happens.
285
288
  """
286
- prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
289
+ prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
287
290
  source_language=self._source_language,
288
291
  prompt_template=self._prompt_template_name,
289
292
  )
@@ -90,7 +90,7 @@ class Translator(Converter):
90
90
  f"({self._source_language} != {self._target_language})"
91
91
  )
92
92
 
93
- prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
93
+ prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
94
94
  source_language=self._source_language,
95
95
  target_language=self._target_language,
96
96
  target_version=self._target_version,
@@ -4,8 +4,8 @@ from unittest.mock import MagicMock
4
4
 
5
5
  import pytest
6
6
 
7
- from ...utils.enums import EmbeddingType
8
- from ..collections import Collections
7
+ from janus.embedding.collections import Collections
8
+ from janus.utils.enums import EmbeddingType
9
9
 
10
10
 
11
11
  class TestCollections(unittest.TestCase):
@@ -2,7 +2,7 @@ import unittest
2
2
  from pathlib import Path
3
3
  from unittest.mock import patch
4
4
 
5
- from ..database import ChromaEmbeddingDatabase, uri_to_path
5
+ from janus.embedding.database import ChromaEmbeddingDatabase, uri_to_path
6
6
 
7
7
 
8
8
  class TestDatabase(unittest.TestCase):
@@ -5,9 +5,9 @@ from unittest.mock import MagicMock
5
5
 
6
6
  from chromadb.api.client import Client
7
7
 
8
- from ...language.treesitter import TreeSitterSplitter
9
- from ...utils.enums import EmbeddingType
10
- from ..vectorize import Vectorizer, VectorizerFactory
8
+ from janus.embedding.vectorize import Vectorizer, VectorizerFactory
9
+ from janus.language.treesitter import TreeSitterSplitter
10
+ from janus.utils.enums import EmbeddingType
11
11
 
12
12
 
13
13
  class MockDBVectorizer(VectorizerFactory):
@@ -5,8 +5,8 @@ from typing import Dict, Optional, Sequence
5
5
  from chromadb import Client, Collection
6
6
  from langchain_community.vectorstores import Chroma
7
7
 
8
- from ..utils.enums import EmbeddingType
9
- from .embedding_models_info import load_embedding_model
8
+ from janus.embedding.embedding_models_info import load_embedding_model
9
+ from janus.utils.enums import EmbeddingType
10
10
 
11
11
  # See https://docs.trychroma.com/telemetry#in-chromas-backend-using-environment-variables
12
12
  os.environ["ANONYMIZED_TELEMETRY"] = "False"
@@ -5,7 +5,7 @@ from urllib.request import url2pathname
5
5
 
6
6
  import chromadb
7
7
 
8
- from ..utils.logger import create_logger
8
+ from janus.utils.logger import create_logger
9
9
 
10
10
  log = create_logger(__name__)
11
11
 
@@ -8,7 +8,7 @@ from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEm
8
8
  from langchain_core.embeddings import Embeddings
9
9
  from langchain_openai import OpenAIEmbeddings
10
10
 
11
- from ..utils.logger import create_logger
11
+ from janus.utils.logger import create_logger
12
12
 
13
13
  load_dotenv()
14
14
 
@@ -6,10 +6,10 @@ from typing import Any, Dict, Optional, Sequence
6
6
  from chromadb import Client, Collection
7
7
  from langchain_community.vectorstores import Chroma
8
8
 
9
- from ..language.block import CodeBlock, TranslatedCodeBlock
10
- from ..utils.enums import EmbeddingType
11
- from .collections import Collections
12
- from .database import ChromaEmbeddingDatabase
9
+ from janus.embedding.collections import Collections
10
+ from janus.embedding.database import ChromaEmbeddingDatabase
11
+ from janus.language.block import CodeBlock, TranslatedCodeBlock
12
+ from janus.utils.enums import EmbeddingType
13
13
 
14
14
 
15
15
  class Vectorizer(object):
@@ -59,7 +59,7 @@ class Vectorizer(object):
59
59
  self,
60
60
  code_block: CodeBlock,
61
61
  collection_name: EmbeddingType | str,
62
- filename: str # perhaps this should be a relative path from the source, but for
62
+ filename: str, # perhaps this should be a relative path from the source, but for
63
63
  # now we're all in 1 directory
64
64
  ) -> None:
65
65
  """Calculate `code_block` embedding, returning success & storing in `embedding_id`
@@ -1,6 +1,6 @@
1
1
  import unittest
2
2
 
3
- from ..combine import CodeBlock, Combiner, TranslatedCodeBlock
3
+ from janus.language.combine import CodeBlock, Combiner, TranslatedCodeBlock
4
4
 
5
5
 
6
6
  class TestCombiner(unittest.TestCase):
@@ -1,6 +1,6 @@
1
1
  import unittest
2
2
 
3
- from ..splitter import Splitter
3
+ from janus.language.splitter import Splitter
4
4
 
5
5
 
6
6
  class TestSplitter(unittest.TestCase):
@@ -1,9 +1,9 @@
1
1
  import unittest
2
2
  from pathlib import Path
3
3
 
4
- from ....llm import load_model
5
- from ...combine import Combiner
6
- from ..alc import AlcSplitter
4
+ from janus.language.alc import AlcSplitter
5
+ from janus.language.combine import Combiner
6
+ from janus.llm import load_model
7
7
 
8
8
 
9
9
  class TestAlcSplitter(unittest.TestCase):
@@ -11,8 +11,8 @@ class TestAlcSplitter(unittest.TestCase):
11
11
 
12
12
  def setUp(self):
13
13
  """Set up the tests."""
14
- model_name = "gpt-3.5-turbo-0125"
15
- llm, _, _ = load_model(model_name)
14
+ model_name = "gpt-4o"
15
+ llm, _, _, _ = load_model(model_name)
16
16
  self.splitter = AlcSplitter(model=llm)
17
17
  self.combiner = Combiner(language="ibmhlasm")
18
18
  self.test_file = Path("janus/language/alc/_tests/alc.asm")
@@ -20,7 +20,7 @@ class TestAlcSplitter(unittest.TestCase):
20
20
  def test_split(self):
21
21
  """Test the split method."""
22
22
  tree_root = self.splitter.split(self.test_file)
23
- self.assertEqual(tree_root.n_descendents, 34)
23
+ self.assertAlmostEqual(tree_root.n_descendents, 32, delta=5)
24
24
  self.assertLessEqual(tree_root.max_tokens, self.splitter.max_tokens)
25
25
  self.assertFalse(tree_root.complete)
26
26
  self.combiner.combine_children(tree_root)