janus-llm 3.5.2__py3-none-any.whl → 4.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
janus/llm/models_info.py CHANGED
@@ -1,14 +1,15 @@
1
1
  import json
2
2
  import os
3
+ import time
3
4
  from pathlib import Path
4
- from typing import Any, Callable
5
+ from typing import Protocol, TypeVar
5
6
 
6
7
  from dotenv import load_dotenv
7
8
  from langchain_community.llms import HuggingFaceTextGenInference
8
- from langchain_core.language_models import BaseLanguageModel
9
+ from langchain_core.runnables import Runnable
9
10
  from langchain_openai import ChatOpenAI
10
11
 
11
- from janus.llm.model_callbacks import COST_PER_1K_TOKENS
12
+ from janus.llm.model_callbacks import COST_PER_1K_TOKENS, openai_model_reroutes
12
13
  from janus.prompts.prompt import (
13
14
  ChatGptPromptEngine,
14
15
  ClaudePromptEngine,
@@ -43,17 +44,34 @@ except ImportError:
43
44
  )
44
45
 
45
46
 
47
+ ModelType = TypeVar(
48
+ "ModelType",
49
+ ChatOpenAI,
50
+ HuggingFaceTextGenInference,
51
+ Bedrock,
52
+ BedrockChat,
53
+ HuggingFacePipeline,
54
+ )
55
+
56
+
57
+ class JanusModelProtocol(Protocol):
58
+ model_id: str
59
+ model_type_name: str
60
+ token_limit: int
61
+ input_token_cost: float
62
+ output_token_cost: float
63
+ prompt_engine: type[PromptEngine]
64
+
65
+ def get_num_tokens(self, text: str) -> int:
66
+ ...
67
+
68
+
69
+ class JanusModel(Runnable, JanusModelProtocol):
70
+ ...
71
+
72
+
46
73
  load_dotenv()
47
74
 
48
- openai_model_reroutes = {
49
- "gpt-4o": "gpt-4o-2024-05-13",
50
- "gpt-4o-mini": "gpt-4o-mini",
51
- "gpt-4": "gpt-4-0613",
52
- "gpt-4-turbo": "gpt-4-turbo-2024-04-09",
53
- "gpt-4-turbo-preview": "gpt-4-0125-preview",
54
- "gpt-3.5-turbo": "gpt-3.5-turbo-0125",
55
- "gpt-3.5-turbo-16k": "gpt-3.5-turbo-0125",
56
- }
57
75
 
