janus-llm 1.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (74) hide show
  1. janus/__init__.py +9 -1
  2. janus/__main__.py +4 -0
  3. janus/_tests/test_cli.py +128 -0
  4. janus/_tests/test_translate.py +49 -7
  5. janus/cli.py +530 -46
  6. janus/converter.py +50 -19
  7. janus/embedding/_tests/test_collections.py +2 -8
  8. janus/embedding/_tests/test_database.py +32 -0
  9. janus/embedding/_tests/test_vectorize.py +9 -4
  10. janus/embedding/collections.py +49 -6
  11. janus/embedding/embedding_models_info.py +130 -0
  12. janus/embedding/vectorize.py +53 -62
  13. janus/language/_tests/__init__.py +0 -0
  14. janus/language/_tests/test_combine.py +62 -0
  15. janus/language/_tests/test_splitter.py +16 -0
  16. janus/language/binary/_tests/test_binary.py +16 -1
  17. janus/language/binary/binary.py +10 -3
  18. janus/language/block.py +31 -30
  19. janus/language/combine.py +26 -34
  20. janus/language/mumps/_tests/test_mumps.py +2 -2
  21. janus/language/mumps/mumps.py +93 -9
  22. janus/language/naive/__init__.py +4 -0
  23. janus/language/naive/basic_splitter.py +14 -0
  24. janus/language/naive/chunk_splitter.py +26 -0
  25. janus/language/naive/registry.py +13 -0
  26. janus/language/naive/simple_ast.py +18 -0
  27. janus/language/naive/tag_splitter.py +61 -0
  28. janus/language/splitter.py +168 -74
  29. janus/language/treesitter/_tests/test_treesitter.py +19 -14
  30. janus/language/treesitter/treesitter.py +37 -13
  31. janus/llm/model_callbacks.py +177 -0
  32. janus/llm/models_info.py +165 -72
  33. janus/metrics/__init__.py +8 -0
  34. janus/metrics/_tests/__init__.py +0 -0
  35. janus/metrics/_tests/reference.py +2 -0
  36. janus/metrics/_tests/target.py +2 -0
  37. janus/metrics/_tests/test_bleu.py +56 -0
  38. janus/metrics/_tests/test_chrf.py +67 -0
  39. janus/metrics/_tests/test_file_pairing.py +59 -0
  40. janus/metrics/_tests/test_llm.py +91 -0
  41. janus/metrics/_tests/test_reading.py +28 -0
  42. janus/metrics/_tests/test_rouge_score.py +65 -0
  43. janus/metrics/_tests/test_similarity_score.py +23 -0
  44. janus/metrics/_tests/test_treesitter_metrics.py +110 -0
  45. janus/metrics/bleu.py +66 -0
  46. janus/metrics/chrf.py +55 -0
  47. janus/metrics/cli.py +7 -0
  48. janus/metrics/complexity_metrics.py +208 -0
  49. janus/metrics/file_pairing.py +113 -0
  50. janus/metrics/llm_metrics.py +202 -0
  51. janus/metrics/metric.py +466 -0
  52. janus/metrics/reading.py +70 -0
  53. janus/metrics/rouge_score.py +96 -0
  54. janus/metrics/similarity.py +53 -0
  55. janus/metrics/splitting.py +38 -0
  56. janus/parsers/_tests/__init__.py +0 -0
  57. janus/parsers/_tests/test_code_parser.py +32 -0
  58. janus/parsers/code_parser.py +24 -253
  59. janus/parsers/doc_parser.py +169 -0
  60. janus/parsers/eval_parser.py +80 -0
  61. janus/parsers/reqs_parser.py +72 -0
  62. janus/prompts/prompt.py +103 -30
  63. janus/translate.py +636 -111
  64. janus/utils/_tests/__init__.py +0 -0
  65. janus/utils/_tests/test_logger.py +67 -0
  66. janus/utils/_tests/test_progress.py +20 -0
  67. janus/utils/enums.py +56 -3
  68. janus/utils/progress.py +56 -0
  69. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/METADATA +27 -11
  70. janus_llm-2.0.1.dist-info/RECORD +94 -0
  71. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/WHEEL +1 -1
  72. janus_llm-1.0.0.dist-info/RECORD +0 -48
  73. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/LICENSE +0 -0
  74. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/entry_points.txt +0 -0
janus/translate.py CHANGED
@@ -1,34 +1,62 @@
1
- import uuid
1
+ import json
2
+ import math
3
+ import re
4
+ import time
5
+ from copy import deepcopy
2
6
  from pathlib import Path
3
- from typing import Any, Dict
7
+ from typing import Any
4
8
 
5
- from chromadb.api.models.Collection import Collection
6
- from langchain.callbacks import get_openai_callback
9
+ from langchain.output_parsers import RetryWithErrorOutputParser
10
+ from langchain.output_parsers.fix import OutputFixingParser
11
+ from langchain_core.exceptions import OutputParserException
12
+ from langchain_core.language_models import BaseLanguageModel
13
+ from langchain_core.output_parsers import BaseOutputParser
14
+ from langchain_core.prompts import ChatPromptTemplate
15
+ from langchain_core.runnables import RunnableLambda, RunnableParallel
16
+ from openai import BadRequestError, RateLimitError
17
+ from text_generation.errors import ValidationError
18
+
19
+ from janus.language.naive.registry import CUSTOM_SPLITTERS
7
20
 
8
21
  from .converter import Converter, run_if_changed
