janus-llm 3.5.2__py3-none-any.whl → 4.0.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
janus/llm/models_info.py CHANGED
@@ -1,14 +1,15 @@
1
1
  import json
2
2
  import os
3
+ import time
3
4
  from pathlib import Path
4
- from typing import Any, Callable
5
+ from typing import Protocol, TypeVar
5
6
 
6
7
  from dotenv import load_dotenv
7
8
  from langchain_community.llms import HuggingFaceTextGenInference
8
- from langchain_core.language_models import BaseLanguageModel
9
+ from langchain_core.runnables import Runnable
9
10
  from langchain_openai import ChatOpenAI
10
11
 
11
- from janus.llm.model_callbacks import COST_PER_1K_TOKENS
12
+ from janus.llm.model_callbacks import COST_PER_1K_TOKENS, openai_model_reroutes
12
13
  from janus.prompts.prompt import (
13
14
  ChatGptPromptEngine,
14
15
  ClaudePromptEngine,
@@ -43,17 +44,34 @@ except ImportError:
43
44
  )
44
45
 
45
46
 
47
+ ModelType = TypeVar(
48
+ "ModelType",
49
+ ChatOpenAI,
50
+ HuggingFaceTextGenInference,
51
+ Bedrock,
52
+ BedrockChat,
53
+ HuggingFacePipeline,
54
+ )
55
+
56
+
57
+ class JanusModelProtocol(Protocol):
58
+ model_id: str
59
+ model_type_name: str
60
+ token_limit: int
61
+ input_token_cost: float
62
+ output_token_cost: float
63
+ prompt_engine: type[PromptEngine]
64
+
65
+ def get_num_tokens(self, text: str) -> int:
66
+ ...
67
+
68
+
69
+ class JanusModel(Runnable, JanusModelProtocol):
70
+ ...
71
+
72
+
46
73
  load_dotenv()
47
74
 
48
- openai_model_reroutes = {
49
- "gpt-4o": "gpt-4o-2024-05-13",
50
- "gpt-4o-mini": "gpt-4o-mini",
51
- "gpt-4": "gpt-4-0613",
52
- "gpt-4-turbo": "gpt-4-turbo-2024-04-09",
53
- "gpt-4-turbo-preview": "gpt-4-0125-preview",
54
- "gpt-3.5-turbo": "gpt-3.5-turbo-0125",
55
- "gpt-3.5-turbo-16k": "gpt-3.5-turbo-0125",
56
- }
57
75
 