58
76
  openai_models = [
59
77
  "gpt-4o",
@@ -104,24 +122,15 @@ bedrock_models = [
104
122
  ]
105
123
  all_models = [*openai_models, *bedrock_models]
106
124
 
107
- MODEL_TYPE_CONSTRUCTORS: dict[str, Callable[[Any], BaseLanguageModel]] = {
125
+ MODEL_TYPE_CONSTRUCTORS: dict[str, ModelType] = {
108
126
  "OpenAI": ChatOpenAI,
109
127
  "HuggingFace": HuggingFaceTextGenInference,
128
+ "Bedrock": Bedrock,
129
+ "BedrockChat": BedrockChat,
130
+ "HuggingFaceLocal": HuggingFacePipeline,
110
131
  }
111
132
 
112
- try:
113
- MODEL_TYPE_CONSTRUCTORS.update(
114
- {
115
- "HuggingFaceLocal": HuggingFacePipeline.from_model_id,
116
- "Bedrock": Bedrock,
117
- "BedrockChat": BedrockChat,
118
- }
119
- )
120
- except NameError:
121
- pass
122
-
123
-
124
- MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
133
+ MODEL_PROMPT_ENGINES: dict[str, type[PromptEngine]] = {
125
134
  **{m: ChatGptPromptEngine for m in openai_models},
126
135
  **{m: ClaudePromptEngine for m in claude_models},
127
136
  **{m: Llama2PromptEngine for m in llama2_models},
@@ -131,11 +140,6 @@ MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
131
140
  **{m: MistralPromptEngine for m in mistral_models},
132
141
  }
133
142
 
134
- _open_ai_defaults: dict[str, str] = {
135
- "openai_api_key": os.getenv("OPENAI_API_KEY"),
136
- "openai_organization": os.getenv("OPENAI_ORG_ID"),
137
- }
138
-
139
143
  MODEL_ID_TO_LONG_ID = {
140
144
  **{m: mr for m, mr in openai_model_reroutes.items()},
141
145
  "bedrock-claude-v2": "anthropic.claude-v2",
@@ -167,7 +171,7 @@ DEFAULT_MODELS = list(MODEL_DEFAULT_ARGUMENTS.keys())
167
171
 
168
172
  MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "llm"
169
173
 
170
- MODEL_TYPES: dict[str, PromptEngine] = {
174
+ MODEL_TYPES: dict[str, str] = {
171
175
  **{m: "OpenAI" for m in openai_models},
172
176
  **{m: "BedrockChat" for m in bedrock_models},
173
177
  }
@@ -210,47 +214,90 @@ def get_available_model_names() -> list[str]:
210
214
  return avaialable_models
211
215
 
212
216
 
213
- def load_model(
214
- user_model_name: str,
215
- ) -> tuple[BaseLanguageModel, str, int, dict[str, float]]:
217
+ def load_model(model_id) -> JanusModel:
216
218
  if not MODEL_CONFIG_DIR.exists():
217
219
  MODEL_CONFIG_DIR.mkdir(parents=True)
218
- model_config_file = MODEL_CONFIG_DIR / f"{user_model_name}.json"
219
- if not model_config_file.exists():
220
- log.warning(
221
- f"Model {user_model_name} not found in user-defined models, searching "
222
- f"default models for {user_model_name}."
223
- )
224
- model_id = user_model_name
225
- if user_model_name not in DEFAULT_MODELS:
226
- message = (
227
- f"Model {user_model_name} not found in default models. Make sure to run "
228
- "`janus llm add` first."
229
- )
230
- log.error(message)
231
- raise ValueError(message)
232
- model_config = {
233
- "model_type": MODEL_TYPES[model_id],
234
- "model_id": model_id,
235
- "model_args": MODEL_DEFAULT_ARGUMENTS[model_id],
236
- "token_limit": TOKEN_LIMITS.get(MODEL_ID_TO_LONG_ID[model_id], 4096),
237
- "model_cost": COST_PER_1K_TOKENS.get(
238
- MODEL_ID_TO_LONG_ID[model_id], {"input": 0, "output": 0}
239
- ),
240
- }
241
- with open(model_config_file, "w") as f:
242
- json.dump(model_config, f)
243
- else:
220
+ model_config_file = MODEL_CONFIG_DIR / f"{model_id}.json"
221
+
222
+ if model_config_file.exists():
223
+ log.info(f"Loading {model_id} from {model_config_file}.")
244
224
  with open(model_config_file, "r") as f:
245
225
  model_config = json.load(f)
246
- model_constructor = MODEL_TYPE_CONSTRUCTORS[model_config["model_type"]]
247
- model_args = model_config["model_args"]
248
- if model_config["model_type"] == "OpenAI":
249
- model_args.update(_open_ai_defaults)
250
- model = model_constructor(**model_args)
251
- return (
252
- model,
253
- model_config["model_id"],
254
- model_config["token_limit"],
255
- model_config["model_cost"],
226
+ model_type_name = model_config["model_type"]
227
+ model_id = model_config["model_id"]
228
+ model_args = model_config["model_args"]
229
+ token_limit = model_config["token_limit"]
230
+ input_token_cost = model_config["model_cost"]["input"]
231
+ output_token_cost = model_config["model_cost"]["output"]
232
+
233
+ elif model_id in DEFAULT_MODELS:
234
+ model_id = model_id
235
+ model_long_id = MODEL_ID_TO_LONG_ID[model_id]
236
+ model_type_name = MODEL_TYPES[model_id]
237
+ model_args = MODEL_DEFAULT_ARGUMENTS[model_id]
238
+
239
+ token_limit = 0
240
+ input_token_cost = 0.0
241
+ output_token_cost = 0.0
242
+ if model_long_id in TOKEN_LIMITS:
243
+ token_limit = TOKEN_LIMITS[model_long_id]
244
+ if model_long_id in COST_PER_1K_TOKENS:
245
+ token_limits = COST_PER_1K_TOKENS[model_long_id]
246
+ input_token_cost = token_limits["input"]
247
+ output_token_cost = token_limits["output"]
248
+
249
+ else:
250
+ model_list = "\n\t".join(DEFAULT_MODELS)
251
+ message = (
252
+ f"Model {model_id} not found in user-defined model directory "
253
+ f"({MODEL_CONFIG_DIR}), and is not a default model. Valid default "
254
+ f"models:\n\t{model_list}\n"
255
+ f"To use a custom model, first run `janus llm add`."
256
+ )
257
+ log.error(message)
258
+ raise ValueError(message)
259
+
260
+ if model_type_name == "HuggingFaceLocal":
261
+ model = HuggingFacePipeline.from_model_id(
262
+ model_id=model_id,
263
+ task="text-generation",
264
+ model_kwargs=model_args,
265
+ )
266
+ model_args.update(pipeline=model.pipeline)
267
+
268
+ elif model_type_name == "OpenAI":
269
+ model_args.update(
270
+ openai_api_key=str(os.getenv("OPENAI_API_KEY")),
271
+ openai_organization=str(os.getenv("OPENAI_ORG_ID")),
272
+ )
273
+ log.warning("Do NOT use this model in sensitive environments!")
274
+ log.warning("If you would like to cancel, please press Ctrl+C.")
275
+ log.warning("Waiting 10 seconds...")
276
+ # Give enough time for the user to read the warnings and cancel
277
+ time.sleep(10)
278
+
279
+ model_type = MODEL_TYPE_CONSTRUCTORS[model_type_name]
280
+ prompt_engine = MODEL_PROMPT_ENGINES[model_id]
281
+
282
+ class JanusModel(model_type):
283
+ model_id: str
284
+ short_model_id: str
285
+ model_type_name: str
286
+ token_limit: int
287
+ input_token_cost: float
288
+ output_token_cost: float
289
+ prompt_engine: type[PromptEngine]
290
+
291
+ model_args.update(
292
+ model_id=MODEL_ID_TO_LONG_ID[model_id],
293
+ short_model_id=model_id,
294
+ )
295
+
296
+ return JanusModel(
297
+ model_type_name=model_type_name,
298
+ token_limit=token_limit,
299
+ input_token_cost=input_token_cost,
300
+ output_token_cost=output_token_cost,
301
+ prompt_engine=prompt_engine,
302
+ **model_args,
256
303
  )
janus/metrics/metric.py CHANGED
@@ -8,6 +8,7 @@ import typer
8
8
  from typing_extensions import Annotated
9
9
 
10
10
  from janus.llm import load_model
11
+ from janus.llm.model_callbacks import COST_PER_1K_TOKENS
11
12
  from janus.metrics.cli import evaluate
12
13
  from janus.metrics.file_pairing import FILE_PAIRING_METHODS
13
14
  from janus.metrics.splitting import SPLITTING_METHODS
@@ -135,7 +136,7 @@ def metric(
135
136
  **kwargs,
136
137
  ):
137
138
  out = []
138
- llm, _, token_limit, model_cost = load_model(llm_name)
139
+ llm = load_model(llm_name)
139
140
  if json_file_name is not None:
140
141
  with open(json_file_name, "r") as f:
141
142
  json_obj = json.load(f)
@@ -171,8 +172,8 @@ def metric(
171
172
  out_file=out_file,
172
173
  lang=language,
173
174
  llm=llm,
174
- token_limit=token_limit,
175
- model_cost=model_cost,
175
+ token_limit=llm.token_limit,
176
+ model_cost=COST_PER_1K_TOKENS[llm.model_id],
176
177
  )
177
178
  else:
178
179
  raise ValueError(
@@ -187,8 +188,8 @@ def metric(
187
188
  progress,
188
189
  language,
189
190
  llm,
190
- token_limit,
191
- model_cost,
191
+ llm.token_limit,
192
+ COST_PER_1K_TOKENS[llm.model_id],
192
193
  *args,
193
194
  **kwargs,
194
195
  )
@@ -199,8 +200,8 @@ def metric(
199
200
  progress,
200
201
  language,
201
202
  llm,
202
- token_limit,
203
- model_cost,
203
+ llm.token_limit,
204
+ COST_PER_1K_TOKENS[llm.model_id],
204
205
  *args,
205
206
  **kwargs,
206
207
  )
@@ -296,7 +297,7 @@ def metric(
296
297
  *args,
297
298
  **kwargs,
298
299
  ):
299
- llm, _, token_limit, model_cost = load_model(llm_name)
300
+ llm = load_model(llm_name)
300
301
  if json_file_name is not None:
301
302
  with open(json_file_name, "r") as f:
302
303
  json_obj = json.load(f)
@@ -328,8 +329,8 @@ def metric(
328
329
  out_file=out_file,
329
330
  lang=language,
330
331
  llm=llm,
331
- token_limit=token_limit,
332
- model_cost=model_cost,
332
+ token_limit=llm.token_limit,
333
+ model_cost=COST_PER_1K_TOKENS[llm.model_id],
333
334
  )
334
335
  else:
335
336
  raise ValueError(
@@ -344,8 +345,8 @@ def metric(
344
345
  progress,
345
346
  language,
346
347
  llm,
347
- token_limit,
348
- model_cost,
348
+ llm.token_limit,
349
+ COST_PER_1K_TOKENS[llm.model_id],
349
350
  *args,
350
351
  **kwargs,
351
352
  )
@@ -356,8 +357,8 @@ def metric(
356
357
  progress,
357
358
  language,
358
359
  llm,
359
- token_limit,
360
- model_cost,
360
+ llm.token_limit,
361
+ COST_PER_1K_TOKENS[llm.model_id],
361
362
  *args,
362
363
  **kwargs,
363
364
  )
janus/parsers/uml.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import re
2
2
  import subprocess # nosec
3
3
  from pathlib import Path
4
- from typing import List, Tuple
4
+ from tempfile import NamedTemporaryFile
5
5
 
6
6
  from langchain_core.exceptions import OutputParserException
7
7
  from langchain_core.messages import BaseMessage
@@ -13,39 +13,76 @@ log = create_logger(__name__)
13
13
 
14
14
 
15
15
  class UMLSyntaxParser(CodeParser):
16
- def _get_uml_output(self, file: Path) -> Tuple[str, str]:
17
- # NOTE: running subprocess with shell=False, added nosec to label that we know
18
- # risk exists
16
+ def _check_plantuml(self, text: str) -> None:
17
+ # Leading newlines can break the parser, remove them
18
+ text = text.replace("\\n", "\n").strip()
19
+
20
+ # Write the text to a temporary file (automatically deleted)
21
+ file = NamedTemporaryFile()
22
+ fname = file.name
23
+ with open(fname, "w") as fin:
24
+ fin.write(text)
25
+
19
26
  try:
20
27
  plantuml_path = Path.home().expanduser() / ".janus/lib/plantuml.jar"
28
+ # NOTE: running subprocess with shell=False, added nosec to
29
+ # label that we know risk exists
21
30
  res = subprocess.run(
22
- ["java", "-jar", plantuml_path, file],
31
+ ["java", "-jar", plantuml_path, fname],
23
32
  stdout=subprocess.PIPE,
24
33
  stderr=subprocess.PIPE,
25
34
  ) # nosec
26
35
  stdout = res.stdout.decode("utf-8")
27
36
  stderr = res.stderr.decode("utf-8")
28
37
  except FileNotFoundError:
29
- log.warning("Plant UML executable not found, skipping syntax check")
30
- stdout = ""
31
- stderr = ""
32
- return stdout, stderr
38
+ err_txt = (
39
+ "Plant UML executable not found. Either choose a different parser"
40
+ " or install with `bash scripts/install_plantuml.sh`. Java and"
41
+ " graphviz are dependencies for the tool, they must also be installed."
42
+ )
43
+ log.error(err_txt)
44
+ raise Exception(err_txt)
45
+
46
+ # Check for bad outputs, raise OutputParserExceptions if so
47
+ if "Error" in stderr or "Error" in stdout:
48
+ err_txt = "Recieved UML parsing error(s)."
49
+
50
+ line_nos = self._get_error_lines(stderr) + self._get_error_lines(stdout)
51
+ lines = text.split("\n")
52
+ for i in line_nos:
53
+ i0 = max(0, i - 3)
54
+ i1 = min(len(lines) - 1, i + 2)
55
+ err_lines = [
56
+ f"> {lines[j]}" if j == i - 1 else f" {lines[j]}"
57
+ for j in range(i0, i1)
58
+ ]
59
+ if i0:
60
+ err_lines.insert(0, " ...")
61
+ if i1 < (len(lines) - 1):
62
+ err_lines.append(" ...")
33
63
 
34
- def _get_errs(self, s: str) -> List[str]:
35
- return [x.group() for x in re.finditer(r"Error (.*)\n", s)]
64
+ err_txt += f"\nError located at line {i} must be fixed:\n"
65
+ err_txt += "\n".join(err_lines)
66
+ log.warning(err_txt)
67
+ raise OutputParserException(err_txt)
68
+
69
+ if "Warning" in stdout or "Warning" in stderr:
70
+ err_txt = "Recieved UML parsing warning (often due to missing PLANTUML)."
71
+ if stderr:
72
+ err_txt += f"\nSTDERR:\n```\n{stderr.strip()}\n```\n"
73
+ if stdout:
74
+ err_txt += f"\nSTDOUT:\n```\n{stdout.strip()}\n```\n"
75
+
76
+ log.warning(err_txt)
77
+ raise OutputParserException(err_txt)
78
+
79
+ def _get_error_lines(self, s: str) -> list[int]:
80
+ return [int(x.group(1)) for x in re.finditer(r"Error line (\d+) in file:", s)]
81
+
82
+ def _get_warns(self, s: str) -> list[str]:
83
+ return [x.group() for x in re.finditer(r"Warning: (.*)\n", s)]
36
84
 
37
85
  def parse(self, text: str | BaseMessage) -> str:
38
86
  text = super().parse(text)
39
- janus_path = Path.home().expanduser() / Path(".janus")
40
- if not janus_path.exists():
41
- janus_path.mkdir()
42
- temp_file_path = janus_path / "tmp.txt"
43
- with open(temp_file_path, "w") as f:
44
- f.write(text)
45
- uml_std_out, uml_std_err = self._get_uml_output(temp_file_path)
46
- uml_errs = self._get_errs(uml_std_out) + self._get_errs(uml_std_err)
47
- if len(uml_errs) > 0:
48
- raise OutputParserException(
49
- "Error: Received UML Errors:\n" + "\n".join(uml_errs)
50
- )
87
+ self._check_plantuml(text)
51
88
  return text
janus/refiners/refiner.py CHANGED
@@ -1,73 +1,115 @@
1
- from langchain_core.prompts import ChatPromptTemplate
1
+ from typing import Any
2
2
 
3
- from janus.llm.models_info import MODEL_PROMPT_ENGINES
3
+ from langchain.output_parsers import RetryWithErrorOutputParser
4
+ from langchain_core.output_parsers import StrOutputParser
5
+ from langchain_core.prompt_values import PromptValue
6
+ from langchain_core.runnables import RunnableSerializable
4
7
 
8
+ from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel
9
+ from janus.parsers.parser import JanusParser
10
+ from janus.utils.logger import create_logger
5
11
 
6
- class Refiner:
7
- def refine(
8
- self,
9
- original_prompt: str,
10
- previous_prompt: str,
11
- previous_output: str,
12
- errors: str,
13
- **kwargs,
14
- ) -> tuple[ChatPromptTemplate, dict[str, str]]:
15
- """Creates a new prompt based on feedback from original results
16
-
17
- Arguments:
18
- original_prompt: original prompt used to produce output
19
- original_output: origial output of llm
20
- errors: list of errors detected by parser
21
-
22
- Returns:
23
- Tuple of new prompt and prompt arguments
24
- """
12
+ log = create_logger(__name__)
13
+
14
+
15
+ class JanusRefiner(JanusParser):
16
+ parser: JanusParser
17
+
18
+ def parse_runnable(self, input: dict[str, Any]) -> Any:
19
+ return self.parse_completion(**input)
20
+
21
+ def parse_completion(self, completion: str, **kwargs) -> Any:
22
+ return self.parser.parse(completion)
23
+
24
+ def parse(self, text: str) -> str:
25
25
  raise NotImplementedError
26
26
 
27
27
 
28
- class BasicRefiner(Refiner):
28
+ class FixParserExceptions(JanusRefiner, RetryWithErrorOutputParser):
29
+ def __init__(self, llm: JanusModel, parser: JanusParser, max_retries: int):
30
+ retry_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
31
+ source_language="text",
32
+ prompt_template="refinement/fix_exceptions",
33
+ ).prompt
34
+ chain = retry_prompt | llm | StrOutputParser()
35
+ RetryWithErrorOutputParser.__init__(
36
+ self, parser=parser, retry_chain=chain, max_retries=max_retries
37
+ )
38
+
39
+ def parse_completion(
40
+ self, completion: str, prompt_value: PromptValue, **kwargs
41
+ ) -> Any:
42
+ return self.parse_with_prompt(completion, prompt_value=prompt_value)
43
+
44
+
45
+ class ReflectionRefiner(JanusRefiner):
46
+ max_retries: int
47
+ reflection_chain: RunnableSerializable
48
+ revision_chain: RunnableSerializable
49
+
29
50
  def __init__(
30
51
  self,
31
- prompt_name: str,
32
- model_id: str,
33
- source_language: str,
34
- ) -> None:
35
- """Basic refiner, asks llm to fix output of previous prompt given errors
36
-
37
- Arguments:
38
- prompt_name: refinement prompt name to use
39
- model_id: ID of the llm to use. Found in models_info.py
40
- source_language: source_langauge to use
41
- """
42
- self._prompt_name = prompt_name
43
- self._model_id = model_id
44
- self._source_language = source_language
45
-
46
- def refine(
47
- self,
48
- original_prompt: str,
49
- previous_prompt: str,
50
- previous_output: str,
51
- errors: str,
52
- **kwargs,
53
- ) -> tuple[ChatPromptTemplate, dict[str, str]]:
54
- """Creates a new prompt based on feedback from original results
55
-
56
- Arguments:
57
- original_prompt: original prompt used to produce output
58
- original_output: origial output of llm
59
- errors: list of errors detected by parser
60
-
61
- Returns:
62
- Tuple of new prompt and prompt arguments
63
- """
64
- prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
65
- prompt_template=self._prompt_name,
66
- source_language=self._source_language,
52
+ llm: JanusModel,
53
+ parser: JanusParser,
54
+ max_retries: int,
55
+ prompt_template_name: str = "refinement/reflection",
56
+ ):
57
+ reflection_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
58
+ source_language="text",
59
+ prompt_template=prompt_template_name,
60
+ ).prompt
61
+ revision_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
62
+ source_language="text",
63
+ prompt_template="refinement/revision",
64
+ ).prompt
65
+
66
+ reflection_chain = reflection_prompt | llm | StrOutputParser()
67
+ revision_chain = revision_prompt | llm | StrOutputParser()
68
+ super().__init__(
69
+ reflection_chain=reflection_chain,
70
+ revision_chain=revision_chain,
71
+ parser=parser,
72
+ max_retries=max_retries,
73
+ )
74
+
75
+ def parse_completion(
76
+ self, completion: str, prompt_value: PromptValue, **kwargs
77
+ ) -> Any:
78
+ for retry_number in range(self.max_retries):
79
+ reflection = self.reflection_chain.invoke(
80
+ dict(
81
+ prompt=prompt_value.to_string(),
82
+ completion=completion,
83
+ )
84
+ )
85
+ if reflection.strip() == "LGTM":
86
+ return self.parser.parse(completion)
87
+ if not retry_number:
88
+ log.info(f"Completion:\n{completion}")
89
+ log.info(f"Reflection:\n{reflection}")
90
+ completion = self.revision_chain.invoke(
91
+ dict(
92
+ prompt=prompt_value.to_string(),
93
+ completion=completion,
94
+ reflection=reflection,
95
+ )
96
+ )
97
+ log.info(f"Revision:\n{completion}")
98
+
99
+ return self.parser.parse(completion)
100
+
101
+
102
+ class HallucinationRefiner(ReflectionRefiner):
103
+ def __init__(self, **kwargs):
104
+ super().__init__(
105
+ prompt_template_name="refinement/hallucination",
106
+ **kwargs,
67
107
  )
68
- prompt_arguments = {
69
- "ORIGINAL_PROMPT": original_prompt,
70
- "OUTPUT": previous_output,
71
- "ERRORS": errors,
72
- }
73
- return prompt_engine.prompt, prompt_arguments
108
+
109
+
110
+ REFINERS = dict(
111
+ none=JanusRefiner,
112
+ parser=FixParserExceptions,
113
+ reflection=ReflectionRefiner,
114
+ hallucination=HallucinationRefiner,
115
+ )
@@ -0,0 +1,42 @@
1
+ from langchain_core.retrievers import BaseRetriever
2
+ from langchain_core.runnables import Runnable, RunnableConfig
3
+
4
+ from janus.language.block import CodeBlock
5
+
6
+
7
+ class JanusRetriever(Runnable):
8
+ def __init__(self) -> None:
9
+ super().__init__()
10
+
11
+ def invoke(
12
+ self, input: CodeBlock, config: RunnableConfig | None = None, **kwargs
13
+ ) -> dict:
14
+ kwargs.update(context=self.get_context(input))
15
+ return kwargs
16
+
17
+ def get_context(self, code_block: CodeBlock) -> str:
18
+ return ""
19
+
20
+
21
+ class ActiveUsingsRetriever(JanusRetriever):
22
+ def get_context(self, code_block: CodeBlock) -> str:
23
+ context = "\n".join(
24
+ f"{context_tag}: {context}"
25
+ for context_tag, context in code_block.context_tags.items()
26
+ )
27
+ return f"You may use the following additional context: {context}"
28
+
29
+
30
+ class TextSearchRetriever(JanusRetriever):
31
+ retriever: BaseRetriever
32
+
33
+ def __init__(self, retriever: BaseRetriever):
34
+ super().__init__()
35
+ self.retriever = retriever
36
+
37
+ def get_context(self, code_block: CodeBlock) -> str:
38
+ if code_block.text is None:
39
+ return ""
40
+ docs = self.retriever.invoke(code_block.text)
41
+ context = "\n\n".join(doc.page_content for doc in docs)
42
+ return f"You may use the following additional context: {context}"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: janus-llm
3
- Version: 3.5.2
3
+ Version: 4.0.0
4
4
  Summary: A transcoding library using LLMs.
5
5
  Home-page: https://github.com/janus-llm/janus-llm
6
6
  License: Apache 2.0