janus-llm 2.1.0__py3-none-any.whl → 3.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +2 -2
- janus/__main__.py +1 -1
- janus/_tests/test_cli.py +1 -2
- janus/cli.py +43 -50
- janus/converter/__init__.py +6 -0
- janus/converter/_tests/__init__.py +0 -0
- janus/{_tests → converter/_tests}/test_translate.py +11 -22
- janus/converter/converter.py +614 -0
- janus/converter/diagram.py +124 -0
- janus/converter/document.py +131 -0
- janus/converter/evaluate.py +15 -0
- janus/converter/requirements.py +51 -0
- janus/converter/translate.py +108 -0
- janus/language/block.py +1 -1
- janus/language/combine.py +0 -1
- janus/language/treesitter/treesitter.py +20 -1
- janus/llm/model_callbacks.py +33 -36
- janus/llm/models_info.py +14 -0
- janus/metrics/reading.py +27 -5
- janus/prompts/prompt.py +37 -11
- {janus_llm-2.1.0.dist-info → janus_llm-3.0.1.dist-info}/METADATA +1 -1
- {janus_llm-2.1.0.dist-info → janus_llm-3.0.1.dist-info}/RECORD +25 -19
- janus/converter.py +0 -161
- janus/translate.py +0 -987
- {janus_llm-2.1.0.dist-info → janus_llm-3.0.1.dist-info}/LICENSE +0 -0
- {janus_llm-2.1.0.dist-info → janus_llm-3.0.1.dist-info}/WHEEL +0 -0
- {janus_llm-2.1.0.dist-info → janus_llm-3.0.1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,614 @@
|
|
1
|
+
import functools
|
2
|
+
import json
|
3
|
+
import math
|
4
|
+
import time
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
from langchain.output_parsers import RetryWithErrorOutputParser
|
9
|
+
from langchain.output_parsers.fix import OutputFixingParser
|
10
|
+
from langchain_core.exceptions import OutputParserException
|
11
|
+
from langchain_core.language_models import BaseLanguageModel
|
12
|
+
from langchain_core.output_parsers import BaseOutputParser
|
13
|
+
from langchain_core.prompts import ChatPromptTemplate
|
14
|
+
from langchain_core.runnables import RunnableLambda, RunnableParallel
|
15
|
+
from openai import BadRequestError, RateLimitError
|
16
|
+
from pydantic import ValidationError
|
17
|
+
|
18
|
+
from janus.embedding.vectorize import ChromaDBVectorizer
|
19
|
+
from janus.language.block import CodeBlock, TranslatedCodeBlock
|
20
|
+
from janus.language.combine import Combiner
|
21
|
+
from janus.language.naive.registry import CUSTOM_SPLITTERS
|
22
|
+
from janus.language.splitter import (
|
23
|
+
EmptyTreeError,
|
24
|
+
FileSizeError,
|
25
|
+
Splitter,
|
26
|
+
TokenLimitError,
|
27
|
+
)
|
28
|
+
from janus.llm import load_model
|
29
|
+
from janus.llm.model_callbacks import get_model_callback
|
30
|
+
from janus.llm.models_info import MODEL_PROMPT_ENGINES
|
31
|
+
from janus.parsers.code_parser import GenericParser
|
32
|
+
from janus.utils.enums import LANGUAGES
|
33
|
+
from janus.utils.logger import create_logger
|
34
|
+
|
35
|
+
log = create_logger(__name__)
|
36
|
+
|
37
|
+
|
38
|
+
def run_if_changed(*tracked_vars):
|
39
|
+
"""Wrapper to skip function calls if the given instance attributes haven't
|
40
|
+
been updated. Requires the _changed_attrs set to exist, and the __setattr__
|
41
|
+
method to be overridden to track parameter updates in _changed_attrs.
|
42
|
+
"""
|
43
|
+
|
44
|
+
def wrapper(func):
|
45
|
+
@functools.wraps(func)
|
46
|
+
def wrapped(self, *args, **kwargs):
|
47
|
+
# If there is overlap between the tracked variables and the changed
|
48
|
+
# ones, then call the function as normal
|
49
|
+
if not tracked_vars or self._changed_attrs.intersection(tracked_vars):
|
50
|
+
func(self, *args, **kwargs)
|
51
|
+
|
52
|
+
return wrapped
|
53
|
+
|
54
|
+
return wrapper
|
55
|
+
|
56
|
+
|
57
|
+
class Converter:
|
58
|
+
"""Parent class that converts code into something else.
|
59
|
+
|
60
|
+
Children will determine what the code gets converted into. Whether that's translated
|
61
|
+
into another language, into pseudocode, requirements, documentation, etc., or
|
62
|
+
converted into embeddings
|
63
|
+
"""
|
64
|
+
|
65
|
+
def __init__(
|
66
|
+
self,
|
67
|
+
model: str = "gpt-3.5-turbo-0125",
|
68
|
+
model_arguments: dict[str, Any] = {},
|
69
|
+
source_language: str = "fortran",
|
70
|
+
max_prompts: int = 10,
|
71
|
+
max_tokens: int | None = None,
|
72
|
+
prompt_template: str = "simple",
|
73
|
+
db_path: str | None = None,
|
74
|
+
db_config: dict[str, Any] | None = None,
|
75
|
+
protected_node_types: tuple[str, ...] = (),
|
76
|
+
prune_node_types: tuple[str, ...] = (),
|
77
|
+
splitter_type: str = "file",
|
78
|
+
) -> None:
|
79
|
+
"""Initialize a Converter instance.
|
80
|
+
|
81
|
+
Arguments:
|
82
|
+
source_language: The source programming language.
|
83
|
+
parser_type: The type of parser to use for parsing the LLM output. Valid
|
84
|
+
values are `"code"`, `"text"`, `"eval"`, and `None` (default). If `None`,
|
85
|
+
the `Converter` assumes you won't be parsing an output (i.e., adding to an
|
86
|
+
embedding DB).
|
87
|
+
"""
|
88
|
+
self._changed_attrs: set = set()
|
89
|
+
|
90
|
+
self.max_prompts: int = max_prompts
|
91
|
+
self._max_tokens: int | None = max_tokens
|
92
|
+
self.override_token_limit: bool = max_tokens is not None
|
93
|
+
|
94
|
+
self._model_name: str
|
95
|
+
self._custom_model_arguments: dict[str, Any]
|
96
|
+
|
97
|
+
self._source_language: str
|
98
|
+
self._source_suffix: str
|
99
|
+
|
100
|
+
self._target_language = "json"
|
101
|
+
self._target_suffix = ".json"
|
102
|
+
|
103
|
+
self._protected_node_types: tuple[str, ...] = ()
|
104
|
+
self._prune_node_types: tuple[str, ...] = ()
|
105
|
+
self._max_tokens: int | None = max_tokens
|
106
|
+
self._prompt_template_name: str
|
107
|
+
self._splitter_type: str
|
108
|
+
self._db_path: str | None
|
109
|
+
self._db_config: dict[str, Any] | None
|
110
|
+
|
111
|
+
self._splitter: Splitter
|
112
|
+
self._llm: BaseLanguageModel
|
113
|
+
self._prompt: ChatPromptTemplate
|
114
|
+
|
115
|
+
self._parser: BaseOutputParser = GenericParser()
|
116
|
+
self._combiner: Combiner = Combiner()
|
117
|
+
|
118
|
+
self.set_splitter(splitter_type=splitter_type)
|
119
|
+
self.set_model(model_name=model, **model_arguments)
|
120
|
+
self.set_prompt(prompt_template=prompt_template)
|
121
|
+
self.set_source_language(source_language)
|
122
|
+
self.set_protected_node_types(protected_node_types)
|
123
|
+
self.set_prune_node_types(prune_node_types)
|
124
|
+
self.set_db_path(db_path=db_path)
|
125
|
+
self.set_db_config(db_config=db_config)
|
126
|
+
|
127
|
+
# Child class must call this. Should we enforce somehow?
|
128
|
+
# self._load_parameters()
|
129
|
+
|
130
|
+
def __setattr__(self, key: Any, value: Any) -> None:
|
131
|
+
if hasattr(self, "_changed_attrs"):
|
132
|
+
if not hasattr(self, key) or getattr(self, key) != value:
|
133
|
+
self._changed_attrs.add(key)
|
134
|
+
# Avoid infinite recursion
|
135
|
+
elif key != "_changed_attrs":
|
136
|
+
self._changed_attrs = set()
|
137
|
+
super().__setattr__(key, value)
|
138
|
+
|
139
|
+
def _load_parameters(self) -> None:
|
140
|
+
self._load_model()
|
141
|
+
self._load_prompt()
|
142
|
+
self._load_splitter()
|
143
|
+
self._load_vectorizer()
|
144
|
+
self._changed_attrs.clear()
|
145
|
+
|
146
|
+
def set_model(self, model_name: str, **custom_arguments: dict[str, Any]):
|
147
|
+
"""Validate and set the model name.
|
148
|
+
|
149
|
+
The affected objects will not be updated until translate() is called.
|
150
|
+
|
151
|
+
Arguments:
|
152
|
+
model_name: The name of the model to use. Valid models are found in
|
153
|
+
`janus.llm.models_info.MODEL_CONSTRUCTORS`.
|
154
|
+
custom_arguments: Additional arguments to pass to the model constructor.
|
155
|
+
"""
|
156
|
+
self._model_name = model_name
|
157
|
+
self._custom_model_arguments = custom_arguments
|
158
|
+
|
159
|
+
def set_prompt(self, prompt_template: str) -> None:
|
160
|
+
"""Validate and set the prompt template name.
|
161
|
+
|
162
|
+
The affected objects will not be updated until translate() is called.
|
163
|
+
|
164
|
+
Arguments:
|
165
|
+
prompt_template: name of prompt template directory
|
166
|
+
(see janus/prompts/templates) or path to a directory.
|
167
|
+
"""
|
168
|
+
self._prompt_template_name = prompt_template
|
169
|
+
|
170
|
+
def set_splitter(self, splitter_type: str) -> None:
|
171
|
+
"""Validate and set the prompt template name.
|
172
|
+
|
173
|
+
The affected objects will not be updated until translate() is called.
|
174
|
+
|
175
|
+
Arguments:
|
176
|
+
prompt_template: name of prompt template directory
|
177
|
+
(see janus/prompts/templates) or path to a directory.
|
178
|
+
"""
|
179
|
+
self._splitter_type = splitter_type
|
180
|
+
|
181
|
+
def set_source_language(self, source_language: str) -> None:
|
182
|
+
"""Validate and set the source language.
|
183
|
+
|
184
|
+
The affected objects will not be updated until _load_parameters() is called.
|
185
|
+
|
186
|
+
Arguments:
|
187
|
+
source_language: The source programming language.
|
188
|
+
"""
|
189
|
+
source_language = source_language.lower()
|
190
|
+
if source_language not in LANGUAGES:
|
191
|
+
raise ValueError(
|
192
|
+
f"Invalid source language: {source_language}. "
|
193
|
+
"Valid source languages are found in `janus.utils.enums.LANGUAGES`."
|
194
|
+
)
|
195
|
+
|
196
|
+
ext = LANGUAGES[source_language]["suffix"]
|
197
|
+
self._source_suffix = f".{ext}"
|
198
|
+
self._source_language = source_language
|
199
|
+
|
200
|
+
def set_protected_node_types(self, protected_node_types: tuple[str, ...]) -> None:
|
201
|
+
"""Set the protected (non-mergeable) node types. This will often be structures
|
202
|
+
like functions, classes, or modules which you might want to keep separate
|
203
|
+
|
204
|
+
The affected objects will not be updated until _load_parameters() is called.
|
205
|
+
|
206
|
+
Arguments:
|
207
|
+
protected_node_types: A set of node types that aren't to be merged
|
208
|
+
"""
|
209
|
+
self._protected_node_types = tuple(set(protected_node_types or []))
|
210
|
+
|
211
|
+
def set_prune_node_types(self, prune_node_types: tuple[str, ...]) -> None:
|
212
|
+
"""Set the node types to prune. This will often be structures
|
213
|
+
like comments or whitespace which you might want to keep out of the LLM
|
214
|
+
|
215
|
+
The affected objects will not be updated until _load_parameters() is called.
|
216
|
+
|
217
|
+
Arguments:
|
218
|
+
prune_node_types: A set of node types which should be pruned
|
219
|
+
"""
|
220
|
+
self._prune_node_types = tuple(set(prune_node_types or []))
|
221
|
+
|
222
|
+
def set_db_path(self, db_path: str | None) -> None:
|
223
|
+
self._db_path = db_path
|
224
|
+
|
225
|
+
def set_db_config(self, db_config: dict[str, Any] | None) -> None:
|
226
|
+
self._db_config = db_config
|
227
|
+
|
228
|
+
@run_if_changed(
|
229
|
+
"_source_language",
|
230
|
+
"_max_tokens",
|
231
|
+
"_llm",
|
232
|
+
"_protected_node_types",
|
233
|
+
"_prune_node_types",
|
234
|
+
"_custom_splitter",
|
235
|
+
)
|
236
|
+
def _load_splitter(self) -> None:
|
237
|
+
"""Load the splitter according to this instance's attributes.
|
238
|
+
|
239
|
+
If the relevant fields have not been changed since the last time this method was
|
240
|
+
called, nothing happens.
|
241
|
+
"""
|
242
|
+
kwargs: dict[str, Any] = dict(
|
243
|
+
language=self._source_language,
|
244
|
+
max_tokens=self._max_tokens,
|
245
|
+
model=self._llm,
|
246
|
+
protected_node_types=self._protected_node_types,
|
247
|
+
prune_node_types=self._prune_node_types,
|
248
|
+
)
|
249
|
+
|
250
|
+
if self._splitter_type == "tag":
|
251
|
+
kwargs["tag"] = "<ITMOD_ALC_SPLIT>"
|
252
|
+
|
253
|
+
self._splitter = CUSTOM_SPLITTERS[self._splitter_type](**kwargs)
|
254
|
+
|
255
|
+
@run_if_changed("_model_name", "_custom_model_arguments")
|
256
|
+
def _load_model(self) -> None:
|
257
|
+
"""Load the model according to this instance's attributes.
|
258
|
+
|
259
|
+
If the relevant fields have not been changed since the last time this method was
|
260
|
+
called, nothing happens.
|
261
|
+
"""
|
262
|
+
|
263
|
+
# Get default arguments, set custom ones
|
264
|
+
# model_arguments = deepcopy(MODEL_DEFAULT_ARGUMENTS[self._model_name])
|
265
|
+
# model_arguments.update(self._custom_model_arguments)
|
266
|
+
|
267
|
+
# Load the model
|
268
|
+
self._llm, token_limit, self.model_cost = load_model(self._model_name)
|
269
|
+
# Set the max_tokens to less than half the model's limit to allow for enough
|
270
|
+
# tokens at output
|
271
|
+
# Only modify max_tokens if it is not specified by user
|
272
|
+
if not self.override_token_limit:
|
273
|
+
self._max_tokens = int(token_limit // 2.5)
|
274
|
+
|
275
|
+
@run_if_changed(
|
276
|
+
"_prompt_template_name",
|
277
|
+
"_source_language",
|
278
|
+
"_model_name",
|
279
|
+
)
|
280
|
+
def _load_prompt(self) -> None:
|
281
|
+
"""Load the prompt according to this instance's attributes.
|
282
|
+
|
283
|
+
If the relevant fields have not been changed since the last time this
|
284
|
+
method was called, nothing happens.
|
285
|
+
"""
|
286
|
+
prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
|
287
|
+
source_language=self._source_language,
|
288
|
+
prompt_template=self._prompt_template_name,
|
289
|
+
)
|
290
|
+
self._prompt = prompt_engine.prompt
|
291
|
+
|
292
|
+
@run_if_changed("_db_path", "_db_config")
|
293
|
+
def _load_vectorizer(self) -> None:
|
294
|
+
if self._db_path is None or self._db_config is None:
|
295
|
+
self._vectorizer = None
|
296
|
+
return
|
297
|
+
vectorizer_factory = ChromaDBVectorizer()
|
298
|
+
self._vectorizer = vectorizer_factory.create_vectorizer(
|
299
|
+
self._db_path, self._db_config
|
300
|
+
)
|
301
|
+
|
302
|
+
def translate(
|
303
|
+
self,
|
304
|
+
input_directory: str | Path,
|
305
|
+
output_directory: str | Path | None = None,
|
306
|
+
overwrite: bool = False,
|
307
|
+
collection_name: str | None = None,
|
308
|
+
) -> None:
|
309
|
+
"""Convert code in the input directory from the source language to the target
|
310
|
+
language, and write the resulting files to the output directory.
|
311
|
+
|
312
|
+
Arguments:
|
313
|
+
input_directory: The directory containing the code to translate.
|
314
|
+
output_directory: The directory to write the translated code to.
|
315
|
+
overwrite: Whether to overwrite existing files (vs skip them)
|
316
|
+
collection_name: Collection to add to
|
317
|
+
"""
|
318
|
+
# Convert paths to pathlib Paths if needed
|
319
|
+
if isinstance(input_directory, str):
|
320
|
+
input_directory = Path(input_directory)
|
321
|
+
if isinstance(output_directory, str):
|
322
|
+
output_directory = Path(output_directory)
|
323
|
+
|
324
|
+
# Make sure the output directory exists
|
325
|
+
if output_directory is not None and not output_directory.exists():
|
326
|
+
output_directory.mkdir(parents=True)
|
327
|
+
|
328
|
+
input_paths = [p for p in input_directory.rglob(f"**/*{self._source_suffix}")]
|
329
|
+
|
330
|
+
log.info(f"Input directory: {input_directory.absolute()}")
|
331
|
+
log.info(
|
332
|
+
f"{self._source_language} '*{self._source_suffix}' files: "
|
333
|
+
f"{len(input_paths)}"
|
334
|
+
)
|
335
|
+
log.info(
|
336
|
+
"Other files (skipped): "
|
337
|
+
f"{len(list(input_directory.iterdir())) - len(input_paths)}\n"
|
338
|
+
)
|
339
|
+
if output_directory is not None:
|
340
|
+
output_paths = [
|
341
|
+
output_directory
|
342
|
+
/ p.relative_to(input_directory).with_suffix(self._target_suffix)
|
343
|
+
for p in input_paths
|
344
|
+
]
|
345
|
+
in_out_pairs = list(zip(input_paths, output_paths))
|
346
|
+
if not overwrite:
|
347
|
+
n_files = len(in_out_pairs)
|
348
|
+
in_out_pairs = [
|
349
|
+
(inp, outp) for inp, outp in in_out_pairs if not outp.exists()
|
350
|
+
]
|
351
|
+
log.info(
|
352
|
+
f"Skipping {n_files - len(in_out_pairs)} existing "
|
353
|
+
f"'*{self._source_suffix}' files"
|
354
|
+
)
|
355
|
+
else:
|
356
|
+
in_out_pairs = [(f, None) for f in input_paths]
|
357
|
+
log.info(f"Translating {len(in_out_pairs)} '*{self._source_suffix}' files")
|
358
|
+
|
359
|
+
# Loop through each input file, convert and save it
|
360
|
+
total_cost = 0.0
|
361
|
+
for in_path, out_path in in_out_pairs:
|
362
|
+
# Translate the file, skip it if there's a rate limit error
|
363
|
+
try:
|
364
|
+
out_block = self.translate_file(in_path)
|
365
|
+
total_cost += out_block.total_cost
|
366
|
+
except RateLimitError:
|
367
|
+
continue
|
368
|
+
except OutputParserException as e:
|
369
|
+
log.error(f"Skipping {in_path.name}, failed to parse output: {e}.")
|
370
|
+
continue
|
371
|
+
except BadRequestError as e:
|
372
|
+
if str(e).startswith("Detected an error in the prompt"):
|
373
|
+
log.warning("Malformed input, skipping")
|
374
|
+
continue
|
375
|
+
raise e
|
376
|
+
except ValidationError as e:
|
377
|
+
# Only allow ValidationError to pass if token limit is manually set
|
378
|
+
if self.override_token_limit:
|
379
|
+
log.warning(
|
380
|
+
"Current file and manually set token "
|
381
|
+
"limit is too large for this model, skipping"
|
382
|
+
)
|
383
|
+
continue
|
384
|
+
raise e
|
385
|
+
except TokenLimitError:
|
386
|
+
log.warning("Ran into irreducible node too large for context, skipping")
|
387
|
+
continue
|
388
|
+
except EmptyTreeError:
|
389
|
+
log.warning(
|
390
|
+
f'Input file "{in_path.name}" has no nodes of interest, skipping'
|
391
|
+
)
|
392
|
+
continue
|
393
|
+
except FileSizeError:
|
394
|
+
log.warning("Current tile is too large for basic splitter, skipping")
|
395
|
+
continue
|
396
|
+
|
397
|
+
# Don't attempt to write files for which translation failed
|
398
|
+
if not out_block.translated:
|
399
|
+
continue
|
400
|
+
|
401
|
+
if collection_name is not None:
|
402
|
+
self._vectorizer.add_nodes_recursively(
|
403
|
+
out_block,
|
404
|
+
collection_name,
|
405
|
+
in_path.name,
|
406
|
+
)
|
407
|
+
|
408
|
+
# Make sure the tree's code has been consolidated at the top level
|
409
|
+
# before writing to file
|
410
|
+
self._combiner.combine(out_block)
|
411
|
+
if out_path is not None and (overwrite or not out_path.exists()):
|
412
|
+
self._save_to_file(out_block, out_path)
|
413
|
+
|
414
|
+
log.info(f"Total cost: ${total_cost:,.2f}")
|
415
|
+
|
416
|
+
def translate_file(self, file: Path) -> TranslatedCodeBlock:
|
417
|
+
"""Translate a single file.
|
418
|
+
|
419
|
+
Arguments:
|
420
|
+
file: Input path to file
|
421
|
+
|
422
|
+
Returns:
|
423
|
+
A `TranslatedCodeBlock` object. This block does not have a path set, and its
|
424
|
+
code is not guaranteed to be consolidated. To amend this, run
|
425
|
+
`Combiner.combine_children` on the block.
|
426
|
+
"""
|
427
|
+
self._load_parameters()
|
428
|
+
filename = file.name
|
429
|
+
|
430
|
+
input_block = self._split_file(file)
|
431
|
+
t0 = time.time()
|
432
|
+
output_block = self._iterative_translate(input_block)
|
433
|
+
output_block.processing_time = time.time() - t0
|
434
|
+
if output_block.translated:
|
435
|
+
completeness = output_block.translation_completeness
|
436
|
+
log.info(
|
437
|
+
f"[{filename}] Translation complete\n"
|
438
|
+
f" {completeness:.2%} of input successfully translated\n"
|
439
|
+
f" Total cost: ${output_block.total_cost:,.2f}\n"
|
440
|
+
f" Total retries: {output_block.total_retries:,d}\n"
|
441
|
+
f" Output CodeBlock Structure:\n{input_block.tree_str()}\n"
|
442
|
+
)
|
443
|
+
|
444
|
+
else:
|
445
|
+
log.error(
|
446
|
+
f"[{filename}] Translation failed\n"
|
447
|
+
f" Total cost: ${output_block.total_cost:,.2f}\n"
|
448
|
+
f" Total retries: {output_block.total_retries:,d}\n"
|
449
|
+
)
|
450
|
+
return output_block
|
451
|
+
|
452
|
+
def _iterative_translate(self, root: CodeBlock) -> TranslatedCodeBlock:
|
453
|
+
"""Translate the passed CodeBlock representing a full file.
|
454
|
+
|
455
|
+
Arguments:
|
456
|
+
root: A root block representing the top-level block of a file
|
457
|
+
|
458
|
+
Returns:
|
459
|
+
A `TranslatedCodeBlock`
|
460
|
+
"""
|
461
|
+
translated_root = TranslatedCodeBlock(root, self._target_language)
|
462
|
+
last_prog, prog_delta = 0, 0.1
|
463
|
+
stack = [translated_root]
|
464
|
+
while stack:
|
465
|
+
translated_block = stack.pop()
|
466
|
+
|
467
|
+
self._add_translation(translated_block)
|
468
|
+
|
469
|
+
# If translating this block was unsuccessful, don't bother with its
|
470
|
+
# children (they wouldn't show up in the final text anyway)
|
471
|
+
if not translated_block.translated:
|
472
|
+
continue
|
473
|
+
|
474
|
+
stack.extend(translated_block.children)
|
475
|
+
|
476
|
+
progress = translated_root.translation_completeness
|
477
|
+
if progress - last_prog > prog_delta:
|
478
|
+
last_prog = int(progress / prog_delta) * prog_delta
|
479
|
+
log.info(f"[{root.name}] progress: {progress:.2%}")
|
480
|
+
|
481
|
+
return translated_root
|
482
|
+
|
483
|
+
def _add_translation(self, block: TranslatedCodeBlock) -> None:
|
484
|
+
"""Given an "empty" `TranslatedCodeBlock`, translate the code represented in
|
485
|
+
`block.original`, setting the relevant fields in the translated block. The
|
486
|
+
`TranslatedCodeBlock` is updated in-pace, nothing is returned. Note that this
|
487
|
+
translates *only* the code for this block, not its children.
|
488
|
+
|
489
|
+
Arguments:
|
490
|
+
block: An empty `TranslatedCodeBlock`
|
491
|
+
"""
|
492
|
+
if block.translated:
|
493
|
+
return
|
494
|
+
|
495
|
+
if block.original.text is None:
|
496
|
+
block.translated = True
|
497
|
+
return
|
498
|
+
|
499
|
+
if self._llm is None:
|
500
|
+
message = (
|
501
|
+
"Model not configured correctly, cannot translate. Try setting "
|
502
|
+
"the model"
|
503
|
+
)
|
504
|
+
log.error(message)
|
505
|
+
raise ValueError(message)
|
506
|
+
|
507
|
+
log.debug(f"[{block.name}] Translating...")
|
508
|
+
log.debug(f"[{block.name}] Input text:\n{block.original.text}")
|
509
|
+
|
510
|
+
# Track the cost of translating this block
|
511
|
+
# TODO: If non-OpenAI models with prices are added, this will need
|
512
|
+
# to be updated.
|
513
|
+
with get_model_callback() as cb:
|
514
|
+
t0 = time.time()
|
515
|
+
block.text = self._run_chain(block)
|
516
|
+
block.processing_time = time.time() - t0
|
517
|
+
block.cost = cb.total_cost
|
518
|
+
block.retries = max(0, cb.successful_requests - 1)
|
519
|
+
|
520
|
+
block.tokens = self._llm.get_num_tokens(block.text)
|
521
|
+
block.translated = True
|
522
|
+
|
523
|
+
log.debug(f"[{block.name}] Output code:\n{block.text}")
|
524
|
+
|
525
|
+
def _split_file(self, file: Path) -> CodeBlock:
|
526
|
+
filename = file.name
|
527
|
+
log.info(f"[{filename}] Splitting file")
|
528
|
+
root = self._splitter.split(file)
|
529
|
+
log.info(
|
530
|
+
f"[{filename}] File split into {root.n_descendents:,} blocks, "
|
531
|
+
f"tree of height {root.height}"
|
532
|
+
)
|
533
|
+
log.info(f"[{filename}] Input CodeBlock Structure:\n{root.tree_str()}")
|
534
|
+
return root
|
535
|
+
|
536
|
+
def _run_chain(self, block: TranslatedCodeBlock) -> str:
|
537
|
+
"""Run the model with three nested error fixing schemes.
|
538
|
+
First, try to fix simple formatting errors by giving the model just
|
539
|
+
the output and the parsing error. After a number of attempts, try
|
540
|
+
giving the model the output, the parsing error, and the original
|
541
|
+
input. Again check/retry this output to solve for formatting errors.
|
542
|
+
If we still haven't succeeded after several attempts, the model may
|
543
|
+
be getting thrown off by a bad initial output; start from scratch
|
544
|
+
and try again.
|
545
|
+
|
546
|
+
The number of tries for each layer of this scheme is roughly equal
|
547
|
+
to the cube root of self.max_retries, so the total calls to the
|
548
|
+
LLM will be roughly as expected (up to sqrt(self.max_retries) over)
|
549
|
+
"""
|
550
|
+
self._parser.set_reference(block.original)
|
551
|
+
|
552
|
+
# Retries with just the output and the error
|
553
|
+
n1 = round(self.max_prompts ** (1 / 3))
|
554
|
+
|
555
|
+
# Retries with the input, output, and error
|
556
|
+
n2 = round((self.max_prompts // n1) ** (1 / 2))
|
557
|
+
|
558
|
+
# Retries with just the input
|
559
|
+
n3 = math.ceil(self.max_prompts / (n1 * n2))
|
560
|
+
|
561
|
+
fix_format = OutputFixingParser.from_llm(
|
562
|
+
llm=self._llm,
|
563
|
+
parser=self._parser,
|
564
|
+
max_retries=n1,
|
565
|
+
)
|
566
|
+
retry = RetryWithErrorOutputParser.from_llm(
|
567
|
+
llm=self._llm,
|
568
|
+
parser=fix_format,
|
569
|
+
max_retries=n2,
|
570
|
+
)
|
571
|
+
|
572
|
+
completion_chain = self._prompt | self._llm
|
573
|
+
chain = RunnableParallel(
|
574
|
+
completion=completion_chain, prompt_value=self._prompt
|
575
|
+
) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
|
576
|
+
|
577
|
+
for _ in range(n3):
|
578
|
+
try:
|
579
|
+
return chain.invoke({"SOURCE_CODE": block.original.text})
|
580
|
+
except OutputParserException:
|
581
|
+
pass
|
582
|
+
|
583
|
+
raise OutputParserException(f"Failed to parse after {n1*n2*n3} retries")
|
584
|
+
|
585
|
+
def _get_output_obj(
|
586
|
+
self, block: TranslatedCodeBlock
|
587
|
+
) -> dict[str, int | float | str | dict[str, str]]:
|
588
|
+
output_str = self._parser.parse_combined_output(block.complete_text)
|
589
|
+
|
590
|
+
output: str | dict[str, str]
|
591
|
+
try:
|
592
|
+
output = json.loads(output_str)
|
593
|
+
except json.JSONDecodeError:
|
594
|
+
output = output_str
|
595
|
+
|
596
|
+
return dict(
|
597
|
+
input=block.original.text,
|
598
|
+
metadata=dict(
|
599
|
+
retries=block.total_retries,
|
600
|
+
cost=block.total_cost,
|
601
|
+
processing_time=block.processing_time,
|
602
|
+
),
|
603
|
+
output=output,
|
604
|
+
)
|
605
|
+
|
606
|
+
def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
|
607
|
+
"""Save a file to disk.
|
608
|
+
|
609
|
+
Arguments:
|
610
|
+
block: The `TranslatedCodeBlock` to save to a file.
|
611
|
+
"""
|
612
|
+
obj = self._get_output_obj(block)
|
613
|
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
614
|
+
out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")
|