58
76
  openai_models = [
59
77
  "gpt-4o",
@@ -104,24 +122,15 @@ bedrock_models = [
104
122
  ]
105
123
  all_models = [*openai_models, *bedrock_models]
106
124
 
107
- MODEL_TYPE_CONSTRUCTORS: dict[str, Callable[[Any], BaseLanguageModel]] = {
125
+ MODEL_TYPE_CONSTRUCTORS: dict[str, ModelType] = {
108
126
  "OpenAI": ChatOpenAI,
109
127
  "HuggingFace": HuggingFaceTextGenInference,
128
+ "Bedrock": Bedrock,
129
+ "BedrockChat": BedrockChat,
130
+ "HuggingFaceLocal": HuggingFacePipeline,
110
131
  }
111
132
 
112
- try:
113
- MODEL_TYPE_CONSTRUCTORS.update(
114
- {
115
- "HuggingFaceLocal": HuggingFacePipeline.from_model_id,
116
- "Bedrock": Bedrock,
117
- "BedrockChat": BedrockChat,
118
- }
119
- )
120
- except NameError:
121
- pass
122
-
123
-
124
- MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
133
+ MODEL_PROMPT_ENGINES: dict[str, type[PromptEngine]] = {
125
134
  **{m: ChatGptPromptEngine for m in openai_models},
126
135
  **{m: ClaudePromptEngine for m in claude_models},
127
136
  **{m: Llama2PromptEngine for m in llama2_models},
@@ -131,11 +140,6 @@ MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
131
140
  **{m: MistralPromptEngine for m in mistral_models},
132
141
  }
133
142
 
134
- _open_ai_defaults: dict[str, str] = {
135
- "openai_api_key": os.getenv("OPENAI_API_KEY"),
136
- "openai_organization": os.getenv("OPENAI_ORG_ID"),
137
- }
138
-
139
143
  MODEL_ID_TO_LONG_ID = {
140
144
  **{m: mr for m, mr in openai_model_reroutes.items()},
141
145
  "bedrock-claude-v2": "anthropic.claude-v2",
@@ -167,7 +171,7 @@ DEFAULT_MODELS = list(MODEL_DEFAULT_ARGUMENTS.keys())
167
171
 
168
172
  MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "llm"
169
173
 
170
- MODEL_TYPES: dict[str, PromptEngine] = {
174
+ MODEL_TYPES: dict[str, str] = {
171
175
  **{m: "OpenAI" for m in openai_models},
172
176
  **{m: "BedrockChat" for m in bedrock_models},
173
177
  }
@@ -210,47 +214,90 @@ def get_available_model_names() -> list[str]:
210
214
  return avaialable_models
211
215
 
212
216
 
213
- def load_model(
214
- user_model_name: str,
215
- ) -> tuple[BaseLanguageModel, str, int, dict[str, float]]:
217
+ def load_model(model_id) -> JanusModel:
216
218
  if not MODEL_CONFIG_DIR.exists():
217
219
  MODEL_CONFIG_DIR.mkdir(parents=True)
218
- model_config_file = MODEL_CONFIG_DIR / f"{user_model_name}.json"
219
- if not model_config_file.exists():
220
- log.warning(
221
- f"Model {user_model_name} not found in user-defined models, searching "
222
- f"default models for {user_model_name}."
223
- )
224
- model_id = user_model_name
225
- if user_model_name not in DEFAULT_MODELS:
226
- message = (
227
- f"Model {user_model_name} not found in default models. Make sure to run "
228
- "`janus llm add` first."
229
- )
230
- log.error(message)
231
- raise ValueError(message)
232
- model_config = {
233
- "model_type": MODEL_TYPES[model_id],
234
- "model_id": model_id,
235
- "model_args": MODEL_DEFAULT_ARGUMENTS[model_id],
236
- "token_limit": TOKEN_LIMITS.get(MODEL_ID_TO_LONG_ID[model_id], 4096),
237
- "model_cost": COST_PER_1K_TOKENS.get(
238
- MODEL_ID_TO_LONG_ID[model_id], {"input": 0, "output": 0}
239
- ),
240
- }
241
- with open(model_config_file, "w") as f:
242
- json.dump(model_config, f)
243
- else:
220
+ model_config_file = MODEL_CONFIG_DIR / f"{model_id}.json"
221
+
222
+ if model_config_file.exists():
223
+ log.info(f"Loading {model_id} from {model_config_file}.")
244
224
  with open(model_config_file, "r") as f:
245
225
  model_config = json.load(f)
246
- model_constructor = MODEL_TYPE_CONSTRUCTORS[model_config["model_type"]]
247
- model_args = model_config["model_args"]
248
- if model_config["model_type"] == "OpenAI":
249
- model_args.update(_open_ai_defaults)
250
- model = model_constructor(**model_args)
251
- return (
252
- model,
253
- model_config["model_id"],
254
- model_config["token_limit"],
255
- model_config["model_cost"],
226
+ model_type_name = model_config["model_type"]
227
+ model_id = model_config["model_id"]
228
+ model_args = model_config["model_args"]
229
+ token_limit = model_config["token_limit"]
230
+ input_token_cost = model_config["model_cost"]["input"]
231
+ output_token_cost = model_config["model_cost"]["output"]
232
+
233
+ elif model_id in DEFAULT_MODELS:
234
+ model_id = model_id
235
+ model_long_id = MODEL_ID_TO_LONG_ID[model_id]
236
+ model_type_name = MODEL_TYPES[model_id]
237
+ model_args = MODEL_DEFAULT_ARGUMENTS[model_id]
238
+
239
+ token_limit = 0
240
+ input_token_cost = 0.0
241
+ output_token_cost = 0.0
242
+ if model_long_id in TOKEN_LIMITS:
243
+ token_limit = TOKEN_LIMITS[model_long_id]
244
+ if model_long_id in COST_PER_1K_TOKENS:
245
+ token_limits = COST_PER_1K_TOKENS[model_long_id]
246
+ input_token_cost = token_limits["input"]
247
+ output_token_cost = token_limits["output"]
248
+
249
+ else:
250
+ model_list = "\n\t".join(DEFAULT_MODELS)
251
+ message = (
252
+ f"Model {model_id} not found in user-defined model directory "
253
+ f"({MODEL_CONFIG_DIR}), and is not a default model. Valid default "
254
+ f"models:\n\t{model_list}\n"
255
+ f"To use a custom model, first run `janus llm add`."
256
+ )
257
+ log.error(message)
258
+ raise ValueError(message)
259
+
260
+ if model_type_name == "HuggingFaceLocal":
261
+ model = HuggingFacePipeline.from_model_id(
262
+ model_id=model_id,
263
+ task="text-generation",
264
+ model_kwargs=model_args,
265
+ )
266
+ model_args.update(pipeline=model.pipeline)
267
+
268
+ elif model_type_name == "OpenAI":
269
+ model_args.update(
270
+ openai_api_key=str(os.getenv("OPENAI_API_KEY")),
271
+ openai_organization=str(os.getenv("OPENAI_ORG_ID")),
272
+ )
273
+ log.warning("Do NOT use this model in sensitive environments!")
274
+ log.warning("If you would like to cancel, please press Ctrl+C.")
275
+ log.warning("Waiting 10 seconds...")
276
+ # Give enough time for the user to read the warnings and cancel
277
+ time.sleep(10)
278
+
279
+ model_type = MODEL_TYPE_CONSTRUCTORS[model_type_name]
280
+ prompt_engine = MODEL_PROMPT_ENGINES[model_id]
281
+
282
+ class JanusModel(model_type):
283
+ model_id: str
284
+ short_model_id: str
285
+ model_type_name: str
286
+ token_limit: int
287
+ input_token_cost: float
288
+ output_token_cost: float
289
+ prompt_engine: type[PromptEngine]
290
+
291
+ model_args.update(
292
+ model_id=MODEL_ID_TO_LONG_ID[model_id],
293
+ short_model_id=model_id,
294
+ )
295
+
296
+ return JanusModel(
297
+ model_type_name=model_type_name,
298
+ token_limit=token_limit,
299
+ input_token_cost=input_token_cost,
300
+ output_token_cost=output_token_cost,
301
+ prompt_engine=prompt_engine,
302
+ **model_args,
256
303
  )
janus/metrics/metric.py CHANGED
@@ -8,6 +8,7 @@ import typer
8
8
  from typing_extensions import Annotated
9
9
 
10
10
  from janus.llm import load_model
11
+ from janus.llm.model_callbacks import COST_PER_1K_TOKENS
11
12
  from janus.metrics.cli import evaluate
12
13
  from janus.metrics.file_pairing import FILE_PAIRING_METHODS
13
14
  from janus.metrics.splitting import SPLITTING_METHODS
@@ -135,7 +136,7 @@ def metric(
135
136
  **kwargs,
136
137
  ):
137
138
  out = []
138
- llm, _, token_limit, model_cost = load_model(llm_name)
139
+ llm = load_model(llm_name)
139
140
  if json_file_name is not None:
140
141
  with open(json_file_name, "r") as f:
141
142
  json_obj = json.load(f)
@@ -171,8 +172,8 @@ def metric(
171
172
  out_file=out_file,
172
173
  lang=language,
173
174
  llm=llm,
174
- token_limit=token_limit,
175
- model_cost=model_cost,
175
+ token_limit=llm.token_limit,
176
+ model_cost=COST_PER_1K_TOKENS[llm.model_id],
176
177
  )
177
178
  else:
178
179
  raise ValueError(
@@ -187,8 +188,8 @@ def metric(
187
188
  progress,
188
189
  language,
189
190
  llm,
190
- token_limit,
191
- model_cost,
191
+ llm.token_limit,
192
+ COST_PER_1K_TOKENS[llm.model_id],
192
193
  *args,
193
194
  **kwargs,
194
195
  )
@@ -199,8 +200,8 @@ def metric(
199
200
  progress,
200
201
  language,
201
202
  llm,
202
- token_limit,
203
- model_cost,
203
+ llm.token_limit,
204
+ COST_PER_1K_TOKENS[llm.model_id],
204
205
  *args,
205
206
  **kwargs,
206
207
  )
@@ -296,7 +297,7 @@ def metric(
296
297
  *args,
297
298
  **kwargs,
298
299
  ):
299
- llm, _, token_limit, model_cost = load_model(llm_name)
300
+ llm = load_model(llm_name)
300
301
  if json_file_name is not None:
301
302
  with open(json_file_name, "r") as f:
302
303
  json_obj = json.load(f)
@@ -328,8 +329,8 @@ def metric(
328
329
  out_file=out_file,
329
330
  lang=language,
330
331
  llm=llm,
331
- token_limit=token_limit,
332
- model_cost=model_cost,
332
+ token_limit=llm.token_limit,
333
+ model_cost=COST_PER_1K_TOKENS[llm.model_id],
333
334
  )
334
335
  else:
335
336
  raise ValueError(
@@ -344,8 +345,8 @@ def metric(
344
345
  progress,
345
346
  language,
346
347
  llm,
347
- token_limit,
348
- model_cost,
348
+ llm.token_limit,
349
+ COST_PER_1K_TOKENS[llm.model_id],
349
350
  *args,
350
351
  **kwargs,
351
352
  )
@@ -356,8 +357,8 @@ def metric(
356
357
  progress,
357
358
  language,
358
359
  llm,
359
- token_limit,
360
- model_cost,
360
+ llm.token_limit,
361
+ COST_PER_1K_TOKENS[llm.model_id],
361
362
  *args,
362
363
  **kwargs,
363
364
  )
janus/parsers/uml.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import re
2
2
  import subprocess # nosec
3
3
  from pathlib import Path
4
- from typing import List, Tuple
4
+ from tempfile import NamedTemporaryFile
5
5
 
6
6
  from langchain_core.exceptions import OutputParserException
7
7
  from langchain_core.messages import BaseMessage
@@ -13,39 +13,76 @@ log = create_logger(__name__)
13
13
 
14
14
 
15
15
  class UMLSyntaxParser(CodeParser):
16
- def _get_uml_output(self, file: Path) -> Tuple[str, str]:
17
- # NOTE: running subprocess with shell=False, added nosec to label that we know
18
- # risk exists
16
+ def _check_plantuml(self, text: str) -> None:
17
+ # Leading newlines can break the parser, remove them
18
+ text = text.replace("\\n", "\n").strip()
19
+
20
+ # Write the text to a temporary file (automatically deleted)
21
+ file = NamedTemporaryFile()
22
+ fname = file.name
23
+ with open(fname, "w") as fin:
24
+ fin.write(text)
25
+
19
26
  try:
20
27
  plantuml_path = Path.home().expanduser() / ".janus/lib/plantuml.jar"
28
+ # NOTE: running subprocess with shell=False, added nosec to
29
+ # label that we know risk exists
21
30
  res = subprocess.run(
22
- ["java", "-jar", plantuml_path, file],
31
+ ["java", "-jar", plantuml_path, fname],
23
32
  stdout=subprocess.PIPE,
24
33
  stderr=subprocess.PIPE,
25
34
  ) # nosec
26
35
  stdout = res.stdout.decode("utf-8")
27
36
  stderr = res.stderr.decode("utf-8")
28
37
  except FileNotFoundError:
29
- log.warning("Plant UML executable not found, skipping syntax check")
30
- stdout = ""
31
- stderr = ""
32
- return stdout, stderr
38
+ err_txt = (
39
+ "Plant UML executable not found. Either choose a different parser"
40
+ " or install with `bash scripts/install_plantuml.sh`. Java and"
41
+ " graphviz are dependencies for the tool, they must also be installed."
42
+ )
43
+ log.error(err_txt)
44
+ raise Exception(err_txt)
45
+
46
+ # Check for bad outputs, raise OutputParserExceptions if so
47
+ if "Error" in stderr or "Error" in stdout:
48
+ err_txt = "Recieved UML parsing error(s)."
49
+
50
+ line_nos = self._get_error_lines(stderr) + self._get_error_lines(stdout)
51
+ lines = text.split("\n")
52
+ for i in line_nos:
53
+ i0 = max(0, i - 3)
54
+ i1 = min(len(lines) - 1, i + 2)
55
+ err_lines = [
56
+ f"> {lines[j]}" if j == i - 1 else f" {lines[j]}"
57
+ for j in range(i0, i1)
58
+ ]
59
+ if i0:
60
+ err_lines.insert(0, " ...")
61
+ if i1 < (len(lines) - 1):
62
+ err_lines.append(" ...")
33
63
 
34
- def _get_errs(self, s: str) -> List[str]:
35
- return [x.group() for x in re.finditer(r"Error (.*)\n", s)]
64
+ err_txt += f"\nError located at line {i} must be fixed:\n"
65
+ err_txt += "\n".join(err_lines)
66
+ log.warning(err_txt)
67
+ raise OutputParserException(err_txt)
68
+
69
+ if "Warning" in stdout or "Warning" in stderr:
70
+ err_txt = "Recieved UML parsing warning (often due to missing PLANTUML)."
71
+ if stderr:
72
+ err_txt += f"\nSTDERR:\n```\n{stderr.strip()}\n```\n"
73
+ if stdout:
74
+ err_txt += f"\nSTDOUT:\n```\n{stdout.strip()}\n```\n"
75
+
76
+ log.warning(err_txt)
77
+ raise OutputParserException(err_txt)
78
+
79
+ def _get_error_lines(self, s: str) -> list[int]:
80
+ return [int(x.group(1)) for x in re.finditer(r"Error line (\d+) in file:", s)]
81
+
82
+ def _get_warns(self, s: str) -> list[str]:
83
+ return [x.group() for x in re.finditer(r"Warning: (.*)\n", s)]
36
84
 
37
85
  def parse(self, text: str | BaseMessage) -> str:
38
86
  text = super().parse(text)
39
- janus_path = Path.home().expanduser() / Path(".janus")
40
- if not janus_path.exists():
41
- janus_path.mkdir()
42
- temp_file_path = janus_path / "tmp.txt"
43
- with open(temp_file_path, "w") as f:
44
- f.write(text)
45
- uml_std_out, uml_std_err = self._get_uml_output(temp_file_path)
46
- uml_errs = self._get_errs(uml_std_out) + self._get_errs(uml_std_err)
47
- if len(uml_errs) > 0:
48
- raise OutputParserException(
49
- "Error: Received UML Errors:\n" + "\n".join(uml_errs)
50
- )
87
+ self._check_plantuml(text)
51
88
  return text
janus/refiners/refiner.py CHANGED
@@ -1,73 +1,115 @@
1
- from langchain_core.prompts import ChatPromptTemplate
1
+ from typing import Any
2
2
 
3
- from janus.llm.models_info import MODEL_PROMPT_ENGINES
3
+ from langchain.output_parsers import RetryWithErrorOutputParser
4
+ from langchain_core.output_parsers import StrOutputParser
5
+ from langchain_core.prompt_values import PromptValue
6
+ from langchain_core.runnables import RunnableSerializable
4
7
 
8
+ from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel
9
+ from janus.parsers.parser import JanusParser
10
+ from janus.utils.logger import create_logger
5
11
 
6
- class Refiner:
7
- def refine(
8
- self,
9
- original_prompt: str,
10
- previous_prompt: str,
11
- previous_output: str,
12
- errors: str,
13
- **kwargs,
14
- ) -> tuple[ChatPromptTemplate, dict[str, str]]:
15
- """Creates a new prompt based on feedback from original results
16
-
17
- Arguments:
18
- original_prompt: original prompt used to produce output
19
- original_output: origial output of llm
20
- errors: list of errors detected by parser
21
-
22
- Returns:
23
- Tuple of new prompt and prompt arguments
24
- """
12
+ log = create_logger(__name__)
13
+
14
+
15
+ class JanusRefiner(JanusParser):
16
+ parser: JanusParser
17
+
18
+ def parse_runnable(self, input: dict[str, Any]) -> Any:
19
+ return self.parse_completion(**input)
20
+
21
+ def parse_completion(self, completion: str, **kwargs) -> Any:
22
+ return self.parser.parse(completion)
23
+
24
+ def parse(self, text: str) -> str:
25
25
  raise NotImplementedError
26
26
 
27
27
 
28
- class BasicRefiner(Refiner):
28
+ class FixParserExceptions(JanusRefiner, RetryWithErrorOutputParser):
29
+ def __init__(self, llm: JanusModel, parser: JanusParser, max_retries: int):
30
+ retry_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
31
+ source_language="text",
32
+ prompt_template="refinement/fix_exceptions",
33
+ ).prompt
34
+ chain = retry_prompt | llm | StrOutputParser()
35
+ RetryWithErrorOutputParser.__init__(
36
+ self, parser=parser, retry_chain=chain, max_retries=max_retries
37
+ )
38
+
39
+ def parse_completion(
40
+ self, completion: str, prompt_value: PromptValue, **kwargs
41
+ ) -> Any:
42
+ return self.parse_with_prompt(completion, prompt_value=prompt_value)
43
+
44
+
45
+ class ReflectionRefiner(JanusRefiner):
46
+ max_retries: int
47
+ reflection_chain: RunnableSerializable
48
+ revision_chain: RunnableSerializable
49
+
29
50
  def __init__(
30
51
  self,
31
- prompt_name: str,
32
- model_id: str,
33
- source_language: str,
34
- ) -> None:
35
- """Basic refiner, asks llm to fix output of previous prompt given errors
36
-
37
- Arguments:
38
- prompt_name: refinement prompt name to use
39
- model_id: ID of the llm to use. Found in models_info.py
40
- source_language: source_langauge to use
41
- """
42
- self._prompt_name = prompt_name
43
- self._model_id = model_id
44
- self._source_language = source_language
45
-
46
- def refine(
47
- self,
48
- original_prompt: str,
49
- previous_prompt: str,
50
- previous_output: str,
51
- errors: str,
52
- **kwargs,
53
- ) -> tuple[ChatPromptTemplate, dict[str, str]]:
54
- """Creates a new prompt based on feedback from original results
55
-
56
- Arguments:
57
- original_prompt: original prompt used to produce output
58
- original_output: origial output of llm
59
- errors: list of errors detected by parser
60
-
61
- Returns:
62
- Tuple of new prompt and prompt arguments
63
- """
64
- prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
65
- prompt_template=self._prompt_name,
66
- source_language=self._source_language,
52
+ llm: JanusModel,
53
+ parser: JanusParser,
54
+ max_retries: int,
55
+ prompt_template_name: str = "refinement/reflection",
56
+ ):
57
+ reflection_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
58
+ source_language="text",
59
+ prompt_template=prompt_template_name,
60
+ ).prompt
61
+ revision_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
62
+ source_language="text",
63
+ prompt_template="refinement/revision",
64
+ ).prompt
65
+
66
+ reflection_chain = reflection_prompt | llm | StrOutputParser()
67
+ revision_chain = revision_prompt | llm | StrOutputParser()
68
+ super().__init__(
69
+ reflection_chain=reflection_chain,
70
+ revision_chain=revision_chain,
71
+ parser=parser,
72
+ max_retries=max_retries,
73
+ )
74
+
75
+ def parse_completion(
76
+ self, completion: str, prompt_value: PromptValue, **kwargs
77
+ ) -> Any:
78
+ for retry_number in range(self.max_retries):
79
+ reflection = self.reflection_chain.invoke(
80
+ dict(
81
+ prompt=prompt_value.to_string(),
82
+ completion=completion,
83
+ )
84
+ )
85
+ if reflection.strip() == "LGTM":
86
+ return self.parser.parse(completion)
87
+ if not retry_number:
88
+ log.info(f"Completion:\n{completion}")
89
+ log.info(f"Reflection:\n{reflection}")
90
+ completion = self.revision_chain.invoke(
91
+ dict(
92
+ prompt=prompt_value.to_string(),
93
+ completion=completion,
94
+ reflection=reflection,
95
+ )
96
+ )
97
+ log.info(f"Revision:\n{completion}")
98
+
99
+ return self.parser.parse(completion)
100
+
101
+
102
+ class HallucinationRefiner(ReflectionRefiner):
103
+ def __init__(self, **kwargs):
104
+ super().__init__(
105
+ prompt_template_name="refinement/hallucination",
106
+ **kwargs,
67
107
  )
68
- prompt_arguments = {
69
- "ORIGINAL_PROMPT": original_prompt,
70
- "OUTPUT": previous_output,
71
- "ERRORS": errors,
72
- }
73
- return prompt_engine.prompt, prompt_arguments
108
+
109
+
110
+ REFINERS = dict(
111
+ none=JanusRefiner,
112
+ parser=FixParserExceptions,
113
+ reflection=ReflectionRefiner,
114
+ hallucination=HallucinationRefiner,
115
+ )
@@ -0,0 +1,42 @@
1
+ from langchain_core.retrievers import BaseRetriever
2
+ from langchain_core.runnables import Runnable, RunnableConfig
3
+
4
+ from janus.language.block import CodeBlock
5
+
6
+
7
+ class JanusRetriever(Runnable):
8
+ def __init__(self) -> None:
9
+ super().__init__()
10
+
11
+ def invoke(
12
+ self, input: CodeBlock, config: RunnableConfig | None = None, **kwargs
13
+ ) -> dict:
14
+ kwargs.update(context=self.get_context(input))
15
+ return kwargs
16
+
17
+ def get_context(self, code_block: CodeBlock) -> str:
18
+ return ""
19
+
20
+
21
+ class ActiveUsingsRetriever(JanusRetriever):
22
+ def get_context(self, code_block: CodeBlock) -> str:
23
+ context = "\n".join(
24
+ f"{context_tag}: {context}"
25
+ for context_tag, context in code_block.context_tags.items()
26
+ )
27
+ return f"You may use the following additional context: {context}"
28
+
29
+
30
+ class TextSearchRetriever(JanusRetriever):
31
+ retriever: BaseRetriever
32
+
33
+ def __init__(self, retriever: BaseRetriever):
34
+ super().__init__()
35
+ self.retriever = retriever
36
+
37
+ def get_context(self, code_block: CodeBlock) -> str:
38
+ if code_block.text is None:
39
+ return ""
40
+ docs = self.retriever.invoke(code_block.text)
41
+ context = "\n\n".join(doc.page_content for doc in docs)
42
+ return f"You may use the following additional context: {context}"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: janus-llm
3
- Version: 3.5.2
3
+ Version: 4.0.0
4
4
  Summary: A transcoding library using LLMs.
5
5
  Home-page: https://github.com/janus-llm/janus-llm
6
6
  License: Apache 2.0