janus-llm 4.3.5__py3-none-any.whl → 4.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +1 -1
- janus/cli/aggregate.py +2 -2
- janus/cli/cli.py +6 -0
- janus/cli/constants.py +6 -0
- janus/cli/diagram.py +36 -7
- janus/cli/document.py +10 -1
- janus/cli/llm.py +7 -3
- janus/cli/partition.py +10 -1
- janus/cli/pipeline.py +126 -0
- janus/cli/self_eval.py +10 -3
- janus/cli/translate.py +10 -1
- janus/converter/__init__.py +2 -0
- janus/converter/_tests/test_translate.py +6 -5
- janus/converter/chain.py +100 -0
- janus/converter/converter.py +467 -90
- janus/converter/diagram.py +12 -8
- janus/converter/document.py +17 -7
- janus/converter/evaluate.py +174 -147
- janus/converter/partition.py +6 -11
- janus/converter/passthrough.py +29 -0
- janus/converter/pool.py +74 -0
- janus/converter/requirements.py +7 -40
- janus/converter/translate.py +2 -58
- janus/language/_tests/test_combine.py +1 -0
- janus/language/block.py +115 -5
- janus/llm/model_callbacks.py +6 -0
- janus/llm/models_info.py +19 -0
- janus/metrics/_tests/test_reading.py +48 -4
- janus/metrics/_tests/test_rouge_score.py +5 -11
- janus/metrics/metric.py +47 -124
- janus/metrics/reading.py +48 -28
- janus/metrics/rouge_score.py +21 -34
- janus/parsers/_tests/test_code_parser.py +1 -1
- janus/parsers/code_parser.py +2 -2
- janus/parsers/eval_parsers/incose_parser.py +3 -3
- janus/parsers/reqs_parser.py +3 -3
- janus/prompts/templates/cyclic/human.txt +16 -0
- janus/prompts/templates/cyclic/system.txt +1 -0
- janus/prompts/templates/eval_prompts/incose/human.txt +1 -1
- janus/prompts/templates/extract_variables/human.txt +5 -0
- janus/prompts/templates/extract_variables/system.txt +1 -0
- {janus_llm-4.3.5.dist-info → janus_llm-4.5.4.dist-info}/METADATA +14 -15
- {janus_llm-4.3.5.dist-info → janus_llm-4.5.4.dist-info}/RECORD +46 -40
- {janus_llm-4.3.5.dist-info → janus_llm-4.5.4.dist-info}/WHEEL +1 -1
- janus/metrics/_tests/test_llm.py +0 -90
- janus/metrics/llm_metrics.py +0 -202
- {janus_llm-4.3.5.dist-info → janus_llm-4.5.4.dist-info}/LICENSE +0 -0
- {janus_llm-4.3.5.dist-info → janus_llm-4.5.4.dist-info}/entry_points.txt +0 -0
janus/converter/converter.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import functools
|
2
2
|
import json
|
3
3
|
import time
|
4
|
+
from copy import deepcopy
|
4
5
|
from pathlib import Path
|
5
6
|
from typing import Any
|
6
7
|
|
@@ -16,7 +17,7 @@ from openai import BadRequestError, RateLimitError
|
|
16
17
|
from pydantic import ValidationError
|
17
18
|
|
18
19
|
from janus.embedding.vectorize import ChromaDBVectorizer
|
19
|
-
from janus.language.block import CodeBlock, TranslatedCodeBlock
|
20
|
+
from janus.language.block import BlockCollection, CodeBlock, TranslatedCodeBlock
|
20
21
|
from janus.language.combine import Combiner
|
21
22
|
from janus.language.naive.registry import CUSTOM_SPLITTERS
|
22
23
|
from janus.language.splitter import (
|
@@ -76,14 +77,22 @@ class Converter:
|
|
76
77
|
source_language: str = "fortran",
|
77
78
|
max_prompts: int = 10,
|
78
79
|
max_tokens: int | None = None,
|
79
|
-
|
80
|
+
prompt_templates: list[str] | str = ["simple"],
|
80
81
|
db_path: str | None = None,
|
81
82
|
db_config: dict[str, Any] | None = None,
|
82
83
|
protected_node_types: tuple[str, ...] = (),
|
83
84
|
prune_node_types: tuple[str, ...] = (),
|
84
85
|
splitter_type: str = "file",
|
85
|
-
refiner_types: list[type[JanusRefiner]] = [JanusRefiner],
|
86
|
+
refiner_types: list[type[JanusRefiner] | str] = [JanusRefiner],
|
86
87
|
retriever_type: str | None = None,
|
88
|
+
combine_output: bool = True,
|
89
|
+
use_janus_inputs: bool = False,
|
90
|
+
target_language: str = "json",
|
91
|
+
target_version: str | None = None,
|
92
|
+
input_types: set[str] | str | None = None,
|
93
|
+
input_labels: set[str] | str | None = None,
|
94
|
+
output_type: str | None = None,
|
95
|
+
output_label: str | None = None,
|
87
96
|
) -> None:
|
88
97
|
"""Initialize a Converter instance.
|
89
98
|
|
@@ -96,7 +105,7 @@ class Converter:
|
|
96
105
|
max_prompts: The maximum number of prompts to try before giving up.
|
97
106
|
max_tokens: The maximum number of tokens to use in the LLM. If `None`, the
|
98
107
|
converter will use half the model's token limit.
|
99
|
-
|
108
|
+
prompt_templates: The name of the prompt templates to use.
|
100
109
|
db_path: The path to the database to use for vectorization.
|
101
110
|
db_config: The configuration for the database.
|
102
111
|
protected_node_types: A set of node types that aren't to be merged.
|
@@ -111,12 +120,21 @@ class Converter:
|
|
111
120
|
- "active_usings"
|
112
121
|
- "language_docs"
|
113
122
|
- None
|
123
|
+
combine_output: Whether to combine the output into a single file or not.
|
124
|
+
use_janus_inputs: Whether to use janus inputs or not.
|
125
|
+
target_language: The target programming language.
|
126
|
+
target_version: The target programming language version.
|
127
|
+
input_types: The types of input to accept.
|
128
|
+
input_labels: The labels of input to accept.
|
129
|
+
output_type: The type of output to produce.
|
130
|
+
output_label: The label of output to produce.
|
114
131
|
"""
|
115
132
|
self._changed_attrs: set = set()
|
116
133
|
|
117
134
|
self.max_prompts: int = max_prompts
|
118
135
|
self._max_tokens: int | None = max_tokens
|
119
136
|
self.override_token_limit: bool = max_tokens is not None
|
137
|
+
self._combine_output = combine_output
|
120
138
|
|
121
139
|
self._model_name: str
|
122
140
|
self._custom_model_arguments: dict[str, Any]
|
@@ -124,13 +142,16 @@ class Converter:
|
|
124
142
|
self._source_language: str
|
125
143
|
self._source_suffixes: list[str]
|
126
144
|
|
127
|
-
self._target_language
|
128
|
-
self._target_suffix
|
145
|
+
self._target_language: str
|
146
|
+
self._target_suffix: str
|
147
|
+
self._target_version: str | None
|
148
|
+
self.set_target_language(target_language, target_version)
|
149
|
+
self._use_janus_inputs = use_janus_inputs
|
129
150
|
|
130
151
|
self._protected_node_types: tuple[str, ...] = ()
|
131
152
|
self._prune_node_types: tuple[str, ...] = ()
|
132
153
|
self._max_tokens: int | None = max_tokens
|
133
|
-
self.
|
154
|
+
self._prompt_template_names: list[str]
|
134
155
|
self._db_path: str | None
|
135
156
|
self._db_config: dict[str, Any] | None
|
136
157
|
|
@@ -142,7 +163,7 @@ class Converter:
|
|
142
163
|
self._combiner: Combiner = Combiner()
|
143
164
|
|
144
165
|
self._splitter_type: str
|
145
|
-
self._refiner_types: list[type[JanusRefiner]]
|
166
|
+
self._refiner_types: list[type[JanusRefiner] | str]
|
146
167
|
self._retriever_type: str | None
|
147
168
|
|
148
169
|
self._splitter: Splitter
|
@@ -153,13 +174,20 @@ class Converter:
|
|
153
174
|
self.set_refiner_types(refiner_types=refiner_types)
|
154
175
|
self.set_retriever(retriever_type=retriever_type)
|
155
176
|
self.set_model(model_name=model, **model_arguments)
|
156
|
-
self.
|
177
|
+
self.set_prompts(prompt_templates=prompt_templates)
|
157
178
|
self.set_source_language(source_language)
|
158
179
|
self.set_protected_node_types(protected_node_types)
|
159
180
|
self.set_prune_node_types(prune_node_types)
|
160
181
|
self.set_db_path(db_path=db_path)
|
161
182
|
self.set_db_config(db_config=db_config)
|
162
183
|
|
184
|
+
self._input_types = input_types
|
185
|
+
self._input_labels = input_labels
|
186
|
+
self._output_type = output_type
|
187
|
+
self._output_label = output_label
|
188
|
+
|
189
|
+
self._load_parameters()
|
190
|
+
|
163
191
|
# Child class must call this. Should we enforce somehow?
|
164
192
|
# self._load_parameters()
|
165
193
|
|
@@ -174,7 +202,7 @@ class Converter:
|
|
174
202
|
|
175
203
|
def _load_parameters(self) -> None:
|
176
204
|
self._load_model()
|
177
|
-
self.
|
205
|
+
self._load_translation_chain()
|
178
206
|
self._load_retriever()
|
179
207
|
self._load_refiner_chain()
|
180
208
|
self._load_splitter()
|
@@ -195,28 +223,30 @@ class Converter:
|
|
195
223
|
self._model_name = model_name
|
196
224
|
self._custom_model_arguments = custom_arguments
|
197
225
|
|
198
|
-
def
|
226
|
+
def set_prompts(self, prompt_templates: list[str] | str) -> None:
|
199
227
|
"""Validate and set the prompt template name.
|
200
228
|
|
201
229
|
Arguments:
|
202
|
-
|
203
|
-
(see janus/prompts/templates) or
|
230
|
+
prompt_templates: name of prompt template directories
|
231
|
+
(see janus/prompts/templates) or paths to directories.
|
204
232
|
"""
|
205
|
-
|
233
|
+
if isinstance(prompt_templates, str):
|
234
|
+
self._prompt_template_names = [prompt_templates]
|
235
|
+
else:
|
236
|
+
self._prompt_template_names = prompt_templates
|
206
237
|
|
207
238
|
def set_splitter(self, splitter_type: str) -> None:
|
208
239
|
"""Validate and set the prompt template name.
|
209
240
|
|
210
241
|
Arguments:
|
211
|
-
|
212
|
-
(see janus/prompts/templates) or path to a directory.
|
242
|
+
splitter_type: the type of splitter to use
|
213
243
|
"""
|
214
244
|
if splitter_type not in CUSTOM_SPLITTERS:
|
215
245
|
raise ValueError(f'Splitter type "{splitter_type}" does not exist.')
|
216
246
|
|
217
247
|
self._splitter_type = splitter_type
|
218
248
|
|
219
|
-
def set_refiner_types(self, refiner_types: list[type[JanusRefiner]]) -> None:
|
249
|
+
def set_refiner_types(self, refiner_types: list[type[JanusRefiner] | str]) -> None:
|
220
250
|
"""Validate and set the refiner type
|
221
251
|
|
222
252
|
Arguments:
|
@@ -329,25 +359,51 @@ class Converter:
|
|
329
359
|
self._max_tokens = int(token_limit * self._llm.input_token_proportion)
|
330
360
|
|
331
361
|
@run_if_changed(
|
332
|
-
"
|
362
|
+
"_prompt_template_names",
|
333
363
|
"_source_language",
|
334
364
|
"_model_name",
|
335
|
-
"
|
365
|
+
"_target_language",
|
366
|
+
"_target_version",
|
336
367
|
)
|
337
|
-
def
|
338
|
-
|
339
|
-
|
340
|
-
If the relevant fields have not been changed since the last time this
|
341
|
-
method was called, nothing happens.
|
342
|
-
"""
|
368
|
+
def _load_translation_chain(self) -> None:
|
369
|
+
prompt_template_name = self._prompt_template_names[0]
|
343
370
|
prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
|
344
371
|
source_language=self._source_language,
|
345
|
-
prompt_template=
|
372
|
+
prompt_template=prompt_template_name,
|
373
|
+
target_language=self._target_language,
|
374
|
+
target_version=self._target_version,
|
346
375
|
)
|
347
|
-
|
348
|
-
self.
|
349
|
-
|
376
|
+
prompt = prompt_engine.prompt
|
377
|
+
self._translation_chain = RunnableParallel(
|
378
|
+
prompt_value=lambda x, prompt=prompt: prompt.invoke(x),
|
379
|
+
original_inputs=RunnablePassthrough(),
|
380
|
+
) | RunnableParallel(
|
381
|
+
completion=lambda x: self._llm.invoke(x["prompt_value"]),
|
382
|
+
original_inputs=lambda x: x["original_inputs"],
|
383
|
+
prompt_value=lambda x: x["prompt_value"],
|
350
384
|
)
|
385
|
+
for prompt_template_name in self._prompt_template_names[1:]:
|
386
|
+
prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
|
387
|
+
source_language=self._source_language,
|
388
|
+
prompt_template=prompt_template_name,
|
389
|
+
target_language=self._target_language,
|
390
|
+
target_version=self._target_version,
|
391
|
+
)
|
392
|
+
prompt = prompt_engine.prompt
|
393
|
+
self._translation_chain = (
|
394
|
+
self._translation_chain
|
395
|
+
| RunnableParallel(
|
396
|
+
prompt_value=lambda x, prompt=prompt: prompt.invoke(
|
397
|
+
dict(completion=x["completion"], **x["original_inputs"])
|
398
|
+
),
|
399
|
+
original_inputs=lambda x: x["original_inputs"],
|
400
|
+
)
|
401
|
+
| RunnableParallel(
|
402
|
+
completion=lambda x: self._llm.invoke(x["prompt_value"]),
|
403
|
+
original_inputs=lambda x: x["original_inputs"],
|
404
|
+
prompt_value=lambda x: x["prompt_value"],
|
405
|
+
)
|
406
|
+
)
|
351
407
|
|
352
408
|
@run_if_changed("_db_path", "_db_config")
|
353
409
|
def _load_vectorizer(self) -> None:
|
@@ -370,11 +426,41 @@ class Converter:
|
|
370
426
|
|
371
427
|
@run_if_changed("_refiner_types", "_model_name", "max_prompts", "_parser")
|
372
428
|
def _load_refiner_chain(self) -> None:
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
429
|
+
from janus.cli.constants import REFINERS
|
430
|
+
|
431
|
+
if len(self._refiner_types) == 0:
|
432
|
+
self._refiner_chain = RunnableLambda(
|
433
|
+
lambda x: self._parser.parse(x["completion"])
|
434
|
+
)
|
435
|
+
return
|
436
|
+
refiner_type = self._refiner_types[0]
|
437
|
+
if isinstance(refiner_type, str):
|
438
|
+
if refiner_type not in REFINERS:
|
439
|
+
raise ValueError(f"Error: unable to find refiner type {refiner_type}")
|
440
|
+
refiner_type = REFINERS[refiner_type]
|
441
|
+
if len(self._refiner_types) == 1:
|
442
|
+
self._refiner_chain = RunnableLambda(
|
443
|
+
lambda x, refiner_type=refiner_type: refiner_type(
|
444
|
+
llm=self._llm,
|
445
|
+
parser=self._parser,
|
446
|
+
max_retries=self.max_prompts,
|
447
|
+
).parse_completion(**x)
|
448
|
+
)
|
449
|
+
return
|
450
|
+
else:
|
451
|
+
self._refiner_chain = RunnableParallel(
|
452
|
+
completion=lambda x, refiner_type=refiner_type: refiner_type(
|
453
|
+
llm=self._llm,
|
454
|
+
parser=self._base_parser,
|
455
|
+
max_retries=self.max_prompts,
|
456
|
+
).parse_completion(**x),
|
457
|
+
prompt_value=lambda x: x["prompt_value"],
|
458
|
+
)
|
459
|
+
for refiner_type in self._refiner_types[1:-1]:
|
460
|
+
if isinstance(refiner_type, str):
|
461
|
+
if refiner_type not in REFINERS:
|
462
|
+
raise ValueError(f"Error: unable to find refiner type {refiner_type}")
|
463
|
+
refiner_type = REFINERS[refiner_type]
|
378
464
|
# NOTE: Do NOT remove refiner_type=refiner_type from lambda.
|
379
465
|
# Due to lambda capture, must be present or chain will not
|
380
466
|
# be correctly constructed.
|
@@ -394,9 +480,17 @@ class Converter:
|
|
394
480
|
).parse_completion(**x)
|
395
481
|
)
|
396
482
|
|
397
|
-
@run_if_changed(
|
483
|
+
@run_if_changed(
|
484
|
+
"_parser",
|
485
|
+
"_retriever",
|
486
|
+
"_prompt",
|
487
|
+
"_llm",
|
488
|
+
"_refiner_chain",
|
489
|
+
"_target_language",
|
490
|
+
"_target_version",
|
491
|
+
)
|
398
492
|
def _load_chain(self):
|
399
|
-
self.chain = self.
|
493
|
+
self.chain = self.get_chain()
|
400
494
|
|
401
495
|
def _input_runnable(self) -> Runnable:
|
402
496
|
return RunnableParallel(
|
@@ -404,6 +498,12 @@ class Converter:
|
|
404
498
|
context=self._retriever,
|
405
499
|
)
|
406
500
|
|
501
|
+
def get_chain(self) -> Runnable:
|
502
|
+
"""
|
503
|
+
Gets a chain that can be executed by langchain
|
504
|
+
"""
|
505
|
+
return self._input_runnable() | self._translation_chain | self._refiner_chain
|
506
|
+
|
407
507
|
def translate(
|
408
508
|
self,
|
409
509
|
input_directory: str | Path,
|
@@ -436,22 +536,24 @@ class Converter:
|
|
436
536
|
failure_directory.mkdir(parents=True)
|
437
537
|
|
438
538
|
input_paths = []
|
439
|
-
|
539
|
+
if self._use_janus_inputs:
|
540
|
+
source_language = "janus"
|
541
|
+
source_suffixes = [".json"]
|
542
|
+
else:
|
543
|
+
source_language = self._source_language
|
544
|
+
source_suffixes = self._source_suffixes
|
545
|
+
for ext in source_suffixes:
|
440
546
|
input_paths.extend(input_directory.rglob(f"**/*{ext}"))
|
441
547
|
|
442
548
|
log.info(f"Input directory: {input_directory.absolute()}")
|
443
|
-
log.info(
|
444
|
-
f"{self._source_language} {self._source_suffixes} files: "
|
445
|
-
f"{len(input_paths)}"
|
446
|
-
)
|
549
|
+
log.info(f"{source_language} {source_suffixes} files: " f"{len(input_paths)}")
|
447
550
|
log.info(
|
448
551
|
"Other files (skipped): "
|
449
552
|
f"{len(list(input_directory.iterdir())) - len(input_paths)}\n"
|
450
553
|
)
|
451
554
|
if output_directory is not None:
|
452
555
|
output_paths = [
|
453
|
-
output_directory
|
454
|
-
/ p.relative_to(input_directory).with_suffix(self._target_suffix)
|
556
|
+
output_directory / p.relative_to(input_directory).with_suffix(".json")
|
455
557
|
for p in input_paths
|
456
558
|
]
|
457
559
|
else:
|
@@ -459,8 +561,7 @@ class Converter:
|
|
459
561
|
|
460
562
|
if failure_directory is not None:
|
461
563
|
failure_paths = [
|
462
|
-
failure_directory
|
463
|
-
/ p.relative_to(input_directory).with_suffix(self._target_suffix)
|
564
|
+
failure_directory / p.relative_to(input_directory).with_suffix(".json")
|
464
565
|
for p in input_paths
|
465
566
|
]
|
466
567
|
else:
|
@@ -484,12 +585,31 @@ class Converter:
|
|
484
585
|
for in_path, out_path, fail_path in in_out_pairs:
|
485
586
|
# Translate the file, skip it if there's a rate limit error
|
486
587
|
log.info(f"Processing {in_path.relative_to(input_directory)}")
|
487
|
-
|
488
|
-
|
588
|
+
if self._use_janus_inputs:
|
589
|
+
out_block = self.translate_janus_file(in_path, fail_path)
|
590
|
+
else:
|
591
|
+
out_block = self.translate_file(in_path, fail_path)
|
592
|
+
|
593
|
+
def _get_total_cost(block):
|
594
|
+
if isinstance(block, list):
|
595
|
+
return sum(_get_total_cost(b) for b in block)
|
596
|
+
return block.total_cost
|
597
|
+
|
598
|
+
total_cost += _get_total_cost(out_block)
|
489
599
|
log.info(f"Current Running Cost: {total_cost}")
|
490
600
|
|
491
|
-
#
|
492
|
-
|
601
|
+
# For files where translation failed, write to failure path instead
|
602
|
+
|
603
|
+
def _has_empty(block):
|
604
|
+
if isinstance(block, BlockCollection):
|
605
|
+
return len(block.blocks) == 0 or any(
|
606
|
+
_has_empty(b) for b in block.blocks
|
607
|
+
)
|
608
|
+
return not block.translated
|
609
|
+
|
610
|
+
if _has_empty(out_block):
|
611
|
+
if fail_path is not None:
|
612
|
+
self._save_to_file(out_block, fail_path)
|
493
613
|
continue
|
494
614
|
|
495
615
|
if collection_name is not None:
|
@@ -501,37 +621,58 @@ class Converter:
|
|
501
621
|
|
502
622
|
# Make sure the tree's code has been consolidated at the top level
|
503
623
|
# before writing to file
|
504
|
-
|
624
|
+
for b in out_block.blocks:
|
625
|
+
self._combiner.combine(b)
|
505
626
|
if out_path is not None and (overwrite or not out_path.exists()):
|
506
627
|
self._save_to_file(out_block, out_path)
|
507
628
|
|
508
629
|
log.info(f"Total cost: ${total_cost:,.2f}")
|
509
630
|
|
510
|
-
def
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
631
|
+
def _filter_blocks(self, code_block):
|
632
|
+
if isinstance(code_block, BlockCollection):
|
633
|
+
input_blocks = list(code_block.blocks)
|
634
|
+
else:
|
635
|
+
input_blocks = [code_block]
|
636
|
+
if self._input_types is not None:
|
637
|
+
if isinstance(self._input_types, str):
|
638
|
+
self._input_types = set([self._input_types])
|
639
|
+
input_blocks = [
|
640
|
+
b
|
641
|
+
for b in input_blocks
|
642
|
+
if isinstance(b, BlockCollection) or b.block_type in self._input_types
|
643
|
+
]
|
644
|
+
if self._input_labels is not None:
|
645
|
+
if isinstance(self._input_labels, str):
|
646
|
+
self._input_labels = set([self._input_labels])
|
647
|
+
input_blocks = [
|
648
|
+
b
|
649
|
+
for b in input_blocks
|
650
|
+
if isinstance(b, BlockCollection) or b.block_label in self._input_labels
|
651
|
+
]
|
652
|
+
return input_blocks
|
518
653
|
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
654
|
+
def translate_blocks(
|
655
|
+
self,
|
656
|
+
code_block: CodeBlock | BlockCollection,
|
657
|
+
failure_path: Path | None = None,
|
658
|
+
) -> BlockCollection | TranslatedCodeBlock:
|
659
|
+
input_blocks = self._filter_blocks(code_block)
|
660
|
+
output_blocks = []
|
661
|
+
for b in input_blocks:
|
662
|
+
output_blocks.append(self.translate_block(b, failure_path))
|
663
|
+
return BlockCollection(output_blocks, code_block.previous_generations)
|
664
|
+
|
665
|
+
def translate_block(
|
666
|
+
self,
|
667
|
+
input_block: CodeBlock,
|
668
|
+
failure_path: Path | None = None,
|
669
|
+
) -> TranslatedCodeBlock:
|
524
670
|
self._load_parameters()
|
525
|
-
filename = file.name
|
526
|
-
|
527
|
-
input_block = self._split_file(file)
|
528
|
-
t0 = time.time()
|
529
671
|
output_block = self._iterative_translate(input_block, failure_path)
|
530
|
-
output_block.processing_time = time.time() - t0
|
531
672
|
if output_block.translated:
|
532
673
|
completeness = output_block.translation_completeness
|
533
674
|
log.info(
|
534
|
-
f"[{
|
675
|
+
f"[{output_block.name}] Translation complete\n"
|
535
676
|
f" {completeness:.2%} of input successfully translated\n"
|
536
677
|
f" Total cost: ${output_block.total_cost:,.2f}\n"
|
537
678
|
f" Output CodeBlock Structure:\n{input_block.tree_str()}\n"
|
@@ -539,11 +680,51 @@ class Converter:
|
|
539
680
|
|
540
681
|
else:
|
541
682
|
log.error(
|
542
|
-
f"[{
|
683
|
+
f"[{output_block.name}] Translation failed\n"
|
543
684
|
f" Total cost: ${output_block.total_cost:,.2f}\n"
|
544
685
|
)
|
545
686
|
return output_block
|
546
687
|
|
688
|
+
def translate_file(
|
689
|
+
self,
|
690
|
+
file: Path,
|
691
|
+
failure_path: Path | None = None,
|
692
|
+
) -> TranslatedCodeBlock:
|
693
|
+
"""Translate a single file.
|
694
|
+
|
695
|
+
Arguments:
|
696
|
+
file: Input path to file
|
697
|
+
failure_path: path to directory to store failure summaries`
|
698
|
+
|
699
|
+
Returns:
|
700
|
+
A `TranslatedCodeBlock` object. This block does not have a path set, and its
|
701
|
+
code is not guaranteed to be consolidated. To amend this, run
|
702
|
+
`Combiner.combine_children` on the block.
|
703
|
+
"""
|
704
|
+
input_block = self._split_file(file)
|
705
|
+
return self.translate_blocks(input_block, failure_path)
|
706
|
+
|
707
|
+
def translate_janus_file(self, file: Path, failure_path: Path | None = None):
|
708
|
+
filename = file.name
|
709
|
+
with open(file, "r") as f:
|
710
|
+
file_obj = json.load(f)
|
711
|
+
return self.translate_janus_obj(file_obj, filename, failure_path)
|
712
|
+
|
713
|
+
def translate_janus_obj(self, obj: Any, name: str, failure_path: Path | None = None):
|
714
|
+
block = self._janus_object_to_codeblock(obj, name)
|
715
|
+
return self.translate_blocks(block, failure_path)
|
716
|
+
|
717
|
+
def translate_text(self, text: str, name: str, failure_path: Path | None = None):
|
718
|
+
"""
|
719
|
+
Translates given text
|
720
|
+
Arguments:
|
721
|
+
text: text to translate
|
722
|
+
name: the name of the text (filename if from a file)
|
723
|
+
failure_path: path to write failure file if translation is not successful
|
724
|
+
"""
|
725
|
+
input_block = self._split_text(text, name)
|
726
|
+
return self.translate_blocks(input_block, failure_path)
|
727
|
+
|
547
728
|
def _iterative_translate(
|
548
729
|
self, root: CodeBlock, failure_path: Path | None = None
|
549
730
|
) -> TranslatedCodeBlock:
|
@@ -556,7 +737,13 @@ class Converter:
|
|
556
737
|
Returns:
|
557
738
|
A `TranslatedCodeBlock`
|
558
739
|
"""
|
559
|
-
translated_root = TranslatedCodeBlock(
|
740
|
+
translated_root = TranslatedCodeBlock(
|
741
|
+
root,
|
742
|
+
self._target_language,
|
743
|
+
self,
|
744
|
+
block_type=self._output_type,
|
745
|
+
block_label=self._output_label,
|
746
|
+
)
|
560
747
|
last_prog, prog_delta = 0, 0.1
|
561
748
|
stack = [translated_root]
|
562
749
|
try:
|
@@ -579,7 +766,7 @@ class Converter:
|
|
579
766
|
except RateLimitError:
|
580
767
|
pass
|
581
768
|
except OutputParserException as e:
|
582
|
-
log.error(f"Skipping file, failed to parse output: {e}
|
769
|
+
log.error(f"Skipping file, failed to parse output: {e}")
|
583
770
|
except BadRequestError as e:
|
584
771
|
if str(e).startswith("Detected an error in the prompt"):
|
585
772
|
log.warning("Malformed input, skipping")
|
@@ -607,9 +794,10 @@ class Converter:
|
|
607
794
|
)
|
608
795
|
raise e
|
609
796
|
finally:
|
610
|
-
|
611
|
-
|
797
|
+
out_obj = self._get_output_obj(
|
798
|
+
translated_root, self._combine_output, include_previous_outputs=True
|
612
799
|
)
|
800
|
+
log.debug(f"Resulting Block:" f"{json.dumps(out_obj)}")
|
613
801
|
if not translated_root.translated:
|
614
802
|
if failure_path is not None:
|
615
803
|
self._save_to_file(translated_root, failure_path)
|
@@ -666,6 +854,16 @@ class Converter:
|
|
666
854
|
|
667
855
|
log.debug(f"[{block.name}] Output code:\n{block.text}")
|
668
856
|
|
857
|
+
def _split_text(self, text: str, name: str) -> CodeBlock:
|
858
|
+
log.info(f"[{name}] Splitting text")
|
859
|
+
root = self._splitter.split_string(text, name)
|
860
|
+
log.info(
|
861
|
+
f"[{name}] Text split into {root.n_descendents:,} blocks,"
|
862
|
+
f"tree of height {root.height}"
|
863
|
+
)
|
864
|
+
log.info(f"[{name}] Input CodeBlock Structure:\n{root.tree_str()}")
|
865
|
+
return root
|
866
|
+
|
669
867
|
def _split_file(self, file: Path) -> CodeBlock:
|
670
868
|
filename = file.name
|
671
869
|
log.info(f"[{filename}] Splitting file")
|
@@ -680,31 +878,113 @@ class Converter:
|
|
680
878
|
def _run_chain(self, block: TranslatedCodeBlock) -> str:
|
681
879
|
return self.chain.invoke(block.original)
|
682
880
|
|
881
|
+
def _combine_metadata(self, metadatas: list[dict]):
|
882
|
+
return dict(
|
883
|
+
cost=sum(m["cost"] for m in metadatas),
|
884
|
+
processing_time=sum(m["processing_time"] for m in metadatas),
|
885
|
+
num_requests=sum(m["num_requests"] for m in metadatas),
|
886
|
+
input_tokens=sum(m["input_tokens"] for m in metadatas),
|
887
|
+
output_tokens=sum(m["output_tokens"] for m in metadatas),
|
888
|
+
converter_name=self.__class__.__name__,
|
889
|
+
type=[m["type"] for m in metadatas],
|
890
|
+
label=[m["label"] for m in metadatas],
|
891
|
+
)
|
892
|
+
|
893
|
+
def _combine_inputs(self, inputs: list[str]):
|
894
|
+
return json.dumps(inputs)
|
895
|
+
|
683
896
|
def _get_output_obj(
|
684
|
-
self,
|
897
|
+
self,
|
898
|
+
block: TranslatedCodeBlock | BlockCollection | dict,
|
899
|
+
combine_children: bool = True,
|
900
|
+
include_previous_outputs: bool = True,
|
685
901
|
) -> dict[str, int | float | str | dict[str, str] | dict[str, float]]:
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
902
|
+
block_type = None
|
903
|
+
block_label = None
|
904
|
+
if isinstance(block, dict):
|
905
|
+
# output object has already been generated
|
906
|
+
new_block = deepcopy(block)
|
907
|
+
if "intermediate_outputs" in new_block:
|
908
|
+
del new_block["intermediate_outputs"]
|
909
|
+
return new_block
|
910
|
+
if isinstance(block, BlockCollection):
|
911
|
+
if len(block.blocks) == 1:
|
912
|
+
outputs = self._get_output_obj(block.blocks[0], combine_children, False)[
|
913
|
+
"outputs"
|
914
|
+
]
|
915
|
+
block_type = block.blocks[0].block_type
|
916
|
+
block_label = block.blocks[0].block_label
|
917
|
+
else:
|
918
|
+
outputs = [
|
919
|
+
self._get_output_obj(b, combine_children, False) for b in block.blocks
|
920
|
+
]
|
921
|
+
elif (
|
922
|
+
not isinstance(block, BlockCollection)
|
923
|
+
and not combine_children
|
924
|
+
and len(block.children) > 0
|
925
|
+
):
|
926
|
+
outputs = self._get_output_obj_children(block, False)
|
690
927
|
else:
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
928
|
+
block_type = block.block_type
|
929
|
+
block_label = block.block_label
|
930
|
+
if not block.translation_completed:
|
931
|
+
# translation wasn't completed, so combined parsing will likely fail
|
932
|
+
outputs = [block.complete_text]
|
933
|
+
else:
|
934
|
+
output_str = self._parser.parse_combined_output(block.complete_text)
|
935
|
+
outputs = [output_str]
|
936
|
+
|
937
|
+
def _get_input(block):
|
938
|
+
if isinstance(block, BlockCollection):
|
939
|
+
return self._combine_inputs([_get_input(b) for b in block.blocks])
|
940
|
+
return block.original.text or ""
|
941
|
+
|
942
|
+
out = dict(
|
943
|
+
input=_get_input(block),
|
699
944
|
metadata=dict(
|
700
945
|
cost=block.total_cost,
|
701
|
-
processing_time=block.
|
946
|
+
processing_time=block.total_processing_time,
|
702
947
|
num_requests=block.total_num_requests,
|
703
948
|
input_tokens=block.total_request_input_tokens,
|
704
949
|
output_tokens=block.total_request_output_tokens,
|
950
|
+
converter_name=self.__class__.__name__,
|
951
|
+
type=block_type,
|
952
|
+
label=block_label,
|
705
953
|
),
|
706
|
-
|
954
|
+
outputs=outputs,
|
707
955
|
)
|
956
|
+
if (
|
957
|
+
include_previous_outputs
|
958
|
+
and isinstance(block, BlockCollection)
|
959
|
+
and len(block.previous_generations) > 0
|
960
|
+
):
|
961
|
+
intermediate_outputs = []
|
962
|
+
for p in block.previous_generations:
|
963
|
+
if isinstance(p, dict):
|
964
|
+
# preserve intermediate outputs from previous runs
|
965
|
+
intermediate_outputs.append(
|
966
|
+
self._get_output_obj(p, combine_children, False)
|
967
|
+
)
|
968
|
+
if len(intermediate_outputs) > 0:
|
969
|
+
out["intermediate_outputs"] = intermediate_outputs
|
970
|
+
return out
|
971
|
+
|
972
|
+
def _get_output_obj_children(
|
973
|
+
self, block: TranslatedCodeBlock, include_previous_outputs: bool = True
|
974
|
+
):
|
975
|
+
if len(block.children) > 0:
|
976
|
+
res = []
|
977
|
+
for c in block.children:
|
978
|
+
res += self._get_output_obj_children(c, include_previous_outputs)
|
979
|
+
return res
|
980
|
+
else:
|
981
|
+
return [
|
982
|
+
self._get_output_obj(
|
983
|
+
block,
|
984
|
+
combine_children=True,
|
985
|
+
include_previous_outputs=include_previous_outputs,
|
986
|
+
)
|
987
|
+
]
|
708
988
|
|
709
989
|
def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
|
710
990
|
"""Save a file to disk.
|
@@ -712,6 +992,103 @@ class Converter:
|
|
712
992
|
Arguments:
|
713
993
|
block: The `TranslatedCodeBlock` to save to a file.
|
714
994
|
"""
|
715
|
-
obj = self._get_output_obj(
|
995
|
+
obj = self._get_output_obj(
|
996
|
+
block, combine_children=self._combine_output, include_previous_outputs=True
|
997
|
+
)
|
716
998
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
717
999
|
out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")
|
1000
|
+
|
1001
|
+
def _janus_object_to_codeblock(self, janus_obj: dict, name: str):
|
1002
|
+
results = []
|
1003
|
+
for o in janus_obj["outputs"]:
|
1004
|
+
metadata = janus_obj["metadata"]
|
1005
|
+
if isinstance(o, str):
|
1006
|
+
block_label = metadata["label"]
|
1007
|
+
if isinstance(block_label, list):
|
1008
|
+
block_label = block_label[0]
|
1009
|
+
block_type = metadata["type"]
|
1010
|
+
if isinstance(block_type, list):
|
1011
|
+
block_type = block_type[0]
|
1012
|
+
code_block = self._split_text(o, name)
|
1013
|
+
code_block.previous_generations = janus_obj.get(
|
1014
|
+
"intermediate_outputs", []
|
1015
|
+
) + [janus_obj]
|
1016
|
+
code_block.block_type = block_type
|
1017
|
+
code_block.block_label = block_label
|
1018
|
+
results.append(code_block)
|
1019
|
+
else:
|
1020
|
+
results += self._janus_object_to_codeblock(o, name).blocks
|
1021
|
+
previous_generations = janus_obj.get("intermediate_outputs", [])
|
1022
|
+
if janus_obj["metadata"]["converter_name"] != "ConverterChain":
|
1023
|
+
previous_generations += [janus_obj]
|
1024
|
+
return BlockCollection(results, previous_generations)
|
1025
|
+
|
1026
|
+
def __or__(self, other: "Converter"):
|
1027
|
+
from janus.converter.chain import ConverterChain
|
1028
|
+
|
1029
|
+
return ConverterChain(self, other)
|
1030
|
+
|
1031
|
+
@property
|
1032
|
+
def source_language(self):
|
1033
|
+
return self._source_language
|
1034
|
+
|
1035
|
+
@property
|
1036
|
+
def target_language(self):
|
1037
|
+
return self._target_language
|
1038
|
+
|
1039
|
+
@property
|
1040
|
+
def target_version(self):
|
1041
|
+
return self._target_version
|
1042
|
+
|
1043
|
+
def set_target_language(
|
1044
|
+
self, target_language: str, target_version: str | None
|
1045
|
+
) -> None:
|
1046
|
+
"""Validate and set the target language.
|
1047
|
+
|
1048
|
+
The affected objects will not be updated until translate() is called.
|
1049
|
+
|
1050
|
+
Arguments:
|
1051
|
+
target_language: The target programming language.
|
1052
|
+
target_version: The target version of the target programming language.
|
1053
|
+
"""
|
1054
|
+
target_language = target_language.lower()
|
1055
|
+
if target_language not in LANGUAGES:
|
1056
|
+
raise ValueError(
|
1057
|
+
f"Invalid target language: {target_language}. "
|
1058
|
+
"Valid target languages are found in `janus.utils.enums.LANGUAGES`."
|
1059
|
+
)
|
1060
|
+
self._target_language = target_language
|
1061
|
+
self._target_version = target_version
|
1062
|
+
# Taking the first suffix as the default for output files
|
1063
|
+
self._target_suffix = f".{LANGUAGES[target_language]['suffixes'][0]}"
|
1064
|
+
|
1065
|
+
@classmethod
|
1066
|
+
def eval_obj(cls, target, metric_func, *args, **kwargs):
|
1067
|
+
if "reference" in kwargs:
|
1068
|
+
return cls.eval_obj_reference(target, metric_func, *args, **kwargs)
|
1069
|
+
else:
|
1070
|
+
return cls.eval_obj_noreference(target, metric_func, *args, **kwargs)
|
1071
|
+
|
1072
|
+
@classmethod
|
1073
|
+
def eval_obj_noreference(cls, target, metric_func, *args, **kwargs):
|
1074
|
+
results = []
|
1075
|
+
for o in target["outputs"]:
|
1076
|
+
if isinstance(o, dict):
|
1077
|
+
results += cls.eval_obj_noreference(o, metric_func, *args, **kwargs)
|
1078
|
+
else:
|
1079
|
+
results.append(metric_func(o, *args, **kwargs))
|
1080
|
+
return results
|
1081
|
+
|
1082
|
+
@classmethod
|
1083
|
+
def eval_obj_reference(cls, target, metric_func, reference, *args, **kwargs):
|
1084
|
+
results = []
|
1085
|
+
for o, r in zip(target["outputs"], reference["outputs"]):
|
1086
|
+
if isinstance(o, dict):
|
1087
|
+
if not isinstance(r, dict):
|
1088
|
+
raise ValueError("Error: format of reference doesn't match target")
|
1089
|
+
results += cls.eval_obj_reference(o, metric_func, r, *args, **kwargs)
|
1090
|
+
else:
|
1091
|
+
if isinstance(r, dict):
|
1092
|
+
raise ValueError("Error: format of reference doesn't match target")
|
1093
|
+
results.append(metric_func(o, r, *args, **kwargs))
|
1094
|
+
return results
|