janus-llm 1.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (74) hide show
  1. janus/__init__.py +9 -1
  2. janus/__main__.py +4 -0
  3. janus/_tests/test_cli.py +128 -0
  4. janus/_tests/test_translate.py +49 -7
  5. janus/cli.py +530 -46
  6. janus/converter.py +50 -19
  7. janus/embedding/_tests/test_collections.py +2 -8
  8. janus/embedding/_tests/test_database.py +32 -0
  9. janus/embedding/_tests/test_vectorize.py +9 -4
  10. janus/embedding/collections.py +49 -6
  11. janus/embedding/embedding_models_info.py +130 -0
  12. janus/embedding/vectorize.py +53 -62
  13. janus/language/_tests/__init__.py +0 -0
  14. janus/language/_tests/test_combine.py +62 -0
  15. janus/language/_tests/test_splitter.py +16 -0
  16. janus/language/binary/_tests/test_binary.py +16 -1
  17. janus/language/binary/binary.py +10 -3
  18. janus/language/block.py +31 -30
  19. janus/language/combine.py +26 -34
  20. janus/language/mumps/_tests/test_mumps.py +2 -2
  21. janus/language/mumps/mumps.py +93 -9
  22. janus/language/naive/__init__.py +4 -0
  23. janus/language/naive/basic_splitter.py +14 -0
  24. janus/language/naive/chunk_splitter.py +26 -0
  25. janus/language/naive/registry.py +13 -0
  26. janus/language/naive/simple_ast.py +18 -0
  27. janus/language/naive/tag_splitter.py +61 -0
  28. janus/language/splitter.py +168 -74
  29. janus/language/treesitter/_tests/test_treesitter.py +19 -14
  30. janus/language/treesitter/treesitter.py +37 -13
  31. janus/llm/model_callbacks.py +177 -0
  32. janus/llm/models_info.py +165 -72
  33. janus/metrics/__init__.py +8 -0
  34. janus/metrics/_tests/__init__.py +0 -0
  35. janus/metrics/_tests/reference.py +2 -0
  36. janus/metrics/_tests/target.py +2 -0
  37. janus/metrics/_tests/test_bleu.py +56 -0
  38. janus/metrics/_tests/test_chrf.py +67 -0
  39. janus/metrics/_tests/test_file_pairing.py +59 -0
  40. janus/metrics/_tests/test_llm.py +91 -0
  41. janus/metrics/_tests/test_reading.py +28 -0
  42. janus/metrics/_tests/test_rouge_score.py +65 -0
  43. janus/metrics/_tests/test_similarity_score.py +23 -0
  44. janus/metrics/_tests/test_treesitter_metrics.py +110 -0
  45. janus/metrics/bleu.py +66 -0
  46. janus/metrics/chrf.py +55 -0
  47. janus/metrics/cli.py +7 -0
  48. janus/metrics/complexity_metrics.py +208 -0
  49. janus/metrics/file_pairing.py +113 -0
  50. janus/metrics/llm_metrics.py +202 -0
  51. janus/metrics/metric.py +466 -0
  52. janus/metrics/reading.py +70 -0
  53. janus/metrics/rouge_score.py +96 -0
  54. janus/metrics/similarity.py +53 -0
  55. janus/metrics/splitting.py +38 -0
  56. janus/parsers/_tests/__init__.py +0 -0
  57. janus/parsers/_tests/test_code_parser.py +32 -0
  58. janus/parsers/code_parser.py +24 -253
  59. janus/parsers/doc_parser.py +169 -0
  60. janus/parsers/eval_parser.py +80 -0
  61. janus/parsers/reqs_parser.py +72 -0
  62. janus/prompts/prompt.py +103 -30
  63. janus/translate.py +636 -111
  64. janus/utils/_tests/__init__.py +0 -0
  65. janus/utils/_tests/test_logger.py +67 -0
  66. janus/utils/_tests/test_progress.py +20 -0
  67. janus/utils/enums.py +56 -3
  68. janus/utils/progress.py +56 -0
  69. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/METADATA +27 -11
  70. janus_llm-2.0.1.dist-info/RECORD +94 -0
  71. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/WHEEL +1 -1
  72. janus_llm-1.0.0.dist-info/RECORD +0 -48
  73. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/LICENSE +0 -0
  74. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,38 @@
