janus-llm 3.5.2__py3-none-any.whl → 4.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +1 -1
- janus/cli.py +90 -42
- janus/converter/converter.py +111 -142
- janus/converter/diagram.py +21 -109
- janus/converter/translate.py +1 -1
- janus/language/alc/_tests/test_alc.py +1 -1
- janus/language/alc/alc.py +16 -11
- janus/language/binary/_tests/test_binary.py +1 -1
- janus/language/binary/binary.py +2 -2
- janus/language/mumps/_tests/test_mumps.py +1 -1
- janus/language/mumps/mumps.py +2 -3
- janus/language/naive/simple_ast.py +3 -2
- janus/language/splitter.py +7 -4
- janus/language/treesitter/_tests/test_treesitter.py +1 -1
- janus/language/treesitter/treesitter.py +2 -2
- janus/llm/model_callbacks.py +13 -0
- janus/llm/models_info.py +118 -71
- janus/metrics/metric.py +15 -14
- janus/parsers/uml.py +60 -23
- janus/refiners/refiner.py +106 -64
- janus/retrievers/retriever.py +42 -0
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/METADATA +1 -1
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/RECORD +26 -26
- janus/parsers/refiner_parser.py +0 -46
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/LICENSE +0 -0
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/WHEEL +0 -0
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/entry_points.txt +0 -0
janus/llm/models_info.py
CHANGED
@@ -1,14 +1,15 @@
|
|
1
1
|
import json
|
2
2
|
import os
|
3
|
+
import time
|
3
4
|
from pathlib import Path
|
4
|
-
from typing import
|
5
|
+
from typing import Protocol, TypeVar
|
5
6
|
|
6
7
|
from dotenv import load_dotenv
|
7
8
|
from langchain_community.llms import HuggingFaceTextGenInference
|
8
|
-
from langchain_core.
|
9
|
+
from langchain_core.runnables import Runnable
|
9
10
|
from langchain_openai import ChatOpenAI
|
10
11
|
|
11
|
-
from janus.llm.model_callbacks import COST_PER_1K_TOKENS
|
12
|
+
from janus.llm.model_callbacks import COST_PER_1K_TOKENS, openai_model_reroutes
|
12
13
|
from janus.prompts.prompt import (
|
13
14
|
ChatGptPromptEngine,
|
14
15
|
ClaudePromptEngine,
|
@@ -43,17 +44,34 @@ except ImportError:
|
|
43
44
|
)
|
44
45
|
|
45
46
|
|
47
|
+
ModelType = TypeVar(
|
48
|
+
"ModelType",
|
49
|
+
ChatOpenAI,
|
50
|
+
HuggingFaceTextGenInference,
|
51
|
+
Bedrock,
|
52
|
+
BedrockChat,
|
53
|
+
HuggingFacePipeline,
|
54
|
+
)
|
55
|
+
|
56
|
+
|
57
|
+
class JanusModelProtocol(Protocol):
|
58
|
+
model_id: str
|
59
|
+
model_type_name: str
|
60
|
+
token_limit: int
|
61
|
+
input_token_cost: float
|
62
|
+
output_token_cost: float
|
63
|
+
prompt_engine: type[PromptEngine]
|
64
|
+
|
65
|
+
def get_num_tokens(self, text: str) -> int:
|
66
|
+
...
|
67
|
+
|
68
|
+
|
69
|
+
class JanusModel(Runnable, JanusModelProtocol):
|
70
|
+
...
|
71
|
+
|
72
|
+
|
46
73
|
load_dotenv()
|
47
74
|
|
48
|
-
openai_model_reroutes = {
|
49
|
-
"gpt-4o": "gpt-4o-2024-05-13",
|
50
|
-
"gpt-4o-mini": "gpt-4o-mini",
|
51
|
-
"gpt-4": "gpt-4-0613",
|
52
|
-
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
|
53
|
-
"gpt-4-turbo-preview": "gpt-4-0125-preview",
|
54
|
-
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
|
55
|
-
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-0125",
|
56
|
-
}
|
57
75
|
|
58
76
|
openai_models = [
|
59
77
|
"gpt-4o",
|
@@ -104,24 +122,15 @@ bedrock_models = [
|
|
104
122
|
]
|
105
123
|
all_models = [*openai_models, *bedrock_models]
|
106
124
|
|
107
|
-
MODEL_TYPE_CONSTRUCTORS: dict[str,
|
125
|
+
MODEL_TYPE_CONSTRUCTORS: dict[str, ModelType] = {
|
108
126
|
"OpenAI": ChatOpenAI,
|
109
127
|
"HuggingFace": HuggingFaceTextGenInference,
|
128
|
+
"Bedrock": Bedrock,
|
129
|
+
"BedrockChat": BedrockChat,
|
130
|
+
"HuggingFaceLocal": HuggingFacePipeline,
|
110
131
|
}
|
111
132
|
|
112
|
-
|
113
|
-
MODEL_TYPE_CONSTRUCTORS.update(
|
114
|
-
{
|
115
|
-
"HuggingFaceLocal": HuggingFacePipeline.from_model_id,
|
116
|
-
"Bedrock": Bedrock,
|
117
|
-
"BedrockChat": BedrockChat,
|
118
|
-
}
|
119
|
-
)
|
120
|
-
except NameError:
|
121
|
-
pass
|
122
|
-
|
123
|
-
|
124
|
-
MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
|
133
|
+
MODEL_PROMPT_ENGINES: dict[str, type[PromptEngine]] = {
|
125
134
|
**{m: ChatGptPromptEngine for m in openai_models},
|
126
135
|
**{m: ClaudePromptEngine for m in claude_models},
|
127
136
|
**{m: Llama2PromptEngine for m in llama2_models},
|
@@ -131,11 +140,6 @@ MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
|
|
131
140
|
**{m: MistralPromptEngine for m in mistral_models},
|
132
141
|
}
|
133
142
|
|
134
|
-
_open_ai_defaults: dict[str, str] = {
|
135
|
-
"openai_api_key": os.getenv("OPENAI_API_KEY"),
|
136
|
-
"openai_organization": os.getenv("OPENAI_ORG_ID"),
|
137
|
-
}
|
138
|
-
|
139
143
|
MODEL_ID_TO_LONG_ID = {
|
140
144
|
**{m: mr for m, mr in openai_model_reroutes.items()},
|
141
145
|
"bedrock-claude-v2": "anthropic.claude-v2",
|
@@ -167,7 +171,7 @@ DEFAULT_MODELS = list(MODEL_DEFAULT_ARGUMENTS.keys())
|
|
167
171
|
|
168
172
|
MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "llm"
|
169
173
|
|
170
|
-
MODEL_TYPES: dict[str,
|
174
|
+
MODEL_TYPES: dict[str, str] = {
|
171
175
|
**{m: "OpenAI" for m in openai_models},
|
172
176
|
**{m: "BedrockChat" for m in bedrock_models},
|
173
177
|
}
|
@@ -210,47 +214,90 @@ def get_available_model_names() -> list[str]:
|
|
210
214
|
return avaialable_models
|
211
215
|
|
212
216
|
|
213
|
-
def load_model(
|
214
|
-
user_model_name: str,
|
215
|
-
) -> tuple[BaseLanguageModel, str, int, dict[str, float]]:
|
217
|
+
def load_model(model_id) -> JanusModel:
|
216
218
|
if not MODEL_CONFIG_DIR.exists():
|
217
219
|
MODEL_CONFIG_DIR.mkdir(parents=True)
|
218
|
-
model_config_file = MODEL_CONFIG_DIR / f"{
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
f"default models for {user_model_name}."
|
223
|
-
)
|
224
|
-
model_id = user_model_name
|
225
|
-
if user_model_name not in DEFAULT_MODELS:
|
226
|
-
message = (
|
227
|
-
f"Model {user_model_name} not found in default models. Make sure to run "
|
228
|
-
"`janus llm add` first."
|
229
|
-
)
|
230
|
-
log.error(message)
|
231
|
-
raise ValueError(message)
|
232
|
-
model_config = {
|
233
|
-
"model_type": MODEL_TYPES[model_id],
|
234
|
-
"model_id": model_id,
|
235
|
-
"model_args": MODEL_DEFAULT_ARGUMENTS[model_id],
|
236
|
-
"token_limit": TOKEN_LIMITS.get(MODEL_ID_TO_LONG_ID[model_id], 4096),
|
237
|
-
"model_cost": COST_PER_1K_TOKENS.get(
|
238
|
-
MODEL_ID_TO_LONG_ID[model_id], {"input": 0, "output": 0}
|
239
|
-
),
|
240
|
-
}
|
241
|
-
with open(model_config_file, "w") as f:
|
242
|
-
json.dump(model_config, f)
|
243
|
-
else:
|
220
|
+
model_config_file = MODEL_CONFIG_DIR / f"{model_id}.json"
|
221
|
+
|
222
|
+
if model_config_file.exists():
|
223
|
+
log.info(f"Loading {model_id} from {model_config_file}.")
|
244
224
|
with open(model_config_file, "r") as f:
|
245
225
|
model_config = json.load(f)
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
226
|
+
model_type_name = model_config["model_type"]
|
227
|
+
model_id = model_config["model_id"]
|
228
|
+
model_args = model_config["model_args"]
|
229
|
+
token_limit = model_config["token_limit"]
|
230
|
+
input_token_cost = model_config["model_cost"]["input"]
|
231
|
+
output_token_cost = model_config["model_cost"]["output"]
|
232
|
+
|
233
|
+
elif model_id in DEFAULT_MODELS:
|
234
|
+
model_id = model_id
|
235
|
+
model_long_id = MODEL_ID_TO_LONG_ID[model_id]
|
236
|
+
model_type_name = MODEL_TYPES[model_id]
|
237
|
+
model_args = MODEL_DEFAULT_ARGUMENTS[model_id]
|
238
|
+
|
239
|
+
token_limit = 0
|
240
|
+
input_token_cost = 0.0
|
241
|
+
output_token_cost = 0.0
|
242
|
+
if model_long_id in TOKEN_LIMITS:
|
243
|
+
token_limit = TOKEN_LIMITS[model_long_id]
|
244
|
+
if model_long_id in COST_PER_1K_TOKENS:
|
245
|
+
token_limits = COST_PER_1K_TOKENS[model_long_id]
|
246
|
+
input_token_cost = token_limits["input"]
|
247
|
+
output_token_cost = token_limits["output"]
|
248
|
+
|
249
|
+
else:
|
250
|
+
model_list = "\n\t".join(DEFAULT_MODELS)
|
251
|
+
message = (
|
252
|
+
f"Model {model_id} not found in user-defined model directory "
|
253
|
+
f"({MODEL_CONFIG_DIR}), and is not a default model. Valid default "
|
254
|
+
f"models:\n\t{model_list}\n"
|
255
|
+
f"To use a custom model, first run `janus llm add`."
|
256
|
+
)
|
257
|
+
log.error(message)
|
258
|
+
raise ValueError(message)
|
259
|
+
|
260
|
+
if model_type_name == "HuggingFaceLocal":
|
261
|
+
model = HuggingFacePipeline.from_model_id(
|
262
|
+
model_id=model_id,
|
263
|
+
task="text-generation",
|
264
|
+
model_kwargs=model_args,
|
265
|
+
)
|
266
|
+
model_args.update(pipeline=model.pipeline)
|
267
|
+
|
268
|
+
elif model_type_name == "OpenAI":
|
269
|
+
model_args.update(
|
270
|
+
openai_api_key=str(os.getenv("OPENAI_API_KEY")),
|
271
|
+
openai_organization=str(os.getenv("OPENAI_ORG_ID")),
|
272
|
+
)
|
273
|
+
log.warning("Do NOT use this model in sensitive environments!")
|
274
|
+
log.warning("If you would like to cancel, please press Ctrl+C.")
|
275
|
+
log.warning("Waiting 10 seconds...")
|
276
|
+
# Give enough time for the user to read the warnings and cancel
|
277
|
+
time.sleep(10)
|
278
|
+
|
279
|
+
model_type = MODEL_TYPE_CONSTRUCTORS[model_type_name]
|
280
|
+
prompt_engine = MODEL_PROMPT_ENGINES[model_id]
|
281
|
+
|
282
|
+
class JanusModel(model_type):
|
283
|
+
model_id: str
|
284
|
+
short_model_id: str
|
285
|
+
model_type_name: str
|
286
|
+
token_limit: int
|
287
|
+
input_token_cost: float
|
288
|
+
output_token_cost: float
|
289
|
+
prompt_engine: type[PromptEngine]
|
290
|
+
|
291
|
+
model_args.update(
|
292
|
+
model_id=MODEL_ID_TO_LONG_ID[model_id],
|
293
|
+
short_model_id=model_id,
|
294
|
+
)
|
295
|
+
|
296
|
+
return JanusModel(
|
297
|
+
model_type_name=model_type_name,
|
298
|
+
token_limit=token_limit,
|
299
|
+
input_token_cost=input_token_cost,
|
300
|
+
output_token_cost=output_token_cost,
|
301
|
+
prompt_engine=prompt_engine,
|
302
|
+
**model_args,
|
256
303
|
)
|
janus/metrics/metric.py
CHANGED
@@ -8,6 +8,7 @@ import typer
|
|
8
8
|
from typing_extensions import Annotated
|
9
9
|
|
10
10
|
from janus.llm import load_model
|
11
|
+
from janus.llm.model_callbacks import COST_PER_1K_TOKENS
|
11
12
|
from janus.metrics.cli import evaluate
|
12
13
|
from janus.metrics.file_pairing import FILE_PAIRING_METHODS
|
13
14
|
from janus.metrics.splitting import SPLITTING_METHODS
|
@@ -135,7 +136,7 @@ def metric(
|
|
135
136
|
**kwargs,
|
136
137
|
):
|
137
138
|
out = []
|
138
|
-
llm
|
139
|
+
llm = load_model(llm_name)
|
139
140
|
if json_file_name is not None:
|
140
141
|
with open(json_file_name, "r") as f:
|
141
142
|
json_obj = json.load(f)
|
@@ -171,8 +172,8 @@ def metric(
|
|
171
172
|
out_file=out_file,
|
172
173
|
lang=language,
|
173
174
|
llm=llm,
|
174
|
-
token_limit=token_limit,
|
175
|
-
model_cost=
|
175
|
+
token_limit=llm.token_limit,
|
176
|
+
model_cost=COST_PER_1K_TOKENS[llm.model_id],
|
176
177
|
)
|
177
178
|
else:
|
178
179
|
raise ValueError(
|
@@ -187,8 +188,8 @@ def metric(
|
|
187
188
|
progress,
|
188
189
|
language,
|
189
190
|
llm,
|
190
|
-
token_limit,
|
191
|
-
|
191
|
+
llm.token_limit,
|
192
|
+
COST_PER_1K_TOKENS[llm.model_id],
|
192
193
|
*args,
|
193
194
|
**kwargs,
|
194
195
|
)
|
@@ -199,8 +200,8 @@ def metric(
|
|
199
200
|
progress,
|
200
201
|
language,
|
201
202
|
llm,
|
202
|
-
token_limit,
|
203
|
-
|
203
|
+
llm.token_limit,
|
204
|
+
COST_PER_1K_TOKENS[llm.model_id],
|
204
205
|
*args,
|
205
206
|
**kwargs,
|
206
207
|
)
|
@@ -296,7 +297,7 @@ def metric(
|
|
296
297
|
*args,
|
297
298
|
**kwargs,
|
298
299
|
):
|
299
|
-
llm
|
300
|
+
llm = load_model(llm_name)
|
300
301
|
if json_file_name is not None:
|
301
302
|
with open(json_file_name, "r") as f:
|
302
303
|
json_obj = json.load(f)
|
@@ -328,8 +329,8 @@ def metric(
|
|
328
329
|
out_file=out_file,
|
329
330
|
lang=language,
|
330
331
|
llm=llm,
|
331
|
-
token_limit=token_limit,
|
332
|
-
model_cost=
|
332
|
+
token_limit=llm.token_limit,
|
333
|
+
model_cost=COST_PER_1K_TOKENS[llm.model_id],
|
333
334
|
)
|
334
335
|
else:
|
335
336
|
raise ValueError(
|
@@ -344,8 +345,8 @@ def metric(
|
|
344
345
|
progress,
|
345
346
|
language,
|
346
347
|
llm,
|
347
|
-
token_limit,
|
348
|
-
|
348
|
+
llm.token_limit,
|
349
|
+
COST_PER_1K_TOKENS[llm.model_id],
|
349
350
|
*args,
|
350
351
|
**kwargs,
|
351
352
|
)
|
@@ -356,8 +357,8 @@ def metric(
|
|
356
357
|
progress,
|
357
358
|
language,
|
358
359
|
llm,
|
359
|
-
token_limit,
|
360
|
-
|
360
|
+
llm.token_limit,
|
361
|
+
COST_PER_1K_TOKENS[llm.model_id],
|
361
362
|
*args,
|
362
363
|
**kwargs,
|
363
364
|
)
|
janus/parsers/uml.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import re
|
2
2
|
import subprocess # nosec
|
3
3
|
from pathlib import Path
|
4
|
-
from
|
4
|
+
from tempfile import NamedTemporaryFile
|
5
5
|
|
6
6
|
from langchain_core.exceptions import OutputParserException
|
7
7
|
from langchain_core.messages import BaseMessage
|
@@ -13,39 +13,76 @@ log = create_logger(__name__)
|
|
13
13
|
|
14
14
|
|
15
15
|
class UMLSyntaxParser(CodeParser):
|
16
|
-
def
|
17
|
-
#
|
18
|
-
|
16
|
+
def _check_plantuml(self, text: str) -> None:
|
17
|
+
# Leading newlines can break the parser, remove them
|
18
|
+
text = text.replace("\\n", "\n").strip()
|
19
|
+
|
20
|
+
# Write the text to a temporary file (automatically deleted)
|
21
|
+
file = NamedTemporaryFile()
|
22
|
+
fname = file.name
|
23
|
+
with open(fname, "w") as fin:
|
24
|
+
fin.write(text)
|
25
|
+
|
19
26
|
try:
|
20
27
|
plantuml_path = Path.home().expanduser() / ".janus/lib/plantuml.jar"
|
28
|
+
# NOTE: running subprocess with shell=False, added nosec to
|
29
|
+
# label that we know risk exists
|
21
30
|
res = subprocess.run(
|
22
|
-
["java", "-jar", plantuml_path,
|
31
|
+
["java", "-jar", plantuml_path, fname],
|
23
32
|
stdout=subprocess.PIPE,
|
24
33
|
stderr=subprocess.PIPE,
|
25
34
|
) # nosec
|
26
35
|
stdout = res.stdout.decode("utf-8")
|
27
36
|
stderr = res.stderr.decode("utf-8")
|
28
37
|
except FileNotFoundError:
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
38
|
+
err_txt = (
|
39
|
+
"Plant UML executable not found. Either choose a different parser"
|
40
|
+
" or install with `bash scripts/install_plantuml.sh`. Java and"
|
41
|
+
" graphviz are dependencies for the tool, they must also be installed."
|
42
|
+
)
|
43
|
+
log.error(err_txt)
|
44
|
+
raise Exception(err_txt)
|
45
|
+
|
46
|
+
# Check for bad outputs, raise OutputParserExceptions if so
|
47
|
+
if "Error" in stderr or "Error" in stdout:
|
48
|
+
err_txt = "Recieved UML parsing error(s)."
|
49
|
+
|
50
|
+
line_nos = self._get_error_lines(stderr) + self._get_error_lines(stdout)
|
51
|
+
lines = text.split("\n")
|
52
|
+
for i in line_nos:
|
53
|
+
i0 = max(0, i - 3)
|
54
|
+
i1 = min(len(lines) - 1, i + 2)
|
55
|
+
err_lines = [
|
56
|
+
f"> {lines[j]}" if j == i - 1 else f" {lines[j]}"
|
57
|
+
for j in range(i0, i1)
|
58
|
+
]
|
59
|
+
if i0:
|
60
|
+
err_lines.insert(0, " ...")
|
61
|
+
if i1 < (len(lines) - 1):
|
62
|
+
err_lines.append(" ...")
|
33
63
|
|
34
|
-
|
35
|
-
|
64
|
+
err_txt += f"\nError located at line {i} must be fixed:\n"
|
65
|
+
err_txt += "\n".join(err_lines)
|
66
|
+
log.warning(err_txt)
|
67
|
+
raise OutputParserException(err_txt)
|
68
|
+
|
69
|
+
if "Warning" in stdout or "Warning" in stderr:
|
70
|
+
err_txt = "Recieved UML parsing warning (often due to missing PLANTUML)."
|
71
|
+
if stderr:
|
72
|
+
err_txt += f"\nSTDERR:\n```\n{stderr.strip()}\n```\n"
|
73
|
+
if stdout:
|
74
|
+
err_txt += f"\nSTDOUT:\n```\n{stdout.strip()}\n```\n"
|
75
|
+
|
76
|
+
log.warning(err_txt)
|
77
|
+
raise OutputParserException(err_txt)
|
78
|
+
|
79
|
+
def _get_error_lines(self, s: str) -> list[int]:
|
80
|
+
return [int(x.group(1)) for x in re.finditer(r"Error line (\d+) in file:", s)]
|
81
|
+
|
82
|
+
def _get_warns(self, s: str) -> list[str]:
|
83
|
+
return [x.group() for x in re.finditer(r"Warning: (.*)\n", s)]
|
36
84
|
|
37
85
|
def parse(self, text: str | BaseMessage) -> str:
|
38
86
|
text = super().parse(text)
|
39
|
-
|
40
|
-
if not janus_path.exists():
|
41
|
-
janus_path.mkdir()
|
42
|
-
temp_file_path = janus_path / "tmp.txt"
|
43
|
-
with open(temp_file_path, "w") as f:
|
44
|
-
f.write(text)
|
45
|
-
uml_std_out, uml_std_err = self._get_uml_output(temp_file_path)
|
46
|
-
uml_errs = self._get_errs(uml_std_out) + self._get_errs(uml_std_err)
|
47
|
-
if len(uml_errs) > 0:
|
48
|
-
raise OutputParserException(
|
49
|
-
"Error: Received UML Errors:\n" + "\n".join(uml_errs)
|
50
|
-
)
|
87
|
+
self._check_plantuml(text)
|
51
88
|
return text
|
janus/refiners/refiner.py
CHANGED
@@ -1,73 +1,115 @@
|
|
1
|
-
from
|
1
|
+
from typing import Any
|
2
2
|
|
3
|
-
from
|
3
|
+
from langchain.output_parsers import RetryWithErrorOutputParser
|
4
|
+
from langchain_core.output_parsers import StrOutputParser
|
5
|
+
from langchain_core.prompt_values import PromptValue
|
6
|
+
from langchain_core.runnables import RunnableSerializable
|
4
7
|
|
8
|
+
from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel
|
9
|
+
from janus.parsers.parser import JanusParser
|
10
|
+
from janus.utils.logger import create_logger
|
5
11
|
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
**
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
original_output: origial output of llm
|
20
|
-
errors: list of errors detected by parser
|
21
|
-
|
22
|
-
Returns:
|
23
|
-
Tuple of new prompt and prompt arguments
|
24
|
-
"""
|
12
|
+
log = create_logger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class JanusRefiner(JanusParser):
|
16
|
+
parser: JanusParser
|
17
|
+
|
18
|
+
def parse_runnable(self, input: dict[str, Any]) -> Any:
|
19
|
+
return self.parse_completion(**input)
|
20
|
+
|
21
|
+
def parse_completion(self, completion: str, **kwargs) -> Any:
|
22
|
+
return self.parser.parse(completion)
|
23
|
+
|
24
|
+
def parse(self, text: str) -> str:
|
25
25
|
raise NotImplementedError
|
26
26
|
|
27
27
|
|
28
|
-
class
|
28
|
+
class FixParserExceptions(JanusRefiner, RetryWithErrorOutputParser):
|
29
|
+
def __init__(self, llm: JanusModel, parser: JanusParser, max_retries: int):
|
30
|
+
retry_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
|
31
|
+
source_language="text",
|
32
|
+
prompt_template="refinement/fix_exceptions",
|
33
|
+
).prompt
|
34
|
+
chain = retry_prompt | llm | StrOutputParser()
|
35
|
+
RetryWithErrorOutputParser.__init__(
|
36
|
+
self, parser=parser, retry_chain=chain, max_retries=max_retries
|
37
|
+
)
|
38
|
+
|
39
|
+
def parse_completion(
|
40
|
+
self, completion: str, prompt_value: PromptValue, **kwargs
|
41
|
+
) -> Any:
|
42
|
+
return self.parse_with_prompt(completion, prompt_value=prompt_value)
|
43
|
+
|
44
|
+
|
45
|
+
class ReflectionRefiner(JanusRefiner):
|
46
|
+
max_retries: int
|
47
|
+
reflection_chain: RunnableSerializable
|
48
|
+
revision_chain: RunnableSerializable
|
49
|
+
|
29
50
|
def __init__(
|
30
51
|
self,
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
52
|
+
llm: JanusModel,
|
53
|
+
parser: JanusParser,
|
54
|
+
max_retries: int,
|
55
|
+
prompt_template_name: str = "refinement/reflection",
|
56
|
+
):
|
57
|
+
reflection_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
|
58
|
+
source_language="text",
|
59
|
+
prompt_template=prompt_template_name,
|
60
|
+
).prompt
|
61
|
+
revision_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
|
62
|
+
source_language="text",
|
63
|
+
prompt_template="refinement/revision",
|
64
|
+
).prompt
|
65
|
+
|
66
|
+
reflection_chain = reflection_prompt | llm | StrOutputParser()
|
67
|
+
revision_chain = revision_prompt | llm | StrOutputParser()
|
68
|
+
super().__init__(
|
69
|
+
reflection_chain=reflection_chain,
|
70
|
+
revision_chain=revision_chain,
|
71
|
+
parser=parser,
|
72
|
+
max_retries=max_retries,
|
73
|
+
)
|
74
|
+
|
75
|
+
def parse_completion(
|
76
|
+
self, completion: str, prompt_value: PromptValue, **kwargs
|
77
|
+
) -> Any:
|
78
|
+
for retry_number in range(self.max_retries):
|
79
|
+
reflection = self.reflection_chain.invoke(
|
80
|
+
dict(
|
81
|
+
prompt=prompt_value.to_string(),
|
82
|
+
completion=completion,
|
83
|
+
)
|
84
|
+
)
|
85
|
+
if reflection.strip() == "LGTM":
|
86
|
+
return self.parser.parse(completion)
|
87
|
+
if not retry_number:
|
88
|
+
log.info(f"Completion:\n{completion}")
|
89
|
+
log.info(f"Reflection:\n{reflection}")
|
90
|
+
completion = self.revision_chain.invoke(
|
91
|
+
dict(
|
92
|
+
prompt=prompt_value.to_string(),
|
93
|
+
completion=completion,
|
94
|
+
reflection=reflection,
|
95
|
+
)
|
96
|
+
)
|
97
|
+
log.info(f"Revision:\n{completion}")
|
98
|
+
|
99
|
+
return self.parser.parse(completion)
|
100
|
+
|
101
|
+
|
102
|
+
class HallucinationRefiner(ReflectionRefiner):
|
103
|
+
def __init__(self, **kwargs):
|
104
|
+
super().__init__(
|
105
|
+
prompt_template_name="refinement/hallucination",
|
106
|
+
**kwargs,
|
67
107
|
)
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
108
|
+
|
109
|
+
|
110
|
+
REFINERS = dict(
|
111
|
+
none=JanusRefiner,
|
112
|
+
parser=FixParserExceptions,
|
113
|
+
reflection=ReflectionRefiner,
|
114
|
+
hallucination=HallucinationRefiner,
|
115
|
+
)
|
@@ -0,0 +1,42 @@
|
|
1
|
+
from langchain_core.retrievers import BaseRetriever
|
2
|
+
from langchain_core.runnables import Runnable, RunnableConfig
|
3
|
+
|
4
|
+
from janus.language.block import CodeBlock
|
5
|
+
|
6
|
+
|
7
|
+
class JanusRetriever(Runnable):
|
8
|
+
def __init__(self) -> None:
|
9
|
+
super().__init__()
|
10
|
+
|
11
|
+
def invoke(
|
12
|
+
self, input: CodeBlock, config: RunnableConfig | None = None, **kwargs
|
13
|
+
) -> dict:
|
14
|
+
kwargs.update(context=self.get_context(input))
|
15
|
+
return kwargs
|
16
|
+
|
17
|
+
def get_context(self, code_block: CodeBlock) -> str:
|
18
|
+
return ""
|
19
|
+
|
20
|
+
|
21
|
+
class ActiveUsingsRetriever(JanusRetriever):
|
22
|
+
def get_context(self, code_block: CodeBlock) -> str:
|
23
|
+
context = "\n".join(
|
24
|
+
f"{context_tag}: {context}"
|
25
|
+
for context_tag, context in code_block.context_tags.items()
|
26
|
+
)
|
27
|
+
return f"You may use the following additional context: {context}"
|
28
|
+
|
29
|
+
|
30
|
+
class TextSearchRetriever(JanusRetriever):
|
31
|
+
retriever: BaseRetriever
|
32
|
+
|
33
|
+
def __init__(self, retriever: BaseRetriever):
|
34
|
+
super().__init__()
|
35
|
+
self.retriever = retriever
|
36
|
+
|
37
|
+
def get_context(self, code_block: CodeBlock) -> str:
|
38
|
+
if code_block.text is None:
|
39
|
+
return ""
|
40
|
+
docs = self.retriever.invoke(code_block.text)
|
41
|
+
context = "\n\n".join(doc.page_content for doc in docs)
|
42
|
+
return f"You may use the following additional context: {context}"
|