janus-llm 4.3.1__py3-none-any.whl → 4.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. janus/__init__.py +1 -1
  2. janus/__main__.py +1 -1
  3. janus/_tests/evaluator_tests/EvalReadMe.md +85 -0
  4. janus/_tests/evaluator_tests/incose_tests/incose_large_test.json +39 -0
  5. janus/_tests/evaluator_tests/incose_tests/incose_small_test.json +17 -0
  6. janus/_tests/evaluator_tests/inline_comment_tests/mumps_inline_comment_test.m +71 -0
  7. janus/_tests/test_cli.py +3 -2
  8. janus/cli/aggregate.py +135 -0
  9. janus/cli/cli.py +117 -0
  10. janus/cli/constants.py +49 -0
  11. janus/cli/database.py +289 -0
  12. janus/cli/diagram.py +207 -0
  13. janus/cli/document.py +183 -0
  14. janus/cli/embedding.py +122 -0
  15. janus/cli/llm.py +191 -0
  16. janus/cli/partition.py +134 -0
  17. janus/cli/pipeline.py +123 -0
  18. janus/cli/self_eval.py +147 -0
  19. janus/cli/translate.py +192 -0
  20. janus/converter/__init__.py +1 -1
  21. janus/converter/_tests/test_translate.py +7 -5
  22. janus/converter/chain.py +180 -0
  23. janus/converter/converter.py +444 -153
  24. janus/converter/diagram.py +8 -6
  25. janus/converter/document.py +27 -16
  26. janus/converter/evaluate.py +143 -144
  27. janus/converter/partition.py +2 -10
  28. janus/converter/requirements.py +4 -40
  29. janus/converter/translate.py +3 -59
  30. janus/embedding/collections.py +1 -1
  31. janus/language/alc/_tests/alc.asm +3779 -0
  32. janus/language/binary/_tests/hello.bin +0 -0
  33. janus/language/block.py +78 -14
  34. janus/language/file.py +1 -1
  35. janus/language/mumps/_tests/mumps.m +235 -0
  36. janus/language/treesitter/_tests/languages/fortran.f90 +416 -0
  37. janus/language/treesitter/_tests/languages/ibmhlasm.asm +16 -0
  38. janus/language/treesitter/_tests/languages/matlab.m +225 -0
  39. janus/llm/models_info.py +9 -1
  40. janus/metrics/_tests/asm_test_file.asm +10 -0
  41. janus/metrics/_tests/mumps_test_file.m +6 -0
  42. janus/metrics/_tests/test_treesitter_metrics.py +1 -1
  43. janus/metrics/metric.py +47 -124
  44. janus/metrics/prompts/clarity.txt +8 -0
  45. janus/metrics/prompts/completeness.txt +16 -0
  46. janus/metrics/prompts/faithfulness.txt +10 -0
  47. janus/metrics/prompts/hallucination.txt +16 -0
  48. janus/metrics/prompts/quality.txt +8 -0
  49. janus/metrics/prompts/readability.txt +16 -0
  50. janus/metrics/prompts/usefulness.txt +16 -0
  51. janus/parsers/code_parser.py +4 -4
  52. janus/parsers/doc_parser.py +12 -9
  53. janus/parsers/parser.py +7 -0
  54. janus/parsers/partition_parser.py +6 -4
  55. janus/parsers/reqs_parser.py +11 -8
  56. janus/parsers/uml.py +5 -4
  57. janus/prompts/prompt.py +2 -2
  58. janus/prompts/templates/README.md +30 -0
  59. janus/prompts/templates/basic_aggregation/human.txt +6 -0
  60. janus/prompts/templates/basic_aggregation/system.txt +1 -0
  61. janus/prompts/templates/basic_refinement/human.txt +14 -0
  62. janus/prompts/templates/basic_refinement/system.txt +1 -0
  63. janus/prompts/templates/diagram/human.txt +9 -0
  64. janus/prompts/templates/diagram/system.txt +1 -0
  65. janus/prompts/templates/diagram_with_documentation/human.txt +15 -0
  66. janus/prompts/templates/diagram_with_documentation/system.txt +1 -0
  67. janus/prompts/templates/document/human.txt +10 -0
  68. janus/prompts/templates/document/system.txt +1 -0
  69. janus/prompts/templates/document_cloze/human.txt +11 -0
  70. janus/prompts/templates/document_cloze/system.txt +1 -0
  71. janus/prompts/templates/document_cloze/variables.json +4 -0
  72. janus/prompts/templates/document_cloze/variables_asm.json +4 -0
  73. janus/prompts/templates/document_inline/human.txt +13 -0
  74. janus/prompts/templates/eval_prompts/incose/human.txt +32 -0
  75. janus/prompts/templates/eval_prompts/incose/system.txt +1 -0
  76. janus/prompts/templates/eval_prompts/incose/variables.json +3 -0
  77. janus/prompts/templates/eval_prompts/inline_comments/human.txt +49 -0
  78. janus/prompts/templates/eval_prompts/inline_comments/system.txt +1 -0
  79. janus/prompts/templates/eval_prompts/inline_comments/variables.json +3 -0
  80. janus/prompts/templates/micromanaged_mumps_v1.0/human.txt +23 -0
  81. janus/prompts/templates/micromanaged_mumps_v1.0/system.txt +3 -0
  82. janus/prompts/templates/micromanaged_mumps_v2.0/human.txt +28 -0
  83. janus/prompts/templates/micromanaged_mumps_v2.0/system.txt +3 -0
  84. janus/prompts/templates/micromanaged_mumps_v2.1/human.txt +29 -0
  85. janus/prompts/templates/micromanaged_mumps_v2.1/system.txt +3 -0
  86. janus/prompts/templates/multidocument/human.txt +15 -0
  87. janus/prompts/templates/multidocument/system.txt +1 -0
  88. janus/prompts/templates/partition/human.txt +22 -0
  89. janus/prompts/templates/partition/system.txt +1 -0
  90. janus/prompts/templates/partition/variables.json +4 -0
  91. janus/prompts/templates/pseudocode/human.txt +7 -0
  92. janus/prompts/templates/pseudocode/system.txt +7 -0
  93. janus/prompts/templates/refinement/fix_exceptions/human.txt +19 -0
  94. janus/prompts/templates/refinement/fix_exceptions/system.txt +1 -0
  95. janus/prompts/templates/refinement/format/code_format/human.txt +12 -0
  96. janus/prompts/templates/refinement/format/code_format/system.txt +1 -0
  97. janus/prompts/templates/refinement/format/requirements_format/human.txt +14 -0
  98. janus/prompts/templates/refinement/format/requirements_format/system.txt +1 -0
  99. janus/prompts/templates/refinement/hallucination/human.txt +13 -0
  100. janus/prompts/templates/refinement/hallucination/system.txt +1 -0
  101. janus/prompts/templates/refinement/reflection/human.txt +15 -0
  102. janus/prompts/templates/refinement/reflection/incose/human.txt +26 -0
  103. janus/prompts/templates/refinement/reflection/incose/system.txt +1 -0
  104. janus/prompts/templates/refinement/reflection/incose_deduplicate/human.txt +16 -0
  105. janus/prompts/templates/refinement/reflection/incose_deduplicate/system.txt +1 -0
  106. janus/prompts/templates/refinement/reflection/system.txt +1 -0
  107. janus/prompts/templates/refinement/revision/human.txt +16 -0
  108. janus/prompts/templates/refinement/revision/incose/human.txt +16 -0
  109. janus/prompts/templates/refinement/revision/incose/system.txt +1 -0
  110. janus/prompts/templates/refinement/revision/incose_deduplicate/human.txt +17 -0
  111. janus/prompts/templates/refinement/revision/incose_deduplicate/system.txt +1 -0
  112. janus/prompts/templates/refinement/revision/system.txt +1 -0
  113. janus/prompts/templates/refinement/uml/alc_fix_variables/human.txt +15 -0
  114. janus/prompts/templates/refinement/uml/alc_fix_variables/system.txt +2 -0
  115. janus/prompts/templates/refinement/uml/fix_connections/human.txt +15 -0
  116. janus/prompts/templates/refinement/uml/fix_connections/system.txt +2 -0
  117. janus/prompts/templates/requirements/human.txt +13 -0
  118. janus/prompts/templates/requirements/system.txt +2 -0
  119. janus/prompts/templates/retrieval/language_docs/human.txt +10 -0
  120. janus/prompts/templates/retrieval/language_docs/system.txt +1 -0
  121. janus/prompts/templates/simple/human.txt +16 -0
  122. janus/prompts/templates/simple/system.txt +3 -0
  123. janus/refiners/format.py +49 -0
  124. janus/refiners/refiner.py +113 -4
  125. janus/utils/enums.py +127 -112
  126. janus/utils/logger.py +2 -0
  127. {janus_llm-4.3.1.dist-info → janus_llm-4.4.5.dist-info}/METADATA +18 -18
  128. janus_llm-4.4.5.dist-info/RECORD +210 -0
  129. {janus_llm-4.3.1.dist-info → janus_llm-4.4.5.dist-info}/WHEEL +1 -1
  130. janus_llm-4.4.5.dist-info/entry_points.txt +3 -0
  131. janus/cli.py +0 -1488
  132. janus/metrics/_tests/test_llm.py +0 -90
  133. janus/metrics/llm_metrics.py +0 -202
  134. janus_llm-4.3.1.dist-info/RECORD +0 -115
  135. janus_llm-4.3.1.dist-info/entry_points.txt +0 -3
  136. {janus_llm-4.3.1.dist-info → janus_llm-4.4.5.dist-info}/LICENSE +0 -0
