janus-llm 1.0.0__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. janus/__init__.py +9 -1
  2. janus/__main__.py +4 -0
  3. janus/_tests/test_cli.py +128 -0
  4. janus/_tests/test_translate.py +49 -7
  5. janus/cli.py +530 -46
  6. janus/converter.py +50 -19
  7. janus/embedding/_tests/test_collections.py +2 -8
  8. janus/embedding/_tests/test_database.py +32 -0
  9. janus/embedding/_tests/test_vectorize.py +9 -4
  10. janus/embedding/collections.py +49 -6
  11. janus/embedding/embedding_models_info.py +120 -0
  12. janus/embedding/vectorize.py +53 -62
  13. janus/language/_tests/__init__.py +0 -0
  14. janus/language/_tests/test_combine.py +62 -0
  15. janus/language/_tests/test_splitter.py +16 -0
  16. janus/language/binary/_tests/test_binary.py +16 -1
  17. janus/language/binary/binary.py +10 -3
  18. janus/language/block.py +31 -30
  19. janus/language/combine.py +26 -34
  20. janus/language/mumps/_tests/test_mumps.py +2 -2
  21. janus/language/mumps/mumps.py +93 -9
  22. janus/language/naive/__init__.py +4 -0
  23. janus/language/naive/basic_splitter.py +14 -0
  24. janus/language/naive/chunk_splitter.py +26 -0
  25. janus/language/naive/registry.py +13 -0
  26. janus/language/naive/simple_ast.py +18 -0
  27. janus/language/naive/tag_splitter.py +61 -0
  28. janus/language/splitter.py +168 -74
  29. janus/language/treesitter/_tests/test_treesitter.py +9 -6
  30. janus/language/treesitter/treesitter.py +37 -13
  31. janus/llm/model_callbacks.py +177 -0
  32. janus/llm/models_info.py +134 -70
  33. janus/metrics/__init__.py +8 -0
  34. janus/metrics/_tests/__init__.py +0 -0
  35. janus/metrics/_tests/reference.py +2 -0
  36. janus/metrics/_tests/target.py +2 -0
  37. janus/metrics/_tests/test_bleu.py +56 -0
  38. janus/metrics/_tests/test_chrf.py +67 -0
  39. janus/metrics/_tests/test_file_pairing.py +59 -0
  40. janus/metrics/_tests/test_llm.py +91 -0
  41. janus/metrics/_tests/test_reading.py +28 -0
  42. janus/metrics/_tests/test_rouge_score.py +65 -0
  43. janus/metrics/_tests/test_similarity_score.py +23 -0
  44. janus/metrics/_tests/test_treesitter_metrics.py +110 -0
  45. janus/metrics/bleu.py +66 -0
  46. janus/metrics/chrf.py +55 -0
  47. janus/metrics/cli.py +7 -0
  48. janus/metrics/complexity_metrics.py +208 -0
  49. janus/metrics/file_pairing.py +113 -0
  50. janus/metrics/llm_metrics.py +202 -0
  51. janus/metrics/metric.py +466 -0
  52. janus/metrics/reading.py +70 -0
  53. janus/metrics/rouge_score.py +96 -0
  54. janus/metrics/similarity.py +53 -0
  55. janus/metrics/splitting.py +38 -0
  56. janus/parsers/_tests/__init__.py +0 -0
  57. janus/parsers/_tests/test_code_parser.py +32 -0
  58. janus/parsers/code_parser.py +24 -253
  59. janus/parsers/doc_parser.py +169 -0
  60. janus/parsers/eval_parser.py +80 -0
  61. janus/parsers/reqs_parser.py +72 -0
  62. janus/prompts/prompt.py +103 -30
  63. janus/translate.py +636 -111
  64. janus/utils/_tests/__init__.py +0 -0
  65. janus/utils/_tests/test_logger.py +67 -0
  66. janus/utils/_tests/test_progress.py +20 -0
  67. janus/utils/enums.py +56 -3
  68. janus/utils/progress.py +56 -0
  69. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/METADATA +23 -10
  70. janus_llm-2.0.0.dist-info/RECORD +94 -0
  71. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/WHEEL +1 -1
  72. janus_llm-1.0.0.dist-info/RECORD +0 -48
  73. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/LICENSE +0 -0
  74. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/entry_points.txt +0 -0
janus/translate.py CHANGED
@@ -1,34 +1,62 @@
1
- import uuid
1
+ import json
2
+ import math
3
+ import re
4
+ import time
5
+ from copy import deepcopy
2
6
  from pathlib import Path
3
- from typing import Any, Dict
7
+ from typing import Any
4
8
 
5
- from chromadb.api.models.Collection import Collection
6
- from langchain.callbacks import get_openai_callback
9
+ from langchain.output_parsers import RetryWithErrorOutputParser
10
+ from langchain.output_parsers.fix import OutputFixingParser
11
+ from langchain_core.exceptions import OutputParserException
12
+ from langchain_core.language_models import BaseLanguageModel
13
+ from langchain_core.output_parsers import BaseOutputParser
14
+ from langchain_core.prompts import ChatPromptTemplate
15
+ from langchain_core.runnables import RunnableLambda, RunnableParallel
16
+ from openai import BadRequestError, RateLimitError
17
+ from text_generation.errors import ValidationError
18
+
19
+ from janus.language.naive.registry import CUSTOM_SPLITTERS
7
20
 
8
21
  from .converter import Converter, run_if_changed