22
+ from .embedding.vectorize import ChromaDBVectorizer
9
23
  from .language.block import CodeBlock, TranslatedCodeBlock
24
+ from .language.combine import ChunkCombiner, Combiner, JsonCombiner
25
+ from .language.splitter import EmptyTreeError, FileSizeError, TokenLimitError
10
26
  from .llm import load_model
11
- from .parsers.code_parser import PARSER_TYPES, CodeParser, EvaluationParser, JanusParser
12
- from .prompts.prompt import SAME_OUTPUT, TEXT_OUTPUT, PromptEngine
27
+ from .llm.model_callbacks import get_model_callback
28
+ from .llm.models_info import MODEL_PROMPT_ENGINES
29
+ from .parsers.code_parser import CodeParser, GenericParser
30
+ from .parsers.doc_parser import MadlibsDocumentationParser, MultiDocumentationParser
31
+ from .parsers.eval_parser import EvaluationParser
32
+ from .parsers.reqs_parser import RequirementsParser
33
+ from .prompts.prompt import SAME_OUTPUT, TEXT_OUTPUT
13
34
  from .utils.enums import LANGUAGES
14
35
  from .utils.logger import create_logger
15
36
 
16
37
  log = create_logger(__name__)
17
38
 
18
39
 
40
+ PARSER_TYPES: set[str] = {"code", "text", "eval", "madlibs", "multidoc", "requirements"}
41
+
42
+
19
43
  class Translator(Converter):
20
44
  """A class that translates code from one programming language to another."""
21
45
 
22
46
  def __init__(
23
47
  self,
24
- model: str = "gpt-3.5-turbo",
25
- model_arguments: Dict[str, Any] = {},
48
+ model: str = "gpt-3.5-turbo-0125",
49
+ model_arguments: dict[str, Any] = {},
26
50
  source_language: str = "fortran",
27
51
  target_language: str = "python",
28
- target_version: str = "3.10",
52
+ target_version: str | None = "3.10",
29
53
  max_prompts: int = 10,
54
+ max_tokens: int | None = None,
30
55
  prompt_template: str | Path = "simple",
31
56
  parser_type: str = "code",
57
+ db_path: str | None = None,
58
+ db_config: dict[str, Any] | None = None,
59
+ custom_splitter: str | None = None,
32
60
  ) -> None:
33
61
  """Initialize a Translator instance.
34
62
 
@@ -41,37 +69,53 @@ class Translator(Converter):
41
69
  target_language: The target programming language.
42
70
  target_version: The target version of the target programming language.
43
71
  max_prompts: The maximum number of prompts to try before giving up.
72
+ max_tokens: The maximum number of tokens the model will take in.
73
+ If unspecificed, model's default max will be used.
44
74
  prompt_template: name of prompt template directory
45
75
  (see janus/prompts/templates) or path to a directory.
46
76
  parser_type: The type of parser to use for parsing the LLM output. Valid
47
77
  values are "code" (default), "text", and "eval".
48
78
  """
79
+ self._custom_splitter = custom_splitter
49
80
  super().__init__(source_language=source_language)
50
81
 
51
- self._parser_type: None | str
52
- self._parser: None | JanusParser
53
- self._model_name: None | str
54
- self._custom_model_arguments: None | Dict[str, Any]
55
- self._target_language: None | str
56
- self._target_version: None | str
57
- self._target_glob: None | str
58
- self._prompt_template_name: None | str
82
+ self._parser_type: str | None
83
+ self._model_name: str | None
84
+ self._custom_model_arguments: dict[str, Any] | None
85
+ self._target_language: str | None
86
+ self._target_version: str | None
87
+ self._target_glob: str | None
88
+ self._prompt_template_name: str | None
89
+ self._db_path: str | None
90
+ self._db_config: dict[str, Any] | None
91
+
92
+ self._llm: BaseLanguageModel | None
93
+ self._parser: BaseOutputParser | None
94
+ self._combiner: Combiner | None
95
+ self._prompt: ChatPromptTemplate | None
96
+
97
+ self.max_prompts = max_prompts
98
+ self.override_token_limit = False if max_tokens is None else True
99
+ self._max_tokens = max_tokens
59
100
 
60
101
  self.set_model(model_name=model, **model_arguments)
61
102
  self.set_parser_type(parser_type=parser_type)
62
103
  self.set_prompt(prompt_template=prompt_template)
63
104
  self.set_target_language(
64
- target_language=target_language, target_version=target_version
105
+ target_language=target_language,
106
+ target_version=target_version,
65
107
  )
108
+ self.set_db_path(db_path=db_path)
109
+ self.set_db_config(db_config=db_config)
66
110
 
67
111
  self._load_parameters()
68
112
 
69
- self.max_prompts = max_prompts
70
-
71
113
  def _load_parameters(self) -> None:
72
114
  self._load_model()
73
- self._load_prompt_engine()
115
+ self._load_prompt()
74
116
  self._load_parser()
117
+ self._load_combiner()
118
+ self._load_vectorizer()
75
119
  super()._load_parameters() # will call self._changed_attrs.clear()
76
120
 
77
121
  def translate(
@@ -79,7 +123,7 @@ class Translator(Converter):
79
123
  input_directory: str | Path,
80
124
  output_directory: str | Path | None = None,
81
125
  overwrite: bool = False,
82
- output_collection: Collection | None = None,
126
+ collection_name: str | None = None,
83
127
  ) -> None:
84
128
  """Translate code in the input directory from the source language to the target
85
129
  language, and write the resulting files to the output directory.
@@ -88,6 +132,7 @@ class Translator(Converter):
88
132
  input_directory: The directory containing the code to translate.
89
133
  output_directory: The directory to write the translated code to.
90
134
  overwrite: Whether to overwrite existing files (vs skip them)
135
+ collection_name: Collection to add to
91
136
  """
92
137
  # Convert paths to pathlib Paths if needed
93
138
  if isinstance(input_directory, str):
@@ -99,25 +144,77 @@ class Translator(Converter):
99
144
  if output_directory is not None and not output_directory.exists():
100
145
  output_directory.mkdir(parents=True)
101
146
 
147
+ source_suffix = LANGUAGES[self._source_language]["suffix"]
102
148
  target_suffix = LANGUAGES[self._target_language]["suffix"]
103
149
 
104
- input_paths = input_directory.rglob(self._source_glob)
150
+ input_paths = [p for p in input_directory.rglob(self._source_glob)]
151
+
152
+ log.info(f"Input directory: {input_directory.absolute()}")
153
+ log.info(
154
+ f"{self._source_language.capitalize()} '*.{source_suffix}' files: "
155
+ f"{len(input_paths)}"
156
+ )
157
+ log.info(
158
+ "Other files (skipped): "
159
+ f"{len(list(input_directory.iterdir())) - len(input_paths)}\n"
160
+ )
161
+ if output_directory is not None:
162
+ output_paths = [
163
+ output_directory
164
+ / p.relative_to(input_directory).with_suffix(f".{target_suffix}")
165
+ for p in input_paths
166
+ ]
167
+ in_out_pairs = list(zip(input_paths, output_paths))
168
+ if not overwrite:
169
+ n_files = len(in_out_pairs)
170
+ in_out_pairs = [
171
+ (inp, outp) for inp, outp in in_out_pairs if not outp.exists()
172
+ ]
173
+ log.info(
174
+ f"Skipping {n_files - len(in_out_pairs)} existing "
175
+ f"'*.{source_suffix}' files"
176
+ )
177
+ else:
178
+ in_out_pairs = [(f, None) for f in input_paths]
179
+ log.info(f"Translating {len(in_out_pairs)} '*.{source_suffix}' files")
105
180
 
106
181
  # Now, loop through every code block in every file and translate it with an LLM
107
182
  total_cost = 0.0
108
- for in_path in input_paths:
109
- relative = in_path.relative_to(input_directory)
110
- # output_name = relative.with_suffix(f".{target_suffix}").name
111
- if output_directory is not None:
112
- out_path = output_directory / relative.with_suffix(f".{target_suffix}")
113
- else:
114
- out_path = None
115
- # Track the cost of translating the file
116
- # TODO: If non-OpenAI models with prices are added, this will need
117
- # to be updated.
118
- with get_openai_callback() as cb:
183
+ for in_path, out_path in in_out_pairs:
184
+ # Translate the file, skip it if there's a rate limit error
185
+ try:
119
186
  out_block = self.translate_file(in_path)
120
- total_cost += cb.total_cost
187
+ total_cost += out_block.total_cost
188
+ except RateLimitError:
189
+ continue
190
+ except OutputParserException as e:
191
+ log.error(f"Skipping {in_path.name}, failed to parse output: {e}.")
192
+ continue
193
+ except BadRequestError as e:
194
+ if str(e).startswith("Detected an error in the prompt"):
195
+ log.warning("Malformed input, skipping")
196
+ continue
197
+ raise e
198
+ except ValidationError as e:
199
+ # Only allow ValidationError to pass if token limit is manually set
200
+ if self.override_token_limit:
201
+ log.warning(
202
+ "Current file and manually set token "
203
+ "limit is too large for this model, skipping"
204
+ )
205
+ continue
206
+ raise e
207
+ except TokenLimitError:
208
+ log.warning("Ran into irreducible node too large for context, skipping")
209
+ continue
210
+ except EmptyTreeError:
211
+ log.warning(
212
+ f'Input file "{in_path.name}" has no nodes of interest, skipping'
213
+ )
214
+ continue
215
+ except FileSizeError:
216
+ log.warning("Current tile is too large for basic splitter, skipping")
217
+ continue
121
218
 
122
219
  # Don't attempt to write files for which translation failed
123
220
  if not out_block.translated:
@@ -139,19 +236,36 @@ class Translator(Converter):
139
236
  #
140
237
  # self._embed_nodes_recursively(out_block, embedding_type, filename)
141
238
 
239
+ if collection_name is not None:
240
+ self._vectorizer.add_nodes_recursively(
241
+ out_block,
242
+ collection_name,
243
+ in_path.name,
244
+ )
245
+ # out_text = self.parser.parse_combined_output(out_block.complete_text)
246
+ # # Using same id naming convention from vectorize.py
247
+ # ids = [str(uuid.uuid3(uuid.NAMESPACE_DNS, out_text))]
248
+ # output_collection.upsert(ids=ids, documents=[out_text])
249
+
142
250
  # Make sure the tree's code has been consolidated at the top level
143
251
  # before writing to file
144
252
  self._combiner.combine(out_block)
145
253
  if out_path is not None and (overwrite or not out_path.exists()):
146
254
  self._save_to_file(out_block, out_path)
147
- if output_collection is not None:
148
- out_text = self.parser.parse_combined_output(out_block.complete_text)
149
- # Using same id naming convention from vectorize.py
150
- ids = [str(uuid.uuid3(uuid.NAMESPACE_DNS, out_text))]
151
- output_collection.upsert(ids=ids, documents=[out_text])
152
255
 
