janus-llm 4.1.0__py3-none-any.whl → 4.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,15 +1,241 @@
1
+ import json
2
+ import re
3
+ from copy import deepcopy
4
+
5
+ from langchain_core.runnables import Runnable, RunnableLambda, RunnableParallel
6
+
1
7
  from janus.converter.converter import Converter
8
+ from janus.language.block import TranslatedCodeBlock
2
9
  from janus.language.combine import JsonCombiner
3
- from janus.parsers.eval_parser import EvaluationParser
10
+ from janus.parsers.eval_parsers.incose_parser import IncoseParser
11
+ from janus.parsers.eval_parsers.inline_comment_parser import InlineCommentParser
4
12
  from janus.utils.logger import create_logger
5
13
 
6
14
  log = create_logger(__name__)
7
15
 
8
16
 
9
17
  class Evaluator(Converter):
10
- def __init__(self, **kwargs):
18
+ """Evaluator
19
+
20
+ A class that performs an LLM self evaluation"
21
+ "on an input target, with an associated prompt.
22
+
23
+ Current valid evaluation types:
24
+ ['incose', 'comments']
25
+
26
+ """
27
+
28
+ def __init__(self, **kwargs) -> None:
29
+ """Initialize the Evaluator class
30
+
31
+ Arguments:
32
+ model: The LLM to use for translation. If an OpenAI model, the
33
+ `OPENAI_API_KEY` environment variable must be set and the
34
+ `OPENAI_ORG_ID` environment variable should be set if needed.
35
+ model_arguments: Additional arguments to pass to the LLM constructor.
36
+ max_prompts: The maximum number of prompts to try before giving up.
37
+ """
38
+ super().__init__(**kwargs)
39
+ self._combiner = JsonCombiner()
40
+ self._load_parameters()
41
+
42
+
43
+ class RequirementEvaluator(Evaluator):
44
+ """INCOSE Requirement Evaluator
45
+
46
+ A class that performs an LLM self evaluation on an input target,
47
+ with an associated prompt.
48
+
49
+ The evaluation prompts are for Incose Evaluations
50
+
51
+ """
52
+
53
+ def __init__(self, eval_items_per_request: int | None = None, **kwargs) -> None:
54
+ """Initialize the Evaluator class
55
+
56
+ Arguments:
57
+ model: The LLM to use for translation. If an OpenAI model, the
58
+ `OPENAI_API_KEY` environment variable must be set and the
59
+ `OPENAI_ORG_ID` environment variable should be set if needed.
60
+ model_arguments: Additional arguments to pass to the LLM constructor.
61
+ max_prompts: The maximum number of prompts to try before giving up.
62
+ """
63
+ super().__init__(**kwargs)
64
+ self.eval_items_per_request = eval_items_per_request
65
+ self._parser = IncoseParser()
66
+ self.set_prompt("eval_prompts/incose")
67
+
68
+ def _input_runnable(self) -> Runnable:
69
+ def _get_code(json_text: str) -> str:
70
+ return json.loads(json_text)["code"]
71
+
72
+ def _get_reqs(json_text: str) -> str:
73
+ return json.dumps(json.loads(json_text)["requirements"])
74
+
75
+ return RunnableLambda(self._parser.parse_input) | RunnableParallel(
76
+ SOURCE_CODE=_get_code,
77
+ REQUIREMENTS=_get_reqs,
78
+ context=self._retriever,
79
+ )
80
+
81
+ def _add_translation(self, block: TranslatedCodeBlock):
82
+ if block.translated:
83
+ return
84
+
85
+ if block.original.text is None:
86
+ block.translated = True
87
+ return
88
+
89
+ if self.eval_items_per_request is None:
90
+ return super()._add_translation(block)
91
+
92
+ input_obj = json.loads(block.original.text)
93
+ requirements = input_obj.get("requirements", [])
94
+
95
+ if not requirements:
96
+ log.debug(f"[{block.name}] Skipping empty block")
97
+ block.translated = True
98
+ block.text = None
99
+ block.complete = True
100
+ return
101
+
102
+ # For some reason requirements objects are in nested lists?
103
+ while isinstance(requirements[0], list):
104
+ requirements = [r for lst in requirements for r in lst]
105
+
106
+ if len(requirements) <= self.eval_items_per_request:
107
+ input_obj["requirements"] = requirements
108
+ block.original.text = json.dumps(input_obj)
109
+ return super()._add_translation(block)
110
+
111
+ block.processing_time = 0
112
+ block.cost = 0
113
+ block.retries = 0
114
+ obj = {}
115
+ for i in range(0, len(requirements), self.eval_items_per_request):
116
+ # Build a new TranslatedBlock using the new working text
117
+ working_requirements = requirements[i : i + self.eval_items_per_request]
118
+ working_copy = deepcopy(block.original)
119
+ working_obj = json.loads(working_copy.text) # type: ignore
120
+ working_obj["requirements"] = working_requirements
121
+ working_copy.text = json.dumps(working_obj)
122
+ working_block = TranslatedCodeBlock(working_copy, self._target_language)
123
+
124
+ # Run the LLM on the working text
125
+ super()._add_translation(working_block)
126
+
127
+ # Update metadata to include for all runs
128
+ block.retries += working_block.retries
129
+ block.cost += working_block.cost
130
+ block.processing_time += working_block.processing_time
131
+
132
+ # Update the output text to merge this section's output in
133
+ obj.update(json.loads(working_block.text))
134
+
135
+ block.text = json.dumps(obj)
136
+ block.tokens = self._llm.get_num_tokens(block.text)
137
+ block.translated = True
138
+
139
+ log.debug(
140
+ f"[{block.name}] Output code:\n{json.dumps(json.loads(block.text), indent=2)}"
141
+ )
142
+
143
+
144
+ class InlineCommentEvaluator(Evaluator):
145
+ """Inline Comment Evaluator
146
+
147
+ A class that performs an LLM self evaluation on inline comments,
148
+ with an associated prompt.
149
+ """
150
+
151
+ def __init__(self, eval_items_per_request: int | None = None, **kwargs) -> None:
152
+ """Initialize the Evaluator class
153
+
154
+ Arguments:
155
+ model: The LLM to use for translation. If an OpenAI model, the
156
+ `OPENAI_API_KEY` environment variable must be set and the
157
+ `OPENAI_ORG_ID` environment variable should be set if needed.
158
+ model_arguments: Additional arguments to pass to the LLM constructor.
159
+ max_prompts: The maximum number of prompts to try before giving up.
160
+ """
11
161
  super().__init__(**kwargs)