1
+ from typing import Callable
2
+
3
+ SPLITTING_METHODS: dict[str, Callable[[str, str], list[str]]] = {}
4
+
5
+
6
+ def register_splitting_method(name: None | str = None) -> Callable[[Callable], Callable]:
7
+ """Registers a pairing method for splitting strings in files
8
+
9
+ Arguments:
10
+ name: The name of the splitting method. If None, the function name is used.
11
+ help: The help text for the pairing method.
12
+
13
+ Returns:
14
+ The decorator function.
15
+ """
16
+
17
+ def decorator(f: Callable[[str, str], list[tuple[str, str]]]):
18
+ if name is None:
19
+ splitting_name = f.__name__
20
+ else:
21
+ splitting_name = name
22
+ SPLITTING_METHODS[splitting_name] = f
23
+ return f
24
+
25
+ return decorator
26
+
27
+
28
+ @register_splitting_method(name="file")
29
+ def split_by_file(src: str, **kwargs) -> list[str]:
30
+ """Split the source text by file
31
+
32
+ Arguments:
33
+ src: The source text.
34
+
35
+ Returns:
36
+ A list of strings.
37
+ """
38
+ return [src]
File without changes
@@ -0,0 +1,32 @@
1
+ import unittest
2
+
3
+ from ..code_parser import CodeParser, JanusParser
4
+
5
+
6
+ class TestJanusParser(unittest.TestCase):
7
+ def setUp(self):
8
+ self.parser = JanusParser()
9
+
10
+ def test_parse_combined_output(self):
11
+ text = "test text"
12
+ self.assertEqual(self.parser.parse_combined_output(text), text)
13
+
14
+
15
+ class TestCodeParser(unittest.TestCase):
16
+ def setUp(self):
17
+ self.parser = CodeParser(language="python")
18
+
19
+ def test_parse(self):
20
+ self.parser.language = "python"
21
+ text = "```\n# test text\n```"
22
+ self.assertEqual(self.parser.parse(text), text.strip("```").strip("\n"))
23
+
24
+ def test_get_format_instructions(self):
25
+ self.assertEqual(
26
+ self.parser.get_format_instructions(),
27
+ "Output must contain text contained within triple square brackets (```)",
28
+ )
29
+
30
+
31
+ if __name__ == "__main__":
32
+ unittest.main()
@@ -1,32 +1,17 @@
1
- import json
2
1
  import re
3
- from collections import defaultdict
4
- from typing import Any, Set
5
2
 
6
3
  from langchain.schema.output_parser import BaseOutputParser
4
+ from langchain_core.exceptions import OutputParserException
5
+ from langchain_core.messages import BaseMessage
6
+ from langchain_core.output_parsers import StrOutputParser
7
7
 
8
8
  from ..language.block import CodeBlock
9
- from ..language.combine import Combiner
10
9
  from ..utils.logger import create_logger
11
10
 
12
11
  log = create_logger(__name__)
13
12
 
14
13
 
15
- PARSER_TYPES: Set[str] = {"code", "text", "eval"}
16
-
17
-
18
- class JanusParser(BaseOutputParser):
19
- def parse(self, text: str) -> str:
20
- """Parse the output text from the LLM.
21
-
22
- Arguments:
23
- text: The output text from the LLM
24
-
25
- Returns:
26
- A parsed version of the text
27
- """
28
- return text
29
-
14
+ class JanusParser:
30
15
  def parse_combined_output(self, text: str) -> str:
31
16
  """Parse the output text from the LLM when multiple inputs are combined
32
17
 
@@ -36,253 +21,39 @@ class JanusParser(BaseOutputParser):
36
21
  Returns:
37
22
  A parsed version of the text
38
23
  """
24
+ if isinstance(text, BaseMessage):
25
+ text = text.content
39
26
  return text
40
27
 
41
- def score(self, input_block: CodeBlock, output_text: str) -> float:
42
- """Validate and score the output text based upon the input CodeBlock.
43
- Output is a score between 0 and 1.
28
+ def parse_into_block(self, text: str, block: CodeBlock):
29
+ if isinstance(text, BaseMessage):
30
+ text = text.content
31
+ block.text = text
44
32
 
45
- Arguments:
46
- input_block: A `CodeBlock` representing the input to the LLM
47
- output_text: The parsed text returned by the LLM
33
+ def set_reference(self, block: CodeBlock):
34
+ pass
48
35
 
49
- Returns:
50
- A score between 0 and 1 (inclusive). A score of 1.0 indicates that
51
- the given text is fully acceptable, and no further attempts
52
- should be made.
53
- """
54
- return 1.0
55
36
 
