janus-llm 3.5.2__py3-none-any.whl → 4.0.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- janus/__init__.py +1 -1
- janus/cli.py +90 -42
- janus/converter/converter.py +111 -142
- janus/converter/diagram.py +21 -109
- janus/converter/translate.py +1 -1
- janus/language/alc/_tests/test_alc.py +1 -1
- janus/language/alc/alc.py +16 -11
- janus/language/binary/_tests/test_binary.py +1 -1
- janus/language/binary/binary.py +2 -2
- janus/language/mumps/_tests/test_mumps.py +1 -1
- janus/language/mumps/mumps.py +2 -3
- janus/language/naive/simple_ast.py +3 -2
- janus/language/splitter.py +7 -4
- janus/language/treesitter/_tests/test_treesitter.py +1 -1
- janus/language/treesitter/treesitter.py +2 -2
- janus/llm/model_callbacks.py +13 -0
- janus/llm/models_info.py +118 -71
- janus/metrics/metric.py +15 -14
- janus/parsers/uml.py +60 -23
- janus/refiners/refiner.py +106 -64
- janus/retrievers/retriever.py +42 -0
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/METADATA +1 -1
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/RECORD +26 -26
- janus/parsers/refiner_parser.py +0 -46
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/LICENSE +0 -0
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/WHEEL +0 -0
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/entry_points.txt +0 -0
janus/llm/models_info.py
CHANGED
@@ -1,14 +1,15 @@
|
|
1
1
|
import json
|
2
2
|
import os
|
3
|
+
import time
|
3
4
|
from pathlib import Path
|
4
|
-
from typing import
|
5
|
+
from typing import Protocol, TypeVar
|
5
6
|
|
6
7
|
from dotenv import load_dotenv
|
7
8
|
from langchain_community.llms import HuggingFaceTextGenInference
|
8
|
-
from langchain_core.
|
9
|
+
from langchain_core.runnables import Runnable
|
9
10
|
from langchain_openai import ChatOpenAI
|
10
11
|
|
11
|
-
from janus.llm.model_callbacks import COST_PER_1K_TOKENS
|
12
|
+
from janus.llm.model_callbacks import COST_PER_1K_TOKENS, openai_model_reroutes
|
12
13
|
from janus.prompts.prompt import (
|
13
14
|
ChatGptPromptEngine,
|
14
15
|
ClaudePromptEngine,
|
@@ -43,17 +44,34 @@ except ImportError:
|
|
43
44
|
)
|
44
45
|
|
45
46
|
|
47
|
+
ModelType = TypeVar(
|
48
|
+
"ModelType",
|
49
|
+
ChatOpenAI,
|
50
|
+
HuggingFaceTextGenInference,
|
51
|
+
Bedrock,
|
52
|
+
BedrockChat,
|
53
|
+
HuggingFacePipeline,
|
54
|
+
)
|
55
|
+
|
56
|
+
|
57
|
+
class JanusModelProtocol(Protocol):
|
58
|
+
model_id: str
|
59
|
+
model_type_name: str
|
60
|
+
token_limit: int
|
61
|
+
input_token_cost: float
|
62
|
+
output_token_cost: float
|
63
|
+
prompt_engine: type[PromptEngine]
|
64
|
+
|
65
|
+
def get_num_tokens(self, text: str) -> int:
|
66
|
+
...
|
67
|
+
|
68
|
+
|
69
|
+
class JanusModel(Runnable, JanusModelProtocol):
|
70
|
+
...
|
71
|
+
|
72
|
+
|
46
73
|
load_dotenv()
|
47
74
|
|
48
|
-
openai_model_reroutes = {
|
49
|
-
"gpt-4o": "gpt-4o-2024-05-13",
|
50
|
-
"gpt-4o-mini": "gpt-4o-mini",
|
51
|
-
"gpt-4": "gpt-4-0613",
|
52
|
-
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
|
53
|
-
"gpt-4-turbo-preview": "gpt-4-0125-preview",
|
54
|
-
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
|
55
|
-
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-0125",
|
56
|
-
}
|
57
75
|
|
58
76
|
openai_models = [
|
59
77
|
"gpt-4o",
|
@@ -104,24 +122,15 @@ bedrock_models = [
|
|
104
122
|
]
|
105
123
|
all_models = [*openai_models, *bedrock_models]
|
106
124
|
|
107
|
-
MODEL_TYPE_CONSTRUCTORS: dict[str,
|
125
|
+
MODEL_TYPE_CONSTRUCTORS: dict[str, ModelType] = {
|
108
126
|
"OpenAI": ChatOpenAI,
|
109
127
|
"HuggingFace": HuggingFaceTextGenInference,
|
128
|
+
"Bedrock": Bedrock,
|
129
|
+
"BedrockChat": BedrockChat,
|
130
|
+
"HuggingFaceLocal": HuggingFacePipeline,
|
110
131
|
}
|
111
132
|
|
112
|
-
|
113
|
-
MODEL_TYPE_CONSTRUCTORS.update(
|
114
|
-
{
|
115
|
-
"HuggingFaceLocal": HuggingFacePipeline.from_model_id,
|
116
|
-
"Bedrock": Bedrock,
|
117
|
-
"BedrockChat": BedrockChat,
|
118
|
-
}
|
119
|
-
)
|
120
|
-
except NameError:
|
121
|
-
pass
|
122
|
-
|
123
|
-
|
124
|
-
MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
|
133
|
+
MODEL_PROMPT_ENGINES: dict[str, type[PromptEngine]] = {
|
125
134
|
**{m: ChatGptPromptEngine for m in openai_models},
|
126
135
|
**{m: ClaudePromptEngine for m in claude_models},
|
127
136
|
**{m: Llama2PromptEngine for m in llama2_models},
|
@@ -131,11 +140,6 @@ MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
|
|
131
140
|
**{m: MistralPromptEngine for m in mistral_models},
|
132
141
|
}
|
133
142
|
|
134
|
-
_open_ai_defaults: dict[str, str] = {
|
135
|
-
"openai_api_key": os.getenv("OPENAI_API_KEY"),
|
136
|
-
"openai_organization": os.getenv("OPENAI_ORG_ID"),
|
137
|
-
}
|
138
|
-
|
139
143
|
MODEL_ID_TO_LONG_ID = {
|
140
144
|
**{m: mr for m, mr in openai_model_reroutes.items()},
|
141
145
|
"bedrock-claude-v2": "anthropic.claude-v2",
|
@@ -167,7 +171,7 @@ DEFAULT_MODELS = list(MODEL_DEFAULT_ARGUMENTS.keys())
|
|
167
171
|
|
168
172
|
MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "llm"
|
169
173
|
|
170
|
-
MODEL_TYPES: dict[str,
|
174
|
+
MODEL_TYPES: dict[str, str] = {
|
171
175
|
**{m: "OpenAI" for m in openai_models},
|
172
176
|
**{m: "BedrockChat" for m in bedrock_models},
|
173
177
|
}
|
@@ -210,47 +214,90 @@ def get_available_model_names() -> list[str]:
|
|
210
214
|
return avaialable_models
|
211
215
|
|
212
216
|
|
213
|
-
def load_model(
|
214
|
-
user_model_name: str,
|
215
|
-
) -> tuple[BaseLanguageModel, str, int, dict[str, float]]:
|
217
|
+
def load_model(model_id) -> JanusModel:
|
216
218
|
if not MODEL_CONFIG_DIR.exists():
|
217
219
|
MODEL_CONFIG_DIR.mkdir(parents=True)
|
218
|
-
model_config_file = MODEL_CONFIG_DIR / f"{
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
f"default models for {user_model_name}."
|
223
|
-
)
|
224
|
-
model_id = user_model_name
|
225
|
-
if user_model_name not in DEFAULT_MODELS:
|
226
|
-
message = (
|
227
|
-
f"Model {user_model_name} not found in default models. Make sure to run "
|
228
|
-
"`janus llm add` first."
|
229
|
-
)
|
230
|
-
log.error(message)
|
231
|
-
raise ValueError(message)
|
232
|
-
model_config = {
|
233
|
-
"model_type": MODEL_TYPES[model_id],
|
234
|
-
"model_id": model_id,
|
235
|
-
"model_args": MODEL_DEFAULT_ARGUMENTS[model_id],
|
236
|
-
"token_limit": TOKEN_LIMITS.get(MODEL_ID_TO_LONG_ID[model_id], 4096),
|
237
|
-
"model_cost": COST_PER_1K_TOKENS.get(
|
238
|
-
MODEL_ID_TO_LONG_ID[model_id], {"input": 0, "output": 0}
|
239
|
-
),
|
240
|
-
}
|
241
|
-
with open(model_config_file, "w") as f:
|
242
|
-
json.dump(model_config, f)
|
243
|
-
else:
|
220
|
+
model_config_file = MODEL_CONFIG_DIR / f"{model_id}.json"
|
221
|
+
|
222
|
+
if model_config_file.exists():
|
223
|
+
log.info(f"Loading {model_id} from {model_config_file}.")
|
244
224
|
with open(model_config_file, "r") as f:
|
245
225
|
model_config = json.load(f)
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
226
|
+
model_type_name = model_config["model_type"]
|
227
|
+
model_id = model_config["model_id"]
|
228
|
+
model_args = model_config["model_args"]
|
229
|
+
token_limit = model_config["token_limit"]
|
230
|
+
input_token_cost = model_config["model_cost"]["input"]
|
231
|
+
output_token_cost = model_config["model_cost"]["output"]
|
232
|
+
|
233
|
+
elif model_id in DEFAULT_MODELS:
|
234
|
+
model_id = model_id
|
235
|
+
model_long_id = MODEL_ID_TO_LONG_ID[model_id]
|
236
|
+
model_type_name = MODEL_TYPES[model_id]
|
237
|
+
model_args = MODEL_DEFAULT_ARGUMENTS[model_id]
|
238
|
+
|
239
|
+
token_limit = 0
|
240
|
+
input_token_cost = 0.0
|
241
|
+
output_token_cost = 0.0
|
242
|
+
if model_long_id in TOKEN_LIMITS:
|
243
|
+
token_limit = TOKEN_LIMITS[model_long_id]
|
244
|
+
if model_long_id in COST_PER_1K_TOKENS:
|
245
|
+
token_limits = COST_PER_1K_TOKENS[model_long_id]
|
246
|
+
input_token_cost = token_limits["input"]
|
247
|
+
output_token_cost = token_limits["output"]
|
248
|
+
|
249
|
+
else:
|
250
|
+
model_list = "\n\t".join(DEFAULT_MODELS)
|
251
|
+
message = (
|
252
|
+
f"Model {model_id} not found in user-defined model directory "
|
253
|
+
f"({MODEL_CONFIG_DIR}), and is not a default model. Valid default "
|
254
|
+
f"models:\n\t{model_list}\n"
|
255
|
+
f"To use a custom model, first run `janus llm add`."
|
256
|
+
)
|
257
|
+
log.error(message)
|
258
|
+
raise ValueError(message)
|
259
|
+
|
260
|
+
if model_type_name == "HuggingFaceLocal":
|
261
|
+
model = HuggingFacePipeline.from_model_id(
|
262
|
+
model_id=model_id,
|
263
|
+
task="text-generation",
|
264
|
+
model_kwargs=model_args,
|
265
|
+
)
|
266
|
+
model_args.update(pipeline=model.pipeline)
|
267
|
+
|
268
|
+
elif model_type_name == "OpenAI":
|
269
|
+
model_args.update(
|
270
|
+
openai_api_key=str(os.getenv("OPENAI_API_KEY")),
|
271
|
+
openai_organization=str(os.getenv("OPENAI_ORG_ID")),
|
272
|
+
)
|
273
|
+
log.warning("Do NOT use this model in sensitive environments!")
|
274
|
+
log.warning("If you would like to cancel, please press Ctrl+C.")
|
275
|
+
log.warning("Waiting 10 seconds...")
|
276
|
+
# Give enough time for the user to read the warnings and cancel
|
277
|
+
time.sleep(10)
|
278
|
+
|
279
|
+
model_type = MODEL_TYPE_CONSTRUCTORS[model_type_name]
|
280
|
+
prompt_engine = MODEL_PROMPT_ENGINES[model_id]
|
281
|
+
|
282
|
+
class JanusModel(model_type):
|
283
|
+
model_id: str
|
284
|
+
short_model_id: str
|
285
|
+
model_type_name: str
|
286
|
+
token_limit: int
|
287
|
+
input_token_cost: float
|
288
|
+
output_token_cost: float
|
289
|
+
prompt_engine: type[PromptEngine]
|
290
|
+
|
291
|
+
model_args.update(
|
292
|
+
model_id=MODEL_ID_TO_LONG_ID[model_id],
|
293
|
+
short_model_id=model_id,
|
294
|
+
)
|
295
|
+
|
296
|
+
return JanusModel(
|
297
|
+
model_type_name=model_type_name,
|
298
|
+
token_limit=token_limit,
|
299
|
+
input_token_cost=input_token_cost,
|
300
|
+
output_token_cost=output_token_cost,
|
301
|
+
prompt_engine=prompt_engine,
|
302
|
+
**model_args,
|
256
303
|
)
|
janus/metrics/metric.py
CHANGED
@@ -8,6 +8,7 @@ import typer
|
|
8
8
|
from typing_extensions import Annotated
|
9
9
|
|
10
10
|
from janus.llm import load_model
|
11
|
+
from janus.llm.model_callbacks import COST_PER_1K_TOKENS
|
11
12
|
from janus.metrics.cli import evaluate
|
12
13
|
from janus.metrics.file_pairing import FILE_PAIRING_METHODS
|
13
14
|
from janus.metrics.splitting import SPLITTING_METHODS
|
@@ -135,7 +136,7 @@ def metric(
|
|
135
136
|
**kwargs,
|
136
137
|
):
|
137
138
|
out = []
|
138
|
-
llm
|
139
|
+
llm = load_model(llm_name)
|
139
140
|
if json_file_name is not None:
|
140
141
|
with open(json_file_name, "r") as f:
|
141
142
|
json_obj = json.load(f)
|
@@ -171,8 +172,8 @@ def metric(
|
|
171
172
|
out_file=out_file,
|
172
173
|
lang=language,
|
173
174
|
llm=llm,
|
174
|
-
token_limit=token_limit,
|
175
|
-
model_cost=
|
175
|
+
token_limit=llm.token_limit,
|
176
|
+
model_cost=COST_PER_1K_TOKENS[llm.model_id],
|
176
177
|
)
|
177
178
|
else:
|
178
179
|
raise ValueError(
|
@@ -187,8 +188,8 @@ def metric(
|
|
187
188
|
progress,
|
188
189
|
language,
|
189
190
|
llm,
|
190
|
-
token_limit,
|
191
|
-
|
191
|
+
llm.token_limit,
|
192
|
+
COST_PER_1K_TOKENS[llm.model_id],
|
192
193
|
*args,
|
193
194
|
**kwargs,
|
194
195
|
)
|
@@ -199,8 +200,8 @@ def metric(
|
|
199
200
|
progress,
|
200
201
|
language,
|
201
202
|
llm,
|
202
|
-
token_limit,
|
203
|
-
|
203
|
+
llm.token_limit,
|
204
|
+
COST_PER_1K_TOKENS[llm.model_id],
|
204
205
|
*args,
|
205
206
|
**kwargs,
|
206
207
|
)
|
@@ -296,7 +297,7 @@ def metric(
|
|
296
297
|
*args,
|
297
298
|
**kwargs,
|
298
299
|
):
|
299
|
-
llm
|
300
|
+
llm = load_model(llm_name)
|
300
301
|
if json_file_name is not None:
|
301
302
|
with open(json_file_name, "r") as f:
|
302
303
|
json_obj = json.load(f)
|
@@ -328,8 +329,8 @@ def metric(
|
|
328
329
|
out_file=out_file,
|
329
330
|
lang=language,
|
330
331
|
llm=llm,
|
331
|
-
token_limit=token_limit,
|
332
|
-
model_cost=
|
332
|
+
token_limit=llm.token_limit,
|
333
|
+
model_cost=COST_PER_1K_TOKENS[llm.model_id],
|
333
334
|
)
|
334
335
|
else:
|
335
336
|
raise ValueError(
|
@@ -344,8 +345,8 @@ def metric(
|
|
344
345
|
progress,
|
345
346
|
language,
|
346
347
|
llm,
|
347
|
-
token_limit,
|
348
|
-
|
348
|
+
llm.token_limit,
|
349
|
+
COST_PER_1K_TOKENS[llm.model_id],
|
349
350
|
*args,
|
350
351
|
**kwargs,
|
351
352
|
)
|
@@ -356,8 +357,8 @@ def metric(
|
|
356
357
|
progress,
|
357
358
|
language,
|
358
359
|
llm,
|
359
|
-
token_limit,
|
360
|
-
|
360
|
+
llm.token_limit,
|
361
|
+
COST_PER_1K_TOKENS[llm.model_id],
|
361
362
|
*args,
|
362
363
|
**kwargs,
|
363
364
|
)
|
janus/parsers/uml.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import re
|
2
2
|
import subprocess # nosec
|
3
3
|
from pathlib import Path
|
4
|
-
from
|
4
|
+
from tempfile import NamedTemporaryFile
|
5
5
|
|
6
6
|
from langchain_core.exceptions import OutputParserException
|
7
7
|
from langchain_core.messages import BaseMessage
|
@@ -13,39 +13,76 @@ log = create_logger(__name__)
|
|
13
13
|
|
14
14
|
|
15
15
|
class UMLSyntaxParser(CodeParser):
|
16
|
-
def
|
17
|
-
#
|
18
|
-
|
16
|
+
def _check_plantuml(self, text: str) -> None:
|
17
|
+
# Leading newlines can break the parser, remove them
|
18
|
+
text = text.replace("\\n", "\n").strip()
|
19
|
+
|
20
|
+
# Write the text to a temporary file (automatically deleted)
|
21
|
+
file = NamedTemporaryFile()
|
22
|
+
fname = file.name
|
23
|
+
with open(fname, "w") as fin:
|
24
|
+
fin.write(text)
|
25
|
+
|
19
26
|
try:
|
20
27
|
plantuml_path = Path.home().expanduser() / ".janus/lib/plantuml.jar"
|
28
|
+
# NOTE: running subprocess with shell=False, added nosec to
|
29
|
+
# label that we know risk exists
|
21
30
|
res = subprocess.run(
|
22
|
-
["java", "-jar", plantuml_path,
|
31
|
+
["java", "-jar", plantuml_path, fname],
|
23
32
|
stdout=subprocess.PIPE,
|
24
33
|
stderr=subprocess.PIPE,
|
25
34
|
) # nosec
|
26
35
|
stdout = res.stdout.decode("utf-8")
|
27
36
|
stderr = res.stderr.decode("utf-8")
|
28
37
|
except FileNotFoundError:
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
38
|
+
err_txt = (
|
39
|
+
"Plant UML executable not found. Either choose a different parser"
|
40
|
+
" or install with `bash scripts/install_plantuml.sh`. Java and"
|
41
|
+
" graphviz are dependencies for the tool, they must also be installed."
|
42
|
+
)
|
43
|
+
log.error(err_txt)
|
44
|
+
raise Exception(err_txt)
|
45
|
+
|
46
|
+
# Check for bad outputs, raise OutputParserExceptions if so
|
47
|
+
if "Error" in stderr or "Error" in stdout:
|
48
|
+
err_txt = "Recieved UML parsing error(s)."
|
49
|
+
|
50
|
+
line_nos = self._get_error_lines(stderr) + self._get_error_lines(stdout)
|
51
|
+
lines = text.split("\n")
|
52
|
+
for i in line_nos:
|
53
|
+
i0 = max(0, i - 3)
|
54
|
+
i1 = min(len(lines) - 1, i + 2)
|
55
|
+
err_lines = [
|
56
|
+
f"> {lines[j]}" if j == i - 1 else f" {lines[j]}"
|
57
|
+
for j in range(i0, i1)
|
58
|
+
]
|
59
|
+
if i0:
|
60
|
+
err_lines.insert(0, " ...")
|
61
|
+
if i1 < (len(lines) - 1):
|
62
|
+
err_lines.append(" ...")
|
33
63
|
|
34
|
-
|
35
|
-
|
64
|
+
err_txt += f"\nError located at line {i} must be fixed:\n"
|
65
|
+
err_txt += "\n".join(err_lines)
|
66
|
+
log.warning(err_txt)
|
67
|
+
raise OutputParserException(err_txt)
|
68
|
+
|
69
|
+
if "Warning" in stdout or "Warning" in stderr:
|
70
|
+
err_txt = "Recieved UML parsing warning (often due to missing PLANTUML)."
|
71
|
+
if stderr:
|
72
|
+
err_txt += f"\nSTDERR:\n```\n{stderr.strip()}\n```\n"
|
73
|
+
if stdout:
|
74
|
+
err_txt += f"\nSTDOUT:\n```\n{stdout.strip()}\n```\n"
|
75
|
+
|
76
|
+
log.warning(err_txt)
|
77
|
+
raise OutputParserException(err_txt)
|
78
|
+
|
79
|
+
def _get_error_lines(self, s: str) -> list[int]:
|
80
|
+
return [int(x.group(1)) for x in re.finditer(r"Error line (\d+) in file:", s)]
|
81
|
+
|
82
|
+
def _get_warns(self, s: str) -> list[str]:
|
83
|
+
return [x.group() for x in re.finditer(r"Warning: (.*)\n", s)]
|
36
84
|
|
37
85
|
def parse(self, text: str | BaseMessage) -> str:
|
38
86
|
text = super().parse(text)
|
39
|
-
|
40
|
-
if not janus_path.exists():
|
41
|
-
janus_path.mkdir()
|
42
|
-
temp_file_path = janus_path / "tmp.txt"
|
43
|
-
with open(temp_file_path, "w") as f:
|
44
|
-
f.write(text)
|
45
|
-
uml_std_out, uml_std_err = self._get_uml_output(temp_file_path)
|
46
|
-
uml_errs = self._get_errs(uml_std_out) + self._get_errs(uml_std_err)
|
47
|
-
if len(uml_errs) > 0:
|
48
|
-
raise OutputParserException(
|
49
|
-
"Error: Received UML Errors:\n" + "\n".join(uml_errs)
|
50
|
-
)
|
87
|
+
self._check_plantuml(text)
|
51
88
|
return text
|
janus/refiners/refiner.py
CHANGED
@@ -1,73 +1,115 @@
|
|
1
|
-
from
|
1
|
+
from typing import Any
|
2
2
|
|
3
|
-
from
|
3
|
+
from langchain.output_parsers import RetryWithErrorOutputParser
|
4
|
+
from langchain_core.output_parsers import StrOutputParser
|
5
|
+
from langchain_core.prompt_values import PromptValue
|
6
|
+
from langchain_core.runnables import RunnableSerializable
|
4
7
|
|
8
|
+
from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel
|
9
|
+
from janus.parsers.parser import JanusParser
|
10
|
+
from janus.utils.logger import create_logger
|
5
11
|
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
**
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
original_output: origial output of llm
|
20
|
-
errors: list of errors detected by parser
|
21
|
-
|
22
|
-
Returns:
|
23
|
-
Tuple of new prompt and prompt arguments
|
24
|
-
"""
|
12
|
+
log = create_logger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class JanusRefiner(JanusParser):
|
16
|
+
parser: JanusParser
|
17
|
+
|
18
|
+
def parse_runnable(self, input: dict[str, Any]) -> Any:
|
19
|
+
return self.parse_completion(**input)
|
20
|
+
|
21
|
+
def parse_completion(self, completion: str, **kwargs) -> Any:
|
22
|
+
return self.parser.parse(completion)
|
23
|
+
|
24
|
+
def parse(self, text: str) -> str:
|
25
25
|
raise NotImplementedError
|
26
26
|
|
27
27
|
|
28
|
-
class
|
28
|
+
class FixParserExceptions(JanusRefiner, RetryWithErrorOutputParser):
|
29
|
+
def __init__(self, llm: JanusModel, parser: JanusParser, max_retries: int):
|
30
|
+
retry_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
|
31
|
+
source_language="text",
|
32
|
+
prompt_template="refinement/fix_exceptions",
|
33
|
+
).prompt
|
34
|
+
chain = retry_prompt | llm | StrOutputParser()
|
35
|
+
RetryWithErrorOutputParser.__init__(
|
36
|
+
self, parser=parser, retry_chain=chain, max_retries=max_retries
|
37
|
+
)
|
38
|
+
|
39
|
+
def parse_completion(
|
40
|
+
self, completion: str, prompt_value: PromptValue, **kwargs
|
41
|
+
) -> Any:
|
42
|
+
return self.parse_with_prompt(completion, prompt_value=prompt_value)
|
43
|
+
|
44
|
+
|
45
|
+
class ReflectionRefiner(JanusRefiner):
|
46
|
+
max_retries: int
|
47
|
+
reflection_chain: RunnableSerializable
|
48
|
+
revision_chain: RunnableSerializable
|
49
|
+
|
29
50
|
def __init__(
|
30
51
|
self,
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
52
|
+
llm: JanusModel,
|
53
|
+
parser: JanusParser,
|
54
|
+
max_retries: int,
|
55
|
+
prompt_template_name: str = "refinement/reflection",
|
56
|
+
):
|
57
|
+
reflection_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
|
58
|
+
source_language="text",
|
59
|
+
prompt_template=prompt_template_name,
|
60
|
+
).prompt
|
61
|
+
revision_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
|
62
|
+
source_language="text",
|
63
|
+
prompt_template="refinement/revision",
|
64
|
+
).prompt
|
65
|
+
|
66
|
+
reflection_chain = reflection_prompt | llm | StrOutputParser()
|
67
|
+
revision_chain = revision_prompt | llm | StrOutputParser()
|
68
|
+
super().__init__(
|
69
|
+
reflection_chain=reflection_chain,
|
70
|
+
revision_chain=revision_chain,
|
71
|
+
parser=parser,
|
72
|
+
max_retries=max_retries,
|
73
|
+
)
|
74
|
+
|
75
|
+
def parse_completion(
|
76
|
+
self, completion: str, prompt_value: PromptValue, **kwargs
|
77
|
+
) -> Any:
|
78
|
+
for retry_number in range(self.max_retries):
|
79
|
+
reflection = self.reflection_chain.invoke(
|
80
|
+
dict(
|
81
|
+
prompt=prompt_value.to_string(),
|
82
|
+
completion=completion,
|
83
|
+
)
|
84
|
+
)
|
85
|
+
if reflection.strip() == "LGTM":
|
86
|
+
return self.parser.parse(completion)
|
87
|
+
if not retry_number:
|
88
|
+
log.info(f"Completion:\n{completion}")
|
89
|
+
log.info(f"Reflection:\n{reflection}")
|
90
|
+
completion = self.revision_chain.invoke(
|
91
|
+
dict(
|
92
|
+
prompt=prompt_value.to_string(),
|
93
|
+
completion=completion,
|
94
|
+
reflection=reflection,
|
95
|
+
)
|
96
|
+
)
|
97
|
+
log.info(f"Revision:\n{completion}")
|
98
|
+
|
99
|
+
return self.parser.parse(completion)
|
100
|
+
|
101
|
+
|
102
|
+
class HallucinationRefiner(ReflectionRefiner):
|
103
|
+
def __init__(self, **kwargs):
|
104
|
+
super().__init__(
|
105
|
+
prompt_template_name="refinement/hallucination",
|
106
|
+
**kwargs,
|
67
107
|
)
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
108
|
+
|
109
|
+
|
110
|
+
REFINERS = dict(
|
111
|
+
none=JanusRefiner,
|
112
|
+
parser=FixParserExceptions,
|
113
|
+
reflection=ReflectionRefiner,
|
114
|
+
hallucination=HallucinationRefiner,
|
115
|
+
)
|
@@ -0,0 +1,42 @@
|
|
1
|
+
from langchain_core.retrievers import BaseRetriever
|
2
|
+
from langchain_core.runnables import Runnable, RunnableConfig
|
3
|
+
|
4
|
+
from janus.language.block import CodeBlock
|
5
|
+
|
6
|
+
|
7
|
+
class JanusRetriever(Runnable):
|
8
|
+
def __init__(self) -> None:
|
9
|
+
super().__init__()
|
10
|
+
|
11
|
+
def invoke(
|
12
|
+
self, input: CodeBlock, config: RunnableConfig | None = None, **kwargs
|
13
|
+
) -> dict:
|
14
|
+
kwargs.update(context=self.get_context(input))
|
15
|
+
return kwargs
|
16
|
+
|
17
|
+
def get_context(self, code_block: CodeBlock) -> str:
|
18
|
+
return ""
|
19
|
+
|
20
|
+
|
21
|
+
class ActiveUsingsRetriever(JanusRetriever):
|
22
|
+
def get_context(self, code_block: CodeBlock) -> str:
|
23
|
+
context = "\n".join(
|
24
|
+
f"{context_tag}: {context}"
|
25
|
+
for context_tag, context in code_block.context_tags.items()
|
26
|
+
)
|
27
|
+
return f"You may use the following additional context: {context}"
|
28
|
+
|
29
|
+
|
30
|
+
class TextSearchRetriever(JanusRetriever):
|
31
|
+
retriever: BaseRetriever
|
32
|
+
|
33
|
+
def __init__(self, retriever: BaseRetriever):
|
34
|
+
super().__init__()
|
35
|
+
self.retriever = retriever
|
36
|
+
|
37
|
+
def get_context(self, code_block: CodeBlock) -> str:
|
38
|
+
if code_block.text is None:
|
39
|
+
return ""
|
40
|
+
docs = self.retriever.invoke(code_block.text)
|
41
|
+
context = "\n\n".join(doc.page_content for doc in docs)
|
42
|
+
return f"You may use the following additional context: {context}"
|