12
- self.set_prompt("evaluate")
13
162
  self._combiner = JsonCombiner()
14
- self._parser = EvaluationParser()
15
163
  self._load_parameters()
164
+ self._parser = InlineCommentParser()
165
+ self.set_prompt("eval_prompts/inline_comments")
166
+ self.eval_items_per_request = eval_items_per_request
167
+
168
+ def _add_translation(self, block: TranslatedCodeBlock):
169
+ if block.translated:
170
+ return
171
+
172
+ if block.original.text is None:
173
+ block.translated = True
174
+ return
175
+
176
+ if self.eval_items_per_request is None:
177
+ return super()._add_translation(block)
178
+
179
+ comment_pattern = r"<(?:INLINE|BLOCK)_COMMENT \w{8}>.*$"
180
+ comments = list(
181
+ re.finditer(comment_pattern, block.original.text, flags=re.MULTILINE)
182
+ )
183
+
184
+ if not comments:
185
+ log.info(f"[{block.name}] Skipping commentless block")
186
+ block.translated = True
187
+ block.text = None
188
+ block.complete = True
189
+ return
190
+
191
+ if len(comments) <= self.eval_items_per_request:
192
+ return super()._add_translation(block)
193
+
194
+ comment_group_indices = list(range(0, len(comments), self.eval_items_per_request))
195
+ log.debug(
196
+ f"[{block.name}] Block contains more than {self.eval_items_per_request}"
197
+ f" comments, splitting {len(comments)} comments into"
198
+ f" {len(comment_group_indices)} groups"
199
+ )
200
+
201
+ block.processing_time = 0
202
+ block.cost = 0
203
+ block.retries = 0
204
+ obj = {}
205
+ for i in range(0, len(comments), self.eval_items_per_request):
206
+ # Split the text into the section containing comments of interest,
207
+ # all the text prior to those comments, and all the text after them
208
+ working_comments = comments[i : i + self.eval_items_per_request]
209
+ start_idx = working_comments[0].start()
210
+ end_idx = working_comments[-1].end()
211
+ prefix = block.original.text[:start_idx]
212
+ keeper = block.original.text[start_idx:end_idx]
213
+ suffix = block.original.text[end_idx:]
214
+
215
+ # Strip all comment placeholders outside of the section of interest
216
+ prefix = re.sub(comment_pattern, "", prefix, flags=re.MULTILINE)
217
+ suffix = re.sub(comment_pattern, "", suffix, flags=re.MULTILINE)
218
+
219
+ # Build a new TranslatedBlock using the new working text
220
+ working_copy = deepcopy(block.original)
221
+ working_copy.text = prefix + keeper + suffix
222
+ working_block = TranslatedCodeBlock(working_copy, self._target_language)
223
+
224
+ # Run the LLM on the working text
225
+ super()._add_translation(working_block)
226
+
227
+ # Update metadata to include for all runs
228
+ block.retries += working_block.retries
229
+ block.cost += working_block.cost
230
+ block.processing_time += working_block.processing_time
231
+
232
+ # Update the output text to merge this section's output in
233
+ obj.update(json.loads(working_block.text))
234
+
235
+ block.text = json.dumps(obj)
236
+ block.tokens = self._llm.get_num_tokens(block.text)
237
+ block.translated = True
238
+
239
+ log.debug(
240
+ f"[{block.name}] Output code:\n{json.dumps(json.loads(block.text), indent=2)}"
241
+ )
@@ -0,0 +1,27 @@
1
+ from pathlib import Path
2
+
3
+ from janus.converter.converter import Converter
4
+ from janus.language.block import TranslatedCodeBlock
5
+ from janus.parsers.partition_parser import PartitionParser
6
+ from janus.utils.logger import create_logger
7
+
8
+ log = create_logger(__name__)
9
+
10
+
11
+ class Partitioner(Converter):
12
+ def __init__(self, partition_token_limit: int, **kwargs):
13
+ super().__init__(**kwargs)
14
+ self.set_prompt("partition")
15
+ self._load_model()
16
+ self._parser = PartitionParser(
17
+ token_limit=partition_token_limit,
18
+ model=self._llm,
19
+ )
20
+ self._target_language = self._source_language
21
+ self._target_suffix = self._source_suffix
22
+ self._load_parameters()
23
+
24
+ def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
25
+ output_str = self._parser.parse_combined_output(block.complete_text)
26
+ out_path.parent.mkdir(parents=True, exist_ok=True)
27
+ out_path.write_text(output_str, encoding="utf-8")
@@ -20,7 +20,7 @@ class TestAlcSplitter(unittest.TestCase):
20
20
  def test_split(self):