@@ -27,7 +27,7 @@ from janus.language.splitter import (
27
27
  )
28
28
  from janus.llm.model_callbacks import get_model_callback
29
29
  from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel, load_model
30
- from janus.parsers.parser import GenericParser, JanusParser
30
+ from janus.parsers.parser import GenericParser, JanusParser, JanusParserException
31
31
  from janus.refiners.refiner import JanusRefiner
32
32
 
33
33
  # from janus.refiners.refiner import BasicRefiner, Refiner
@@ -76,7 +76,7 @@ class Converter:
76
76
  source_language: str = "fortran",
77
77
  max_prompts: int = 10,
78
78
  max_tokens: int | None = None,
79
- prompt_template: str = "simple",
79
+ prompt_templates: list[str] | str = ["simple"],
80
80
  db_path: str | None = None,
81
81
  db_config: dict[str, Any] | None = None,
82
82
  protected_node_types: tuple[str, ...] = (),
@@ -84,6 +84,10 @@ class Converter:
84
84
  splitter_type: str = "file",
85
85
  refiner_types: list[type[JanusRefiner]] = [JanusRefiner],
86
86
  retriever_type: str | None = None,
87
+ combine_output: bool = True,
88
+ use_janus_inputs: bool = False,
89
+ target_language: str = "json",
90
+ target_version: str | None = None,
87
91
  ) -> None:
88
92
  """Initialize a Converter instance.
89
93
 
@@ -96,7 +100,7 @@ class Converter:
96
100
  max_prompts: The maximum number of prompts to try before giving up.
97
101
  max_tokens: The maximum number of tokens to use in the LLM. If `None`, the
98
102
  converter will use half the model's token limit.
99
- prompt_template: The name of the prompt template to use.
103
+ prompt_templates: The name of the prompt templates to use.
100
104
  db_path: The path to the database to use for vectorization.
101
105
  db_config: The configuration for the database.
102
106
  protected_node_types: A set of node types that aren't to be merged.
@@ -111,26 +115,34 @@ class Converter:
111
115
  - "active_usings"
112
116
  - "language_docs"
113
117
  - None
118
+ combine_output: Whether to combine the output into a single file or not.
119
+ use_janus_inputs: Whether to use janus inputs or not.
120
+ target_language: The target programming language.
121
+ target_version: The target programming language version.
114
122
  """
115
123
  self._changed_attrs: set = set()
116
124
 
117
125
  self.max_prompts: int = max_prompts
118
126
  self._max_tokens: int | None = max_tokens
119
127
  self.override_token_limit: bool = max_tokens is not None
128
+ self._combine_output = combine_output
120
129
 
121
130
  self._model_name: str
122
131
  self._custom_model_arguments: dict[str, Any]
123
132
 
124
133
  self._source_language: str
125
- self._source_suffix: str
134
+ self._source_suffixes: list[str]
126
135
 
127
- self._target_language = "json"
128
- self._target_suffix = ".json"
136
+ self._target_language: str
137
+ self._target_suffix: str
138
+ self._target_version: str | None
139
+ self.set_target_language(target_language, target_version)
140
+ self._use_janus_inputs = use_janus_inputs
129
141
 
130
142
  self._protected_node_types: tuple[str, ...] = ()
131
143
  self._prune_node_types: tuple[str, ...] = ()
132
144
  self._max_tokens: int | None = max_tokens
133
- self._prompt_template_name: str
145
+ self._prompt_template_names: list[str]
134
146
  self._db_path: str | None
135
147
  self._db_config: dict[str, Any] | None
136
148
 
@@ -153,7 +165,7 @@ class Converter:
153
165
  self.set_refiner_types(refiner_types=refiner_types)
154
166
  self.set_retriever(retriever_type=retriever_type)
155
167
  self.set_model(model_name=model, **model_arguments)
156
- self.set_prompt(prompt_template=prompt_template)
168
+ self.set_prompts(prompt_templates=prompt_templates)
157
169
  self.set_source_language(source_language)
158
170
  self.set_protected_node_types(protected_node_types)
159
171
  self.set_prune_node_types(prune_node_types)
@@ -174,7 +186,7 @@ class Converter:
174
186
 
175
187
  def _load_parameters(self) -> None:
176
188
  self._load_model()
177
- self._load_prompt()
189
+ self._load_translation_chain()
178
190
  self._load_retriever()
179
191
  self._load_refiner_chain()
180
192
  self._load_splitter()
@@ -195,21 +207,23 @@ class Converter:
195
207
  self._model_name = model_name
196
208
  self._custom_model_arguments = custom_arguments
197
209
 
198
- def set_prompt(self, prompt_template: str) -> None:
210
+ def set_prompts(self, prompt_templates: list[str] | str) -> None:
199
211
  """Validate and set the prompt template name.
200
212
 
201
213
  Arguments:
202
- prompt_template: name of prompt template directory
203
- (see janus/prompts/templates) or path to a directory.
214
+ prompt_templates: name of prompt template directories
215
+ (see janus/prompts/templates) or paths to directories.
204
216
  """
205
- self._prompt_template_name = prompt_template
217
+ if isinstance(prompt_templates, str):
218
+ self._prompt_template_names = [prompt_templates]
219
+ else:
220
+ self._prompt_template_names = prompt_templates
206
221
 
207
222
  def set_splitter(self, splitter_type: str) -> None:
208
223
  """Validate and set the prompt template name.
209
224
 
210
225
  Arguments:
211
- prompt_template: name of prompt template directory
212
- (see janus/prompts/templates) or path to a directory.
226
+ splitter_type: the type of splitter to use
213
227
  """
214
228
  if splitter_type not in CUSTOM_SPLITTERS:
215
229
  raise ValueError(f'Splitter type "{splitter_type}" does not exist.')
@@ -245,8 +259,10 @@ class Converter:
245
259
  "Valid source languages are found in `janus.utils.enums.LANGUAGES`."
246
260
  )
247
261
 
248
- ext = LANGUAGES[source_language]["suffix"]
249
- self._source_suffix = f".{ext}"
262
+ self._source_suffixes = [
263
+ f".{ext}" for ext in LANGUAGES[source_language]["suffixes"]
264
+ ]
265
+
250
266
  self._source_language = source_language
251
267
 
252
268
  def set_protected_node_types(self, protected_node_types: tuple[str, ...]) -> None:
@@ -324,28 +340,48 @@ class Converter:
324
340
  # tokens at output
325
341
  # Only modify max_tokens if it is not specified by user
326
342
  if not self.override_token_limit:
327
- self._max_tokens = int(token_limit // 2.5)
343
+ self._max_tokens = int(token_limit * self._llm.input_token_proportion)
328
344
 
329
- @run_if_changed(
330
- "_prompt_template_name",
331
- "_source_language",
332
- "_model_name",
333
- "_parser",
334
- )
335
- def _load_prompt(self) -> None:
336
- """Load the prompt according to this instance's attributes.
337
-
338
- If the relevant fields have not been changed since the last time this
339
- method was called, nothing happens.
340
- """
345
+ @run_if_changed("_prompt_template_names", "_source_language", "_model_name")
346
+ def _load_translation_chain(self) -> None:
347
+ prompt_template_name = self._prompt_template_names[0]
341
348
  prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
342
349
  source_language=self._source_language,
343
- prompt_template=self._prompt_template_name,
350
+ prompt_template=prompt_template_name,
351
+ target_language=self._target_language,
352
+ target_version=self._target_version,
344
353
  )
