janus-llm 4.1.0__py3-none-any.whl → 4.3.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,15 +1,241 @@
1
+ import json
2
+ import re
3
+ from copy import deepcopy
4
+
5
+ from langchain_core.runnables import Runnable, RunnableLambda, RunnableParallel
6
+
1
7
  from janus.converter.converter import Converter
8
+ from janus.language.block import TranslatedCodeBlock
2
9
  from janus.language.combine import JsonCombiner
3
- from janus.parsers.eval_parser import EvaluationParser
10
+ from janus.parsers.eval_parsers.incose_parser import IncoseParser
11
+ from janus.parsers.eval_parsers.inline_comment_parser import InlineCommentParser
4
12
  from janus.utils.logger import create_logger
5
13
 
6
14
  log = create_logger(__name__)
7
15
 
8
16
 
9
17
  class Evaluator(Converter):
10
- def __init__(self, **kwargs):
18
+ """Evaluator
19
+
20
+ A class that performs an LLM self evaluation"
21
+ "on an input target, with an associated prompt.
22
+
23
+ Current valid evaluation types:
24
+ ['incose', 'comments']
25
+
26
+ """
27
+
28
+ def __init__(self, **kwargs) -> None:
29
+ """Initialize the Evaluator class
30
+
31
+ Arguments:
32
+ model: The LLM to use for translation. If an OpenAI model, the
33
+ `OPENAI_API_KEY` environment variable must be set and the
34
+ `OPENAI_ORG_ID` environment variable should be set if needed.
35
+ model_arguments: Additional arguments to pass to the LLM constructor.
36
+ max_prompts: The maximum number of prompts to try before giving up.
37
+ """
38
+ super().__init__(**kwargs)
39
+ self._combiner = JsonCombiner()
40
+ self._load_parameters()
41
+
42
+
43
+ class RequirementEvaluator(Evaluator):
44
+ """INCOSE Requirement Evaluator
45
+
46
+ A class that performs an LLM self evaluation on an input target,
47
+ with an associated prompt.
48
+
49
+ The evaluation prompts are for Incose Evaluations
50
+
51
+ """
52
+
53
+ def __init__(self, eval_items_per_request: int | None = None, **kwargs) -> None:
54
+ """Initialize the Evaluator class
55
+
56
+ Arguments:
57
+ model: The LLM to use for translation. If an OpenAI model, the
58
+ `OPENAI_API_KEY` environment variable must be set and the
59
+ `OPENAI_ORG_ID` environment variable should be set if needed.
60
+ model_arguments: Additional arguments to pass to the LLM constructor.
61
+ max_prompts: The maximum number of prompts to try before giving up.
62
+ """
63
+ super().__init__(**kwargs)
64
+ self.eval_items_per_request = eval_items_per_request
65
+ self._parser = IncoseParser()
66
+ self.set_prompt("eval_prompts/incose")
67
+
68
+ def _input_runnable(self) -> Runnable:
69
+ def _get_code(json_text: str) -> str:
70
+ return json.loads(json_text)["code"]
71
+
72
+ def _get_reqs(json_text: str) -> str:
73
+ return json.dumps(json.loads(json_text)["requirements"])
74
+
75
+ return RunnableLambda(self._parser.parse_input) | RunnableParallel(
76
+ SOURCE_CODE=_get_code,
77
+ REQUIREMENTS=_get_reqs,
78
+ context=self._retriever,
79
+ )
80
+
81
+ def _add_translation(self, block: TranslatedCodeBlock):
82
+ if block.translated:
83
+ return
84
+
85
+ if block.original.text is None:
86
+ block.translated = True
87
+ return
88
+
89
+ if self.eval_items_per_request is None:
90
+ return super()._add_translation(block)
91
+
92
+ input_obj = json.loads(block.original.text)
93
+ requirements = input_obj.get("requirements", [])
94
+
95
+ if not requirements:
96
+ log.debug(f"[{block.name}] Skipping empty block")
97
+ block.translated = True
98
+ block.text = None
99
+ block.complete = True
100
+ return
101
+
102
+ # For some reason requirements objects are in nested lists?
103
+ while isinstance(requirements[0], list):
104
+ requirements = [r for lst in requirements for r in lst]
105
+
106
+ if len(requirements) <= self.eval_items_per_request:
107
+ input_obj["requirements"] = requirements
108
+ block.original.text = json.dumps(input_obj)
109
+ return super()._add_translation(block)
110
+
111
+ block.processing_time = 0
112
+ block.cost = 0
113
+ block.retries = 0
114
+ obj = {}
115
+ for i in range(0, len(requirements), self.eval_items_per_request):
116
+ # Build a new TranslatedBlock using the new working text
117
+ working_requirements = requirements[i : i + self.eval_items_per_request]
118
+ working_copy = deepcopy(block.original)
119
+ working_obj = json.loads(working_copy.text) # type: ignore
120
+ working_obj["requirements"] = working_requirements
121
+ working_copy.text = json.dumps(working_obj)
122
+ working_block = TranslatedCodeBlock(working_copy, self._target_language)
123
+
124
+ # Run the LLM on the working text
125
+ super()._add_translation(working_block)
126
+
127
+ # Update metadata to include for all runs
128
+ block.retries += working_block.retries
129
+ block.cost += working_block.cost
130
+ block.processing_time += working_block.processing_time
131
+
132
+ # Update the output text to merge this section's output in
133
+ obj.update(json.loads(working_block.text))
134
+
135
+ block.text = json.dumps(obj)
136
+ block.tokens = self._llm.get_num_tokens(block.text)
137
+ block.translated = True
138
+
139
+ log.debug(
140
+ f"[{block.name}] Output code:\n{json.dumps(json.loads(block.text), indent=2)}"
141
+ )
142
+
143
+
144
+ class InlineCommentEvaluator(Evaluator):
145
+ """Inline Comment Evaluator
146
+
147
+ A class that performs an LLM self evaluation on inline comments,
148
+ with an associated prompt.
149
+ """
150
+
151
+ def __init__(self, eval_items_per_request: int | None = None, **kwargs) -> None:
152
+ """Initialize the Evaluator class
153
+
154
+ Arguments:
155
+ model: The LLM to use for translation. If an OpenAI model, the
156
+ `OPENAI_API_KEY` environment variable must be set and the
157
+ `OPENAI_ORG_ID` environment variable should be set if needed.
158
+ model_arguments: Additional arguments to pass to the LLM constructor.
159
+ max_prompts: The maximum number of prompts to try before giving up.
160
+ """
11
161
  super().__init__(**kwargs)
