janus-llm 4.3.5__py3-none-any.whl → 4.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. janus/__init__.py +1 -1
  2. janus/cli/aggregate.py +2 -2
  3. janus/cli/cli.py +6 -0
  4. janus/cli/constants.py +6 -0
  5. janus/cli/diagram.py +36 -7
  6. janus/cli/document.py +10 -1
  7. janus/cli/llm.py +7 -3
  8. janus/cli/partition.py +10 -1
  9. janus/cli/pipeline.py +126 -0
  10. janus/cli/self_eval.py +10 -3
  11. janus/cli/translate.py +10 -1
  12. janus/converter/__init__.py +2 -0
  13. janus/converter/_tests/test_translate.py +6 -5
  14. janus/converter/chain.py +100 -0
  15. janus/converter/converter.py +467 -90
  16. janus/converter/diagram.py +12 -8
  17. janus/converter/document.py +17 -7
  18. janus/converter/evaluate.py +174 -147
  19. janus/converter/partition.py +6 -11
  20. janus/converter/passthrough.py +29 -0
  21. janus/converter/pool.py +74 -0
  22. janus/converter/requirements.py +7 -40
  23. janus/converter/translate.py +2 -58
  24. janus/language/_tests/test_combine.py +1 -0
  25. janus/language/block.py +115 -5
  26. janus/llm/model_callbacks.py +6 -0
  27. janus/llm/models_info.py +19 -0
  28. janus/metrics/_tests/test_reading.py +48 -4
  29. janus/metrics/_tests/test_rouge_score.py +5 -11
  30. janus/metrics/metric.py +47 -124
  31. janus/metrics/reading.py +48 -28
  32. janus/metrics/rouge_score.py +21 -34
  33. janus/parsers/_tests/test_code_parser.py +1 -1
  34. janus/parsers/code_parser.py +2 -2
  35. janus/parsers/eval_parsers/incose_parser.py +3 -3
  36. janus/parsers/reqs_parser.py +3 -3
  37. janus/prompts/templates/cyclic/human.txt +16 -0
  38. janus/prompts/templates/cyclic/system.txt +1 -0
  39. janus/prompts/templates/eval_prompts/incose/human.txt +1 -1
  40. janus/prompts/templates/extract_variables/human.txt +5 -0
  41. janus/prompts/templates/extract_variables/system.txt +1 -0
  42. {janus_llm-4.3.5.dist-info → janus_llm-4.5.4.dist-info}/METADATA +14 -15
  43. {janus_llm-4.3.5.dist-info → janus_llm-4.5.4.dist-info}/RECORD +46 -40
  44. {janus_llm-4.3.5.dist-info → janus_llm-4.5.4.dist-info}/WHEEL +1 -1
  45. janus/metrics/_tests/test_llm.py +0 -90
  46. janus/metrics/llm_metrics.py +0 -202
  47. {janus_llm-4.3.5.dist-info → janus_llm-4.5.4.dist-info}/LICENSE +0 -0
  48. {janus_llm-4.3.5.dist-info → janus_llm-4.5.4.dist-info}/entry_points.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  import functools
2
2
  import json
3
3
  import time
4
+ from copy import deepcopy
4
5
  from pathlib import Path
5
6
  from typing import Any
6
7
 
@@ -16,7 +17,7 @@ from openai import BadRequestError, RateLimitError
16
17
  from pydantic import ValidationError
17
18
 
18
19
  from janus.embedding.vectorize import ChromaDBVectorizer
19
- from janus.language.block import CodeBlock, TranslatedCodeBlock
20
+ from janus.language.block import BlockCollection, CodeBlock, TranslatedCodeBlock
20
21
  from janus.language.combine import Combiner
21
22
  from janus.language.naive.registry import CUSTOM_SPLITTERS
22
23
  from janus.language.splitter import (
@@ -76,14 +77,22 @@ class Converter:
76
77
  source_language: str = "fortran",
77
78
  max_prompts: int = 10,
78
79
  max_tokens: int | None = None,
79
- prompt_template: str = "simple",
80
+ prompt_templates: list[str] | str = ["simple"],
80
81
  db_path: str | None = None,
81
82
  db_config: dict[str, Any] | None = None,
82
83
  protected_node_types: tuple[str, ...] = (),
83
84
  prune_node_types: tuple[str, ...] = (),
84
85
  splitter_type: str = "file",
85
- refiner_types: list[type[JanusRefiner]] = [JanusRefiner],
86
+ refiner_types: list[type[JanusRefiner] | str] = [JanusRefiner],
86
87
  retriever_type: str | None = None,
88
+ combine_output: bool = True,
89
+ use_janus_inputs: bool = False,
90
+ target_language: str = "json",
91
+ target_version: str | None = None,
92
+ input_types: set[str] | str | None = None,
93
+ input_labels: set[str] | str | None = None,
94
+ output_type: str | None = None,
95
+ output_label: str | None = None,
87
96
  ) -> None:
88
97
  """Initialize a Converter instance.
89
98
 
@@ -96,7 +105,7 @@ class Converter:
96
105
  max_prompts: The maximum number of prompts to try before giving up.
97
106
  max_tokens: The maximum number of tokens to use in the LLM. If `None`, the
98
107
  converter will use half the model's token limit.
99
- prompt_template: The name of the prompt template to use.
108
+ prompt_templates: The name of the prompt templates to use.
100
109
  db_path: The path to the database to use for vectorization.
101
110
  db_config: The configuration for the database.
102
111
  protected_node_types: A set of node types that aren't to be merged.
@@ -111,12 +120,21 @@ class Converter:
111
120
  - "active_usings"
112
121
  - "language_docs"
113
122
  - None
123
+ combine_output: Whether to combine the output into a single file or not.
124
+ use_janus_inputs: Whether to use janus inputs or not.
125
+ target_language: The target programming language.
126
+ target_version: The target programming language version.
127
+ input_types: The types of input to accept.
128
+ input_labels: The labels of input to accept.
129
+ output_type: The type of output to produce.
130
+ output_label: The label of output to produce.
114
131
  """
115
132
  self._changed_attrs: set = set()
116
133
 
117
134
  self.max_prompts: int = max_prompts
118
135
  self._max_tokens: int | None = max_tokens
119
136
  self.override_token_limit: bool = max_tokens is not None
137
+ self._combine_output = combine_output
120
138
 
121
139
  self._model_name: str
122
140
  self._custom_model_arguments: dict[str, Any]
@@ -124,13 +142,16 @@ class Converter:
124
142
  self._source_language: str
125
143
  self._source_suffixes: list[str]
126
144
 
127
- self._target_language = "json"
128
- self._target_suffix = ".json"
145
+ self._target_language: str
146
+ self._target_suffix: str
147
+ self._target_version: str | None
148
+ self.set_target_language(target_language, target_version)
149
+ self._use_janus_inputs = use_janus_inputs
129
150
 
130
151
  self._protected_node_types: tuple[str, ...] = ()
131
152
  self._prune_node_types: tuple[str, ...] = ()
132
153
  self._max_tokens: int | None = max_tokens
133
- self._prompt_template_name: str
154
+ self._prompt_template_names: list[str]
134
155
  self._db_path: str | None
135
156
  self._db_config: dict[str, Any] | None
136
157
 
@@ -142,7 +163,7 @@ class Converter:
142
163
  self._combiner: Combiner = Combiner()
143
164
 
144
165
  self._splitter_type: str
145
- self._refiner_types: list[type[JanusRefiner]]
166
+ self._refiner_types: list[type[JanusRefiner] | str]
146
167
  self._retriever_type: str | None
147
168
 
148
169
  self._splitter: Splitter
@@ -153,13 +174,20 @@ class Converter:
153
174
  self.set_refiner_types(refiner_types=refiner_types)
154
175
  self.set_retriever(retriever_type=retriever_type)
155
176
  self.set_model(model_name=model, **model_arguments)
156
- self.set_prompt(prompt_template=prompt_template)
177
+ self.set_prompts(prompt_templates=prompt_templates)
157
178
  self.set_source_language(source_language)
158
179
  self.set_protected_node_types(protected_node_types)
159
180
  self.set_prune_node_types(prune_node_types)
160
181
  self.set_db_path(db_path=db_path)
161
182
  self.set_db_config(db_config=db_config)
162
183
 
184
+ self._input_types = input_types
185
+ self._input_labels = input_labels
186
+ self._output_type = output_type
187
+ self._output_label = output_label
188
+
189
+ self._load_parameters()
190
+
163
191
  # Child class must call this. Should we enforce somehow?
164
192
  # self._load_parameters()
165
193
 
@@ -174,7 +202,7 @@ class Converter:
174
202
 
175
203
  def _load_parameters(self) -> None:
176
204
  self._load_model()
177
- self._load_prompt()
205
+ self._load_translation_chain()
178
206
  self._load_retriever()
179
207
  self._load_refiner_chain()
180
208
  self._load_splitter()
@@ -195,28 +223,30 @@ class Converter:
195
223
  self._model_name = model_name
196
224
  self._custom_model_arguments = custom_arguments
197
225
 
