janus-llm 3.5.3__py3-none-any.whl → 4.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +1 -1
- janus/cli.py +91 -48
- janus/converter/_tests/test_translate.py +2 -2
- janus/converter/converter.py +111 -142
- janus/converter/diagram.py +21 -109
- janus/converter/translate.py +1 -1
- janus/language/alc/_tests/test_alc.py +1 -1
- janus/language/alc/alc.py +15 -10
- janus/language/binary/_tests/test_binary.py +1 -1
- janus/language/binary/binary.py +2 -2
- janus/language/mumps/_tests/test_mumps.py +1 -1
- janus/language/mumps/mumps.py +2 -3
- janus/language/splitter.py +2 -2
- janus/language/treesitter/_tests/test_treesitter.py +1 -1
- janus/language/treesitter/treesitter.py +2 -2
- janus/llm/model_callbacks.py +22 -0
- janus/llm/models_info.py +142 -81
- janus/metrics/metric.py +15 -14
- janus/parsers/uml.py +60 -23
- janus/refiners/refiner.py +106 -64
- janus/retrievers/retriever.py +42 -0
- {janus_llm-3.5.3.dist-info → janus_llm-4.1.0.dist-info}/METADATA +1 -1
- {janus_llm-3.5.3.dist-info → janus_llm-4.1.0.dist-info}/RECORD +26 -26
- janus/parsers/refiner_parser.py +0 -46
- {janus_llm-3.5.3.dist-info → janus_llm-4.1.0.dist-info}/LICENSE +0 -0
- {janus_llm-3.5.3.dist-info → janus_llm-4.1.0.dist-info}/WHEEL +0 -0
- {janus_llm-3.5.3.dist-info → janus_llm-4.1.0.dist-info}/entry_points.txt +0 -0
janus/llm/models_info.py
CHANGED
@@ -1,15 +1,14 @@
|
|
1
1
|
import json
|
2
2
|
import os
|
3
|
-
import time
|
4
3
|
from pathlib import Path
|
5
|
-
from typing import
|
4
|
+
from typing import Callable, Protocol, TypeVar
|
6
5
|
|
7
6
|
from dotenv import load_dotenv
|
8
7
|
from langchain_community.llms import HuggingFaceTextGenInference
|
9
|
-
from langchain_core.
|
10
|
-
from langchain_openai import
|
8
|
+
from langchain_core.runnables import Runnable
|
9
|
+
from langchain_openai import AzureChatOpenAI
|
11
10
|
|
12
|
-
from janus.llm.model_callbacks import COST_PER_1K_TOKENS
|
11
|
+
from janus.llm.model_callbacks import COST_PER_1K_TOKENS, azure_model_reroutes
|
13
12
|
from janus.prompts.prompt import (
|
14
13
|
ChatGptPromptEngine,
|
15
14
|
ClaudePromptEngine,
|
@@ -44,17 +43,33 @@ except ImportError:
|
|
44
43
|
)
|
45
44
|
|
46
45
|
|
47
|
-
|
46
|
+
ModelType = TypeVar(
|
47
|
+
"ModelType",
|
48
|
+
AzureChatOpenAI,
|
49
|
+
HuggingFaceTextGenInference,
|
50
|
+
Bedrock,
|
51
|
+
BedrockChat,
|
52
|
+
HuggingFacePipeline,
|
53
|
+
)
|
48
54
|
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
55
|
+
|
56
|
+
class JanusModelProtocol(Protocol):
|
57
|
+
model_id: str
|
58
|
+
model_type_name: str
|
59
|
+
token_limit: int
|
60
|
+
input_token_cost: float
|
61
|
+
output_token_cost: float
|
62
|
+
prompt_engine: type[PromptEngine]
|
63
|
+
|
64
|
+
def get_num_tokens(self, text: str) -> int:
|
65
|
+
...
|
66
|
+
|
67
|
+
|
68
|
+
class JanusModel(Runnable, JanusModelProtocol):
|
69
|
+
...
|
70
|
+
|
71
|
+
|
72
|
+
load_dotenv()
|
58
73
|
|
59
74
|
openai_models = [
|
60
75
|
"gpt-4o",
|
@@ -65,6 +80,11 @@ openai_models = [
|
|
65
80
|
"gpt-3.5-turbo",
|
66
81
|
"gpt-3.5-turbo-16k",
|
67
82
|
]
|
83
|
+
azure_models = [
|
84
|
+
"gpt-4o",
|
85
|
+
"gpt-4o-mini",
|
86
|
+
"gpt-3.5-turbo-16k",
|
87
|
+
]
|
68
88
|
claude_models = [
|
69
89
|
"bedrock-claude-v2",
|
70
90
|
"bedrock-claude-instant-v1",
|
@@ -103,27 +123,21 @@ bedrock_models = [
|
|
103
123
|
*cohere_models,
|
104
124
|
*mistral_models,
|
105
125
|
]
|
106
|
-
all_models = [*
|
126
|
+
all_models = [*azure_models, *bedrock_models]
|
107
127
|
|
108
|
-
MODEL_TYPE_CONSTRUCTORS: dict[str,
|
109
|
-
"OpenAI": ChatOpenAI,
|
128
|
+
MODEL_TYPE_CONSTRUCTORS: dict[str, ModelType] = {
|
129
|
+
# "OpenAI": ChatOpenAI,
|
110
130
|
"HuggingFace": HuggingFaceTextGenInference,
|
131
|
+
"Azure": AzureChatOpenAI,
|
132
|
+
"Bedrock": Bedrock,
|
133
|
+
"BedrockChat": BedrockChat,
|
134
|
+
"HuggingFaceLocal": HuggingFacePipeline,
|
111
135
|
}
|
112
136
|
|
113
|
-
try:
|
114
|
-
MODEL_TYPE_CONSTRUCTORS.update(
|
115
|
-
{
|
116
|
-
"HuggingFaceLocal": HuggingFacePipeline.from_model_id,
|
117
|
-
"Bedrock": Bedrock,
|
118
|
-
"BedrockChat": BedrockChat,
|
119
|
-
}
|
120
|
-
)
|
121
|
-
except NameError:
|
122
|
-
pass
|
123
|
-
|
124
137
|
|
125
138
|
MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
|
126
|
-
**{m: ChatGptPromptEngine for m in openai_models},
|
139
|
+
# **{m: ChatGptPromptEngine for m in openai_models},
|
140
|
+
**{m: ChatGptPromptEngine for m in azure_models},
|
127
141
|
**{m: ClaudePromptEngine for m in claude_models},
|
128
142
|
**{m: Llama2PromptEngine for m in llama2_models},
|
129
143
|
**{m: Llama3PromptEngine for m in llama3_models},
|
@@ -132,13 +146,9 @@ MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
|
|
132
146
|
**{m: MistralPromptEngine for m in mistral_models},
|
133
147
|
}
|
134
148
|
|
135
|
-
_open_ai_defaults: dict[str, str] = {
|
136
|
-
"openai_api_key": os.getenv("OPENAI_API_KEY"),
|
137
|
-
"openai_organization": os.getenv("OPENAI_ORG_ID"),
|
138
|
-
}
|
139
|
-
|
140
149
|
MODEL_ID_TO_LONG_ID = {
|
141
|
-
**{m: mr for m, mr in openai_model_reroutes.items()},
|
150
|
+
# **{m: mr for m, mr in openai_model_reroutes.items()},
|
151
|
+
**{m: mr for m, mr in azure_model_reroutes.items()},
|
142
152
|
"bedrock-claude-v2": "anthropic.claude-v2",
|
143
153
|
"bedrock-claude-instant-v1": "anthropic.claude-instant-v1",
|
144
154
|
"bedrock-claude-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
|
@@ -169,7 +179,8 @@ DEFAULT_MODELS = list(MODEL_DEFAULT_ARGUMENTS.keys())
|
|
169
179
|
MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "llm"
|
170
180
|
|
171
181
|
MODEL_TYPES: dict[str, PromptEngine] = {
|
172
|
-
**{m: "OpenAI" for m in openai_models},
|
182
|
+
# **{m: "OpenAI" for m in openai_models},
|
183
|
+
**{m: "Azure" for m in azure_models},
|
173
184
|
**{m: "BedrockChat" for m in bedrock_models},
|
174
185
|
}
|
175
186
|
|
@@ -179,7 +190,10 @@ TOKEN_LIMITS: dict[str, int] = {
|
|
179
190
|
"gpt-4-1106-preview": 128_000,
|
180
191
|
"gpt-4-0125-preview": 128_000,
|
181
192
|
"gpt-4o-2024-05-13": 128_000,
|
193
|
+
"gpt-4o-2024-08-06": 128_000,
|
194
|
+
"gpt-4o-mini": 128_000,
|
182
195
|
"gpt-3.5-turbo-0125": 16_384,
|
196
|
+
"gpt35-turbo-16k": 16_384,
|
183
197
|
"text-embedding-ada-002": 8191,
|
184
198
|
"gpt4all": 16_384,
|
185
199
|
"anthropic.claude-v2": 100_000,
|
@@ -211,53 +225,100 @@ def get_available_model_names() -> list[str]:
|
|
211
225
|
return avaialable_models
|
212
226
|
|
213
227
|
|
214
|
-
def load_model(
|
215
|
-
user_model_name: str,
|
216
|
-
) -> tuple[BaseLanguageModel, str, int, dict[str, float]]:
|
228
|
+
def load_model(model_id) -> JanusModel:
|
217
229
|
if not MODEL_CONFIG_DIR.exists():
|
218
230
|
MODEL_CONFIG_DIR.mkdir(parents=True)
|
219
|
-
model_config_file = MODEL_CONFIG_DIR / f"{
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
f"default models for {user_model_name}."
|
224
|
-
)
|
225
|
-
model_id = user_model_name
|
226
|
-
if user_model_name not in DEFAULT_MODELS:
|
227
|
-
message = (
|
228
|
-
f"Model {user_model_name} not found in default models. Make sure to run "
|
229
|
-
"`janus llm add` first."
|
230
|
-
)
|
231
|
-
log.error(message)
|
232
|
-
raise ValueError(message)
|
233
|
-
model_config = {
|
234
|
-
"model_type": MODEL_TYPES[model_id],
|
235
|
-
"model_id": model_id,
|
236
|
-
"model_args": MODEL_DEFAULT_ARGUMENTS[model_id],
|
237
|
-
"token_limit": TOKEN_LIMITS.get(MODEL_ID_TO_LONG_ID[model_id], 4096),
|
238
|
-
"model_cost": COST_PER_1K_TOKENS.get(
|
239
|
-
MODEL_ID_TO_LONG_ID[model_id], {"input": 0, "output": 0}
|
240
|
-
),
|
241
|
-
}
|
242
|
-
with open(model_config_file, "w") as f:
|
243
|
-
json.dump(model_config, f)
|
244
|
-
else:
|
231
|
+
model_config_file = MODEL_CONFIG_DIR / f"{model_id}.json"
|
232
|
+
|
233
|
+
if model_config_file.exists():
|
234
|
+
log.info(f"Loading {model_id} from {model_config_file}.")
|
245
235
|
with open(model_config_file, "r") as f:
|
246
236
|
model_config = json.load(f)
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
237
|
+
model_type_name = model_config["model_type"]
|
238
|
+
model_id = model_config["model_id"]
|
239
|
+
model_args = model_config["model_args"]
|
240
|
+
token_limit = model_config["token_limit"]
|
241
|
+
input_token_cost = model_config["model_cost"]["input"]
|
242
|
+
output_token_cost = model_config["model_cost"]["output"]
|
243
|
+
|
244
|
+
elif model_id in DEFAULT_MODELS:
|
245
|
+
model_id = model_id
|
246
|
+
model_long_id = MODEL_ID_TO_LONG_ID[model_id]
|
247
|
+
model_type_name = MODEL_TYPES[model_id]
|
248
|
+
model_args = MODEL_DEFAULT_ARGUMENTS[model_id]
|
249
|
+
|
250
|
+
token_limit = 0
|
251
|
+
input_token_cost = 0.0
|
252
|
+
output_token_cost = 0.0
|
253
|
+
if model_long_id in TOKEN_LIMITS:
|
254
|
+
token_limit = TOKEN_LIMITS[model_long_id]
|
255
|
+
if model_long_id in COST_PER_1K_TOKENS:
|
256
|
+
token_limits = COST_PER_1K_TOKENS[model_long_id]
|
257
|
+
input_token_cost = token_limits["input"]
|
258
|
+
output_token_cost = token_limits["output"]
|
259
|
+
|
260
|
+
else:
|
261
|
+
model_list = "\n\t".join(DEFAULT_MODELS)
|
262
|
+
message = (
|
263
|
+
f"Model {model_id} not found in user-defined model directory "
|
264
|
+
f"({MODEL_CONFIG_DIR}), and is not a default model. Valid default "
|
265
|
+
f"models:\n\t{model_list}\n"
|
266
|
+
f"To use a custom model, first run `janus llm add`."
|
267
|
+
)
|
268
|
+
log.error(message)
|
269
|
+
raise ValueError(message)
|
270
|
+
|
271
|
+
if model_type_name == "HuggingFaceLocal":
|
272
|
+
model = HuggingFacePipeline.from_model_id(
|
273
|
+
model_id=model_id,
|
274
|
+
task="text-generation",
|
275
|
+
model_kwargs=model_args,
|
276
|
+
)
|
277
|
+
model_args.update(pipeline=model.pipeline)
|
278
|
+
|
279
|
+
elif model_type_name == "OpenAI":
|
280
|
+
model_args.update(
|
281
|
+
openai_api_key=str(os.getenv("OPENAI_API_KEY")),
|
282
|
+
openai_organization=str(os.getenv("OPENAI_ORG_ID")),
|
283
|
+
)
|
284
|
+
# log.warning("Do NOT use this model in sensitive environments!")
|
285
|
+
# log.warning("If you would like to cancel, please press Ctrl+C.")
|
286
|
+
# log.warning("Waiting 10 seconds...")
|
254
287
|
# Give enough time for the user to read the warnings and cancel
|
255
|
-
time.sleep(10)
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
288
|
+
# time.sleep(10)
|
289
|
+
raise DeprecationWarning("OpenAI models are no longer supported.")
|
290
|
+
|
291
|
+
elif model_type_name == "Azure":
|
292
|
+
model_args.update(
|
293
|
+
{
|
294
|
+
"api_key": os.getenv("AZURE_OPENAI_API_KEY"),
|
295
|
+
"azure_endpoint": os.getenv("AZURE_OPENAI_ENDPOINT"),
|
296
|
+
"api_version": os.getenv("OPENAI_API_VERSION", "2024-02-01"),
|
297
|
+
}
|
298
|
+
)
|
299
|
+
|
300
|
+
model_type = MODEL_TYPE_CONSTRUCTORS[model_type_name]
|
301
|
+
prompt_engine = MODEL_PROMPT_ENGINES[model_id]
|
302
|
+
|
303
|
+
class JanusModel(model_type):
|
304
|
+
model_id: str
|
305
|
+
short_model_id: str
|
306
|
+
model_type_name: str
|
307
|
+
token_limit: int
|
308
|
+
input_token_cost: float
|
309
|
+
output_token_cost: float
|
310
|
+
prompt_engine: type[PromptEngine]
|
311
|
+
|
312
|
+
model_args.update(
|
313
|
+
model_id=MODEL_ID_TO_LONG_ID[model_id],
|
314
|
+
short_model_id=model_id,
|
315
|
+
)
|
316
|
+
|
317
|
+
return JanusModel(
|
318
|
+
model_type_name=model_type_name,
|
319
|
+
token_limit=token_limit,
|
320
|
+
input_token_cost=input_token_cost,
|
321
|
+
output_token_cost=output_token_cost,
|
322
|
+
prompt_engine=prompt_engine,
|
323
|
+
**model_args,
|
263
324
|
)
|
janus/metrics/metric.py
CHANGED
@@ -8,6 +8,7 @@ import typer
|
|
8
8
|
from typing_extensions import Annotated
|
9
9
|
|
10
10
|
from janus.llm import load_model
|
11
|
+
from janus.llm.model_callbacks import COST_PER_1K_TOKENS
|
11
12
|
from janus.metrics.cli import evaluate
|
12
13
|
from janus.metrics.file_pairing import FILE_PAIRING_METHODS
|
13
14
|
from janus.metrics.splitting import SPLITTING_METHODS
|
@@ -135,7 +136,7 @@ def metric(
|
|
135
136
|
**kwargs,
|
136
137
|
):
|
137
138
|
out = []
|
138
|
-
llm
|
139
|
+
llm = load_model(llm_name)
|
139
140
|
if json_file_name is not None:
|
140
141
|
with open(json_file_name, "r") as f:
|
141
142
|
json_obj = json.load(f)
|
@@ -171,8 +172,8 @@ def metric(
|
|
171
172
|
out_file=out_file,
|
172
173
|
lang=language,
|
173
174
|
llm=llm,
|
174
|
-
token_limit=token_limit,
|
175
|
-
model_cost=
|
175
|
+
token_limit=llm.token_limit,
|
176
|
+
model_cost=COST_PER_1K_TOKENS[llm.model_id],
|
176
177
|
)
|
177
178
|
else:
|
178
179
|
raise ValueError(
|
@@ -187,8 +188,8 @@ def metric(
|
|
187
188
|
progress,
|
188
189
|
language,
|
189
190
|
llm,
|
190
|
-
token_limit,
|
191
|
-
|
191
|
+
llm.token_limit,
|
192
|
+
COST_PER_1K_TOKENS[llm.model_id],
|
192
193
|
*args,
|
193
194
|
**kwargs,
|
194
195
|
)
|
@@ -199,8 +200,8 @@ def metric(
|
|
199
200
|
progress,
|
200
201
|
language,
|
201
202
|
llm,
|
202
|
-
token_limit,
|
203
|
-
|
203
|
+
llm.token_limit,
|
204
|
+
COST_PER_1K_TOKENS[llm.model_id],
|
204
205
|
*args,
|
205
206
|
**kwargs,
|
206
207
|
)
|
@@ -296,7 +297,7 @@ def metric(
|
|
296
297
|
*args,
|
297
298
|
**kwargs,
|
298
299
|
):
|
299
|
-
llm
|
300
|
+
llm = load_model(llm_name)
|
300
301
|
if json_file_name is not None:
|
301
302
|
with open(json_file_name, "r") as f:
|
302
303
|
json_obj = json.load(f)
|
@@ -328,8 +329,8 @@ def metric(
|
|
328
329
|
out_file=out_file,
|
329
330
|
lang=language,
|
330
331
|
llm=llm,
|
331
|
-
token_limit=token_limit,
|
332
|
-
model_cost=
|
332
|
+
token_limit=llm.token_limit,
|
333
|
+
model_cost=COST_PER_1K_TOKENS[llm.model_id],
|
333
334
|
)
|
334
335
|
else:
|
335
336
|
raise ValueError(
|
@@ -344,8 +345,8 @@ def metric(
|
|
344
345
|
progress,
|
345
346
|
language,
|
346
347
|
llm,
|
347
|
-
token_limit,
|
348
|
-
|
348
|
+
llm.token_limit,
|
349
|
+
COST_PER_1K_TOKENS[llm.model_id],
|
349
350
|
*args,
|
350
351
|
**kwargs,
|
351
352
|
)
|
@@ -356,8 +357,8 @@ def metric(
|
|
356
357
|
progress,
|
357
358
|
language,
|
358
359
|
llm,
|
359
|
-
token_limit,
|
360
|
-
|
360
|
+
llm.token_limit,
|
361
|
+
COST_PER_1K_TOKENS[llm.model_id],
|
361
362
|
*args,
|
362
363
|
**kwargs,
|
363
364
|
)
|
janus/parsers/uml.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import re
|
2
2
|
import subprocess # nosec
|
3
3
|
from pathlib import Path
|
4
|
-
from
|
4
|
+
from tempfile import NamedTemporaryFile
|
5
5
|
|
6
6
|
from langchain_core.exceptions import OutputParserException
|
7
7
|
from langchain_core.messages import BaseMessage
|
@@ -13,39 +13,76 @@ log = create_logger(__name__)
|
|
13
13
|
|
14
14
|
|
15
15
|
class UMLSyntaxParser(CodeParser):
|
16
|
-
def
|
17
|
-
#
|
18
|
-
|
16
|
+
def _check_plantuml(self, text: str) -> None:
|
17
|
+
# Leading newlines can break the parser, remove them
|
18
|
+
text = text.replace("\\n", "\n").strip()
|
19
|
+
|
20
|
+
# Write the text to a temporary file (automatically deleted)
|
21
|
+
file = NamedTemporaryFile()
|
22
|
+
fname = file.name
|
23
|
+
with open(fname, "w") as fin:
|
24
|
+
fin.write(text)
|
25
|
+
|
19
26
|
try:
|
20
27
|
plantuml_path = Path.home().expanduser() / ".janus/lib/plantuml.jar"
|
28
|
+
# NOTE: running subprocess with shell=False, added nosec to
|
29
|
+
# label that we know risk exists
|
21
30
|
res = subprocess.run(
|
22
|
-
["java", "-jar", plantuml_path,
|
31
|
+
["java", "-jar", plantuml_path, fname],
|
23
32
|
stdout=subprocess.PIPE,
|
24
33
|
stderr=subprocess.PIPE,
|
25
34
|
) # nosec
|
26
35
|
stdout = res.stdout.decode("utf-8")
|
27
36
|
stderr = res.stderr.decode("utf-8")
|
28
37
|
except FileNotFoundError:
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
38
|
+
err_txt = (
|
39
|
+
"Plant UML executable not found. Either choose a different parser"
|
40
|
+
" or install with `bash scripts/install_plantuml.sh`. Java and"
|
41
|
+
" graphviz are dependencies for the tool, they must also be installed."
|
42
|
+
)
|
43
|
+
log.error(err_txt)
|
44
|
+
raise Exception(err_txt)
|
45
|
+
|
46
|
+
# Check for bad outputs, raise OutputParserExceptions if so
|
47
|
+
if "Error" in stderr or "Error" in stdout:
|
48
|
+
err_txt = "Recieved UML parsing error(s)."
|
49
|
+
|
50
|
+
line_nos = self._get_error_lines(stderr) + self._get_error_lines(stdout)
|
51
|
+
lines = text.split("\n")
|
52
|
+
for i in line_nos:
|
53
|
+
i0 = max(0, i - 3)
|
54
|
+
i1 = min(len(lines) - 1, i + 2)
|
55
|
+
err_lines = [
|
56
|
+
f"> {lines[j]}" if j == i - 1 else f" {lines[j]}"
|
57
|
+
for j in range(i0, i1)
|
58
|
+
]
|
59
|
+
if i0:
|
60
|
+
err_lines.insert(0, " ...")
|
61
|
+
if i1 < (len(lines) - 1):
|
62
|
+
err_lines.append(" ...")
|
33
63
|
|
34
|
-
|
35
|
-
|
64
|
+
err_txt += f"\nError located at line {i} must be fixed:\n"
|
65
|
+
err_txt += "\n".join(err_lines)
|
66
|
+
log.warning(err_txt)
|
67
|
+
raise OutputParserException(err_txt)
|
68
|
+
|
69
|
+
if "Warning" in stdout or "Warning" in stderr:
|
70
|
+
err_txt = "Recieved UML parsing warning (often due to missing PLANTUML)."
|
71
|
+
if stderr:
|
72
|
+
err_txt += f"\nSTDERR:\n```\n{stderr.strip()}\n```\n"
|
73
|
+
if stdout:
|
74
|
+
err_txt += f"\nSTDOUT:\n```\n{stdout.strip()}\n```\n"
|
75
|
+
|
76
|
+
log.warning(err_txt)
|
77
|
+
raise OutputParserException(err_txt)
|
78
|
+
|
79
|
+
def _get_error_lines(self, s: str) -> list[int]:
|
80
|
+
return [int(x.group(1)) for x in re.finditer(r"Error line (\d+) in file:", s)]
|
81
|
+
|
82
|
+
def _get_warns(self, s: str) -> list[str]:
|
83
|
+
return [x.group() for x in re.finditer(r"Warning: (.*)\n", s)]
|
36
84
|
|
37
85
|
def parse(self, text: str | BaseMessage) -> str:
|
38
86
|
text = super().parse(text)
|
39
|
-
|
40
|
-
if not janus_path.exists():
|
41
|
-
janus_path.mkdir()
|
42
|
-
temp_file_path = janus_path / "tmp.txt"
|
43
|
-
with open(temp_file_path, "w") as f:
|
44
|
-
f.write(text)
|
45
|
-
uml_std_out, uml_std_err = self._get_uml_output(temp_file_path)
|
46
|
-
uml_errs = self._get_errs(uml_std_out) + self._get_errs(uml_std_err)
|
47
|
-
if len(uml_errs) > 0:
|
48
|
-
raise OutputParserException(
|
49
|
-
"Error: Received UML Errors:\n" + "\n".join(uml_errs)
|
50
|
-
)
|
87
|
+
self._check_plantuml(text)
|
51
88
|
return text
|
janus/refiners/refiner.py
CHANGED
@@ -1,73 +1,115 @@
|
|
1
|
-
from
|
1
|
+
from typing import Any
|
2
2
|
|
3
|
-
from
|
3
|
+
from langchain.output_parsers import RetryWithErrorOutputParser
|
4
|
+
from langchain_core.output_parsers import StrOutputParser
|
5
|
+
from langchain_core.prompt_values import PromptValue
|
6
|
+
from langchain_core.runnables import RunnableSerializable
|
4
7
|
|
8
|
+
from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel
|
9
|
+
from janus.parsers.parser import JanusParser
|
10
|
+
from janus.utils.logger import create_logger
|
5
11
|
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
**
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
original_output: origial output of llm
|
20
|
-
errors: list of errors detected by parser
|
21
|
-
|
22
|
-
Returns:
|
23
|
-
Tuple of new prompt and prompt arguments
|
24
|
-
"""
|
12
|
+
log = create_logger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class JanusRefiner(JanusParser):
|
16
|
+
parser: JanusParser
|
17
|
+
|
18
|
+
def parse_runnable(self, input: dict[str, Any]) -> Any:
|
19
|
+
return self.parse_completion(**input)
|
20
|
+
|
21
|
+
def parse_completion(self, completion: str, **kwargs) -> Any:
|
22
|
+
return self.parser.parse(completion)
|
23
|
+
|
24
|
+
def parse(self, text: str) -> str:
|
25
25
|
raise NotImplementedError
|
26
26
|
|
27
27
|
|
28
|
-
class
|
28
|
+
class FixParserExceptions(JanusRefiner, RetryWithErrorOutputParser):
|
29
|
+
def __init__(self, llm: JanusModel, parser: JanusParser, max_retries: int):
|
30
|
+
retry_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
|
31
|
+
source_language="text",
|
32
|
+
prompt_template="refinement/fix_exceptions",
|
33
|
+
).prompt
|
34
|
+
chain = retry_prompt | llm | StrOutputParser()
|
35
|
+
RetryWithErrorOutputParser.__init__(
|
36
|
+
self, parser=parser, retry_chain=chain, max_retries=max_retries
|
37
|
+
)
|
38
|
+
|
39
|
+
def parse_completion(
|
40
|
+
self, completion: str, prompt_value: PromptValue, **kwargs
|
41
|
+
) -> Any:
|
42
|
+
return self.parse_with_prompt(completion, prompt_value=prompt_value)
|
43
|
+
|
44
|
+
|
45
|
+
class ReflectionRefiner(JanusRefiner):
|
46
|
+
max_retries: int
|
47
|
+
reflection_chain: RunnableSerializable
|
48
|
+
revision_chain: RunnableSerializable
|
49
|
+
|
29
50
|
def __init__(
|
30
51
|
self,
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
52
|
+
llm: JanusModel,
|
53
|
+
parser: JanusParser,
|
54
|
+
max_retries: int,
|
55
|
+
prompt_template_name: str = "refinement/reflection",
|
56
|
+
):
|
57
|
+
reflection_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
|
58
|
+
source_language="text",
|
59
|
+
prompt_template=prompt_template_name,
|
60
|
+
).prompt
|
61
|
+
revision_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
|
62
|
+
source_language="text",
|
63
|
+
prompt_template="refinement/revision",
|
64
|
+
).prompt
|
65
|
+
|
66
|
+
reflection_chain = reflection_prompt | llm | StrOutputParser()
|
67
|
+
revision_chain = revision_prompt | llm | StrOutputParser()
|
68
|
+
super().__init__(
|
69
|
+
reflection_chain=reflection_chain,
|
70
|
+
revision_chain=revision_chain,
|
71
|
+
parser=parser,
|
72
|
+
max_retries=max_retries,
|
73
|
+
)
|
74
|
+
|
75
|
+
def parse_completion(
|
76
|
+
self, completion: str, prompt_value: PromptValue, **kwargs
|
77
|
+
) -> Any:
|
78
|
+
for retry_number in range(self.max_retries):
|
79
|
+
reflection = self.reflection_chain.invoke(
|
80
|
+
dict(
|
81
|
+
prompt=prompt_value.to_string(),
|
82
|
+
completion=completion,
|
83
|
+
)
|
84
|
+
)
|
85
|
+
if reflection.strip() == "LGTM":
|
86
|
+
return self.parser.parse(completion)
|
87
|
+
if not retry_number:
|
88
|
+
log.info(f"Completion:\n{completion}")
|
89
|
+
log.info(f"Reflection:\n{reflection}")
|
90
|
+
completion = self.revision_chain.invoke(
|
91
|
+
dict(
|
92
|
+
prompt=prompt_value.to_string(),
|
93
|
+
completion=completion,
|
94
|
+
reflection=reflection,
|
95
|
+
)
|
96
|
+
)
|
97
|
+
log.info(f"Revision:\n{completion}")
|
98
|
+
|
99
|
+
return self.parser.parse(completion)
|
100
|
+
|
101
|
+
|
102
|
+
class HallucinationRefiner(ReflectionRefiner):
|
103
|
+
def __init__(self, **kwargs):
|
104
|
+
super().__init__(
|
105
|
+
prompt_template_name="refinement/hallucination",
|
106
|
+
**kwargs,
|
67
107
|
)
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
108
|
+
|
109
|
+
|
110
|
+
REFINERS = dict(
|
111
|
+
none=JanusRefiner,
|
112
|
+
parser=FixParserExceptions,
|
113
|
+
reflection=ReflectionRefiner,
|
114
|
+
hallucination=HallucinationRefiner,
|
115
|
+
)
|