janus-llm 3.5.3__py3-none-any.whl → 4.1.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
janus/llm/models_info.py CHANGED
@@ -1,15 +1,14 @@
1
1
  import json
2
2
  import os
3
- import time
4
3
  from pathlib import Path
5
- from typing import Any, Callable
4
+ from typing import Callable, Protocol, TypeVar
6
5
 
7
6
  from dotenv import load_dotenv
8
7
  from langchain_community.llms import HuggingFaceTextGenInference
9
- from langchain_core.language_models import BaseLanguageModel
10
- from langchain_openai import ChatOpenAI
8
+ from langchain_core.runnables import Runnable
9
+ from langchain_openai import AzureChatOpenAI
11
10
 
12
- from janus.llm.model_callbacks import COST_PER_1K_TOKENS
11
+ from janus.llm.model_callbacks import COST_PER_1K_TOKENS, azure_model_reroutes
13
12
  from janus.prompts.prompt import (
14
13
  ChatGptPromptEngine,
15
14
  ClaudePromptEngine,
@@ -44,17 +43,33 @@ except ImportError:
44
43
  )
45
44
 
46
45
 
47
- load_dotenv()
46
+ ModelType = TypeVar(
47
+ "ModelType",
48
+ AzureChatOpenAI,
49
+ HuggingFaceTextGenInference,
50
+ Bedrock,
51
+ BedrockChat,
52
+ HuggingFacePipeline,
53
+ )
48
54
 
49
- openai_model_reroutes = {
50
- "gpt-4o": "gpt-4o-2024-05-13",
51
- "gpt-4o-mini": "gpt-4o-mini",
52
- "gpt-4": "gpt-4-0613",
53
- "gpt-4-turbo": "gpt-4-turbo-2024-04-09",
54
- "gpt-4-turbo-preview": "gpt-4-0125-preview",
55
- "gpt-3.5-turbo": "gpt-3.5-turbo-0125",
56
- "gpt-3.5-turbo-16k": "gpt-3.5-turbo-0125",
57
- }
55
+
56
+ class JanusModelProtocol(Protocol):
57
+ model_id: str
58
+ model_type_name: str
59
+ token_limit: int
60
+ input_token_cost: float
61
+ output_token_cost: float
62
+ prompt_engine: type[PromptEngine]
63
+
64
+ def get_num_tokens(self, text: str) -> int:
65
+ ...
66
+
67
+
68
+ class JanusModel(Runnable, JanusModelProtocol):
69
+ ...
70
+
71
+
72
+ load_dotenv()
58
73
 