198
- def set_prompt(self, prompt_template: str) -> None:
226
+ def set_prompts(self, prompt_templates: list[str] | str) -> None:
199
227
  """Validate and set the prompt template name.
200
228
 
201
229
  Arguments:
202
- prompt_template: name of prompt template directory
203
- (see janus/prompts/templates) or path to a directory.
230
+ prompt_templates: name of prompt template directories
231
+ (see janus/prompts/templates) or paths to directories.
204
232
  """
205
- self._prompt_template_name = prompt_template
233
+ if isinstance(prompt_templates, str):
234
+ self._prompt_template_names = [prompt_templates]
235
+ else:
236
+ self._prompt_template_names = prompt_templates
206
237
 
207
238
  def set_splitter(self, splitter_type: str) -> None:
208
239
  """Validate and set the prompt template name.
209
240
 
210
241
  Arguments:
211
- prompt_template: name of prompt template directory
212
- (see janus/prompts/templates) or path to a directory.
242
+ splitter_type: the type of splitter to use
213
243
  """
214
244
  if splitter_type not in CUSTOM_SPLITTERS:
215
245
  raise ValueError(f'Splitter type "{splitter_type}" does not exist.')
216
246
 
217
247
  self._splitter_type = splitter_type
218
248
 
219
- def set_refiner_types(self, refiner_types: list[type[JanusRefiner]]) -> None:
249
+ def set_refiner_types(self, refiner_types: list[type[JanusRefiner] | str]) -> None:
220
250
  """Validate and set the refiner type
221
251
 
222
252
  Arguments:
@@ -329,25 +359,51 @@ class Converter:
329
359
  self._max_tokens = int(token_limit * self._llm.input_token_proportion)
330
360
 
331
361
  @run_if_changed(
332
- "_prompt_template_name",
362
+ "_prompt_template_names",
333
363
  "_source_language",
334
364
  "_model_name",
335
- "_parser",
365
+ "_target_language",
366
+ "_target_version",
336
367
  )
337
- def _load_prompt(self) -> None:
338
- """Load the prompt according to this instance's attributes.
339
-
340
- If the relevant fields have not been changed since the last time this
341
- method was called, nothing happens.
342
- """
368
+ def _load_translation_chain(self) -> None:
369
+ prompt_template_name = self._prompt_template_names[0]
343
370
  prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
344
371
  source_language=self._source_language,
345
- prompt_template=self._prompt_template_name,
372
+ prompt_template=prompt_template_name,
373
+ target_language=self._target_language,
374
+ target_version=self._target_version,
346
375
  )
347
- self._prompt = prompt_engine.prompt
348
- self._prompt = self._prompt.partial(
349
- format_instructions=self._parser.get_format_instructions()
376
+ prompt = prompt_engine.prompt
377
+ self._translation_chain = RunnableParallel(
378
+ prompt_value=lambda x, prompt=prompt: prompt.invoke(x),
379
+ original_inputs=RunnablePassthrough(),
380
+ ) | RunnableParallel(
381
+ completion=lambda x: self._llm.invoke(x["prompt_value"]),
382
+ original_inputs=lambda x: x["original_inputs"],
383
+ prompt_value=lambda x: x["prompt_value"],
350
384
  )
385
+ for prompt_template_name in self._prompt_template_names[1:]:
386
+ prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
387
+ source_language=self._source_language,
388
+ prompt_template=prompt_template_name,
389
+ target_language=self._target_language,
390
+ target_version=self._target_version,
391
+ )
392
+ prompt = prompt_engine.prompt
393
+ self._translation_chain = (
394
+ self._translation_chain
395
+ | RunnableParallel(
396
+ prompt_value=lambda x, prompt=prompt: prompt.invoke(
397
+ dict(completion=x["completion"], **x["original_inputs"])
398
+ ),
399
+ original_inputs=lambda x: x["original_inputs"],
400
+ )
401
+ | RunnableParallel(
402
+ completion=lambda x: self._llm.invoke(x["prompt_value"]),
403
+ original_inputs=lambda x: x["original_inputs"],
404
+ prompt_value=lambda x: x["prompt_value"],
405
+ )
406
+ )
351
407
 
352
408
  @run_if_changed("_db_path", "_db_config")
353
409
  def _load_vectorizer(self) -> None:
@@ -370,11 +426,41 @@ class Converter:
370
426
 
371
427
  @run_if_changed("_refiner_types", "_model_name", "max_prompts", "_parser")
372
428
  def _load_refiner_chain(self) -> None:
373
- self._refiner_chain = RunnableParallel(
374
- completion=self._llm,
375
- prompt_value=RunnablePassthrough(),
376
- )
377
- for refiner_type in self._refiner_types[:-1]:
429
+ from janus.cli.constants import REFINERS
430
+
431
+ if len(self._refiner_types) == 0:
432
+ self._refiner_chain = RunnableLambda(
433
+ lambda x: self._parser.parse(x["completion"])
434
+ )
435
+ return
436
+ refiner_type = self._refiner_types[0]
437
+ if isinstance(refiner_type, str):
438
+ if refiner_type not in REFINERS:
439
+ raise ValueError(f"Error: unable to find refiner type {refiner_type}")
440
+ refiner_type = REFINERS[refiner_type]
441
+ if len(self._refiner_types) == 1:
442
+ self._refiner_chain = RunnableLambda(
443
+ lambda x, refiner_type=refiner_type: refiner_type(
444
+ llm=self._llm,
445
+ parser=self._parser,
446
+ max_retries=self.max_prompts,
447
+ ).parse_completion(**x)
448
+ )
449
+ return
450
+ else:
451
+ self._refiner_chain = RunnableParallel(
452
+ completion=lambda x, refiner_type=refiner_type: refiner_type(
453
+ llm=self._llm,
454
+ parser=self._base_parser,
455
+ max_retries=self.max_prompts,
456
+ ).parse_completion(**x),
457
+ prompt_value=lambda x: x["prompt_value"],
458
+ )
459
+ for refiner_type in self._refiner_types[1:-1]:
460
+ if isinstance(refiner_type, str):
461
+ if refiner_type not in REFINERS:
462
+ raise ValueError(f"Error: unable to find refiner type {refiner_type}")
463
+ refiner_type = REFINERS[refiner_type]
378
464
  # NOTE: Do NOT remove refiner_type=refiner_type from lambda.
379
465
  # Due to lambda capture, must be present or chain will not
380
466
  # be correctly constructed.
@@ -394,9 +480,17 @@ class Converter:
394
480
  ).parse_completion(**x)
395
481
  )
396
482
 
397
- @run_if_changed("_parser", "_retriever", "_prompt", "_llm", "_refiner_chain")
483
+ @run_if_changed(
484
+ "_parser",
485
+ "_retriever",
486
+ "_prompt",
487
+ "_llm",
488
+ "_refiner_chain",
489
+ "_target_language",
490
+ "_target_version",
491
+ )
398
492
  def _load_chain(self):
399
- self.chain = self._input_runnable() | self._prompt | self._refiner_chain
493
+ self.chain = self.get_chain()
400
494
 
401
495
  def _input_runnable(self) -> Runnable:
402
496
  return RunnableParallel(
@@ -404,6 +498,12 @@ class Converter:
404
498
  context=self._retriever,
405
499
  )
406
500
 
501
+ def get_chain(self) -> Runnable:
502
+ """
503
+ Gets a chain that can be executed by langchain
504
+ """
505
+ return self._input_runnable() | self._translation_chain | self._refiner_chain
506
+
407
507
  def translate(
408
508
  self,
409
509
  input_directory: str | Path,
@@ -436,22 +536,24 @@ class Converter:
436
536
  failure_directory.mkdir(parents=True)
437
537
 
438
538
  input_paths = []
439
- for ext in self._source_suffixes:
539
+ if self._use_janus_inputs:
540
+ source_language = "janus"
541
+ source_suffixes = [".json"]
542
+ else:
543
+ source_language = self._source_language
544
+ source_suffixes = self._source_suffixes
545
+ for ext in source_suffixes:
440
546
  input_paths.extend(input_directory.rglob(f"**/*{ext}"))
441
547
 
442
548
  log.info(f"Input directory: {input_directory.absolute()}")
443
- log.info(
444
- f"{self._source_language} {self._source_suffixes} files: "
445
- f"{len(input_paths)}"
446
- )
549
+ log.info(f"{source_language} {source_suffixes} files: " f"{len(input_paths)}")
447
550
  log.info(
448
551
  "Other files (skipped): "
449
552
  f"{len(list(input_directory.iterdir())) - len(input_paths)}\n"
450
553
  )
451
554
  if output_directory is not None:
452
555
  output_paths = [
453
- output_directory
454
- / p.relative_to(input_directory).with_suffix(self._target_suffix)
556
+ output_directory / p.relative_to(input_directory).with_suffix(".json")
455
557
  for p in input_paths
456
558
  ]
457
559
  else:
@@ -459,8 +561,7 @@ class Converter:
459
561
 
460
562
  if failure_directory is not None:
461
563
  failure_paths = [
462
- failure_directory
463
- / p.relative_to(input_directory).with_suffix(self._target_suffix)
564
+ failure_directory / p.relative_to(input_directory).with_suffix(".json")
464
565
  for p in input_paths
465
566
  ]
466
567
  else:
@@ -484,12 +585,31 @@ class Converter:
484
585
  for in_path, out_path, fail_path in in_out_pairs:
485
586
  # Translate the file, skip it if there's a rate limit error
486
587
  log.info(f"Processing {in_path.relative_to(input_directory)}")
487
- out_block = self.translate_file(in_path, fail_path)
488
- total_cost += out_block.total_cost
588
+ if self._use_janus_inputs:
589
+ out_block = self.translate_janus_file(in_path, fail_path)
590
+ else:
591
+ out_block = self.translate_file(in_path, fail_path)
592
+
593
+ def _get_total_cost(block):
594
+ if isinstance(block, list):
595
+ return sum(_get_total_cost(b) for b in block)
596
+ return block.total_cost
597
+
598
+ total_cost += _get_total_cost(out_block)
489
599
  log.info(f"Current Running Cost: {total_cost}")
490
600
 
491
- # Don't attempt to write files for which translation failed
492
- if not out_block.translated:
601
+ # For files where translation failed, write to failure path instead
602
+
603
+ def _has_empty(block):
604
+ if isinstance(block, BlockCollection):
605
+ return len(block.blocks) == 0 or any(
606
+ _has_empty(b) for b in block.blocks
607
+ )
608
+ return not block.translated
609
+
610
+ if _has_empty(out_block):
611
+ if fail_path is not None:
612
+ self._save_to_file(out_block, fail_path)
493
613
  continue
494
614
 
495
615
  if collection_name is not None:
@@ -501,37 +621,58 @@ class Converter:
501
621
 
502
622
  # Make sure the tree's code has been consolidated at the top level
503
623
  # before writing to file
504
- self._combiner.combine(out_block)
624
+ for b in out_block.blocks:
625
+ self._combiner.combine(b)
505
626
  if out_path is not None and (overwrite or not out_path.exists()):
506
627
  self._save_to_file(out_block, out_path)
507
628
 
508
629
  log.info(f"Total cost: ${total_cost:,.2f}")
509
630
 
510
- def translate_file(
511
- self, file: Path, failure_path: Path | None = None
512
- ) -> TranslatedCodeBlock:
513
- """Translate a single file.
514
-
515
- Arguments:
516
- file: Input path to file
517
- failure_path: path to directory to store failure summaries`
631
+ def _filter_blocks(self, code_block):
632
+ if isinstance(code_block, BlockCollection):
633
+ input_blocks = list(code_block.blocks)
634
+ else:
635
+ input_blocks = [code_block]
636
+ if self._input_types is not None:
637
+ if isinstance(self._input_types, str):
638
+ self._input_types = set([self._input_types])
639
+ input_blocks = [
640
+ b
641
+ for b in input_blocks
642
+ if isinstance(b, BlockCollection) or b.block_type in self._input_types
643
+ ]
644
+ if self._input_labels is not None:
645
+ if isinstance(self._input_labels, str):
646
+ self._input_labels = set([self._input_labels])
647
+ input_blocks = [
648
+ b
649
+ for b in input_blocks
650
+ if isinstance(b, BlockCollection) or b.block_label in self._input_labels
651
+ ]
652
+ return input_blocks
518
653
 
519
- Returns:
520
- A `TranslatedCodeBlock` object. This block does not have a path set, and its
521
- code is not guaranteed to be consolidated. To amend this, run
522
- `Combiner.combine_children` on the block.
523
- """
654
+ def translate_blocks(
655
+ self,
656
+ code_block: CodeBlock | BlockCollection,
657
+ failure_path: Path | None = None,
658
+ ) -> BlockCollection | TranslatedCodeBlock:
659
+ input_blocks = self._filter_blocks(code_block)
660
+ output_blocks = []
661
+ for b in input_blocks:
662
+ output_blocks.append(self.translate_block(b, failure_path))
663
+ return BlockCollection(output_blocks, code_block.previous_generations)
664
+
665
+ def translate_block(
666
+ self,
667
+ input_block: CodeBlock,
668
+ failure_path: Path | None = None,
669
+ ) -> TranslatedCodeBlock:
524
670
  self._load_parameters()
525
- filename = file.name
526
-
527
- input_block = self._split_file(file)
528
- t0 = time.time()
529
671
  output_block = self._iterative_translate(input_block, failure_path)
530
- output_block.processing_time = time.time() - t0
531
672
  if output_block.translated:
532
673
  completeness = output_block.translation_completeness
533
674
  log.info(
534
- f"[{filename}] Translation complete\n"
675
+ f"[{output_block.name}] Translation complete\n"
535
676
  f" {completeness:.2%} of input successfully translated\n"
536
677
  f" Total cost: ${output_block.total_cost:,.2f}\n"
537
678
  f" Output CodeBlock Structure:\n{input_block.tree_str()}\n"
@@ -539,11 +680,51 @@ class Converter:
539
680
 
540
681
  else:
541
682
  log.error(
542
- f"[{filename}] Translation failed\n"
683
+ f"[{output_block.name}] Translation failed\n"
543
684
  f" Total cost: ${output_block.total_cost:,.2f}\n"
544
685
  )
545
686
  return output_block
546
687
 
688
+ def translate_file(
689
+ self,
690
+ file: Path,
691
+ failure_path: Path | None = None,
692
+ ) -> TranslatedCodeBlock:
693
+ """Translate a single file.
694
+
695
+ Arguments:
696
+ file: Input path to file
697
+ failure_path: path to directory to store failure summaries`
698
+
699
+ Returns:
700
+ A `TranslatedCodeBlock` object. This block does not have a path set, and its
701
+ code is not guaranteed to be consolidated. To amend this, run
702
+ `Combiner.combine_children` on the block.
703
+ """
704
+ input_block = self._split_file(file)
705
+ return self.translate_blocks(input_block, failure_path)
706
+
707
+ def translate_janus_file(self, file: Path, failure_path: Path | None = None):
708
+ filename = file.name
709
+ with open(file, "r") as f:
710
+ file_obj = json.load(f)
711
+ return self.translate_janus_obj(file_obj, filename, failure_path)
712
+
713
+ def translate_janus_obj(self, obj: Any, name: str, failure_path: Path | None = None):
714
+ block = self._janus_object_to_codeblock(obj, name)
715
+ return self.translate_blocks(block, failure_path)
716
+
717
+ def translate_text(self, text: str, name: str, failure_path: Path | None = None):
718
+ """
719
+ Translates given text
720
+ Arguments:
721
+ text: text to translate
722
+ name: the name of the text (filename if from a file)
723
+ failure_path: path to write failure file if translation is not successful
724
+ """
725
+ input_block = self._split_text(text, name)
726
+ return self.translate_blocks(input_block, failure_path)
727
+
547
728
  def _iterative_translate(
548
729
  self, root: CodeBlock, failure_path: Path | None = None
549
730
  ) -> TranslatedCodeBlock:
@@ -556,7 +737,13 @@ class Converter:
556
737
  Returns:
557
738
  A `TranslatedCodeBlock`
558
739
  """