12
- self.set_prompt("evaluate")
13
162
  self._combiner = JsonCombiner()
14
- self._parser = EvaluationParser()
15
163
  self._load_parameters()
164
+ self._parser = InlineCommentParser()
165
+ self.set_prompt("eval_prompts/inline_comments")
166
+ self.eval_items_per_request = eval_items_per_request
167
+
168
+ def _add_translation(self, block: TranslatedCodeBlock):
169
+ if block.translated:
170
+ return
171
+
172
+ if block.original.text is None:
173
+ block.translated = True
174
+ return
175
+
176
+ if self.eval_items_per_request is None:
177
+ return super()._add_translation(block)
178
+
179
+ comment_pattern = r"<(?:INLINE|BLOCK)_COMMENT \w{8}>.*$"
180
+ comments = list(
181
+ re.finditer(comment_pattern, block.original.text, flags=re.MULTILINE)
182
+ )
183
+
184
+ if not comments:
185
+ log.info(f"[{block.name}] Skipping commentless block")
186
+ block.translated = True
187
+ block.text = None
188
+ block.complete = True
189
+ return
190
+
191
+ if len(comments) <= self.eval_items_per_request:
192
+ return super()._add_translation(block)
193
+
194
+ comment_group_indices = list(range(0, len(comments), self.eval_items_per_request))
195
+ log.debug(
196
+ f"[{block.name}] Block contains more than {self.eval_items_per_request}"
197
+ f" comments, splitting {len(comments)} comments into"
198
+ f" {len(comment_group_indices)} groups"
199
+ )
200
+
201
+ block.processing_time = 0
202
+ block.cost = 0
203
+ block.retries = 0
204
+ obj = {}
205
+ for i in range(0, len(comments), self.eval_items_per_request):
206
+ # Split the text into the section containing comments of interest,
207
+ # all the text prior to those comments, and all the text after them
208
+ working_comments = comments[i : i + self.eval_items_per_request]
209
+ start_idx = working_comments[0].start()
210
+ end_idx = working_comments[-1].end()
211
+ prefix = block.original.text[:start_idx]
212
+ keeper = block.original.text[start_idx:end_idx]
213
+ suffix = block.original.text[end_idx:]
214
+
215
+ # Strip all comment placeholders outside of the section of interest
216
+ prefix = re.sub(comment_pattern, "", prefix, flags=re.MULTILINE)
217
+ suffix = re.sub(comment_pattern, "", suffix, flags=re.MULTILINE)
218
+
219
+ # Build a new TranslatedBlock using the new working text
220
+ working_copy = deepcopy(block.original)
221
+ working_copy.text = prefix + keeper + suffix
222
+ working_block = TranslatedCodeBlock(working_copy, self._target_language)
223
+
224
+ # Run the LLM on the working text
225
+ super()._add_translation(working_block)
226
+
227
+ # Update metadata to include for all runs
228
+ block.retries += working_block.retries
229
+ block.cost += working_block.cost
230
+ block.processing_time += working_block.processing_time
231
+
232
+ # Update the output text to merge this section's output in
233
+ obj.update(json.loads(working_block.text))
234
+
235
+ block.text = json.dumps(obj)
236
+ block.tokens = self._llm.get_num_tokens(block.text)
237
+ block.translated = True
238
+
239
+ log.debug(
240
+ f"[{block.name}] Output code:\n{json.dumps(json.loads(block.text), indent=2)}"
241
+ )
@@ -0,0 +1,27 @@
1
+ from pathlib import Path
2
+
3
+ from janus.converter.converter import Converter
4
+ from janus.language.block import TranslatedCodeBlock
5
+ from janus.parsers.partition_parser import PartitionParser
6
+ from janus.utils.logger import create_logger
7
+
8
+ log = create_logger(__name__)
9
+
10
+
11
+ class Partitioner(Converter):
12
+ def __init__(self, partition_token_limit: int, **kwargs):
13
+ super().__init__(**kwargs)
14
+ self.set_prompt("partition")
15
+ self._load_model()
16
+ self._parser = PartitionParser(
17
+ token_limit=partition_token_limit,
18
+ model=self._llm,
19
+ )
20
+ self._target_language = self._source_language
21
+ self._target_suffix = self._source_suffix
22
+ self._load_parameters()
23
+
24
+ def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
25
+ output_str = self._parser.parse_combined_output(block.complete_text)
26
+ out_path.parent.mkdir(parents=True, exist_ok=True)
27
+ out_path.write_text(output_str, encoding="utf-8")
@@ -20,7 +20,7 @@ class TestAlcSplitter(unittest.TestCase):
20
20
  def test_split(self):
