janus-llm 1.0.0__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. janus/__init__.py +9 -1
  2. janus/__main__.py +4 -0
  3. janus/_tests/test_cli.py +128 -0
  4. janus/_tests/test_translate.py +49 -7
  5. janus/cli.py +530 -46
  6. janus/converter.py +50 -19
  7. janus/embedding/_tests/test_collections.py +2 -8
  8. janus/embedding/_tests/test_database.py +32 -0
  9. janus/embedding/_tests/test_vectorize.py +9 -4
  10. janus/embedding/collections.py +49 -6
  11. janus/embedding/embedding_models_info.py +120 -0
  12. janus/embedding/vectorize.py +53 -62
  13. janus/language/_tests/__init__.py +0 -0
  14. janus/language/_tests/test_combine.py +62 -0
  15. janus/language/_tests/test_splitter.py +16 -0
  16. janus/language/binary/_tests/test_binary.py +16 -1
  17. janus/language/binary/binary.py +10 -3
  18. janus/language/block.py +31 -30
  19. janus/language/combine.py +26 -34
  20. janus/language/mumps/_tests/test_mumps.py +2 -2
  21. janus/language/mumps/mumps.py +93 -9
  22. janus/language/naive/__init__.py +4 -0
  23. janus/language/naive/basic_splitter.py +14 -0
  24. janus/language/naive/chunk_splitter.py +26 -0
  25. janus/language/naive/registry.py +13 -0
  26. janus/language/naive/simple_ast.py +18 -0
  27. janus/language/naive/tag_splitter.py +61 -0
  28. janus/language/splitter.py +168 -74
  29. janus/language/treesitter/_tests/test_treesitter.py +9 -6
  30. janus/language/treesitter/treesitter.py +37 -13
  31. janus/llm/model_callbacks.py +177 -0
  32. janus/llm/models_info.py +134 -70
  33. janus/metrics/__init__.py +8 -0
  34. janus/metrics/_tests/__init__.py +0 -0
  35. janus/metrics/_tests/reference.py +2 -0
  36. janus/metrics/_tests/target.py +2 -0
  37. janus/metrics/_tests/test_bleu.py +56 -0
  38. janus/metrics/_tests/test_chrf.py +67 -0
  39. janus/metrics/_tests/test_file_pairing.py +59 -0
  40. janus/metrics/_tests/test_llm.py +91 -0
  41. janus/metrics/_tests/test_reading.py +28 -0
  42. janus/metrics/_tests/test_rouge_score.py +65 -0
  43. janus/metrics/_tests/test_similarity_score.py +23 -0
  44. janus/metrics/_tests/test_treesitter_metrics.py +110 -0
  45. janus/metrics/bleu.py +66 -0
  46. janus/metrics/chrf.py +55 -0
  47. janus/metrics/cli.py +7 -0
  48. janus/metrics/complexity_metrics.py +208 -0
  49. janus/metrics/file_pairing.py +113 -0
  50. janus/metrics/llm_metrics.py +202 -0
  51. janus/metrics/metric.py +466 -0
  52. janus/metrics/reading.py +70 -0
  53. janus/metrics/rouge_score.py +96 -0
  54. janus/metrics/similarity.py +53 -0
  55. janus/metrics/splitting.py +38 -0
  56. janus/parsers/_tests/__init__.py +0 -0
  57. janus/parsers/_tests/test_code_parser.py +32 -0
  58. janus/parsers/code_parser.py +24 -253
  59. janus/parsers/doc_parser.py +169 -0
  60. janus/parsers/eval_parser.py +80 -0
  61. janus/parsers/reqs_parser.py +72 -0
  62. janus/prompts/prompt.py +103 -30
  63. janus/translate.py +636 -111
  64. janus/utils/_tests/__init__.py +0 -0
  65. janus/utils/_tests/test_logger.py +67 -0
  66. janus/utils/_tests/test_progress.py +20 -0
  67. janus/utils/enums.py +56 -3
  68. janus/utils/progress.py +56 -0
  69. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/METADATA +23 -10
  70. janus_llm-2.0.0.dist-info/RECORD +94 -0
  71. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/WHEEL +1 -1
  72. janus_llm-1.0.0.dist-info/RECORD +0 -48
  73. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/LICENSE +0 -0
  74. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,38 @@