21
21
  """Test the split method."""
22
22
  tree_root = self.splitter.split(self.test_file)
23
- self.assertAlmostEqual(tree_root.n_descendents, 32, delta=5)
23
+ self.assertAlmostEqual(tree_root.n_descendents, 16, delta=2)
24
24
  self.assertLessEqual(tree_root.max_tokens, self.splitter.max_tokens)
25
25
  self.assertFalse(tree_root.complete)
26
26
  self.combiner.combine_children(tree_root)
janus/language/alc/alc.py CHANGED
@@ -79,10 +79,15 @@ class AlcSplitter(TreeSitterSplitter):
79
79
  if len(sects) > 1:
80
80
  block.children = []
81
81
  for sect in sects:
82
- if sect[0].node_type in sect_types:
83
- sect_node = self.merge_nodes(sect)
84
- sect_node.children = sect
85
- sect_node.node_type = NodeType(str(sect[0].node_type)[:5])
82
+ node_type = sect[0].node_type
83
+ if node_type in sect_types:
84
+ if len(sect) == 1:
85
+ # Don't make a node its own child
86
+ sect_node = sect[0]
87
+ else:
88
+ sect_node = self.merge_nodes(sect)
89
+ sect_node.children = sect
90
+ sect_node.node_type = NodeType(str(node_type)[:5])
86
91
  block.children.append(sect_node)
87
92
  else:
88
93
  block.children.extend(sect)
janus/language/combine.py CHANGED
@@ -1,3 +1,5 @@
1
+ import re
2
+
1
3
  from janus.language.block import CodeBlock, TranslatedCodeBlock
2
4
  from janus.language.file import FileManager
3
5
  from janus.utils.logger import create_logger
