janus-llm 4.3.5__py3-none-any.whl → 4.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +1 -1
- janus/cli/aggregate.py +2 -2
- janus/cli/cli.py +6 -0
- janus/cli/constants.py +6 -0
- janus/cli/diagram.py +36 -7
- janus/cli/document.py +10 -1
- janus/cli/llm.py +7 -3
- janus/cli/partition.py +10 -1
- janus/cli/pipeline.py +123 -0
- janus/cli/self_eval.py +1 -3
- janus/cli/translate.py +10 -1
- janus/converter/_tests/test_translate.py +5 -5
- janus/converter/chain.py +180 -0
- janus/converter/converter.py +333 -78
- janus/converter/diagram.py +8 -6
- janus/converter/document.py +7 -3
- janus/converter/evaluate.py +140 -148
- janus/converter/partition.py +2 -10
- janus/converter/requirements.py +4 -40
- janus/converter/translate.py +2 -58
- janus/language/block.py +31 -2
- janus/metrics/metric.py +47 -124
- janus/parsers/reqs_parser.py +3 -3
- {janus_llm-4.3.5.dist-info → janus_llm-4.4.5.dist-info}/METADATA +12 -12
- {janus_llm-4.3.5.dist-info → janus_llm-4.4.5.dist-info}/RECORD +28 -28
- janus/metrics/_tests/test_llm.py +0 -90
- janus/metrics/llm_metrics.py +0 -202
- {janus_llm-4.3.5.dist-info → janus_llm-4.4.5.dist-info}/LICENSE +0 -0
- {janus_llm-4.3.5.dist-info → janus_llm-4.4.5.dist-info}/WHEEL +0 -0
- {janus_llm-4.3.5.dist-info → janus_llm-4.4.5.dist-info}/entry_points.txt +0 -0
janus/converter/converter.py
CHANGED
@@ -76,7 +76,7 @@ class Converter:
|
|
76
76
|
source_language: str = "fortran",
|
77
77
|
max_prompts: int = 10,
|
78
78
|
max_tokens: int | None = None,
|
79
|
-
|
79
|
+
prompt_templates: list[str] | str = ["simple"],
|
80
80
|
db_path: str | None = None,
|
81
81
|
db_config: dict[str, Any] | None = None,
|
82
82
|
protected_node_types: tuple[str, ...] = (),
|
@@ -84,6 +84,10 @@ class Converter:
|
|
84
84
|
splitter_type: str = "file",
|
85
85
|
refiner_types: list[type[JanusRefiner]] = [JanusRefiner],
|
86
86
|
retriever_type: str | None = None,
|
87
|
+
combine_output: bool = True,
|
88
|
+
use_janus_inputs: bool = False,
|
89
|
+
target_language: str = "json",
|
90
|
+
target_version: str | None = None,
|
87
91
|
) -> None:
|
88
92
|
"""Initialize a Converter instance.
|
89
93
|
|
@@ -96,7 +100,7 @@ class Converter:
|
|
96
100
|
max_prompts: The maximum number of prompts to try before giving up.
|
97
101
|
max_tokens: The maximum number of tokens to use in the LLM. If `None`, the
|
98
102
|
converter will use half the model's token limit.
|
99
|
-
|
103
|
+
prompt_templates: The name of the prompt templates to use.
|
100
104
|
db_path: The path to the database to use for vectorization.
|
101
105
|
db_config: The configuration for the database.
|
102
106
|
protected_node_types: A set of node types that aren't to be merged.
|
@@ -111,12 +115,17 @@ class Converter:
|
|
111
115
|
- "active_usings"
|
112
116
|
- "language_docs"
|
113
117
|
- None
|
118
|
+
combine_output: Whether to combine the output into a single file or not.
|
119
|
+
use_janus_inputs: Whether to use janus inputs or not.
|
120
|
+
target_language: The target programming language.
|
121
|
+
target_version: The target programming language version.
|
114
122
|
"""
|
115
123
|
self._changed_attrs: set = set()
|
116
124
|
|
117
125
|
self.max_prompts: int = max_prompts
|
118
126
|
self._max_tokens: int | None = max_tokens
|
119
127
|
self.override_token_limit: bool = max_tokens is not None
|
128
|
+
self._combine_output = combine_output
|
120
129
|
|
121
130
|
self._model_name: str
|
122
131
|
self._custom_model_arguments: dict[str, Any]
|
@@ -124,13 +133,16 @@ class Converter:
|
|
124
133
|
self._source_language: str
|
125
134
|
self._source_suffixes: list[str]
|
126
135
|
|
127
|
-
self._target_language
|
128
|
-
self._target_suffix
|
136
|
+
self._target_language: str
|
137
|
+
self._target_suffix: str
|
138
|
+
self._target_version: str | None
|
139
|
+
self.set_target_language(target_language, target_version)
|
140
|
+
self._use_janus_inputs = use_janus_inputs
|
129
141
|
|
130
142
|
self._protected_node_types: tuple[str, ...] = ()
|
131
143
|
self._prune_node_types: tuple[str, ...] = ()
|
132
144
|
self._max_tokens: int | None = max_tokens
|
133
|
-
self.
|
145
|
+
self._prompt_template_names: list[str]
|
134
146
|
self._db_path: str | None
|
135
147
|
self._db_config: dict[str, Any] | None
|
136
148
|
|
@@ -153,7 +165,7 @@ class Converter:
|
|
153
165
|
self.set_refiner_types(refiner_types=refiner_types)
|
154
166
|
self.set_retriever(retriever_type=retriever_type)
|
155
167
|
self.set_model(model_name=model, **model_arguments)
|
156
|
-
self.
|
168
|
+
self.set_prompts(prompt_templates=prompt_templates)
|
157
169
|
self.set_source_language(source_language)
|
158
170
|
self.set_protected_node_types(protected_node_types)
|
159
171
|
self.set_prune_node_types(prune_node_types)
|
@@ -174,7 +186,7 @@ class Converter:
|
|
174
186
|
|
175
187
|
def _load_parameters(self) -> None:
|
176
188
|
self._load_model()
|
177
|
-
self.
|
189
|
+
self._load_translation_chain()
|
178
190
|
self._load_retriever()
|
179
191
|
self._load_refiner_chain()
|
180
192
|
self._load_splitter()
|
@@ -195,21 +207,23 @@ class Converter:
|
|
195
207
|
self._model_name = model_name
|
196
208
|
self._custom_model_arguments = custom_arguments
|
197
209
|
|
198
|
-
def
|
210
|
+
def set_prompts(self, prompt_templates: list[str] | str) -> None:
|
199
211
|
"""Validate and set the prompt template name.
|
200
212
|
|
201
213
|
Arguments:
|
202
|
-
|
203
|
-
(see janus/prompts/templates) or
|
214
|
+
prompt_templates: name of prompt template directories
|
215
|
+
(see janus/prompts/templates) or paths to directories.
|
204
216
|
"""
|
205
|
-
|
217
|
+
if isinstance(prompt_templates, str):
|
218
|
+
self._prompt_template_names = [prompt_templates]
|
219
|
+
else:
|
220
|
+
self._prompt_template_names = prompt_templates
|
206
221
|
|
207
222
|
def set_splitter(self, splitter_type: str) -> None:
|
208
223
|
"""Validate and set the prompt template name.
|
209
224
|
|
210
225
|
Arguments:
|
211
|
-
|
212
|
-
(see janus/prompts/templates) or path to a directory.
|
226
|
+
splitter_type: the type of splitter to use
|
213
227
|
"""
|
214
228
|
if splitter_type not in CUSTOM_SPLITTERS:
|
215
229
|
raise ValueError(f'Splitter type "{splitter_type}" does not exist.')
|
@@ -328,26 +342,46 @@ class Converter:
|
|
328
342
|
if not self.override_token_limit:
|
329
343
|
self._max_tokens = int(token_limit * self._llm.input_token_proportion)
|
330
344
|
|
331
|
-
@run_if_changed(
|
332
|
-
|
333
|
-
|
334
|
-
"_model_name",
|
335
|
-
"_parser",
|
336
|
-
)
|
337
|
-
def _load_prompt(self) -> None:
|
338
|
-
"""Load the prompt according to this instance's attributes.
|
339
|
-
|
340
|
-
If the relevant fields have not been changed since the last time this
|
341
|
-
method was called, nothing happens.
|
342
|
-
"""
|
345
|
+
@run_if_changed("_prompt_template_names", "_source_language", "_model_name")
|
346
|
+
def _load_translation_chain(self) -> None:
|
347
|
+
prompt_template_name = self._prompt_template_names[0]
|
343
348
|
prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
|
344
349
|
source_language=self._source_language,
|
345
|
-
prompt_template=
|
350
|
+
prompt_template=prompt_template_name,
|
351
|
+
target_language=self._target_language,
|
352
|
+
target_version=self._target_version,
|
346
353
|
)
|
347
|
-
|
348
|
-
self.
|
349
|
-
|
354
|
+
prompt = prompt_engine.prompt
|
355
|
+
self._translation_chain = RunnableParallel(
|
356
|
+
prompt_value=lambda x, prompt=prompt: prompt.invoke(x),
|
357
|
+
original_inputs=RunnablePassthrough(),
|
358
|
+
) | RunnableParallel(
|
359
|
+
completion=lambda x: self._llm.invoke(x["prompt_value"]),
|
360
|
+
original_inputs=lambda x: x["original_inputs"],
|
361
|
+
prompt_value=lambda x: x["prompt_value"],
|
350
362
|
)
|
363
|
+
for prompt_template_name in self._prompt_template_names[1:]:
|
364
|
+
prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
|
365
|
+
source_language=self._source_language,
|
366
|
+
prompt_template=prompt_template_name,
|
367
|
+
target_language=self._target_language,
|
368
|
+
target_version=self._target_version,
|
369
|
+
)
|
370
|
+
prompt = prompt_engine.prompt
|
371
|
+
self._translation_chain = (
|
372
|
+
self._translation_chain
|
373
|
+
| RunnableParallel(
|
374
|
+
prompt_value=lambda x, prompt=prompt: prompt.invoke(
|
375
|
+
dict(completion=x["completion"], **x["original_inputs"])
|
376
|
+
),
|
377
|
+
original_inputs=lambda x: x["original_inputs"],
|
378
|
+
)
|
379
|
+
| RunnableParallel(
|
380
|
+
completion=lambda x: self._llm.invoke(x["prompt_value"]),
|
381
|
+
original_inputs=lambda x: x["original_inputs"],
|
382
|
+
prompt_value=lambda x: x["prompt_value"],
|
383
|
+
)
|
384
|
+
)
|
351
385
|
|
352
386
|
@run_if_changed("_db_path", "_db_config")
|
353
387
|
def _load_vectorizer(self) -> None:
|
@@ -370,11 +404,31 @@ class Converter:
|
|
370
404
|
|
371
405
|
@run_if_changed("_refiner_types", "_model_name", "max_prompts", "_parser")
|
372
406
|
def _load_refiner_chain(self) -> None:
|
373
|
-
self.
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
407
|
+
if len(self._refiner_types) == 0:
|
408
|
+
self._refiner_chain = RunnableLambda(
|
409
|
+
lambda x: self._parser.parse(x["completion"])
|
410
|
+
)
|
411
|
+
return
|
412
|
+
refiner_type = self._refiner_types[0]
|
413
|
+
if len(self._refiner_types) == 1:
|
414
|
+
self._refiner_chain = RunnableLambda(
|
415
|
+
lambda x, refiner_type=refiner_type: refiner_type(
|
416
|
+
llm=self._llm,
|
417
|
+
parser=self._parser,
|
418
|
+
max_retries=self.max_prompts,
|
419
|
+
).parse_completion(**x)
|
420
|
+
)
|
421
|
+
return
|
422
|
+
else:
|
423
|
+
self._refiner_chain = RunnableParallel(
|
424
|
+
completion=lambda x, refiner_type=refiner_type: refiner_type(
|
425
|
+
llm=self._llm,
|
426
|
+
parser=self._base_parser,
|
427
|
+
max_retries=self.max_prompts,
|
428
|
+
).parse_completion(**x),
|
429
|
+
prompt_value=lambda x: x["prompt_value"],
|
430
|
+
)
|
431
|
+
for refiner_type in self._refiner_types[1:-1]:
|
378
432
|
# NOTE: Do NOT remove refiner_type=refiner_type from lambda.
|
379
433
|
# Due to lambda capture, must be present or chain will not
|
380
434
|
# be correctly constructed.
|
@@ -396,7 +450,7 @@ class Converter:
|
|
396
450
|
|
397
451
|
@run_if_changed("_parser", "_retriever", "_prompt", "_llm", "_refiner_chain")
|
398
452
|
def _load_chain(self):
|
399
|
-
self.chain = self.
|
453
|
+
self.chain = self.get_chain()
|
400
454
|
|
401
455
|
def _input_runnable(self) -> Runnable:
|
402
456
|
return RunnableParallel(
|
@@ -404,6 +458,12 @@ class Converter:
|
|
404
458
|
context=self._retriever,
|
405
459
|
)
|
406
460
|
|
461
|
+
def get_chain(self) -> Runnable:
|
462
|
+
"""
|
463
|
+
Gets a chain that can be executed by langchain
|
464
|
+
"""
|
465
|
+
return self._input_runnable() | self._translation_chain | self._refiner_chain
|
466
|
+
|
407
467
|
def translate(
|
408
468
|
self,
|
409
469
|
input_directory: str | Path,
|
@@ -436,22 +496,24 @@ class Converter:
|
|
436
496
|
failure_directory.mkdir(parents=True)
|
437
497
|
|
438
498
|
input_paths = []
|
439
|
-
|
499
|
+
if self._use_janus_inputs:
|
500
|
+
source_language = "janus"
|
501
|
+
source_suffixes = [".json"]
|
502
|
+
else:
|
503
|
+
source_language = self._source_language
|
504
|
+
source_suffixes = self._source_suffixes
|
505
|
+
for ext in source_suffixes:
|
440
506
|
input_paths.extend(input_directory.rglob(f"**/*{ext}"))
|
441
507
|
|
442
508
|
log.info(f"Input directory: {input_directory.absolute()}")
|
443
|
-
log.info(
|
444
|
-
f"{self._source_language} {self._source_suffixes} files: "
|
445
|
-
f"{len(input_paths)}"
|
446
|
-
)
|
509
|
+
log.info(f"{source_language} {source_suffixes} files: " f"{len(input_paths)}")
|
447
510
|
log.info(
|
448
511
|
"Other files (skipped): "
|
449
512
|
f"{len(list(input_directory.iterdir())) - len(input_paths)}\n"
|
450
513
|
)
|
451
514
|
if output_directory is not None:
|
452
515
|
output_paths = [
|
453
|
-
output_directory
|
454
|
-
/ p.relative_to(input_directory).with_suffix(self._target_suffix)
|
516
|
+
output_directory / p.relative_to(input_directory).with_suffix(".json")
|
455
517
|
for p in input_paths
|
456
518
|
]
|
457
519
|
else:
|
@@ -459,8 +521,7 @@ class Converter:
|
|
459
521
|
|
460
522
|
if failure_directory is not None:
|
461
523
|
failure_paths = [
|
462
|
-
failure_directory
|
463
|
-
/ p.relative_to(input_directory).with_suffix(self._target_suffix)
|
524
|
+
failure_directory / p.relative_to(input_directory).with_suffix(".json")
|
464
525
|
for p in input_paths
|
465
526
|
]
|
466
527
|
else:
|
@@ -484,12 +545,32 @@ class Converter:
|
|
484
545
|
for in_path, out_path, fail_path in in_out_pairs:
|
485
546
|
# Translate the file, skip it if there's a rate limit error
|
486
547
|
log.info(f"Processing {in_path.relative_to(input_directory)}")
|
487
|
-
|
488
|
-
|
548
|
+
if self._use_janus_inputs:
|
549
|
+
out_block = self.translate_janus_file(in_path, fail_path)
|
550
|
+
else:
|
551
|
+
out_block = self.translate_file(in_path, fail_path)
|
552
|
+
|
553
|
+
def _get_total_cost(block):
|
554
|
+
if isinstance(block, list):
|
555
|
+
return sum(_get_total_cost(b) for b in block)
|
556
|
+
return block.total_cost
|
557
|
+
|
558
|
+
total_cost += _get_total_cost(out_block)
|
489
559
|
log.info(f"Current Running Cost: {total_cost}")
|
490
560
|
|
491
|
-
#
|
492
|
-
|
561
|
+
# For files where translation failed, write to failure path instead
|
562
|
+
|
563
|
+
def _has_empty(block):
|
564
|
+
if isinstance(block, list):
|
565
|
+
return len(block) == 0 or any(_has_empty(b) for b in block)
|
566
|
+
return not block.translated
|
567
|
+
|
568
|
+
while isinstance(out_block, list) and len(out_block) == 1:
|
569
|
+
out_block = out_block[0]
|
570
|
+
|
571
|
+
if _has_empty(out_block):
|
572
|
+
if fail_path is not None:
|
573
|
+
self._save_to_file(out_block, fail_path)
|
493
574
|
continue
|
494
575
|
|
495
576
|
if collection_name is not None:
|
@@ -507,31 +588,22 @@ class Converter:
|
|
507
588
|
|
508
589
|
log.info(f"Total cost: ${total_cost:,.2f}")
|
509
590
|
|
510
|
-
def
|
511
|
-
self,
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
file: Input path to file
|
517
|
-
failure_path: path to directory to store failure summaries`
|
518
|
-
|
519
|
-
Returns:
|
520
|
-
A `TranslatedCodeBlock` object. This block does not have a path set, and its
|
521
|
-
code is not guaranteed to be consolidated. To amend this, run
|
522
|
-
`Combiner.combine_children` on the block.
|
523
|
-
"""
|
591
|
+
def translate_block(
|
592
|
+
self,
|
593
|
+
input_block: CodeBlock | list[CodeBlock],
|
594
|
+
name: str,
|
595
|
+
failure_path: Path | None = None,
|
596
|
+
):
|
524
597
|
self._load_parameters()
|
525
|
-
|
526
|
-
|
527
|
-
input_block = self._split_file(file)
|
598
|
+
if isinstance(input_block, list):
|
599
|
+
return [self.translate_block(b, name, failure_path) for b in input_block]
|
528
600
|
t0 = time.time()
|
529
601
|
output_block = self._iterative_translate(input_block, failure_path)
|
530
602
|
output_block.processing_time = time.time() - t0
|
531
603
|
if output_block.translated:
|
532
604
|
completeness = output_block.translation_completeness
|
533
605
|
log.info(
|
534
|
-
f"[{
|
606
|
+
f"[{name}] Translation complete\n"
|
535
607
|
f" {completeness:.2%} of input successfully translated\n"
|
536
608
|
f" Total cost: ${output_block.total_cost:,.2f}\n"
|
537
609
|
f" Output CodeBlock Structure:\n{input_block.tree_str()}\n"
|
@@ -539,11 +611,52 @@ class Converter:
|
|
539
611
|
|
540
612
|
else:
|
541
613
|
log.error(
|
542
|
-
f"[{
|
614
|
+
f"[{name}] Translation failed\n"
|
543
615
|
f" Total cost: ${output_block.total_cost:,.2f}\n"
|
544
616
|
)
|
545
617
|
return output_block
|
546
618
|
|
619
|
+
def translate_file(
|
620
|
+
self,
|
621
|
+
file: Path,
|
622
|
+
failure_path: Path | None = None,
|
623
|
+
) -> TranslatedCodeBlock:
|
624
|
+
"""Translate a single file.
|
625
|
+
|
626
|
+
Arguments:
|
627
|
+
file: Input path to file
|
628
|
+
failure_path: path to directory to store failure summaries`
|
629
|
+
|
630
|
+
Returns:
|
631
|
+
A `TranslatedCodeBlock` object. This block does not have a path set, and its
|
632
|
+
code is not guaranteed to be consolidated. To amend this, run
|
633
|
+
`Combiner.combine_children` on the block.
|
634
|
+
"""
|
635
|
+
filename = file.name
|
636
|
+
input_block = self._split_file(file)
|
637
|
+
return self.translate_block(input_block, filename, failure_path)
|
638
|
+
|
639
|
+
def translate_janus_file(self, file: Path, failure_path: Path | None = None):
|
640
|
+
filename = file.name
|
641
|
+
with open(file, "r") as f:
|
642
|
+
file_obj = json.load(f)
|
643
|
+
return self.translate_janus_obj(file_obj, filename, failure_path)
|
644
|
+
|
645
|
+
def translate_janus_obj(self, obj: Any, name: str, failure_path: Path | None = None):
|
646
|
+
block = self._janus_object_to_codeblock(obj, name)
|
647
|
+
return self.translate_block(block)
|
648
|
+
|
649
|
+
def translate_text(self, text: str, name: str, failure_path: Path | None = None):
|
650
|
+
"""
|
651
|
+
Translates given text
|
652
|
+
Arguments:
|
653
|
+
text: text to translate
|
654
|
+
name: the name of the text (filename if from a file)
|
655
|
+
failure_path: path to write failure file if translation is not successful
|
656
|
+
"""
|
657
|
+
input_block = self._split_text(text, name)
|
658
|
+
return self.translate_block(input_block, name, failure_path)
|
659
|
+
|
547
660
|
def _iterative_translate(
|
548
661
|
self, root: CodeBlock, failure_path: Path | None = None
|
549
662
|
) -> TranslatedCodeBlock:
|
@@ -607,9 +720,8 @@ class Converter:
|
|
607
720
|
)
|
608
721
|
raise e
|
609
722
|
finally:
|
610
|
-
|
611
|
-
|
612
|
-
)
|
723
|
+
out_obj = self._get_output_obj(translated_root, self._combine_output)
|
724
|
+
log.debug(f"Resulting Block:" f"{json.dumps(out_obj)}")
|
613
725
|
if not translated_root.translated:
|
614
726
|
if failure_path is not None:
|
615
727
|
self._save_to_file(translated_root, failure_path)
|
@@ -666,6 +778,16 @@ class Converter:
|
|
666
778
|
|
667
779
|
log.debug(f"[{block.name}] Output code:\n{block.text}")
|
668
780
|
|
781
|
+
def _split_text(self, text: str, name: str) -> CodeBlock:
|
782
|
+
log.info(f"[{name}] Splitting text")
|
783
|
+
root = self._splitter.split_string(text, name)
|
784
|
+
log.info(
|
785
|
+
f"[{name}] Text split into {root.n_descendents:,} blocks,"
|
786
|
+
f"tree of height {root.height}"
|
787
|
+
)
|
788
|
+
log.info(f"[{name}] Input CodeBlock Structure:\n{root.tree_str()}")
|
789
|
+
return root
|
790
|
+
|
669
791
|
def _split_file(self, file: Path) -> CodeBlock:
|
670
792
|
filename = file.name
|
671
793
|
log.info(f"[{filename}] Splitting file")
|
@@ -680,19 +802,51 @@ class Converter:
|
|
680
802
|
def _run_chain(self, block: TranslatedCodeBlock) -> str:
|
681
803
|
return self.chain.invoke(block.original)
|
682
804
|
|
805
|
+
def _combine_metadata(self, metadatas: list[dict]):
|
806
|
+
return dict(
|
807
|
+
cost=sum(m["cost"] for m in metadatas),
|
808
|
+
processing_time=sum(m["processing_time"] for m in metadatas),
|
809
|
+
num_requests=sum(m["num_requests"] for m in metadatas),
|
810
|
+
input_tokens=sum(m["input_tokens"] for m in metadatas),
|
811
|
+
output_tokens=sum(m["output_tokens"] for m in metadatas),
|
812
|
+
converter_name=self.__class__.__name__,
|
813
|
+
)
|
814
|
+
|
815
|
+
def _combine_inputs(self, inputs: list[str]):
|
816
|
+
s = ""
|
817
|
+
for i in inputs:
|
818
|
+
s += i
|
819
|
+
return s
|
820
|
+
|
683
821
|
def _get_output_obj(
|
684
|
-
self, block: TranslatedCodeBlock
|
822
|
+
self, block: TranslatedCodeBlock | list, combine_children: bool = True
|
685
823
|
) -> dict[str, int | float | str | dict[str, str] | dict[str, float]]:
|
824
|
+
if isinstance(block, list):
|
825
|
+
# TODO: run on all items in list
|
826
|
+
outputs = [self._get_output_obj(b, combine_children) for b in block]
|
827
|
+
metadata = self._combine_metadata([o["metadata"] for o in outputs])
|
828
|
+
input_agg = self._combine_inputs(o["input"] for o in outputs)
|
829
|
+
return dict(
|
830
|
+
input=input_agg,
|
831
|
+
metadata=metadata,
|
832
|
+
outputs=outputs,
|
833
|
+
)
|
834
|
+
if not combine_children and len(block.children) > 0:
|
835
|
+
outputs = self._get_output_obj_children(block)
|
836
|
+
metadata = self._combine_metadata([o["metadata"] for o in outputs])
|
837
|
+
input_agg = self._combine_inputs(o["input"] for o in outputs)
|
838
|
+
return dict(
|
839
|
+
input=input_agg,
|
840
|
+
metadata=metadata,
|
841
|
+
outputs=outputs,
|
842
|
+
)
|
686
843
|
output_obj: str | dict[str, str]
|
687
844
|
if not block.translation_completed:
|
688
845
|
# translation wasn't completed, so combined parsing will likely fail
|
689
|
-
output_obj = block.complete_text
|
846
|
+
output_obj = [block.complete_text]
|
690
847
|
else:
|
691
848
|
output_str = self._parser.parse_combined_output(block.complete_text)
|
692
|
-
|
693
|
-
output_obj = json.loads(output_str)
|
694
|
-
except json.JSONDecodeError:
|
695
|
-
output_obj = output_str
|
849
|
+
output_obj = [output_str]
|
696
850
|
|
697
851
|
return dict(
|
698
852
|
input=block.original.text or "",
|
@@ -702,16 +856,117 @@ class Converter:
|
|
702
856
|
num_requests=block.total_num_requests,
|
703
857
|
input_tokens=block.total_request_input_tokens,
|
704
858
|
output_tokens=block.total_request_output_tokens,
|
859
|
+
converter_name=self.__class__.__name__,
|
705
860
|
),
|
706
|
-
|
861
|
+
outputs=output_obj,
|
707
862
|
)
|
708
863
|
|
864
|
+
def _get_output_obj_children(self, block: TranslatedCodeBlock):
|
865
|
+
if len(block.children) > 0:
|
866
|
+
res = []
|
867
|
+
for c in block.children:
|
868
|
+
res += self._get_output_obj_children(c)
|
869
|
+
return res
|
870
|
+
else:
|
871
|
+
return [self._get_output_obj(block, combine_children=True)]
|
872
|
+
|
709
873
|
def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
|
710
874
|
"""Save a file to disk.
|
711
875
|
|
712
876
|
Arguments:
|
713
877
|
block: The `TranslatedCodeBlock` to save to a file.
|
714
878
|
"""
|
715
|
-
obj = self._get_output_obj(block)
|
879
|
+
obj = self._get_output_obj(block, combine_children=self._combine_output)
|
716
880
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
717
881
|
out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")
|
882
|
+
|
883
|
+
def _janus_object_to_codeblock(self, janus_obj: dict, name: str):
|
884
|
+
results = []
|
885
|
+
for o in janus_obj["outputs"]:
|
886
|
+
if isinstance(o, str):
|
887
|
+
code_block = self._split_text(o, name)
|
888
|
+
meta_data = janus_obj["metadata"]
|
889
|
+
code_block.initial_cost = meta_data["cost"]
|
890
|
+
code_block.initial_input_tokens = meta_data["input_tokens"]
|
891
|
+
code_block.initial_output_tokens = meta_data["output_tokens"]
|
892
|
+
code_block.initial_num_requests = meta_data["num_requests"]
|
893
|
+
code_block.initial_processing_time = meta_data["processing_time"]
|
894
|
+
code_block.previous_generations = janus_obj.get(
|
895
|
+
"intermediate_outputs", []
|
896
|
+
) + [janus_obj]
|
897
|
+
results.append(code_block)
|
898
|
+
else:
|
899
|
+
results.append(self._janus_object_to_codeblock(o))
|
900
|
+
while isinstance(results, list) and len(results) == 1:
|
901
|
+
results = results[0]
|
902
|
+
return results
|
903
|
+
|
904
|
+
def __or__(self, other: "Converter"):
|
905
|
+
from janus.converter.chain import ConverterChain
|
906
|
+
|
907
|
+
return ConverterChain(self, other)
|
908
|
+
|
909
|
+
@property
|
910
|
+
def source_language(self):
|
911
|
+
return self._source_language
|
912
|
+
|
913
|
+
@property
|
914
|
+
def target_language(self):
|
915
|
+
return self._target_language
|
916
|
+
|
917
|
+
@property
|
918
|
+
def target_version(self):
|
919
|
+
return self._target_version
|
920
|
+
|
921
|
+
def set_target_language(
|
922
|
+
self, target_language: str, target_version: str | None
|
923
|
+
) -> None:
|
924
|
+
"""Validate and set the target language.
|
925
|
+
|
926
|
+
The affected objects will not be updated until translate() is called.
|
927
|
+
|
928
|
+
Arguments:
|
929
|
+
target_language: The target programming language.
|
930
|
+
target_version: The target version of the target programming language.
|
931
|
+
"""
|
932
|
+
target_language = target_language.lower()
|
933
|
+
if target_language not in LANGUAGES:
|
934
|
+
raise ValueError(
|
935
|
+
f"Invalid target language: {target_language}. "
|
936
|
+
"Valid target languages are found in `janus.utils.enums.LANGUAGES`."
|
937
|
+
)
|
938
|
+
self._target_language = target_language
|
939
|
+
self._target_version = target_version
|
940
|
+
# Taking the first suffix as the default for output files
|
941
|
+
self._target_suffix = f".{LANGUAGES[target_language]['suffixes'][0]}"
|
942
|
+
|
943
|
+
@classmethod
|
944
|
+
def eval_obj(cls, target, metric_func, *args, **kwargs):
|
945
|
+
if "reference" in kwargs:
|
946
|
+
return cls.eval_obj_reference(target, metric_func, *args, **kwargs)
|
947
|
+
else:
|
948
|
+
return cls.eval_obj_noreference(target, metric_func, *args, **kwargs)
|
949
|
+
|
950
|
+
@classmethod
|
951
|
+
def eval_obj_noreference(cls, target, metric_func, *args, **kwargs):
|
952
|
+
results = []
|
953
|
+
for o in target["outputs"]:
|
954
|
+
if isinstance(o, dict):
|
955
|
+
results += cls.eval_obj_noreference(o, metric_func, *args, **kwargs)
|
956
|
+
else:
|
957
|
+
results.append(metric_func(o, *args, **kwargs))
|
958
|
+
return results
|
959
|
+
|
960
|
+
@classmethod
|
961
|
+
def eval_obj_reference(cls, target, metric_func, reference, *args, **kwargs):
|
962
|
+
results = []
|
963
|
+
for o, r in zip(target["outputs"], reference["outputs"]):
|
964
|
+
if isinstance(o, dict):
|
965
|
+
if not isinstance(r, dict):
|
966
|
+
raise ValueError("Error: format of reference doesn't match target")
|
967
|
+
results += cls.eval_obj_reference(o, metric_func, r, *args, **kwargs)
|
968
|
+
else:
|
969
|
+
if isinstance(r, dict):
|
970
|
+
raise ValueError("Error: format of reference doesn't match target")
|
971
|
+
results.append(metric_func(o, r, *args, **kwargs))
|
972
|
+
return results
|
janus/converter/diagram.py
CHANGED
@@ -14,6 +14,7 @@ class DiagramGenerator(Documenter):
|
|
14
14
|
self,
|
15
15
|
diagram_type="Activity",
|
16
16
|
add_documentation=False,
|
17
|
+
extract_variables=False,
|
17
18
|
**kwargs,
|
18
19
|
) -> None:
|
19
20
|
"""Initialize the DiagramGenerator class
|
@@ -28,24 +29,25 @@ class DiagramGenerator(Documenter):
|
|
28
29
|
self._documenter = Documenter(**kwargs)
|
29
30
|
|
30
31
|
super().__init__(**kwargs)
|
31
|
-
|
32
|
-
|
32
|
+
prompts = []
|
33
|
+
if extract_variables:
|
34
|
+
prompts.append("extract_variables")
|
35
|
+
prompts += ["diagram_with_documentation" if add_documentation else "diagram"]
|
36
|
+
self.set_prompts(prompts)
|
33
37
|
self._parser = UMLSyntaxParser(language="plantuml")
|
34
38
|
|
35
39
|
self._load_parameters()
|
36
40
|
|
37
|
-
def _load_prompt(self):
|
38
|
-
super()._load_prompt()
|
39
|
-
self._prompt = self._prompt.partial(DIAGRAM_TYPE=self._diagram_type)
|
40
|
-
|
41
41
|
def _input_runnable(self) -> Runnable:
|
42
42
|
if self._add_documentation:
|
43
43
|
return RunnableParallel(
|
44
44
|
SOURCE_CODE=self._parser.parse_input,
|
45
45
|
DOCUMENTATION=self._documenter.chain,
|
46
46
|
context=self._retriever,
|
47
|
+
DIAGRAM_TYPE=lambda x: self._diagram_type,
|
47
48
|
)
|
48
49
|
return RunnableParallel(
|
49
50
|
SOURCE_CODE=self._parser.parse_input,
|
50
51
|
context=self._retriever,
|
52
|
+
DIAGRAM_TYPE=lambda x: self._diagram_type,
|
51
53
|
)
|