janus-llm 3.5.2__py3-none-any.whl → 4.0.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- janus/__init__.py +1 -1
- janus/cli.py +90 -42
- janus/converter/converter.py +111 -142
- janus/converter/diagram.py +21 -109
- janus/converter/translate.py +1 -1
- janus/language/alc/_tests/test_alc.py +1 -1
- janus/language/alc/alc.py +16 -11
- janus/language/binary/_tests/test_binary.py +1 -1
- janus/language/binary/binary.py +2 -2
- janus/language/mumps/_tests/test_mumps.py +1 -1
- janus/language/mumps/mumps.py +2 -3
- janus/language/naive/simple_ast.py +3 -2
- janus/language/splitter.py +7 -4
- janus/language/treesitter/_tests/test_treesitter.py +1 -1
- janus/language/treesitter/treesitter.py +2 -2
- janus/llm/model_callbacks.py +13 -0
- janus/llm/models_info.py +118 -71
- janus/metrics/metric.py +15 -14
- janus/parsers/uml.py +60 -23
- janus/refiners/refiner.py +106 -64
- janus/retrievers/retriever.py +42 -0
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/METADATA +1 -1
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/RECORD +26 -26
- janus/parsers/refiner_parser.py +0 -46
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/LICENSE +0 -0
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/WHEEL +0 -0
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/entry_points.txt +0 -0
janus/__init__.py
CHANGED
@@ -5,7 +5,7 @@ from langchain_core._api.deprecation import LangChainDeprecationWarning
|
|
5
5
|
from janus.converter.translate import Translator
|
6
6
|
from janus.metrics import * # noqa: F403
|
7
7
|
|
8
|
-
__version__ = "
|
8
|
+
__version__ = "4.0.0"
|
9
9
|
|
10
10
|
# Ignoring a deprecation warning from langchain_core that I can't seem to hunt down
|
11
11
|
warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
|
janus/cli.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import json
|
2
2
|
import logging
|
3
3
|
import os
|
4
|
+
import subprocess # nosec
|
4
5
|
from pathlib import Path
|
5
6
|
from typing import List, Optional
|
6
7
|
|
@@ -42,6 +43,7 @@ from janus.llm.models_info import (
|
|
42
43
|
openai_models,
|
43
44
|
)
|
44
45
|
from janus.metrics.cli import evaluate
|
46
|
+
from janus.refiners.refiner import REFINERS
|
45
47
|
from janus.utils.enums import LANGUAGES
|
46
48
|
from janus.utils.logger import create_logger
|
47
49
|
|
@@ -241,6 +243,24 @@ def translate(
|
|
241
243
|
click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
|
242
244
|
),
|
243
245
|
] = "file",
|
246
|
+
refiner_type: Annotated[
|
247
|
+
str,
|
248
|
+
typer.Option(
|
249
|
+
"-r",
|
250
|
+
"--refiner",
|
251
|
+
help="Name of custom refiner to use",
|
252
|
+
click_type=click.Choice(list(REFINERS.keys())),
|
253
|
+
),
|
254
|
+
] = "none",
|
255
|
+
retriever_type: Annotated[
|
256
|
+
str,
|
257
|
+
typer.Option(
|
258
|
+
"-R",
|
259
|
+
"--retriever",
|
260
|
+
help="Name of custom retriever to use",
|
261
|
+
click_type=click.Choice(["active_usings"]),
|
262
|
+
),
|
263
|
+
] = None,
|
244
264
|
max_tokens: Annotated[
|
245
265
|
int,
|
246
266
|
typer.Option(
|
@@ -250,13 +270,6 @@ def translate(
|
|
250
270
|
"If unspecificed, model's default max will be used.",
|
251
271
|
),
|
252
272
|
] = None,
|
253
|
-
skip_refiner: Annotated[
|
254
|
-
bool,
|
255
|
-
typer.Option(
|
256
|
-
"--skip-refiner",
|
257
|
-
help="Whether to skip the refiner for generating output",
|
258
|
-
),
|
259
|
-
] = True,
|
260
273
|
):
|
261
274
|
try:
|
262
275
|
target_language, target_version = target_lang.split("-")
|
@@ -282,8 +295,8 @@ def translate(
|
|
282
295
|
db_path=db_loc,
|
283
296
|
db_config=collections_config,
|
284
297
|
splitter_type=splitter_type,
|
285
|
-
|
286
|
-
|
298
|
+
refiner_type=refiner_type,
|
299
|
+
retriever_type=retriever_type,
|
287
300
|
)
|
288
301
|
translator.translate(input_dir, output_dir, overwrite, collection)
|
289
302
|
|
@@ -341,14 +354,6 @@ def document(
|
|
341
354
|
help="Whether to overwrite existing files in the output directory",
|
342
355
|
),
|
343
356
|
] = False,
|
344
|
-
skip_context: Annotated[
|
345
|
-
bool,
|
346
|
-
typer.Option(
|
347
|
-
"--skip-context",
|
348
|
-
help="Prompts will include any context information associated with source"
|
349
|
-
" code blocks, unless this option is specified",
|
350
|
-
),
|
351
|
-
] = False,
|
352
357
|
doc_mode: Annotated[
|
353
358
|
str,
|
354
359
|
typer.Option(
|
@@ -396,6 +401,24 @@ def document(
|
|
396
401
|
click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
|
397
402
|
),
|
398
403
|
] = "file",
|
404
|
+
refiner_type: Annotated[
|
405
|
+
str,
|
406
|
+
typer.Option(
|
407
|
+
"-r",
|
408
|
+
"--refiner",
|
409
|
+
help="Name of custom refiner to use",
|
410
|
+
click_type=click.Choice(list(REFINERS.keys())),
|
411
|
+
),
|
412
|
+
] = "none",
|
413
|
+
retriever_type: Annotated[
|
414
|
+
str,
|
415
|
+
typer.Option(
|
416
|
+
"-R",
|
417
|
+
"--retriever",
|
418
|
+
help="Name of custom retriever to use",
|
419
|
+
click_type=click.Choice(["active_usings"]),
|
420
|
+
),
|
421
|
+
] = None,
|
399
422
|
max_tokens: Annotated[
|
400
423
|
int,
|
401
424
|
typer.Option(
|
@@ -405,13 +428,6 @@ def document(
|
|
405
428
|
"If unspecificed, model's default max will be used.",
|
406
429
|
),
|
407
430
|
] = None,
|
408
|
-
skip_refiner: Annotated[
|
409
|
-
bool,
|
410
|
-
typer.Option(
|
411
|
-
"--skip-refiner",
|
412
|
-
help="Whether to skip the refiner for generating output",
|
413
|
-
),
|
414
|
-
] = True,
|
415
431
|
):
|
416
432
|
model_arguments = dict(temperature=temperature)
|
417
433
|
collections_config = get_collections_config()
|
@@ -424,8 +440,8 @@ def document(
|
|
424
440
|
db_path=db_loc,
|
425
441
|
db_config=collections_config,
|
426
442
|
splitter_type=splitter_type,
|
427
|
-
|
428
|
-
|
443
|
+
refiner_type=refiner_type,
|
444
|
+
retriever_type=retriever_type,
|
429
445
|
)
|
430
446
|
if doc_mode == "madlibs":
|
431
447
|
documenter = MadLibsDocumenter(
|
@@ -614,14 +630,6 @@ def diagram(
|
|
614
630
|
help="Whether to overwrite existing files in the output directory",
|
615
631
|
),
|
616
632
|
] = False,
|
617
|
-
skip_context: Annotated[
|
618
|
-
bool,
|
619
|
-
typer.Option(
|
620
|
-
"--skip-context",
|
621
|
-
help="Prompts will include any context information associated with source"
|
622
|
-
" code blocks, unless this option is specified",
|
623
|
-
),
|
624
|
-
] = False,
|
625
633
|
temperature: Annotated[
|
626
634
|
float,
|
627
635
|
typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
|
@@ -658,13 +666,24 @@ def diagram(
|
|
658
666
|
click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
|
659
667
|
),
|
660
668
|
] = "file",
|
661
|
-
|
662
|
-
|
669
|
+
refiner_type: Annotated[
|
670
|
+
str,
|
663
671
|
typer.Option(
|
664
|
-
"
|
665
|
-
|
672
|
+
"-r",
|
673
|
+
"--refiner",
|
674
|
+
help="Name of custom refiner to use",
|
675
|
+
click_type=click.Choice(list(REFINERS.keys())),
|
666
676
|
),
|
667
|
-
] =
|
677
|
+
] = "none",
|
678
|
+
retriever_type: Annotated[
|
679
|
+
str,
|
680
|
+
typer.Option(
|
681
|
+
"-R",
|
682
|
+
"--retriever",
|
683
|
+
help="Name of custom retriever to use",
|
684
|
+
click_type=click.Choice(["active_usings"]),
|
685
|
+
),
|
686
|
+
] = None,
|
668
687
|
):
|
669
688
|
model_arguments = dict(temperature=temperature)
|
670
689
|
collections_config = get_collections_config()
|
@@ -675,11 +694,11 @@ def diagram(
|
|
675
694
|
max_prompts=max_prompts,
|
676
695
|
db_path=db_loc,
|
677
696
|
db_config=collections_config,
|
697
|
+
splitter_type=splitter_type,
|
698
|
+
refiner_type=refiner_type,
|
699
|
+
retriever_type=retriever_type,
|
678
700
|
diagram_type=diagram_type,
|
679
701
|
add_documentation=add_documentation,
|
680
|
-
splitter_type=splitter_type,
|
681
|
-
skip_refiner=skip_refiner,
|
682
|
-
skip_context=skip_context,
|
683
702
|
)
|
684
703
|
diagram_generator.translate(input_dir, output_dir, overwrite, collection)
|
685
704
|
|
@@ -1156,5 +1175,34 @@ app.add_typer(evaluate, name="evaluate")
|
|
1156
1175
|
app.add_typer(embedding, name="embedding")
|
1157
1176
|
|
1158
1177
|
|
1178
|
+
@app.command()
|
1179
|
+
def render(
|
1180
|
+
input_dir: Annotated[
|
1181
|
+
str,
|
1182
|
+
typer.Option(
|
1183
|
+
"--input",
|
1184
|
+
"-i",
|
1185
|
+
),
|
1186
|
+
],
|
1187
|
+
output_dir: Annotated[str, typer.Option("--output", "-o")],
|
1188
|
+
):
|
1189
|
+
input_dir = Path(input_dir)
|
1190
|
+
output_dir = Path(output_dir)
|
1191
|
+
for input_file in input_dir.rglob("*.json"):
|
1192
|
+
with open(input_file, "r") as f:
|
1193
|
+
data = json.load(f)
|
1194
|
+
|
1195
|
+
output_file = output_dir / input_file.relative_to(input_dir).with_suffix(".txt")
|
1196
|
+
if not output_file.parent.exists():
|
1197
|
+
output_file.parent.mkdir()
|
1198
|
+
|
1199
|
+
text = data["output"].replace("\\n", "\n").strip()
|
1200
|
+
output_file.write_text(text)
|
1201
|
+
|
1202
|
+
jar_path = homedir / ".janus/lib/plantuml.jar"
|
1203
|
+
subprocess.run(["java", "-jar", jar_path, output_file]) # nosec
|
1204
|
+
output_file.unlink()
|
1205
|
+
|
1206
|
+
|
1159
1207
|
if __name__ == "__main__":
|
1160
1208
|
app()
|
janus/converter/converter.py
CHANGED
@@ -2,13 +2,11 @@ import functools
|
|
2
2
|
import json
|
3
3
|
import time
|
4
4
|
from pathlib import Path
|
5
|
-
from typing import Any
|
5
|
+
from typing import Any
|
6
6
|
|
7
|
-
from langchain.output_parsers import RetryWithErrorOutputParser
|
8
7
|
from langchain_core.exceptions import OutputParserException
|
9
|
-
from langchain_core.
|
10
|
-
from langchain_core.
|
11
|
-
from langchain_core.runnables import RunnableLambda, RunnableParallel
|
8
|
+
from langchain_core.prompts import ChatPromptTemplate
|
9
|
+
from langchain_core.runnables import Runnable, RunnableParallel, RunnablePassthrough
|
12
10
|
from openai import BadRequestError, RateLimitError
|
13
11
|
from pydantic import ValidationError
|
14
12
|
|
@@ -22,12 +20,18 @@ from janus.language.splitter import (
|
|
22
20
|
Splitter,
|
23
21
|
TokenLimitError,
|
24
22
|
)
|
25
|
-
from janus.llm import load_model
|
26
23
|
from janus.llm.model_callbacks import get_model_callback
|
27
|
-
from janus.llm.models_info import MODEL_PROMPT_ENGINES
|
24
|
+
from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel, load_model
|
28
25
|
from janus.parsers.parser import GenericParser, JanusParser
|
29
|
-
from janus.
|
30
|
-
|
26
|
+
from janus.refiners.refiner import (
|
27
|
+
FixParserExceptions,
|
28
|
+
HallucinationRefiner,
|
29
|
+
JanusRefiner,
|
30
|
+
ReflectionRefiner,
|
31
|
+
)
|
32
|
+
|
33
|
+
# from janus.refiners.refiner import BasicRefiner, Refiner
|
34
|
+
from janus.retrievers.retriever import ActiveUsingsRetriever, JanusRetriever
|
31
35
|
from janus.utils.enums import LANGUAGES
|
32
36
|
from janus.utils.logger import create_logger
|
33
37
|
|
@@ -74,9 +78,8 @@ class Converter:
|
|
74
78
|
protected_node_types: tuple[str, ...] = (),
|
75
79
|
prune_node_types: tuple[str, ...] = (),
|
76
80
|
splitter_type: str = "file",
|
77
|
-
refiner_type: str =
|
78
|
-
|
79
|
-
skip_context: bool = False,
|
81
|
+
refiner_type: str | None = None,
|
82
|
+
retriever_type: str | None = None,
|
80
83
|
) -> None:
|
81
84
|
"""Initialize a Converter instance.
|
82
85
|
|
@@ -96,9 +99,13 @@ class Converter:
|
|
96
99
|
prune_node_types: A set of node types which should be pruned.
|
97
100
|
splitter_type: The type of splitter to use. Valid values are `"file"`,
|
98
101
|
`"tag"`, `"chunk"`, `"ast-strict"`, and `"ast-flex"`.
|
99
|
-
refiner_type: The type of refiner to use. Valid values
|
100
|
-
|
101
|
-
|
102
|
+
refiner_type: The type of refiner to use. Valid values:
|
103
|
+
- "parser"
|
104
|
+
- "reflection"
|
105
|
+
- None
|
106
|
+
retriever_type: The type of retriever to use. Valid values:
|
107
|
+
- "active_usings"
|
108
|
+
- None
|
102
109
|
"""
|
103
110
|
self._changed_attrs: set = set()
|
104
111
|
|
@@ -107,7 +114,6 @@ class Converter:
|
|
107
114
|
self.override_token_limit: bool = max_tokens is not None
|
108
115
|
|
109
116
|
self._model_name: str
|
110
|
-
self._model_id: str
|
111
117
|
self._custom_model_arguments: dict[str, Any]
|
112
118
|
|
113
119
|
self._source_language: str
|
@@ -120,24 +126,26 @@ class Converter:
|
|
120
126
|
self._prune_node_types: tuple[str, ...] = ()
|
121
127
|
self._max_tokens: int | None = max_tokens
|
122
128
|
self._prompt_template_name: str
|
123
|
-
self._splitter_type: str
|
124
129
|
self._db_path: str | None
|
125
130
|
self._db_config: dict[str, Any] | None
|
126
131
|
|
127
|
-
self.
|
128
|
-
self._llm: BaseLanguageModel
|
132
|
+
self._llm: JanusModel
|
129
133
|
self._prompt: ChatPromptTemplate
|
130
134
|
|
131
135
|
self._parser: JanusParser = GenericParser()
|
132
136
|
self._combiner: Combiner = Combiner()
|
133
137
|
|
134
|
-
self.
|
135
|
-
self.
|
138
|
+
self._splitter_type: str
|
139
|
+
self._refiner_type: str | None
|
140
|
+
self._retriever_type: str | None
|
136
141
|
|
137
|
-
self.
|
142
|
+
self._splitter: Splitter
|
143
|
+
self._refiner: JanusRefiner
|
144
|
+
self._retriever: JanusRetriever
|
138
145
|
|
139
146
|
self.set_splitter(splitter_type=splitter_type)
|
140
147
|
self.set_refiner(refiner_type=refiner_type)
|
148
|
+
self.set_retriever(retriever_type=retriever_type)
|
141
149
|
self.set_model(model_name=model, **model_arguments)
|
142
150
|
self.set_prompt(prompt_template=prompt_template)
|
143
151
|
self.set_source_language(source_language)
|
@@ -146,8 +154,6 @@ class Converter:
|
|
146
154
|
self.set_db_path(db_path=db_path)
|
147
155
|
self.set_db_config(db_config=db_config)
|
148
156
|
|
149
|
-
self.skip_context = skip_context
|
150
|
-
|
151
157
|
# Child class must call this. Should we enforce somehow?
|
152
158
|
# self._load_parameters()
|
153
159
|
|
@@ -163,9 +169,11 @@ class Converter:
|
|
163
169
|
def _load_parameters(self) -> None:
|
164
170
|
self._load_model()
|
165
171
|
self._load_prompt()
|
172
|
+
self._load_retriever()
|
173
|
+
self._load_refiner()
|
166
174
|
self._load_splitter()
|
167
175
|
self._load_vectorizer()
|
168
|
-
self.
|
176
|
+
self._load_chain()
|
169
177
|
self._changed_attrs.clear()
|
170
178
|
|
171
179
|
def set_model(self, model_name: str, **custom_arguments: dict[str, Any]):
|
@@ -184,8 +192,6 @@ class Converter:
|
|
184
192
|
def set_prompt(self, prompt_template: str) -> None:
|
185
193
|
"""Validate and set the prompt template name.
|
186
194
|
|
187
|
-
The affected objects will not be updated until translate() is called.
|
188
|
-
|
189
195
|
Arguments:
|
190
196
|
prompt_template: name of prompt template directory
|
191
197
|
(see janus/prompts/templates) or path to a directory.
|
@@ -195,29 +201,34 @@ class Converter:
|
|
195
201
|
def set_splitter(self, splitter_type: str) -> None:
|
196
202
|
"""Validate and set the prompt template name.
|
197
203
|
|
198
|
-
The affected objects will not be updated until translate() is called.
|
199
|
-
|
200
204
|
Arguments:
|
201
205
|
prompt_template: name of prompt template directory
|
202
206
|
(see janus/prompts/templates) or path to a directory.
|
203
207
|
"""
|
204
|
-
|
208
|
+
if splitter_type not in CUSTOM_SPLITTERS:
|
209
|
+
raise ValueError(f'Splitter type "{splitter_type}" does not exist.')
|
205
210
|
|
206
|
-
|
207
|
-
"""Validate and set the refiner name
|
211
|
+
self._splitter_type = splitter_type
|
208
212
|
|
209
|
-
|
213
|
+
def set_refiner(self, refiner_type: str | None) -> None:
|
214
|
+
"""Validate and set the refiner type
|
210
215
|
|
211
216
|
Arguments:
|
212
|
-
refiner_type: the
|
217
|
+
refiner_type: the type of refiner to use
|
213
218
|
"""
|
214
219
|
self._refiner_type = refiner_type
|
215
220
|
|
221
|
+
def set_retriever(self, retriever_type: str | None) -> None:
|
222
|
+
"""Validate and set the retriever type
|
223
|
+
|
224
|
+
Arguments:
|
225
|
+
retriever_type: the type of retriever to use
|
226
|
+
"""
|
227
|
+
self._retriever_type = retriever_type
|
228
|
+
|
216
229
|
def set_source_language(self, source_language: str) -> None:
|
217
230
|
"""Validate and set the source language.
|
218
231
|
|
219
|
-
The affected objects will not be updated until _load_parameters() is called.
|
220
|
-
|
221
232
|
Arguments:
|
222
233
|
source_language: The source programming language.
|
223
234
|
"""
|
@@ -287,20 +298,6 @@ class Converter:
|
|
287
298
|
|
288
299
|
self._splitter = CUSTOM_SPLITTERS[self._splitter_type](**kwargs)
|
289
300
|
|
290
|
-
@run_if_changed("_refiner_type", "_model_name")
|
291
|
-
def _load_refiner(self) -> None:
|
292
|
-
"""Load the refiner according to this instance's attributes.
|
293
|
-
|
294
|
-
If the relevant fields have not been changed since the last time this method was
|
295
|
-
called, nothing happens.
|
296
|
-
"""
|
297
|
-
if self._refiner_type == "basic":
|
298
|
-
self._refiner = BasicRefiner(
|
299
|
-
"basic_refinement", self._model_id, self._source_language
|
300
|
-
)
|
301
|
-
else:
|
302
|
-
raise ValueError(f"Error: unknown refiner type {self._refiner_type}")
|
303
|
-
|
304
301
|
@run_if_changed("_model_name", "_custom_model_arguments")
|
305
302
|
def _load_model(self) -> None:
|
306
303
|
"""Load the model according to this instance's attributes.
|
@@ -314,9 +311,9 @@ class Converter:
|
|
314
311
|
# model_arguments.update(self._custom_model_arguments)
|
315
312
|
|
316
313
|
# Load the model
|
317
|
-
self._llm
|
318
|
-
|
319
|
-
|
314
|
+
self._llm = load_model(self._model_name)
|
315
|
+
token_limit = self._llm.token_limit
|
316
|
+
|
320
317
|
# Set the max_tokens to less than half the model's limit to allow for enough
|
321
318
|
# tokens at output
|
322
319
|
# Only modify max_tokens if it is not specified by user
|
@@ -335,7 +332,7 @@ class Converter:
|
|
335
332
|
If the relevant fields have not been changed since the last time this
|
336
333
|
method was called, nothing happens.
|
337
334
|
"""
|
338
|
-
prompt_engine = MODEL_PROMPT_ENGINES[self.
|
335
|
+
prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
|
339
336
|
source_language=self._source_language,
|
340
337
|
prompt_template=self._prompt_template_name,
|
341
338
|
)
|
@@ -354,6 +351,59 @@ class Converter:
|
|
354
351
|
self._db_path, self._db_config
|
355
352
|
)
|
356
353
|
|
354
|
+
@run_if_changed("_retriever_type")
|
355
|
+
def _load_retriever(self):
|
356
|
+
if self._retriever_type == "active_usings":
|
357
|
+
self._retriever = ActiveUsingsRetriever()
|
358
|
+
else:
|
359
|
+
self._retriever = JanusRetriever()
|
360
|
+
|
361
|
+
@run_if_changed("_refiner_type", "_model_name", "max_prompts", "_parser", "_llm")
|
362
|
+
def _load_refiner(self) -> None:
|
363
|
+
"""Load the refiner according to this instance's attributes.
|
364
|
+
|
365
|
+
If the relevant fields have not been changed since the last time this method was
|
366
|
+
called, nothing happens.
|
367
|
+
"""
|
368
|
+
if self._refiner_type == "parser":
|
369
|
+
self._refiner = FixParserExceptions(
|
370
|
+
llm=self._llm,
|
371
|
+
parser=self._parser,
|
372
|
+
max_retries=self.max_prompts,
|
373
|
+
)
|
374
|
+
elif self._refiner_type == "reflection":
|
375
|
+
self._refiner = ReflectionRefiner(
|
376
|
+
llm=self._llm,
|
377
|
+
parser=self._parser,
|
378
|
+
max_retries=self.max_prompts,
|
379
|
+
)
|
380
|
+
elif self._refiner_type == "hallucination":
|
381
|
+
self._refiner = HallucinationRefiner(
|
382
|
+
llm=self._llm,
|
383
|
+
parser=self._parser,
|
384
|
+
max_retries=self.max_prompts,
|
385
|
+
)
|
386
|
+
else:
|
387
|
+
self._refiner = JanusRefiner(parser=self._parser)
|
388
|
+
|
389
|
+
@run_if_changed("_parser", "_retriever", "_prompt", "_llm", "_refiner")
|
390
|
+
def _load_chain(self):
|
391
|
+
self.chain = (
|
392
|
+
self._input_runnable()
|
393
|
+
| self._prompt
|
394
|
+
| RunnableParallel(
|
395
|
+
completion=self._llm,
|
396
|
+
prompt_value=RunnablePassthrough(),
|
397
|
+
)
|
398
|
+
| self._refiner.parse_runnable
|
399
|
+
)
|
400
|
+
|
401
|
+
def _input_runnable(self) -> Runnable:
|
402
|
+
return RunnableParallel(
|
403
|
+
SOURCE_CODE=self._parser.parse_input,
|
404
|
+
context=self._retriever,
|
405
|
+
)
|
406
|
+
|
357
407
|
def translate(
|
358
408
|
self,
|
359
409
|
input_directory: str | Path,
|
@@ -598,110 +648,29 @@ class Converter:
|
|
598
648
|
return root
|
599
649
|
|
600
650
|
def _run_chain(self, block: TranslatedCodeBlock) -> str:
|
601
|
-
|
602
|
-
First, try to fix simple formatting errors by giving the model just
|
603
|
-
the output and the parsing error. After a number of attempts, try
|
604
|
-
giving the model the output, the parsing error, and the original
|
605
|
-
input. Again check/retry this output to solve for formatting errors.
|
606
|
-
If we still haven't succeeded after several attempts, the model may
|
607
|
-
be getting thrown off by a bad initial output; start from scratch
|
608
|
-
and try again.
|
609
|
-
|
610
|
-
The number of tries for each layer of this scheme is roughly equal
|
611
|
-
to the cube root of self.max_retries, so the total calls to the
|
612
|
-
LLM will be roughly as expected (up to sqrt(self.max_retries) over)
|
613
|
-
"""
|
614
|
-
input = self._parser.parse_input(block.original)
|
615
|
-
|
616
|
-
# Retries with just the output and the error
|
617
|
-
n1 = round(self.max_prompts ** (1 / 2))
|
618
|
-
|
619
|
-
# Retries with the input, output, and error
|
620
|
-
n2 = round(self.max_prompts // n1)
|
621
|
-
|
622
|
-
if not self.skip_context:
|
623
|
-
self._make_prompt_additions(block)
|
624
|
-
if not self.skip_refiner: # Make replacements in the prompt
|
625
|
-
refine_output = RefinerParser(
|
626
|
-
parser=self._parser,
|
627
|
-
initial_prompt=self._prompt.format(**{"SOURCE_CODE": input}),
|
628
|
-
refiner=self._refiner,
|
629
|
-
max_retries=n1,
|
630
|
-
llm=self._llm,
|
631
|
-
)
|
632
|
-
else:
|
633
|
-
refine_output = RetryWithErrorOutputParser.from_llm(
|
634
|
-
llm=self._llm,
|
635
|
-
parser=self._parser,
|
636
|
-
max_retries=n1,
|
637
|
-
)
|
638
|
-
|
639
|
-
completion_chain = self._prompt | self._llm
|
640
|
-
chain = RunnableParallel(
|
641
|
-
completion=completion_chain, prompt_value=self._prompt
|
642
|
-
) | RunnableLambda(lambda x: refine_output.parse_with_prompt(**x))
|
643
|
-
for _ in range(n2):
|
644
|
-
try:
|
645
|
-
return chain.invoke({"SOURCE_CODE": input})
|
646
|
-
except OutputParserException:
|
647
|
-
pass
|
648
|
-
|
649
|
-
raise OutputParserException(f"Failed to parse after {n1*n2} retries")
|
651
|
+
return self.chain.invoke(block.original)
|
650
652
|
|
651
653
|
def _get_output_obj(
|
652
654
|
self, block: TranslatedCodeBlock
|
653
|
-
) -> dict[str, int | float | str | dict[str, str]]:
|
655
|
+
) -> dict[str, int | float | str | dict[str, str] | dict[str, float]]:
|
654
656
|
output_str = self._parser.parse_combined_output(block.complete_text)
|
655
657
|
|
656
|
-
|
658
|
+
output_obj: str | dict[str, str]
|
657
659
|
try:
|
658
|
-
|
660
|
+
output_obj = json.loads(output_str)
|
659
661
|
except json.JSONDecodeError:
|
660
|
-
|
662
|
+
output_obj = output_str
|
661
663
|
|
662
664
|
return dict(
|
663
|
-
input=block.original.text,
|
665
|
+
input=block.original.text or "",
|
664
666
|
metadata=dict(
|
665
667
|
retries=block.total_retries,
|
666
668
|
cost=block.total_cost,
|
667
669
|
processing_time=block.processing_time,
|
668
670
|
),
|
669
|
-
output=
|
670
|
-
)
|
671
|
-
|
672
|
-
@staticmethod
|
673
|
-
def _get_prompt_additions(block) -> Optional[List[Tuple[str, str]]]:
|
674
|
-
"""Get a list of strings to append to the prompt.
|
675
|
-
|
676
|
-
Arguments:
|
677
|
-
block: The `TranslatedCodeBlock` to save to a file.
|
678
|
-
"""
|
679
|
-
return [(key, item) for key, item in block.context_tags.items()]
|
680
|
-
|
681
|
-
def _make_prompt_additions(self, block: CodeBlock):
|
682
|
-
# Prepare the additional context to prepend
|
683
|
-
additional_context = "".join(
|
684
|
-
[
|
685
|
-
f"{context_tag}: {context}\n"
|
686
|
-
for context_tag, context in self._get_prompt_additions(block)
|
687
|
-
]
|
671
|
+
output=output_obj,
|
688
672
|
)
|
689
673
|
|
690
|
-
if not hasattr(self._prompt, "messages"):
|
691
|
-
log.debug("Skipping additions to prompt, no messages found on prompt object!")
|
692
|
-
return
|
693
|
-
|
694
|
-
# Iterate through existing messages to find and update the system message
|
695
|
-
for i, message in enumerate(self._prompt.messages):
|
696
|
-
if isinstance(message, SystemMessagePromptTemplate):
|
697
|
-
# Prepend the additional context to the system message
|
698
|
-
updated_system_message = SystemMessagePromptTemplate.from_template(
|
699
|
-
additional_context + message.prompt.template
|
700
|
-
)
|
701
|
-
# Directly modify the message in the list
|
702
|
-
self._prompt.messages[i] = updated_system_message
|
703
|
-
break # Assuming there's only one system message to update
|
704
|
-
|
705
674
|
def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
|
706
675
|
"""Save a file to disk.
|
707
676
|
|