21
21
  """Test the split method."""
22
22
  tree_root = self.splitter.split(self.test_file)
23
- self.assertAlmostEqual(tree_root.n_descendents, 32, delta=5)
23
+ self.assertAlmostEqual(tree_root.n_descendents, 16, delta=2)
24
24
  self.assertLessEqual(tree_root.max_tokens, self.splitter.max_tokens)
25
25
  self.assertFalse(tree_root.complete)
26
26
  self.combiner.combine_children(tree_root)
janus/language/alc/alc.py CHANGED
@@ -79,10 +79,15 @@ class AlcSplitter(TreeSitterSplitter):
79
79
  if len(sects) > 1:
80
80
  block.children = []
81
81
  for sect in sects:
82
- if sect[0].node_type in sect_types:
83
- sect_node = self.merge_nodes(sect)
84
- sect_node.children = sect
85
- sect_node.node_type = NodeType(str(sect[0].node_type)[:5])
82
+ node_type = sect[0].node_type
83
+ if node_type in sect_types:
84
+ if len(sect) == 1:
85
+ # Don't make a node its own child
86
+ sect_node = sect[0]
87
+ else:
88
+ sect_node = self.merge_nodes(sect)
89
+ sect_node.children = sect
90
+ sect_node.node_type = NodeType(str(node_type)[:5])
86
91
  block.children.append(sect_node)
87
92
  else:
88
93
  block.children.extend(sect)
janus/language/combine.py CHANGED
@@ -1,3 +1,5 @@
1
+ import re
2
+
1
3
  from janus.language.block import CodeBlock, TranslatedCodeBlock
2
4
  from janus.language.file import FileManager
3
5
  from janus.utils.logger import create_logger
