janus-llm 3.5.2__py3-none-any.whl → 4.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +1 -1
- janus/cli.py +90 -42
- janus/converter/converter.py +111 -142
- janus/converter/diagram.py +21 -109
- janus/converter/translate.py +1 -1
- janus/language/alc/_tests/test_alc.py +1 -1
- janus/language/alc/alc.py +16 -11
- janus/language/binary/_tests/test_binary.py +1 -1
- janus/language/binary/binary.py +2 -2
- janus/language/mumps/_tests/test_mumps.py +1 -1
- janus/language/mumps/mumps.py +2 -3
- janus/language/naive/simple_ast.py +3 -2
- janus/language/splitter.py +7 -4
- janus/language/treesitter/_tests/test_treesitter.py +1 -1
- janus/language/treesitter/treesitter.py +2 -2
- janus/llm/model_callbacks.py +13 -0
- janus/llm/models_info.py +118 -71
- janus/metrics/metric.py +15 -14
- janus/parsers/uml.py +60 -23
- janus/refiners/refiner.py +106 -64
- janus/retrievers/retriever.py +42 -0
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/METADATA +1 -1
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/RECORD +26 -26
- janus/parsers/refiner_parser.py +0 -46
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/LICENSE +0 -0
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/WHEEL +0 -0
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/entry_points.txt +0 -0
janus/__init__.py
CHANGED
@@ -5,7 +5,7 @@ from langchain_core._api.deprecation import LangChainDeprecationWarning
|
|
5
5
|
from janus.converter.translate import Translator
|
6
6
|
from janus.metrics import * # noqa: F403
|
7
7
|
|
8
|
-
__version__ = "
|
8
|
+
__version__ = "4.0.0"
|
9
9
|
|
10
10
|
# Ignoring a deprecation warning from langchain_core that I can't seem to hunt down
|
11
11
|
warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
|
janus/cli.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import json
|
2
2
|
import logging
|
3
3
|
import os
|
4
|
+
import subprocess # nosec
|
4
5
|
from pathlib import Path
|
5
6
|
from typing import List, Optional
|
6
7
|
|
@@ -42,6 +43,7 @@ from janus.llm.models_info import (
|
|
42
43
|
openai_models,
|
43
44
|
)
|
44
45
|
from janus.metrics.cli import evaluate
|
46
|
+
from janus.refiners.refiner import REFINERS
|
45
47
|
from janus.utils.enums import LANGUAGES
|
46
48
|
from janus.utils.logger import create_logger
|
47
49
|
|
@@ -241,6 +243,24 @@ def translate(
|
|
241
243
|
click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
|
242
244
|
),
|
243
245
|
] = "file",
|
246
|
+
refiner_type: Annotated[
|
247
|
+
str,
|
248
|
+
typer.Option(
|
249
|
+
"-r",
|
250
|
+
"--refiner",
|
251
|
+
help="Name of custom refiner to use",
|
252
|
+
click_type=click.Choice(list(REFINERS.keys())),
|
253
|
+
),
|
254
|
+
] = "none",
|
255
|
+
retriever_type: Annotated[
|
256
|
+
str,
|
257
|
+
typer.Option(
|
258
|
+
"-R",
|
259
|
+
"--retriever",
|
260
|
+
help="Name of custom retriever to use",
|
261
|
+
click_type=click.Choice(["active_usings"]),
|
262
|
+
),
|
263
|
+
] = None,
|
244
264
|
max_tokens: Annotated[
|
245
265
|
int,
|
246
266
|
typer.Option(
|
@@ -250,13 +270,6 @@ def translate(
|
|
250
270
|
"If unspecificed, model's default max will be used.",
|
251
271
|
),
|
252
272
|
] = None,
|
253
|
-
skip_refiner: Annotated[
|
254
|
-
bool,
|
255
|
-
typer.Option(
|
256
|
-
"--skip-refiner",
|
257
|
-
help="Whether to skip the refiner for generating output",
|
258
|
-
),
|
259
|
-
] = True,
|
260
273
|
):
|
261
274
|
try:
|
262
275
|
target_language, target_version = target_lang.split("-")
|
@@ -282,8 +295,8 @@ def translate(
|
|
282
295
|
db_path=db_loc,
|
283
296
|
db_config=collections_config,
|
284
297
|
splitter_type=splitter_type,
|
285
|
-
|
286
|
-
|
298
|
+
refiner_type=refiner_type,
|
299
|
+
retriever_type=retriever_type,
|
287
300
|
)
|
288
301
|
translator.translate(input_dir, output_dir, overwrite, collection)
|
289
302
|
|
@@ -341,14 +354,6 @@ def document(
|
|
341
354
|
help="Whether to overwrite existing files in the output directory",
|
342
355
|
),
|
343
356
|
] = False,
|
344
|
-
skip_context: Annotated[
|
345
|
-
bool,
|
346
|
-
typer.Option(
|
347
|
-
"--skip-context",
|
348
|
-
help="Prompts will include any context information associated with source"
|
349
|
-
" code blocks, unless this option is specified",
|
350
|
-
),
|
351
|
-
] = False,
|
352
357
|
doc_mode: Annotated[
|
353
358
|
str,
|
354
359
|
typer.Option(
|
@@ -396,6 +401,24 @@ def document(
|
|
396
401
|
click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
|
397
402
|
),
|
398
403
|
] = "file",
|
404
|
+
refiner_type: Annotated[
|
405
|
+
str,
|
406
|
+
typer.Option(
|
407
|
+
"-r",
|
408
|
+
"--refiner",
|
409
|
+
help="Name of custom refiner to use",
|
410
|
+
click_type=click.Choice(list(REFINERS.keys())),
|
411
|
+
),
|
412
|
+
] = "none",
|
413
|
+
retriever_type: Annotated[
|
414
|
+
str,
|
415
|
+
typer.Option(
|
416
|
+
"-R",
|
417
|
+
"--retriever",
|
418
|
+
help="Name of custom retriever to use",
|
419
|
+
click_type=click.Choice(["active_usings"]),
|
420
|
+
),
|
421
|
+
] = None,
|
399
422
|
max_tokens: Annotated[
|
400
423
|
int,
|
401
424
|
typer.Option(
|
@@ -405,13 +428,6 @@ def document(
|
|
405
428
|
"If unspecificed, model's default max will be used.",
|
406
429
|
),
|
407
430
|
] = None,
|
408
|
-
skip_refiner: Annotated[
|
409
|
-
bool,
|
410
|
-
typer.Option(
|
411
|
-
"--skip-refiner",
|
412
|
-
help="Whether to skip the refiner for generating output",
|
413
|
-
),
|
414
|
-
] = True,
|
415
431
|
):
|
416
432
|
model_arguments = dict(temperature=temperature)
|
417
433
|
collections_config = get_collections_config()
|
@@ -424,8 +440,8 @@ def document(
|
|
424
440
|
db_path=db_loc,
|
425
441
|
db_config=collections_config,
|
426
442
|
splitter_type=splitter_type,
|
427
|
-
|
428
|
-
|
443
|
+
refiner_type=refiner_type,
|
444
|
+
retriever_type=retriever_type,
|
429
445
|
)
|
430
446
|
if doc_mode == "madlibs":
|
431
447
|
documenter = MadLibsDocumenter(
|
@@ -614,14 +630,6 @@ def diagram(
|
|
614
630
|
help="Whether to overwrite existing files in the output directory",
|
615
631
|
),
|
616
632
|
] = False,
|
617
|
-
skip_context: Annotated[
|
618
|
-
bool,
|
619
|
-
typer.Option(
|
620
|
-
"--skip-context",
|
621
|
-
help="Prompts will include any context information associated with source"
|
622
|
-
" code blocks, unless this option is specified",
|
623
|
-
),
|
624
|
-
] = False,
|
625
633
|
temperature: Annotated[
|
626
634
|
float,
|
627
635
|
typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
|
@@ -658,13 +666,24 @@ def diagram(
|
|
658
666
|
click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
|
659
667
|
),
|
660
668
|
] = "file",
|
661
|
-
|
662
|
-
|
669
|
+
refiner_type: Annotated[
|
670
|
+
str,
|
663
671
|
typer.Option(
|
664
|
-
"
|
665
|
-
|
672
|
+
"-r",
|
673
|
+
"--refiner",
|
674
|
+
help="Name of custom refiner to use",
|
675
|
+
click_type=click.Choice(list(REFINERS.keys())),
|
666
676
|
),
|
667
|
-
] =
|
677
|
+
] = "none",
|
678
|
+
retriever_type: Annotated[
|
679
|
+
str,
|
680
|
+
typer.Option(
|
681
|
+
"-R",
|
682
|
+
"--retriever",
|
683
|
+
help="Name of custom retriever to use",
|
684
|
+
click_type=click.Choice(["active_usings"]),
|
685
|
+
),
|
686
|
+
] = None,
|
668
687
|
):
|
669
688
|
model_arguments = dict(temperature=temperature)
|
670
689
|
collections_config = get_collections_config()
|
@@ -675,11 +694,11 @@ def diagram(
|
|
675
694
|
max_prompts=max_prompts,
|
676
695
|
db_path=db_loc,
|
677
696
|
db_config=collections_config,
|
697
|
+
splitter_type=splitter_type,
|
698
|
+
refiner_type=refiner_type,
|
699
|
+
retriever_type=retriever_type,
|
678
700
|
diagram_type=diagram_type,
|
679
701
|
add_documentation=add_documentation,
|
680
|
-
splitter_type=splitter_type,
|
681
|
-
skip_refiner=skip_refiner,
|
682
|
-
skip_context=skip_context,
|
683
702
|
)
|
684
703
|
diagram_generator.translate(input_dir, output_dir, overwrite, collection)
|
685
704
|
|
@@ -1156,5 +1175,34 @@ app.add_typer(evaluate, name="evaluate")
|
|
1156
1175
|
app.add_typer(embedding, name="embedding")
|
1157
1176
|
|
1158
1177
|
|
1178
|
+
@app.command()
|
1179
|
+
def render(
|
1180
|
+
input_dir: Annotated[
|
1181
|
+
str,
|
1182
|
+
typer.Option(
|
1183
|
+
"--input",
|
1184
|
+
"-i",
|
1185
|
+
),
|
1186
|
+
],
|
1187
|
+
output_dir: Annotated[str, typer.Option("--output", "-o")],
|
1188
|
+
):
|
1189
|
+
input_dir = Path(input_dir)
|
1190
|
+
output_dir = Path(output_dir)
|
1191
|
+
for input_file in input_dir.rglob("*.json"):
|
1192
|
+
with open(input_file, "r") as f:
|
1193
|
+
data = json.load(f)
|
1194
|
+
|
1195
|
+
output_file = output_dir / input_file.relative_to(input_dir).with_suffix(".txt")
|
1196
|
+
if not output_file.parent.exists():
|
1197
|
+
output_file.parent.mkdir()
|
1198
|
+
|
1199
|
+
text = data["output"].replace("\\n", "\n").strip()
|
1200
|
+
output_file.write_text(text)
|
1201
|
+
|
1202
|
+
jar_path = homedir / ".janus/lib/plantuml.jar"
|
1203
|
+
subprocess.run(["java", "-jar", jar_path, output_file]) # nosec
|
1204
|
+
output_file.unlink()
|
1205
|
+
|
1206
|
+
|
1159
1207
|
if __name__ == "__main__":
|
1160
1208
|
app()
|
janus/converter/converter.py
CHANGED
@@ -2,13 +2,11 @@ import functools
|
|
2
2
|
import json
|
3
3
|
import time
|
4
4
|
from pathlib import Path
|
5
|
-
from typing import Any
|
5
|
+
from typing import Any
|
6
6
|
|
7
|
-
from langchain.output_parsers import RetryWithErrorOutputParser
|
8
7
|
from langchain_core.exceptions import OutputParserException
|
9
|
-
from langchain_core.
|
10
|
-
from langchain_core.
|
11
|
-
from langchain_core.runnables import RunnableLambda, RunnableParallel
|
8
|
+
from langchain_core.prompts import ChatPromptTemplate
|
9
|
+
from langchain_core.runnables import Runnable, RunnableParallel, RunnablePassthrough
|
12
10
|
from openai import BadRequestError, RateLimitError
|
13
11
|
from pydantic import ValidationError
|
14
12
|
|
@@ -22,12 +20,18 @@ from janus.language.splitter import (
|
|
22
20
|
Splitter,
|
23
21
|
TokenLimitError,
|
24
22
|
)
|
25
|
-
from janus.llm import load_model
|
26
23
|
from janus.llm.model_callbacks import get_model_callback
|
27
|
-
from janus.llm.models_info import MODEL_PROMPT_ENGINES
|
24
|
+
from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel, load_model
|
28
25
|
from janus.parsers.parser import GenericParser, JanusParser
|
29
|
-
from janus.
|
30
|
-
|
26
|
+
from janus.refiners.refiner import (
|
27
|
+
FixParserExceptions,
|
28
|
+
HallucinationRefiner,
|
29
|
+
JanusRefiner,
|
30
|
+
ReflectionRefiner,
|
31
|
+
)
|
32
|
+
|
33
|
+
# from janus.refiners.refiner import BasicRefiner, Refiner
|
34
|
+
from janus.retrievers.retriever import ActiveUsingsRetriever, JanusRetriever
|
31
35
|
from janus.utils.enums import LANGUAGES
|
32
36
|
from janus.utils.logger import create_logger
|
33
37
|
|
@@ -74,9 +78,8 @@ class Converter:
|
|
74
78
|
protected_node_types: tuple[str, ...] = (),
|
75
79
|
prune_node_types: tuple[str, ...] = (),
|
76
80
|
splitter_type: str = "file",
|
77
|
-
refiner_type: str =
|
78
|
-
|
79
|
-
skip_context: bool = False,
|
81
|
+
refiner_type: str | None = None,
|
82
|
+
retriever_type: str | None = None,
|
80
83
|
) -> None:
|
81
84
|
"""Initialize a Converter instance.
|
82
85
|
|
@@ -96,9 +99,13 @@ class Converter:
|
|
96
99
|
prune_node_types: A set of node types which should be pruned.
|
97
100
|
splitter_type: The type of splitter to use. Valid values are `"file"`,
|
98
101
|
`"tag"`, `"chunk"`, `"ast-strict"`, and `"ast-flex"`.
|
99
|
-
refiner_type: The type of refiner to use. Valid values
|
100
|
-
|
101
|
-
|
102
|
+
refiner_type: The type of refiner to use. Valid values:
|
103
|
+
- "parser"
|
104
|
+
- "reflection"
|
105
|
+
- None
|
106
|
+
retriever_type: The type of retriever to use. Valid values:
|
107
|
+
- "active_usings"
|
108
|
+
- None
|
102
109
|
"""
|
103
110
|
self._changed_attrs: set = set()
|
104
111
|
|
@@ -107,7 +114,6 @@ class Converter:
|
|
107
114
|
self.override_token_limit: bool = max_tokens is not None
|
108
115
|
|
109
116
|
self._model_name: str
|
110
|
-
self._model_id: str
|
111
117
|
self._custom_model_arguments: dict[str, Any]
|
112
118
|
|
113
119
|
self._source_language: str
|
@@ -120,24 +126,26 @@ class Converter:
|
|
120
126
|
self._prune_node_types: tuple[str, ...] = ()
|
121
127
|
self._max_tokens: int | None = max_tokens
|
122
128
|
self._prompt_template_name: str
|
123
|
-
self._splitter_type: str
|
124
129
|
self._db_path: str | None
|
125
130
|
self._db_config: dict[str, Any] | None
|
126
131
|
|
127
|
-
self.
|
128
|
-
self._llm: BaseLanguageModel
|
132
|
+
self._llm: JanusModel
|
129
133
|
self._prompt: ChatPromptTemplate
|
130
134
|
|
131
135
|
self._parser: JanusParser = GenericParser()
|
132
136
|
self._combiner: Combiner = Combiner()
|
133
137
|
|
134
|
-
self.
|
135
|
-
self.
|
138
|
+
self._splitter_type: str
|
139
|
+
self._refiner_type: str | None
|
140
|
+
self._retriever_type: str | None
|
136
141
|
|
137
|
-
self.
|
142
|
+
self._splitter: Splitter
|
143
|
+
self._refiner: JanusRefiner
|
144
|
+
self._retriever: JanusRetriever
|
138
145
|
|
139
146
|
self.set_splitter(splitter_type=splitter_type)
|
140
147
|
self.set_refiner(refiner_type=refiner_type)
|
148
|
+
self.set_retriever(retriever_type=retriever_type)
|
141
149
|
self.set_model(model_name=model, **model_arguments)
|
142
150
|
self.set_prompt(prompt_template=prompt_template)
|
143
151
|
self.set_source_language(source_language)
|
@@ -146,8 +154,6 @@ class Converter:
|
|
146
154
|
self.set_db_path(db_path=db_path)
|
147
155
|
self.set_db_config(db_config=db_config)
|
148
156
|
|
149
|
-
self.skip_context = skip_context
|
150
|
-
|
151
157
|
# Child class must call this. Should we enforce somehow?
|
152
158
|
# self._load_parameters()
|
153
159
|
|
@@ -163,9 +169,11 @@ class Converter:
|
|
163
169
|
def _load_parameters(self) -> None:
|
164
170
|
self._load_model()
|
165
171
|
self._load_prompt()
|
172
|
+
self._load_retriever()
|
173
|
+
self._load_refiner()
|
166
174
|
self._load_splitter()
|
167
175
|
self._load_vectorizer()
|
168
|
-
self.
|
176
|
+
self._load_chain()
|
169
177
|
self._changed_attrs.clear()
|
170
178
|
|
171
179
|
def set_model(self, model_name: str, **custom_arguments: dict[str, Any]):
|
@@ -184,8 +192,6 @@ class Converter:
|
|
184
192
|
def set_prompt(self, prompt_template: str) -> None:
|
185
193
|
"""Validate and set the prompt template name.
|
186
194
|
|
187
|
-
The affected objects will not be updated until translate() is called.
|
188
|
-
|
189
195
|
Arguments:
|
190
196
|
prompt_template: name of prompt template directory
|
191
197
|
(see janus/prompts/templates) or path to a directory.
|
@@ -195,29 +201,34 @@ class Converter:
|
|
195
201
|
def set_splitter(self, splitter_type: str) -> None:
|
196
202
|
"""Validate and set the prompt template name.
|
197
203
|
|
198
|
-
The affected objects will not be updated until translate() is called.
|
199
|
-
|
200
204
|
Arguments:
|
201
205
|
prompt_template: name of prompt template directory
|
202
206
|
(see janus/prompts/templates) or path to a directory.
|
203
207
|
"""
|
204
|
-
|
208
|
+
if splitter_type not in CUSTOM_SPLITTERS:
|
209
|
+
raise ValueError(f'Splitter type "{splitter_type}" does not exist.')
|
205
210
|
|
206
|
-
|
207
|
-
"""Validate and set the refiner name
|
211
|
+
self._splitter_type = splitter_type
|
208
212
|
|
209
|
-
|
213
|
+
def set_refiner(self, refiner_type: str | None) -> None:
|
214
|
+
"""Validate and set the refiner type
|
210
215
|
|
211
216
|
Arguments:
|
212
|
-
refiner_type: the
|
217
|
+
refiner_type: the type of refiner to use
|
213
218
|
"""
|
214
219
|
self._refiner_type = refiner_type
|
215
220
|
|
221
|
+
def set_retriever(self, retriever_type: str | None) -> None:
|
222
|
+
"""Validate and set the retriever type
|
223
|
+
|
224
|
+
Arguments:
|
225
|
+
retriever_type: the type of retriever to use
|
226
|
+
"""
|
227
|
+
self._retriever_type = retriever_type
|
228
|
+
|
216
229
|
def set_source_language(self, source_language: str) -> None:
|
217
230
|
"""Validate and set the source language.
|
218
231
|
|
219
|
-
The affected objects will not be updated until _load_parameters() is called.
|
220
|
-
|
221
232
|
Arguments:
|
222
233
|
source_language: The source programming language.
|
223
234
|
"""
|
@@ -287,20 +298,6 @@ class Converter:
|
|
287
298
|
|
288
299
|
self._splitter = CUSTOM_SPLITTERS[self._splitter_type](**kwargs)
|
289
300
|
|
290
|
-
@run_if_changed("_refiner_type", "_model_name")
|
291
|
-
def _load_refiner(self) -> None:
|
292
|
-
"""Load the refiner according to this instance's attributes.
|
293
|
-
|
294
|
-
If the relevant fields have not been changed since the last time this method was
|
295
|
-
called, nothing happens.
|
296
|
-
"""
|
297
|
-
if self._refiner_type == "basic":
|
298
|
-
self._refiner = BasicRefiner(
|
299
|
-
"basic_refinement", self._model_id, self._source_language
|
300
|
-
)
|
301
|
-
else:
|
302
|
-
raise ValueError(f"Error: unknown refiner type {self._refiner_type}")
|
303
|
-
|
304
301
|
@run_if_changed("_model_name", "_custom_model_arguments")
|
305
302
|
def _load_model(self) -> None:
|
306
303
|
"""Load the model according to this instance's attributes.
|
@@ -314,9 +311,9 @@ class Converter:
|
|
314
311
|
# model_arguments.update(self._custom_model_arguments)
|
315
312
|
|
316
313
|
# Load the model
|
317
|
-
self._llm
|
318
|
-
|
319
|
-
|
314
|
+
self._llm = load_model(self._model_name)
|
315
|
+
token_limit = self._llm.token_limit
|
316
|
+
|
320
317
|
# Set the max_tokens to less than half the model's limit to allow for enough
|
321
318
|
# tokens at output
|
322
319
|
# Only modify max_tokens if it is not specified by user
|
@@ -335,7 +332,7 @@ class Converter:
|
|
335
332
|
If the relevant fields have not been changed since the last time this
|
336
333
|
method was called, nothing happens.
|
337
334
|
"""
|
338
|
-
prompt_engine = MODEL_PROMPT_ENGINES[self.
|
335
|
+
prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
|
339
336
|
source_language=self._source_language,
|
340
337
|
prompt_template=self._prompt_template_name,
|
341
338
|
)
|
@@ -354,6 +351,59 @@ class Converter:
|
|
354
351
|
self._db_path, self._db_config
|
355
352
|
)
|
356
353
|
|
354
|
+
@run_if_changed("_retriever_type")
|
355
|
+
def _load_retriever(self):
|
356
|
+
if self._retriever_type == "active_usings":
|
357
|
+
self._retriever = ActiveUsingsRetriever()
|
358
|
+
else:
|
359
|
+
self._retriever = JanusRetriever()
|
360
|
+
|
361
|
+
@run_if_changed("_refiner_type", "_model_name", "max_prompts", "_parser", "_llm")
|
362
|
+
def _load_refiner(self) -> None:
|
363
|
+
"""Load the refiner according to this instance's attributes.
|
364
|
+
|
365
|
+
If the relevant fields have not been changed since the last time this method was
|
366
|
+
called, nothing happens.
|
367
|
+
"""
|
368
|
+
if self._refiner_type == "parser":
|
369
|
+
self._refiner = FixParserExceptions(
|
370
|
+
llm=self._llm,
|
371
|
+
parser=self._parser,
|
372
|
+
max_retries=self.max_prompts,
|
373
|
+
)
|
374
|
+
elif self._refiner_type == "reflection":
|
375
|
+
self._refiner = ReflectionRefiner(
|
376
|
+
llm=self._llm,
|
377
|
+
parser=self._parser,
|
378
|
+
max_retries=self.max_prompts,
|
379
|
+
)
|
380
|
+
elif self._refiner_type == "hallucination":
|
381
|
+
self._refiner = HallucinationRefiner(
|
382
|
+
llm=self._llm,
|
383
|
+
parser=self._parser,
|
384
|
+
max_retries=self.max_prompts,
|
385
|
+
)
|
386
|
+
else:
|
387
|
+
self._refiner = JanusRefiner(parser=self._parser)
|
388
|
+
|
389
|
+
@run_if_changed("_parser", "_retriever", "_prompt", "_llm", "_refiner")
|
390
|
+
def _load_chain(self):
|
391
|
+
self.chain = (
|
392
|
+
self._input_runnable()
|
393
|
+
| self._prompt
|
394
|
+
| RunnableParallel(
|
395
|
+
completion=self._llm,
|
396
|
+
prompt_value=RunnablePassthrough(),
|
397
|
+
)
|
398
|
+
| self._refiner.parse_runnable
|
399
|
+
)
|
400
|
+
|
401
|
+
def _input_runnable(self) -> Runnable:
|
402
|
+
return RunnableParallel(
|
403
|
+
SOURCE_CODE=self._parser.parse_input,
|
404
|
+
context=self._retriever,
|
405
|
+
)
|
406
|
+
|
357
407
|
def translate(
|
358
408
|
self,
|
359
409
|
input_directory: str | Path,
|
@@ -598,110 +648,29 @@ class Converter:
|
|
598
648
|
return root
|
599
649
|
|
600
650
|
def _run_chain(self, block: TranslatedCodeBlock) -> str:
|
601
|
-
|
602
|
-
First, try to fix simple formatting errors by giving the model just
|
603
|
-
the output and the parsing error. After a number of attempts, try
|
604
|
-
giving the model the output, the parsing error, and the original
|
605
|
-
input. Again check/retry this output to solve for formatting errors.
|
606
|
-
If we still haven't succeeded after several attempts, the model may
|
607
|
-
be getting thrown off by a bad initial output; start from scratch
|
608
|
-
and try again.
|
609
|
-
|
610
|
-
The number of tries for each layer of this scheme is roughly equal
|
611
|
-
to the cube root of self.max_retries, so the total calls to the
|
612
|
-
LLM will be roughly as expected (up to sqrt(self.max_retries) over)
|
613
|
-
"""
|
614
|
-
input = self._parser.parse_input(block.original)
|
615
|
-
|
616
|
-
# Retries with just the output and the error
|
617
|
-
n1 = round(self.max_prompts ** (1 / 2))
|
618
|
-
|
619
|
-
# Retries with the input, output, and error
|
620
|
-
n2 = round(self.max_prompts // n1)
|
621
|
-
|
622
|
-
if not self.skip_context:
|
623
|
-
self._make_prompt_additions(block)
|
624
|
-
if not self.skip_refiner: # Make replacements in the prompt
|
625
|
-
refine_output = RefinerParser(
|
626
|
-
parser=self._parser,
|
627
|
-
initial_prompt=self._prompt.format(**{"SOURCE_CODE": input}),
|
628
|
-
refiner=self._refiner,
|
629
|
-
max_retries=n1,
|
630
|
-
llm=self._llm,
|
631
|
-
)
|
632
|
-
else:
|
633
|
-
refine_output = RetryWithErrorOutputParser.from_llm(
|
634
|
-
llm=self._llm,
|
635
|
-
parser=self._parser,
|
636
|
-
max_retries=n1,
|
637
|
-
)
|
638
|
-
|
639
|
-
completion_chain = self._prompt | self._llm
|
640
|
-
chain = RunnableParallel(
|
641
|
-
completion=completion_chain, prompt_value=self._prompt
|
642
|
-
) | RunnableLambda(lambda x: refine_output.parse_with_prompt(**x))
|
643
|
-
for _ in range(n2):
|
644
|
-
try:
|
645
|
-
return chain.invoke({"SOURCE_CODE": input})
|
646
|
-
except OutputParserException:
|
647
|
-
pass
|
648
|
-
|
649
|
-
raise OutputParserException(f"Failed to parse after {n1*n2} retries")
|
651
|
+
return self.chain.invoke(block.original)
|
650
652
|
|
651
653
|
def _get_output_obj(
|
652
654
|
self, block: TranslatedCodeBlock
|
653
|
-
) -> dict[str, int | float | str | dict[str, str]]:
|
655
|
+
) -> dict[str, int | float | str | dict[str, str] | dict[str, float]]:
|
654
656
|
output_str = self._parser.parse_combined_output(block.complete_text)
|
655
657
|
|
656
|
-
|
658
|
+
output_obj: str | dict[str, str]
|
657
659
|
try:
|
658
|
-
|
660
|
+
output_obj = json.loads(output_str)
|
659
661
|
except json.JSONDecodeError:
|
660
|
-
|
662
|
+
output_obj = output_str
|
661
663
|
|
662
664
|
return dict(
|
663
|
-
input=block.original.text,
|
665
|
+
input=block.original.text or "",
|
664
666
|
metadata=dict(
|
665
667
|
retries=block.total_retries,
|
666
668
|
cost=block.total_cost,
|
667
669
|
processing_time=block.processing_time,
|
668
670
|
),
|
669
|
-
output=
|
670
|
-
)
|
671
|
-
|
672
|
-
@staticmethod
|
673
|
-
def _get_prompt_additions(block) -> Optional[List[Tuple[str, str]]]:
|
674
|
-
"""Get a list of strings to append to the prompt.
|
675
|
-
|
676
|
-
Arguments:
|
677
|
-
block: The `TranslatedCodeBlock` to save to a file.
|
678
|
-
"""
|
679
|
-
return [(key, item) for key, item in block.context_tags.items()]
|
680
|
-
|
681
|
-
def _make_prompt_additions(self, block: CodeBlock):
|
682
|
-
# Prepare the additional context to prepend
|
683
|
-
additional_context = "".join(
|
684
|
-
[
|
685
|
-
f"{context_tag}: {context}\n"
|
686
|
-
for context_tag, context in self._get_prompt_additions(block)
|
687
|
-
]
|
671
|
+
output=output_obj,
|
688
672
|
)
|
689
673
|
|
690
|
-
if not hasattr(self._prompt, "messages"):
|
691
|
-
log.debug("Skipping additions to prompt, no messages found on prompt object!")
|
692
|
-
return
|
693
|
-
|
694
|
-
# Iterate through existing messages to find and update the system message
|
695
|
-
for i, message in enumerate(self._prompt.messages):
|
696
|
-
if isinstance(message, SystemMessagePromptTemplate):
|
697
|
-
# Prepend the additional context to the system message
|
698
|
-
updated_system_message = SystemMessagePromptTemplate.from_template(
|
699
|
-
additional_context + message.prompt.template
|
700
|
-
)
|
701
|
-
# Directly modify the message in the list
|
702
|
-
self._prompt.messages[i] = updated_system_message
|
703
|
-
break # Assuming there's only one system message to update
|
704
|
-
|
705
674
|
def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
|
706
675
|
"""Save a file to disk.
|
707
676
|
|