345
- self._prompt = prompt_engine.prompt
346
- self._prompt = self._prompt.partial(
347
- format_instructions=self._parser.get_format_instructions()
354
+ prompt = prompt_engine.prompt
355
+ self._translation_chain = RunnableParallel(
356
+ prompt_value=lambda x, prompt=prompt: prompt.invoke(x),
357
+ original_inputs=RunnablePassthrough(),
358
+ ) | RunnableParallel(
359
+ completion=lambda x: self._llm.invoke(x["prompt_value"]),
360
+ original_inputs=lambda x: x["original_inputs"],
361
+ prompt_value=lambda x: x["prompt_value"],
348
362
  )
363
+ for prompt_template_name in self._prompt_template_names[1:]:
364
+ prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
365
+ source_language=self._source_language,
366
+ prompt_template=prompt_template_name,
367
+ target_language=self._target_language,
368
+ target_version=self._target_version,
369
+ )
370
+ prompt = prompt_engine.prompt
371
+ self._translation_chain = (
372
+ self._translation_chain
373
+ | RunnableParallel(
374
+ prompt_value=lambda x, prompt=prompt: prompt.invoke(
375
+ dict(completion=x["completion"], **x["original_inputs"])
376
+ ),
377
+ original_inputs=lambda x: x["original_inputs"],
378
+ )
379
+ | RunnableParallel(
380
+ completion=lambda x: self._llm.invoke(x["prompt_value"]),
381
+ original_inputs=lambda x: x["original_inputs"],
382
+ prompt_value=lambda x: x["prompt_value"],
383
+ )
384
+ )
349
385
 
350
386
  @run_if_changed("_db_path", "_db_config")
351
387
  def _load_vectorizer(self) -> None:
@@ -368,11 +404,31 @@ class Converter:
368
404
 
369
405
  @run_if_changed("_refiner_types", "_model_name", "max_prompts", "_parser")
370
406
  def _load_refiner_chain(self) -> None:
371
- self._refiner_chain = RunnableParallel(
372
- completion=self._llm,
373
- prompt_value=RunnablePassthrough(),
374
- )
375
- for refiner_type in self._refiner_types[:-1]:
407
+ if len(self._refiner_types) == 0:
408
+ self._refiner_chain = RunnableLambda(
409
+ lambda x: self._parser.parse(x["completion"])
410
+ )
411
+ return
412
+ refiner_type = self._refiner_types[0]
413
+ if len(self._refiner_types) == 1:
414
+ self._refiner_chain = RunnableLambda(
415
+ lambda x, refiner_type=refiner_type: refiner_type(
416
+ llm=self._llm,
417
+ parser=self._parser,
418
+ max_retries=self.max_prompts,
419
+ ).parse_completion(**x)
420
+ )
421
+ return
422
+ else:
423
+ self._refiner_chain = RunnableParallel(
424
+ completion=lambda x, refiner_type=refiner_type: refiner_type(
425
+ llm=self._llm,
426
+ parser=self._base_parser,
427
+ max_retries=self.max_prompts,
428
+ ).parse_completion(**x),
429
+ prompt_value=lambda x: x["prompt_value"],
430
+ )
431
+ for refiner_type in self._refiner_types[1:-1]:
376
432
  # NOTE: Do NOT remove refiner_type=refiner_type from lambda.
377
433
  # Due to lambda capture, must be present or chain will not
378
434
  # be correctly constructed.
@@ -394,7 +450,7 @@ class Converter:
394
450
 
395
451
  @run_if_changed("_parser", "_retriever", "_prompt", "_llm", "_refiner_chain")
396
452
  def _load_chain(self):
397
- self.chain = self._input_runnable() | self._prompt | self._refiner_chain
453
+ self.chain = self.get_chain()
398
454
 
399
455
  def _input_runnable(self) -> Runnable:
400
456
  return RunnableParallel(
@@ -402,10 +458,17 @@ class Converter:
402
458
  context=self._retriever,
403
459
  )
404
460
 
461
+ def get_chain(self) -> Runnable:
462
+ """
463
+ Gets a chain that can be executed by langchain
464
+ """
465
+ return self._input_runnable() | self._translation_chain | self._refiner_chain
466
+
405
467
  def translate(
406
468
  self,
407
469
  input_directory: str | Path,
408
470
  output_directory: str | Path | None = None,
471
+ failure_directory: str | Path | None = None,
409
472
  overwrite: bool = False,
410
473
  collection_name: str | None = None,
411
474
  ) -> None:
@@ -423,92 +486,91 @@ class Converter:
423
486
  input_directory = Path(input_directory)
424
487
  if isinstance(output_directory, str):
425
488
  output_directory = Path(output_directory)
489
+ if isinstance(failure_directory, str):
490
+ failure_directory = Path(failure_directory)
426
491
 
427
492
  # Make sure the output directory exists
428
493
  if output_directory is not None and not output_directory.exists():
429
494
  output_directory.mkdir(parents=True)
495
+ if failure_directory is not None and not failure_directory.exists():
496
+ failure_directory.mkdir(parents=True)
430
497
 
