janus-llm 1.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- janus/__init__.py +9 -1
- janus/__main__.py +4 -0
- janus/_tests/test_cli.py +128 -0
- janus/_tests/test_translate.py +49 -7
- janus/cli.py +530 -46
- janus/converter.py +50 -19
- janus/embedding/_tests/test_collections.py +2 -8
- janus/embedding/_tests/test_database.py +32 -0
- janus/embedding/_tests/test_vectorize.py +9 -4
- janus/embedding/collections.py +49 -6
- janus/embedding/embedding_models_info.py +130 -0
- janus/embedding/vectorize.py +53 -62
- janus/language/_tests/__init__.py +0 -0
- janus/language/_tests/test_combine.py +62 -0
- janus/language/_tests/test_splitter.py +16 -0
- janus/language/binary/_tests/test_binary.py +16 -1
- janus/language/binary/binary.py +10 -3
- janus/language/block.py +31 -30
- janus/language/combine.py +26 -34
- janus/language/mumps/_tests/test_mumps.py +2 -2
- janus/language/mumps/mumps.py +93 -9
- janus/language/naive/__init__.py +4 -0
- janus/language/naive/basic_splitter.py +14 -0
- janus/language/naive/chunk_splitter.py +26 -0
- janus/language/naive/registry.py +13 -0
- janus/language/naive/simple_ast.py +18 -0
- janus/language/naive/tag_splitter.py +61 -0
- janus/language/splitter.py +168 -74
- janus/language/treesitter/_tests/test_treesitter.py +19 -14
- janus/language/treesitter/treesitter.py +37 -13
- janus/llm/model_callbacks.py +177 -0
- janus/llm/models_info.py +165 -72
- janus/metrics/__init__.py +8 -0
- janus/metrics/_tests/__init__.py +0 -0
- janus/metrics/_tests/reference.py +2 -0
- janus/metrics/_tests/target.py +2 -0
- janus/metrics/_tests/test_bleu.py +56 -0
- janus/metrics/_tests/test_chrf.py +67 -0
- janus/metrics/_tests/test_file_pairing.py +59 -0
- janus/metrics/_tests/test_llm.py +91 -0
- janus/metrics/_tests/test_reading.py +28 -0
- janus/metrics/_tests/test_rouge_score.py +65 -0
- janus/metrics/_tests/test_similarity_score.py +23 -0
- janus/metrics/_tests/test_treesitter_metrics.py +110 -0
- janus/metrics/bleu.py +66 -0
- janus/metrics/chrf.py +55 -0
- janus/metrics/cli.py +7 -0
- janus/metrics/complexity_metrics.py +208 -0
- janus/metrics/file_pairing.py +113 -0
- janus/metrics/llm_metrics.py +202 -0
- janus/metrics/metric.py +466 -0
- janus/metrics/reading.py +70 -0
- janus/metrics/rouge_score.py +96 -0
- janus/metrics/similarity.py +53 -0
- janus/metrics/splitting.py +38 -0
- janus/parsers/_tests/__init__.py +0 -0
- janus/parsers/_tests/test_code_parser.py +32 -0
- janus/parsers/code_parser.py +24 -253
- janus/parsers/doc_parser.py +169 -0
- janus/parsers/eval_parser.py +80 -0
- janus/parsers/reqs_parser.py +72 -0
- janus/prompts/prompt.py +103 -30
- janus/translate.py +636 -111
- janus/utils/_tests/__init__.py +0 -0
- janus/utils/_tests/test_logger.py +67 -0
- janus/utils/_tests/test_progress.py +20 -0
- janus/utils/enums.py +56 -3
- janus/utils/progress.py +56 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/METADATA +27 -11
- janus_llm-2.0.1.dist-info/RECORD +94 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/WHEEL +1 -1
- janus_llm-1.0.0.dist-info/RECORD +0 -48
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/LICENSE +0 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,38 @@
|
|
1
|
+
from typing import Callable
|
2
|
+
|
3
|
+
SPLITTING_METHODS: dict[str, Callable[[str, str], list[str]]] = {}
|
4
|
+
|
5
|
+
|
6
|
+
def register_splitting_method(name: None | str = None) -> Callable[[Callable], Callable]:
|
7
|
+
"""Registers a pairing method for splitting strings in files
|
8
|
+
|
9
|
+
Arguments:
|
10
|
+
name: The name of the splitting method. If None, the function name is used.
|
11
|
+
help: The help text for the pairing method.
|
12
|
+
|
13
|
+
Returns:
|
14
|
+
The decorator function.
|
15
|
+
"""
|
16
|
+
|
17
|
+
def decorator(f: Callable[[str, str], list[tuple[str, str]]]):
|
18
|
+
if name is None:
|
19
|
+
splitting_name = f.__name__
|
20
|
+
else:
|
21
|
+
splitting_name = name
|
22
|
+
SPLITTING_METHODS[splitting_name] = f
|
23
|
+
return f
|
24
|
+
|
25
|
+
return decorator
|
26
|
+
|
27
|
+
|
28
|
+
@register_splitting_method(name="file")
|
29
|
+
def split_by_file(src: str, **kwargs) -> list[str]:
|
30
|
+
"""Split the source text by file
|
31
|
+
|
32
|
+
Arguments:
|
33
|
+
src: The source text.
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
A list of strings.
|
37
|
+
"""
|
38
|
+
return [src]
|
File without changes
|
@@ -0,0 +1,32 @@
|
|
1
|
+
import unittest
|
2
|
+
|
3
|
+
from ..code_parser import CodeParser, JanusParser
|
4
|
+
|
5
|
+
|
6
|
+
class TestJanusParser(unittest.TestCase):
|
7
|
+
def setUp(self):
|
8
|
+
self.parser = JanusParser()
|
9
|
+
|
10
|
+
def test_parse_combined_output(self):
|
11
|
+
text = "test text"
|
12
|
+
self.assertEqual(self.parser.parse_combined_output(text), text)
|
13
|
+
|
14
|
+
|
15
|
+
class TestCodeParser(unittest.TestCase):
|
16
|
+
def setUp(self):
|
17
|
+
self.parser = CodeParser(language="python")
|
18
|
+
|
19
|
+
def test_parse(self):
|
20
|
+
self.parser.language = "python"
|
21
|
+
text = "```\n# test text\n```"
|
22
|
+
self.assertEqual(self.parser.parse(text), text.strip("```").strip("\n"))
|
23
|
+
|
24
|
+
def test_get_format_instructions(self):
|
25
|
+
self.assertEqual(
|
26
|
+
self.parser.get_format_instructions(),
|
27
|
+
"Output must contain text contained within triple square brackets (```)",
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
if __name__ == "__main__":
|
32
|
+
unittest.main()
|
janus/parsers/code_parser.py
CHANGED
@@ -1,32 +1,17 @@
|
|
1
|
-
import json
|
2
1
|
import re
|
3
|
-
from collections import defaultdict
|
4
|
-
from typing import Any, Set
|
5
2
|
|
6
3
|
from langchain.schema.output_parser import BaseOutputParser
|
4
|
+
from langchain_core.exceptions import OutputParserException
|
5
|
+
from langchain_core.messages import BaseMessage
|
6
|
+
from langchain_core.output_parsers import StrOutputParser
|
7
7
|
|
8
8
|
from ..language.block import CodeBlock
|
9
|
-
from ..language.combine import Combiner
|
10
9
|
from ..utils.logger import create_logger
|
11
10
|
|
12
11
|
log = create_logger(__name__)
|
13
12
|
|
14
13
|
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
class JanusParser(BaseOutputParser):
|
19
|
-
def parse(self, text: str) -> str:
|
20
|
-
"""Parse the output text from the LLM.
|
21
|
-
|
22
|
-
Arguments:
|
23
|
-
text: The output text from the LLM
|
24
|
-
|
25
|
-
Returns:
|
26
|
-
A parsed version of the text
|
27
|
-
"""
|
28
|
-
return text
|
29
|
-
|
14
|
+
class JanusParser:
|
30
15
|
def parse_combined_output(self, text: str) -> str:
|
31
16
|
"""Parse the output text from the LLM when multiple inputs are combined
|
32
17
|
|
@@ -36,253 +21,39 @@ class JanusParser(BaseOutputParser):
|
|
36
21
|
Returns:
|
37
22
|
A parsed version of the text
|
38
23
|
"""
|
24
|
+
if isinstance(text, BaseMessage):
|
25
|
+
text = text.content
|
39
26
|
return text
|
40
27
|
|
41
|
-
def
|
42
|
-
|
43
|
-
|
28
|
+
def parse_into_block(self, text: str, block: CodeBlock):
|
29
|
+
if isinstance(text, BaseMessage):
|
30
|
+
text = text.content
|
31
|
+
block.text = text
|
44
32
|
|
45
|
-
|
46
|
-
|
47
|
-
output_text: The parsed text returned by the LLM
|
33
|
+
def set_reference(self, block: CodeBlock):
|
34
|
+
pass
|
48
35
|
|
49
|
-
Returns:
|
50
|
-
A score between 0 and 1 (inclusive). A score of 1.0 indicates that
|
51
|
-
the given text is fully acceptable, and no further attempts
|
52
|
-
should be made.
|
53
|
-
"""
|
54
|
-
return 1.0
|
55
36
|
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
return type(self).__name__
|
37
|
+
class GenericParser(StrOutputParser, JanusParser):
|
38
|
+
def parse(self, text: str) -> str:
|
39
|
+
if isinstance(text, BaseMessage):
|
40
|
+
return text.content
|
41
|
+
return text
|
62
42
|
|
63
43
|
|
64
|
-
class CodeParser(JanusParser):
|
44
|
+
class CodeParser(BaseOutputParser[str], JanusParser):
|
65
45
|
language: str
|
66
46
|
|
67
47
|
def parse(self, text: str) -> str:
|
68
|
-
|
69
|
-
|
70
|
-
Arguments:
|
71
|
-
text: The output text from the LLM
|
72
|
-
|
73
|
-
Returns:
|
74
|
-
A parsed version of the text
|
75
|
-
"""
|
48
|
+
if isinstance(text, BaseMessage):
|
49
|
+
text = text.content
|
76
50
|
pattern = rf"```[^\S\r\n]*(?:{self.language}[^\S\r\n]*)?\n?(.*?)\n*```"
|
77
51
|
code = re.search(pattern, text, re.DOTALL)
|
78
52
|
if code is None:
|
79
|
-
raise
|
80
|
-
|
81
|
-
|
82
|
-
def score(self, input_block: CodeBlock, output_text: str) -> float:
|
83
|
-
"""The score for translated code is the percentage of this block's
|
84
|
-
children which are present in the output
|
85
|
-
|
86
|
-
Arguments:
|
87
|
-
input_block: A `CodeBlock` representing the input to the LLM
|
88
|
-
output_text: The parsed text returned by the LLM
|
89
|
-
|
90
|
-
Returns:
|
91
|
-
A score between 0 and 1 (inclusive). A score of 1.0 indicates that
|
92
|
-
the given text is fully acceptable, and no further attempts
|
93
|
-
should be made.
|
94
|
-
"""
|
95
|
-
if not input_block.children:
|
96
|
-
return 1.0
|
97
|
-
|
98
|
-
missing_children = []
|
99
|
-
for child in input_block.children:
|
100
|
-
if not Combiner.contains_child(output_text, child):
|
101
|
-
missing_children.append(child.id)
|
102
|
-
|
103
|
-
if missing_children:
|
104
|
-
log.warning(
|
105
|
-
f"[{input_block.name}] Child placeholders not present in text: "
|
106
|
-
f"{missing_children}"
|
53
|
+
raise OutputParserException(
|
54
|
+
"Code not find code between triple square brackets"
|
107
55
|
)
|
108
|
-
|
109
|
-
|
110
|
-
return 1.0 - len(missing_children) / len(input_block.children)
|
111
|
-
|
112
|
-
def get_format_instructions(self) -> str:
|
113
|
-
return "Output must contain text contained within triple backticks."
|
114
|
-
|
115
|
-
|
116
|
-
class JsonLinesParser(JanusParser):
|
117
|
-
def parse(self, text: str) -> str:
|
118
|
-
"""Parse the output text from the LLM.
|
119
|
-
|
120
|
-
Arguments:
|
121
|
-
text: The output text from the LLM.
|
122
|
-
|
123
|
-
Returns:
|
124
|
-
A parsed version of the text.
|
125
|
-
"""
|
126
|
-
string = r"\"\w+\""
|
127
|
-
number = r"-?\d+(?:\.\d*)?"
|
128
|
-
json_value = rf"(?:{string}|{number})"
|
129
|
-
json_line = rf"\s*{string} *: *{json_value},?\s*"
|
130
|
-
pattern = "({" + rf"(?:{json_line})+" + "})"
|
131
|
-
matches = list(re.finditer(pattern, text, re.DOTALL))
|
132
|
-
if not matches:
|
133
|
-
raise ValueError("Could not find JSON output")
|
134
|
-
|
135
|
-
output_strings = [json.dumps(json.loads(match.group(1))) for match in matches]
|
136
|
-
return "\n".join(output_strings)
|
137
|
-
|
138
|
-
def parse_combined_output(self, text: str) -> str:
|
139
|
-
"""Parse the output text from the LLM when multiple inputs are combined.
|
140
|
-
|
141
|
-
Arguments:
|
142
|
-
text: The output text from the LLM.
|
143
|
-
|
144
|
-
Returns:
|
145
|
-
A parsed version of the text.
|
146
|
-
"""
|
147
|
-
return self.parse(text)
|
148
|
-
|
149
|
-
def get_format_instructions(self) -> str:
|
150
|
-
"""Get the format instructions for the parser.
|
151
|
-
|
152
|
-
Returns:
|
153
|
-
The format instructions for the LLM.
|
154
|
-
"""
|
155
|
-
return "Output must contain one or more JSON-formatted blocks."
|
156
|
-
|
157
|
-
|
158
|
-
class JsonParser(JsonLinesParser):
|
159
|
-
def parse(self, text: str) -> str:
|
160
|
-
"""Parse the output text from the LLM.
|
161
|
-
|
162
|
-
Arguments:
|
163
|
-
text: The output text from the LLM.
|
164
|
-
|
165
|
-
Returns:
|
166
|
-
A parsed version of the text.
|
167
|
-
"""
|
168
|
-
jsonl_text = super().parse(text)
|
169
|
-
if len(jsonl_text.split("\n")) > 1:
|
170
|
-
raise ValueError("Multiple JSON objects found")
|
171
|
-
|
172
|
-
return jsonl_text
|
173
|
-
|
174
|
-
def parse_combined_output(self, text: str) -> str:
|
175
|
-
"""Parse the output text from the LLM when multiple inputs are combined.
|
176
|
-
|
177
|
-
Arguments:
|
178
|
-
text: The output text from the LLM.
|
179
|
-
|
180
|
-
Returns:
|
181
|
-
A parsed version of the text.
|
182
|
-
"""
|
183
|
-
jsonl_text = JsonLinesParser.parse(self, text)
|
184
|
-
json_lines = jsonl_text.split("\n")
|
185
|
-
output_obj = {i: json.loads(t) for i, t in enumerate(json_lines)}
|
186
|
-
return json.dumps(output_obj)
|
56
|
+
return str(code.group(1))
|
187
57
|
|
188
58
|
def get_format_instructions(self) -> str:
|
189
|
-
"
|
190
|
-
|
191
|
-
Returns:
|
192
|
-
The format instructions for the LLM.
|
193
|
-
"""
|
194
|
-
return "Output must contain exactly one JSON-formatted block."
|
195
|
-
|
196
|
-
|
197
|
-
class EvaluationParser(JsonParser):
|
198
|
-
expected_keys: Set[str]
|
199
|
-
|
200
|
-
def __init__(self, expected_keys: Set[str], **kwargs: Any):
|
201
|
-
"""Create a new EvaluationParser.
|
202
|
-
|
203
|
-
Arguments:
|
204
|
-
expected_keys: The set of keys that should be present in the JSON
|
205
|
-
object
|
206
|
-
kwargs: Additional arguments to pass to the parent class
|
207
|
-
"""
|
208
|
-
super().__init__(expected_keys=expected_keys, **kwargs)
|
209
|
-
self.expected_keys = {k.lower() for k in expected_keys}
|
210
|
-
|
211
|
-
def parse(self, text: str) -> str:
|
212
|
-
"""Parse the JSON object, convert keys to lowercase, filter out
|
213
|
-
unexpected keys
|
214
|
-
|
215
|
-
Arguments:
|
216
|
-
text: The output text from the LLM.
|
217
|
-
|
218
|
-
Returns:
|
219
|
-
A parsed version of the text.
|
220
|
-
"""
|
221
|
-
json_text = super().parse(text)
|
222
|
-
obj = json.loads(json_text)
|
223
|
-
obj = {k.lower(): v for k, v in obj.items()}
|
224
|
-
obj = {k: v for k, v in obj.items() if k in self.expected_keys}
|
225
|
-
return json.dumps(obj)
|
226
|
-
|
227
|
-
def parse_combined_output(self, text: str) -> str:
|
228
|
-
"""Parse the JSON object, convert keys to lowercase, filter out
|
229
|
-
unexpected keys, and average the values
|
230
|
-
|
231
|
-
Arguments:
|
232
|
-
text: The output text from the LLM.
|
233
|
-
|
234
|
-
Returns:
|
235
|
-
A parsed version of the text.
|
236
|
-
"""
|
237
|
-
json_text = super().parse_combined_output(text)
|
238
|
-
multi_obj = json.loads(json_text)
|
239
|
-
n_evals = len(multi_obj)
|
240
|
-
|
241
|
-
output_obj = defaultdict(float)
|
242
|
-
for obj in multi_obj.values():
|
243
|
-
for k, v in obj.items():
|
244
|
-
output_obj[k] += v / n_evals
|
245
|
-
|
246
|
-
return json.dumps(output_obj)
|
247
|
-
|
248
|
-
def score(self, input_block: CodeBlock, output_text: str) -> float:
|
249
|
-
"""The score for the output text is the percentage of expected keys
|
250
|
-
that are present in the json object. Non-numeric values count for
|
251
|
-
half.
|
252
|
-
|
253
|
-
Arguments:
|
254
|
-
input_block: A `CodeBlock` representing the input to the LLM
|
255
|
-
output_text: The parsed text returned by the LLM
|
256
|
-
|
257
|
-
Returns:
|
258
|
-
A score between 0 and 1 (inclusive). A score of 1.0 indicates that
|
259
|
-
the given text is fully acceptable, and no further attempts
|
260
|
-
should be made.
|
261
|
-
"""
|
262
|
-
obj = json.loads(output_text)
|
263
|
-
|
264
|
-
expected_keys = self.expected_keys.intersection(obj.keys())
|
265
|
-
missing_keys = self.expected_keys.difference(obj.keys())
|
266
|
-
if missing_keys:
|
267
|
-
log.warning(f"[{input_block.name}] Expected keys missing: {missing_keys}")
|
268
|
-
|
269
|
-
non_numerics = {k: v for k, v in obj.items() if not isinstance(v, (int, float))}
|
270
|
-
if non_numerics:
|
271
|
-
log.warning(f"[{input_block.name}] Non-numeric values: {non_numerics}")
|
272
|
-
|
273
|
-
if missing_keys or non_numerics:
|
274
|
-
log.debug(f"Text:\n{output_text}")
|
275
|
-
|
276
|
-
return (len(expected_keys) - len(non_numerics) * 0.5) / len(self.expected_keys)
|
277
|
-
|
278
|
-
def get_format_instructions(self) -> str:
|
279
|
-
"""Get the format instructions for the parser.
|
280
|
-
|
281
|
-
Returns:
|
282
|
-
The format instructions for the LLM.
|
283
|
-
"""
|
284
|
-
return (
|
285
|
-
"Output must contain exactly one JSON-formatted block. The JSON "
|
286
|
-
"object should contain only the keys contained in the provided "
|
287
|
-
"expected_keys set (if any), and values should be numeric."
|
288
|
-
)
|
59
|
+
return "Output must contain text contained within triple square brackets (```)"
|
@@ -0,0 +1,169 @@
|
|
1
|
+
import json
|
2
|
+
import re
|
3
|
+
|
4
|
+
from langchain.output_parsers import PydanticOutputParser
|
5
|
+
from langchain.output_parsers.json import parse_json_markdown
|
6
|
+
from langchain.schema.output_parser import BaseOutputParser
|
7
|
+
from langchain_core.exceptions import OutputParserException
|
8
|
+
from langchain_core.messages import AIMessage
|
9
|
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
10
|
+
|
11
|
+
from ..language.block import CodeBlock
|
12
|
+
from ..utils.logger import create_logger
|
13
|
+
from .code_parser import JanusParser
|
14
|
+
|
15
|
+
log = create_logger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class MultiDoc(BaseModel):
|
19
|
+
docstring: str = Field(
|
20
|
+
description="A Sphinx-style docstring for the code, including a summary "
|
21
|
+
"of its functionality; the name, type, and description of "
|
22
|
+
"any parameters or returns; and any potential exceptions "
|
23
|
+
"that might arise in its execution"
|
24
|
+
)
|
25
|
+
example_usage: str = Field(
|
26
|
+
description="A well-commented minimal example utilizing the given "
|
27
|
+
"code's functionality"
|
28
|
+
)
|
29
|
+
pseudocode: str = Field(
|
30
|
+
description="A Python-stype pseudocode implementation of the module or "
|
31
|
+
"function's behavior"
|
32
|
+
)
|
33
|
+
|
34
|
+
|
35
|
+
class MultiDocumentationParser(PydanticOutputParser, JanusParser):
|
36
|
+
block_name: str = ""
|
37
|
+
|
38
|
+
def __init__(self):
|
39
|
+
PydanticOutputParser.__init__(self, pydantic_object=MultiDoc)
|
40
|
+
|
41
|
+
def set_reference(self, block: CodeBlock):
|
42
|
+
self.block_name = block.name
|
43
|
+
|
44
|
+
def parse(self, text: str) -> str:
|
45
|
+
if isinstance(text, AIMessage):
|
46
|
+
text = text.content
|
47
|
+
try:
|
48
|
+
docs = json.loads(super().parse(text).json())
|
49
|
+
except (OutputParserException, json.JSONDecodeError):
|
50
|
+
log.debug(f"Invalid JSON object. Output:\n{text}")
|
51
|
+
raise
|
52
|
+
docs["name"] = self.block_name
|
53
|
+
return json.dumps(docs)
|
54
|
+
|
55
|
+
def parse_combined_output(self, text: str) -> str:
|
56
|
+
"""Parse the output text from the LLM when multiple inputs are combined.
|
57
|
+
|
58
|
+
Arguments:
|
59
|
+
text: The output text from the LLM.
|
60
|
+
|
61
|
+
Returns:
|
62
|
+
A parsed version of the text.
|
63
|
+
"""
|
64
|
+
objs = [
|
65
|
+
parse_json_markdown(line.strip()) for line in text.split("\n") if line.strip()
|
66
|
+
]
|
67
|
+
output_obj = {d.pop("name"): d for d in objs}
|
68
|
+
return json.dumps(output_obj)
|
69
|
+
|
70
|
+
def get_format_instructions(self) -> str:
|
71
|
+
"""Get the format instructions for the parser.
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
The format instructions for the LLM.
|
75
|
+
"""
|
76
|
+
return (
|
77
|
+
"Output must contain a sphinx-style docstring, example usage, and "
|
78
|
+
"pseudocode, all in a json-formatted string with the following fields: "
|
79
|
+
'"docstring", "example_usage", and "pseudocode".'
|
80
|
+
)
|
81
|
+
|
82
|
+
@property
|
83
|
+
def _type(self) -> str:
|
84
|
+
return self.__class__.name
|
85
|
+
|
86
|
+
|
87
|
+
class MadlibsDocumentationParser(BaseOutputParser[str], JanusParser):
|
88
|
+
expected_keys: set[str]
|
89
|
+
|
90
|
+
def __init__(self):
|
91
|
+
super().__init__(expected_keys=[])
|
92
|
+
|
93
|
+
def set_reference(self, block: CodeBlock):
|
94
|
+
comment_ids = re.findall(r"<(?:BLOCK|INLINE)_COMMENT (\w{8})>", block.text)
|
95
|
+
self.expected_keys = set(comment_ids)
|
96
|
+
|
97
|
+
def parse(self, text: str) -> str:
|
98
|
+
if isinstance(text, AIMessage):
|
99
|
+
text = text.content
|
100
|
+
try:
|
101
|
+
obj = parse_json_markdown(text)
|
102
|
+
except json.JSONDecodeError as e:
|
103
|
+
log.debug(f"Invalid JSON object. Output:\n{text}")
|
104
|
+
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
105
|
+
|
106
|
+
if not isinstance(obj, dict):
|
107
|
+
raise OutputParserException(
|
108
|
+
f"Got invalid return object. Expected a dictionary, but got {type(obj)}"
|
109
|
+
)
|
110
|
+
|
111
|
+
seen_keys = set(obj.keys())
|
112
|
+
missing_keys = self.expected_keys.difference(obj.keys())
|
113
|
+
invalid_keys = seen_keys.difference(self.expected_keys)
|
114
|
+
if missing_keys:
|
115
|
+
log.debug(f"Missing keys: {missing_keys}")
|
116
|
+
if invalid_keys:
|
117
|
+
log.debug(f"Invalid keys: {invalid_keys}")
|
118
|
+
log.debug(f"Missing keys: {missing_keys}")
|
119
|
+
raise OutputParserException(
|
120
|
+
f"Got invalid return object. Missing the following expected "
|
121
|
+
f"keys: {missing_keys}"
|
122
|
+
)
|
123
|
+
|
124
|
+
for key in invalid_keys:
|
125
|
+
del obj[key]
|
126
|
+
|
127
|
+
for value in obj.values():
|
128
|
+
if not isinstance(value, str):
|
129
|
+
raise OutputParserException(
|
130
|
+
f"Got invalid return object. Expected all string values,"
|
131
|
+
f' but got type "{type(value)}"'
|
132
|
+
)
|
133
|
+
|
134
|
+
return json.dumps(obj)
|
135
|
+
|
136
|
+
def parse_combined_output(self, text: str) -> str:
|
137
|
+
"""Parse the output text from the LLM when multiple inputs are combined.
|
138
|
+
|
139
|
+
Arguments:
|
140
|
+
text: The output text from the LLM.
|
141
|
+
|
142
|
+
Returns:
|
143
|
+
A parsed version of the text.
|
144
|
+
"""
|
145
|
+
if not text.strip():
|
146
|
+
return str({})
|
147
|
+
objs = [
|
148
|
+
parse_json_markdown(line.strip()) for line in text.split("\n") if line.strip()
|
149
|
+
]
|
150
|
+
output_obj = {}
|
151
|
+
for obj in objs:
|
152
|
+
output_obj.update(obj)
|
153
|
+
return json.dumps(output_obj)
|
154
|
+
|
155
|
+
def get_format_instructions(self) -> str:
|
156
|
+
"""Get the format instructions for the parser.
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
The format instructions for the LLM.
|
160
|
+
"""
|
161
|
+
return (
|
162
|
+
"Output must contain exactly one JSON-formatted block. The JSON "
|
163
|
+
"object should contain only (and all of) the comment IDs present "
|
164
|
+
"in the input code."
|
165
|
+
)
|
166
|
+
|
167
|
+
@property
|
168
|
+
def _type(self) -> str:
|
169
|
+
return self.__class__.name
|
@@ -0,0 +1,80 @@
|
|
1
|
+
import json
|
2
|
+
|
3
|
+
from langchain.output_parsers import PydanticOutputParser
|
4
|
+
from langchain_core.pydantic_v1 import BaseModel, Field, validator
|
5
|
+
|
6
|
+
from ..utils.logger import create_logger
|
7
|
+
from .code_parser import JanusParser
|
8
|
+
|
9
|
+
log = create_logger(__name__)
|
10
|
+
|
11
|
+
|
12
|
+
class Eval(BaseModel):
|
13
|
+
syntax: float = Field(description="A numeric score (0-4) for code syntax")
|
14
|
+
style: float = Field(description="A numeric score (0-4) for code style")
|
15
|
+
completeness: float = Field(description="A numeric score (0-4) for code completeness")
|
16
|
+
correctness: float = Field(description="A numeric score (0-4) for code correctness")
|
17
|
+
|
18
|
+
# You can add custom validation logic easily with Pydantic.
|
19
|
+
@validator("*")
|
20
|
+
def score_is_valid(cls, v: float | int):
|
21
|
+
try:
|
22
|
+
v = float(v)
|
23
|
+
except ValueError:
|
24
|
+
raise ValueError("must be a number")
|
25
|
+
|
26
|
+
if not 0 <= v <= 4:
|
27
|
+
raise ValueError("must be a value between 0 and 4 inclusive")
|
28
|
+
|
29
|
+
return v
|
30
|
+
|
31
|
+
def __add__(self, other):
|
32
|
+
if isinstance(other, int) and other == 0:
|
33
|
+
return self.copy()
|
34
|
+
return Eval.construct(
|
35
|
+
syntax=self.syntax + other.syntax,
|
36
|
+
style=self.style + other.style,
|
37
|
+
correctness=self.correctness + other.correctness,
|
38
|
+
completeness=self.completeness + other.completeness,
|
39
|
+
)
|
40
|
+
|
41
|
+
def __radd__(self, other):
|
42
|
+
return self.__add__(other)
|
43
|
+
|
44
|
+
def __truediv__(self, other):
|
45
|
+
if isinstance(other, int):
|
46
|
+
return Eval.construct(
|
47
|
+
syntax=self.syntax / other,
|
48
|
+
style=self.style / other,
|
49
|
+
correctness=self.correctness / other,
|
50
|
+
completeness=self.completeness / other,
|
51
|
+
)
|
52
|
+
return Eval.construct(
|
53
|
+
syntax=self.syntax / other.syntax,
|
54
|
+
style=self.style / other.style,
|
55
|
+
correctness=self.correctness / other.correctness,
|
56
|
+
completeness=self.completeness / other.completeness,
|
57
|
+
)
|
58
|
+
|
59
|
+
|
60
|
+
class EvaluationParser(PydanticOutputParser, JanusParser):
|
61
|
+
def __init__(self):
|
62
|
+
PydanticOutputParser.__init__(self, pydantic_object=Eval)
|
63
|
+
|
64
|
+
def parse(self, text: str) -> str:
|
65
|
+
eval = super().parse(text)
|
66
|
+
return json.dumps(eval.json())
|
67
|
+
|
68
|
+
def parse_combined_output(self, text: str) -> str:
|
69
|
+
"""Parse the JSON object, convert keys to lowercase, filter out
|
70
|
+
unexpected keys, and average the values
|
71
|
+
|
72
|
+
Arguments:
|
73
|
+
text: The output text from the LLM.
|
74
|
+
|
75
|
+
Returns:
|
76
|
+
A parsed version of the text.
|
77
|
+
"""
|
78
|
+
objs = [super().parse(line.strip()) for line in text.split("\n")]
|
79
|
+
avg_obj = sum(objs) / len(objs)
|
80
|
+
return json.dumps(avg_obj.json())
|