559
- translated_root = TranslatedCodeBlock(root, self._target_language)
740
+ translated_root = TranslatedCodeBlock(
741
+ root,
742
+ self._target_language,
743
+ self,
744
+ block_type=self._output_type,
745
+ block_label=self._output_label,
746
+ )
560
747
  last_prog, prog_delta = 0, 0.1
561
748
  stack = [translated_root]
562
749
  try:
@@ -579,7 +766,7 @@ class Converter:
579
766
  except RateLimitError:
580
767
  pass
581
768
  except OutputParserException as e:
582
- log.error(f"Skipping file, failed to parse output: {e}.")
769
+ log.error(f"Skipping file, failed to parse output: {e}")
583
770
  except BadRequestError as e:
584
771
  if str(e).startswith("Detected an error in the prompt"):
585
772
  log.warning("Malformed input, skipping")
@@ -607,9 +794,10 @@ class Converter:
607
794
  )
608
795
  raise e
609
796
  finally:
610
- log.debug(
611
- f"Resulting Block: {json.dumps(self._get_output_obj(translated_root))}"
797
+ out_obj = self._get_output_obj(
798
+ translated_root, self._combine_output, include_previous_outputs=True
612
799
  )
800
+ log.debug(f"Resulting Block:" f"{json.dumps(out_obj)}")
613
801
  if not translated_root.translated:
614
802
  if failure_path is not None:
615
803
  self._save_to_file(translated_root, failure_path)
@@ -666,6 +854,16 @@ class Converter:
666
854
 
667
855
  log.debug(f"[{block.name}] Output code:\n{block.text}")
668
856
 
857
+ def _split_text(self, text: str, name: str) -> CodeBlock:
858
+ log.info(f"[{name}] Splitting text")
859
+ root = self._splitter.split_string(text, name)
860
+ log.info(
861
+ f"[{name}] Text split into {root.n_descendents:,} blocks,"
862
+ f"tree of height {root.height}"
863
+ )
864
+ log.info(f"[{name}] Input CodeBlock Structure:\n{root.tree_str()}")
865
+ return root
866
+
669
867
  def _split_file(self, file: Path) -> CodeBlock:
670
868
  filename = file.name
671
869
  log.info(f"[{filename}] Splitting file")
@@ -680,31 +878,113 @@ class Converter:
680
878
  def _run_chain(self, block: TranslatedCodeBlock) -> str:
681
879
  return self.chain.invoke(block.original)
682
880
 
881
+ def _combine_metadata(self, metadatas: list[dict]):
882
+ return dict(
883
+ cost=sum(m["cost"] for m in metadatas),
884
+ processing_time=sum(m["processing_time"] for m in metadatas),
885
+ num_requests=sum(m["num_requests"] for m in metadatas),
886
+ input_tokens=sum(m["input_tokens"] for m in metadatas),
887
+ output_tokens=sum(m["output_tokens"] for m in metadatas),
888
+ converter_name=self.__class__.__name__,
889
+ type=[m["type"] for m in metadatas],
890
+ label=[m["label"] for m in metadatas],
891
+ )
892
+
893
+ def _combine_inputs(self, inputs: list[str]):
894
+ return json.dumps(inputs)
895
+
683
896
  def _get_output_obj(
684
- self, block: TranslatedCodeBlock
897
+ self,
898
+ block: TranslatedCodeBlock | BlockCollection | dict,
899
+ combine_children: bool = True,
900
+ include_previous_outputs: bool = True,
685
901
  ) -> dict[str, int | float | str | dict[str, str] | dict[str, float]]:
686
- output_obj: str | dict[str, str]
687
- if not block.translation_completed:
688
- # translation wasn't completed, so combined parsing will likely fail
689
- output_obj = block.complete_text
902
+ block_type = None
903
+ block_label = None
904
+ if isinstance(block, dict):
905
+ # output object has already been generated
906
+ new_block = deepcopy(block)
907
+ if "intermediate_outputs" in new_block:
908
+ del new_block["intermediate_outputs"]
909
+ return new_block
910
+ if isinstance(block, BlockCollection):
911
+ if len(block.blocks) == 1:
912
+ outputs = self._get_output_obj(block.blocks[0], combine_children, False)[
913
+ "outputs"
914
+ ]
915
+ block_type = block.blocks[0].block_type
916
+ block_label = block.blocks[0].block_label
917
+ else:
918
+ outputs = [
919
+ self._get_output_obj(b, combine_children, False) for b in block.blocks
920
+ ]
921
+ elif (
922
+ not isinstance(block, BlockCollection)
923
+ and not combine_children
924
+ and len(block.children) > 0
925
+ ):
926
+ outputs = self._get_output_obj_children(block, False)
690
927
  else:
691
- output_str = self._parser.parse_combined_output(block.complete_text)
692
- try:
693
- output_obj = json.loads(output_str)
694
- except json.JSONDecodeError:
695
- output_obj = output_str
696
-
697
- return dict(
698
- input=block.original.text or "",
928
+ block_type = block.block_type
929
+ block_label = block.block_label
930
+ if not block.translation_completed:
931
+ # translation wasn't completed, so combined parsing will likely fail
932
+ outputs = [block.complete_text]
933
+ else:
934
+ output_str = self._parser.parse_combined_output(block.complete_text)
935
+ outputs = [output_str]
936
+
937
+ def _get_input(block):
938
+ if isinstance(block, BlockCollection):
939
+ return self._combine_inputs([_get_input(b) for b in block.blocks])
940
+ return block.original.text or ""
941
+
942
+ out = dict(
943
+ input=_get_input(block),
699
944
  metadata=dict(
700
945
  cost=block.total_cost,
701
- processing_time=block.processing_time,
946
+ processing_time=block.total_processing_time,
702
947
  num_requests=block.total_num_requests,
703
948
  input_tokens=block.total_request_input_tokens,
704
949
  output_tokens=block.total_request_output_tokens,
950
+ converter_name=self.__class__.__name__,
951
+ type=block_type,
952
+ label=block_label,
705
953
  ),
706
- output=output_obj,
954
+ outputs=outputs,
707
955
  )
956
+ if (
957
+ include_previous_outputs
958
+ and isinstance(block, BlockCollection)
959
+ and len(block.previous_generations) > 0
960
+ ):
961
+ intermediate_outputs = []
962
+ for p in block.previous_generations:
963
+ if isinstance(p, dict):
964
+ # preserve intermediate outputs from previous runs
965
+ intermediate_outputs.append(
966
+ self._get_output_obj(p, combine_children, False)
967
+ )
968
+ if len(intermediate_outputs) > 0:
969
+ out["intermediate_outputs"] = intermediate_outputs
970
+ return out
971
+
972
+ def _get_output_obj_children(
973
+ self, block: TranslatedCodeBlock, include_previous_outputs: bool = True
974
+ ):
975
+ if len(block.children) > 0:
976
+ res = []
977
+ for c in block.children:
978
+ res += self._get_output_obj_children(c, include_previous_outputs)
979
+ return res
980
+ else:
981
+ return [
982
+ self._get_output_obj(
983
+ block,
984
+ combine_children=True,
985
+ include_previous_outputs=include_previous_outputs,
986
+ )
987
+ ]
708
988
 
709
989
  def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
710
990
  """Save a file to disk.
@@ -712,6 +992,103 @@ class Converter:
712
992
  Arguments:
713
993
  block: The `TranslatedCodeBlock` to save to a file.
714
994
  """
715
- obj = self._get_output_obj(block)
995
+ obj = self._get_output_obj(
996
+ block, combine_children=self._combine_output, include_previous_outputs=True
997
+ )
716
998
  out_path.parent.mkdir(parents=True, exist_ok=True)
717
999
  out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")
1000
+
1001
+ def _janus_object_to_codeblock(self, janus_obj: dict, name: str):
1002
+ results = []
1003
+ for o in janus_obj["outputs"]:
1004
+ metadata = janus_obj["metadata"]
1005
+ if isinstance(o, str):
1006
+ block_label = metadata["label"]
1007
+ if isinstance(block_label, list):
1008
+ block_label = block_label[0]
1009
+ block_type = metadata["type"]
1010
+ if isinstance(block_type, list):
1011
+ block_type = block_type[0]
1012
+ code_block = self._split_text(o, name)
1013
+ code_block.previous_generations = janus_obj.get(
1014
+ "intermediate_outputs", []
1015
+ ) + [janus_obj]
1016
+ code_block.block_type = block_type
1017
+ code_block.block_label = block_label
1018
+ results.append(code_block)
1019
+ else:
1020
+ results += self._janus_object_to_codeblock(o, name).blocks
1021
+ previous_generations = janus_obj.get("intermediate_outputs", [])
1022
+ if janus_obj["metadata"]["converter_name"] != "ConverterChain":
1023
+ previous_generations += [janus_obj]
1024
+ return BlockCollection(results, previous_generations)
1025
+
1026
+ def __or__(self, other: "Converter"):
1027
+ from janus.converter.chain import ConverterChain
1028
+
1029
+ return ConverterChain(self, other)
1030
+
1031
+ @property
1032
+ def source_language(self):
1033
+ return self._source_language
1034
+
1035
+ @property
1036
+ def target_language(self):
1037
+ return self._target_language
1038
+
1039
+ @property
1040
+ def target_version(self):
1041
+ return self._target_version
1042
+
1043
+ def set_target_language(
1044
+ self, target_language: str, target_version: str | None
1045
+ ) -> None:
1046
+ """Validate and set the target language.
1047
+
1048
+ The affected objects will not be updated until translate() is called.
1049
+
1050
+ Arguments:
1051
+ target_language: The target programming language.
1052
+ target_version: The target version of the target programming language.
1053
+ """
1054
+ target_language = target_language.lower()
1055
+ if target_language not in LANGUAGES:
1056
+ raise ValueError(
1057
+ f"Invalid target language: {target_language}. "
1058
+ "Valid target languages are found in `janus.utils.enums.LANGUAGES`."
1059
+ )
1060
+ self._target_language = target_language
1061
+ self._target_version = target_version
1062
+ # Taking the first suffix as the default for output files
1063
+ self._target_suffix = f".{LANGUAGES[target_language]['suffixes'][0]}"
1064
+
1065
+ @classmethod
1066
+ def eval_obj(cls, target, metric_func, *args, **kwargs):
1067
+ if "reference" in kwargs:
1068
+ return cls.eval_obj_reference(target, metric_func, *args, **kwargs)
1069
+ else:
1070
+ return cls.eval_obj_noreference(target, metric_func, *args, **kwargs)
1071
+
1072
+ @classmethod
1073
+ def eval_obj_noreference(cls, target, metric_func, *args, **kwargs):
1074
+ results = []
1075
+ for o in target["outputs"]:
1076
+ if isinstance(o, dict):
1077
+ results += cls.eval_obj_noreference(o, metric_func, *args, **kwargs)
1078
+ else:
1079
+ results.append(metric_func(o, *args, **kwargs))
1080
+ return results
1081
+
1082
+ @classmethod
1083
+ def eval_obj_reference(cls, target, metric_func, reference, *args, **kwargs):
1084
+ results = []
1085
+ for o, r in zip(target["outputs"], reference["outputs"]):
1086
+ if isinstance(o, dict):
1087
+ if not isinstance(r, dict):
1088
+ raise ValueError("Error: format of reference doesn't match target")
1089
+ results += cls.eval_obj_reference(o, metric_func, r, *args, **kwargs)
1090
+ else:
1091
+ if isinstance(r, dict):
1092
+ raise ValueError("Error: format of reference doesn't match target")
1093
+ results.append(metric_func(o, r, *args, **kwargs))
1094
+ return results