59
74
  openai_models = [
60
75
  "gpt-4o",
@@ -65,6 +80,11 @@ openai_models = [
65
80
  "gpt-3.5-turbo",
66
81
  "gpt-3.5-turbo-16k",
67
82
  ]
83
+ azure_models = [
84
+ "gpt-4o",
85
+ "gpt-4o-mini",
86
+ "gpt-3.5-turbo-16k",
87
+ ]
68
88
  claude_models = [
69
89
  "bedrock-claude-v2",
70
90
  "bedrock-claude-instant-v1",
@@ -103,27 +123,21 @@ bedrock_models = [
103
123
  *cohere_models,
104
124
  *mistral_models,
105
125
  ]
106
- all_models = [*openai_models, *bedrock_models]
126
+ all_models = [*azure_models, *bedrock_models]
107
127
 
108
- MODEL_TYPE_CONSTRUCTORS: dict[str, Callable[[Any], BaseLanguageModel]] = {
109
- "OpenAI": ChatOpenAI,
128
+ MODEL_TYPE_CONSTRUCTORS: dict[str, ModelType] = {
129
+ # "OpenAI": ChatOpenAI,
110
130
  "HuggingFace": HuggingFaceTextGenInference,
131
+ "Azure": AzureChatOpenAI,
132
+ "Bedrock": Bedrock,
133
+ "BedrockChat": BedrockChat,
134
+ "HuggingFaceLocal": HuggingFacePipeline,
111
135
  }
112
136
 
113
- try:
114
- MODEL_TYPE_CONSTRUCTORS.update(
115
- {
116
- "HuggingFaceLocal": HuggingFacePipeline.from_model_id,
117
- "Bedrock": Bedrock,
118
- "BedrockChat": BedrockChat,
119
- }
120
- )
121
- except NameError:
122
- pass
123
-
124
137
 
125
138
  MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
126
- **{m: ChatGptPromptEngine for m in openai_models},
139
+ # **{m: ChatGptPromptEngine for m in openai_models},
140
+ **{m: ChatGptPromptEngine for m in azure_models},
127
141
  **{m: ClaudePromptEngine for m in claude_models},
128
142
  **{m: Llama2PromptEngine for m in llama2_models},
129
143
  **{m: Llama3PromptEngine for m in llama3_models},
@@ -132,13 +146,9 @@ MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
132
146
  **{m: MistralPromptEngine for m in mistral_models},
133
147
  }
134
148
 
135
- _open_ai_defaults: dict[str, str] = {
136
- "openai_api_key": os.getenv("OPENAI_API_KEY"),
137
- "openai_organization": os.getenv("OPENAI_ORG_ID"),
138
- }
139
-
140
149
  MODEL_ID_TO_LONG_ID = {
141
- **{m: mr for m, mr in openai_model_reroutes.items()},
150
+ # **{m: mr for m, mr in openai_model_reroutes.items()},
151
+ **{m: mr for m, mr in azure_model_reroutes.items()},
142
152
  "bedrock-claude-v2": "anthropic.claude-v2",
143
153
  "bedrock-claude-instant-v1": "anthropic.claude-instant-v1",
144
154
  "bedrock-claude-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
@@ -169,7 +179,8 @@ DEFAULT_MODELS = list(MODEL_DEFAULT_ARGUMENTS.keys())
169
179
  MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "llm"
170
180
 
171
181
  MODEL_TYPES: dict[str, PromptEngine] = {
172
- **{m: "OpenAI" for m in openai_models},
182
+ # **{m: "OpenAI" for m in openai_models},
183
+ **{m: "Azure" for m in azure_models},
173
184
  **{m: "BedrockChat" for m in bedrock_models},
174
185
  }
175
186
 
@@ -179,7 +190,10 @@ TOKEN_LIMITS: dict[str, int] = {
179
190
  "gpt-4-1106-preview": 128_000,
180
191
  "gpt-4-0125-preview": 128_000,
181
192
  "gpt-4o-2024-05-13": 128_000,
193
+ "gpt-4o-2024-08-06": 128_000,
194
+ "gpt-4o-mini": 128_000,
182
195
  "gpt-3.5-turbo-0125": 16_384,
196
+ "gpt35-turbo-16k": 16_384,
183
197
  "text-embedding-ada-002": 8191,
184
198
  "gpt4all": 16_384,
185
199
  "anthropic.claude-v2": 100_000,
@@ -211,53 +225,100 @@ def get_available_model_names() -> list[str]:
211
225
  return avaialable_models
212
226
 
213
227
 
214
- def load_model(
215
- user_model_name: str,
216
- ) -> tuple[BaseLanguageModel, str, int, dict[str, float]]:
228
+ def load_model(model_id) -> JanusModel:
217
229
  if not MODEL_CONFIG_DIR.exists():
218
230
  MODEL_CONFIG_DIR.mkdir(parents=True)
219
- model_config_file = MODEL_CONFIG_DIR / f"{user_model_name}.json"
220
- if not model_config_file.exists():
221
- log.warning(
222
- f"Model {user_model_name} not found in user-defined models, searching "
223
- f"default models for {user_model_name}."
224
- )
225
- model_id = user_model_name
226
- if user_model_name not in DEFAULT_MODELS:
227
- message = (
228
- f"Model {user_model_name} not found in default models. Make sure to run "
229
- "`janus llm add` first."
230
- )
231
- log.error(message)
232
- raise ValueError(message)
233
- model_config = {
234
- "model_type": MODEL_TYPES[model_id],
235
- "model_id": model_id,
236
- "model_args": MODEL_DEFAULT_ARGUMENTS[model_id],
237
- "token_limit": TOKEN_LIMITS.get(MODEL_ID_TO_LONG_ID[model_id], 4096),
238
- "model_cost": COST_PER_1K_TOKENS.get(
239
- MODEL_ID_TO_LONG_ID[model_id], {"input": 0, "output": 0}
240
- ),
241
- }
242
- with open(model_config_file, "w") as f:
243
- json.dump(model_config, f)
244
- else:
231
+ model_config_file = MODEL_CONFIG_DIR / f"{model_id}.json"
232
+
233
+ if model_config_file.exists():
234
+ log.info(f"Loading {model_id} from {model_config_file}.")
245
235
  with open(model_config_file, "r") as f:
246
236
  model_config = json.load(f)
247
- model_constructor = MODEL_TYPE_CONSTRUCTORS[model_config["model_type"]]
248
- model_args = model_config["model_args"]
249
- if model_config["model_type"] == "OpenAI":
250
- model_args.update(_open_ai_defaults)
251
- log.warning("Do NOT use this model in sensitive environments!")
252
- log.warning("If you would like to cancel, please press Ctrl+C.")
253
- log.warning("Waiting 10 seconds...")
237
+ model_type_name = model_config["model_type"]
238
+ model_id = model_config["model_id"]
239
+ model_args = model_config["model_args"]
240
+ token_limit = model_config["token_limit"]
241
+ input_token_cost = model_config["model_cost"]["input"]
242
+ output_token_cost = model_config["model_cost"]["output"]
243
+
244
+ elif model_id in DEFAULT_MODELS:
245
+ model_id = model_id
246
+ model_long_id = MODEL_ID_TO_LONG_ID[model_id]
247
+ model_type_name = MODEL_TYPES[model_id]
248
+ model_args = MODEL_DEFAULT_ARGUMENTS[model_id]
249
+
250
+ token_limit = 0
251
+ input_token_cost = 0.0
252
+ output_token_cost = 0.0
253
+ if model_long_id in TOKEN_LIMITS:
254
+ token_limit = TOKEN_LIMITS[model_long_id]
255
+ if model_long_id in COST_PER_1K_TOKENS:
256
+ token_limits = COST_PER_1K_TOKENS[model_long_id]
257
+ input_token_cost = token_limits["input"]
258
+ output_token_cost = token_limits["output"]
259
+
260
+ else:
261
+ model_list = "\n\t".join(DEFAULT_MODELS)
262
+ message = (
263
+ f"Model {model_id} not found in user-defined model directory "
264
+ f"({MODEL_CONFIG_DIR}), and is not a default model. Valid default "
265
+ f"models:\n\t{model_list}\n"
266
+ f"To use a custom model, first run `janus llm add`."
267
+ )
268
+ log.error(message)
269
+ raise ValueError(message)
270
+
271
+ if model_type_name == "HuggingFaceLocal":
272
+ model = HuggingFacePipeline.from_model_id(
273
+ model_id=model_id,
274
+ task="text-generation",
275
+ model_kwargs=model_args,
276
+ )
277
+ model_args.update(pipeline=model.pipeline)
278
+
279
+ elif model_type_name == "OpenAI":
280
+ model_args.update(
281
+ openai_api_key=str(os.getenv("OPENAI_API_KEY")),
282
+ openai_organization=str(os.getenv("OPENAI_ORG_ID")),
283
+ )
284
+ # log.warning("Do NOT use this model in sensitive environments!")
285
+ # log.warning("If you would like to cancel, please press Ctrl+C.")
286
+ # log.warning("Waiting 10 seconds...")
254
287
  # Give enough time for the user to read the warnings and cancel
255
- time.sleep(10)
256
-
257
- model = model_constructor(**model_args)
258
- return (
259
- model,
260
- model_config["model_id"],
261
- model_config["token_limit"],
262
- model_config["model_cost"],
288
+ # time.sleep(10)
289
+ raise DeprecationWarning("OpenAI models are no longer supported.")
290
+
291
+ elif model_type_name == "Azure":
292
+ model_args.update(
293
+ {
294
+ "api_key": os.getenv("AZURE_OPENAI_API_KEY"),
295
+ "azure_endpoint": os.getenv("AZURE_OPENAI_ENDPOINT"),
296
+ "api_version": os.getenv("OPENAI_API_VERSION", "2024-02-01"),
297
+ }
298
+ )
299
+
300
+ model_type = MODEL_TYPE_CONSTRUCTORS[model_type_name]
301
+ prompt_engine = MODEL_PROMPT_ENGINES[model_id]
302
+
303
+ class JanusModel(model_type):
304
+ model_id: str
305
+ short_model_id: str
306
+ model_type_name: str
307
+ token_limit: int
308
+ input_token_cost: float
309
+ output_token_cost: float
310
+ prompt_engine: type[PromptEngine]
311
+
312
+ model_args.update(
313
+ model_id=MODEL_ID_TO_LONG_ID[model_id],
314
+ short_model_id=model_id,
315
+ )
316
+
317
+ return JanusModel(
318
+ model_type_name=model_type_name,
319
+ token_limit=token_limit,
320
+ input_token_cost=input_token_cost,
321
+ output_token_cost=output_token_cost,
322
+ prompt_engine=prompt_engine,
323
+ **model_args,
263
324
  )
janus/metrics/metric.py CHANGED
@@ -8,6 +8,7 @@ import typer
8
8
  from typing_extensions import Annotated
9
9
 
10
10
  from janus.llm import load_model
11
+ from janus.llm.model_callbacks import COST_PER_1K_TOKENS
11
12
  from janus.metrics.cli import evaluate
12
13
  from janus.metrics.file_pairing import FILE_PAIRING_METHODS
13
14
  from janus.metrics.splitting import SPLITTING_METHODS
@@ -135,7 +136,7 @@ def metric(
135
136
  **kwargs,
136
137
  ):
137
138
  out = []
138
- llm, _, token_limit, model_cost = load_model(llm_name)
139
+ llm = load_model(llm_name)
139
140
  if json_file_name is not None:
140
141
  with open(json_file_name, "r") as f:
141
142
  json_obj = json.load(f)
@@ -171,8 +172,8 @@ def metric(
171
172
  out_file=out_file,
172
173
  lang=language,
173
174
  llm=llm,
174
- token_limit=token_limit,
175
- model_cost=model_cost,
175
+ token_limit=llm.token_limit,
176
+ model_cost=COST_PER_1K_TOKENS[llm.model_id],
176
177
  )
177
178
  else:
178
179
  raise ValueError(
@@ -187,8 +188,8 @@ def metric(
187
188
  progress,
188
189
  language,
189
190
  llm,
190
- token_limit,
191
- model_cost,
191
+ llm.token_limit,
192
+ COST_PER_1K_TOKENS[llm.model_id],
192
193
  *args,
193
194
  **kwargs,
194
195
  )
@@ -199,8 +200,8 @@ def metric(
199
200
  progress,
200
201
  language,
201
202
  llm,
202
- token_limit,
203
- model_cost,
203
+ llm.token_limit,
204
+ COST_PER_1K_TOKENS[llm.model_id],
204
205
  *args,
205
206
  **kwargs,
206
207
  )
@@ -296,7 +297,7 @@ def metric(
296
297
  *args,
297
298
  **kwargs,
298
299
  ):
299
- llm, _, token_limit, model_cost = load_model(llm_name)
300
+ llm = load_model(llm_name)
300
301
  if json_file_name is not None:
301
302
  with open(json_file_name, "r") as f:
302
303
  json_obj = json.load(f)
@@ -328,8 +329,8 @@ def metric(
328
329
  out_file=out_file,
329
330
  lang=language,
330
331
  llm=llm,
331
- token_limit=token_limit,
332
- model_cost=model_cost,
332
+ token_limit=llm.token_limit,
333
+ model_cost=COST_PER_1K_TOKENS[llm.model_id],
333
334
  )
334
335
  else:
335
336
  raise ValueError(
@@ -344,8 +345,8 @@ def metric(
344
345
  progress,
345
346
  language,
346
347
  llm,
347
- token_limit,
348
- model_cost,
348
+ llm.token_limit,
349
+ COST_PER_1K_TOKENS[llm.model_id],
349
350
  *args,
350
351
  **kwargs,
351
352
  )
@@ -356,8 +357,8 @@ def metric(
356
357
  progress,
357
358
  language,
358
359
  llm,
359
- token_limit,
360
- model_cost,
360
+ llm.token_limit,
361
+ COST_PER_1K_TOKENS[llm.model_id],
361
362
  *args,
362
363
  **kwargs,
363
364
  )
janus/parsers/uml.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import re
2
2
  import subprocess # nosec
3
3
  from pathlib import Path
4
- from typing import List, Tuple
4
+ from tempfile import NamedTemporaryFile
5
5
 
6
6
  from langchain_core.exceptions import OutputParserException
7
7
  from langchain_core.messages import BaseMessage
@@ -13,39 +13,76 @@ log = create_logger(__name__)
13
13
 
14
14
 
15
15
  class UMLSyntaxParser(CodeParser):
16
- def _get_uml_output(self, file: Path) -> Tuple[str, str]:
17
- # NOTE: running subprocess with shell=False, added nosec to label that we know
18
- # risk exists
16
+ def _check_plantuml(self, text: str) -> None:
17
+ # Leading newlines can break the parser, remove them
18
+ text = text.replace("\\n", "\n").strip()
19
+
20
+ # Write the text to a temporary file (automatically deleted)
21
+ file = NamedTemporaryFile()
22
+ fname = file.name
23
+ with open(fname, "w") as fin:
24
+ fin.write(text)
25
+
19
26
  try:
20
27
  plantuml_path = Path.home().expanduser() / ".janus/lib/plantuml.jar"
28
+ # NOTE: running subprocess with shell=False, added nosec to
29
+ # label that we know risk exists
21
30
  res = subprocess.run(
22
- ["java", "-jar", plantuml_path, file],
31
+ ["java", "-jar", plantuml_path, fname],
23
32
  stdout=subprocess.PIPE,
24
33
  stderr=subprocess.PIPE,
25
34
  ) # nosec
26
35
  stdout = res.stdout.decode("utf-8")
27
36
  stderr = res.stderr.decode("utf-8")
28
37
  except FileNotFoundError:
29
- log.warning("Plant UML executable not found, skipping syntax check")
30
- stdout = ""
31
- stderr = ""
32
- return stdout, stderr
38
+ err_txt = (
39
+ "Plant UML executable not found. Either choose a different parser"
40
+ " or install with `bash scripts/install_plantuml.sh`. Java and"
41
+ " graphviz are dependencies for the tool, they must also be installed."
42
+ )
43
+ log.error(err_txt)
44
+ raise Exception(err_txt)
45
+
46
+ # Check for bad outputs, raise OutputParserExceptions if so
47
+ if "Error" in stderr or "Error" in stdout:
48
+ err_txt = "Recieved UML parsing error(s)."
49
+
50
+ line_nos = self._get_error_lines(stderr) + self._get_error_lines(stdout)
51
+ lines = text.split("\n")
52
+ for i in line_nos:
53
+ i0 = max(0, i - 3)
54
+ i1 = min(len(lines) - 1, i + 2)
55
+ err_lines = [
56
+ f"> {lines[j]}" if j == i - 1 else f" {lines[j]}"
57
+ for j in range(i0, i1)
58
+ ]
59
+ if i0:
60
+ err_lines.insert(0, " ...")
61
+ if i1 < (len(lines) - 1):
62
+ err_lines.append(" ...")
33
63
 
34
- def _get_errs(self, s: str) -> List[str]:
35
- return [x.group() for x in re.finditer(r"Error (.*)\n", s)]
64
+ err_txt += f"\nError located at line {i} must be fixed:\n"
65
+ err_txt += "\n".join(err_lines)
66
+ log.warning(err_txt)
67
+ raise OutputParserException(err_txt)
68
+
69
+ if "Warning" in stdout or "Warning" in stderr:
70
+ err_txt = "Recieved UML parsing warning (often due to missing PLANTUML)."
71
+ if stderr:
72
+ err_txt += f"\nSTDERR:\n```\n{stderr.strip()}\n```\n"
73
+ if stdout:
74
+ err_txt += f"\nSTDOUT:\n```\n{stdout.strip()}\n```\n"
75
+
76
+ log.warning(err_txt)
77
+ raise OutputParserException(err_txt)
78
+
79
+ def _get_error_lines(self, s: str) -> list[int]:
80
+ return [int(x.group(1)) for x in re.finditer(r"Error line (\d+) in file:", s)]
81
+
82
+ def _get_warns(self, s: str) -> list[str]:
83
+ return [x.group() for x in re.finditer(r"Warning: (.*)\n", s)]
36
84
 
37
85
  def parse(self, text: str | BaseMessage) -> str:
38
86
  text = super().parse(text)
39
- janus_path = Path.home().expanduser() / Path(".janus")
40
- if not janus_path.exists():
41
- janus_path.mkdir()
42
- temp_file_path = janus_path / "tmp.txt"
43
- with open(temp_file_path, "w") as f:
44
- f.write(text)
45
- uml_std_out, uml_std_err = self._get_uml_output(temp_file_path)
46
- uml_errs = self._get_errs(uml_std_out) + self._get_errs(uml_std_err)
47
- if len(uml_errs) > 0:
48
- raise OutputParserException(
49
- "Error: Received UML Errors:\n" + "\n".join(uml_errs)
50
- )
87
+ self._check_plantuml(text)
51
88
  return text
janus/refiners/refiner.py CHANGED
@@ -1,73 +1,115 @@
1
- from langchain_core.prompts import ChatPromptTemplate
1
+ from typing import Any
2
2
 
3
- from janus.llm.models_info import MODEL_PROMPT_ENGINES
3
+ from langchain.output_parsers import RetryWithErrorOutputParser
4
+ from langchain_core.output_parsers import StrOutputParser
5
+ from langchain_core.prompt_values import PromptValue
6
+ from langchain_core.runnables import RunnableSerializable
4
7
 
8
+ from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel
9
+ from janus.parsers.parser import JanusParser
10
+ from janus.utils.logger import create_logger
5
11
 
6
- class Refiner:
7
- def refine(
8
- self,
9
- original_prompt: str,
10
- previous_prompt: str,
11
- previous_output: str,
12
- errors: str,
13
- **kwargs,
14
- ) -> tuple[ChatPromptTemplate, dict[str, str]]:
15
- """Creates a new prompt based on feedback from original results
16
-
17
- Arguments:
18
- original_prompt: original prompt used to produce output
19
- original_output: origial output of llm
20
- errors: list of errors detected by parser
21
-
22
- Returns:
23
- Tuple of new prompt and prompt arguments
24
- """
12
+ log = create_logger(__name__)
13
+
14
+
15
+ class JanusRefiner(JanusParser):
16
+ parser: JanusParser
17
+
18
+ def parse_runnable(self, input: dict[str, Any]) -> Any:
19
+ return self.parse_completion(**input)
20
+
21
+ def parse_completion(self, completion: str, **kwargs) -> Any:
22
+ return self.parser.parse(completion)
23
+
24
+ def parse(self, text: str) -> str:
25
25
  raise NotImplementedError
26
26
 
27
27
 
28
- class BasicRefiner(Refiner):
28
+ class FixParserExceptions(JanusRefiner, RetryWithErrorOutputParser):
29
+ def __init__(self, llm: JanusModel, parser: JanusParser, max_retries: int):
30
+ retry_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
31
+ source_language="text",
32
+ prompt_template="refinement/fix_exceptions",
33
+ ).prompt
34
+ chain = retry_prompt | llm | StrOutputParser()
35
+ RetryWithErrorOutputParser.__init__(
36
+ self, parser=parser, retry_chain=chain, max_retries=max_retries
37
+ )
38
+
39
+ def parse_completion(
40
+ self, completion: str, prompt_value: PromptValue, **kwargs
41
+ ) -> Any:
42
+ return self.parse_with_prompt(completion, prompt_value=prompt_value)
43
+
44
+
45
+ class ReflectionRefiner(JanusRefiner):
46
+ max_retries: int
47
+ reflection_chain: RunnableSerializable
48
+ revision_chain: RunnableSerializable
49
+
29
50
  def __init__(
30
51
  self,
31
- prompt_name: str,
32
- model_id: str,
33
- source_language: str,
34
- ) -> None:
35
- """Basic refiner, asks llm to fix output of previous prompt given errors
36
-
37
- Arguments:
38
- prompt_name: refinement prompt name to use
39
- model_id: ID of the llm to use. Found in models_info.py
40
- source_language: source_langauge to use
41
- """
42
- self._prompt_name = prompt_name
43
- self._model_id = model_id
44
- self._source_language = source_language
45
-
46
- def refine(
47
- self,
48
- original_prompt: str,
49
- previous_prompt: str,
50
- previous_output: str,
51
- errors: str,
52
- **kwargs,
53
- ) -> tuple[ChatPromptTemplate, dict[str, str]]:
54
- """Creates a new prompt based on feedback from original results
55
-
56
- Arguments:
57
- original_prompt: original prompt used to produce output
58
- original_output: origial output of llm
59
- errors: list of errors detected by parser
60
-
61
- Returns:
62
- Tuple of new prompt and prompt arguments
63
- """
64
- prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
65
- prompt_template=self._prompt_name,
66
- source_language=self._source_language,
52
+ llm: JanusModel,
53
+ parser: JanusParser,
54
+ max_retries: int,
55
+ prompt_template_name: str = "refinement/reflection",
56
+ ):
57
+ reflection_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
58
+ source_language="text",
59
+ prompt_template=prompt_template_name,
60
+ ).prompt
61
+ revision_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
62
+ source_language="text",
63
+ prompt_template="refinement/revision",
64
+ ).prompt
65
+
66
+ reflection_chain = reflection_prompt | llm | StrOutputParser()
67
+ revision_chain = revision_prompt | llm | StrOutputParser()
68
+ super().__init__(
69
+ reflection_chain=reflection_chain,
70
+ revision_chain=revision_chain,
71
+ parser=parser,
72
+ max_retries=max_retries,
73
+ )
74
+
75
+ def parse_completion(
76
+ self, completion: str, prompt_value: PromptValue, **kwargs
77
+ ) -> Any:
78
+ for retry_number in range(self.max_retries):
79
+ reflection = self.reflection_chain.invoke(
80
+ dict(
81
+ prompt=prompt_value.to_string(),
82
+ completion=completion,
83
+ )
84
+ )
85
+ if reflection.strip() == "LGTM":
86
+ return self.parser.parse(completion)
87
+ if not retry_number:
88
+ log.info(f"Completion:\n{completion}")
89
+ log.info(f"Reflection:\n{reflection}")
90
+ completion = self.revision_chain.invoke(
91
+ dict(
92
+ prompt=prompt_value.to_string(),
93
+ completion=completion,
94
+ reflection=reflection,
95
+ )
96
+ )
97
+ log.info(f"Revision:\n{completion}")
98
+
99
+ return self.parser.parse(completion)
100
+
101
+
102
+ class HallucinationRefiner(ReflectionRefiner):
103
+ def __init__(self, **kwargs):
104
+ super().__init__(
105
+ prompt_template_name="refinement/hallucination",
106
+ **kwargs,
67
107
  )
68
- prompt_arguments = {
69
- "ORIGINAL_PROMPT": original_prompt,
70
- "OUTPUT": previous_output,
71
- "ERRORS": errors,
72
- }
73
- return prompt_engine.prompt, prompt_arguments
108
+
109
+
110
+ REFINERS = dict(
111
+ none=JanusRefiner,
112
+ parser=FixParserExceptions,
113
+ reflection=ReflectionRefiner,
114
+ hallucination=HallucinationRefiner,
115
+ )