431
- input_paths = [p for p in input_directory.rglob(f"**/*{self._source_suffix}")]
498
+ input_paths = []
499
+ if self._use_janus_inputs:
500
+ source_language = "janus"
501
+ source_suffixes = [".json"]
502
+ else:
503
+ source_language = self._source_language
504
+ source_suffixes = self._source_suffixes
505
+ for ext in source_suffixes:
506
+ input_paths.extend(input_directory.rglob(f"**/*{ext}"))
432
507
 
433
508
  log.info(f"Input directory: {input_directory.absolute()}")
434
- log.info(
435
- f"{self._source_language} '*{self._source_suffix}' files: "
436
- f"{len(input_paths)}"
437
- )
509
+ log.info(f"{source_language} {source_suffixes} files: " f"{len(input_paths)}")
438
510
  log.info(
439
511
  "Other files (skipped): "
440
512
  f"{len(list(input_directory.iterdir())) - len(input_paths)}\n"
441
513
  )
442
514
  if output_directory is not None:
443
515
  output_paths = [
444
- output_directory
445
- / p.relative_to(input_directory).with_suffix(self._target_suffix)
516
+ output_directory / p.relative_to(input_directory).with_suffix(".json")
446
517
  for p in input_paths
447
518
  ]
448
- in_out_pairs = list(zip(input_paths, output_paths))
449
- if not overwrite:
450
- n_files = len(in_out_pairs)
451
- in_out_pairs = [
452
- (inp, outp) for inp, outp in in_out_pairs if not outp.exists()
453
- ]
454
- log.info(
455
- f"Skipping {n_files - len(in_out_pairs)} existing "
456
- f"'*{self._source_suffix}' files"
457
- )
458
519
  else:
459
- in_out_pairs = [(f, None) for f in input_paths]
460
- log.info(f"Translating {len(in_out_pairs)} '*{self._source_suffix}' files")
520
+ output_paths = [None for _ in input_paths]
521
+
522
+ if failure_directory is not None:
523
+ failure_paths = [
524
+ failure_directory / p.relative_to(input_directory).with_suffix(".json")
525
+ for p in input_paths
526
+ ]
527
+ else:
528
+ failure_paths = [None for _ in input_paths]
529
+ in_out_pairs = list(zip(input_paths, output_paths, failure_paths))
530
+ if not overwrite:
531
+ n_files = len(in_out_pairs)
532
+ in_out_pairs = [
533
+ (inp, outp, failp)
534
+ for inp, outp, failp in in_out_pairs
535
+ if outp is None or not outp.exists()
536
+ ]
537
+ log.info(
538
+ f"Skipping {n_files - len(in_out_pairs)} existing "
539
+ f"{self._source_suffixes} files"
540
+ )
541
+ log.info(f"Translating {len(in_out_pairs)} {self._source_suffixes} files")
461
542
 
462
543
  # Loop through each input file, convert and save it
463
544
  total_cost = 0.0
464
- for in_path, out_path in in_out_pairs:
545
+ for in_path, out_path, fail_path in in_out_pairs:
465
546
  # Translate the file, skip it if there's a rate limit error
466
- try:
467
- log.info(f"Processing {in_path.relative_to(input_directory)}")
468
- out_block = self.translate_file(in_path)
469
- total_cost += out_block.total_cost
470
- except RateLimitError:
471
- continue
472
- except OutputParserException as e:
473
- log.error(f"Skipping {in_path.name}, failed to parse output: {e}.")
474
- continue
475
- except BadRequestError as e:
476
- if str(e).startswith("Detected an error in the prompt"):
477
- log.warning("Malformed input, skipping")
478
- continue
479
- raise e
480
- except ValidationError as e:
481
- # Only allow ValidationError to pass if token limit is manually set
482
- if self.override_token_limit:
483
- log.warning(
484
- "Current file and manually set token "
485
- "limit is too large for this model, skipping"
486
- )
487
- continue
488
- raise e
489
- except TokenLimitError:
490
- log.warning("Ran into irreducible node too large for context, skipping")
491
- continue
492
- except EmptyTreeError:
493
- log.warning(
494
- f'Input file "{in_path.name}" has no nodes of interest, skipping'
495
- )
496
- continue
497
- except FileSizeError:
498
- log.warning("Current tile is too large for basic splitter, skipping")
499
- continue
500
- except ValueError as e:
501
- if str(e).startswith(
502
- "Error raised by bedrock service"
503
- ) and "maximum context length" in str(e):
504
- log.warning(
505
- "Input is too large for this model's context length, skipping"
506
- )
507
- continue
508
- raise e
547
+ log.info(f"Processing {in_path.relative_to(input_directory)}")
548
+ if self._use_janus_inputs:
549
+ out_block = self.translate_janus_file(in_path, fail_path)
550
+ else:
551
+ out_block = self.translate_file(in_path, fail_path)
552
+
553
+ def _get_total_cost(block):
554
+ if isinstance(block, list):
555
+ return sum(_get_total_cost(b) for b in block)
556
+ return block.total_cost
557
+
558
+ total_cost += _get_total_cost(out_block)
559
+ log.info(f"Current Running Cost: {total_cost}")
560
+
561
+ # For files where translation failed, write to failure path instead
509
562
 
510
- # Don't attempt to write files for which translation failed
511
- if not out_block.translated:
563
+ def _has_empty(block):
564
+ if isinstance(block, list):
565
+ return len(block) == 0 or any(_has_empty(b) for b in block)
566
+ return not block.translated
567
+
568
+ while isinstance(out_block, list) and len(out_block) == 1:
569
+ out_block = out_block[0]
570
+
571
+ if _has_empty(out_block):
572
+ if fail_path is not None:
573
+ self._save_to_file(out_block, fail_path)
512
574
  continue
