janus-llm 3.5.3__py3-none-any.whl → 4.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
janus/llm/models_info.py CHANGED
@@ -1,15 +1,14 @@
1
1
  import json
2
2
  import os
3
- import time
4
3
  from pathlib import Path
5
- from typing import Any, Callable
4
+ from typing import Callable, Protocol, TypeVar
6
5
 
7
6
  from dotenv import load_dotenv
8
7
  from langchain_community.llms import HuggingFaceTextGenInference
9
- from langchain_core.language_models import BaseLanguageModel
10
- from langchain_openai import ChatOpenAI
8
+ from langchain_core.runnables import Runnable
9
+ from langchain_openai import AzureChatOpenAI
11
10
 
12
- from janus.llm.model_callbacks import COST_PER_1K_TOKENS
11
+ from janus.llm.model_callbacks import COST_PER_1K_TOKENS, azure_model_reroutes
13
12
  from janus.prompts.prompt import (
14
13
  ChatGptPromptEngine,
15
14
  ClaudePromptEngine,
@@ -44,17 +43,33 @@ except ImportError:
44
43
  )
45
44
 
46
45
 
47
- load_dotenv()
46
+ ModelType = TypeVar(
47
+ "ModelType",
48
+ AzureChatOpenAI,
49
+ HuggingFaceTextGenInference,
50
+ Bedrock,
51
+ BedrockChat,
52
+ HuggingFacePipeline,
53
+ )
48
54
 
49
- openai_model_reroutes = {
50
- "gpt-4o": "gpt-4o-2024-05-13",
51
- "gpt-4o-mini": "gpt-4o-mini",
52
- "gpt-4": "gpt-4-0613",
53
- "gpt-4-turbo": "gpt-4-turbo-2024-04-09",
54
- "gpt-4-turbo-preview": "gpt-4-0125-preview",
55
- "gpt-3.5-turbo": "gpt-3.5-turbo-0125",
56
- "gpt-3.5-turbo-16k": "gpt-3.5-turbo-0125",
57
- }
55
+
56
+ class JanusModelProtocol(Protocol):
57
+ model_id: str
58
+ model_type_name: str
59
+ token_limit: int
60
+ input_token_cost: float
61
+ output_token_cost: float
62
+ prompt_engine: type[PromptEngine]
63
+
64
+ def get_num_tokens(self, text: str) -> int:
65
+ ...
66
+
67
+
68
+ class JanusModel(Runnable, JanusModelProtocol):
69
+ ...
70
+
71
+
72
+ load_dotenv()
58
73
 
