langroid 0.16.5__py3-none-any.whl → 0.16.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/md_tool_message_grammar.py +455 -0
- langroid/agent/tools/code_file_tool_parse.py +150 -0
- langroid/agent/tools/code_file_tool_pyparsing.py +194 -0
- langroid/agent/tools/code_file_tool_pyparsing2.py +199 -0
- langroid/agent/tools/formatted_model_custom.py +150 -0
- langroid/agent/tools/formatted_model_custom2.py +168 -0
- langroid/agent/tools/formatted_model_custom3.py +279 -0
- langroid/agent/tools/formatted_model_custom4.py +395 -0
- langroid/agent/tools/formatted_model_jinja.py +133 -0
- langroid/agent/tools/formatted_model_jinja.py-e +122 -0
- langroid/agent/tools/formatted_model_jinja2.py +145 -0
- langroid/agent/tools/formatted_model_jinja2.py-e +135 -0
- langroid/agent/tools/formatted_model_lark.py +0 -0
- langroid/agent/tools/formatted_model_lark2.py +168 -0
- langroid/agent/tools/formatted_model_parse.py +105 -0
- langroid/agent/tools/formatted_model_parse.py-e +98 -0
- langroid/agent/tools/formatted_model_parse2.py +113 -0
- langroid/agent/tools/formatted_model_parse2.py-e +109 -0
- langroid/agent/tools/formatted_model_parse3.py +114 -0
- langroid/agent/tools/formatted_model_parse3.py-e +110 -0
- langroid/agent/tools/formatted_model_parsimon.py +194 -0
- langroid/agent/tools/formatted_model_parsimon.py-e +186 -0
- langroid/agent/tools/formatted_model_pyparsing.py +169 -0
- langroid/agent/tools/formatted_model_pyparsing.py-e +149 -0
- langroid/agent/tools/formatted_model_pyparsing2.py +159 -0
- langroid/agent/tools/formatted_model_pyparsing2.py-e +143 -0
- langroid/agent/tools/formatted_model_pyparsing3.py +133 -0
- langroid/agent/tools/formatted_model_pyparsing3.py-e +121 -0
- langroid/agent/tools/formatted_model_pyparsing4.py +213 -0
- langroid/agent/tools/formatted_model_pyparsing4.py-e +176 -0
- langroid/agent/tools/formatted_model_pyparsing5.py +173 -0
- langroid/agent/tools/formatted_model_pyparsing5.py-e +142 -0
- langroid/agent/tools/formatted_model_regex.py +246 -0
- langroid/agent/tools/formatted_model_regex.py-e +248 -0
- langroid/agent/tools/formatted_model_regex2.py +250 -0
- langroid/agent/tools/formatted_model_regex2.py-e +253 -0
- langroid/agent/tools/formatted_model_tatsu.py +172 -0
- langroid/agent/tools/formatted_model_tatsu.py-e +160 -0
- langroid/agent/tools/formatted_model_template.py +217 -0
- langroid/agent/tools/formatted_model_template.py-e +200 -0
- langroid/agent/tools/formatted_model_xml.py +178 -0
- langroid/agent/tools/formatted_model_xml2.py +178 -0
- langroid/agent/tools/formatted_model_xml3.py +132 -0
- langroid/agent/tools/formatted_model_xml4.py +130 -0
- langroid/agent/tools/formatted_model_xml5.py +130 -0
- langroid/agent/tools/formatted_model_xml6.py +113 -0
- langroid/agent/tools/formatted_model_xml7.py +117 -0
- langroid/agent/tools/formatted_model_xml8.py +164 -0
- langroid/agent/tools/generic_tool.py +165 -0
- langroid/agent/tools/generic_tool_tatsu.py +275 -0
- langroid/agent/tools/grammar_based_model.py +132 -0
- langroid/agent/tools/grammar_based_model.py-e +128 -0
- langroid/agent/tools/grammar_based_model_lark.py +156 -0
- langroid/agent/tools/grammar_based_model_lark.py-e +153 -0
- langroid/agent/tools/grammar_based_model_parse.py +86 -0
- langroid/agent/tools/grammar_based_model_parse.py-e +80 -0
- langroid/agent/tools/grammar_based_model_parsimonious.py +129 -0
- langroid/agent/tools/grammar_based_model_parsimonious.py-e +120 -0
- langroid/agent/tools/grammar_based_model_pyparsing.py +105 -0
- langroid/agent/tools/grammar_based_model_pyparsing.py-e +103 -0
- langroid/agent/tools/grammar_based_model_regex.py +139 -0
- langroid/agent/tools/grammar_based_model_regex.py-e +130 -0
- langroid/agent/tools/grammar_based_model_regex2.py +124 -0
- langroid/agent/tools/grammar_based_model_regex2.py-e +116 -0
- langroid/agent/tools/grammar_based_model_tatsu.py +80 -0
- langroid/agent/tools/grammar_based_model_tatsu.py-e +77 -0
- langroid/agent/tools/lark_earley_example.py +135 -0
- langroid/agent/tools/lark_earley_example.py-e +117 -0
- langroid/agent/tools/lark_example.py +72 -0
- langroid/agent/tools/parse_example.py +76 -0
- langroid/agent/tools/parse_example2.py +87 -0
- langroid/agent/tools/parse_example3.py +42 -0
- langroid/agent/tools/parse_test.py +791 -0
- langroid/agent/xml_tool_message.py +106 -0
- langroid/language_models/openai_gpt.py +6 -1
- {langroid-0.16.5.dist-info → langroid-0.16.7.dist-info}/METADATA +1 -1
- {langroid-0.16.5.dist-info → langroid-0.16.7.dist-info}/RECORD +80 -6
- pyproject.toml +1 -1
- {langroid-0.16.5.dist-info → langroid-0.16.7.dist-info}/LICENSE +0 -0
- {langroid-0.16.5.dist-info → langroid-0.16.7.dist-info}/WHEEL +0 -0
@@ -0,0 +1,194 @@
|
|
1
|
+
""""
|
2
|
+
Non-JSON Tool for LLM to specify contents of a code file.
|
3
|
+
|
4
|
+
Why Non-JSON? Because there are numerous issues with even the best LLMs trying
|
5
|
+
to return code within JSON strings (e.g. unescaped newlines, quotes, etc.),
|
6
|
+
and the problem is even worse with weak LLMs. Json repair methods exist, but
|
7
|
+
can't deal with all possible cases.
|
8
|
+
|
9
|
+
E.g. see this study from Aider: https://aider.chat/2024/08/14/code-in-json.html
|
10
|
+
|
11
|
+
Note: We express the formatting rules with a template since it has several benefits:
|
12
|
+
- all of the formatting rules are in one place,
|
13
|
+
- we get a parser for free, and don't have to write parsing code,
|
14
|
+
- we get a formatting example generator for free, and don't have to write
|
15
|
+
example generation code.
|
16
|
+
- consistency between the parser and the example generator is guaranteed.
|
17
|
+
"""
|
18
|
+
|
19
|
+
from typing import Any, Callable, Dict, List, Tuple, Type
|
20
|
+
|
21
|
+
from pyparsing import (
|
22
|
+
LineEnd,
|
23
|
+
Literal,
|
24
|
+
Optional,
|
25
|
+
ParserElement,
|
26
|
+
SkipTo,
|
27
|
+
White,
|
28
|
+
Word,
|
29
|
+
alphanums,
|
30
|
+
lineEnd,
|
31
|
+
printables,
|
32
|
+
)
|
33
|
+
|
34
|
+
from langroid.agent.tool_message import ToolMessage
|
35
|
+
from langroid.utils.constants import TOOL, TOOL_END
|
36
|
+
|
37
|
+
CODE_FENCE_START = "`" * 3
|
38
|
+
CODE_FENCE_END = "`" * 3
|
39
|
+
|
40
|
+
|
41
|
+
class CodeFileTool(ToolMessage):
|
42
|
+
"""
|
43
|
+
Used by LLM to specify contents of a code file.
|
44
|
+
"""
|
45
|
+
|
46
|
+
request: str = "code_file_tool"
|
47
|
+
purpose: str = """
|
48
|
+
To specify the contents of a code file.
|
49
|
+
"""
|
50
|
+
file_path: str
|
51
|
+
contents: str
|
52
|
+
language: str
|
53
|
+
|
54
|
+
@classmethod
|
55
|
+
def create_parser(cls):
|
56
|
+
TOOL_START = Literal(TOOL + ":")
|
57
|
+
CODE_FENCE = Literal("```")
|
58
|
+
|
59
|
+
file_path = SkipTo(lineEnd)("file_path")
|
60
|
+
language = Word(alphanums)("language")
|
61
|
+
contents = SkipTo(CODE_FENCE)("contents")
|
62
|
+
|
63
|
+
parser = (
|
64
|
+
TOOL_START
|
65
|
+
+ Optional(Word(printables), default=cls.default_value("request"))(
|
66
|
+
"request"
|
67
|
+
)
|
68
|
+
+ lineEnd
|
69
|
+
+ file_path
|
70
|
+
+ lineEnd
|
71
|
+
+ CODE_FENCE
|
72
|
+
+ Optional(White()) # Allow space after ```
|
73
|
+
+ language
|
74
|
+
+ lineEnd
|
75
|
+
+ contents
|
76
|
+
+ CODE_FENCE
|
77
|
+
+ lineEnd # Add this line to expect a newline after the closing fence
|
78
|
+
+ Literal(TOOL_END)
|
79
|
+
)
|
80
|
+
return parser
|
81
|
+
|
82
|
+
@classmethod
|
83
|
+
def parse(cls, string) -> Dict[str, Any]:
|
84
|
+
parser = cls.create_parser()
|
85
|
+
try:
|
86
|
+
result = parser.parseString(string, parseAll=True)
|
87
|
+
return {
|
88
|
+
"request": result["request"],
|
89
|
+
"file_path": result["file_path"].strip(),
|
90
|
+
"language": result["language"],
|
91
|
+
"contents": result["contents"].strip(),
|
92
|
+
}
|
93
|
+
except Exception as e:
|
94
|
+
print(f"Parsing failed: {e}")
|
95
|
+
return {}
|
96
|
+
|
97
|
+
@classmethod
|
98
|
+
def format(cls, instance) -> str:
|
99
|
+
parser = cls.create_parser()
|
100
|
+
|
101
|
+
def format_element(element):
|
102
|
+
if isinstance(element, Literal):
|
103
|
+
return element.match
|
104
|
+
elif element.resultsName:
|
105
|
+
if element.resultsName == "request":
|
106
|
+
return instance.request
|
107
|
+
elif element.resultsName == "file_path":
|
108
|
+
return instance.file_path
|
109
|
+
elif element.resultsName == "language":
|
110
|
+
return instance.language
|
111
|
+
elif element.resultsName == "contents":
|
112
|
+
return f"{instance.contents}\n" # Add newline after contents
|
113
|
+
elif isinstance(element, LineEnd):
|
114
|
+
return "\n"
|
115
|
+
return ""
|
116
|
+
|
117
|
+
def traverse_parser(parser_element):
|
118
|
+
if isinstance(parser_element, ParserElement):
|
119
|
+
if isinstance(parser_element, SkipTo):
|
120
|
+
return format_element(parser_element)
|
121
|
+
elif hasattr(parser_element, "exprs"):
|
122
|
+
return "".join(
|
123
|
+
traverse_parser(expr) for expr in parser_element.exprs
|
124
|
+
)
|
125
|
+
else:
|
126
|
+
return format_element(parser_element)
|
127
|
+
return str(parser_element)
|
128
|
+
|
129
|
+
formatted_string = traverse_parser(parser)
|
130
|
+
|
131
|
+
return formatted_string.strip()
|
132
|
+
|
133
|
+
@classmethod
|
134
|
+
def create(cls, get_directory: Callable[[], str]) -> Type["CodeFileTool"]:
|
135
|
+
"""
|
136
|
+
Create a subclass of CodeFileTool with a static method get_directory,
|
137
|
+
which returns the current directory path, so that all file paths are
|
138
|
+
interpreted as relative to the current directory.
|
139
|
+
"""
|
140
|
+
|
141
|
+
class SubCodeFileTool(cls):
|
142
|
+
get_directory: Callable[[], str] = staticmethod(get_directory)
|
143
|
+
|
144
|
+
return SubCodeFileTool
|
145
|
+
|
146
|
+
@classmethod
|
147
|
+
def examples(cls) -> List[ToolMessage | Tuple[str, ToolMessage]]:
|
148
|
+
return [
|
149
|
+
cls(
|
150
|
+
file_path="src/lib.rs",
|
151
|
+
language="rust",
|
152
|
+
contents="""
|
153
|
+
// function to add two numbers
|
154
|
+
pub fn add(a: i32, b: i32) -> i32 {
|
155
|
+
a + b
|
156
|
+
}
|
157
|
+
""",
|
158
|
+
)
|
159
|
+
]
|
160
|
+
|
161
|
+
@classmethod
|
162
|
+
def find_candidates(cls, input_str: str) -> List[str]:
|
163
|
+
"""
|
164
|
+
Find all possible (top-level) candidates for
|
165
|
+
CodeFileTool in the input string.
|
166
|
+
"""
|
167
|
+
# Use parse.findall to find all instances of the CodeFileTool pattern
|
168
|
+
parser = compile(cls.get_template())
|
169
|
+
matches = list(parser.findall(input_str))
|
170
|
+
candidates = [match.fixed for match in matches]
|
171
|
+
return candidates
|
172
|
+
|
173
|
+
@classmethod
|
174
|
+
def from_string(cls, input_string: str) -> "CodeFileTool":
|
175
|
+
"""Parse a string into a CodeFileTool object, using the TEMPLATE."""
|
176
|
+
parsed_data = cls.parse(input_string)
|
177
|
+
if parsed_data:
|
178
|
+
return cls(**parsed_data)
|
179
|
+
raise ValueError("Invalid input string format")
|
180
|
+
|
181
|
+
@classmethod
|
182
|
+
def to_string(cls, instance) -> str:
|
183
|
+
"""Convert a CodeFileTool object to a string, using the TEMPLATE."""
|
184
|
+
return cls.format(instance)
|
185
|
+
|
186
|
+
def __str__(self):
|
187
|
+
return self.to_string()
|
188
|
+
|
189
|
+
def __repr__(self):
|
190
|
+
return f"""CodeFileTool(
|
191
|
+
file_path='{self.file_path}',
|
192
|
+
language='{self.language}',
|
193
|
+
contents='{self.contents}')
|
194
|
+
"""
|
@@ -0,0 +1,199 @@
|
|
1
|
+
""""
|
2
|
+
Non-JSON Tool for LLM to specify contents of a code file.
|
3
|
+
|
4
|
+
Why Non-JSON? Because there are numerous issues with even the best LLMs trying
|
5
|
+
to return code within JSON strings (e.g. unescaped newlines, quotes, etc.),
|
6
|
+
and the problem is even worse with weak LLMs. Json repair methods exist, but
|
7
|
+
can't deal with all possible cases.
|
8
|
+
|
9
|
+
E.g. see this study from Aider: https://aider.chat/2024/08/14/code-in-json.html
|
10
|
+
|
11
|
+
Note: We express the formatting rules with a template since it has several benefits:
|
12
|
+
- all of the formatting rules are in one place,
|
13
|
+
- we get a parser for free, and don't have to write parsing code,
|
14
|
+
- we get a formatting example generator for free, and don't have to write
|
15
|
+
example generation code.
|
16
|
+
- consistency between the parser and the example generator is guaranteed.
|
17
|
+
"""
|
18
|
+
|
19
|
+
from typing import Callable, List, Tuple, Type
|
20
|
+
|
21
|
+
from pyparsing import (
|
22
|
+
Literal,
|
23
|
+
Optional,
|
24
|
+
SkipTo,
|
25
|
+
White,
|
26
|
+
Word,
|
27
|
+
alphanums,
|
28
|
+
lineEnd,
|
29
|
+
printables,
|
30
|
+
)
|
31
|
+
|
32
|
+
from langroid.agent.tool_message import ToolMessage
|
33
|
+
from langroid.agent.tools.generic_tool import GenericTool
|
34
|
+
from langroid.utils.constants import TOOL, TOOL_END
|
35
|
+
|
36
|
+
CODE_FENCE_START = "`" * 3
|
37
|
+
CODE_FENCE_END = "`" * 3
|
38
|
+
|
39
|
+
|
40
|
+
class CodeFileTool(GenericTool):
|
41
|
+
"""
|
42
|
+
Used by LLM to specify contents of a code file.
|
43
|
+
"""
|
44
|
+
|
45
|
+
request: str = "code_file_tool"
|
46
|
+
purpose: str = """
|
47
|
+
To specify the <contents> of a code file at <file_path>,
|
48
|
+
containing code in a specific <language>.
|
49
|
+
"""
|
50
|
+
file_path: str
|
51
|
+
contents: str
|
52
|
+
language: str
|
53
|
+
|
54
|
+
@classmethod
|
55
|
+
def define_grammar(cls):
|
56
|
+
TOOL_START = Literal(TOOL + ":")
|
57
|
+
CODE_FENCE = Literal("```")
|
58
|
+
|
59
|
+
file_path = SkipTo(lineEnd)("file_path")
|
60
|
+
language = Word(alphanums)("language")
|
61
|
+
contents = SkipTo(lineEnd + CODE_FENCE)("contents")
|
62
|
+
|
63
|
+
grammar = (
|
64
|
+
TOOL_START
|
65
|
+
+ Optional(White())
|
66
|
+
+ Optional(Word(printables), default=cls.default_value("request"))(
|
67
|
+
"request"
|
68
|
+
)
|
69
|
+
+ lineEnd
|
70
|
+
+ Optional(White())
|
71
|
+
+ file_path
|
72
|
+
+ lineEnd
|
73
|
+
+ CODE_FENCE
|
74
|
+
+ Optional(White())
|
75
|
+
+ language
|
76
|
+
+ lineEnd
|
77
|
+
+ contents
|
78
|
+
+ lineEnd
|
79
|
+
+ CODE_FENCE
|
80
|
+
+ lineEnd
|
81
|
+
+ Optional(White())
|
82
|
+
+ Literal(TOOL_END)
|
83
|
+
+ Optional(White())
|
84
|
+
)
|
85
|
+
return grammar
|
86
|
+
|
87
|
+
@classmethod
|
88
|
+
def create(cls, get_directory: Callable[[], str]) -> Type["CodeFileTool"]:
|
89
|
+
"""
|
90
|
+
Create a subclass of CodeFileTool with a static method get_directory,
|
91
|
+
which returns the current directory path, so that all file paths are
|
92
|
+
interpreted as relative to the current directory.
|
93
|
+
"""
|
94
|
+
|
95
|
+
class SubCodeFileTool(cls):
|
96
|
+
get_directory: Callable[[], str] = staticmethod(get_directory)
|
97
|
+
|
98
|
+
return SubCodeFileTool
|
99
|
+
|
100
|
+
@classmethod
|
101
|
+
def examples(cls) -> List[ToolMessage | Tuple[str, ToolMessage]]:
|
102
|
+
return [
|
103
|
+
cls(
|
104
|
+
file_path="src/lib.rs",
|
105
|
+
language="rust",
|
106
|
+
contents="""
|
107
|
+
// function to add two numbers
|
108
|
+
pub fn add(a: i32, b: i32) -> i32 {
|
109
|
+
a + b
|
110
|
+
}
|
111
|
+
""",
|
112
|
+
)
|
113
|
+
]
|
114
|
+
|
115
|
+
def __repr__(self):
|
116
|
+
return f"""CodeFileTool(
|
117
|
+
file_path='{self.file_path}',
|
118
|
+
language='{self.language}',
|
119
|
+
contents='{self.contents}')
|
120
|
+
"""
|
121
|
+
|
122
|
+
|
123
|
+
if __name__ == "__main__":
|
124
|
+
# Informal test to print instructions for CodeFileTool
|
125
|
+
print("Testing CodeFileTool instructions:")
|
126
|
+
print("-" * 50)
|
127
|
+
|
128
|
+
instructions = CodeFileTool.instructions()
|
129
|
+
print(instructions)
|
130
|
+
|
131
|
+
print("-" * 50)
|
132
|
+
print("End of instructions test")
|
133
|
+
|
134
|
+
# You can add more informal tests here if needed
|
135
|
+
# For example, testing the parse method:
|
136
|
+
print("\nTesting CodeFileTool parse method:")
|
137
|
+
print("-" * 50)
|
138
|
+
|
139
|
+
test_input = """TOOL: code_file_tool
|
140
|
+
src/main.py
|
141
|
+
```python
|
142
|
+
def hello_world():
|
143
|
+
print("Hello, World!")
|
144
|
+
|
145
|
+
if __name__ == "__main__":
|
146
|
+
hello_world()
|
147
|
+
```
|
148
|
+
TOOL_END"""
|
149
|
+
|
150
|
+
parsed_result = CodeFileTool.parse(test_input)
|
151
|
+
print("Parsed result:")
|
152
|
+
for key, value in parsed_result.items():
|
153
|
+
print(f"{key}: {value}")
|
154
|
+
|
155
|
+
print("-" * 50)
|
156
|
+
print("End of parse test")
|
157
|
+
|
158
|
+
# Test format method
|
159
|
+
print("\nTesting CodeFileTool format method:")
|
160
|
+
print("-" * 50)
|
161
|
+
test_instance = CodeFileTool(
|
162
|
+
request="code_file_tool",
|
163
|
+
file_path="tests/test_file.py",
|
164
|
+
language="python",
|
165
|
+
contents="""
|
166
|
+
def test_function():
|
167
|
+
assert 1 + 1 == 2
|
168
|
+
|
169
|
+
if __name__ == "__main__":
|
170
|
+
test_function()
|
171
|
+
""",
|
172
|
+
)
|
173
|
+
formatted_output = CodeFileTool.format(test_instance)
|
174
|
+
print("Formatted output:")
|
175
|
+
print(formatted_output)
|
176
|
+
print("-" * 50)
|
177
|
+
print("End of format test")
|
178
|
+
|
179
|
+
# Additional test: Round-trip (parse -> format -> parse)
|
180
|
+
print("\nTesting CodeFileTool round-trip (parse -> format -> parse):")
|
181
|
+
print("-" * 50)
|
182
|
+
initial_parse = CodeFileTool.parse(test_input)
|
183
|
+
initial_instance = CodeFileTool(**initial_parse)
|
184
|
+
formatted_output = CodeFileTool.format(initial_instance)
|
185
|
+
final_parse = CodeFileTool.parse(formatted_output)
|
186
|
+
|
187
|
+
print("Initial parse:")
|
188
|
+
print(initial_parse)
|
189
|
+
print("\nFormatted output:")
|
190
|
+
print(formatted_output)
|
191
|
+
print("\nFinal parse:")
|
192
|
+
print(final_parse)
|
193
|
+
|
194
|
+
if initial_parse == final_parse:
|
195
|
+
print("\nRound-trip test passed: Initial and final parses match.")
|
196
|
+
else:
|
197
|
+
print("\nRound-trip test failed: Initial and final parses do not match.")
|
198
|
+
print("-" * 50)
|
199
|
+
print("End of round-trip test")
|
@@ -0,0 +1,150 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import List, Tuple
|
3
|
+
|
4
|
+
from langroid.pydantic_v1 import BaseModel
|
5
|
+
|
6
|
+
|
7
|
+
class FormattingModel(BaseModel, ABC):
|
8
|
+
@classmethod
|
9
|
+
@abstractmethod
|
10
|
+
def format_spec(cls) -> str:
|
11
|
+
pass
|
12
|
+
|
13
|
+
@classmethod
|
14
|
+
@abstractmethod
|
15
|
+
def parse_spec(cls) -> List[Tuple[str, str, str]]:
|
16
|
+
pass
|
17
|
+
|
18
|
+
@classmethod
|
19
|
+
@abstractmethod
|
20
|
+
def start_token(cls) -> str:
|
21
|
+
pass
|
22
|
+
|
23
|
+
@classmethod
|
24
|
+
@abstractmethod
|
25
|
+
def end_token(cls) -> str:
|
26
|
+
pass
|
27
|
+
|
28
|
+
@classmethod
|
29
|
+
def format(cls, instance: "FormattingModel") -> str:
|
30
|
+
spec = cls.format_spec()
|
31
|
+
formatted = spec.format(**instance.dict())
|
32
|
+
return f"{cls.start_token()}\n{formatted}\n{cls.end_token()}"
|
33
|
+
|
34
|
+
@classmethod
|
35
|
+
def parse(cls, formatted_string: str) -> "FormattingModel":
|
36
|
+
lines = formatted_string.strip().split("\n")
|
37
|
+
if lines[0] != cls.start_token() or lines[-1] != cls.end_token():
|
38
|
+
raise ValueError("Invalid start or end token")
|
39
|
+
content = "\n".join(lines[1:-1])
|
40
|
+
|
41
|
+
parsed_data = {}
|
42
|
+
for field, start, end in cls.parse_spec():
|
43
|
+
start_index = content.find(start)
|
44
|
+
if start_index == -1:
|
45
|
+
raise ValueError(f"Could not find start of {field}")
|
46
|
+
end_index = content.find(end, start_index + len(start))
|
47
|
+
if end_index == -1:
|
48
|
+
raise ValueError(f"Could not find end of {field}")
|
49
|
+
value = content[start_index + len(start) : end_index].strip()
|
50
|
+
parsed_data[field] = value
|
51
|
+
|
52
|
+
return cls(**parsed_data)
|
53
|
+
|
54
|
+
|
55
|
+
class CodeFileModel(FormattingModel):
|
56
|
+
file_path: str
|
57
|
+
language: str
|
58
|
+
code: str
|
59
|
+
|
60
|
+
@classmethod
|
61
|
+
def format_spec(cls):
|
62
|
+
return "file_path: {file_path}\nlanguage: {language}\n```\n{code}\n```"
|
63
|
+
|
64
|
+
@classmethod
|
65
|
+
def parse_spec(cls):
|
66
|
+
return [
|
67
|
+
("file_path", "file_path:", "\n"),
|
68
|
+
("language", "language:", "\n"),
|
69
|
+
("code", "```\n", "\n```"),
|
70
|
+
]
|
71
|
+
|
72
|
+
@classmethod
|
73
|
+
def start_token(cls):
|
74
|
+
return "<code_file>"
|
75
|
+
|
76
|
+
@classmethod
|
77
|
+
def end_token(cls):
|
78
|
+
return "</code_file>"
|
79
|
+
|
80
|
+
|
81
|
+
# Test cases
|
82
|
+
if __name__ == "__main__":
|
83
|
+
# Test formatting
|
84
|
+
code_file = CodeFileModel(
|
85
|
+
file_path="src/main.py",
|
86
|
+
language="python",
|
87
|
+
code="def main():\n print('Hello, World!')",
|
88
|
+
)
|
89
|
+
formatted = CodeFileModel.format(code_file)
|
90
|
+
expected_format = """<code_file>
|
91
|
+
file_path: src/main.py
|
92
|
+
language: python
|
93
|
+
```
|
94
|
+
def main():
|
95
|
+
print('Hello, World!')
|
96
|
+
```
|
97
|
+
</code_file>"""
|
98
|
+
assert (
|
99
|
+
formatted == expected_format
|
100
|
+
), f"Formatting failed. Expected:\n{expected_format}\nGot:\n{formatted}"
|
101
|
+
print("Formatting test passed.")
|
102
|
+
|
103
|
+
# Test parsing
|
104
|
+
parsed = CodeFileModel.parse(formatted)
|
105
|
+
assert (
|
106
|
+
parsed == code_file
|
107
|
+
), f"Parsing failed. Expected:\n{code_file}\nGot:\n{parsed}"
|
108
|
+
print("Parsing test passed.")
|
109
|
+
|
110
|
+
# Test round-trip
|
111
|
+
round_trip = CodeFileModel.parse(CodeFileModel.format(code_file))
|
112
|
+
assert (
|
113
|
+
round_trip == code_file
|
114
|
+
), f"Round-trip failed. Expected:\n{code_file}\nGot:\n{round_trip}"
|
115
|
+
print("Round-trip test passed.")
|
116
|
+
|
117
|
+
# Test with different values
|
118
|
+
code_file2 = CodeFileModel(
|
119
|
+
file_path="src/app.js",
|
120
|
+
language="javascript",
|
121
|
+
code="function greet() {\n console.log('Hello, World!');\n}",
|
122
|
+
)
|
123
|
+
formatted2 = CodeFileModel.format(code_file2)
|
124
|
+
parsed2 = CodeFileModel.parse(formatted2)
|
125
|
+
assert (
|
126
|
+
parsed2 == code_file2
|
127
|
+
), f"Parsing failed for different values. Expected:\n{code_file2}\nGot:\n{parsed2}"
|
128
|
+
print("Different values test passed.")
|
129
|
+
|
130
|
+
# Test tolerant parsing
|
131
|
+
tolerant_input = """<code_file>
|
132
|
+
file_path: src/main.py
|
133
|
+
language: python
|
134
|
+
```
|
135
|
+
def main():
|
136
|
+
print('Hello, World!')
|
137
|
+
```
|
138
|
+
</code_file>"""
|
139
|
+
parsed_tolerant = CodeFileModel.parse(tolerant_input)
|
140
|
+
expected_tolerant = CodeFileModel(
|
141
|
+
file_path="src/main.py",
|
142
|
+
language="python",
|
143
|
+
code="def main():\n print('Hello, World!')",
|
144
|
+
)
|
145
|
+
assert (
|
146
|
+
parsed_tolerant == expected_tolerant
|
147
|
+
), f"Tolerant parsing failed. Expected:\n{expected_tolerant}\nGot:\n{parsed_tolerant}"
|
148
|
+
print("Tolerant parsing test passed.")
|
149
|
+
|
150
|
+
print("All tests passed successfully!")
|
@@ -0,0 +1,168 @@
|
|
1
|
+
from abc import ABC
|
2
|
+
from typing import Dict
|
3
|
+
|
4
|
+
from langroid.pydantic_v1 import BaseModel, Field
|
5
|
+
|
6
|
+
|
7
|
+
class FormatMetadata(BaseModel):
|
8
|
+
prefix: str = ""
|
9
|
+
suffix: str = ""
|
10
|
+
multiline: bool = False
|
11
|
+
|
12
|
+
|
13
|
+
class FormattingModel(BaseModel, ABC):
|
14
|
+
@classmethod
|
15
|
+
def format_spec(cls) -> str:
|
16
|
+
lines = []
|
17
|
+
for name, field in cls.__fields__.items():
|
18
|
+
metadata: FormatMetadata = field.field_info.extra.get(
|
19
|
+
"format_metadata", FormatMetadata()
|
20
|
+
)
|
21
|
+
if metadata.multiline:
|
22
|
+
lines.append(f"{metadata.prefix}{{{name}}}{metadata.suffix}")
|
23
|
+
else:
|
24
|
+
lines.append(f"{metadata.prefix}{{{name}}}{metadata.suffix}")
|
25
|
+
return "\n".join(lines)
|
26
|
+
|
27
|
+
@classmethod
|
28
|
+
def parse_spec(cls) -> Dict[str, FormatMetadata]:
|
29
|
+
return {
|
30
|
+
name: field.field_info.extra.get("format_metadata", FormatMetadata())
|
31
|
+
for name, field in cls.__fields__.items()
|
32
|
+
}
|
33
|
+
|
34
|
+
@classmethod
|
35
|
+
def start_token(cls) -> str:
|
36
|
+
return getattr(cls.Config, "start_token", "<format>")
|
37
|
+
|
38
|
+
@classmethod
|
39
|
+
def end_token(cls) -> str:
|
40
|
+
return getattr(cls.Config, "end_token", "</format>")
|
41
|
+
|
42
|
+
@classmethod
|
43
|
+
def format(cls, instance: "FormattingModel") -> str:
|
44
|
+
spec = cls.format_spec()
|
45
|
+
formatted = spec.format(**instance.dict())
|
46
|
+
return f"{cls.start_token()}\n{formatted}\n{cls.end_token()}"
|
47
|
+
|
48
|
+
@classmethod
|
49
|
+
def parse(cls, formatted_string: str) -> "FormattingModel":
|
50
|
+
lines = formatted_string.strip().split("\n")
|
51
|
+
if lines[0] != cls.start_token() or lines[-1] != cls.end_token():
|
52
|
+
raise ValueError("Invalid start or end token")
|
53
|
+
content = "\n".join(lines[1:-1])
|
54
|
+
|
55
|
+
parsed_data = {}
|
56
|
+
parse_spec = cls.parse_spec()
|
57
|
+
|
58
|
+
for field, metadata in parse_spec.items():
|
59
|
+
if metadata.multiline:
|
60
|
+
start = f"{metadata.prefix}"
|
61
|
+
end = f"{metadata.suffix}"
|
62
|
+
start_index = content.find(start)
|
63
|
+
if start_index == -1:
|
64
|
+
raise ValueError(f"Could not find start of {field}")
|
65
|
+
end_index = content.find(end, start_index + len(start))
|
66
|
+
if end_index == -1:
|
67
|
+
raise ValueError(f"Could not find end of {field}")
|
68
|
+
value = content[start_index + len(start) : end_index].strip()
|
69
|
+
else:
|
70
|
+
line_start = f"{metadata.prefix}"
|
71
|
+
line_end = metadata.suffix or "\n"
|
72
|
+
start_index = content.find(line_start)
|
73
|
+
if start_index == -1:
|
74
|
+
raise ValueError(f"Could not find {field}")
|
75
|
+
end_index = content.find(line_end, start_index + len(line_start))
|
76
|
+
if end_index == -1:
|
77
|
+
end_index = len(content)
|
78
|
+
value = content[start_index + len(line_start) : end_index].strip()
|
79
|
+
|
80
|
+
parsed_data[field] = value
|
81
|
+
|
82
|
+
return cls(**parsed_data)
|
83
|
+
|
84
|
+
|
85
|
+
class CodeFileModel(FormattingModel):
|
86
|
+
file_path: str = Field(..., format_metadata=FormatMetadata(prefix="file_path: "))
|
87
|
+
language: str = Field(..., format_metadata=FormatMetadata(prefix="language: "))
|
88
|
+
code: str = Field(
|
89
|
+
...,
|
90
|
+
format_metadata=FormatMetadata(prefix="```\n", suffix="\n```", multiline=True),
|
91
|
+
)
|
92
|
+
|
93
|
+
class Config:
|
94
|
+
start_token = "<code_file>"
|
95
|
+
end_token = "</code_file>"
|
96
|
+
|
97
|
+
|
98
|
+
# Test cases
|
99
|
+
#
|
100
|
+
if __name__ == "__main__":
|
101
|
+
# Test formatting
|
102
|
+
code_file = CodeFileModel(
|
103
|
+
file_path="src/main.py",
|
104
|
+
language="python",
|
105
|
+
code="def main():\n print('Hello, World!')",
|
106
|
+
)
|
107
|
+
formatted = CodeFileModel.format(code_file)
|
108
|
+
expected_format = """<code_file>
|
109
|
+
file_path: src/main.py
|
110
|
+
language: python
|
111
|
+
```
|
112
|
+
def main():
|
113
|
+
print('Hello, World!')
|
114
|
+
```
|
115
|
+
</code_file>"""
|
116
|
+
assert (
|
117
|
+
formatted == expected_format
|
118
|
+
), f"Formatting failed. Expected:\n{expected_format}\nGot:\n{formatted}"
|
119
|
+
print("Formatting test passed.")
|
120
|
+
|
121
|
+
# Test parsing
|
122
|
+
parsed = CodeFileModel.parse(formatted)
|
123
|
+
assert (
|
124
|
+
parsed == code_file
|
125
|
+
), f"Parsing failed. Expected:\n{code_file}\nGot:\n{parsed}"
|
126
|
+
print("Parsing test passed.")
|
127
|
+
|
128
|
+
# Test round-trip
|
129
|
+
round_trip = CodeFileModel.parse(CodeFileModel.format(code_file))
|
130
|
+
assert (
|
131
|
+
round_trip == code_file
|
132
|
+
), f"Round-trip failed. Expected:\n{code_file}\nGot:\n{round_trip}"
|
133
|
+
print("Round-trip test passed.")
|
134
|
+
|
135
|
+
# Test with different values
|
136
|
+
code_file2 = CodeFileModel(
|
137
|
+
file_path="src/app.js",
|
138
|
+
language="javascript",
|
139
|
+
code="function greet() {\n console.log('Hello, World!');\n}",
|
140
|
+
)
|
141
|
+
formatted2 = CodeFileModel.format(code_file2)
|
142
|
+
parsed2 = CodeFileModel.parse(formatted2)
|
143
|
+
assert (
|
144
|
+
parsed2 == code_file2
|
145
|
+
), f"Parsing failed for different values. Expected:\n{code_file2}\nGot:\n{parsed2}"
|
146
|
+
print("Different values test passed.")
|
147
|
+
|
148
|
+
# Test tolerant parsing
|
149
|
+
tolerant_input = """<code_file>
|
150
|
+
file_path: src/main.py
|
151
|
+
language: python
|
152
|
+
```
|
153
|
+
def main():
|
154
|
+
print('Hello, World!')
|
155
|
+
```
|
156
|
+
</code_file>"""
|
157
|
+
parsed_tolerant = CodeFileModel.parse(tolerant_input)
|
158
|
+
expected_tolerant = CodeFileModel(
|
159
|
+
file_path="src/main.py",
|
160
|
+
language="python",
|
161
|
+
code="def main():\n print('Hello, World!')",
|
162
|
+
)
|
163
|
+
assert (
|
164
|
+
parsed_tolerant == expected_tolerant
|
165
|
+
), f"Tolerant parsing failed. Expected:\n{expected_tolerant}\nGot:\n{parsed_tolerant}"
|
166
|
+
print("Tolerant parsing test passed.")
|
167
|
+
|
168
|
+
print("All tests passed successfully!")
|