513
575
 
514
576
  if collection_name is not None:
@@ -526,47 +588,83 @@ class Converter:
526
588
 
527
589
  log.info(f"Total cost: ${total_cost:,.2f}")
528
590
 
529
- def translate_file(self, file: Path) -> TranslatedCodeBlock:
530
- """Translate a single file.
531
-
532
- Arguments:
533
- file: Input path to file
534
-
535
- Returns:
536
- A `TranslatedCodeBlock` object. This block does not have a path set, and its
537
- code is not guaranteed to be consolidated. To amend this, run
538
- `Combiner.combine_children` on the block.
539
- """
591
+ def translate_block(
592
+ self,
593
+ input_block: CodeBlock | list[CodeBlock],
594
+ name: str,
595
+ failure_path: Path | None = None,
596
+ ):
540
597
  self._load_parameters()
541
- filename = file.name
542
-
543
- input_block = self._split_file(file)
598
+ if isinstance(input_block, list):
599
+ return [self.translate_block(b, name, failure_path) for b in input_block]
544
600
  t0 = time.time()
545
- output_block = self._iterative_translate(input_block)
601
+ output_block = self._iterative_translate(input_block, failure_path)
546
602
  output_block.processing_time = time.time() - t0
547
603
  if output_block.translated:
548
604
  completeness = output_block.translation_completeness
549
605
  log.info(
550
- f"[{filename}] Translation complete\n"
606
+ f"[{name}] Translation complete\n"
551
607
  f" {completeness:.2%} of input successfully translated\n"
552
608
  f" Total cost: ${output_block.total_cost:,.2f}\n"
553
- f" Total retries: {output_block.total_retries:,d}\n"
554
609
  f" Output CodeBlock Structure:\n{input_block.tree_str()}\n"
555
610
  )
556
611
 
557
612
  else:
558
613
  log.error(
559
- f"[{filename}] Translation failed\n"
614
+ f"[{name}] Translation failed\n"
560
615
  f" Total cost: ${output_block.total_cost:,.2f}\n"
561
- f" Total retries: {output_block.total_retries:,d}\n"
562
616
  )
563
617
  return output_block
564
618
 