59
74
  openai_models = [
60
75
  "gpt-4o",
@@ -65,6 +80,11 @@ openai_models = [
65
80
  "gpt-3.5-turbo",
66
81
  "gpt-3.5-turbo-16k",
67
82
  ]
83
+ azure_models = [
84
+ "gpt-4o",
85
+ "gpt-4o-mini",
86
+ "gpt-3.5-turbo-16k",
87
+ ]
68
88
  claude_models = [
69
89
  "bedrock-claude-v2",
70
90
  "bedrock-claude-instant-v1",
@@ -103,27 +123,21 @@ bedrock_models = [
103
123
  *cohere_models,
104
124
  *mistral_models,
105
125
  ]
106
- all_models = [*openai_models, *bedrock_models]
126
+ all_models = [*azure_models, *bedrock_models]
107
127
 
108
- MODEL_TYPE_CONSTRUCTORS: dict[str, Callable[[Any], BaseLanguageModel]] = {
109
- "OpenAI": ChatOpenAI,
128
+ MODEL_TYPE_CONSTRUCTORS: dict[str, ModelType] = {
129
+ # "OpenAI": ChatOpenAI,
110
130
  "HuggingFace": HuggingFaceTextGenInference,
131
+ "Azure": AzureChatOpenAI,
132
+ "Bedrock": Bedrock,
133
+ "BedrockChat": BedrockChat,
134
+ "HuggingFaceLocal": HuggingFacePipeline,
111
135
  }
112
136
 
113
- try:
114
- MODEL_TYPE_CONSTRUCTORS.update(
115
- {
116
- "HuggingFaceLocal": HuggingFacePipeline.from_model_id,
117
- "Bedrock": Bedrock,
118
- "BedrockChat": BedrockChat,
119
- }
120
- )
121
- except NameError:
122
- pass
123
-
124
137
 
125
138
  MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
126
- **{m: ChatGptPromptEngine for m in openai_models},
139
+ # **{m: ChatGptPromptEngine for m in openai_models},
140
+ **{m: ChatGptPromptEngine for m in azure_models},
127
141
  **{m: ClaudePromptEngine for m in claude_models},
128
142
  **{m: Llama2PromptEngine for m in llama2_models},
129
143
  **{m: Llama3PromptEngine for m in llama3_models},
@@ -132,13 +146,9 @@ MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
132
146
  **{m: MistralPromptEngine for m in mistral_models},
133
147
  }
134
148
 
135
- _open_ai_defaults: dict[str, str] = {
136
- "openai_api_key": os.getenv("OPENAI_API_KEY"),
137
- "openai_organization": os.getenv("OPENAI_ORG_ID"),
138
- }
139
-
140
149
  MODEL_ID_TO_LONG_ID = {
141
- **{m: mr for m, mr in openai_model_reroutes.items()},
150
+ # **{m: mr for m, mr in openai_model_reroutes.items()},
151
+ **{m: mr for m, mr in azure_model_reroutes.items()},
142
152
  "bedrock-claude-v2": "anthropic.claude-v2",
143
153
  "bedrock-claude-instant-v1": "anthropic.claude-instant-v1",
144
154
  "bedrock-claude-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
@@ -169,7 +179,8 @@ DEFAULT_MODELS = list(MODEL_DEFAULT_ARGUMENTS.keys())
169
179
  MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "llm"
170
180
 
171
181
  MODEL_TYPES: dict[str, PromptEngine] = {
172
- **{m: "OpenAI" for m in openai_models},
182
+ # **{m: "OpenAI" for m in openai_models},
183
+ **{m: "Azure" for m in azure_models},
173
184
  **{m: "BedrockChat" for m in bedrock_models},
174
185
  }
175
186
 
@@ -179,7 +190,10 @@ TOKEN_LIMITS: dict[str, int] = {
179
190
  "gpt-4-1106-preview": 128_000,
180
191
  "gpt-4-0125-preview": 128_000,
181
192
  "gpt-4o-2024-05-13": 128_000,
193
+ "gpt-4o-2024-08-06": 128_000,
194
+ "gpt-4o-mini": 128_000,
182
195
  "gpt-3.5-turbo-0125": 16_384,
196
+ "gpt35-turbo-16k": 16_384,
183
197
  "text-embedding-ada-002": 8191,
184
198
  "gpt4all": 16_384,
185
199
  "anthropic.claude-v2": 100_000,
@@ -211,53 +225,100 @@ def get_available_model_names() -> list[str]:
211
225
  return avaialable_models
212
226
 
213
227
 
214
- def load_model(
215
- user_model_name: str,
216
- ) -> tuple[BaseLanguageModel, str, int, dict[str, float]]:
228
+ def load_model(model_id) -> JanusModel:
217
229
  if not MODEL_CONFIG_DIR.exists():
218
230
  MODEL_CONFIG_DIR.mkdir(parents=True)
219
- model_config_file = MODEL_CONFIG_DIR / f"{user_model_name}.json"
220
- if not model_config_file.exists():
221
- log.warning(
222
- f"Model {user_model_name} not found in user-defined models, searching "
223
- f"default models for {user_model_name}."
224
- )
225
- model_id = user_model_name
226
- if user_model_name not in DEFAULT_MODELS:
227
- message = (
228
- f"Model {user_model_name} not found in default models. Make sure to run "
229
- "`janus llm add` first."
230
- )
231
- log.error(message)
232
- raise ValueError(message)
233
- model_config = {
234
- "model_type": MODEL_TYPES[model_id],
235
- "model_id": model_id,
236
- "model_args": MODEL_DEFAULT_ARGUMENTS[model_id],
237
- "token_limit": TOKEN_LIMITS.get(MODEL_ID_TO_LONG_ID[model_id], 4096),
238
- "model_cost": COST_PER_1K_TOKENS.get(
239
- MODEL_ID_TO_LONG_ID[model_id], {"input": 0, "output": 0}
240
- ),
241
- }
242
- with open(model_config_file, "w") as f:
243
- json.dump(model_config, f)
244
- else:
231
+ model_config_file = MODEL_CONFIG_DIR / f"{model_id}.json"
232
+
233
+ if model_config_file.exists():
234
+ log.info(f"Loading {model_id} from {model_config_file}.")
245
235
  with open(model_config_file, "r") as f:
246
236
  model_config = json.load(f)
247
- model_constructor = MODEL_TYPE_CONSTRUCTORS[model_config["model_type"]]
248
- model_args = model_config["model_args"]
249
- if model_config["model_type"] == "OpenAI":
250
- model_args.update(_open_ai_defaults)
251
- log.warning("Do NOT use this model in sensitive environments!")
252
- log.warning("If you would like to cancel, please press Ctrl+C.")
253
- log.warning("Waiting 10 seconds...")
237
+ model_type_name = model_config["model_type"]
238
+ model_id = model_config["model_id"]
239
+ model_args = model_config["model_args"]
240
+ token_limit = model_config["token_limit"]
241
+ input_token_cost = model_config["model_cost"]["input"]
242
+ output_token_cost = model_config["model_cost"]["output"]
243
+
244
+ elif model_id in DEFAULT_MODELS:
245
+ model_id = model_id
246
+ model_long_id = MODEL_ID_TO_LONG_ID[model_id]
247
+ model_type_name = MODEL_TYPES[model_id]
248
+ model_args = MODEL_DEFAULT_ARGUMENTS[model_id]
249
+
250
+ token_limit = 0
251
+ input_token_cost = 0.0
252
+ output_token_cost = 0.0
253
+ if model_long_id in TOKEN_LIMITS:
254
+ token_limit = TOKEN_LIMITS[model_long_id]
255
+ if model_long_id in COST_PER_1K_TOKENS:
256
+ token_limits = COST_PER_1K_TOKENS[model_long_id]
257
+ input_token_cost = token_limits["input"]
258
+ output_token_cost = token_limits["output"]
259
+
260
+ else:
261
+ model_list = "\n\t".join(DEFAULT_MODELS)
262
+ message = (
263
+ f"Model {model_id} not found in user-defined model directory "
264
+ f"({MODEL_CONFIG_DIR}), and is not a default model. Valid default "
265
+ f"models:\n\t{model_list}\n"
266
+ f"To use a custom model, first run `janus llm add`."
267
+ )
268
+ log.error(message)
269
+ raise ValueError(message)
270
+
271
+ if model_type_name == "HuggingFaceLocal":
272
+ model = HuggingFacePipeline.from_model_id(
273
+ model_id=model_id,
274
+ task="text-generation",
275
+ model_kwargs=model_args,
276
+ )
277
+ model_args.update(pipeline=model.pipeline)
278
+
279
+ elif model_type_name == "OpenAI":
280
+ model_args.update(
281
+ openai_api_key=str(os.getenv("OPENAI_API_KEY")),
282
+ openai_organization=str(os.getenv("OPENAI_ORG_ID")),
283
+ )
284
+ # log.warning("Do NOT use this model in sensitive environments!")
285
+ # log.warning("If you would like to cancel, please press Ctrl+C.")
286
+ # log.warning("Waiting 10 seconds...")
254
287
  # Give enough time for the user to read the warnings and cancel
255
- time.sleep(10)
256
-
257
- model = model_constructor(**model_args)
258
- return (
259
- model,
260
- model_config["model_id"],
261
- model_config["token_limit"],
262
- model_config["model_cost"],
288
+ # time.sleep(10)
289
+ raise DeprecationWarning("OpenAI models are no longer supported.")
290
+
291
+ elif model_type_name == "Azure":
292
+ model_args.update(
293
+ {
294
+ "api_key": os.getenv("AZURE_OPENAI_API_KEY"),
295
+ "azure_endpoint": os.getenv("AZURE_OPENAI_ENDPOINT"),
296
+ "api_version": os.getenv("OPENAI_API_VERSION", "2024-02-01"),
297
+ }
298
+ )
299
+
300
+ model_type = MODEL_TYPE_CONSTRUCTORS[model_type_name]
301
+ prompt_engine = MODEL_PROMPT_ENGINES[model_id]
302
+
303
+ class JanusModel(model_type):
304
+ model_id: str
305
+ short_model_id: str
306
+ model_type_name: str
307
+ token_limit: int
308
+ input_token_cost: float
309
+ output_token_cost: float
310
+ prompt_engine: type[PromptEngine]
311
+
312
+ model_args.update(
313
+ model_id=MODEL_ID_TO_LONG_ID[model_id],
314
+ short_model_id=model_id,
315
+ )
316
+
317
+ return JanusModel(
318
+ model_type_name=model_type_name,
319
+ token_limit=token_limit,
320
+ input_token_cost=input_token_cost,
321
+ output_token_cost=output_token_cost,
322
+ prompt_engine=prompt_engine,
323
+ **model_args,
263
324
  )
janus/metrics/metric.py CHANGED
@@ -8,6 +8,7 @@ import typer
8
8
  from typing_extensions import Annotated
9
9
 
10
10
  from janus.llm import load_model
11
+ from janus.llm.model_callbacks import COST_PER_1K_TOKENS
11
12
  from janus.metrics.cli import evaluate
12
13
  from janus.metrics.file_pairing import FILE_PAIRING_METHODS
13
14
  from janus.metrics.splitting import SPLITTING_METHODS
@@ -135,7 +136,7 @@ def metric(
135
136
  **kwargs,
136
137
  ):
137
138
  out = []
138
- llm, _, token_limit, model_cost = load_model(llm_name)
139
+ llm = load_model(llm_name)
139
140
  if json_file_name is not None:
140
141
  with open(json_file_name, "r") as f:
141
142
  json_obj = json.load(f)
@@ -171,8 +172,8 @@ def metric(
171
172
  out_file=out_file,
172
173
  lang=language,
173
174
  llm=llm,
174
- token_limit=token_limit,
175
- model_cost=model_cost,
175
+ token_limit=llm.token_limit,
176
+ model_cost=COST_PER_1K_TOKENS[llm.model_id],
176
177
  )
177
178
  else:
178
179
  raise ValueError(
@@ -187,8 +188,8 @@ def metric(
187
188
  progress,
188
189
  language,
189
190
  llm,
190
- token_limit,
191
- model_cost,
191
+ llm.token_limit,
192
+ COST_PER_1K_TOKENS[llm.model_id],
192
193
  *args,
193
194
  **kwargs,
194
195
  )
@@ -199,8 +200,8 @@ def metric(
199
200
  progress,
200
201
  language,
201
202
  llm,
202
- token_limit,
203
- model_cost,
203
+ llm.token_limit,
204
+ COST_PER_1K_TOKENS[llm.model_id],
204
205
  *args,
205
206
  **kwargs,
206
207
  )
@@ -296,7 +297,7 @@ def metric(
296
297
  *args,
297
298
  **kwargs,
298
299
  ):
299
- llm, _, token_limit, model_cost = load_model(llm_name)
300
+ llm = load_model(llm_name)
300
301
  if json_file_name is not None:
301
302
  with open(json_file_name, "r") as f:
302
303
  json_obj = json.load(f)
@@ -328,8 +329,8 @@ def metric(
328
329
  out_file=out_file,
329
330
  lang=language,
330
331
  llm=llm,
331
- token_limit=token_limit,
332
- model_cost=model_cost,
332
+ token_limit=llm.token_limit,
333
+ model_cost=COST_PER_1K_TOKENS[llm.model_id],
333
334
  )
334
335
  else:
335
336
  raise ValueError(
@@ -344,8 +345,8 @@ def metric(
344
345
  progress,
345
346
  language,
346
347
  llm,
347
- token_limit,
348
- model_cost,
348
+ llm.token_limit,
349
+ COST_PER_1K_TOKENS[llm.model_id],
349
350
  *args,
350
351
  **kwargs,
351
352
  )
@@ -356,8 +357,8 @@ def metric(
356
357
  progress,
357
358
  language,
358
359
  llm,
359
- token_limit,
360
- model_cost,
360
+ llm.token_limit,
361
+ COST_PER_1K_TOKENS[llm.model_id],
361
362
  *args,
362
363
  **kwargs,
363
364
  )
janus/parsers/uml.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import re
2
2
  import subprocess # nosec
3
3
  from pathlib import Path
4
- from typing import List, Tuple
4
+ from tempfile import NamedTemporaryFile
5
5
 
6
6
  from langchain_core.exceptions import OutputParserException
7
7
  from langchain_core.messages import BaseMessage
@@ -13,39 +13,76 @@ log = create_logger(__name__)
13
13
 
14
14
 
15
15
  class UMLSyntaxParser(CodeParser):
16
- def _get_uml_output(self, file: Path) -> Tuple[str, str]:
17
- # NOTE: running subprocess with shell=False, added nosec to label that we know
18
- # risk exists
16
+ def _check_plantuml(self, text: str) -> None:
17
+ # Leading newlines can break the parser, remove them
18
+ text = text.replace("\\n", "\n").strip()
19
+
20
+ # Write the text to a temporary file (automatically deleted)
21
+ file = NamedTemporaryFile()
22
+ fname = file.name
23
+ with open(fname, "w") as fin:
24
+ fin.write(text)
25
+
19
26
  try:
20
27
  plantuml_path = Path.home().expanduser() / ".janus/lib/plantuml.jar"
28
+ # NOTE: running subprocess with shell=False, added nosec to
29
+ # label that we know risk exists
21
30
  res = subprocess.run(
22
- ["java", "-jar", plantuml_path, file],
31
+ ["java", "-jar", plantuml_path, fname],
23
32
  stdout=subprocess.PIPE,
24
33
  stderr=subprocess.PIPE,
25
34
  ) # nosec
26
35
  stdout = res.stdout.decode("utf-8")
27
36
  stderr = res.stderr.decode("utf-8")
28
37
  except FileNotFoundError:
29
- log.warning("Plant UML executable not found, skipping syntax check")
30
- stdout = ""
31
- stderr = ""
32
- return stdout, stderr
38
+ err_txt = (
39
+ "Plant UML executable not found. Either choose a different parser"
40
+ " or install with `bash scripts/install_plantuml.sh`. Java and"
41
+ " graphviz are dependencies for the tool, they must also be installed."
42
+ )
43
+ log.error(err_txt)
44
+ raise Exception(err_txt)
45
+
46
+ # Check for bad outputs, raise OutputParserExceptions if so
47
+ if "Error" in stderr or "Error" in stdout:
48
+ err_txt = "Recieved UML parsing error(s)."
49
+
50
+ line_nos = self._get_error_lines(stderr) + self._get_error_lines(stdout)
51
+ lines = text.split("\n")
52
+ for i in line_nos:
53
+ i0 = max(0, i - 3)
54
+ i1 = min(len(lines) - 1, i + 2)
55
+ err_lines = [
56
+ f"> {lines[j]}" if j == i - 1 else f" {lines[j]}"
57
+ for j in range(i0, i1)
58
+ ]
59
+ if i0:
60
+ err_lines.insert(0, " ...")
61
+ if i1 < (len(lines) - 1):
62
+ err_lines.append(" ...")
33
63
 
34
- def _get_errs(self, s: str) -> List[str]:
35
- return [x.group() for x in re.finditer(r"Error (.*)\n", s)]
64
+ err_txt += f"\nError located at line {i} must be fixed:\n"
65
+ err_txt += "\n".join(err_lines)
66
+ log.warning(err_txt)
67
+ raise OutputParserException(err_txt)
68
+
69
+ if "Warning" in stdout or "Warning" in stderr:
70
+ err_txt = "Recieved UML parsing warning (often due to missing PLANTUML)."
71
+ if stderr:
72
+ err_txt += f"\nSTDERR:\n```\n{stderr.strip()}\n```\n"
73
+ if stdout:
74
+ err_txt += f"\nSTDOUT:\n```\n{stdout.strip()}\n```\n"
75
+
76
+ log.warning(err_txt)
77
+ raise OutputParserException(err_txt)
78
+
79
+ def _get_error_lines(self, s: str) -> list[int]:
80
+ return [int(x.group(1)) for x in re.finditer(r"Error line (\d+) in file:", s)]
81
+
82
+ def _get_warns(self, s: str) -> list[str]:
83
+ return [x.group() for x in re.finditer(r"Warning: (.*)\n", s)]
36
84
 
37
85
  def parse(self, text: str | BaseMessage) -> str:
38
86
  text = super().parse(text)
39
- janus_path = Path.home().expanduser() / Path(".janus")
40
- if not janus_path.exists():
41
- janus_path.mkdir()
42
- temp_file_path = janus_path / "tmp.txt"
43
- with open(temp_file_path, "w") as f:
44
- f.write(text)
45
- uml_std_out, uml_std_err = self._get_uml_output(temp_file_path)
46
- uml_errs = self._get_errs(uml_std_out) + self._get_errs(uml_std_err)
47
- if len(uml_errs) > 0:
48
- raise OutputParserException(
49
- "Error: Received UML Errors:\n" + "\n".join(uml_errs)
50
- )
87
+ self._check_plantuml(text)
51
88
  return text
janus/refiners/refiner.py CHANGED
@@ -1,73 +1,115 @@
1
- from langchain_core.prompts import ChatPromptTemplate
1
+ from typing import Any
2
2
 
3
- from janus.llm.models_info import MODEL_PROMPT_ENGINES
3
+ from langchain.output_parsers import RetryWithErrorOutputParser
4
+ from langchain_core.output_parsers import StrOutputParser
5
+ from langchain_core.prompt_values import PromptValue
6
+ from langchain_core.runnables import RunnableSerializable
4
7
 
8
+ from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel
9
+ from janus.parsers.parser import JanusParser
10
+ from janus.utils.logger import create_logger
5
11
 
6
- class Refiner:
7
- def refine(
8
- self,
9
- original_prompt: str,
10
- previous_prompt: str,
11
- previous_output: str,
12
- errors: str,
13
- **kwargs,
14
- ) -> tuple[ChatPromptTemplate, dict[str, str]]:
15
- """Creates a new prompt based on feedback from original results
16
-
17
- Arguments:
18
- original_prompt: original prompt used to produce output
19
- original_output: origial output of llm
20
- errors: list of errors detected by parser
21
-
22
- Returns:
23
- Tuple of new prompt and prompt arguments
24
- """
12
+ log = create_logger(__name__)
13
+
14
+
15
+ class JanusRefiner(JanusParser):
16
+ parser: JanusParser
17
+
18
+ def parse_runnable(self, input: dict[str, Any]) -> Any:
19
+ return self.parse_completion(**input)
20
+
21
+ def parse_completion(self, completion: str, **kwargs) -> Any:
22
+ return self.parser.parse(completion)
23
+
24
+ def parse(self, text: str) -> str:
25
25
  raise NotImplementedError
26
26
 
27
27
 
28
- class BasicRefiner(Refiner):
28
+ class FixParserExceptions(JanusRefiner, RetryWithErrorOutputParser):
29
+ def __init__(self, llm: JanusModel, parser: JanusParser, max_retries: int):
30
+ retry_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
31
+ source_language="text",
32
+ prompt_template="refinement/fix_exceptions",
33
+ ).prompt
34
+ chain = retry_prompt | llm | StrOutputParser()
35
+ RetryWithErrorOutputParser.__init__(
36
+ self, parser=parser, retry_chain=chain, max_retries=max_retries
37
+ )
38
+
39
+ def parse_completion(
40
+ self, completion: str, prompt_value: PromptValue, **kwargs
41
+ ) -> Any:
42
+ return self.parse_with_prompt(completion, prompt_value=prompt_value)
43
+
44
+
45
+ class ReflectionRefiner(JanusRefiner):
46
+ max_retries: int
47
+ reflection_chain: RunnableSerializable
48
+ revision_chain: RunnableSerializable
49
+
29
50
  def __init__(
30
51
  self,
31
- prompt_name: str,
32
- model_id: str,
33
- source_language: str,
34
- ) -> None:
35
- """Basic refiner, asks llm to fix output of previous prompt given errors
36
-
37
- Arguments:
38
- prompt_name: refinement prompt name to use
39
- model_id: ID of the llm to use. Found in models_info.py
40
- source_language: source_langauge to use
41
- """
42
- self._prompt_name = prompt_name
43
- self._model_id = model_id
44
- self._source_language = source_language
45
-
46
- def refine(
47
- self,
48
- original_prompt: str,
49
- previous_prompt: str,
50
- previous_output: str,
51
- errors: str,
52
- **kwargs,
53
- ) -> tuple[ChatPromptTemplate, dict[str, str]]:
54
- """Creates a new prompt based on feedback from original results
55
-
56
- Arguments:
57
- original_prompt: original prompt used to produce output
58
- original_output: origial output of llm
59
- errors: list of errors detected by parser
60
-
61
- Returns:
62
- Tuple of new prompt and prompt arguments
63
- """
64
- prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
65
- prompt_template=self._prompt_name,
66
- source_language=self._source_language,
52
+ llm: JanusModel,
53
+ parser: JanusParser,
54
+ max_retries: int,
55
+ prompt_template_name: str = "refinement/reflection",
56
+ ):
57
+ reflection_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
58
+ source_language="text",
59
+ prompt_template=prompt_template_name,
60
+ ).prompt
61
+ revision_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
62
+ source_language="text",
63
+ prompt_template="refinement/revision",
64
+ ).prompt
65
+
66
+ reflection_chain = reflection_prompt | llm | StrOutputParser()
67
+ revision_chain = revision_prompt | llm | StrOutputParser()
68
+ super().__init__(
69
+ reflection_chain=reflection_chain,
70
+ revision_chain=revision_chain,
71
+ parser=parser,
72
+ max_retries=max_retries,
73
+ )
74
+
75
+ def parse_completion(
76
+ self, completion: str, prompt_value: PromptValue, **kwargs
77
+ ) -> Any:
78
+ for retry_number in range(self.max_retries):
79
+ reflection = self.reflection_chain.invoke(
80
+ dict(
81
+ prompt=prompt_value.to_string(),
82
+ completion=completion,
83
+ )
84
+ )
85
+ if reflection.strip() == "LGTM":
86
+ return self.parser.parse(completion)
87
+ if not retry_number:
88
+ log.info(f"Completion:\n{completion}")
89
+ log.info(f"Reflection:\n{reflection}")
90
+ completion = self.revision_chain.invoke(
91
+ dict(
92
+ prompt=prompt_value.to_string(),
93
+ completion=completion,
94
+ reflection=reflection,
95
+ )
96
+ )
97
+ log.info(f"Revision:\n{completion}")
98
+
99
+ return self.parser.parse(completion)
100
+
101
+
102
+ class HallucinationRefiner(ReflectionRefiner):
103
+ def __init__(self, **kwargs):
104
+ super().__init__(
105
+ prompt_template_name="refinement/hallucination",
106
+ **kwargs,
67
107
  )
68
- prompt_arguments = {
69
- "ORIGINAL_PROMPT": original_prompt,
70
- "OUTPUT": previous_output,
71
- "ERRORS": errors,
72
- }
73
- return prompt_engine.prompt, prompt_arguments
108
+
109
+
110
+ REFINERS = dict(
111
+ none=JanusRefiner,
112
+ parser=FixParserExceptions,
113
+ reflection=ReflectionRefiner,
114
+ hallucination=HallucinationRefiner,
115
+ )