1
+ from typing import Callable
2
+
3
+ SPLITTING_METHODS: dict[str, Callable[[str, str], list[str]]] = {}
4
+
5
+
6
+ def register_splitting_method(name: None | str = None) -> Callable[[Callable], Callable]:
7
+ """Registers a pairing method for splitting strings in files
8
+
9
+ Arguments:
10
+ name: The name of the splitting method. If None, the function name is used.
11
+ help: The help text for the pairing method.
12
+
13
+ Returns:
14
+ The decorator function.
15
+ """
16
+
17
+ def decorator(f: Callable[[str, str], list[tuple[str, str]]]):
18
+ if name is None:
19
+ splitting_name = f.__name__
20
+ else:
21
+ splitting_name = name
22
+ SPLITTING_METHODS[splitting_name] = f
23
+ return f
24
+
25
+ return decorator
26
+
27
+
28
+ @register_splitting_method(name="file")
29
+ def split_by_file(src: str, **kwargs) -> list[str]:
30
+ """Split the source text by file
31
+
32
+ Arguments:
33
+ src: The source text.
34
+
35
+ Returns:
36
+ A list of strings.
37
+ """
38
+ return [src]
File without changes
@@ -0,0 +1,32 @@
1
+ import unittest
2
+
3
+ from ..code_parser import CodeParser, JanusParser
4
+
5
+
6
+ class TestJanusParser(unittest.TestCase):
7
+ def setUp(self):
8
+ self.parser = JanusParser()
9
+
10
+ def test_parse_combined_output(self):
11
+ text = "test text"
12
+ self.assertEqual(self.parser.parse_combined_output(text), text)
13
+
14
+
15
+ class TestCodeParser(unittest.TestCase):
16
+ def setUp(self):
17
+ self.parser = CodeParser(language="python")
18
+
19
+ def test_parse(self):
20
+ self.parser.language = "python"
21
+ text = "```\n# test text\n```"
22
+ self.assertEqual(self.parser.parse(text), text.strip("```").strip("\n"))
23
+
24
+ def test_get_format_instructions(self):
25
+ self.assertEqual(
26
+ self.parser.get_format_instructions(),
27
+ "Output must contain text contained within triple square brackets (```)",
28
+ )
29
+
30
+
31
+ if __name__ == "__main__":
32
+ unittest.main()
@@ -1,32 +1,17 @@
1
- import json
2
1
  import re
3
- from collections import defaultdict
4
- from typing import Any, Set
5
2
 
6
3
  from langchain.schema.output_parser import BaseOutputParser
4
+ from langchain_core.exceptions import OutputParserException
5
+ from langchain_core.messages import BaseMessage
6
+ from langchain_core.output_parsers import StrOutputParser
7
7
 
8
8
  from ..language.block import CodeBlock
9
- from ..language.combine import Combiner
10
9
  from ..utils.logger import create_logger
11
10
 
12
11
  log = create_logger(__name__)
13
12
 
14
13
 
15
- PARSER_TYPES: Set[str] = {"code", "text", "eval"}
16
-
17
-
18
- class JanusParser(BaseOutputParser):
19
- def parse(self, text: str) -> str:
20
- """Parse the output text from the LLM.
21
-
22
- Arguments:
23
- text: The output text from the LLM
24
-
25
- Returns:
26
- A parsed version of the text
27
- """
28
- return text
29
-
14
+ class JanusParser:
30
15
  def parse_combined_output(self, text: str) -> str:
31
16
  """Parse the output text from the LLM when multiple inputs are combined
32
17
 
@@ -36,253 +21,39 @@ class JanusParser(BaseOutputParser):
36
21
  Returns:
37
22
  A parsed version of the text
38
23
  """
24
+ if isinstance(text, BaseMessage):
25
+ text = text.content
39
26
  return text
40
27
 
41
- def score(self, input_block: CodeBlock, output_text: str) -> float:
42
- """Validate and score the output text based upon the input CodeBlock.
43
- Output is a score between 0 and 1.
28
+ def parse_into_block(self, text: str, block: CodeBlock):
29
+ if isinstance(text, BaseMessage):
30
+ text = text.content
31
+ block.text = text
44
32
 
45
- Arguments:
46
- input_block: A `CodeBlock` representing the input to the LLM
47
- output_text: The parsed text returned by the LLM
33
+ def set_reference(self, block: CodeBlock):
34
+ pass
48
35
 
49
- Returns:
50
- A score between 0 and 1 (inclusive). A score of 1.0 indicates that
51
- the given text is fully acceptable, and no further attempts
52
- should be made.
53
- """
54
- return 1.0
55
36
 