@@ -90,3 +92,23 @@ class ChunkCombiner(Combiner):
90
92
  root: The functional code block to combine with its children.
91
93
  """
92
94
  return root
95
+
96
+
97
+ class PartitionCombiner(Combiner):
98
+ @staticmethod
99
+ def combine(root: CodeBlock) -> None:
100
+ """A combiner which inserts partition tags between code blocks"""
101
+ queue = [root]
102
+ while queue:
103
+ block = queue.pop(0)
104
+ if block.children:
105
+ queue.extend(block.children)
106
+ else:
107
+ block.affixes = (block.prefix, block.suffix + "\n<JANUS_PARTITION>\n")
108
+
109
+ super(PartitionCombiner, PartitionCombiner).combine(root)
110
+ root.text = re.sub(r"(?:\n<JANUS_PARTITION>\n)+$", "", root.text)
111
+ root.affixes = (
112
+ root.prefix,
113
+ re.sub(r"(?:\n<JANUS_PARTITION>\n)+$", "", root.suffix),
114
+ )
@@ -275,42 +275,50 @@ class Splitter(FileManager):
275
275
 
276
276
  groups = [[n] for n in nodes]
277
277
  while len(groups) > 1 and min(adj_sums) <= self.max_tokens and any(merge_allowed):
278
- # Get the indices of the adjacent nodes that would result in the
279
- # smallest possible merged snippet. Ignore protected nodes.
278
+ # Get the index of the node that would result in the smallest
279
+ # merged snippet when merged with the node that follows it.
280
+ # Ignore protected nodes.
280
281
  mergeable_indices = compress(range(len(adj_sums)), merge_allowed)
281
- i0 = int(min(mergeable_indices, key=adj_sums.__getitem__))
282
- i1 = i0 + 1
282
+ C = int(min(mergeable_indices, key=adj_sums.__getitem__))
283
+
284
+ # C: Central index
285
+ # L: Index to the left
286
+ # R: Index to the right (to be merged in to C)
287
+ # N: Next index (to the right of R, the "new R")
288
+ L, R, N = C - 1, C + 1, C + 2
283
289
 
284
290
  # Recalculate the length. We can't simply use the adj_sum, because
285
291
  # it is an underestimate due to the adjoining suffix/prefix.
286
- central_node = groups[i0][-1]
287
- merged_text = "".join([text_chunks[i0], central_node.suffix, text_chunks[i1]])
292
+ central_node = groups[C][-1]
293
+ merged_text = "".join([text_chunks[C], central_node.suffix, text_chunks[R]])
288
294
  merged_text_length = self._count_tokens(merged_text)
289
295
 
290
296
  # If the true length of the merged pair is too long, don't merge them
291
297
  # Instead, correct the estimate, since shorter pairs may yet exist
292
298
  if merged_text_length > self.max_tokens:
293
- adj_sums[i0] = merged_text_length
299
+ adj_sums[C] = merged_text_length
294
300
  continue
295
301
 
296
302
  # Update adjacent sum estimates
297
- if i0 > 0:
298
- adj_sums[i0 - 1] += merged_text_length
299
- if i1 < len(adj_sums) - 1:
300
- adj_sums[i1 + 1] += merged_text_length
301
-
302
- if i0 > 0 and i1 < len(merge_allowed) - 1:
303
- if not (merge_allowed[i0 - 1] and merge_allowed[i1 + 1]):
304
- merge_allowed[i0 - 1] = merge_allowed[i1 + 1] = False
303
+ if L >= 0:
304
+ adj_sums[L] = lengths[L] + merged_text_length
305
+ if N < len(adj_sums):
306
+ adj_sums[R] = lengths[N] + merged_text_length
305
307
 
306
308
  # The potential merge length for this pair is removed
307
- adj_sums.pop(i0)
308
- merge_allowed.pop(i0)
309
+ adj_sums.pop(C)
310
+
311
+ # The merged-in node is removed from the protected list
312
+ # The merge_allowed list need not be updated - if the node now to
313
+ # its right is protected, the merge_allowed element corresponding
314
+ # to the merged neighbor will have been True, and now corresponds
315
+ # to the merged node.
316
+ merge_allowed.pop(C)
309
317
 
310
318
  # Merge the pair of node groups
311
- groups[i0 : i1 + 1] = [groups[i0] + groups[i1]]
312
- text_chunks[i0 : i1 + 1] = [merged_text]
313
- lengths[i0 : i1 + 1] = [merged_text_length]
319
+ groups[C:N] = [groups[C] + groups[R]]
320
+ text_chunks[C:N] = [merged_text]
321
+ lengths[C:N] = [merged_text_length]
314
322
 
315
323
  return groups
316
324
 
@@ -403,13 +411,13 @@ class Splitter(FileManager):
403
411
  self._split_into_lines(node)
404
412
 
405
413
  def _split_into_lines(self, node: CodeBlock):
406
- split_text = re.split(r"(\n+)", node.text)
414
+ split_text = list(re.split(r"(\n+)", node.text))
407
415
 
408
416
  # If the string didn't start/end with newlines, make sure to include
409
417
  # empty strings for the prefix/suffixes
410
- if split_text[0].strip("\n"):
418
+ if not re.match(r"^\n+$", split_text[0]):
411
419
  split_text = [""] + split_text
412
- if split_text[-1].strip("\n"):
420
+ if not re.match(r"^\n+$", split_text[-1]):
413
421
  split_text.append("")
414
422
  betweens = split_text[::2]
415
423
  lines = split_text[1::2]
@@ -154,7 +154,15 @@ class TreeSitterSplitter(Splitter):
154
154
  The pointer to the language.
155
155
  """