153
256
  log.info(f"Total cost: ${total_cost:,.2f}")
154
257
 
258
+ def _split_file(self, file: Path) -> CodeBlock:
259
+ filename = file.name
260
+ log.info(f"[{filename}] Splitting file")
261
+ root = self._splitter.split(file)
262
+ log.info(
263
+ f"[{filename}] File split into {root.n_descendents:,} blocks, "
264
+ f"tree of height {root.height}"
265
+ )
266
+ log.info(f"[{filename}] Input CodeBlock Structure:\n{root.tree_str()}")
267
+ return root
268
+
155
269
  def translate_file(self, file: Path) -> TranslatedCodeBlock:
156
270
  """Translate a single file.
157
271
 
@@ -164,18 +278,12 @@ class Translator(Converter):
164
278
  `Combiner.combine_children` on the block.
165
279
  """
166
280
  self._load_parameters()
167
-
168
281
  filename = file.name
169
- log.info(f"[{filename}] Splitting file")
170
- input_block = self._splitter.split(file)
171
- log.info(
172
- f"[{filename}] File split into {input_block.n_descendents:,} blocks, "
173
- f"tree of height {input_block.height}"
174
- )
175
- log.info(f"[{filename}] Input CodeBlock Structure:\n{input_block.tree_str()}")
176
- # (temporarily?) comment-out adding embeddings; will be moved
177
- # self._embed_nodes_recursively(input_block, EmbeddingType.SOURCE, filename)
282
+
283
+ input_block = self._split_file(file)
284
+ t0 = time.time()
178
285
  output_block = self._iterative_translate(input_block)
286
+ output_block.processing_time = time.time() - t0
179
287
  if output_block.translated:
180
288
  completeness = output_block.translation_completeness
181
289
  log.info(
@@ -183,7 +291,9 @@ class Translator(Converter):
183
291
  f" {completeness:.2%} of input successfully translated\n"
184
292
  f" Total cost: ${output_block.total_cost:,.2f}\n"
185
293
  f" Total retries: {output_block.total_retries:,d}\n"
294
+ f" Output CodeBlock Structure:\n{input_block.tree_str()}\n"
186
295
  )
296
+
187
297
  else:
188
298
  log.error(
189
299
  f"[{filename}] Translation failed\n"
@@ -223,25 +333,14 @@ class Translator(Converter):
223
333
  while stack:
224
334
  translated_block = stack.pop()
225
335
 
226
- # Track the cost of translating this block
227
- # TODO: If non-OpenAI models with prices are added, this will need
228
- # to be updated.
229
- with get_openai_callback() as cb:
230
- self._add_translation(translated_block)
231
- translated_block.cost = cb.total_cost
232
- translated_block.retries = max(0, cb.successful_requests - 1)
336
+ self._add_translation(translated_block)
233
337
 
234
338
  # If translating this block was unsuccessful, don't bother with its
235
339
  # children (they wouldn't show up in the final text anyway)
236
340
  if not translated_block.translated:
237
341
  continue
238
342
 
239
- for child in translated_block.children:
240
- # Don't bother translating children if they aren't used
241
- if self._combiner.contains_child(translated_block.text, child):
242
- stack.append(child)
243
- else:
244
- log.warning(f"Skipping {child.id} (not referenced in parent code)")
343
+ stack.extend(translated_block.children)
245
344
 
246
345
  progress = translated_root.translation_completeness
247
346
  if progress - last_prog > prog_delta:
@@ -266,11 +365,6 @@ class Translator(Converter):
266
365
  block.translated = True
267
366
  return
268
367
 
269
- log.debug(f"[{block.name}] Translating...")
270
- log.debug(f"[{block.name}] Input text:\n{block.original.text}")
271
- prompt = self._prompt_engine.create(block.original)
272
- top_score = -1.0
273
-
274
368
  if self._llm is None:
275
369
  message = (
276
370
  "Model not configured correctly, cannot translate. Try setting "
@@ -279,51 +373,85 @@ class Translator(Converter):
279
373
  log.error(message)
280
374
  raise ValueError(message)
281
375
 
282
- # Retry the request up to max_prompts times before failing
283
- for _ in range(self.max_prompts + 1):
284
- output = self._llm.predict_messages(prompt)
285
- try:
286
- parsed_output = self.parser.parse(output.content)
287
- except ValueError as e:
288
- log.warning(f"[{block.name}] Failed to parse output: {e}")
289
- log.debug(f"[{block.name}] Failed output:\n{output.content}")
290
- continue
291
-
292
- score = self.parser.score(block.original, parsed_output)
293
- if score > top_score:
294
- block.text = parsed_output
295
- top_score = score
296
-
297
- if score >= 1.0:
298
- break
299
-
300
- else:
301
- if block.text is None:
302
- error_msg = (
303
- f"[{block.name}] Failed to parse output after "
304
- f"{self.max_prompts} retries. Marking as untranslated."
305
- )
306
- log.warning(error_msg)
307
- return
376
+ log.debug(f"[{block.name}] Translating...")
377
+ log.debug(f"[{block.name}] Input text:\n{block.original.text}")
308
378
 
309
- log.warning(f"[{block.name}] Output not complete")
379
+ # Track the cost of translating this block
380
+ # TODO: If non-OpenAI models with prices are added, this will need
381
+ # to be updated.
382
+ with get_model_callback() as cb:
383
+ t0 = time.time()
384
+ block.text = self._run_chain(block)
385
+ block.processing_time = time.time() - t0
386
+ block.cost = cb.total_cost
387
+ block.retries = max(0, cb.successful_requests - 1)
310
388
 
311
389
  block.tokens = self._llm.get_num_tokens(block.text)
312
390
  block.translated = True
313
391
 
314
392
  log.debug(f"[{block.name}] Output code:\n{block.text}")
315
393
 
394
+ def _run_chain(self, block: TranslatedCodeBlock) -> str:
395
+ """Run the model with three nested error fixing schemes.
396
+ First, try to fix simple formatting errors by giving the model just
397
+ the output and the parsing error. After a number of attempts, try
398
+ giving the model the output, the parsing error, and the original
399
+ input. Again check/retry this output to solve for formatting errors.
400
+ If we still haven't succeeded after several attempts, the model may
401
+ be getting thrown off by a bad initial output; start from scratch
402
+ and try again.
403
+
404
+ The number of tries for each layer of this scheme is roughly equal
405
+ to the cube root of self.max_retries, so the total calls to the
406
+ LLM will be roughly as expected (up to sqrt(self.max_retries) over)
407
+ """
408
+ self._parser.set_reference(block.original)
409
+
410
+ # Retries with just the output and the error
411
+ n1 = round(self.max_prompts ** (1 / 3))
412
+
413
+ # Retries with the input, output, and error
414
+ n2 = round((self.max_prompts // n1) ** (1 / 2))
415
+
416
+ # Retries with just the input
417
+ n3 = math.ceil(self.max_prompts / (n1 * n2))
418
+
419
+ fix_format = OutputFixingParser.from_llm(
420
+ llm=self._llm,
421
+ parser=self._parser,
422
+ max_retries=n1,
423
+ )
424
+ retry = RetryWithErrorOutputParser.from_llm(
425
+ llm=self._llm,
426
+ parser=fix_format,
427
+ max_retries=n2,
428
+ )
429
+
430
+ completion_chain = self._prompt | self._llm
431
+ chain = RunnableParallel(
432
+ completion=completion_chain, prompt_value=self._prompt
433
+ ) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
434
+
435
+ for _ in range(n3):
436
+ try:
437
+ return chain.invoke({"SOURCE_CODE": block.original.text})
438
+ except OutputParserException:
439
+ pass
440
+
441
+ raise OutputParserException(f"Failed to parse after {n1*n2*n3} retries")
442
+
316
443
  def _save_to_file(self, block: CodeBlock, out_path: Path) -> None:
317
444
  """Save a file to disk.
318
445
 
319
446
  Arguments:
320
447
  block: The `CodeBlock` to save to a file.
321
448
  """
322
- out_text = self.parser.parse_combined_output(block.complete_text)
449
+ # TODO: can't use output fixer and this system for combining output
450
+ out_text = self._parser.parse_combined_output(block.complete_text)
323
451
  out_path.parent.mkdir(parents=True, exist_ok=True)
324
452
  out_path.write_text(out_text, encoding="utf-8")
325
453
 
326
- def set_model(self, model_name: str, **custom_arguments: Dict[str, Any]):
454
+ def set_model(self, model_name: str, **custom_arguments: dict[str, Any]):
327
455
  """Validate and set the model name.
328
456
 
329
457
  The affected objects will not be updated until translate() is called.
@@ -363,7 +491,9 @@ class Translator(Converter):
363
491
  """
364
492
  self._prompt_template_name = prompt_template
365
493
 
366
- def set_target_language(self, target_language: str, target_version: str) -> None:
494
+ def set_target_language(
495
+ self, target_language: str, target_version: str | None
496
+ ) -> None:
367
497
  """Validate and set the target language.
368
498
 
369
499
  The affected objects will not be updated until translate() is called.
@@ -382,8 +512,14 @@ class Translator(Converter):
382
512
  self._target_language = target_language
383
513
  self._target_version = target_version
384
514
 
515
+ def set_db_path(self, db_path: str) -> None:
516
+ self._db_path = db_path
517
+
518
+ def set_db_config(self, db_config: dict[str, Any] | None) -> None:
519
+ self._db_config = db_config
520
+
385
521
  @run_if_changed("_model_name", "_custom_model_arguments")
386
- def _load_model(self):
522
+ def _load_model(self) -> None:
387
523
  """Load the model according to this instance's attributes.
388
524
 
389
525
  If the relevant fields have not been changed since the last time this method was
@@ -398,7 +534,9 @@ class Translator(Converter):
398
534
  self._llm, token_limit, self.model_cost = load_model(self._model_name)
399
535
  # Set the max_tokens to less than half the model's limit to allow for enough
400
536
  # tokens at output
401
- self._max_tokens = token_limit // 2.5
537
+ # Only modify max_tokens if it is not specified by user
538
+ if not self.override_token_limit:
539
+ self._max_tokens = token_limit // 2.5
402
540
 
403
541
  @run_if_changed("_parser_type", "_target_language")
404
542
  def _load_parser(self) -> None:
@@ -412,14 +550,26 @@ class Translator(Converter):
412
550
  f"Target language ({self._target_language}) suggests target "
413
551
  f"parser should be 'text', but is '{self._parser_type}'"
414
552
  )
553
+ if (
554
+ self._parser_type in {"eval", "multidoc", "madlibs"}
555
+ and "json" != self._target_language
556
+ ):
557
+ raise ValueError(
558
+ f"Parser type ({self._parser_type}) suggests target language"
559
+ f" should be 'json', but is '{self._target_language}'"
560
+ )
415
561
  if "code" == self._parser_type:
416
- self.parser = CodeParser(language=self._target_language)
562
+ self._parser = CodeParser(language=self._target_language)
417
563
  elif "eval" == self._parser_type:
418
- self.parser = EvaluationParser(
419
- expected_keys={"syntax", "style", "completeness", "correctness"}
420
- )
564
+ self._parser = EvaluationParser()
565
+ elif "multidoc" == self._parser_type:
566
+ self._parser = MultiDocumentationParser()
567
+ elif "madlibs" == self._parser_type:
568
+ self._parser = MadlibsDocumentationParser()
421
569
  elif "text" == self._parser_type:
422
- self.parser = JanusParser()
570
+ self._parser = GenericParser()
571
+ elif "requirements" == self._parser_type:
572
+ self._parser = RequirementsParser()
423
573
  else:
424
574
  raise ValueError(
425
575
  f"Unsupported parser type: {self._parser_type}. Can be: "
@@ -427,13 +577,17 @@ class Translator(Converter):
427
577
  )
428
578
 
429
579
  @run_if_changed(
430
- "_prompt_template_name", "_source_language", "_target_language", "_target_version"
580
+ "_prompt_template_name",
581
+ "_source_language",
582
+ "_target_language",
583
+ "_target_version",
584
+ "_model_name",
431
585
  )
432
- def _load_prompt_engine(self) -> None:
433
- """Load the prompt engine according to this instance's attributes.
586
+ def _load_prompt(self) -> None:
587
+ """Load the prompt according to this instance's attributes.
434
588
 
435
- If the relevant fields have not been changed since the last time this method was
436
- called, nothing happens.
589
+ If the relevant fields have not been changed since the last time this
590
+ method was called, nothing happens.
437
591
  """
438
592
  if self._prompt_template_name in SAME_OUTPUT:
439
593
  if self._target_language != self._source_language:
@@ -448,9 +602,380 @@ class Translator(Converter):
448
602
  f"language should be 'text', but is '{self._target_language}'"
449
603
  )
450
604
 
451
- self._prompt_engine = PromptEngine(
605
+ prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
452
606
  source_language=self._source_language,
453
607
  target_language=self._target_language,
454
608
  target_version=self._target_version,
455
609
  prompt_template=self._prompt_template_name,
456
610
  )
611
+ self._prompt = prompt_engine.prompt
612
+
613
+ @run_if_changed("_db_path")
614
+ def _load_vectorizer(self) -> None:
615
+ if self._db_path is None:
616
+ self._vectorizer = None
617
+ return
618
+ vectorizer_factory = ChromaDBVectorizer()
619
+ self._vectorizer = vectorizer_factory.create_vectorizer(
620
+ self._db_path, self._db_config
621
+ )
622
+
623
+ @run_if_changed(
624
+ "_source_language",
625
+ "_max_tokens",
626
+ "_llm",
627
+ "_protected_node_types",
628
+ "_prune_node_types",
629
+ )
630
+ def _load_splitter(self) -> None:
631
+ if self._custom_splitter is None:
632
+ super()._load_splitter()
633
+ else:
634
+ kwargs = dict(
635
+ max_tokens=self._max_tokens,
636
+ model=self._llm,
637
+ protected_node_types=self._protected_node_types,
638
+ prune_node_types=self._prune_node_types,
639
+ )
640
+ # TODO: This should be configurable
641
+ if self._custom_splitter == "tag":
642
+ kwargs["tag"] = "<ITMOD_ALC_SPLIT>"
643
+ self._splitter = CUSTOM_SPLITTERS[self._custom_splitter](
644
+ language=self._source_language, **kwargs
645
+ )
646
+
647
+ @run_if_changed("_target_language", "_parser_type")
648
+ def _load_combiner(self) -> None:
649
+ if self._parser_type == "requirements":
650
+ self._combiner = ChunkCombiner()
651
+ elif self._target_language == "json":
652
+ self._combiner = JsonCombiner()
653
+ else:
654
+ self._combiner = Combiner()
655
+
656
+
657
+ class Documenter(Translator):
658
+ def __init__(
659
+ self, source_language: str = "fortran", drop_comments: bool = True, **kwargs
660
+ ):
661
+ kwargs.update(
662
+ source_language=source_language,
663
+ target_language="text",
664
+ target_version=None,
665
+ prompt_template="document",
666
+ parser_type="text",
667
+ )
668
+ super().__init__(**kwargs)
669
+
670
+ if drop_comments:
671
+ comment_node_type = LANGUAGES[source_language].get(
672
+ "comment_node_type", "comment"
673
+ )
674
+ self.set_prune_node_types([comment_node_type])
675
+
676
+
677
+ class MultiDocumenter(Documenter):
678
+ def __init__(self, **kwargs):
679
+ super().__init__(**kwargs)
680
+ self.set_prompt("multidocument")
681
+ self.set_parser_type("multidoc")
682
+ self.set_target_language("json", None)
683
+
684
+
685
+ class MadLibsDocumenter(Documenter):
686
+ def __init__(
687
+ self,
688
+ comments_per_request: int | None = None,
689
+ **kwargs,
690
+ ) -> None:
691
+ kwargs.update(drop_comments=False)
692
+ super().__init__(**kwargs)
693
+ self.set_prompt("document_madlibs")
694
+ self.set_parser_type("madlibs")
695
+ self.set_target_language("json", None)
696
+ self.comments_per_request = comments_per_request
697
+
698
+ def _add_translation(self, block: TranslatedCodeBlock):
699
+ if block.translated:
700
+ return
701
+
702
+ if block.original.text is None:
703
+ block.translated = True
704
+ return
705
+
706
+ if self.comments_per_request is None:
707
+ return super()._add_translation(block)
708
+
709
+ comment_pattern = r"<(?:INLINE|BLOCK)_COMMENT \w{8}>"
710
+ comments = list(
711
+ re.finditer(
712
+ comment_pattern,
713
+ block.original.text,
714
+ )
715
+ )
716
+
717
+ if not comments:
718
+ log.info(f"[{block.name}] Skipping commentless block")
719
+ block.translated = True
720
+ block.text = None
721
+ block.complete = True
722
+ return
723
+
724
+ if len(comments) <= self.comments_per_request:
725
+ return super()._add_translation(block)
726
+
727
+ comment_group_indices = list(range(0, len(comments), self.comments_per_request))
728
+ log.debug(
729
+ f"[{block.name}] Block contains more than {self.comments_per_request}"
730
+ f" comments, splitting {len(comments)} comments into"
731
+ f" {len(comment_group_indices)} groups"
732
+ )
733
+
734
+ block.processing_time = 0
735
+ block.cost = 0
736
+ block.retries = 0
737
+ obj = {}
738
+ for i in range(0, len(comments), self.comments_per_request):
739
+ # Split the text into the section containing comments of interest,
740
+ # all the text prior to those comments, and all the text after them
741
+ working_comments = comments[i : i + self.comments_per_request]
742
+ start_idx = working_comments[0].start()
743
+ end_idx = working_comments[-1].end()
744
+ prefix = block.original.text[:start_idx]
745
+ keeper = block.original.text[start_idx:end_idx]
746
+ suffix = block.original.text[end_idx:]
747
+
748
+ # Strip all comment placeholders outside of the section of interest
749
+ prefix = re.sub(comment_pattern, "", prefix)
750
+ suffix = re.sub(comment_pattern, "", suffix)
751
+
752
+ # Build a new TranslatedBlock using the new working text
753
+ working_copy = deepcopy(block.original)
754
+ working_copy.text = prefix + keeper + suffix
755
+ working_block = TranslatedCodeBlock(working_copy, self._target_language)
756
+
757
+ # Run the LLM on the working text
758
+ super()._add_translation(working_block)
759
+
760
+ # Update metadata to include for all runs
761
+ block.retries += working_block.retries
762
+ block.cost += working_block.cost
763
+ block.processing_time += working_block.processing_time
764
+
765
+ # Update the output text to merge this section's output in
766
+ out_text = self._parser.parse(working_block.text)
767
+ obj.update(json.loads(out_text))
768
+
769
+ self._parser.set_reference(block.original)
770
+ block.text = self._parser.parse(json.dumps(obj))
771
+ block.tokens = self._llm.get_num_tokens(block.text)
772
+ block.translated = True
773
+
774
+ def _get_obj(
775
+ self, block: TranslatedCodeBlock
776
+ ) -> dict[str, int | float | dict[str, str]]:
777
+ out_text = self._parser.parse_combined_output(block.complete_text)
778
+ obj = dict(
779
+ retries=block.total_retries,
780
+ cost=block.total_cost,
781
+ processing_time=block.processing_time,
782
+ comments=json.loads(out_text),
783
+ )
784
+ return obj
785
+
786
+ def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
787
+ """Save a file to disk.
788
+
789
+ Arguments:
790
+ block: The `CodeBlock` to save to a file.
791
+ """
792
+ obj = self._get_obj(block)
793
+ out_path.parent.mkdir(parents=True, exist_ok=True)
794
+ out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")
795
+
796
+
797
+ class DiagramGenerator(Documenter):
798
+ """DiagramGenerator
799
+
800
+ A class that translates code from one programming language to a set of diagrams.
801
+ """
802
+
803
+ def __init__(
804
+ self,
805
+ model: str = "gpt-3.5-turbo-0125",
806
+ model_arguments: dict[str, Any] = {},
807
+ source_language: str = "fortran",
808
+ max_prompts: int = 10,
809
+ db_path: str | None = None,
810
+ db_config: dict[str, Any] | None = None,
811
+ diagram_type="Activity",
812
+ add_documentation=False,
813
+ custom_splitter: str | None = None,
814
+ ) -> None:
815
+ """Initialize the DiagramGenerator class
816
+
817
+ Arguments:
818
+ model: The LLM to use for translation. If an OpenAI model, the
819
+ `OPENAI_API_KEY` environment variable must be set and the
820
+ `OPENAI_ORG_ID` environment variable should be set if needed.
821
+ model_arguments: Additional arguments to pass to the LLM constructor.
822
+ source_language: The source programming language.
823
+ max_prompts: The maximum number of prompts to try before giving up.
824
+ db_path: path to chroma database
825
+ db_config: database configuraiton
826
+ diagram_type: type of PLANTUML diagram to generate
827
+ """
828
+ super().__init__(
829
+ model=model,
830
+ model_arguments=model_arguments,
831
+ source_language=source_language,
832
+ max_prompts=max_prompts,
833
+ db_path=db_path,
834
+ db_config=db_config,
835
+ custom_splitter=custom_splitter,
836
+ )
837
+ self._diagram_type = diagram_type
838
+ self._add_documentation = add_documentation
839
+ self._documenter = None
840
+ self._model = model
841
+ self._model_arguments = model_arguments
842
+ self._max_prompts = max_prompts
843
+ if add_documentation:
844
+ self._diagram_prompt_template_name = "diagram_with_documentation"
845
+ else:
846
+ self._diagram_prompt_template_name = "diagram"
847
+ self._load_diagram_prompt_engine()
848
+
849
+ def _add_translation(self, block: TranslatedCodeBlock) -> None:
850
+ """Given an "empty" `TranslatedCodeBlock`, translate the code represented in
851
+ `block.original`, setting the relevant fields in the translated block. The
852
+ `TranslatedCodeBlock` is updated in-pace, nothing is returned. Note that this
853
+ translates *only* the code for this block, not its children.
854
+
855
+ Arguments:
856
+ block: An empty `TranslatedCodeBlock`
857
+ """
858
+ if block.translated:
859
+ return
860
+
861
+ if block.original.text is None:
862
+ block.translated = True
863
+ return
864
+
865
+ if self._add_documentation:
866
+ documentation_block = deepcopy(block)
867
+ super()._add_translation(documentation_block)
868
+ if not documentation_block.translated:
869
+ message = "Error: unable to produce documentation for code block"
870
+ log.message(message)
871
+ raise ValueError(message)
872
+ documentation = json.loads(documentation_block.text)["docstring"]
873
+
874
+ if self._llm is None:
875
+ message = (
876
+ "Model not configured correctly, cannot translate. Try setting "
877
+ "the model"
878
+ )
879
+ log.error(message)
880
+ raise ValueError(message)
881
+
882
+ log.debug(f"[{block.name}] Translating...")
883
+ log.debug(f"[{block.name}] Input text:\n{block.original.text}")
884
+
885
+ self._parser.set_reference(block.original)
886
+
887
+ query_and_parse = self.diagram_prompt | self._llm | self._parser
888
+
889
+ if self._add_documentation:
890
+ block.text = query_and_parse.invoke(
891
+ {
892
+ "SOURCE_CODE": block.original.text,
893
+ "DIAGRAM_TYPE": self._diagram_type,
894
+ "DOCUMENTATION": documentation,
895
+ }
896
+ )
897
+ else:
898
+ block.text = query_and_parse.invoke(
899
+ {
900
+ "SOURCE_CODE": block.original.text,
901
+ "DIAGRAM_TYPE": self._diagram_type,
902
+ }
903
+ )
904
+ block.tokens = self._llm.get_num_tokens(block.text)
905
+ block.translated = True
906
+
907
+ log.debug(f"[{block.name}] Output code:\n{block.text}")
908
+
909
+ @run_if_changed(
910
+ "_diagram_prompt_template_name",
911
+ "_source_language",
912
+ )
913
+ def _load_diagram_prompt_engine(self) -> None:
914
+ """Load the prompt engine according to this instance's attributes.
915
+
916
+ If the relevant fields have not been changed since the last time this method was
917
+ called, nothing happens.
918
+ """
919
+ if self._diagram_prompt_template_name in SAME_OUTPUT:
920
+ if self._target_language != self._source_language:
921
+ raise ValueError(
922
+ f"Prompt template ({self._prompt_template_name}) suggests "
923
+ f"source and target languages should match, but do not "
924
+ f"({self._source_language} != {self._target_language})"
925
+ )
926
+ if (
927
+ self._diagram_prompt_template_name in TEXT_OUTPUT
928
+ and self._target_language != "text"
929
+ ):
930
+ raise ValueError(
931
+ f"Prompt template ({self._prompt_template_name}) suggests target "
932
+ f"language should be 'text', but is '{self._target_language}'"
933
+ )
934
+
935
+ self._diagram_prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
936
+ source_language=self._source_language,
937
+ target_language=self._target_language,
938
+ target_version=self._target_version,
939
+ prompt_template=self._diagram_prompt_template_name,
940
+ )
941
+ self.diagram_prompt = self._diagram_prompt_engine.prompt
942
+
943
+
944
+ class RequirementsDocumenter(Documenter):
945
+ """RequirementsGenerator
946
+
947
+ A class that translates code from one programming language to its requirements.
948
+ """
949
+
950
+ def __init__(self, **kwargs):
951
+ super().__init__(**kwargs)
952
+ self.set_prompt("chunk_requirements")
953
+ self.set_target_language("json", None)
954
+ self.set_parser_type("requirements")
955
+
956
+ def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
957
+ """Save a file to disk.
958
+
959
+ Arguments:
960
+ block: The `CodeBlock` to save to a file.
961
+ """
962
+ output_list = list()
963
+ # For each chunk of code, get generation metadata, the text of the code,
964
+ # and the LLM generated requirements
965
+ for child in block.children:
966
+ code = child.original.text
967
+ requirements = self._parser.parse_combined_output(child.complete_text)
968
+ metadata = dict(
969
+ retries=child.total_retries,
970
+ cost=child.total_cost,
971
+ processing_time=child.processing_time,
972
+ )
973
+ # Put them all in a top level 'output' key
974
+ output_list.append(
975
+ dict(metadata=metadata, code=code, requirements=requirements)
976
+ )
977
+ obj = dict(
978
+ output=output_list,
979
+ )
980
+ out_path.parent.mkdir(parents=True, exist_ok=True)
981
+ out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")