22
+ from .embedding.vectorize import ChromaDBVectorizer
9
23
  from .language.block import CodeBlock, TranslatedCodeBlock
24
+ from .language.combine import ChunkCombiner, Combiner, JsonCombiner
25
+ from .language.splitter import EmptyTreeError, FileSizeError, TokenLimitError
10
26
  from .llm import load_model
11
- from .parsers.code_parser import PARSER_TYPES, CodeParser, EvaluationParser, JanusParser
12
- from .prompts.prompt import SAME_OUTPUT, TEXT_OUTPUT, PromptEngine
27
+ from .llm.model_callbacks import get_model_callback
28
+ from .llm.models_info import MODEL_PROMPT_ENGINES
29
+ from .parsers.code_parser import CodeParser, GenericParser
30
+ from .parsers.doc_parser import MadlibsDocumentationParser, MultiDocumentationParser
31
+ from .parsers.eval_parser import EvaluationParser
32
+ from .parsers.reqs_parser import RequirementsParser
33
+ from .prompts.prompt import SAME_OUTPUT, TEXT_OUTPUT
13
34
  from .utils.enums import LANGUAGES
14
35
  from .utils.logger import create_logger
15
36
 
16
37
  log = create_logger(__name__)
17
38
 
18
39
 
40
+ PARSER_TYPES: set[str] = {"code", "text", "eval", "madlibs", "multidoc", "requirements"}
41
+
42
+
19
43
  class Translator(Converter):
20
44
  """A class that translates code from one programming language to another."""
21
45
 
22
46
  def __init__(
23
47
  self,
24
- model: str = "gpt-3.5-turbo",
25
- model_arguments: Dict[str, Any] = {},
48
+ model: str = "gpt-3.5-turbo-0125",
49
+ model_arguments: dict[str, Any] = {},
26
50
  source_language: str = "fortran",
27
51
  target_language: str = "python",
28
- target_version: str = "3.10",
52
+ target_version: str | None = "3.10",
29
53
  max_prompts: int = 10,
54
+ max_tokens: int | None = None,
30
55
  prompt_template: str | Path = "simple",
31
56
  parser_type: str = "code",
57
+ db_path: str | None = None,
58
+ db_config: dict[str, Any] | None = None,
59
+ custom_splitter: str | None = None,
32
60
  ) -> None:
33
61
  """Initialize a Translator instance.
34
62
 
@@ -41,37 +69,53 @@ class Translator(Converter):
41
69
  target_language: The target programming language.
42
70
  target_version: The target version of the target programming language.
43
71
  max_prompts: The maximum number of prompts to try before giving up.
72
+ max_tokens: The maximum number of tokens the model will take in.
73
+ If unspecificed, model's default max will be used.
44
74
  prompt_template: name of prompt template directory
45
75
  (see janus/prompts/templates) or path to a directory.
46
76
  parser_type: The type of parser to use for parsing the LLM output. Valid
47
77
  values are "code" (default), "text", and "eval".
48
78
  """
79
+ self._custom_splitter = custom_splitter
49
80
  super().__init__(source_language=source_language)
50
81
 
51
- self._parser_type: None | str
52
- self._parser: None | JanusParser
53
- self._model_name: None | str
54
- self._custom_model_arguments: None | Dict[str, Any]
55
- self._target_language: None | str
56
- self._target_version: None | str
57
- self._target_glob: None | str
58
- self._prompt_template_name: None | str
82
+ self._parser_type: str | None
83
+ self._model_name: str | None
84
+ self._custom_model_arguments: dict[str, Any] | None
85
+ self._target_language: str | None
86
+ self._target_version: str | None
87
+ self._target_glob: str | None
88
+ self._prompt_template_name: str | None
89
+ self._db_path: str | None
90
+ self._db_config: dict[str, Any] | None
91
+
92
+ self._llm: BaseLanguageModel | None
93
+ self._parser: BaseOutputParser | None
94
+ self._combiner: Combiner | None
95
+ self._prompt: ChatPromptTemplate | None
96
+
97
+ self.max_prompts = max_prompts
98
+ self.override_token_limit = False if max_tokens is None else True
99
+ self._max_tokens = max_tokens
59
100
 
60
101
  self.set_model(model_name=model, **model_arguments)
61
102
  self.set_parser_type(parser_type=parser_type)
62
103
  self.set_prompt(prompt_template=prompt_template)
63
104
  self.set_target_language(
64
- target_language=target_language, target_version=target_version
105
+ target_language=target_language,
106
+ target_version=target_version,
65
107
  )
108
+ self.set_db_path(db_path=db_path)
109
+ self.set_db_config(db_config=db_config)
66
110
 
67
111
  self._load_parameters()
68
112
 
69
- self.max_prompts = max_prompts
70
-
71
113
  def _load_parameters(self) -> None:
72
114
  self._load_model()
73
- self._load_prompt_engine()
115
+ self._load_prompt()
74
116
  self._load_parser()
117
+ self._load_combiner()
118
+ self._load_vectorizer()
75
119
  super()._load_parameters() # will call self._changed_attrs.clear()
76
120
 
77
121
  def translate(
@@ -79,7 +123,7 @@ class Translator(Converter):
79
123
  input_directory: str | Path,
80
124
  output_directory: str | Path | None = None,
81
125
  overwrite: bool = False,
82
- output_collection: Collection | None = None,
126
+ collection_name: str | None = None,
83
127
  ) -> None:
84
128
  """Translate code in the input directory from the source language to the target
85
129
  language, and write the resulting files to the output directory.
@@ -88,6 +132,7 @@ class Translator(Converter):
88
132
  input_directory: The directory containing the code to translate.
89
133
  output_directory: The directory to write the translated code to.
90
134
  overwrite: Whether to overwrite existing files (vs skip them)
135
+ collection_name: Collection to add to
91
136
  """
92
137
  # Convert paths to pathlib Paths if needed
93
138
  if isinstance(input_directory, str):
@@ -99,25 +144,77 @@ class Translator(Converter):
99
144
  if output_directory is not None and not output_directory.exists():
100
145
  output_directory.mkdir(parents=True)
101
146
 
147
+ source_suffix = LANGUAGES[self._source_language]["suffix"]
102
148
  target_suffix = LANGUAGES[self._target_language]["suffix"]
103
149
 
104
- input_paths = input_directory.rglob(self._source_glob)
150
+ input_paths = [p for p in input_directory.rglob(self._source_glob)]
151
+
152
+ log.info(f"Input directory: {input_directory.absolute()}")
153
+ log.info(
154
+ f"{self._source_language.capitalize()} '*.{source_suffix}' files: "
155
+ f"{len(input_paths)}"
156
+ )
157
+ log.info(
158
+ "Other files (skipped): "
159
+ f"{len(list(input_directory.iterdir())) - len(input_paths)}\n"
160
+ )
161
+ if output_directory is not None:
162
+ output_paths = [
163
+ output_directory
164
+ / p.relative_to(input_directory).with_suffix(f".{target_suffix}")
165
+ for p in input_paths
166
+ ]
167
+ in_out_pairs = list(zip(input_paths, output_paths))
168
+ if not overwrite:
169
+ n_files = len(in_out_pairs)
170
+ in_out_pairs = [
171
+ (inp, outp) for inp, outp in in_out_pairs if not outp.exists()
172
+ ]
173
+ log.info(
174
+ f"Skipping {n_files - len(in_out_pairs)} existing "
175
+ f"'*.{source_suffix}' files"
176
+ )
177
+ else:
178
+ in_out_pairs = [(f, None) for f in input_paths]
179
+ log.info(f"Translating {len(in_out_pairs)} '*.{source_suffix}' files")
105
180
 
106
181
  # Now, loop through every code block in every file and translate it with an LLM
107
182
  total_cost = 0.0
108
- for in_path in input_paths:
109
- relative = in_path.relative_to(input_directory)
110
- # output_name = relative.with_suffix(f".{target_suffix}").name
111
- if output_directory is not None:
112
- out_path = output_directory / relative.with_suffix(f".{target_suffix}")
113
- else:
114
- out_path = None
115
- # Track the cost of translating the file
116
- # TODO: If non-OpenAI models with prices are added, this will need
117
- # to be updated.
118
- with get_openai_callback() as cb:
183
+ for in_path, out_path in in_out_pairs:
184
+ # Translate the file, skip it if there's a rate limit error
185
+ try:
119
186
  out_block = self.translate_file(in_path)
120
- total_cost += cb.total_cost
187
+ total_cost += out_block.total_cost
188
+ except RateLimitError:
189
+ continue
190
+ except OutputParserException as e:
191
+ log.error(f"Skipping {in_path.name}, failed to parse output: {e}.")
192
+ continue
193
+ except BadRequestError as e:
194
+ if str(e).startswith("Detected an error in the prompt"):
195
+ log.warning("Malformed input, skipping")
196
+ continue
197
+ raise e
198
+ except ValidationError as e:
199
+ # Only allow ValidationError to pass if token limit is manually set
200
+ if self.override_token_limit:
201
+ log.warning(
202
+ "Current file and manually set token "
203
+ "limit is too large for this model, skipping"
204
+ )
205
+ continue
206
+ raise e
207
+ except TokenLimitError:
208
+ log.warning("Ran into irreducible node too large for context, skipping")
209
+ continue
210
+ except EmptyTreeError:
211
+ log.warning(
212
+ f'Input file "{in_path.name}" has no nodes of interest, skipping'
213
+ )
214
+ continue
215
+ except FileSizeError:
216
+ log.warning("Current tile is too large for basic splitter, skipping")
217
+ continue
121
218
 
122
219
  # Don't attempt to write files for which translation failed
123
220
  if not out_block.translated:
@@ -139,19 +236,36 @@ class Translator(Converter):
139
236
  #
140
237
  # self._embed_nodes_recursively(out_block, embedding_type, filename)
141
238
 
239
+ if collection_name is not None:
240
+ self._vectorizer.add_nodes_recursively(
241
+ out_block,
242
+ collection_name,
243
+ in_path.name,
244
+ )
245
+ # out_text = self.parser.parse_combined_output(out_block.complete_text)
246
+ # # Using same id naming convention from vectorize.py
247
+ # ids = [str(uuid.uuid3(uuid.NAMESPACE_DNS, out_text))]
248
+ # output_collection.upsert(ids=ids, documents=[out_text])
249
+
142
250
  # Make sure the tree's code has been consolidated at the top level
143
251
  # before writing to file
144
252
  self._combiner.combine(out_block)
145
253
  if out_path is not None and (overwrite or not out_path.exists()):
146
254
  self._save_to_file(out_block, out_path)
147
- if output_collection is not None:
148
- out_text = self.parser.parse_combined_output(out_block.complete_text)
149
- # Using same id naming convention from vectorize.py
150
- ids = [str(uuid.uuid3(uuid.NAMESPACE_DNS, out_text))]
151
- output_collection.upsert(ids=ids, documents=[out_text])
152
255
 