156
156
  lib = cdll.LoadLibrary(os.fspath(so_file))
157
- language_function = getattr(lib, f"tree_sitter_{self.language}")
157
+ # Added this try-except block to handle the case where the language is not
158
+ # supported in lowercase by the creator of the grammar. Ex: COBOL
159
+ # https://github.com/yutaro-sakamoto/tree-sitter-cobol/blob/main/grammar.js#L13
160
+ try:
161
+ language_function = getattr(lib, f"tree_sitter_{self.language}")
162
+ except AttributeError:
163
+ language = self.language.upper()
164
+ language_function = getattr(lib, f"tree_sitter_{language}")
165
+
158
166
  language_function.restype = c_void_p
159
167
  pointer = language_function()
160
168
  return pointer
janus/llm/models_info.py CHANGED
@@ -6,9 +6,13 @@ from typing import Callable, Protocol, TypeVar
6
6
  from dotenv import load_dotenv
7
7
  from langchain_community.llms import HuggingFaceTextGenInference
8
8
  from langchain_core.runnables import Runnable
9
- from langchain_openai import AzureChatOpenAI
9
+ from langchain_openai import AzureChatOpenAI, ChatOpenAI
10
10
 
11
- from janus.llm.model_callbacks import COST_PER_1K_TOKENS, azure_model_reroutes
11
+ from janus.llm.model_callbacks import (
12
+ COST_PER_1K_TOKENS,
13
+ azure_model_reroutes,
14
+ openai_model_reroutes,
15
+ )
12
16
  from janus.prompts.prompt import (
13
17
  ChatGptPromptEngine,
14
18
  ClaudePromptEngine,
@@ -90,6 +94,7 @@ claude_models = [
90
94
  "bedrock-claude-instant-v1",
91
95
  "bedrock-claude-haiku",
92
96
  "bedrock-claude-sonnet",
97
+ "bedrock-claude-sonnet-3.5",
93
98
  ]
94
99
  llama2_models = [
95
100
  "bedrock-llama2-70b",
@@ -126,7 +131,7 @@ bedrock_models = [
126
131
  all_models = [*azure_models, *bedrock_models]
127
132
 
128
133
  MODEL_TYPE_CONSTRUCTORS: dict[str, ModelType] = {
129
- # "OpenAI": ChatOpenAI,
134
+ "OpenAI": ChatOpenAI,
130
135
  "HuggingFace": HuggingFaceTextGenInference,
131
136
  "Azure": AzureChatOpenAI,
132
137
  "Bedrock": Bedrock,
@@ -136,7 +141,7 @@ MODEL_TYPE_CONSTRUCTORS: dict[str, ModelType] = {
136
141
 
137
142
 
138
143
  MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
139
- # **{m: ChatGptPromptEngine for m in openai_models},
144
+ **{m: ChatGptPromptEngine for m in openai_models},
140
145
  **{m: ChatGptPromptEngine for m in azure_models},
141
146
  **{m: ClaudePromptEngine for m in claude_models},
142
147
  **{m: Llama2PromptEngine for m in llama2_models},
@@ -147,12 +152,13 @@ MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
147
152
  }
148
153
 
149
154
  MODEL_ID_TO_LONG_ID = {
150
- # **{m: mr for m, mr in openai_model_reroutes.items()},
155
+ **{m: mr for m, mr in openai_model_reroutes.items()},
151
156
  **{m: mr for m, mr in azure_model_reroutes.items()},
152
157
  "bedrock-claude-v2": "anthropic.claude-v2",
153
158
  "bedrock-claude-instant-v1": "anthropic.claude-instant-v1",
154
159
  "bedrock-claude-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
155
160
  "bedrock-claude-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
161
+ "bedrock-claude-sonnet-3.5": "anthropic.claude-3-5-sonnet-20240620-v1:0",
156
162
  "bedrock-llama2-70b": "meta.llama2-70b-v1",
157
163
  "bedrock-llama2-70b-chat": "meta.llama2-70b-chat-v1",
158
164
  "bedrock-llama2-13b": "meta.llama2-13b-chat-v1",
@@ -179,7 +185,7 @@ DEFAULT_MODELS = list(MODEL_DEFAULT_ARGUMENTS.keys())
179
185
  MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "llm"
180
186
 
181
187
  MODEL_TYPES: dict[str, PromptEngine] = {
182
- # **{m: "OpenAI" for m in openai_models},
188
+ **{m: "OpenAI" for m in openai_models},
183
189
  **{m: "Azure" for m in azure_models},
184
190
  **{m: "BedrockChat" for m in bedrock_models},
185
191
  }
@@ -200,6 +206,7 @@ TOKEN_LIMITS: dict[str, int] = {
200
206
  "anthropic.claude-instant-v1": 100_000,
201
207
  "anthropic.claude-3-haiku-20240307-v1:0": 248_000,
202
208
  "anthropic.claude-3-sonnet-20240229-v1:0": 248_000,
209
+ "anthropic.claude-3-5-sonnet-20240620-v1:0": 200_000,
203
210
  "meta.llama2-70b-v1": 4096,
204
211
  "meta.llama2-70b-chat-v1": 4096,
205
212
  "meta.llama2-13b-chat-v1": 4096,
@@ -286,15 +293,16 @@ def load_model(model_id) -> JanusModel:
286
293
  # log.warning("Waiting 10 seconds...")
287
294
  # Give enough time for the user to read the warnings and cancel
288
295
  # time.sleep(10)
289
- raise DeprecationWarning("OpenAI models are no longer supported.")
296
+ # raise DeprecationWarning("OpenAI models are no longer supported.")
290
297
 
291
298
  elif model_type_name == "Azure":
292
299
  model_args.update(
293
- {
294
- "api_key": os.getenv("AZURE_OPENAI_API_KEY"),
295
- "azure_endpoint": os.getenv("AZURE_OPENAI_ENDPOINT"),
296
- "api_version": os.getenv("OPENAI_API_VERSION", "2024-02-01"),
297
- }
300
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
301
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
302
+ api_version=os.getenv("OPENAI_API_VERSION", "2024-02-01"),
303
+ azure_deployment=model_id,
304
+ request_timeout=3600,
305
+ max_tokens=4096,
298
306
  )
299
307
 
300
308
  model_type = MODEL_TYPE_CONSTRUCTORS[model_type_name]
@@ -0,0 +1,134 @@
1
+ import json
2
+ import random
3
+ import uuid
4
+ from typing import List
5
+
6
+ from langchain.output_parsers import PydanticOutputParser
7
+ from langchain_core.exceptions import OutputParserException
8
+ from langchain_core.messages import BaseMessage
9
+ from langchain_core.pydantic_v1 import BaseModel, Field, validator
10
+
11
+ from janus.language.block import CodeBlock
12
+ from janus.parsers.parser import JanusParser
13
+ from janus.utils.logger import create_logger
14
+
15
+ log = create_logger(__name__)
16
+ RNG = random.Random()
17
+
18
+
19
+ class Criteria(BaseModel):
20
+ reasoning: str = Field(description="A short explanation for the given assessment")
21
+ score: str = Field("A simple `pass` or `fail`")
22
+
23
+ @validator("score")
24
+ def score_is_valid(cls, v: str):
25
+ v = v.lower().strip()
26
+ if v not in {"pass", "fail"}:
27
+ raise OutputParserException("Score must be either 'pass' or 'fail'")
28
+ return v
29
+
30
+
31
+ class Requirement(BaseModel):
32
+ requirement_id: str = Field(description="The 8-character comment ID")
33
+ requirement: str = Field(description="The original requirement being evaluated")
34
+ C1: Criteria
35
+ C2: Criteria
36
+ C3: Criteria
37
+ C4: Criteria
38
+ C5: Criteria
39
+ C6: Criteria
40
+ C7: Criteria
41
+ C8: Criteria
42
+ C9: Criteria
43
+
44
+
45
+ class RequirementList(BaseModel):
46
+ __root__: List[Requirement] = Field(
47
+ description=(
48
+ "A list of requirement evaluations. Each element should include"
49
+ " the requirement's 8-character ID in the `requirement_id` field,"
50
+ " the original requirement in the 'requirement' field, "
51
+ " and nine score objects corresponding to each criterion."
52
+ )
53
+ )
54
+
55
+
56
+ class IncoseParser(JanusParser, PydanticOutputParser):
57
+ requirements: dict[str, str]
58
+
59
+ def __init__(self):
60
+ PydanticOutputParser.__init__(
61
+ self,
62
+ pydantic_object=RequirementList,
63
+ requirements={},
64
+ )
65
+
66
+ def parse_input(self, block: CodeBlock) -> str:
67
+ # TODO: Perform comment stripping/placeholding here rather than in script
68
+ text = super().parse_input(block)
69
+ RNG.seed(text)
70
+
71
+ obj = json.loads(text)
72
+
73
+ # For some reason requirements objects are in a double list?
74
+ reqs = obj["requirements"]
75
+
76
+ # Generate a unique ID for each requirement (ensure they are unique)
77
+ req_ids = set()
78
+ while len(req_ids) < len(reqs):
79
+ req_ids.add(str(uuid.UUID(int=RNG.getrandbits(128), version=4))[:8])
80
+
81
+ self.requirements = dict(zip(req_ids, reqs))
82
+ reqs_str = "\n\n".join(
83
+ f"Requirement {rid} : {req}" for rid, req in self.requirements.items()
84
+ )
85
+ obj["requirements"] = reqs_str
86
+ return json.dumps(obj)
87
+
88
+ def parse(self, text: str | BaseMessage) -> str:
89
+ if isinstance(text, BaseMessage):
90
+ text = str(text.content)
91
+
92
+ # Strip everything outside the JSON object
93
+ begin, end = text.find("["), text.rfind("]")
94
+ text = text[begin : end + 1]
95
+
96
+ try:
97
+ out: RequirementList = super().parse(text)
98
+ except json.JSONDecodeError as e:
99
+ log.debug(f"Invalid JSON object. Output:\n{text}")
100
+ raise OutputParserException(f"Got invalid JSON object. Error: {e}")
101
+
102
+ evals: dict[str, dict] = {c.requirement_id: c.dict() for c in out.__root__}
103
+
104
+ seen_keys = set(evals.keys())
105
+ expected_keys = set(self.requirements.keys())
106
+ missing_keys = expected_keys.difference(seen_keys)
107
+ invalid_keys = seen_keys.difference(expected_keys)
108
+ if missing_keys:
109
+ log.debug(f"Missing keys: {missing_keys}")
110
+ if invalid_keys:
111
+ log.debug(f"Invalid keys: {invalid_keys}")
112
+ log.debug(f"Missing keys: {missing_keys}")
113
+ raise OutputParserException(
114
+ f"Got invalid return object. Missing the following expected "
115
+ f"keys: {missing_keys}"
116
+ )
117
+
118
+ for key in invalid_keys:
119
+ del evals[key]
120
+
121
+ for rid in evals.keys():
122
+ evals[rid]["requirement"] = self.requirements[rid]
123
+ evals[rid].pop("requirement_id")
124
+
125
+ return json.dumps(evals)
126
+
127
+ def parse_combined_output(self, text: str) -> str:
128
+ if not text.strip():
129
+ return str({})
130
+ objs = [json.loads(line.strip()) for line in text.split("\n") if line.strip()]
131
+ output_obj = {}
132
+ for obj in objs:
133
+ output_obj.update(obj)
134
+ return json.dumps(output_obj)