@@ -90,3 +92,23 @@ class ChunkCombiner(Combiner):
90
92
  root: The functional code block to combine with its children.
91
93
  """
92
94
  return root
95
+
96
+
97
+ class PartitionCombiner(Combiner):
98
+ @staticmethod
99
+ def combine(root: CodeBlock) -> None:
100
+ """A combiner which inserts partition tags between code blocks"""
101
+ queue = [root]
102
+ while queue:
103
+ block = queue.pop(0)
104
+ if block.children:
105
+ queue.extend(block.children)
106
+ else:
107
+ block.affixes = (block.prefix, block.suffix + "\n<JANUS_PARTITION>\n")
108
+
109
+ super(PartitionCombiner, PartitionCombiner).combine(root)
110
+ root.text = re.sub(r"(?:\n<JANUS_PARTITION>\n)+$", "", root.text)
111
+ root.affixes = (
112
+ root.prefix,
113
+ re.sub(r"(?:\n<JANUS_PARTITION>\n)+$", "", root.suffix),
114
+ )
@@ -275,42 +275,50 @@ class Splitter(FileManager):
275
275
 
276
276
  groups = [[n] for n in nodes]
277
277
  while len(groups) > 1 and min(adj_sums) <= self.max_tokens and any(merge_allowed):
278
- # Get the indices of the adjacent nodes that would result in the
279
- # smallest possible merged snippet. Ignore protected nodes.
278
+ # Get the index of the node that would result in the smallest
279
+ # merged snippet when merged with the node that follows it.
280
+ # Ignore protected nodes.
280
281
  mergeable_indices = compress(range(len(adj_sums)), merge_allowed)
281
- i0 = int(min(mergeable_indices, key=adj_sums.__getitem__))
282
- i1 = i0 + 1
282
+ C = int(min(mergeable_indices, key=adj_sums.__getitem__))
283
+
284
+ # C: Central index
285
+ # L: Index to the left
286
+ # R: Index to the right (to be merged in to C)
287
+ # N: Next index (to the right of R, the "new R")
288
+ L, R, N = C - 1, C + 1, C + 2
283
289
 
284
290
  # Recalculate the length. We can't simply use the adj_sum, because
285
291
  # it is an underestimate due to the adjoining suffix/prefix.
286
- central_node = groups[i0][-1]
287
- merged_text = "".join([text_chunks[i0], central_node.suffix, text_chunks[i1]])
292
+ central_node = groups[C][-1]
293
+ merged_text = "".join([text_chunks[C], central_node.suffix, text_chunks[R]])
288
294
  merged_text_length = self._count_tokens(merged_text)
289
295
 
290
296
  # If the true length of the merged pair is too long, don't merge them
291
297
  # Instead, correct the estimate, since shorter pairs may yet exist
292
298
  if merged_text_length > self.max_tokens:
293
- adj_sums[i0] = merged_text_length
299
+ adj_sums[C] = merged_text_length
294
300
  continue
295
301
 
296
302
  # Update adjacent sum estimates
297
- if i0 > 0:
298
- adj_sums[i0 - 1] += merged_text_length
299
- if i1 < len(adj_sums) - 1:
300
- adj_sums[i1 + 1] += merged_text_length
301
-
302
- if i0 > 0 and i1 < len(merge_allowed) - 1:
303
- if not (merge_allowed[i0 - 1] and merge_allowed[i1 + 1]):
304
- merge_allowed[i0 - 1] = merge_allowed[i1 + 1] = False
303
+ if L >= 0:
304
+ adj_sums[L] = lengths[L] + merged_text_length
305
+ if N < len(adj_sums):
306
+ adj_sums[R] = lengths[N] + merged_text_length
305
307
 
306
308
  # The potential merge length for this pair is removed
307
- adj_sums.pop(i0)
308
- merge_allowed.pop(i0)
309
+ adj_sums.pop(C)
310
+
311
+ # The merged-in node is removed from the protected list
312
+ # The merge_allowed list need not be updated - if the node now to
313
+ # its right is protected, the merge_allowed element corresponding
314
+ # to the merged neighbor will have been True, and now corresponds
315
+ # to the merged node.
316
+ merge_allowed.pop(C)
309
317
 
310
318
  # Merge the pair of node groups
311
- groups[i0 : i1 + 1] = [groups[i0] + groups[i1]]
312
- text_chunks[i0 : i1 + 1] = [merged_text]
313
- lengths[i0 : i1 + 1] = [merged_text_length]
319
+ groups[C:N] = [groups[C] + groups[R]]
320
+ text_chunks[C:N] = [merged_text]
321
+ lengths[C:N] = [merged_text_length]
314
322
 
315
323
  return groups
316
324
 
@@ -403,13 +411,13 @@ class Splitter(FileManager):
403
411
  self._split_into_lines(node)
404
412
 
405
413
  def _split_into_lines(self, node: CodeBlock):
406
- split_text = re.split(r"(\n+)", node.text)
414
+ split_text = list(re.split(r"(\n+)", node.text))
407
415
 
408
416
  # If the string didn't start/end with newlines, make sure to include
409
417
  # empty strings for the prefix/suffixes
410
- if split_text[0].strip("\n"):
418
+ if not re.match(r"^\n+$", split_text[0]):
411
419
  split_text = [""] + split_text
412
- if split_text[-1].strip("\n"):
420
+ if not re.match(r"^\n+$", split_text[-1]):
413
421
  split_text.append("")
414
422
  betweens = split_text[::2]
415
423
  lines = split_text[1::2]
@@ -154,7 +154,15 @@ class TreeSitterSplitter(Splitter):
154
154
  The pointer to the language.
155
155
  """