153
256
  log.info(f"Total cost: ${total_cost:,.2f}")
154
257
 
258
+ def _split_file(self, file: Path) -> CodeBlock:
259
+ filename = file.name
260
+ log.info(f"[{filename}] Splitting file")
261
+ root = self._splitter.split(file)
262
+ log.info(
263
+ f"[{filename}] File split into {root.n_descendents:,} blocks, "
264
+ f"tree of height {root.height}"
265
+ )
266
+ log.info(f"[{filename}] Input CodeBlock Structure:\n{root.tree_str()}")
267
+ return root
268
+
155
269
  def translate_file(self, file: Path) -> TranslatedCodeBlock:
156
270
  """Translate a single file.
157
271
 
@@ -164,18 +278,12 @@ class Translator(Converter):
164
278
  `Combiner.combine_children` on the block.
165
279
  """
166
280
  self._load_parameters()
167
-
168
281
  filename = file.name
169
- log.info(f"[{filename}] Splitting file")
170
- input_block = self._splitter.split(file)
171
- log.info(
172
- f"[{filename}] File split into {input_block.n_descendents:,} blocks, "
173
- f"tree of height {input_block.height}"
174
- )
175
- log.info(f"[{filename}] Input CodeBlock Structure:\n{input_block.tree_str()}")
176
- # (temporarily?) comment-out adding embeddings; will be moved
177
- # self._embed_nodes_recursively(input_block, EmbeddingType.SOURCE, filename)
282
+
283
+ input_block = self._split_file(file)
284
+ t0 = time.time()
178
285
  output_block = self._iterative_translate(input_block)
286
+ output_block.processing_time = time.time() - t0
179
287
  if output_block.translated:
180
288
  completeness = output_block.translation_completeness
181
289
  log.info(
@@ -183,7 +291,9 @@ class Translator(Converter):
183
291
  f" {completeness:.2%} of input successfully translated\n"
184
292
  f" Total cost: ${output_block.total_cost:,.2f}\n"
185
293
  f" Total retries: {output_block.total_retries:,d}\n"
294
+ f" Output CodeBlock Structure:\n{input_block.tree_str()}\n"
186
295
  )
296
+
187
297
  else:
188
298
  log.error(
189
299
  f"[{filename}] Translation failed\n"
@@ -223,25 +333,14 @@ class Translator(Converter):
223
333
  while stack:
224
334
  translated_block = stack.pop()
225
335
 
226
- # Track the cost of translating this block
227
- # TODO: If non-OpenAI models with prices are added, this will need
228
- # to be updated.
229
- with get_openai_callback() as cb:
230
- self._add_translation(translated_block)
231
- translated_block.cost = cb.total_cost
232
- translated_block.retries = max(0, cb.successful_requests - 1)
336
+ self._add_translation(translated_block)
233
337
 
234
338
  # If translating this block was unsuccessful, don't bother with its
235
339
  # children (they wouldn't show up in the final text anyway)
236
340
  if not translated_block.translated:
237
341
  continue
238
342
 
239
- for child in translated_block.children:
240
- # Don't bother translating children if they aren't used
241
- if self._combiner.contains_child(translated_block.text, child):
242
- stack.append(child)
243
- else:
244
- log.warning(f"Skipping {child.id} (not referenced in parent code)")
343
+ stack.extend(translated_block.children)
245
344
 
246
345
  progress = translated_root.translation_completeness
247
346
  if progress - last_prog > prog_delta:
@@ -266,11 +365,6 @@ class Translator(Converter):
266
365
  block.translated = True
267
366
  return
268
367
 
269
- log.debug(f"[{block.name}] Translating...")
270
- log.debug(f"[{block.name}] Input text:\n{block.original.text}")
271
- prompt = self._prompt_engine.create(block.original)
272
- top_score = -1.0
273
-
274
368
  if self._llm is None:
275
369
  message = (
276
370
  "Model not configured correctly, cannot translate. Try setting "
@@ -279,51 +373,85 @@ class Translator(Converter):
279
373
  log.error(message)
280
374
  raise ValueError(message)
281
375
 
282
- # Retry the request up to max_prompts times before failing
283
- for _ in range(self.max_prompts + 1):
284
- output = self._llm.predict_messages(prompt)
285
- try:
286
- parsed_output = self.parser.parse(output.content)
287
- except ValueError as e:
288
- log.warning(f"[{block.name}] Failed to parse output: {e}")
289
- log.debug(f"[{block.name}] Failed output:\n{output.content}")
290
- continue
291
-
292
- score = self.parser.score(block.original, parsed_output)
293
- if score > top_score:
294
- block.text = parsed_output
295
- top_score = score
296
-
297
- if score >= 1.0:
298
- break
299
-
300
- else:
301
- if block.text is None:
302
- error_msg = (
303
- f"[{block.name}] Failed to parse output after "
304
- f"{self.max_prompts} retries. Marking as untranslated."
305
- )
306
- log.warning(error_msg)
307
- return
376
+ log.debug(f"[{block.name}] Translating...")
377
+ log.debug(f"[{block.name}] Input text:\n{block.original.text}")
308
378
 
309
- log.warning(f"[{block.name}] Output not complete")
379
+ # Track the cost of translating this block
380
+ # TODO: If non-OpenAI models with prices are added, this will need
381
+ # to be updated.
382
+ with get_model_callback() as cb:
383
+ t0 = time.time()
384
+ block.text = self._run_chain(block)
385
+ block.processing_time = time.time() - t0
386
+ block.cost = cb.total_cost
387
+ block.retries = max(0, cb.successful_requests - 1)
310
388
 
311
389
  block.tokens = self._llm.get_num_tokens(block.text)
312
390
  block.translated = True
313
391
 
314
392
  log.debug(f"[{block.name}] Output code:\n{block.text}")
315
393
 
394
+ def _run_chain(self, block: TranslatedCodeBlock) -> str:
395
+ """Run the model with three nested error fixing schemes.
396
+ First, try to fix simple formatting errors by giving the model just
397
+ the output and the parsing error. After a number of attempts, try
398
+ giving the model the output, the parsing error, and the original
399
+ input. Again check/retry this output to solve for formatting errors.
400
+ If we still haven't succeeded after several attempts, the model may
401
+ be getting thrown off by a bad initial output; start from scratch
402
+ and try again.
403
+
404
+ The number of tries for each layer of this scheme is roughly equal
405
+ to the cube root of self.max_retries, so the total calls to the
406
+ LLM will be roughly as expected (up to sqrt(self.max_retries) over)
407
+ """
408
+ self._parser.set_reference(block.original)
409
+
410
+ # Retries with just the output and the error
411
+ n1 = round(self.max_prompts ** (1 / 3))
412
+
413
+ # Retries with the input, output, and error
414
+ n2 = round((self.max_prompts // n1) ** (1 / 2))
415
+
416
+ # Retries with just the input
417
+ n3 = math.ceil(self.max_prompts / (n1 * n2))
418
+
419
+ fix_format = OutputFixingParser.from_llm(
420
+ llm=self._llm,
421
+ parser=self._parser,
422
+ max_retries=n1,
423
+ )
424
+ retry = RetryWithErrorOutputParser.from_llm(
425
+ llm=self._llm,
426
+ parser=fix_format,
427
+ max_retries=n2,
428
+ )
429
+
430
+ completion_chain = self._prompt | self._llm
431
+ chain = RunnableParallel(
432
+ completion=completion_chain, prompt_value=self._prompt
433
+ ) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
434
+
435
+ for _ in range(n3):
436
+ try:
437
+ return chain.invoke({"SOURCE_CODE": block.original.text})
438
+ except OutputParserException:
439
+ pass
440
+
441
+ raise OutputParserException(f"Failed to parse after {n1*n2*n3} retries")
442
+
316
443
  def _save_to_file(self, block: CodeBlock, out_path: Path) -> None:
317
444
  """Save a file to disk.
318
445
 
319
446
  Arguments:
320
447
  block: The `CodeBlock` to save to a file.
321
448
  """
322
- out_text = self.parser.parse_combined_output(block.complete_text)
449
+ # TODO: can't use output fixer and this system for combining output
450
+ out_text = self._parser.parse_combined_output(block.complete_text)
323
451
  out_path.parent.mkdir(parents=True, exist_ok=True)
324
452
  out_path.write_text(out_text, encoding="utf-8")
325
453
 
326
- def set_model(self, model_name: str, **custom_arguments: Dict[str, Any]):
454
+ def set_model(self, model_name: str, **custom_arguments: dict[str, Any]):
327
455
  """Validate and set the model name.
328
456
 
329
457
  The affected objects will not be updated until translate() is called.
@@ -363,7 +491,9 @@ class Translator(Converter):
363
491
  """
364
492
  self._prompt_template_name = prompt_template
365
493
 
366
- def set_target_language(self, target_language: str, target_version: str) -> None:
494
+ def set_target_language(
495
+ self, target_language: str, target_version: str | None
496
+ ) -> None:
367
497
  """Validate and set the target language.
368
498
 
369
499
  The affected objects will not be updated until translate() is called.
@@ -382,8 +512,14 @@ class Translator(Converter):
382
512
  self._target_language = target_language
383
513
  self._target_version = target_version
384
514
 
515
+ def set_db_path(self, db_path: str) -> None:
516
+ self._db_path = db_path
517
+
518
+ def set_db_config(self, db_config: dict[str, Any] | None) -> None:
519
+ self._db_config = db_config
520
+
385
521
  @run_if_changed("_model_name", "_custom_model_arguments")
386
- def _load_model(self):
522
+ def _load_model(self) -> None:
387
523
  """Load the model according to this instance's attributes.
388
524
 
389
525
  If the relevant fields have not been changed since the last time this method was
@@ -398,7 +534,9 @@ class Translator(Converter):
398
534
  self._llm, token_limit, self.model_cost = load_model(self._model_name)
399
535
  # Set the max_tokens to less than half the model's limit to allow for enough
400
536
  # tokens at output
401
- self._max_tokens = token_limit // 2.5
537
+ # Only modify max_tokens if it is not specified by user
538
+ if not self.override_token_limit:
539
+ self._max_tokens = token_limit // 2.5
402
540
 
403
541
  @run_if_changed("_parser_type", "_target_language")
404
542
  def _load_parser(self) -> None:
@@ -412,14 +550,26 @@ class Translator(Converter):
412
550
  f"Target language ({self._target_language}) suggests target "
413
551
  f"parser should be 'text', but is '{self._parser_type}'"
414
552
  )
553
+ if (
554
+ self._parser_type in {"eval", "multidoc", "madlibs"}
555
+ and "json" != self._target_language
556
+ ):
557
+ raise ValueError(
558
+ f"Parser type ({self._parser_type}) suggests target language"
559
+ f" should be 'json', but is '{self._target_language}'"
560
+ )
415
561
  if "code" == self._parser_type:
416
- self.parser = CodeParser(language=self._target_language)
562
+ self._parser = CodeParser(language=self._target_language)
417
563
  elif "eval" == self._parser_type:
418
- self.parser = EvaluationParser(
419
- expected_keys={"syntax", "style", "completeness", "correctness"}
420
- )
564
+ self._parser = EvaluationParser()
565
+ elif "multidoc" == self._parser_type:
566
+ self._parser = MultiDocumentationParser()
567
+ elif "madlibs" == self._parser_type:
568
+ self._parser = MadlibsDocumentationParser()
421
569
  elif "text" == self._parser_type:
422
- self.parser = JanusParser()
570
+ self._parser = GenericParser()
571
+ elif "requirements" == self._parser_type:
572
+ self._parser = RequirementsParser()
423
573
  else:
424
574
  raise ValueError(
425
575
  f"Unsupported parser type: {self._parser_type}. Can be: "
@@ -427,13 +577,17 @@ class Translator(Converter):
427
577
  )
428
578
 
429
579
  @run_if_changed(
430
- "_prompt_template_name", "_source_language", "_target_language", "_target_version"
580
+ "_prompt_template_name",
581
+ "_source_language",
582
+ "_target_language",
583
+ "_target_version",
584
+ "_model_name",
431
585
  )
432
- def _load_prompt_engine(self) -> None:
433
- """Load the prompt engine according to this instance's attributes.
586
+ def _load_prompt(self) -> None:
587
+ """Load the prompt according to this instance's attributes.
434
588
 
435
- If the relevant fields have not been changed since the last time this method was
436
- called, nothing happens.
589
+ If the relevant fields have not been changed since the last time this
590
+ method was called, nothing happens.
437
591
  """
438
592
  if self._prompt_template_name in SAME_OUTPUT:
439
593
  if self._target_language != self._source_language:
@@ -448,9 +602,380 @@ class Translator(Converter):
448
602
  f"language should be 'text', but is '{self._target_language}'"
449
603
  )
450
604
 
451
- self._prompt_engine = PromptEngine(
605
+ prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
452
606
  source_language=self._source_language,
453
607
  target_language=self._target_language,
454
608
  target_version=self._target_version,
455
609
  prompt_template=self._prompt_template_name,
456
610
  )
611
+ self._prompt = prompt_engine.prompt
612
+
613
+ @run_if_changed("_db_path")
614
+ def _load_vectorizer(self) -> None:
615
+ if self._db_path is None:
616
+ self._vectorizer = None
617
+ return
618
+ vectorizer_factory = ChromaDBVectorizer()
619
+ self._vectorizer = vectorizer_factory.create_vectorizer(
620
+ self._db_path, self._db_config
621
+ )
622
+
623
+ @run_if_changed(
624
+ "_source_language",
625
+ "_max_tokens",
626
+ "_llm",
627
+ "_protected_node_types",
628
+ "_prune_node_types",
629
+ )
630
+ def _load_splitter(self) -> None:
631
+ if self._custom_splitter is None:
632
+ super()._load_splitter()
633
+ else:
634
+ kwargs = dict(
635
+ max_tokens=self._max_tokens,
636
+ model=self._llm,
637
+ protected_node_types=self._protected_node_types,
638
+ prune_node_types=self._prune_node_types,
639
+ )
640
+ # TODO: This should be configurable
641
+ if self._custom_splitter == "tag":
642
+ kwargs["tag"] = "<ITMOD_ALC_SPLIT>"
643
+ self._splitter = CUSTOM_SPLITTERS[self._custom_splitter](
644
+ language=self._source_language, **kwargs
645
+ )
646
+
647
+ @run_if_changed("_target_language", "_parser_type")
648
+ def _load_combiner(self) -> None:
649
+ if self._parser_type == "requirements":
650
+ self._combiner = ChunkCombiner()
651
+ elif self._target_language == "json":
652
+ self._combiner = JsonCombiner()
653
+ else:
654
+ self._combiner = Combiner()
655
+
656
+
657
+ class Documenter(Translator):
658
+ def __init__(
659
+ self, source_language: str = "fortran", drop_comments: bool = True, **kwargs
660
+ ):
661
+ kwargs.update(
662
+ source_language=source_language,
663
+ target_language="text",
664
+ target_version=None,
665
+ prompt_template="document",
666
+ parser_type="text",
667
+ )
668
+ super().__init__(**kwargs)
669
+
670
+ if drop_comments:
671
+ comment_node_type = LANGUAGES[source_language].get(
672
+ "comment_node_type", "comment"
673
+ )
674
+ self.set_prune_node_types([comment_node_type])
675
+
676
+
677
+ class MultiDocumenter(Documenter):
678
+ def __init__(self, **kwargs):
679
+ super().__init__(**kwargs)
680
+ self.set_prompt("multidocument")
681
+ self.set_parser_type("multidoc")
682
+ self.set_target_language("json", None)
683
+
684
+
685
+ class MadLibsDocumenter(Documenter):
686
+ def __init__(
687
+ self,
688
+ comments_per_request: int | None = None,
689
+ **kwargs,
690
+ ) -> None:
691
+ kwargs.update(drop_comments=False)
692
+ super().__init__(**kwargs)
693
+ self.set_prompt("document_madlibs")
694
+ self.set_parser_type("madlibs")
695
+ self.set_target_language("json", None)
696
+ self.comments_per_request = comments_per_request
697
+
698
+ def _add_translation(self, block: TranslatedCodeBlock):
699
+ if block.translated:
700
+ return
701
+
702
+ if block.original.text is None:
703
+ block.translated = True
704
+ return
705
+
706
+ if self.comments_per_request is None:
707
+ return super()._add_translation(block)
708
+
709
+ comment_pattern = r"<(?:INLINE|BLOCK)_COMMENT \w{8}>"
710
+ comments = list(
711
+ re.finditer(
712
+ comment_pattern,
713
+ block.original.text,
714
+ )
715
+ )
716
+
717
+ if not comments:
718
+ log.info(f"[{block.name}] Skipping commentless block")
719
+ block.translated = True
720
+ block.text = None
721
+ block.complete = True
722
+ return
723
+
724
+ if len(comments) <= self.comments_per_request:
725
+ return super()._add_translation(block)
726
+
727
+ comment_group_indices = list(range(0, len(comments), self.comments_per_request))
728
+ log.debug(
729
+ f"[{block.name}] Block contains more than {self.comments_per_request}"
730
+ f" comments, splitting {len(comments)} comments into"
731
+ f" {len(comment_group_indices)} groups"
732
+ )
733
+
734
+ block.processing_time = 0
735
+ block.cost = 0
736
+ block.retries = 0
737
+ obj = {}
738
+ for i in range(0, len(comments), self.comments_per_request):
739
+ # Split the text into the section containing comments of interest,
740
+ # all the text prior to those comments, and all the text after them
741
+ working_comments = comments[i : i + self.comments_per_request]
742
+ start_idx = working_comments[0].start()
743
+ end_idx = working_comments[-1].end()
744
+ prefix = block.original.text[:start_idx]
745
+ keeper = block.original.text[start_idx:end_idx]
746
+ suffix = block.original.text[end_idx:]
747
+
748
+ # Strip all comment placeholders outside of the section of interest
749
+ prefix = re.sub(comment_pattern, "", prefix)
750
+ suffix = re.sub(comment_pattern, "", suffix)
751
+
752
+ # Build a new TranslatedBlock using the new working text
753
+ working_copy = deepcopy(block.original)
754
+ working_copy.text = prefix + keeper + suffix
755
+ working_block = TranslatedCodeBlock(working_copy, self._target_language)
756
+
757
+ # Run the LLM on the working text
758
+ super()._add_translation(working_block)
759
+
760
+ # Update metadata to include for all runs
761
+ block.retries += working_block.retries
762
+ block.cost += working_block.cost
763
+ block.processing_time += working_block.processing_time
764
+
765
+ # Update the output text to merge this section's output in
766
+ out_text = self._parser.parse(working_block.text)
767
+ obj.update(json.loads(out_text))
768
+
769
+ self._parser.set_reference(block.original)
770
+ block.text = self._parser.parse(json.dumps(obj))
771
+ block.tokens = self._llm.get_num_tokens(block.text)
772
+ block.translated = True
773
+
774
+ def _get_obj(
775
+ self, block: TranslatedCodeBlock
776
+ ) -> dict[str, int | float | dict[str, str]]:
777
+ out_text = self._parser.parse_combined_output(block.complete_text)
778
+ obj = dict(
779
+ retries=block.total_retries,
780
+ cost=block.total_cost,
781
+ processing_time=block.processing_time,
782
+ comments=json.loads(out_text),
783
+ )
784
+ return obj
785
+
786
+ def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
787
+ """Save a file to disk.
788
+
789
+ Arguments:
790
+ block: The `CodeBlock` to save to a file.
791
+ """
792
+ obj = self._get_obj(block)
793
+ out_path.parent.mkdir(parents=True, exist_ok=True)
794
+ out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")
795
+
796
+
797
+ class DiagramGenerator(Documenter):
798
+ """DiagramGenerator
799
+
800
+ A class that translates code from one programming language to a set of diagrams.
801
+ """
802
+
803
+ def __init__(
804
+ self,
805
+ model: str = "gpt-3.5-turbo-0125",
806
+ model_arguments: dict[str, Any] = {},
807
+ source_language: str = "fortran",
808
+ max_prompts: int = 10,
809
+ db_path: str | None = None,
810
+ db_config: dict[str, Any] | None = None,
811
+ diagram_type="Activity",
812
+ add_documentation=False,
813
+ custom_splitter: str | None = None,
814
+ ) -> None:
815
+ """Initialize the DiagramGenerator class
816
+
817
+ Arguments:
818
+ model: The LLM to use for translation. If an OpenAI model, the
819
+ `OPENAI_API_KEY` environment variable must be set and the
820
+ `OPENAI_ORG_ID` environment variable should be set if needed.
821
+ model_arguments: Additional arguments to pass to the LLM constructor.
822
+ source_language: The source programming language.
823
+ max_prompts: The maximum number of prompts to try before giving up.
824
+ db_path: path to chroma database
825
+ db_config: database configuraiton
826
+ diagram_type: type of PLANTUML diagram to generate
827
+ """
828
+ super().__init__(
829
+ model=model,
830
+ model_arguments=model_arguments,
831
+ source_language=source_language,
832
+ max_prompts=max_prompts,
833
+ db_path=db_path,
834
+ db_config=db_config,
835
+ custom_splitter=custom_splitter,
836
+ )
837
+ self._diagram_type = diagram_type
838
+ self._add_documentation = add_documentation
839
+ self._documenter = None
840
+ self._model = model
841
+ self._model_arguments = model_arguments
842
+ self._max_prompts = max_prompts
843
+ if add_documentation:
844
+ self._diagram_prompt_template_name = "diagram_with_documentation"
845
+ else:
846
+ self._diagram_prompt_template_name = "diagram"
847
+ self._load_diagram_prompt_engine()
848
+
849
+ def _add_translation(self, block: TranslatedCodeBlock) -> None:
850
+ """Given an "empty" `TranslatedCodeBlock`, translate the code represented in
851
+ `block.original`, setting the relevant fields in the translated block. The
852
+ `TranslatedCodeBlock` is updated in-pace, nothing is returned. Note that this
853
+ translates *only* the code for this block, not its children.
854
+
855
+ Arguments:
856
+ block: An empty `TranslatedCodeBlock`
857
+ """
858
+ if block.translated:
859
+ return
860
+
861
+ if block.original.text is None:
862
+ block.translated = True
863
+ return
864
+
865
+ if self._add_documentation:
866
+ documentation_block = deepcopy(block)
867
+ super()._add_translation(documentation_block)
868
+ if not documentation_block.translated:
869
+ message = "Error: unable to produce documentation for code block"
870
+ log.message(message)
871
+ raise ValueError(message)
872
+ documentation = json.loads(documentation_block.text)["docstring"]
873
+
874
+ if self._llm is None:
875
+ message = (
876
+ "Model not configured correctly, cannot translate. Try setting "
877
+ "the model"
878
+ )
879
+ log.error(message)
880
+ raise ValueError(message)
881
+
882
+ log.debug(f"[{block.name}] Translating...")
883
+ log.debug(f"[{block.name}] Input text:\n{block.original.text}")
884
+
885
+ self._parser.set_reference(block.original)
886
+
887
+ query_and_parse = self.diagram_prompt | self._llm | self._parser
888
+
889
+ if self._add_documentation:
890
+ block.text = query_and_parse.invoke(
891
+ {
892
+ "SOURCE_CODE": block.original.text,
893
+ "DIAGRAM_TYPE": self._diagram_type,
894
+ "DOCUMENTATION": documentation,
895
+ }
896
+ )
897
+ else:
898
+ block.text = query_and_parse.invoke(
899
+ {
900
+ "SOURCE_CODE": block.original.text,
901
+ "DIAGRAM_TYPE": self._diagram_type,
902
+ }
903
+ )
904
+ block.tokens = self._llm.get_num_tokens(block.text)
905
+ block.translated = True
906
+
907
+ log.debug(f"[{block.name}] Output code:\n{block.text}")
908
+
909
+ @run_if_changed(
910
+ "_diagram_prompt_template_name",
911
+ "_source_language",
912
+ )
913
+ def _load_diagram_prompt_engine(self) -> None:
914
+ """Load the prompt engine according to this instance's attributes.
915
+
916
+ If the relevant fields have not been changed since the last time this method was
917
+ called, nothing happens.
918
+ """
919
+ if self._diagram_prompt_template_name in SAME_OUTPUT:
920
+ if self._target_language != self._source_language:
921
+ raise ValueError(
922
+ f"Prompt template ({self._prompt_template_name}) suggests "
923
+ f"source and target languages should match, but do not "
924
+ f"({self._source_language} != {self._target_language})"
925
+ )
926
+ if (
927
+ self._diagram_prompt_template_name in TEXT_OUTPUT
928
+ and self._target_language != "text"
929
+ ):
930
+ raise ValueError(
931
+ f"Prompt template ({self._prompt_template_name}) suggests target "
932
+ f"language should be 'text', but is '{self._target_language}'"
933
+ )
934
+
935
+ self._diagram_prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
936
+ source_language=self._source_language,
937
+ target_language=self._target_language,
938
+ target_version=self._target_version,
939
+ prompt_template=self._diagram_prompt_template_name,
940
+ )
941
+ self.diagram_prompt = self._diagram_prompt_engine.prompt
942
+
943
+
944
+ class RequirementsDocumenter(Documenter):
945
+ """RequirementsGenerator
946
+
947
+ A class that translates code from one programming language to its requirements.
948
+ """
949
+
950
+ def __init__(self, **kwargs):
951
+ super().__init__(**kwargs)
952
+ self.set_prompt("chunk_requirements")
953
+ self.set_target_language("json", None)
954
+ self.set_parser_type("requirements")
955
+
956
+ def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
957
+ """Save a file to disk.
958
+
959
+ Arguments:
960
+ block: The `CodeBlock` to save to a file.
961
+ """
962
+ output_list = list()
963
+ # For each chunk of code, get generation metadata, the text of the code,
964
+ # and the LLM generated requirements
965
+ for child in block.children:
966
+ code = child.original.text
967
+ requirements = self._parser.parse_combined_output(child.complete_text)
968
+ metadata = dict(
969
+ retries=child.total_retries,
970
+ cost=child.total_cost,
971
+ processing_time=child.processing_time,
972
+ )
973
+ # Put them all in a top level 'output' key
974
+ output_list.append(
975
+ dict(metadata=metadata, code=code, requirements=requirements)
976
+ )
977
+ obj = dict(
978
+ output=output_list,
979
+ )
980
+ out_path.parent.mkdir(parents=True, exist_ok=True)
981
+ out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")