janus-llm 3.5.3__py3-none-any.whl → 4.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +1 -1
- janus/cli.py +91 -48
- janus/converter/_tests/test_translate.py +2 -2
- janus/converter/converter.py +111 -142
- janus/converter/diagram.py +21 -109
- janus/converter/translate.py +1 -1
- janus/language/alc/_tests/test_alc.py +1 -1
- janus/language/alc/alc.py +15 -10
- janus/language/binary/_tests/test_binary.py +1 -1
- janus/language/binary/binary.py +2 -2
- janus/language/mumps/_tests/test_mumps.py +1 -1
- janus/language/mumps/mumps.py +2 -3
- janus/language/splitter.py +2 -2
- janus/language/treesitter/_tests/test_treesitter.py +1 -1
- janus/language/treesitter/treesitter.py +2 -2
- janus/llm/model_callbacks.py +22 -0
- janus/llm/models_info.py +142 -81
- janus/metrics/metric.py +15 -14
- janus/parsers/uml.py +60 -23
- janus/refiners/refiner.py +106 -64
- janus/retrievers/retriever.py +42 -0
- {janus_llm-3.5.3.dist-info → janus_llm-4.1.0.dist-info}/METADATA +1 -1
- {janus_llm-3.5.3.dist-info → janus_llm-4.1.0.dist-info}/RECORD +26 -26
- janus/parsers/refiner_parser.py +0 -46
- {janus_llm-3.5.3.dist-info → janus_llm-4.1.0.dist-info}/LICENSE +0 -0
- {janus_llm-3.5.3.dist-info → janus_llm-4.1.0.dist-info}/WHEEL +0 -0
- {janus_llm-3.5.3.dist-info → janus_llm-4.1.0.dist-info}/entry_points.txt +0 -0
janus/__init__.py
CHANGED
@@ -5,7 +5,7 @@ from langchain_core._api.deprecation import LangChainDeprecationWarning
|
|
5
5
|
from janus.converter.translate import Translator
|
6
6
|
from janus.metrics import * # noqa: F403
|
7
7
|
|
8
|
-
__version__ = "
|
8
|
+
__version__ = "4.1.0"
|
9
9
|
|
10
10
|
# Ignoring a deprecation warning from langchain_core that I can't seem to hunt down
|
11
11
|
warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
|
janus/cli.py
CHANGED
@@ -39,10 +39,12 @@ from janus.llm.models_info import (
|
|
39
39
|
MODEL_TYPE_CONSTRUCTORS,
|
40
40
|
MODEL_TYPES,
|
41
41
|
TOKEN_LIMITS,
|
42
|
+
azure_models,
|
42
43
|
bedrock_models,
|
43
44
|
openai_models,
|
44
45
|
)
|
45
46
|
from janus.metrics.cli import evaluate
|
47
|
+
from janus.refiners.refiner import REFINERS
|
46
48
|
from janus.utils.enums import LANGUAGES
|
47
49
|
from janus.utils.logger import create_logger
|
48
50
|
|
@@ -242,6 +244,24 @@ def translate(
|
|
242
244
|
click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
|
243
245
|
),
|
244
246
|
] = "file",
|
247
|
+
refiner_type: Annotated[
|
248
|
+
str,
|
249
|
+
typer.Option(
|
250
|
+
"-r",
|
251
|
+
"--refiner",
|
252
|
+
help="Name of custom refiner to use",
|
253
|
+
click_type=click.Choice(list(REFINERS.keys())),
|
254
|
+
),
|
255
|
+
] = "none",
|
256
|
+
retriever_type: Annotated[
|
257
|
+
str,
|
258
|
+
typer.Option(
|
259
|
+
"-R",
|
260
|
+
"--retriever",
|
261
|
+
help="Name of custom retriever to use",
|
262
|
+
click_type=click.Choice(["active_usings"]),
|
263
|
+
),
|
264
|
+
] = None,
|
245
265
|
max_tokens: Annotated[
|
246
266
|
int,
|
247
267
|
typer.Option(
|
@@ -251,13 +271,6 @@ def translate(
|
|
251
271
|
"If unspecificed, model's default max will be used.",
|
252
272
|
),
|
253
273
|
] = None,
|
254
|
-
skip_refiner: Annotated[
|
255
|
-
bool,
|
256
|
-
typer.Option(
|
257
|
-
"--skip-refiner",
|
258
|
-
help="Whether to skip the refiner for generating output",
|
259
|
-
),
|
260
|
-
] = True,
|
261
274
|
):
|
262
275
|
try:
|
263
276
|
target_language, target_version = target_lang.split("-")
|
@@ -283,8 +296,8 @@ def translate(
|
|
283
296
|
db_path=db_loc,
|
284
297
|
db_config=collections_config,
|
285
298
|
splitter_type=splitter_type,
|
286
|
-
|
287
|
-
|
299
|
+
refiner_type=refiner_type,
|
300
|
+
retriever_type=retriever_type,
|
288
301
|
)
|
289
302
|
translator.translate(input_dir, output_dir, overwrite, collection)
|
290
303
|
|
@@ -342,14 +355,6 @@ def document(
|
|
342
355
|
help="Whether to overwrite existing files in the output directory",
|
343
356
|
),
|
344
357
|
] = False,
|
345
|
-
skip_context: Annotated[
|
346
|
-
bool,
|
347
|
-
typer.Option(
|
348
|
-
"--skip-context",
|
349
|
-
help="Prompts will include any context information associated with source"
|
350
|
-
" code blocks, unless this option is specified",
|
351
|
-
),
|
352
|
-
] = False,
|
353
358
|
doc_mode: Annotated[
|
354
359
|
str,
|
355
360
|
typer.Option(
|
@@ -397,6 +402,24 @@ def document(
|
|
397
402
|
click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
|
398
403
|
),
|
399
404
|
] = "file",
|
405
|
+
refiner_type: Annotated[
|
406
|
+
str,
|
407
|
+
typer.Option(
|
408
|
+
"-r",
|
409
|
+
"--refiner",
|
410
|
+
help="Name of custom refiner to use",
|
411
|
+
click_type=click.Choice(list(REFINERS.keys())),
|
412
|
+
),
|
413
|
+
] = "none",
|
414
|
+
retriever_type: Annotated[
|
415
|
+
str,
|
416
|
+
typer.Option(
|
417
|
+
"-R",
|
418
|
+
"--retriever",
|
419
|
+
help="Name of custom retriever to use",
|
420
|
+
click_type=click.Choice(["active_usings"]),
|
421
|
+
),
|
422
|
+
] = None,
|
400
423
|
max_tokens: Annotated[
|
401
424
|
int,
|
402
425
|
typer.Option(
|
@@ -406,13 +429,6 @@ def document(
|
|
406
429
|
"If unspecificed, model's default max will be used.",
|
407
430
|
),
|
408
431
|
] = None,
|
409
|
-
skip_refiner: Annotated[
|
410
|
-
bool,
|
411
|
-
typer.Option(
|
412
|
-
"--skip-refiner",
|
413
|
-
help="Whether to skip the refiner for generating output",
|
414
|
-
),
|
415
|
-
] = True,
|
416
432
|
):
|
417
433
|
model_arguments = dict(temperature=temperature)
|
418
434
|
collections_config = get_collections_config()
|
@@ -425,8 +441,8 @@ def document(
|
|
425
441
|
db_path=db_loc,
|
426
442
|
db_config=collections_config,
|
427
443
|
splitter_type=splitter_type,
|
428
|
-
|
429
|
-
|
444
|
+
refiner_type=refiner_type,
|
445
|
+
retriever_type=retriever_type,
|
430
446
|
)
|
431
447
|
if doc_mode == "madlibs":
|
432
448
|
documenter = MadLibsDocumenter(
|
@@ -615,14 +631,6 @@ def diagram(
|
|
615
631
|
help="Whether to overwrite existing files in the output directory",
|
616
632
|
),
|
617
633
|
] = False,
|
618
|
-
skip_context: Annotated[
|
619
|
-
bool,
|
620
|
-
typer.Option(
|
621
|
-
"--skip-context",
|
622
|
-
help="Prompts will include any context information associated with source"
|
623
|
-
" code blocks, unless this option is specified",
|
624
|
-
),
|
625
|
-
] = False,
|
626
634
|
temperature: Annotated[
|
627
635
|
float,
|
628
636
|
typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
|
@@ -659,13 +667,24 @@ def diagram(
|
|
659
667
|
click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
|
660
668
|
),
|
661
669
|
] = "file",
|
662
|
-
|
663
|
-
|
670
|
+
refiner_type: Annotated[
|
671
|
+
str,
|
664
672
|
typer.Option(
|
665
|
-
"
|
666
|
-
|
673
|
+
"-r",
|
674
|
+
"--refiner",
|
675
|
+
help="Name of custom refiner to use",
|
676
|
+
click_type=click.Choice(list(REFINERS.keys())),
|
667
677
|
),
|
668
|
-
] =
|
678
|
+
] = "none",
|
679
|
+
retriever_type: Annotated[
|
680
|
+
str,
|
681
|
+
typer.Option(
|
682
|
+
"-R",
|
683
|
+
"--retriever",
|
684
|
+
help="Name of custom retriever to use",
|
685
|
+
click_type=click.Choice(["active_usings"]),
|
686
|
+
),
|
687
|
+
] = None,
|
669
688
|
):
|
670
689
|
model_arguments = dict(temperature=temperature)
|
671
690
|
collections_config = get_collections_config()
|
@@ -676,11 +695,11 @@ def diagram(
|
|
676
695
|
max_prompts=max_prompts,
|
677
696
|
db_path=db_loc,
|
678
697
|
db_config=collections_config,
|
698
|
+
splitter_type=splitter_type,
|
699
|
+
refiner_type=refiner_type,
|
700
|
+
retriever_type=retriever_type,
|
679
701
|
diagram_type=diagram_type,
|
680
702
|
add_documentation=add_documentation,
|
681
|
-
splitter_type=splitter_type,
|
682
|
-
skip_refiner=skip_refiner,
|
683
|
-
skip_context=skip_context,
|
684
703
|
)
|
685
704
|
diagram_generator.translate(input_dir, output_dir, overwrite, collection)
|
686
705
|
|
@@ -934,7 +953,7 @@ def llm_add(
|
|
934
953
|
help="The type of the model",
|
935
954
|
click_type=click.Choice(sorted(list(MODEL_TYPE_CONSTRUCTORS.keys()))),
|
936
955
|
),
|
937
|
-
] = "
|
956
|
+
] = "Azure",
|
938
957
|
):
|
939
958
|
if not MODEL_CONFIG_DIR.exists():
|
940
959
|
MODEL_CONFIG_DIR.mkdir(parents=True)
|
@@ -978,6 +997,7 @@ def llm_add(
|
|
978
997
|
"model_cost": {"input": in_cost, "output": out_cost},
|
979
998
|
}
|
980
999
|
elif model_type == "OpenAI":
|
1000
|
+
print("DEPRECATED: Use 'Azure' instead. CTRL+C to exit.")
|
981
1001
|
model_id = typer.prompt(
|
982
1002
|
"Enter the model ID (list model IDs with `janus llm ls -a`)",
|
983
1003
|
default="gpt-4o",
|
@@ -999,6 +1019,28 @@ def llm_add(
|
|
999
1019
|
"token_limit": max_tokens,
|
1000
1020
|
"model_cost": model_cost,
|
1001
1021
|
}
|
1022
|
+
elif model_type == "Azure":
|
1023
|
+
model_id = typer.prompt(
|
1024
|
+
"Enter the model ID (list model IDs with `janus llm ls -a`)",
|
1025
|
+
default="gpt-4o",
|
1026
|
+
type=click.Choice(azure_models),
|
1027
|
+
show_choices=False,
|
1028
|
+
)
|
1029
|
+
params = dict(
|
1030
|
+
# Azure uses the "azure_deployment" key for what we're calling "long_model_id"
|
1031
|
+
azure_deployment=MODEL_ID_TO_LONG_ID[model_id],
|
1032
|
+
temperature=0.7,
|
1033
|
+
n=1,
|
1034
|
+
)
|
1035
|
+
max_tokens = TOKEN_LIMITS[MODEL_ID_TO_LONG_ID[model_id]]
|
1036
|
+
model_cost = COST_PER_1K_TOKENS[MODEL_ID_TO_LONG_ID[model_id]]
|
1037
|
+
cfg = {
|
1038
|
+
"model_type": model_type,
|
1039
|
+
"model_id": model_id,
|
1040
|
+
"model_args": params,
|
1041
|
+
"token_limit": max_tokens,
|
1042
|
+
"model_cost": model_cost,
|
1043
|
+
}
|
1002
1044
|
elif model_type == "BedrockChat" or model_type == "Bedrock":
|
1003
1045
|
model_id = typer.prompt(
|
1004
1046
|
"Enter the model ID (list model IDs with `janus llm ls -a`)",
|
@@ -1173,13 +1215,14 @@ def render(
|
|
1173
1215
|
for input_file in input_dir.rglob("*.json"):
|
1174
1216
|
with open(input_file, "r") as f:
|
1175
1217
|
data = json.load(f)
|
1176
|
-
|
1177
|
-
output_file = output_dir /
|
1178
|
-
output_file = output_file.with_suffix(".txt")
|
1218
|
+
|
1219
|
+
output_file = output_dir / input_file.relative_to(input_dir).with_suffix(".txt")
|
1179
1220
|
if not output_file.parent.exists():
|
1180
1221
|
output_file.parent.mkdir()
|
1181
|
-
|
1182
|
-
|
1222
|
+
|
1223
|
+
text = data["output"].replace("\\n", "\n").strip()
|
1224
|
+
output_file.write_text(text)
|
1225
|
+
|
1183
1226
|
jar_path = homedir / ".janus/lib/plantuml.jar"
|
1184
1227
|
subprocess.run(["java", "-jar", jar_path, output_file]) # nosec
|
1185
1228
|
output_file.unlink()
|
@@ -90,14 +90,14 @@ class TestDiagramGenerator(unittest.TestCase):
|
|
90
90
|
def setUp(self):
|
91
91
|
"""Set up the tests."""
|
92
92
|
self.diagram_generator = DiagramGenerator(
|
93
|
-
model="gpt-4o",
|
93
|
+
model="gpt-4o-mini",
|
94
94
|
source_language="fortran",
|
95
95
|
diagram_type="Activity",
|
96
96
|
)
|
97
97
|
|
98
98
|
def test_init(self):
|
99
99
|
"""Test __init__ method."""
|
100
|
-
self.assertEqual(self.diagram_generator._model_name, "gpt-4o")
|
100
|
+
self.assertEqual(self.diagram_generator._model_name, "gpt-4o-mini")
|
101
101
|
self.assertEqual(self.diagram_generator._source_language, "fortran")
|
102
102
|
self.assertEqual(self.diagram_generator._diagram_type, "Activity")
|
103
103
|
|
janus/converter/converter.py
CHANGED
@@ -2,13 +2,11 @@ import functools
|
|
2
2
|
import json
|
3
3
|
import time
|
4
4
|
from pathlib import Path
|
5
|
-
from typing import Any
|
5
|
+
from typing import Any
|
6
6
|
|
7
|
-
from langchain.output_parsers import RetryWithErrorOutputParser
|
8
7
|
from langchain_core.exceptions import OutputParserException
|
9
|
-
from langchain_core.
|
10
|
-
from langchain_core.
|
11
|
-
from langchain_core.runnables import RunnableLambda, RunnableParallel
|
8
|
+
from langchain_core.prompts import ChatPromptTemplate
|
9
|
+
from langchain_core.runnables import Runnable, RunnableParallel, RunnablePassthrough
|
12
10
|
from openai import BadRequestError, RateLimitError
|
13
11
|
from pydantic import ValidationError
|
14
12
|
|
@@ -22,12 +20,18 @@ from janus.language.splitter import (
|
|
22
20
|
Splitter,
|
23
21
|
TokenLimitError,
|
24
22
|
)
|
25
|
-
from janus.llm import load_model
|
26
23
|
from janus.llm.model_callbacks import get_model_callback
|
27
|
-
from janus.llm.models_info import MODEL_PROMPT_ENGINES
|
24
|
+
from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel, load_model
|
28
25
|
from janus.parsers.parser import GenericParser, JanusParser
|
29
|
-
from janus.
|
30
|
-
|
26
|
+
from janus.refiners.refiner import (
|
27
|
+
FixParserExceptions,
|
28
|
+
HallucinationRefiner,
|
29
|
+
JanusRefiner,
|
30
|
+
ReflectionRefiner,
|
31
|
+
)
|
32
|
+
|
33
|
+
# from janus.refiners.refiner import BasicRefiner, Refiner
|
34
|
+
from janus.retrievers.retriever import ActiveUsingsRetriever, JanusRetriever
|
31
35
|
from janus.utils.enums import LANGUAGES
|
32
36
|
from janus.utils.logger import create_logger
|
33
37
|
|
@@ -74,9 +78,8 @@ class Converter:
|
|
74
78
|
protected_node_types: tuple[str, ...] = (),
|
75
79
|
prune_node_types: tuple[str, ...] = (),
|
76
80
|
splitter_type: str = "file",
|
77
|
-
refiner_type: str =
|
78
|
-
|
79
|
-
skip_context: bool = False,
|
81
|
+
refiner_type: str | None = None,
|
82
|
+
retriever_type: str | None = None,
|
80
83
|
) -> None:
|
81
84
|
"""Initialize a Converter instance.
|
82
85
|
|
@@ -96,9 +99,13 @@ class Converter:
|
|
96
99
|
prune_node_types: A set of node types which should be pruned.
|
97
100
|
splitter_type: The type of splitter to use. Valid values are `"file"`,
|
98
101
|
`"tag"`, `"chunk"`, `"ast-strict"`, and `"ast-flex"`.
|
99
|
-
refiner_type: The type of refiner to use. Valid values
|
100
|
-
|
101
|
-
|
102
|
+
refiner_type: The type of refiner to use. Valid values:
|
103
|
+
- "parser"
|
104
|
+
- "reflection"
|
105
|
+
- None
|
106
|
+
retriever_type: The type of retriever to use. Valid values:
|
107
|
+
- "active_usings"
|
108
|
+
- None
|
102
109
|
"""
|
103
110
|
self._changed_attrs: set = set()
|
104
111
|
|
@@ -107,7 +114,6 @@ class Converter:
|
|
107
114
|
self.override_token_limit: bool = max_tokens is not None
|
108
115
|
|
109
116
|
self._model_name: str
|
110
|
-
self._model_id: str
|
111
117
|
self._custom_model_arguments: dict[str, Any]
|
112
118
|
|
113
119
|
self._source_language: str
|
@@ -120,24 +126,26 @@ class Converter:
|
|
120
126
|
self._prune_node_types: tuple[str, ...] = ()
|
121
127
|
self._max_tokens: int | None = max_tokens
|
122
128
|
self._prompt_template_name: str
|
123
|
-
self._splitter_type: str
|
124
129
|
self._db_path: str | None
|
125
130
|
self._db_config: dict[str, Any] | None
|
126
131
|
|
127
|
-
self.
|
128
|
-
self._llm: BaseLanguageModel
|
132
|
+
self._llm: JanusModel
|
129
133
|
self._prompt: ChatPromptTemplate
|
130
134
|
|
131
135
|
self._parser: JanusParser = GenericParser()
|
132
136
|
self._combiner: Combiner = Combiner()
|
133
137
|
|
134
|
-
self.
|
135
|
-
self.
|
138
|
+
self._splitter_type: str
|
139
|
+
self._refiner_type: str | None
|
140
|
+
self._retriever_type: str | None
|
136
141
|
|
137
|
-
self.
|
142
|
+
self._splitter: Splitter
|
143
|
+
self._refiner: JanusRefiner
|
144
|
+
self._retriever: JanusRetriever
|
138
145
|
|
139
146
|
self.set_splitter(splitter_type=splitter_type)
|
140
147
|
self.set_refiner(refiner_type=refiner_type)
|
148
|
+
self.set_retriever(retriever_type=retriever_type)
|
141
149
|
self.set_model(model_name=model, **model_arguments)
|
142
150
|
self.set_prompt(prompt_template=prompt_template)
|
143
151
|
self.set_source_language(source_language)
|
@@ -146,8 +154,6 @@ class Converter:
|
|
146
154
|
self.set_db_path(db_path=db_path)
|
147
155
|
self.set_db_config(db_config=db_config)
|
148
156
|
|
149
|
-
self.skip_context = skip_context
|
150
|
-
|
151
157
|
# Child class must call this. Should we enforce somehow?
|
152
158
|
# self._load_parameters()
|
153
159
|
|
@@ -163,9 +169,11 @@ class Converter:
|
|
163
169
|
def _load_parameters(self) -> None:
|
164
170
|
self._load_model()
|
165
171
|
self._load_prompt()
|
172
|
+
self._load_retriever()
|
173
|
+
self._load_refiner()
|
166
174
|
self._load_splitter()
|
167
175
|
self._load_vectorizer()
|
168
|
-
self.
|
176
|
+
self._load_chain()
|
169
177
|
self._changed_attrs.clear()
|
170
178
|
|
171
179
|
def set_model(self, model_name: str, **custom_arguments: dict[str, Any]):
|
@@ -184,8 +192,6 @@ class Converter:
|
|
184
192
|
def set_prompt(self, prompt_template: str) -> None:
|
185
193
|
"""Validate and set the prompt template name.
|
186
194
|
|
187
|
-
The affected objects will not be updated until translate() is called.
|
188
|
-
|
189
195
|
Arguments:
|
190
196
|
prompt_template: name of prompt template directory
|
191
197
|
(see janus/prompts/templates) or path to a directory.
|
@@ -195,29 +201,34 @@ class Converter:
|
|
195
201
|
def set_splitter(self, splitter_type: str) -> None:
|
196
202
|
"""Validate and set the prompt template name.
|
197
203
|
|
198
|
-
The affected objects will not be updated until translate() is called.
|
199
|
-
|
200
204
|
Arguments:
|
201
205
|
prompt_template: name of prompt template directory
|
202
206
|
(see janus/prompts/templates) or path to a directory.
|
203
207
|
"""
|
204
|
-
|
208
|
+
if splitter_type not in CUSTOM_SPLITTERS:
|
209
|
+
raise ValueError(f'Splitter type "{splitter_type}" does not exist.')
|
205
210
|
|
206
|
-
|
207
|
-
"""Validate and set the refiner name
|
211
|
+
self._splitter_type = splitter_type
|
208
212
|
|
209
|
-
|
213
|
+
def set_refiner(self, refiner_type: str | None) -> None:
|
214
|
+
"""Validate and set the refiner type
|
210
215
|
|
211
216
|
Arguments:
|
212
|
-
refiner_type: the
|
217
|
+
refiner_type: the type of refiner to use
|
213
218
|
"""
|
214
219
|
self._refiner_type = refiner_type
|
215
220
|
|
221
|
+
def set_retriever(self, retriever_type: str | None) -> None:
|
222
|
+
"""Validate and set the retriever type
|
223
|
+
|
224
|
+
Arguments:
|
225
|
+
retriever_type: the type of retriever to use
|
226
|
+
"""
|
227
|
+
self._retriever_type = retriever_type
|
228
|
+
|
216
229
|
def set_source_language(self, source_language: str) -> None:
|
217
230
|
"""Validate and set the source language.
|
218
231
|
|
219
|
-
The affected objects will not be updated until _load_parameters() is called.
|
220
|
-
|
221
232
|
Arguments:
|
222
233
|
source_language: The source programming language.
|
223
234
|
"""
|
@@ -287,20 +298,6 @@ class Converter:
|
|
287
298
|
|
288
299
|
self._splitter = CUSTOM_SPLITTERS[self._splitter_type](**kwargs)
|
289
300
|
|
290
|
-
@run_if_changed("_refiner_type", "_model_name")
|
291
|
-
def _load_refiner(self) -> None:
|
292
|
-
"""Load the refiner according to this instance's attributes.
|
293
|
-
|
294
|
-
If the relevant fields have not been changed since the last time this method was
|
295
|
-
called, nothing happens.
|
296
|
-
"""
|
297
|
-
if self._refiner_type == "basic":
|
298
|
-
self._refiner = BasicRefiner(
|
299
|
-
"basic_refinement", self._model_id, self._source_language
|
300
|
-
)
|
301
|
-
else:
|
302
|
-
raise ValueError(f"Error: unknown refiner type {self._refiner_type}")
|
303
|
-
|
304
301
|
@run_if_changed("_model_name", "_custom_model_arguments")
|
305
302
|
def _load_model(self) -> None:
|
306
303
|
"""Load the model according to this instance's attributes.
|
@@ -314,9 +311,9 @@ class Converter:
|
|
314
311
|
# model_arguments.update(self._custom_model_arguments)
|
315
312
|
|
316
313
|
# Load the model
|
317
|
-
self._llm
|
318
|
-
|
319
|
-
|
314
|
+
self._llm = load_model(self._model_name)
|
315
|
+
token_limit = self._llm.token_limit
|
316
|
+
|
320
317
|
# Set the max_tokens to less than half the model's limit to allow for enough
|
321
318
|
# tokens at output
|
322
319
|
# Only modify max_tokens if it is not specified by user
|
@@ -335,7 +332,7 @@ class Converter:
|
|
335
332
|
If the relevant fields have not been changed since the last time this
|
336
333
|
method was called, nothing happens.
|
337
334
|
"""
|
338
|
-
prompt_engine = MODEL_PROMPT_ENGINES[self.
|
335
|
+
prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
|
339
336
|
source_language=self._source_language,
|
340
337
|
prompt_template=self._prompt_template_name,
|
341
338
|
)
|
@@ -354,6 +351,59 @@ class Converter:
|
|
354
351
|
self._db_path, self._db_config
|
355
352
|
)
|
356
353
|
|
354
|
+
@run_if_changed("_retriever_type")
|
355
|
+
def _load_retriever(self):
|
356
|
+
if self._retriever_type == "active_usings":
|
357
|
+
self._retriever = ActiveUsingsRetriever()
|
358
|
+
else:
|
359
|
+
self._retriever = JanusRetriever()
|
360
|
+
|
361
|
+
@run_if_changed("_refiner_type", "_model_name", "max_prompts", "_parser", "_llm")
|
362
|
+
def _load_refiner(self) -> None:
|
363
|
+
"""Load the refiner according to this instance's attributes.
|
364
|
+
|
365
|
+
If the relevant fields have not been changed since the last time this method was
|
366
|
+
called, nothing happens.
|
367
|
+
"""
|
368
|
+
if self._refiner_type == "parser":
|
369
|
+
self._refiner = FixParserExceptions(
|
370
|
+
llm=self._llm,
|
371
|
+
parser=self._parser,
|
372
|
+
max_retries=self.max_prompts,
|
373
|
+
)
|
374
|
+
elif self._refiner_type == "reflection":
|
375
|
+
self._refiner = ReflectionRefiner(
|
376
|
+
llm=self._llm,
|
377
|
+
parser=self._parser,
|
378
|
+
max_retries=self.max_prompts,
|
379
|
+
)
|
380
|
+
elif self._refiner_type == "hallucination":
|
381
|
+
self._refiner = HallucinationRefiner(
|
382
|
+
llm=self._llm,
|
383
|
+
parser=self._parser,
|
384
|
+
max_retries=self.max_prompts,
|
385
|
+
)
|
386
|
+
else:
|
387
|
+
self._refiner = JanusRefiner(parser=self._parser)
|
388
|
+
|
389
|
+
@run_if_changed("_parser", "_retriever", "_prompt", "_llm", "_refiner")
|
390
|
+
def _load_chain(self):
|
391
|
+
self.chain = (
|
392
|
+
self._input_runnable()
|
393
|
+
| self._prompt
|
394
|
+
| RunnableParallel(
|
395
|
+
completion=self._llm,
|
396
|
+
prompt_value=RunnablePassthrough(),
|
397
|
+
)
|
398
|
+
| self._refiner.parse_runnable
|
399
|
+
)
|
400
|
+
|
401
|
+
def _input_runnable(self) -> Runnable:
|
402
|
+
return RunnableParallel(
|
403
|
+
SOURCE_CODE=self._parser.parse_input,
|
404
|
+
context=self._retriever,
|
405
|
+
)
|
406
|
+
|
357
407
|
def translate(
|
358
408
|
self,
|
359
409
|
input_directory: str | Path,
|
@@ -598,110 +648,29 @@ class Converter:
|
|
598
648
|
return root
|
599
649
|
|
600
650
|
def _run_chain(self, block: TranslatedCodeBlock) -> str:
|
601
|
-
|
602
|
-
First, try to fix simple formatting errors by giving the model just
|
603
|
-
the output and the parsing error. After a number of attempts, try
|
604
|
-
giving the model the output, the parsing error, and the original
|
605
|
-
input. Again check/retry this output to solve for formatting errors.
|
606
|
-
If we still haven't succeeded after several attempts, the model may
|
607
|
-
be getting thrown off by a bad initial output; start from scratch
|
608
|
-
and try again.
|
609
|
-
|
610
|
-
The number of tries for each layer of this scheme is roughly equal
|
611
|
-
to the cube root of self.max_retries, so the total calls to the
|
612
|
-
LLM will be roughly as expected (up to sqrt(self.max_retries) over)
|
613
|
-
"""
|
614
|
-
input = self._parser.parse_input(block.original)
|
615
|
-
|
616
|
-
# Retries with just the output and the error
|
617
|
-
n1 = round(self.max_prompts ** (1 / 2))
|
618
|
-
|
619
|
-
# Retries with the input, output, and error
|
620
|
-
n2 = round(self.max_prompts // n1)
|
621
|
-
|
622
|
-
if not self.skip_context:
|
623
|
-
self._make_prompt_additions(block)
|
624
|
-
if not self.skip_refiner: # Make replacements in the prompt
|
625
|
-
refine_output = RefinerParser(
|
626
|
-
parser=self._parser,
|
627
|
-
initial_prompt=self._prompt.format(**{"SOURCE_CODE": input}),
|
628
|
-
refiner=self._refiner,
|
629
|
-
max_retries=n1,
|
630
|
-
llm=self._llm,
|
631
|
-
)
|
632
|
-
else:
|
633
|
-
refine_output = RetryWithErrorOutputParser.from_llm(
|
634
|
-
llm=self._llm,
|
635
|
-
parser=self._parser,
|
636
|
-
max_retries=n1,
|
637
|
-
)
|
638
|
-
|
639
|
-
completion_chain = self._prompt | self._llm
|
640
|
-
chain = RunnableParallel(
|
641
|
-
completion=completion_chain, prompt_value=self._prompt
|
642
|
-
) | RunnableLambda(lambda x: refine_output.parse_with_prompt(**x))
|
643
|
-
for _ in range(n2):
|
644
|
-
try:
|
645
|
-
return chain.invoke({"SOURCE_CODE": input})
|
646
|
-
except OutputParserException:
|
647
|
-
pass
|
648
|
-
|
649
|
-
raise OutputParserException(f"Failed to parse after {n1*n2} retries")
|
651
|
+
return self.chain.invoke(block.original)
|
650
652
|
|
651
653
|
def _get_output_obj(
|
652
654
|
self, block: TranslatedCodeBlock
|
653
|
-
) -> dict[str, int | float | str | dict[str, str]]:
|
655
|
+
) -> dict[str, int | float | str | dict[str, str] | dict[str, float]]:
|
654
656
|
output_str = self._parser.parse_combined_output(block.complete_text)
|
655
657
|
|
656
|
-
|
658
|
+
output_obj: str | dict[str, str]
|
657
659
|
try:
|
658
|
-
|
660
|
+
output_obj = json.loads(output_str)
|
659
661
|
except json.JSONDecodeError:
|
660
|
-
|
662
|
+
output_obj = output_str
|
661
663
|
|
662
664
|
return dict(
|
663
|
-
input=block.original.text,
|
665
|
+
input=block.original.text or "",
|
664
666
|
metadata=dict(
|
665
667
|
retries=block.total_retries,
|
666
668
|
cost=block.total_cost,
|
667
669
|
processing_time=block.processing_time,
|
668
670
|
),
|
669
|
-
output=
|
670
|
-
)
|
671
|
-
|
672
|
-
@staticmethod
|
673
|
-
def _get_prompt_additions(block) -> Optional[List[Tuple[str, str]]]:
|
674
|
-
"""Get a list of strings to append to the prompt.
|
675
|
-
|
676
|
-
Arguments:
|
677
|
-
block: The `TranslatedCodeBlock` to save to a file.
|
678
|
-
"""
|
679
|
-
return [(key, item) for key, item in block.context_tags.items()]
|
680
|
-
|
681
|
-
def _make_prompt_additions(self, block: CodeBlock):
|
682
|
-
# Prepare the additional context to prepend
|
683
|
-
additional_context = "".join(
|
684
|
-
[
|
685
|
-
f"{context_tag}: {context}\n"
|
686
|
-
for context_tag, context in self._get_prompt_additions(block)
|
687
|
-
]
|
671
|
+
output=output_obj,
|
688
672
|
)
|
689
673
|
|
690
|
-
if not hasattr(self._prompt, "messages"):
|
691
|
-
log.debug("Skipping additions to prompt, no messages found on prompt object!")
|
692
|
-
return
|
693
|
-
|
694
|
-
# Iterate through existing messages to find and update the system message
|
695
|
-
for i, message in enumerate(self._prompt.messages):
|
696
|
-
if isinstance(message, SystemMessagePromptTemplate):
|
697
|
-
# Prepend the additional context to the system message
|
698
|
-
updated_system_message = SystemMessagePromptTemplate.from_template(
|
699
|
-
additional_context + message.prompt.template
|
700
|
-
)
|
701
|
-
# Directly modify the message in the list
|
702
|
-
self._prompt.messages[i] = updated_system_message
|
703
|
-
break # Assuming there's only one system message to update
|
704
|
-
|
705
674
|
def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
|
706
675
|
"""Save a file to disk.
|
707
676
|
|