565
- def _iterative_translate(self, root: CodeBlock) -> TranslatedCodeBlock:
619
+ def translate_file(
620
+ self,
621
+ file: Path,
622
+ failure_path: Path | None = None,
623
+ ) -> TranslatedCodeBlock:
624
+ """Translate a single file.
625
+
626
+ Arguments:
627
+ file: Input path to file
628
+ failure_path: path to directory to store failure summaries`
629
+
630
+ Returns:
631
+ A `TranslatedCodeBlock` object. This block does not have a path set, and its
632
+ code is not guaranteed to be consolidated. To amend this, run
633
+ `Combiner.combine_children` on the block.
634
+ """
635
+ filename = file.name
636
+ input_block = self._split_file(file)
637
+ return self.translate_block(input_block, filename, failure_path)
638
+
639
+ def translate_janus_file(self, file: Path, failure_path: Path | None = None):
640
+ filename = file.name
641
+ with open(file, "r") as f:
642
+ file_obj = json.load(f)
643
+ return self.translate_janus_obj(file_obj, filename, failure_path)
644
+
645
+ def translate_janus_obj(self, obj: Any, name: str, failure_path: Path | None = None):
646
+ block = self._janus_object_to_codeblock(obj, name)
647
+ return self.translate_block(block)
648
+
649
+ def translate_text(self, text: str, name: str, failure_path: Path | None = None):
650
+ """
651
+ Translates given text
652
+ Arguments:
653
+ text: text to translate
654
+ name: the name of the text (filename if from a file)
655
+ failure_path: path to write failure file if translation is not successful
656
+ """
657
+ input_block = self._split_text(text, name)
658
+ return self.translate_block(input_block, name, failure_path)
659
+
660
+ def _iterative_translate(
661
+ self, root: CodeBlock, failure_path: Path | None = None
662
+ ) -> TranslatedCodeBlock:
566
663
  """Translate the passed CodeBlock representing a full file.
567
664
 
568
665
  Arguments:
569
666
  root: A root block representing the top-level block of a file
667
+ failure_path: path to store data files for failed translations
570
668
 
571
669
  Returns:
572
670
  A `TranslatedCodeBlock`
@@ -574,22 +672,59 @@ class Converter:
574
672
  translated_root = TranslatedCodeBlock(root, self._target_language)
575
673
  last_prog, prog_delta = 0, 0.1
576
674
  stack = [translated_root]
577
- while stack:
578
- translated_block = stack.pop()
579
-
580
- self._add_translation(translated_block)
675
+ try:
676
+ while stack:
677
+ translated_block = stack.pop()
581
678
 
582
- # If translating this block was unsuccessful, don't bother with its
583
- # children (they wouldn't show up in the final text anyway)
584
- if not translated_block.translated:
585
- continue
679
+ self._add_translation(translated_block)
586
680
 
587
- stack.extend(translated_block.children)
681
+ # If translating this block was unsuccessful, don't bother with its
682
+ # children (they wouldn't show up in the final text anyway)
683
+ if not translated_block.translated:
684
+ continue
588
685
 
589
- progress = translated_root.translation_completeness
590
- if progress - last_prog > prog_delta:
591
- last_prog = int(progress / prog_delta) * prog_delta
592
- log.info(f"[{root.name}] progress: {progress:.2%}")
686
+ stack.extend(translated_block.children)
687
+
688
+ progress = translated_root.translation_completeness
689
+ if progress - last_prog > prog_delta:
690
+ last_prog = int(progress / prog_delta) * prog_delta
691
+ log.info(f"[{root.name}] progress: {progress:.2%}")
692
+ except RateLimitError:
693
+ pass
694
+ except OutputParserException as e:
695
+ log.error(f"Skipping file, failed to parse output: {e}.")
696
+ except BadRequestError as e:
697
+ if str(e).startswith("Detected an error in the prompt"):
698
+ log.warning("Malformed input, skipping")
699
+ raise e
700
+ except ValidationError as e:
701
+ # Only allow ValidationError to pass if token limit is manually set
702
+ if self.override_token_limit:
703
+ log.warning(
704
+ "Current file and manually set token "
705
+ "limit is too large for this model, skipping"
706
+ )
707
+ raise e
708
+ except TokenLimitError:
709
+ log.warning("Ran into irreducible node too large for context, skipping")
710
+ except EmptyTreeError:
711
+ log.warning("Input file has no nodes of interest, skipping")
712
+ except FileSizeError:
713
+ log.warning("Current tile is too large for basic splitter, skipping")
714
+ except ValueError as e:
715
+ if str(e).startswith(
716
+ "Error raised by bedrock service"
717
+ ) and "maximum context length" in str(e):
718
+ log.warning(
719
+ "Input is too large for this model's context length, skipping"
720
+ )
721
+ raise e
722
+ finally:
723
+ out_obj = self._get_output_obj(translated_root, self._combine_output)
724
+ log.debug(f"Resulting Block:" f"{json.dumps(out_obj)}")
725
+ if not translated_root.translated:
726
+ if failure_path is not None:
727
+ self._save_to_file(translated_root, failure_path)
593
728
 
594
729
  return translated_root
595
730
 
@@ -624,17 +759,35 @@ class Converter:
624
759
  # TODO: If non-OpenAI models with prices are added, this will need
625
760
  # to be updated.
626
761
  with get_model_callback() as cb:
627
- t0 = time.time()
628
- block.text = self._run_chain(block)
629
- block.processing_time = time.time() - t0
630
- block.cost = cb.total_cost
631
- block.retries = max(0, cb.successful_requests - 1)
762
+ try:
763
+ t0 = time.time()
764
+ block.text = self._run_chain(block)
765
+ except JanusParserException as e:
766
+ block.text = e.unparsed_output
767
+ block.tokens = self._llm.get_num_tokens(block.text)
768
+ raise e
769
+ finally:
770
+ block.processing_time = time.time() - t0
771
+ block.cost = cb.total_cost
772
+ block.request_input_tokens = cb.prompt_tokens
773
+ block.request_output_tokens = cb.completion_tokens
774
+ block.num_requests = cb.successful_requests
632
775
 
633
776
  block.tokens = self._llm.get_num_tokens(block.text)
634
777
  block.translated = True
635
778
 
636
779
  log.debug(f"[{block.name}] Output code:\n{block.text}")
637
780
 
781
+ def _split_text(self, text: str, name: str) -> CodeBlock:
782
+ log.info(f"[{name}] Splitting text")
783
+ root = self._splitter.split_string(text, name)
784
+ log.info(
785
+ f"[{name}] Text split into {root.n_descendents:,} blocks,"
786
+ f"tree of height {root.height}"
787
+ )
788
+ log.info(f"[{name}] Input CodeBlock Structure:\n{root.tree_str()}")
789
+ return root
790
+
638
791
  def _split_file(self, file: Path) -> CodeBlock:
639
792
  filename = file.name
640
793
  log.info(f"[{filename}] Splitting file")
@@ -649,33 +802,171 @@ class Converter:
649
802
  def _run_chain(self, block: TranslatedCodeBlock) -> str:
650
803
  return self.chain.invoke(block.original)
651
804
 
805
+ def _combine_metadata(self, metadatas: list[dict]):
806
+ return dict(
807
+ cost=sum(m["cost"] for m in metadatas),
808
+ processing_time=sum(m["processing_time"] for m in metadatas),
809
+ num_requests=sum(m["num_requests"] for m in metadatas),
810
+ input_tokens=sum(m["input_tokens"] for m in metadatas),
811
+ output_tokens=sum(m["output_tokens"] for m in metadatas),
812
+ converter_name=self.__class__.__name__,
813
+ )
814
+
815
+ def _combine_inputs(self, inputs: list[str]):
816
+ s = ""
817
+ for i in inputs:
818
+ s += i
819
+ return s
820
+
652
821
  def _get_output_obj(
653
- self, block: TranslatedCodeBlock
822
+ self, block: TranslatedCodeBlock | list, combine_children: bool = True
654
823
  ) -> dict[str, int | float | str | dict[str, str] | dict[str, float]]:
655
- output_str = self._parser.parse_combined_output(block.complete_text)
656
-
824
+ if isinstance(block, list):
825
+ # TODO: run on all items in list
826
+ outputs = [self._get_output_obj(b, combine_children) for b in block]
827
+ metadata = self._combine_metadata([o["metadata"] for o in outputs])
828
+ input_agg = self._combine_inputs(o["input"] for o in outputs)
829
+ return dict(
830
+ input=input_agg,
831
+ metadata=metadata,
832
+ outputs=outputs,
833
+ )
834
+ if not combine_children and len(block.children) > 0:
835
+ outputs = self._get_output_obj_children(block)
836
+ metadata = self._combine_metadata([o["metadata"] for o in outputs])
837
+ input_agg = self._combine_inputs(o["input"] for o in outputs)
838
+ return dict(
839
+ input=input_agg,
840
+ metadata=metadata,
841
+ outputs=outputs,
842
+ )
657
843
  output_obj: str | dict[str, str]
658
- try:
659
- output_obj = json.loads(output_str)
660
- except json.JSONDecodeError:
661
- output_obj = output_str
844
+ if not block.translation_completed:
845
+ # translation wasn't completed, so combined parsing will likely fail
846
+ output_obj = [block.complete_text]
847
+ else:
848
+ output_str = self._parser.parse_combined_output(block.complete_text)
849
+ output_obj = [output_str]
662
850
 
663
851
  return dict(
664
852
  input=block.original.text or "",
665
853
  metadata=dict(
666
- retries=block.total_retries,
667
854
  cost=block.total_cost,
668
855
  processing_time=block.processing_time,
856
+ num_requests=block.total_num_requests,
857
+ input_tokens=block.total_request_input_tokens,
858
+ output_tokens=block.total_request_output_tokens,
859
+ converter_name=self.__class__.__name__,
669
860
  ),
670
- output=output_obj,
861
+ outputs=output_obj,
671
862
  )
672
863
 
864
+ def _get_output_obj_children(self, block: TranslatedCodeBlock):
865
+ if len(block.children) > 0:
866
+ res = []
867
+ for c in block.children:
868
+ res += self._get_output_obj_children(c)
869
+ return res
870
+ else:
871
+ return [self._get_output_obj(block, combine_children=True)]
872
+
673
873
  def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
674
874
  """Save a file to disk.