56
- def get_format_instructions(self) -> str:
57
- return "No format requirements"
58
-
59
- @property
60
- def _type(self) -> str:
61
- return type(self).__name__
37
+ class GenericParser(StrOutputParser, JanusParser):
38
+ def parse(self, text: str) -> str:
39
+ if isinstance(text, BaseMessage):
40
+ return text.content
41
+ return text
62
42
 
63
43
 
64
- class CodeParser(JanusParser):
44
+ class CodeParser(BaseOutputParser[str], JanusParser):
65
45
  language: str
66
46
 
67
47
  def parse(self, text: str) -> str:
68
- """Parse the output text from the LLM.
69
-
70
- Arguments:
71
- text: The output text from the LLM
72
-
73
- Returns:
74
- A parsed version of the text
75
- """
48
+ if isinstance(text, BaseMessage):
49
+ text = text.content
76
50
  pattern = rf"```[^\S\r\n]*(?:{self.language}[^\S\r\n]*)?\n?(.*?)\n*```"
77
51
  code = re.search(pattern, text, re.DOTALL)
78
52
  if code is None:
79
- raise ValueError("Code not find code between triple backticks")
80
- return code.group(1)
81
-
82
- def score(self, input_block: CodeBlock, output_text: str) -> float:
83
- """The score for translated code is the percentage of this block's
84
- children which are present in the output
85
-
86
- Arguments:
87
- input_block: A `CodeBlock` representing the input to the LLM
88
- output_text: The parsed text returned by the LLM
89
-
90
- Returns:
91
- A score between 0 and 1 (inclusive). A score of 1.0 indicates that
92
- the given text is fully acceptable, and no further attempts
93
- should be made.
94
- """
95
- if not input_block.children:
96
- return 1.0
97
-
98
- missing_children = []
99
- for child in input_block.children:
100
- if not Combiner.contains_child(output_text, child):
101
- missing_children.append(child.id)
102
-
103
- if missing_children:
104
- log.warning(
105
- f"[{input_block.name}] Child placeholders not present in text: "
106
- f"{missing_children}"
53
+ raise OutputParserException(
54
+ "Code not find code between triple square brackets"
107
55
  )
108
- log.debug(f"Code:\n{output_text}")
109
-
110
- return 1.0 - len(missing_children) / len(input_block.children)
111
-
112
- def get_format_instructions(self) -> str:
113
- return "Output must contain text contained within triple backticks."
114
-
115
-
116
- class JsonLinesParser(JanusParser):
117
- def parse(self, text: str) -> str:
118
- """Parse the output text from the LLM.
119
-
120
- Arguments:
121
- text: The output text from the LLM.
122
-
123
- Returns:
124
- A parsed version of the text.
125
- """
126
- string = r"\"\w+\""
127
- number = r"-?\d+(?:\.\d*)?"
128
- json_value = rf"(?:{string}|{number})"
129
- json_line = rf"\s*{string} *: *{json_value},?\s*"
130
- pattern = "({" + rf"(?:{json_line})+" + "})"
131
- matches = list(re.finditer(pattern, text, re.DOTALL))
132
- if not matches:
133
- raise ValueError("Could not find JSON output")
134
-
135
- output_strings = [json.dumps(json.loads(match.group(1))) for match in matches]
136
- return "\n".join(output_strings)
137
-
138
- def parse_combined_output(self, text: str) -> str:
139
- """Parse the output text from the LLM when multiple inputs are combined.
140
-
141
- Arguments:
142
- text: The output text from the LLM.
143
-
144
- Returns:
145
- A parsed version of the text.
146
- """
147
- return self.parse(text)
148
-
149
- def get_format_instructions(self) -> str:
150
- """Get the format instructions for the parser.
151
-
152
- Returns:
153
- The format instructions for the LLM.
154
- """
155
- return "Output must contain one or more JSON-formatted blocks."
156
-
157
-
158
- class JsonParser(JsonLinesParser):
159
- def parse(self, text: str) -> str:
160
- """Parse the output text from the LLM.
161
-
162
- Arguments:
163
- text: The output text from the LLM.
164
-
165
- Returns:
166
- A parsed version of the text.
167
- """
168
- jsonl_text = super().parse(text)
169
- if len(jsonl_text.split("\n")) > 1:
170
- raise ValueError("Multiple JSON objects found")
171
-
172
- return jsonl_text
173
-
174
- def parse_combined_output(self, text: str) -> str:
175
- """Parse the output text from the LLM when multiple inputs are combined.
176
-
177
- Arguments:
178
- text: The output text from the LLM.
179
-
180
- Returns:
181
- A parsed version of the text.
182
- """
183
- jsonl_text = JsonLinesParser.parse(self, text)
184
- json_lines = jsonl_text.split("\n")
185
- output_obj = {i: json.loads(t) for i, t in enumerate(json_lines)}
186
- return json.dumps(output_obj)
56
+ return str(code.group(1))
187
57
 
188
58
  def get_format_instructions(self) -> str:
189
- """Get the format instructions for the parser.
190
-
191
- Returns:
192
- The format instructions for the LLM.
193
- """
194
- return "Output must contain exactly one JSON-formatted block."
195
-
196
-
197
- class EvaluationParser(JsonParser):
198
- expected_keys: Set[str]
199
-
200
- def __init__(self, expected_keys: Set[str], **kwargs: Any):
201
- """Create a new EvaluationParser.
202
-
203
- Arguments:
204
- expected_keys: The set of keys that should be present in the JSON
205
- object
206
- kwargs: Additional arguments to pass to the parent class
207
- """
208
- super().__init__(expected_keys=expected_keys, **kwargs)
209
- self.expected_keys = {k.lower() for k in expected_keys}
210
-
211
- def parse(self, text: str) -> str:
212
- """Parse the JSON object, convert keys to lowercase, filter out
213
- unexpected keys
214
-
215
- Arguments:
216
- text: The output text from the LLM.
217
-
218
- Returns:
219
- A parsed version of the text.
220
- """
221
- json_text = super().parse(text)
222
- obj = json.loads(json_text)
223
- obj = {k.lower(): v for k, v in obj.items()}
224
- obj = {k: v for k, v in obj.items() if k in self.expected_keys}
225
- return json.dumps(obj)
226
-
227
- def parse_combined_output(self, text: str) -> str:
228
- """Parse the JSON object, convert keys to lowercase, filter out
229
- unexpected keys, and average the values
230
-
231
- Arguments:
232
- text: The output text from the LLM.
233
-
234
- Returns:
235
- A parsed version of the text.
236
- """
237
- json_text = super().parse_combined_output(text)
238
- multi_obj = json.loads(json_text)
239
- n_evals = len(multi_obj)
240
-
241
- output_obj = defaultdict(float)
242
- for obj in multi_obj.values():
243
- for k, v in obj.items():
244
- output_obj[k] += v / n_evals
245
-
246
- return json.dumps(output_obj)
247
-
248
- def score(self, input_block: CodeBlock, output_text: str) -> float:
249
- """The score for the output text is the percentage of expected keys
250
- that are present in the json object. Non-numeric values count for
251
- half.
252
-
253
- Arguments:
254
- input_block: A `CodeBlock` representing the input to the LLM
255
- output_text: The parsed text returned by the LLM
256
-
257
- Returns:
258
- A score between 0 and 1 (inclusive). A score of 1.0 indicates that
259
- the given text is fully acceptable, and no further attempts
260
- should be made.
261
- """
262
- obj = json.loads(output_text)
263
-
264
- expected_keys = self.expected_keys.intersection(obj.keys())
265
- missing_keys = self.expected_keys.difference(obj.keys())
266
- if missing_keys:
267
- log.warning(f"[{input_block.name}] Expected keys missing: {missing_keys}")
268
-
269
- non_numerics = {k: v for k, v in obj.items() if not isinstance(v, (int, float))}
270
- if non_numerics:
271
- log.warning(f"[{input_block.name}] Non-numeric values: {non_numerics}")
272
-
273
- if missing_keys or non_numerics:
274
- log.debug(f"Text:\n{output_text}")
275
-
276
- return (len(expected_keys) - len(non_numerics) * 0.5) / len(self.expected_keys)
277
-
278
- def get_format_instructions(self) -> str:
279
- """Get the format instructions for the parser.
280
-
281
- Returns:
282
- The format instructions for the LLM.
283
- """
284
- return (
285
- "Output must contain exactly one JSON-formatted block. The JSON "
286
- "object should contain only the keys contained in the provided "
287
- "expected_keys set (if any), and values should be numeric."
288
- )
59
+ return "Output must contain text contained within triple square brackets (```)"
@@ -0,0 +1,169 @@
1
+ import json
2
+ import re
3
+
4
+ from langchain.output_parsers import PydanticOutputParser
5
+ from langchain.output_parsers.json import parse_json_markdown
6
+ from langchain.schema.output_parser import BaseOutputParser
7
+ from langchain_core.exceptions import OutputParserException
8
+ from langchain_core.messages import AIMessage
9
+ from langchain_core.pydantic_v1 import BaseModel, Field
10
+
11
+ from ..language.block import CodeBlock
12
+ from ..utils.logger import create_logger
13
+ from .code_parser import JanusParser
14
+
15
+ log = create_logger(__name__)
16
+
17
+
18
+ class MultiDoc(BaseModel):
19
+ docstring: str = Field(
20
+ description="A Sphinx-style docstring for the code, including a summary "
21
+ "of its functionality; the name, type, and description of "
22
+ "any parameters or returns; and any potential exceptions "
23
+ "that might arise in its execution"
24
+ )
25
+ example_usage: str = Field(
26
+ description="A well-commented minimal example utilizing the given "
27
+ "code's functionality"
28
+ )
29
+ pseudocode: str = Field(
30
+ description="A Python-stype pseudocode implementation of the module or "
31
+ "function's behavior"
32
+ )
33
+
34
+
35
+ class MultiDocumentationParser(PydanticOutputParser, JanusParser):
36
+ block_name: str = ""
37
+
38
+ def __init__(self):
39
+ PydanticOutputParser.__init__(self, pydantic_object=MultiDoc)
40
+
41
+ def set_reference(self, block: CodeBlock):
42
+ self.block_name = block.name
43
+
44
+ def parse(self, text: str) -> str:
45
+ if isinstance(text, AIMessage):
46
+ text = text.content
47
+ try:
48
+ docs = json.loads(super().parse(text).json())
49
+ except (OutputParserException, json.JSONDecodeError):
50
+ log.debug(f"Invalid JSON object. Output:\n{text}")
51
+ raise
52
+ docs["name"] = self.block_name
53
+ return json.dumps(docs)
54
+
55
+ def parse_combined_output(self, text: str) -> str:
56
+ """Parse the output text from the LLM when multiple inputs are combined.
57
+
58
+ Arguments:
59
+ text: The output text from the LLM.
60
+
61
+ Returns:
62
+ A parsed version of the text.
63
+ """
64
+ objs = [
65
+ parse_json_markdown(line.strip()) for line in text.split("\n") if line.strip()
66
+ ]
67
+ output_obj = {d.pop("name"): d for d in objs}
68
+ return json.dumps(output_obj)
69
+
70
+ def get_format_instructions(self) -> str:
71
+ """Get the format instructions for the parser.
72
+
73
+ Returns:
74
+ The format instructions for the LLM.
75
+ """
76
+ return (
77
+ "Output must contain a sphinx-style docstring, example usage, and "
78
+ "pseudocode, all in a json-formatted string with the following fields: "
79
+ '"docstring", "example_usage", and "pseudocode".'
80
+ )
81
+
82
+ @property
83
+ def _type(self) -> str:
84
+ return self.__class__.name
85
+
86
+
87
+ class MadlibsDocumentationParser(BaseOutputParser[str], JanusParser):
88
+ expected_keys: set[str]
89
+
90
+ def __init__(self):
91
+ super().__init__(expected_keys=[])
92
+
93
+ def set_reference(self, block: CodeBlock):
94
+ comment_ids = re.findall(r"<(?:BLOCK|INLINE)_COMMENT (\w{8})>", block.text)
95
+ self.expected_keys = set(comment_ids)
96
+
97
+ def parse(self, text: str) -> str:
98
+ if isinstance(text, AIMessage):
99
+ text = text.content
100
+ try:
101
+ obj = parse_json_markdown(text)
102
+ except json.JSONDecodeError as e:
103
+ log.debug(f"Invalid JSON object. Output:\n{text}")
104
+ raise OutputParserException(f"Got invalid JSON object. Error: {e}")
105
+
106
+ if not isinstance(obj, dict):
107
+ raise OutputParserException(
108
+ f"Got invalid return object. Expected a dictionary, but got {type(obj)}"
109
+ )
110
+
111
+ seen_keys = set(obj.keys())
112
+ missing_keys = self.expected_keys.difference(obj.keys())
113
+ invalid_keys = seen_keys.difference(self.expected_keys)
114
+ if missing_keys:
115
+ log.debug(f"Missing keys: {missing_keys}")
116
+ if invalid_keys:
117
+ log.debug(f"Invalid keys: {invalid_keys}")
118
+ log.debug(f"Missing keys: {missing_keys}")
119
+ raise OutputParserException(
120
+ f"Got invalid return object. Missing the following expected "
121
+ f"keys: {missing_keys}"
122
+ )
123
+
124
+ for key in invalid_keys:
125
+ del obj[key]
126
+
127
+ for value in obj.values():
128
+ if not isinstance(value, str):
129
+ raise OutputParserException(
130
+ f"Got invalid return object. Expected all string values,"
131
+ f' but got type "{type(value)}"'
132
+ )
133
+
134
+ return json.dumps(obj)
135
+
136
+ def parse_combined_output(self, text: str) -> str:
137
+ """Parse the output text from the LLM when multiple inputs are combined.
138
+
139
+ Arguments:
140
+ text: The output text from the LLM.
141
+
142
+ Returns:
143
+ A parsed version of the text.
144
+ """
145
+ if not text.strip():
146
+ return str({})
147
+ objs = [
148
+ parse_json_markdown(line.strip()) for line in text.split("\n") if line.strip()
149
+ ]
150
+ output_obj = {}
151
+ for obj in objs:
152
+ output_obj.update(obj)
153
+ return json.dumps(output_obj)
154
+
155
+ def get_format_instructions(self) -> str:
156
+ """Get the format instructions for the parser.
157
+
158
+ Returns:
159
+ The format instructions for the LLM.
160
+ """
161
+ return (
162
+ "Output must contain exactly one JSON-formatted block. The JSON "
163
+ "object should contain only (and all of) the comment IDs present "
164
+ "in the input code."
165
+ )
166
+
167
+ @property
168
+ def _type(self) -> str:
169
+ return self.__class__.name
@@ -0,0 +1,80 @@
1
+ import json
2
+
3
+ from langchain.output_parsers import PydanticOutputParser
4
+ from langchain_core.pydantic_v1 import BaseModel, Field, validator
5
+
6
+ from ..utils.logger import create_logger
7
+ from .code_parser import JanusParser
8
+
9
+ log = create_logger(__name__)
10
+
11
+
12
+ class Eval(BaseModel):
13
+ syntax: float = Field(description="A numeric score (0-4) for code syntax")
14
+ style: float = Field(description="A numeric score (0-4) for code style")
15
+ completeness: float = Field(description="A numeric score (0-4) for code completeness")
16
+ correctness: float = Field(description="A numeric score (0-4) for code correctness")
17
+
18
+ # You can add custom validation logic easily with Pydantic.
19
+ @validator("*")
20
+ def score_is_valid(cls, v: float | int):
21
+ try:
22
+ v = float(v)
23
+ except ValueError:
24
+ raise ValueError("must be a number")
25
+
26
+ if not 0 <= v <= 4:
27
+ raise ValueError("must be a value between 0 and 4 inclusive")
28
+
29
+ return v
30
+
31
+ def __add__(self, other):
32
+ if isinstance(other, int) and other == 0:
33
+ return self.copy()
34
+ return Eval.construct(
35
+ syntax=self.syntax + other.syntax,
36
+ style=self.style + other.style,
37
+ correctness=self.correctness + other.correctness,
38
+ completeness=self.completeness + other.completeness,
39
+ )
40
+
41
+ def __radd__(self, other):
42
+ return self.__add__(other)
43
+
44
+ def __truediv__(self, other):
45
+ if isinstance(other, int):
46
+ return Eval.construct(
47
+ syntax=self.syntax / other,
48
+ style=self.style / other,
49
+ correctness=self.correctness / other,
50
+ completeness=self.completeness / other,
51
+ )
52
+ return Eval.construct(
53
+ syntax=self.syntax / other.syntax,
54
+ style=self.style / other.style,
55
+ correctness=self.correctness / other.correctness,
56
+ completeness=self.completeness / other.completeness,
57
+ )
58
+
59
+
60
+ class EvaluationParser(PydanticOutputParser, JanusParser):
61
+ def __init__(self):
62
+ PydanticOutputParser.__init__(self, pydantic_object=Eval)
63
+
64
+ def parse(self, text: str) -> str:
65
+ eval = super().parse(text)
66
+ return json.dumps(eval.json())
67
+
68
+ def parse_combined_output(self, text: str) -> str:
69
+ """Parse the JSON object, convert keys to lowercase, filter out
70
+ unexpected keys, and average the values
71
+
72
+ Arguments:
73
+ text: The output text from the LLM.
74
+
75
+ Returns:
76
+ A parsed version of the text.
77
+ """
78
+ objs = [super().parse(line.strip()) for line in text.split("\n")]
79
+ avg_obj = sum(objs) / len(objs)
80
+ return json.dumps(avg_obj.json())