56
- def get_format_instructions(self) -> str:
57
- return "No format requirements"
58
-
59
- @property
60
- def _type(self) -> str:
61
- return type(self).__name__
37
+ class GenericParser(StrOutputParser, JanusParser):
38
+ def parse(self, text: str) -> str:
39
+ if isinstance(text, BaseMessage):
40
+ return text.content
41
+ return text
62
42
 
63
43
 
64
- class CodeParser(JanusParser):
44
+ class CodeParser(BaseOutputParser[str], JanusParser):
65
45
  language: str
66
46
 
67
47
  def parse(self, text: str) -> str:
68
- """Parse the output text from the LLM.
69
-
70
- Arguments:
71
- text: The output text from the LLM
72
-
73
- Returns:
74
- A parsed version of the text
75
- """
48
+ if isinstance(text, BaseMessage):
49
+ text = text.content
76
50
  pattern = rf"```[^\S\r\n]*(?:{self.language}[^\S\r\n]*)?\n?(.*?)\n*```"
77
51
  code = re.search(pattern, text, re.DOTALL)
78
52
  if code is None:
79
- raise ValueError("Code not find code between triple backticks")
80
- return code.group(1)
81
-
82
- def score(self, input_block: CodeBlock, output_text: str) -> float:
83
- """The score for translated code is the percentage of this block's
84
- children which are present in the output
85
-
86
- Arguments:
87
- input_block: A `CodeBlock` representing the input to the LLM
88
- output_text: The parsed text returned by the LLM
89
-
90
- Returns:
91
- A score between 0 and 1 (inclusive). A score of 1.0 indicates that
92
- the given text is fully acceptable, and no further attempts
93
- should be made.
94
- """
95
- if not input_block.children:
96
- return 1.0
97
-
98
- missing_children = []
99
- for child in input_block.children:
100
- if not Combiner.contains_child(output_text, child):
101
- missing_children.append(child.id)
102
-
103
- if missing_children:
104
- log.warning(
105
- f"[{input_block.name}] Child placeholders not present in text: "
106
- f"{missing_children}"
53
+ raise OutputParserException(
54
+ "Code not find code between triple square brackets"
107
55
  )
108
- log.debug(f"Code:\n{output_text}")
109
-
110
- return 1.0 - len(missing_children) / len(input_block.children)
111
-
112
- def get_format_instructions(self) -> str:
113
- return "Output must contain text contained within triple backticks."
114
-
115
-
116
- class JsonLinesParser(JanusParser):
117
- def parse(self, text: str) -> str:
118
- """Parse the output text from the LLM.
119
-
120
- Arguments:
121
- text: The output text from the LLM.
122
-
123
- Returns:
124
- A parsed version of the text.
125
- """
126
- string = r"\"\w+\""
127
- number = r"-?\d+(?:\.\d*)?"
128
- json_value = rf"(?:{string}|{number})"
129
- json_line = rf"\s*{string} *: *{json_value},?\s*"
130
- pattern = "({" + rf"(?:{json_line})+" + "})"
131
- matches = list(re.finditer(pattern, text, re.DOTALL))
132
- if not matches:
133
- raise ValueError("Could not find JSON output")
134
-
135
- output_strings = [json.dumps(json.loads(match.group(1))) for match in matches]
136
- return "\n".join(output_strings)
137
-
138
- def parse_combined_output(self, text: str) -> str:
139
- """Parse the output text from the LLM when multiple inputs are combined.
140
-
141
- Arguments:
142
- text: The output text from the LLM.
143
-
144
- Returns:
145
- A parsed version of the text.
146
- """
147
- return self.parse(text)
148
-
149
- def get_format_instructions(self) -> str:
150
- """Get the format instructions for the parser.
151
-
152
- Returns:
153
- The format instructions for the LLM.
154
- """
155
- return "Output must contain one or more JSON-formatted blocks."
156
-
157
-
158
- class JsonParser(JsonLinesParser):
159
- def parse(self, text: str) -> str:
160
- """Parse the output text from the LLM.
161
-
162
- Arguments:
163
- text: The output text from the LLM.
164
-
165
- Returns:
166
- A parsed version of the text.
167
- """
168
- jsonl_text = super().parse(text)
169
- if len(jsonl_text.split("\n")) > 1:
170
- raise ValueError("Multiple JSON objects found")
171
-
172
- return jsonl_text
173
-
174
- def parse_combined_output(self, text: str) -> str:
175
- """Parse the output text from the LLM when multiple inputs are combined.
176
-
177
- Arguments:
178
- text: The output text from the LLM.
179
-
180
- Returns:
181
- A parsed version of the text.
182
- """
183
- jsonl_text = JsonLinesParser.parse(self, text)
184
- json_lines = jsonl_text.split("\n")
185
- output_obj = {i: json.loads(t) for i, t in enumerate(json_lines)}
186
- return json.dumps(output_obj)
56
+ return str(code.group(1))
187
57
 
188
58
  def get_format_instructions(self) -> str:
189
- """Get the format instructions for the parser.
190
-
191
- Returns:
192
- The format instructions for the LLM.
193
- """
194
- return "Output must contain exactly one JSON-formatted block."
195
-
196
-
197
- class EvaluationParser(JsonParser):
198
- expected_keys: Set[str]
199
-
200
- def __init__(self, expected_keys: Set[str], **kwargs: Any):
201
- """Create a new EvaluationParser.
202
-
203
- Arguments:
204
- expected_keys: The set of keys that should be present in the JSON
205
- object
206
- kwargs: Additional arguments to pass to the parent class
207
- """
208
- super().__init__(expected_keys=expected_keys, **kwargs)
209
- self.expected_keys = {k.lower() for k in expected_keys}
210
-
211
- def parse(self, text: str) -> str:
212
- """Parse the JSON object, convert keys to lowercase, filter out
213
- unexpected keys
214
-
215
- Arguments:
216
- text: The output text from the LLM.
217
-
218
- Returns:
219
- A parsed version of the text.
220
- """
221
- json_text = super().parse(text)
222
- obj = json.loads(json_text)
223
- obj = {k.lower(): v for k, v in obj.items()}
224
- obj = {k: v for k, v in obj.items() if k in self.expected_keys}
225
- return json.dumps(obj)
226
-
227
- def parse_combined_output(self, text: str) -> str:
228
- """Parse the JSON object, convert keys to lowercase, filter out
229
- unexpected keys, and average the values
230
-
231
- Arguments:
232
- text: The output text from the LLM.
233
-
234
- Returns:
235
- A parsed version of the text.
236
- """
237
- json_text = super().parse_combined_output(text)
238
- multi_obj = json.loads(json_text)
239
- n_evals = len(multi_obj)
240
-
241
- output_obj = defaultdict(float)
242
- for obj in multi_obj.values():
243
- for k, v in obj.items():
244
- output_obj[k] += v / n_evals
245
-
246
- return json.dumps(output_obj)
247
-
248
- def score(self, input_block: CodeBlock, output_text: str) -> float:
249
- """The score for the output text is the percentage of expected keys
250
- that are present in the json object. Non-numeric values count for
251
- half.
252
-
253
- Arguments:
254
- input_block: A `CodeBlock` representing the input to the LLM
255
- output_text: The parsed text returned by the LLM
256
-
257
- Returns:
258
- A score between 0 and 1 (inclusive). A score of 1.0 indicates that
259
- the given text is fully acceptable, and no further attempts
260
- should be made.
261
- """
262
- obj = json.loads(output_text)
263
-
264
- expected_keys = self.expected_keys.intersection(obj.keys())
265
- missing_keys = self.expected_keys.difference(obj.keys())
266
- if missing_keys:
267
- log.warning(f"[{input_block.name}] Expected keys missing: {missing_keys}")
268
-
269
- non_numerics = {k: v for k, v in obj.items() if not isinstance(v, (int, float))}
270
- if non_numerics:
271
- log.warning(f"[{input_block.name}] Non-numeric values: {non_numerics}")
272
-
273
- if missing_keys or non_numerics:
274
- log.debug(f"Text:\n{output_text}")
275
-
276
- return (len(expected_keys) - len(non_numerics) * 0.5) / len(self.expected_keys)
277
-
278
- def get_format_instructions(self) -> str:
279
- """Get the format instructions for the parser.
280
-
281
- Returns:
282
- The format instructions for the LLM.
283
- """
284
- return (
285
- "Output must contain exactly one JSON-formatted block. The JSON "
286
- "object should contain only the keys contained in the provided "
287
- "expected_keys set (if any), and values should be numeric."
288
- )
59
+ return "Output must contain text contained within triple square brackets (```)"
@@ -0,0 +1,169 @@
1
+ import json
2
+ import re
3
+
4
+ from langchain.output_parsers import PydanticOutputParser
5
+ from langchain.output_parsers.json import parse_json_markdown
6
+ from langchain.schema.output_parser import BaseOutputParser
7
+ from langchain_core.exceptions import OutputParserException
8
+ from langchain_core.messages import AIMessage
9
+ from langchain_core.pydantic_v1 import BaseModel, Field
10
+
11
+ from ..language.block import CodeBlock
12
+ from ..utils.logger import create_logger
13
+ from .code_parser import JanusParser
14
+
15
+ log = create_logger(__name__)
16
+
17
+
18
+ class MultiDoc(BaseModel):
19
+ docstring: str = Field(
20
+ description="A Sphinx-style docstring for the code, including a summary "
21
+ "of its functionality; the name, type, and description of "
22
+ "any parameters or returns; and any potential exceptions "
23
+ "that might arise in its execution"
24
+ )
25
+ example_usage: str = Field(
26
+ description="A well-commented minimal example utilizing the given "
27
+ "code's functionality"
28
+ )
29
+ pseudocode: str = Field(
30
+ description="A Python-stype pseudocode implementation of the module or "
31
+ "function's behavior"
32
+ )
33
+
34
+
35
+ class MultiDocumentationParser(PydanticOutputParser, JanusParser):
36
+ block_name: str = ""
37
+
38
+ def __init__(self):
39
+ PydanticOutputParser.__init__(self, pydantic_object=MultiDoc)
40
+
41
+ def set_reference(self, block: CodeBlock):
42
+ self.block_name = block.name
43
+
44
+ def parse(self, text: str) -> str:
45
+ if isinstance(text, AIMessage):
46
+ text = text.content
47
+ try:
48
+ docs = json.loads(super().parse(text).json())
49
+ except (OutputParserException, json.JSONDecodeError):
50
+ log.debug(f"Invalid JSON object. Output:\n{text}")
51
+ raise
52
+ docs["name"] = self.block_name
53
+ return json.dumps(docs)
54
+
55
+ def parse_combined_output(self, text: str) -> str:
56
+ """Parse the output text from the LLM when multiple inputs are combined.
57
+
58
+ Arguments:
59
+ text: The output text from the LLM.
60
+
61
+ Returns:
62
+ A parsed version of the text.
63
+ """
64
+ objs = [
65
+ parse_json_markdown(line.strip()) for line in text.split("\n") if line.strip()
66
+ ]
67
+ output_obj = {d.pop("name"): d for d in objs}
68
+ return json.dumps(output_obj)
69
+
70
+ def get_format_instructions(self) -> str:
71
+ """Get the format instructions for the parser.
72
+
73
+ Returns:
74
+ The format instructions for the LLM.
75
+ """
76
+ return (
77
+ "Output must contain a sphinx-style docstring, example usage, and "
78
+ "pseudocode, all in a json-formatted string with the following fields: "
79
+ '"docstring", "example_usage", and "pseudocode".'
80
+ )
81
+
82
+ @property
83
+ def _type(self) -> str:
84
+ return self.__class__.name
85
+
86
+
87
+ class MadlibsDocumentationParser(BaseOutputParser[str], JanusParser):
88
+ expected_keys: set[str]
89
+
90
+ def __init__(self):
91
+ super().__init__(expected_keys=[])
92
+
93
+ def set_reference(self, block: CodeBlock):
94
+ comment_ids = re.findall(r"<(?:BLOCK|INLINE)_COMMENT (\w{8})>", block.text)
95
+ self.expected_keys = set(comment_ids)
96
+
97
+ def parse(self, text: str) -> str:
98
+ if isinstance(text, AIMessage):
99
+ text = text.content
100
+ try:
101
+ obj = parse_json_markdown(text)
102
+ except json.JSONDecodeError as e:
103
+ log.debug(f"Invalid JSON object. Output:\n{text}")
104
+ raise OutputParserException(f"Got invalid JSON object. Error: {e}")
105
+
106
+ if not isinstance(obj, dict):
107
+ raise OutputParserException(
108
+ f"Got invalid return object. Expected a dictionary, but got {type(obj)}"
109
+ )
110
+
111
+ seen_keys = set(obj.keys())
112
+ missing_keys = self.expected_keys.difference(obj.keys())
113
+ invalid_keys = seen_keys.difference(self.expected_keys)
114
+ if missing_keys:
115
+ log.debug(f"Missing keys: {missing_keys}")
116
+ if invalid_keys:
117
+ log.debug(f"Invalid keys: {invalid_keys}")
118
+ log.debug(f"Missing keys: {missing_keys}")
119
+ raise OutputParserException(
120
+ f"Got invalid return object. Missing the following expected "
121
+ f"keys: {missing_keys}"
122
+ )
123
+
124
+ for key in invalid_keys:
125
+ del obj[key]
126
+
127
+ for value in obj.values():
128
+ if not isinstance(value, str):
129
+ raise OutputParserException(
130
+ f"Got invalid return object. Expected all string values,"
131
+ f' but got type "{type(value)}"'
132
+ )
133
+
134
+ return json.dumps(obj)
135
+
136
+ def parse_combined_output(self, text: str) -> str:
137
+ """Parse the output text from the LLM when multiple inputs are combined.
138
+
139
+ Arguments:
140
+ text: The output text from the LLM.
141
+
142
+ Returns:
143
+ A parsed version of the text.
144
+ """
145
+ if not text.strip():
146
+ return str({})
147
+ objs = [
148
+ parse_json_markdown(line.strip()) for line in text.split("\n") if line.strip()
149
+ ]
150
+ output_obj = {}
151
+ for obj in objs:
152
+ output_obj.update(obj)
153
+ return json.dumps(output_obj)
154
+
155
+ def get_format_instructions(self) -> str:
156
+ """Get the format instructions for the parser.
157
+
158
+ Returns:
159
+ The format instructions for the LLM.
160
+ """
161
+ return (
162
+ "Output must contain exactly one JSON-formatted block. The JSON "
163
+ "object should contain only (and all of) the comment IDs present "
164
+ "in the input code."
165
+ )
166
+
167
+ @property
168
+ def _type(self) -> str:
169
+ return self.__class__.name
@@ -0,0 +1,80 @@
1
+ import json
2
+
3
+ from langchain.output_parsers import PydanticOutputParser
4
+ from langchain_core.pydantic_v1 import BaseModel, Field, validator
5
+
6
+ from ..utils.logger import create_logger
7
+ from .code_parser import JanusParser
8
+
9
+ log = create_logger(__name__)
10
+
11
+
12
+ class Eval(BaseModel):
13
+ syntax: float = Field(description="A numeric score (0-4) for code syntax")
14
+ style: float = Field(description="A numeric score (0-4) for code style")
15
+ completeness: float = Field(description="A numeric score (0-4) for code completeness")
16
+ correctness: float = Field(description="A numeric score (0-4) for code correctness")
17
+
18
+ # You can add custom validation logic easily with Pydantic.
19
+ @validator("*")
20
+ def score_is_valid(cls, v: float | int):
21
+ try:
22
+ v = float(v)
23
+ except ValueError:
24
+ raise ValueError("must be a number")
25
+
26
+ if not 0 <= v <= 4:
27
+ raise ValueError("must be a value between 0 and 4 inclusive")
28
+
29
+ return v
30
+
31
+ def __add__(self, other):
32
+ if isinstance(other, int) and other == 0:
33
+ return self.copy()
34
+ return Eval.construct(
35
+ syntax=self.syntax + other.syntax,
36
+ style=self.style + other.style,
37
+ correctness=self.correctness + other.correctness,
38
+ completeness=self.completeness + other.completeness,
39
+ )
40
+
41
+ def __radd__(self, other):
42
+ return self.__add__(other)
43
+
44
+ def __truediv__(self, other):
45
+ if isinstance(other, int):
46
+ return Eval.construct(
47
+ syntax=self.syntax / other,
48
+ style=self.style / other,
49
+ correctness=self.correctness / other,
50
+ completeness=self.completeness / other,
51
+ )
52
+ return Eval.construct(
53
+ syntax=self.syntax / other.syntax,
54
+ style=self.style / other.style,
55
+ correctness=self.correctness / other.correctness,
56
+ completeness=self.completeness / other.completeness,
57
+ )
58
+
59
+
60
+ class EvaluationParser(PydanticOutputParser, JanusParser):
61
+ def __init__(self):
62
+ PydanticOutputParser.__init__(self, pydantic_object=Eval)
63
+
64
+ def parse(self, text: str) -> str:
65
+ eval = super().parse(text)
66
+ return json.dumps(eval.json())
67
+
68
+ def parse_combined_output(self, text: str) -> str:
69
+ """Parse the JSON object, convert keys to lowercase, filter out
70
+ unexpected keys, and average the values
71
+
72
+ Arguments:
73
+ text: The output text from the LLM.
74
+
75
+ Returns:
76
+ A parsed version of the text.
77
+ """
78
+ objs = [super().parse(line.strip()) for line in text.split("\n")]
79
+ avg_obj = sum(objs) / len(objs)
80
+ return json.dumps(avg_obj.json())