675
875
 
676
876
  Arguments:
677
877
  block: The `TranslatedCodeBlock` to save to a file.
678
878
  """
679
- obj = self._get_output_obj(block)
879
+ obj = self._get_output_obj(block, combine_children=self._combine_output)
680
880
  out_path.parent.mkdir(parents=True, exist_ok=True)
681
881
  out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")
882
+
883
+ def _janus_object_to_codeblock(self, janus_obj: dict, name: str):
884
+ results = []
885
+ for o in janus_obj["outputs"]:
886
+ if isinstance(o, str):
887
+ code_block = self._split_text(o, name)
888
+ meta_data = janus_obj["metadata"]
889
+ code_block.initial_cost = meta_data["cost"]
890
+ code_block.initial_input_tokens = meta_data["input_tokens"]
891
+ code_block.initial_output_tokens = meta_data["output_tokens"]
892
+ code_block.initial_num_requests = meta_data["num_requests"]
893
+ code_block.initial_processing_time = meta_data["processing_time"]
894
+ code_block.previous_generations = janus_obj.get(
895
+ "intermediate_outputs", []
896
+ ) + [janus_obj]
897
+ results.append(code_block)
898
+ else:
899
+ results.append(self._janus_object_to_codeblock(o))
900
+ while isinstance(results, list) and len(results) == 1:
901
+ results = results[0]
902
+ return results
903
+
904
+ def __or__(self, other: "Converter"):
905
+ from janus.converter.chain import ConverterChain
906
+
907
+ return ConverterChain(self, other)
908
+
909
+ @property
910
+ def source_language(self):
911
+ return self._source_language
912
+
913
+ @property
914
+ def target_language(self):
915
+ return self._target_language
916
+
917
+ @property
918
+ def target_version(self):
919
+ return self._target_version
920
+
921
+ def set_target_language(
922
+ self, target_language: str, target_version: str | None
923
+ ) -> None:
924
+ """Validate and set the target language.
925
+
926
+ The affected objects will not be updated until translate() is called.
927
+
928
+ Arguments:
929
+ target_language: The target programming language.
930
+ target_version: The target version of the target programming language.
931
+ """
932
+ target_language = target_language.lower()
933
+ if target_language not in LANGUAGES:
934
+ raise ValueError(
935
+ f"Invalid target language: {target_language}. "
936
+ "Valid target languages are found in `janus.utils.enums.LANGUAGES`."
937
+ )
938
+ self._target_language = target_language
939
+ self._target_version = target_version
940
+ # Taking the first suffix as the default for output files
941
+ self._target_suffix = f".{LANGUAGES[target_language]['suffixes'][0]}"
942
+
943
+ @classmethod
944
+ def eval_obj(cls, target, metric_func, *args, **kwargs):
945
+ if "reference" in kwargs:
946
+ return cls.eval_obj_reference(target, metric_func, *args, **kwargs)
947
+ else:
948
+ return cls.eval_obj_noreference(target, metric_func, *args, **kwargs)
949
+
950
+ @classmethod
951
+ def eval_obj_noreference(cls, target, metric_func, *args, **kwargs):
952
+ results = []
953
+ for o in target["outputs"]:
954
+ if isinstance(o, dict):
955
+ results += cls.eval_obj_noreference(o, metric_func, *args, **kwargs)
956
+ else:
957
+ results.append(metric_func(o, *args, **kwargs))
958
+ return results
959
+
960
+ @classmethod
961
+ def eval_obj_reference(cls, target, metric_func, reference, *args, **kwargs):
962
+ results = []
963
+ for o, r in zip(target["outputs"], reference["outputs"]):
964
+ if isinstance(o, dict):
965
+ if not isinstance(r, dict):
966
+ raise ValueError("Error: format of reference doesn't match target")
967
+ results += cls.eval_obj_reference(o, metric_func, r, *args, **kwargs)
968
+ else:
969
+ if isinstance(r, dict):
970
+ raise ValueError("Error: format of reference doesn't match target")
971
+ results.append(metric_func(o, r, *args, **kwargs))
972
+ return results