156
156
  lib = cdll.LoadLibrary(os.fspath(so_file))
157
- language_function = getattr(lib, f"tree_sitter_{self.language}")
157
+ # Added this try-except block to handle the case where the language is not
158
+ # supported in lowercase by the creator of the grammar. Ex: COBOL
159
+ # https://github.com/yutaro-sakamoto/tree-sitter-cobol/blob/main/grammar.js#L13
160
+ try:
161
+ language_function = getattr(lib, f"tree_sitter_{self.language}")
162
+ except AttributeError:
163
+ language = self.language.upper()
164
+ language_function = getattr(lib, f"tree_sitter_{language}")
165
+
158
166
  language_function.restype = c_void_p
159
167
  pointer = language_function()
160
168
  return pointer
janus/llm/models_info.py CHANGED
@@ -6,9 +6,13 @@ from typing import Callable, Protocol, TypeVar
6
6
  from dotenv import load_dotenv
7
7
  from langchain_community.llms import HuggingFaceTextGenInference
8
8
  from langchain_core.runnables import Runnable
9
- from langchain_openai import AzureChatOpenAI
9
+ from langchain_openai import AzureChatOpenAI, ChatOpenAI
10
10
 
11
- from janus.llm.model_callbacks import COST_PER_1K_TOKENS, azure_model_reroutes
11
+ from janus.llm.model_callbacks import (
12
+ COST_PER_1K_TOKENS,
13
+ azure_model_reroutes,
14
+ openai_model_reroutes,
15
+ )
12
16
  from janus.prompts.prompt import (
13
17
  ChatGptPromptEngine,
14
18
  ClaudePromptEngine,
@@ -90,6 +94,7 @@ claude_models = [
90
94
  "bedrock-claude-instant-v1",
91
95
  "bedrock-claude-haiku",
92
96
  "bedrock-claude-sonnet",
97
+ "bedrock-claude-sonnet-3.5",
93
98
  ]
94
99
  llama2_models = [
95
100
  "bedrock-llama2-70b",
@@ -126,7 +131,7 @@ bedrock_models = [
126
131
  all_models = [*azure_models, *bedrock_models]
127
132
 
128
133
  MODEL_TYPE_CONSTRUCTORS: dict[str, ModelType] = {
129
- # "OpenAI": ChatOpenAI,
134
+ "OpenAI": ChatOpenAI,
130
135
  "HuggingFace": HuggingFaceTextGenInference,
131
136
  "Azure": AzureChatOpenAI,
132
137
  "Bedrock": Bedrock,
@@ -136,7 +141,7 @@ MODEL_TYPE_CONSTRUCTORS: dict[str, ModelType] = {
136
141
 
137
142
 
138
143
  MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
139
- # **{m: ChatGptPromptEngine for m in openai_models},
144
+ **{m: ChatGptPromptEngine for m in openai_models},
140
145
  **{m: ChatGptPromptEngine for m in azure_models},
141
146
  **{m: ClaudePromptEngine for m in claude_models},
142
147
  **{m: Llama2PromptEngine for m in llama2_models},
@@ -147,12 +152,13 @@ MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
147
152
  }
148
153
 
149
154
  MODEL_ID_TO_LONG_ID = {
150
- # **{m: mr for m, mr in openai_model_reroutes.items()},
155
+ **{m: mr for m, mr in openai_model_reroutes.items()},
151
156
  **{m: mr for m, mr in azure_model_reroutes.items()},
152
157
  "bedrock-claude-v2": "anthropic.claude-v2",
153
158
  "bedrock-claude-instant-v1": "anthropic.claude-instant-v1",
154
159
  "bedrock-claude-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
155
160
  "bedrock-claude-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
161
+ "bedrock-claude-sonnet-3.5": "anthropic.claude-3-5-sonnet-20240620-v1:0",
156
162
  "bedrock-llama2-70b": "meta.llama2-70b-v1",
157
163
  "bedrock-llama2-70b-chat": "meta.llama2-70b-chat-v1",
158
164
  "bedrock-llama2-13b": "meta.llama2-13b-chat-v1",
@@ -179,7 +185,7 @@ DEFAULT_MODELS = list(MODEL_DEFAULT_ARGUMENTS.keys())
179
185
  MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "llm"
180
186
 
181
187
  MODEL_TYPES: dict[str, PromptEngine] = {
182
- # **{m: "OpenAI" for m in openai_models},
188
+ **{m: "OpenAI" for m in openai_models},
183
189
  **{m: "Azure" for m in azure_models},
184
190
  **{m: "BedrockChat" for m in bedrock_models},
185
191
  }
@@ -200,6 +206,7 @@ TOKEN_LIMITS: dict[str, int] = {
200
206
  "anthropic.claude-instant-v1": 100_000,
201
207
  "anthropic.claude-3-haiku-20240307-v1:0": 248_000,
202
208
  "anthropic.claude-3-sonnet-20240229-v1:0": 248_000,
209
+ "anthropic.claude-3-5-sonnet-20240620-v1:0": 200_000,
203
210
  "meta.llama2-70b-v1": 4096,
204
211
  "meta.llama2-70b-chat-v1": 4096,
205
212
  "meta.llama2-13b-chat-v1": 4096,
@@ -286,15 +293,16 @@ def load_model(model_id) -> JanusModel:
286
293
  # log.warning("Waiting 10 seconds...")
287
294
  # Give enough time for the user to read the warnings and cancel
288
295
  # time.sleep(10)
289
- raise DeprecationWarning("OpenAI models are no longer supported.")
296
+ # raise DeprecationWarning("OpenAI models are no longer supported.")
290
297
 
291
298
  elif model_type_name == "Azure":
292
299
  model_args.update(
293
- {
294
- "api_key": os.getenv("AZURE_OPENAI_API_KEY"),
295
- "azure_endpoint": os.getenv("AZURE_OPENAI_ENDPOINT"),
296
- "api_version": os.getenv("OPENAI_API_VERSION", "2024-02-01"),
297
- }
300
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
301
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
302
+ api_version=os.getenv("OPENAI_API_VERSION", "2024-02-01"),
303
+ azure_deployment=model_id,
304
+ request_timeout=3600,
305
+ max_tokens=4096,
298
306
  )
299
307
 
300
308
  model_type = MODEL_TYPE_CONSTRUCTORS[model_type_name]
@@ -0,0 +1,134 @@
1
+ import json
2
+ import random
3
+ import uuid
4
+ from typing import List
5
+
6
+ from langchain.output_parsers import PydanticOutputParser
7
+ from langchain_core.exceptions import OutputParserException
8
+ from langchain_core.messages import BaseMessage
9
+ from langchain_core.pydantic_v1 import BaseModel, Field, validator
10
+
11
+ from janus.language.block import CodeBlock
12
+ from janus.parsers.parser import JanusParser
13
+ from janus.utils.logger import create_logger
14
+
15
+ log = create_logger(__name__)
16
+ RNG = random.Random()
17
+
18
+
19
+ class Criteria(BaseModel):
20
+ reasoning: str = Field(description="A short explanation for the given assessment")
21
+ score: str = Field("A simple `pass` or `fail`")
22
+
23
+ @validator("score")
24
+ def score_is_valid(cls, v: str):
25
+ v = v.lower().strip()
26
+ if v not in {"pass", "fail"}:
27
+ raise OutputParserException("Score must be either 'pass' or 'fail'")
28
+ return v
29
+
30
+
31
+ class Requirement(BaseModel):
32
+ requirement_id: str = Field(description="The 8-character comment ID")
33
+ requirement: str = Field(description="The original requirement being evaluated")
34
+ C1: Criteria
35
+ C2: Criteria
36
+ C3: Criteria
37
+ C4: Criteria
38
+ C5: Criteria
39
+ C6: Criteria
40
+ C7: Criteria
41
+ C8: Criteria
42
+ C9: Criteria
43
+
44
+
45
+ class RequirementList(BaseModel):
46
+ __root__: List[Requirement] = Field(
47
+ description=(
48
+ "A list of requirement evaluations. Each element should include"
49
+ " the requirement's 8-character ID in the `requirement_id` field,"
50
+ " the original requirement in the 'requirement' field, "
51
+ " and nine score objects corresponding to each criterion."
52
+ )
53
+ )
54
+
55
+
56
+ class IncoseParser(JanusParser, PydanticOutputParser):
57
+ requirements: dict[str, str]
58
+
59
+ def __init__(self):
60
+ PydanticOutputParser.__init__(
61
+ self,
62
+ pydantic_object=RequirementList,
63
+ requirements={},
64
+ )
65
+
66
+ def parse_input(self, block: CodeBlock) -> str:
67
+ # TODO: Perform comment stripping/placeholding here rather than in script
68
+ text = super().parse_input(block)
69
+ RNG.seed(text)
70
+
71
+ obj = json.loads(text)
72
+
73
+ # For some reason requirements objects are in a double list?
74
+ reqs = obj["requirements"]
75
+
76
+ # Generate a unique ID for each requirement (ensure they are unique)
77
+ req_ids = set()
78
+ while len(req_ids) < len(reqs):
79
+ req_ids.add(str(uuid.UUID(int=RNG.getrandbits(128), version=4))[:8])
80
+
81
+ self.requirements = dict(zip(req_ids, reqs))
82
+ reqs_str = "\n\n".join(
83
+ f"Requirement {rid} : {req}" for rid, req in self.requirements.items()
84
+ )
85
+ obj["requirements"] = reqs_str
86
+ return json.dumps(obj)
87
+
88
+ def parse(self, text: str | BaseMessage) -> str:
89
+ if isinstance(text, BaseMessage):
90
+ text = str(text.content)
91
+
92
+ # Strip everything outside the JSON object
93
+ begin, end = text.find("["), text.rfind("]")
94
+ text = text[begin : end + 1]
95
+
96
+ try:
97
+ out: RequirementList = super().parse(text)
98
+ except json.JSONDecodeError as e:
99
+ log.debug(f"Invalid JSON object. Output:\n{text}")
100
+ raise OutputParserException(f"Got invalid JSON object. Error: {e}")
101
+
102
+ evals: dict[str, dict] = {c.requirement_id: c.dict() for c in out.__root__}
103
+
104
+ seen_keys = set(evals.keys())
105
+ expected_keys = set(self.requirements.keys())
106
+ missing_keys = expected_keys.difference(seen_keys)
107
+ invalid_keys = seen_keys.difference(expected_keys)
108
+ if missing_keys:
109
+ log.debug(f"Missing keys: {missing_keys}")
110
+ if invalid_keys:
111
+ log.debug(f"Invalid keys: {invalid_keys}")
112
+ log.debug(f"Missing keys: {missing_keys}")
113
+ raise OutputParserException(
114
+ f"Got invalid return object. Missing the following expected "
115
+ f"keys: {missing_keys}"
116
+ )
117
+
118
+ for key in invalid_keys:
119
+ del evals[key]
120
+
121
+ for rid in evals.keys():
122
+ evals[rid]["requirement"] = self.requirements[rid]
123
+ evals[rid].pop("requirement_id")
124
+
125
+ return json.dumps(evals)
126
+
127
+ def parse_combined_output(self, text: str) -> str:
128
+ if not text.strip():
129
+ return str({})
130
+ objs = [json.loads(line.strip()) for line in text.split("\n") if line.strip()]
131
+ output_obj = {}
132
+ for obj in objs:
133
+ output_obj.update(obj)
134
+ return json.dumps(output_obj)