janus-llm 1.0.0__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +9 -1
- janus/__main__.py +4 -0
- janus/_tests/test_cli.py +128 -0
- janus/_tests/test_translate.py +49 -7
- janus/cli.py +530 -46
- janus/converter.py +50 -19
- janus/embedding/_tests/test_collections.py +2 -8
- janus/embedding/_tests/test_database.py +32 -0
- janus/embedding/_tests/test_vectorize.py +9 -4
- janus/embedding/collections.py +49 -6
- janus/embedding/embedding_models_info.py +120 -0
- janus/embedding/vectorize.py +53 -62
- janus/language/_tests/__init__.py +0 -0
- janus/language/_tests/test_combine.py +62 -0
- janus/language/_tests/test_splitter.py +16 -0
- janus/language/binary/_tests/test_binary.py +16 -1
- janus/language/binary/binary.py +10 -3
- janus/language/block.py +31 -30
- janus/language/combine.py +26 -34
- janus/language/mumps/_tests/test_mumps.py +2 -2
- janus/language/mumps/mumps.py +93 -9
- janus/language/naive/__init__.py +4 -0
- janus/language/naive/basic_splitter.py +14 -0
- janus/language/naive/chunk_splitter.py +26 -0
- janus/language/naive/registry.py +13 -0
- janus/language/naive/simple_ast.py +18 -0
- janus/language/naive/tag_splitter.py +61 -0
- janus/language/splitter.py +168 -74
- janus/language/treesitter/_tests/test_treesitter.py +9 -6
- janus/language/treesitter/treesitter.py +37 -13
- janus/llm/model_callbacks.py +177 -0
- janus/llm/models_info.py +134 -70
- janus/metrics/__init__.py +8 -0
- janus/metrics/_tests/__init__.py +0 -0
- janus/metrics/_tests/reference.py +2 -0
- janus/metrics/_tests/target.py +2 -0
- janus/metrics/_tests/test_bleu.py +56 -0
- janus/metrics/_tests/test_chrf.py +67 -0
- janus/metrics/_tests/test_file_pairing.py +59 -0
- janus/metrics/_tests/test_llm.py +91 -0
- janus/metrics/_tests/test_reading.py +28 -0
- janus/metrics/_tests/test_rouge_score.py +65 -0
- janus/metrics/_tests/test_similarity_score.py +23 -0
- janus/metrics/_tests/test_treesitter_metrics.py +110 -0
- janus/metrics/bleu.py +66 -0
- janus/metrics/chrf.py +55 -0
- janus/metrics/cli.py +7 -0
- janus/metrics/complexity_metrics.py +208 -0
- janus/metrics/file_pairing.py +113 -0
- janus/metrics/llm_metrics.py +202 -0
- janus/metrics/metric.py +466 -0
- janus/metrics/reading.py +70 -0
- janus/metrics/rouge_score.py +96 -0
- janus/metrics/similarity.py +53 -0
- janus/metrics/splitting.py +38 -0
- janus/parsers/_tests/__init__.py +0 -0
- janus/parsers/_tests/test_code_parser.py +32 -0
- janus/parsers/code_parser.py +24 -253
- janus/parsers/doc_parser.py +169 -0
- janus/parsers/eval_parser.py +80 -0
- janus/parsers/reqs_parser.py +72 -0
- janus/prompts/prompt.py +103 -30
- janus/translate.py +636 -111
- janus/utils/_tests/__init__.py +0 -0
- janus/utils/_tests/test_logger.py +67 -0
- janus/utils/_tests/test_progress.py +20 -0
- janus/utils/enums.py +56 -3
- janus/utils/progress.py +56 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/METADATA +23 -10
- janus_llm-2.0.0.dist-info/RECORD +94 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/WHEEL +1 -1
- janus_llm-1.0.0.dist-info/RECORD +0 -48
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/LICENSE +0 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,38 @@
|
|
1
|
+
from typing import Callable
|
2
|
+
|
3
|
+
SPLITTING_METHODS: dict[str, Callable[[str, str], list[str]]] = {}
|
4
|
+
|
5
|
+
|
6
|
+
def register_splitting_method(name: None | str = None) -> Callable[[Callable], Callable]:
|
7
|
+
"""Registers a pairing method for splitting strings in files
|
8
|
+
|
9
|
+
Arguments:
|
10
|
+
name: The name of the splitting method. If None, the function name is used.
|
11
|
+
help: The help text for the pairing method.
|
12
|
+
|
13
|
+
Returns:
|
14
|
+
The decorator function.
|
15
|
+
"""
|
16
|
+
|
17
|
+
def decorator(f: Callable[[str, str], list[tuple[str, str]]]):
|
18
|
+
if name is None:
|
19
|
+
splitting_name = f.__name__
|
20
|
+
else:
|
21
|
+
splitting_name = name
|
22
|
+
SPLITTING_METHODS[splitting_name] = f
|
23
|
+
return f
|
24
|
+
|
25
|
+
return decorator
|
26
|
+
|
27
|
+
|
28
|
+
@register_splitting_method(name="file")
|
29
|
+
def split_by_file(src: str, **kwargs) -> list[str]:
|
30
|
+
"""Split the source text by file
|
31
|
+
|
32
|
+
Arguments:
|
33
|
+
src: The source text.
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
A list of strings.
|
37
|
+
"""
|
38
|
+
return [src]
|
File without changes
|
@@ -0,0 +1,32 @@
|
|
1
|
+
import unittest
|
2
|
+
|
3
|
+
from ..code_parser import CodeParser, JanusParser
|
4
|
+
|
5
|
+
|
6
|
+
class TestJanusParser(unittest.TestCase):
|
7
|
+
def setUp(self):
|
8
|
+
self.parser = JanusParser()
|
9
|
+
|
10
|
+
def test_parse_combined_output(self):
|
11
|
+
text = "test text"
|
12
|
+
self.assertEqual(self.parser.parse_combined_output(text), text)
|
13
|
+
|
14
|
+
|
15
|
+
class TestCodeParser(unittest.TestCase):
|
16
|
+
def setUp(self):
|
17
|
+
self.parser = CodeParser(language="python")
|
18
|
+
|
19
|
+
def test_parse(self):
|
20
|
+
self.parser.language = "python"
|
21
|
+
text = "```\n# test text\n```"
|
22
|
+
self.assertEqual(self.parser.parse(text), text.strip("```").strip("\n"))
|
23
|
+
|
24
|
+
def test_get_format_instructions(self):
|
25
|
+
self.assertEqual(
|
26
|
+
self.parser.get_format_instructions(),
|
27
|
+
"Output must contain text contained within triple square brackets (```)",
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
if __name__ == "__main__":
|
32
|
+
unittest.main()
|
janus/parsers/code_parser.py
CHANGED
@@ -1,32 +1,17 @@
|
|
1
|
-
import json
|
2
1
|
import re
|
3
|
-
from collections import defaultdict
|
4
|
-
from typing import Any, Set
|
5
2
|
|
6
3
|
from langchain.schema.output_parser import BaseOutputParser
|
4
|
+
from langchain_core.exceptions import OutputParserException
|
5
|
+
from langchain_core.messages import BaseMessage
|
6
|
+
from langchain_core.output_parsers import StrOutputParser
|
7
7
|
|
8
8
|
from ..language.block import CodeBlock
|
9
|
-
from ..language.combine import Combiner
|
10
9
|
from ..utils.logger import create_logger
|
11
10
|
|
12
11
|
log = create_logger(__name__)
|
13
12
|
|
14
13
|
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
class JanusParser(BaseOutputParser):
|
19
|
-
def parse(self, text: str) -> str:
|
20
|
-
"""Parse the output text from the LLM.
|
21
|
-
|
22
|
-
Arguments:
|
23
|
-
text: The output text from the LLM
|
24
|
-
|
25
|
-
Returns:
|
26
|
-
A parsed version of the text
|
27
|
-
"""
|
28
|
-
return text
|
29
|
-
|
14
|
+
class JanusParser:
|
30
15
|
def parse_combined_output(self, text: str) -> str:
|
31
16
|
"""Parse the output text from the LLM when multiple inputs are combined
|
32
17
|
|
@@ -36,253 +21,39 @@ class JanusParser(BaseOutputParser):
|
|
36
21
|
Returns:
|
37
22
|
A parsed version of the text
|
38
23
|
"""
|
24
|
+
if isinstance(text, BaseMessage):
|
25
|
+
text = text.content
|
39
26
|
return text
|
40
27
|
|
41
|
-
def
|
42
|
-
|
43
|
-
|
28
|
+
def parse_into_block(self, text: str, block: CodeBlock):
|
29
|
+
if isinstance(text, BaseMessage):
|
30
|
+
text = text.content
|
31
|
+
block.text = text
|
44
32
|
|
45
|
-
|
46
|
-
|
47
|
-
output_text: The parsed text returned by the LLM
|
33
|
+
def set_reference(self, block: CodeBlock):
|
34
|
+
pass
|
48
35
|
|
49
|
-
Returns:
|
50
|
-
A score between 0 and 1 (inclusive). A score of 1.0 indicates that
|
51
|
-
the given text is fully acceptable, and no further attempts
|
52
|
-
should be made.
|
53
|
-
"""
|
54
|
-
return 1.0
|
55
36
|
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
return type(self).__name__
|
37
|
+
class GenericParser(StrOutputParser, JanusParser):
|
38
|
+
def parse(self, text: str) -> str:
|
39
|
+
if isinstance(text, BaseMessage):
|
40
|
+
return text.content
|
41
|
+
return text
|
62
42
|
|
63
43
|
|
64
|
-
class CodeParser(JanusParser):
|
44
|
+
class CodeParser(BaseOutputParser[str], JanusParser):
|
65
45
|
language: str
|
66
46
|
|
67
47
|
def parse(self, text: str) -> str:
|
68
|
-
|
69
|
-
|
70
|
-
Arguments:
|
71
|
-
text: The output text from the LLM
|
72
|
-
|
73
|
-
Returns:
|
74
|
-
A parsed version of the text
|
75
|
-
"""
|
48
|
+
if isinstance(text, BaseMessage):
|
49
|
+
text = text.content
|
76
50
|
pattern = rf"```[^\S\r\n]*(?:{self.language}[^\S\r\n]*)?\n?(.*?)\n*```"
|
77
51
|
code = re.search(pattern, text, re.DOTALL)
|
78
52
|
if code is None:
|
79
|
-
raise
|
80
|
-
|
81
|
-
|
82
|
-
def score(self, input_block: CodeBlock, output_text: str) -> float:
|
83
|
-
"""The score for translated code is the percentage of this block's
|
84
|
-
children which are present in the output
|
85
|
-
|
86
|
-
Arguments:
|
87
|
-
input_block: A `CodeBlock` representing the input to the LLM
|
88
|
-
output_text: The parsed text returned by the LLM
|
89
|
-
|
90
|
-
Returns:
|
91
|
-
A score between 0 and 1 (inclusive). A score of 1.0 indicates that
|
92
|
-
the given text is fully acceptable, and no further attempts
|
93
|
-
should be made.
|
94
|
-
"""
|
95
|
-
if not input_block.children:
|
96
|
-
return 1.0
|
97
|
-
|
98
|
-
missing_children = []
|
99
|
-
for child in input_block.children:
|
100
|
-
if not Combiner.contains_child(output_text, child):
|
101
|
-
missing_children.append(child.id)
|
102
|
-
|
103
|
-
if missing_children:
|
104
|
-
log.warning(
|
105
|
-
f"[{input_block.name}] Child placeholders not present in text: "
|
106
|
-
f"{missing_children}"
|
53
|
+
raise OutputParserException(
|
54
|
+
"Code not find code between triple square brackets"
|
107
55
|
)
|
108
|
-
|
109
|
-
|
110
|
-
return 1.0 - len(missing_children) / len(input_block.children)
|
111
|
-
|
112
|
-
def get_format_instructions(self) -> str:
|
113
|
-
return "Output must contain text contained within triple backticks."
|
114
|
-
|
115
|
-
|
116
|
-
class JsonLinesParser(JanusParser):
|
117
|
-
def parse(self, text: str) -> str:
|
118
|
-
"""Parse the output text from the LLM.
|
119
|
-
|
120
|
-
Arguments:
|
121
|
-
text: The output text from the LLM.
|
122
|
-
|
123
|
-
Returns:
|
124
|
-
A parsed version of the text.
|
125
|
-
"""
|
126
|
-
string = r"\"\w+\""
|
127
|
-
number = r"-?\d+(?:\.\d*)?"
|
128
|
-
json_value = rf"(?:{string}|{number})"
|
129
|
-
json_line = rf"\s*{string} *: *{json_value},?\s*"
|
130
|
-
pattern = "({" + rf"(?:{json_line})+" + "})"
|
131
|
-
matches = list(re.finditer(pattern, text, re.DOTALL))
|
132
|
-
if not matches:
|
133
|
-
raise ValueError("Could not find JSON output")
|
134
|
-
|
135
|
-
output_strings = [json.dumps(json.loads(match.group(1))) for match in matches]
|
136
|
-
return "\n".join(output_strings)
|
137
|
-
|
138
|
-
def parse_combined_output(self, text: str) -> str:
|
139
|
-
"""Parse the output text from the LLM when multiple inputs are combined.
|
140
|
-
|
141
|
-
Arguments:
|
142
|
-
text: The output text from the LLM.
|
143
|
-
|
144
|
-
Returns:
|
145
|
-
A parsed version of the text.
|
146
|
-
"""
|
147
|
-
return self.parse(text)
|
148
|
-
|
149
|
-
def get_format_instructions(self) -> str:
|
150
|
-
"""Get the format instructions for the parser.
|
151
|
-
|
152
|
-
Returns:
|
153
|
-
The format instructions for the LLM.
|
154
|
-
"""
|
155
|
-
return "Output must contain one or more JSON-formatted blocks."
|
156
|
-
|
157
|
-
|
158
|
-
class JsonParser(JsonLinesParser):
|
159
|
-
def parse(self, text: str) -> str:
|
160
|
-
"""Parse the output text from the LLM.
|
161
|
-
|
162
|
-
Arguments:
|
163
|
-
text: The output text from the LLM.
|
164
|
-
|
165
|
-
Returns:
|
166
|
-
A parsed version of the text.
|
167
|
-
"""
|
168
|
-
jsonl_text = super().parse(text)
|
169
|
-
if len(jsonl_text.split("\n")) > 1:
|
170
|
-
raise ValueError("Multiple JSON objects found")
|
171
|
-
|
172
|
-
return jsonl_text
|
173
|
-
|
174
|
-
def parse_combined_output(self, text: str) -> str:
|
175
|
-
"""Parse the output text from the LLM when multiple inputs are combined.
|
176
|
-
|
177
|
-
Arguments:
|
178
|
-
text: The output text from the LLM.
|
179
|
-
|
180
|
-
Returns:
|
181
|
-
A parsed version of the text.
|
182
|
-
"""
|
183
|
-
jsonl_text = JsonLinesParser.parse(self, text)
|
184
|
-
json_lines = jsonl_text.split("\n")
|
185
|
-
output_obj = {i: json.loads(t) for i, t in enumerate(json_lines)}
|
186
|
-
return json.dumps(output_obj)
|
56
|
+
return str(code.group(1))
|
187
57
|
|
188
58
|
def get_format_instructions(self) -> str:
|
189
|
-
"
|
190
|
-
|
191
|
-
Returns:
|
192
|
-
The format instructions for the LLM.
|
193
|
-
"""
|
194
|
-
return "Output must contain exactly one JSON-formatted block."
|
195
|
-
|
196
|
-
|
197
|
-
class EvaluationParser(JsonParser):
|
198
|
-
expected_keys: Set[str]
|
199
|
-
|
200
|
-
def __init__(self, expected_keys: Set[str], **kwargs: Any):
|
201
|
-
"""Create a new EvaluationParser.
|
202
|
-
|
203
|
-
Arguments:
|
204
|
-
expected_keys: The set of keys that should be present in the JSON
|
205
|
-
object
|
206
|
-
kwargs: Additional arguments to pass to the parent class
|
207
|
-
"""
|
208
|
-
super().__init__(expected_keys=expected_keys, **kwargs)
|
209
|
-
self.expected_keys = {k.lower() for k in expected_keys}
|
210
|
-
|
211
|
-
def parse(self, text: str) -> str:
|
212
|
-
"""Parse the JSON object, convert keys to lowercase, filter out
|
213
|
-
unexpected keys
|
214
|
-
|
215
|
-
Arguments:
|
216
|
-
text: The output text from the LLM.
|
217
|
-
|
218
|
-
Returns:
|
219
|
-
A parsed version of the text.
|
220
|
-
"""
|
221
|
-
json_text = super().parse(text)
|
222
|
-
obj = json.loads(json_text)
|
223
|
-
obj = {k.lower(): v for k, v in obj.items()}
|
224
|
-
obj = {k: v for k, v in obj.items() if k in self.expected_keys}
|
225
|
-
return json.dumps(obj)
|
226
|
-
|
227
|
-
def parse_combined_output(self, text: str) -> str:
|
228
|
-
"""Parse the JSON object, convert keys to lowercase, filter out
|
229
|
-
unexpected keys, and average the values
|
230
|
-
|
231
|
-
Arguments:
|
232
|
-
text: The output text from the LLM.
|
233
|
-
|
234
|
-
Returns:
|
235
|
-
A parsed version of the text.
|
236
|
-
"""
|
237
|
-
json_text = super().parse_combined_output(text)
|
238
|
-
multi_obj = json.loads(json_text)
|
239
|
-
n_evals = len(multi_obj)
|
240
|
-
|
241
|
-
output_obj = defaultdict(float)
|
242
|
-
for obj in multi_obj.values():
|
243
|
-
for k, v in obj.items():
|
244
|
-
output_obj[k] += v / n_evals
|
245
|
-
|
246
|
-
return json.dumps(output_obj)
|
247
|
-
|
248
|
-
def score(self, input_block: CodeBlock, output_text: str) -> float:
|
249
|
-
"""The score for the output text is the percentage of expected keys
|
250
|
-
that are present in the json object. Non-numeric values count for
|
251
|
-
half.
|
252
|
-
|
253
|
-
Arguments:
|
254
|
-
input_block: A `CodeBlock` representing the input to the LLM
|
255
|
-
output_text: The parsed text returned by the LLM
|
256
|
-
|
257
|
-
Returns:
|
258
|
-
A score between 0 and 1 (inclusive). A score of 1.0 indicates that
|
259
|
-
the given text is fully acceptable, and no further attempts
|
260
|
-
should be made.
|
261
|
-
"""
|
262
|
-
obj = json.loads(output_text)
|
263
|
-
|
264
|
-
expected_keys = self.expected_keys.intersection(obj.keys())
|
265
|
-
missing_keys = self.expected_keys.difference(obj.keys())
|
266
|
-
if missing_keys:
|
267
|
-
log.warning(f"[{input_block.name}] Expected keys missing: {missing_keys}")
|
268
|
-
|
269
|
-
non_numerics = {k: v for k, v in obj.items() if not isinstance(v, (int, float))}
|
270
|
-
if non_numerics:
|
271
|
-
log.warning(f"[{input_block.name}] Non-numeric values: {non_numerics}")
|
272
|
-
|
273
|
-
if missing_keys or non_numerics:
|
274
|
-
log.debug(f"Text:\n{output_text}")
|
275
|
-
|
276
|
-
return (len(expected_keys) - len(non_numerics) * 0.5) / len(self.expected_keys)
|
277
|
-
|
278
|
-
def get_format_instructions(self) -> str:
|
279
|
-
"""Get the format instructions for the parser.
|
280
|
-
|
281
|
-
Returns:
|
282
|
-
The format instructions for the LLM.
|
283
|
-
"""
|
284
|
-
return (
|
285
|
-
"Output must contain exactly one JSON-formatted block. The JSON "
|
286
|
-
"object should contain only the keys contained in the provided "
|
287
|
-
"expected_keys set (if any), and values should be numeric."
|
288
|
-
)
|
59
|
+
return "Output must contain text contained within triple square brackets (```)"
|
@@ -0,0 +1,169 @@
|
|
1
|
+
import json
|
2
|
+
import re
|
3
|
+
|
4
|
+
from langchain.output_parsers import PydanticOutputParser
|
5
|
+
from langchain.output_parsers.json import parse_json_markdown
|
6
|
+
from langchain.schema.output_parser import BaseOutputParser
|
7
|
+
from langchain_core.exceptions import OutputParserException
|
8
|
+
from langchain_core.messages import AIMessage
|
9
|
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
10
|
+
|
11
|
+
from ..language.block import CodeBlock
|
12
|
+
from ..utils.logger import create_logger
|
13
|
+
from .code_parser import JanusParser
|
14
|
+
|
15
|
+
log = create_logger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class MultiDoc(BaseModel):
|
19
|
+
docstring: str = Field(
|
20
|
+
description="A Sphinx-style docstring for the code, including a summary "
|
21
|
+
"of its functionality; the name, type, and description of "
|
22
|
+
"any parameters or returns; and any potential exceptions "
|
23
|
+
"that might arise in its execution"
|
24
|
+
)
|
25
|
+
example_usage: str = Field(
|
26
|
+
description="A well-commented minimal example utilizing the given "
|
27
|
+
"code's functionality"
|
28
|
+
)
|
29
|
+
pseudocode: str = Field(
|
30
|
+
description="A Python-stype pseudocode implementation of the module or "
|
31
|
+
"function's behavior"
|
32
|
+
)
|
33
|
+
|
34
|
+
|
35
|
+
class MultiDocumentationParser(PydanticOutputParser, JanusParser):
|
36
|
+
block_name: str = ""
|
37
|
+
|
38
|
+
def __init__(self):
|
39
|
+
PydanticOutputParser.__init__(self, pydantic_object=MultiDoc)
|
40
|
+
|
41
|
+
def set_reference(self, block: CodeBlock):
|
42
|
+
self.block_name = block.name
|
43
|
+
|
44
|
+
def parse(self, text: str) -> str:
|
45
|
+
if isinstance(text, AIMessage):
|
46
|
+
text = text.content
|
47
|
+
try:
|
48
|
+
docs = json.loads(super().parse(text).json())
|
49
|
+
except (OutputParserException, json.JSONDecodeError):
|
50
|
+
log.debug(f"Invalid JSON object. Output:\n{text}")
|
51
|
+
raise
|
52
|
+
docs["name"] = self.block_name
|
53
|
+
return json.dumps(docs)
|
54
|
+
|
55
|
+
def parse_combined_output(self, text: str) -> str:
|
56
|
+
"""Parse the output text from the LLM when multiple inputs are combined.
|
57
|
+
|
58
|
+
Arguments:
|
59
|
+
text: The output text from the LLM.
|
60
|
+
|
61
|
+
Returns:
|
62
|
+
A parsed version of the text.
|
63
|
+
"""
|
64
|
+
objs = [
|
65
|
+
parse_json_markdown(line.strip()) for line in text.split("\n") if line.strip()
|
66
|
+
]
|
67
|
+
output_obj = {d.pop("name"): d for d in objs}
|
68
|
+
return json.dumps(output_obj)
|
69
|
+
|
70
|
+
def get_format_instructions(self) -> str:
|
71
|
+
"""Get the format instructions for the parser.
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
The format instructions for the LLM.
|
75
|
+
"""
|
76
|
+
return (
|
77
|
+
"Output must contain a sphinx-style docstring, example usage, and "
|
78
|
+
"pseudocode, all in a json-formatted string with the following fields: "
|
79
|
+
'"docstring", "example_usage", and "pseudocode".'
|
80
|
+
)
|
81
|
+
|
82
|
+
@property
|
83
|
+
def _type(self) -> str:
|
84
|
+
return self.__class__.name
|
85
|
+
|
86
|
+
|
87
|
+
class MadlibsDocumentationParser(BaseOutputParser[str], JanusParser):
|
88
|
+
expected_keys: set[str]
|
89
|
+
|
90
|
+
def __init__(self):
|
91
|
+
super().__init__(expected_keys=[])
|
92
|
+
|
93
|
+
def set_reference(self, block: CodeBlock):
|
94
|
+
comment_ids = re.findall(r"<(?:BLOCK|INLINE)_COMMENT (\w{8})>", block.text)
|
95
|
+
self.expected_keys = set(comment_ids)
|
96
|
+
|
97
|
+
def parse(self, text: str) -> str:
|
98
|
+
if isinstance(text, AIMessage):
|
99
|
+
text = text.content
|
100
|
+
try:
|
101
|
+
obj = parse_json_markdown(text)
|
102
|
+
except json.JSONDecodeError as e:
|
103
|
+
log.debug(f"Invalid JSON object. Output:\n{text}")
|
104
|
+
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
105
|
+
|
106
|
+
if not isinstance(obj, dict):
|
107
|
+
raise OutputParserException(
|
108
|
+
f"Got invalid return object. Expected a dictionary, but got {type(obj)}"
|
109
|
+
)
|
110
|
+
|
111
|
+
seen_keys = set(obj.keys())
|
112
|
+
missing_keys = self.expected_keys.difference(obj.keys())
|
113
|
+
invalid_keys = seen_keys.difference(self.expected_keys)
|
114
|
+
if missing_keys:
|
115
|
+
log.debug(f"Missing keys: {missing_keys}")
|
116
|
+
if invalid_keys:
|
117
|
+
log.debug(f"Invalid keys: {invalid_keys}")
|
118
|
+
log.debug(f"Missing keys: {missing_keys}")
|
119
|
+
raise OutputParserException(
|
120
|
+
f"Got invalid return object. Missing the following expected "
|
121
|
+
f"keys: {missing_keys}"
|
122
|
+
)
|
123
|
+
|
124
|
+
for key in invalid_keys:
|
125
|
+
del obj[key]
|
126
|
+
|
127
|
+
for value in obj.values():
|
128
|
+
if not isinstance(value, str):
|
129
|
+
raise OutputParserException(
|
130
|
+
f"Got invalid return object. Expected all string values,"
|
131
|
+
f' but got type "{type(value)}"'
|
132
|
+
)
|
133
|
+
|
134
|
+
return json.dumps(obj)
|
135
|
+
|
136
|
+
def parse_combined_output(self, text: str) -> str:
|
137
|
+
"""Parse the output text from the LLM when multiple inputs are combined.
|
138
|
+
|
139
|
+
Arguments:
|
140
|
+
text: The output text from the LLM.
|
141
|
+
|
142
|
+
Returns:
|
143
|
+
A parsed version of the text.
|
144
|
+
"""
|
145
|
+
if not text.strip():
|
146
|
+
return str({})
|
147
|
+
objs = [
|
148
|
+
parse_json_markdown(line.strip()) for line in text.split("\n") if line.strip()
|
149
|
+
]
|
150
|
+
output_obj = {}
|
151
|
+
for obj in objs:
|
152
|
+
output_obj.update(obj)
|
153
|
+
return json.dumps(output_obj)
|
154
|
+
|
155
|
+
def get_format_instructions(self) -> str:
|
156
|
+
"""Get the format instructions for the parser.
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
The format instructions for the LLM.
|
160
|
+
"""
|
161
|
+
return (
|
162
|
+
"Output must contain exactly one JSON-formatted block. The JSON "
|
163
|
+
"object should contain only (and all of) the comment IDs present "
|
164
|
+
"in the input code."
|
165
|
+
)
|
166
|
+
|
167
|
+
@property
|
168
|
+
def _type(self) -> str:
|
169
|
+
return self.__class__.name
|
@@ -0,0 +1,80 @@
|
|
1
|
+
import json
|
2
|
+
|
3
|
+
from langchain.output_parsers import PydanticOutputParser
|
4
|
+
from langchain_core.pydantic_v1 import BaseModel, Field, validator
|
5
|
+
|
6
|
+
from ..utils.logger import create_logger
|
7
|
+
from .code_parser import JanusParser
|
8
|
+
|
9
|
+
log = create_logger(__name__)
|
10
|
+
|
11
|
+
|
12
|
+
class Eval(BaseModel):
|
13
|
+
syntax: float = Field(description="A numeric score (0-4) for code syntax")
|
14
|
+
style: float = Field(description="A numeric score (0-4) for code style")
|
15
|
+
completeness: float = Field(description="A numeric score (0-4) for code completeness")
|
16
|
+
correctness: float = Field(description="A numeric score (0-4) for code correctness")
|
17
|
+
|
18
|
+
# You can add custom validation logic easily with Pydantic.
|
19
|
+
@validator("*")
|
20
|
+
def score_is_valid(cls, v: float | int):
|
21
|
+
try:
|
22
|
+
v = float(v)
|
23
|
+
except ValueError:
|
24
|
+
raise ValueError("must be a number")
|
25
|
+
|
26
|
+
if not 0 <= v <= 4:
|
27
|
+
raise ValueError("must be a value between 0 and 4 inclusive")
|
28
|
+
|
29
|
+
return v
|
30
|
+
|
31
|
+
def __add__(self, other):
|
32
|
+
if isinstance(other, int) and other == 0:
|
33
|
+
return self.copy()
|
34
|
+
return Eval.construct(
|
35
|
+
syntax=self.syntax + other.syntax,
|
36
|
+
style=self.style + other.style,
|
37
|
+
correctness=self.correctness + other.correctness,
|
38
|
+
completeness=self.completeness + other.completeness,
|
39
|
+
)
|
40
|
+
|
41
|
+
def __radd__(self, other):
|
42
|
+
return self.__add__(other)
|
43
|
+
|
44
|
+
def __truediv__(self, other):
|
45
|
+
if isinstance(other, int):
|
46
|
+
return Eval.construct(
|
47
|
+
syntax=self.syntax / other,
|
48
|
+
style=self.style / other,
|
49
|
+
correctness=self.correctness / other,
|
50
|
+
completeness=self.completeness / other,
|
51
|
+
)
|
52
|
+
return Eval.construct(
|
53
|
+
syntax=self.syntax / other.syntax,
|
54
|
+
style=self.style / other.style,
|
55
|
+
correctness=self.correctness / other.correctness,
|
56
|
+
completeness=self.completeness / other.completeness,
|
57
|
+
)
|
58
|
+
|
59
|
+
|
60
|
+
class EvaluationParser(PydanticOutputParser, JanusParser):
|
61
|
+
def __init__(self):
|
62
|
+
PydanticOutputParser.__init__(self, pydantic_object=Eval)
|
63
|
+
|
64
|
+
def parse(self, text: str) -> str:
|
65
|
+
eval = super().parse(text)
|
66
|
+
return json.dumps(eval.json())
|
67
|
+
|
68
|
+
def parse_combined_output(self, text: str) -> str:
|
69
|
+
"""Parse the JSON object, convert keys to lowercase, filter out
|
70
|
+
unexpected keys, and average the values
|
71
|
+
|
72
|
+
Arguments:
|
73
|
+
text: The output text from the LLM.
|
74
|
+
|
75
|
+
Returns:
|
76
|
+
A parsed version of the text.
|
77
|
+
"""
|
78
|
+
objs = [super().parse(line.strip()) for line in text.split("\n")]
|
79
|
+
avg_obj = sum(objs) / len(objs)
|
80
|
+
return json.dumps(avg_obj.json())
|