langroid 0.16.5__py3-none-any.whl → 0.16.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/md_tool_message_grammar.py +455 -0
- langroid/agent/tools/code_file_tool_parse.py +150 -0
- langroid/agent/tools/code_file_tool_pyparsing.py +194 -0
- langroid/agent/tools/code_file_tool_pyparsing2.py +199 -0
- langroid/agent/tools/formatted_model_custom.py +150 -0
- langroid/agent/tools/formatted_model_custom2.py +168 -0
- langroid/agent/tools/formatted_model_custom3.py +279 -0
- langroid/agent/tools/formatted_model_custom4.py +395 -0
- langroid/agent/tools/formatted_model_jinja.py +133 -0
- langroid/agent/tools/formatted_model_jinja.py-e +122 -0
- langroid/agent/tools/formatted_model_jinja2.py +145 -0
- langroid/agent/tools/formatted_model_jinja2.py-e +135 -0
- langroid/agent/tools/formatted_model_lark.py +0 -0
- langroid/agent/tools/formatted_model_lark2.py +168 -0
- langroid/agent/tools/formatted_model_parse.py +105 -0
- langroid/agent/tools/formatted_model_parse.py-e +98 -0
- langroid/agent/tools/formatted_model_parse2.py +113 -0
- langroid/agent/tools/formatted_model_parse2.py-e +109 -0
- langroid/agent/tools/formatted_model_parse3.py +114 -0
- langroid/agent/tools/formatted_model_parse3.py-e +110 -0
- langroid/agent/tools/formatted_model_parsimon.py +194 -0
- langroid/agent/tools/formatted_model_parsimon.py-e +186 -0
- langroid/agent/tools/formatted_model_pyparsing.py +169 -0
- langroid/agent/tools/formatted_model_pyparsing.py-e +149 -0
- langroid/agent/tools/formatted_model_pyparsing2.py +159 -0
- langroid/agent/tools/formatted_model_pyparsing2.py-e +143 -0
- langroid/agent/tools/formatted_model_pyparsing3.py +133 -0
- langroid/agent/tools/formatted_model_pyparsing3.py-e +121 -0
- langroid/agent/tools/formatted_model_pyparsing4.py +213 -0
- langroid/agent/tools/formatted_model_pyparsing4.py-e +176 -0
- langroid/agent/tools/formatted_model_pyparsing5.py +173 -0
- langroid/agent/tools/formatted_model_pyparsing5.py-e +142 -0
- langroid/agent/tools/formatted_model_regex.py +246 -0
- langroid/agent/tools/formatted_model_regex.py-e +248 -0
- langroid/agent/tools/formatted_model_regex2.py +250 -0
- langroid/agent/tools/formatted_model_regex2.py-e +253 -0
- langroid/agent/tools/formatted_model_tatsu.py +172 -0
- langroid/agent/tools/formatted_model_tatsu.py-e +160 -0
- langroid/agent/tools/formatted_model_template.py +217 -0
- langroid/agent/tools/formatted_model_template.py-e +200 -0
- langroid/agent/tools/formatted_model_xml.py +178 -0
- langroid/agent/tools/formatted_model_xml2.py +178 -0
- langroid/agent/tools/formatted_model_xml3.py +132 -0
- langroid/agent/tools/formatted_model_xml4.py +130 -0
- langroid/agent/tools/formatted_model_xml5.py +130 -0
- langroid/agent/tools/formatted_model_xml6.py +113 -0
- langroid/agent/tools/formatted_model_xml7.py +117 -0
- langroid/agent/tools/formatted_model_xml8.py +164 -0
- langroid/agent/tools/generic_tool.py +165 -0
- langroid/agent/tools/generic_tool_tatsu.py +275 -0
- langroid/agent/tools/grammar_based_model.py +132 -0
- langroid/agent/tools/grammar_based_model.py-e +128 -0
- langroid/agent/tools/grammar_based_model_lark.py +156 -0
- langroid/agent/tools/grammar_based_model_lark.py-e +153 -0
- langroid/agent/tools/grammar_based_model_parse.py +86 -0
- langroid/agent/tools/grammar_based_model_parse.py-e +80 -0
- langroid/agent/tools/grammar_based_model_parsimonious.py +129 -0
- langroid/agent/tools/grammar_based_model_parsimonious.py-e +120 -0
- langroid/agent/tools/grammar_based_model_pyparsing.py +105 -0
- langroid/agent/tools/grammar_based_model_pyparsing.py-e +103 -0
- langroid/agent/tools/grammar_based_model_regex.py +139 -0
- langroid/agent/tools/grammar_based_model_regex.py-e +130 -0
- langroid/agent/tools/grammar_based_model_regex2.py +124 -0
- langroid/agent/tools/grammar_based_model_regex2.py-e +116 -0
- langroid/agent/tools/grammar_based_model_tatsu.py +80 -0
- langroid/agent/tools/grammar_based_model_tatsu.py-e +77 -0
- langroid/agent/tools/lark_earley_example.py +135 -0
- langroid/agent/tools/lark_earley_example.py-e +117 -0
- langroid/agent/tools/lark_example.py +72 -0
- langroid/agent/tools/parse_example.py +76 -0
- langroid/agent/tools/parse_example2.py +87 -0
- langroid/agent/tools/parse_example3.py +42 -0
- langroid/agent/tools/parse_test.py +791 -0
- langroid/agent/xml_tool_message.py +106 -0
- langroid/language_models/openai_gpt.py +6 -1
- {langroid-0.16.5.dist-info → langroid-0.16.7.dist-info}/METADATA +1 -1
- {langroid-0.16.5.dist-info → langroid-0.16.7.dist-info}/RECORD +80 -6
- pyproject.toml +1 -1
- {langroid-0.16.5.dist-info → langroid-0.16.7.dist-info}/LICENSE +0 -0
- {langroid-0.16.5.dist-info → langroid-0.16.7.dist-info}/WHEEL +0 -0
@@ -0,0 +1,279 @@
|
|
1
|
+
from abc import ABC
|
2
|
+
from typing import Dict, List
|
3
|
+
|
4
|
+
from langroid.pydantic_v1 import BaseModel, Field
|
5
|
+
|
6
|
+
|
7
|
+
class FormatMetadata(BaseModel):
|
8
|
+
prefix: str = ""
|
9
|
+
suffix: str = ""
|
10
|
+
multiline: bool = False
|
11
|
+
|
12
|
+
|
13
|
+
class FormattingModel(BaseModel, ABC):
|
14
|
+
@classmethod
|
15
|
+
def format_spec(cls) -> str:
|
16
|
+
lines = []
|
17
|
+
for name, field in cls.__fields__.items():
|
18
|
+
metadata: FormatMetadata = field.field_info.extra.get(
|
19
|
+
"format_metadata", FormatMetadata()
|
20
|
+
)
|
21
|
+
if metadata.multiline:
|
22
|
+
lines.append(f"{metadata.prefix}{{{name}}}{metadata.suffix}")
|
23
|
+
else:
|
24
|
+
lines.append(f"{metadata.prefix}{{{name}}}{metadata.suffix}")
|
25
|
+
return "\n".join(lines)
|
26
|
+
|
27
|
+
@classmethod
|
28
|
+
def parse_spec(cls) -> Dict[str, FormatMetadata]:
|
29
|
+
return {
|
30
|
+
name: field.field_info.extra.get("format_metadata", FormatMetadata())
|
31
|
+
for name, field in cls.__fields__.items()
|
32
|
+
}
|
33
|
+
|
34
|
+
@classmethod
|
35
|
+
def start_token(cls) -> str:
|
36
|
+
return getattr(cls.Config, "start_token", "<format>")
|
37
|
+
|
38
|
+
@classmethod
|
39
|
+
def end_token(cls) -> str:
|
40
|
+
return getattr(cls.Config, "end_token", "</format>")
|
41
|
+
|
42
|
+
@classmethod
|
43
|
+
def format(cls, instance: "FormattingModel") -> str:
|
44
|
+
spec = cls.format_spec()
|
45
|
+
formatted = spec.format(**instance.dict())
|
46
|
+
return f"{cls.start_token()}\n{formatted}\n{cls.end_token()}"
|
47
|
+
|
48
|
+
@classmethod
|
49
|
+
def parse(cls, formatted_string: str) -> "FormattingModel":
|
50
|
+
lines = formatted_string.strip().split("\n")
|
51
|
+
if lines[0] != cls.start_token():
|
52
|
+
raise ValueError("Invalid start token")
|
53
|
+
|
54
|
+
content = "\n".join(lines[1:])
|
55
|
+
if content.endswith(cls.end_token()):
|
56
|
+
content = content[: -len(cls.end_token())].strip()
|
57
|
+
|
58
|
+
parsed_data = {}
|
59
|
+
parse_spec = cls.parse_spec()
|
60
|
+
field_names = list(parse_spec.keys())
|
61
|
+
|
62
|
+
for i, (field, metadata) in enumerate(parse_spec.items()):
|
63
|
+
is_last_field = i == len(field_names) - 1
|
64
|
+
if metadata.multiline:
|
65
|
+
start = f"{metadata.prefix}"
|
66
|
+
end = f"{metadata.suffix}"
|
67
|
+
start_index = content.find(start)
|
68
|
+
if start_index == -1:
|
69
|
+
raise ValueError(f"Could not find start of {field}")
|
70
|
+
start_index += len(start)
|
71
|
+
if is_last_field:
|
72
|
+
end_index = content.find(end, start_index)
|
73
|
+
if end_index == -1:
|
74
|
+
end_index = len(content)
|
75
|
+
value = content[start_index:end_index].strip()
|
76
|
+
else:
|
77
|
+
end_index = content.find(end, start_index)
|
78
|
+
if end_index == -1:
|
79
|
+
raise ValueError(f"Could not find end of {field}")
|
80
|
+
value = content[start_index:end_index].strip()
|
81
|
+
else:
|
82
|
+
line_start = f"{metadata.prefix}"
|
83
|
+
line_end = metadata.suffix or "\n"
|
84
|
+
start_index = content.find(line_start)
|
85
|
+
if start_index == -1:
|
86
|
+
raise ValueError(f"Could not find {field}")
|
87
|
+
start_index += len(line_start)
|
88
|
+
if is_last_field:
|
89
|
+
end_index = content.find(line_end, start_index)
|
90
|
+
if end_index == -1:
|
91
|
+
end_index = len(content)
|
92
|
+
value = content[start_index:end_index].strip()
|
93
|
+
else:
|
94
|
+
end_index = content.find(line_end, start_index)
|
95
|
+
if end_index == -1:
|
96
|
+
raise ValueError(f"Could not find end of {field}")
|
97
|
+
value = content[start_index:end_index].strip()
|
98
|
+
|
99
|
+
parsed_data[field] = value
|
100
|
+
content = content[
|
101
|
+
end_index + len(end if metadata.multiline else line_end) :
|
102
|
+
].strip()
|
103
|
+
|
104
|
+
return cls(**parsed_data)
|
105
|
+
|
106
|
+
@staticmethod
|
107
|
+
def find_all_candidates(string: str, begin_token: str, end_token: str) -> List[str]:
|
108
|
+
candidates = []
|
109
|
+
start = 0
|
110
|
+
while True:
|
111
|
+
start_index = string.find(begin_token, start)
|
112
|
+
if start_index == -1:
|
113
|
+
break
|
114
|
+
|
115
|
+
end_index = string.find(end_token, start_index + len(begin_token))
|
116
|
+
if end_index == -1:
|
117
|
+
# If no end token is found, assume it extends to the end of the string
|
118
|
+
candidates.append(string[start_index:])
|
119
|
+
break
|
120
|
+
|
121
|
+
# Check if there's a nested begin token before the end token
|
122
|
+
next_start = string.find(
|
123
|
+
begin_token, start_index + len(begin_token), end_index
|
124
|
+
)
|
125
|
+
if next_start != -1:
|
126
|
+
# If there's a nested begin token, continue searching from there
|
127
|
+
start = next_start
|
128
|
+
continue
|
129
|
+
|
130
|
+
candidates.append(string[start_index : end_index + len(end_token)])
|
131
|
+
start = end_index + len(end_token)
|
132
|
+
|
133
|
+
return candidates
|
134
|
+
|
135
|
+
|
136
|
+
class CodeFileModel(FormattingModel):
|
137
|
+
file_path: str = Field(..., format_metadata=FormatMetadata(prefix="file_path: "))
|
138
|
+
language: str = Field(..., format_metadata=FormatMetadata(prefix="language: "))
|
139
|
+
code: str = Field(
|
140
|
+
...,
|
141
|
+
format_metadata=FormatMetadata(prefix="```\n", suffix="\n```", multiline=True),
|
142
|
+
)
|
143
|
+
|
144
|
+
class Config:
|
145
|
+
start_token = "<code_file>"
|
146
|
+
end_token = "</code_file>"
|
147
|
+
|
148
|
+
|
149
|
+
if __name__ == "__main__":
|
150
|
+
# Test formatting
|
151
|
+
code_file = CodeFileModel(
|
152
|
+
file_path="src/main.py",
|
153
|
+
language="python",
|
154
|
+
code="def main():\n print('Hello, World!')",
|
155
|
+
)
|
156
|
+
formatted = CodeFileModel.format(code_file)
|
157
|
+
expected_format = """<code_file>
|
158
|
+
file_path: src/main.py
|
159
|
+
language: python
|
160
|
+
```
|
161
|
+
def main():
|
162
|
+
print('Hello, World!')
|
163
|
+
```
|
164
|
+
</code_file>"""
|
165
|
+
assert (
|
166
|
+
formatted == expected_format
|
167
|
+
), f"Formatting failed. Expected:\n{expected_format}\nGot:\n{formatted}"
|
168
|
+
print("Formatting test passed.")
|
169
|
+
|
170
|
+
# Test parsing
|
171
|
+
parsed = CodeFileModel.parse(formatted)
|
172
|
+
assert (
|
173
|
+
parsed == code_file
|
174
|
+
), f"Parsing failed. Expected:\n{code_file}\nGot:\n{parsed}"
|
175
|
+
print("Parsing test passed.")
|
176
|
+
|
177
|
+
# Test round-trip
|
178
|
+
round_trip = CodeFileModel.parse(CodeFileModel.format(code_file))
|
179
|
+
assert (
|
180
|
+
round_trip == code_file
|
181
|
+
), f"Round-trip failed. Expected:\n{code_file}\nGot:\n{round_trip}"
|
182
|
+
print("Round-trip test passed.")
|
183
|
+
|
184
|
+
# Test with different values
|
185
|
+
code_file2 = CodeFileModel(
|
186
|
+
file_path="src/app.js",
|
187
|
+
language="javascript",
|
188
|
+
code="function greet() {\n console.log('Hello, World!');\n}",
|
189
|
+
)
|
190
|
+
formatted2 = CodeFileModel.format(code_file2)
|
191
|
+
parsed2 = CodeFileModel.parse(formatted2)
|
192
|
+
assert (
|
193
|
+
parsed2 == code_file2
|
194
|
+
), f"Parsing failed for different values. Expected:\n{code_file2}\nGot:\n{parsed2}"
|
195
|
+
print("Different values test passed.")
|
196
|
+
|
197
|
+
# Test tolerant parsing
|
198
|
+
tolerant_input = """<code_file>
|
199
|
+
file_path: src/main.py
|
200
|
+
language: python
|
201
|
+
```
|
202
|
+
def main():
|
203
|
+
print('Hello, World!')
|
204
|
+
```
|
205
|
+
</code_file>"""
|
206
|
+
parsed_tolerant = CodeFileModel.parse(tolerant_input)
|
207
|
+
expected_tolerant = CodeFileModel(
|
208
|
+
file_path="src/main.py",
|
209
|
+
language="python",
|
210
|
+
code="def main():\n print('Hello, World!')",
|
211
|
+
)
|
212
|
+
assert (
|
213
|
+
parsed_tolerant == expected_tolerant
|
214
|
+
), f"Tolerant parsing failed. Expected:\n{expected_tolerant}\nGot:\n{parsed_tolerant}"
|
215
|
+
print("Tolerant parsing test passed.")
|
216
|
+
|
217
|
+
# Test tolerant parsing without end token and last field suffix
|
218
|
+
tolerant_input_no_end = """<code_file>
|
219
|
+
file_path: src/main.py
|
220
|
+
language: python
|
221
|
+
```
|
222
|
+
def main():
|
223
|
+
print('Hello, World!')"""
|
224
|
+
parsed_tolerant_no_end = CodeFileModel.parse(tolerant_input_no_end)
|
225
|
+
expected_tolerant_no_end = CodeFileModel(
|
226
|
+
file_path="src/main.py",
|
227
|
+
language="python",
|
228
|
+
code="def main():\n print('Hello, World!')",
|
229
|
+
)
|
230
|
+
assert (
|
231
|
+
parsed_tolerant_no_end == expected_tolerant_no_end
|
232
|
+
), f"Tolerant parsing without end token failed. Expected:\n{expected_tolerant_no_end}\nGot:\n{parsed_tolerant_no_end}"
|
233
|
+
print("Tolerant parsing without end token test passed.")
|
234
|
+
|
235
|
+
# Test find_all_candidates method
|
236
|
+
test_string = """
|
237
|
+
Some text before
|
238
|
+
<code_file>
|
239
|
+
file_path: src/main.py
|
240
|
+
language: python
|
241
|
+
```
|
242
|
+
def main():
|
243
|
+
print('Hello, World!')
|
244
|
+
```
|
245
|
+
</code_file>
|
246
|
+
Some text in between
|
247
|
+
<code_file>
|
248
|
+
file_path: src/helper.py
|
249
|
+
language: python
|
250
|
+
```
|
251
|
+
def helper():
|
252
|
+
return 'Helper function'
|
253
|
+
```
|
254
|
+
</code_file>
|
255
|
+
<code_file>
|
256
|
+
file_path: src/incomplete.py
|
257
|
+
language: python
|
258
|
+
```
|
259
|
+
def incomplete():
|
260
|
+
print('No end token')
|
261
|
+
Some text after
|
262
|
+
"""
|
263
|
+
|
264
|
+
candidates = FormattingModel.find_all_candidates(
|
265
|
+
test_string, "<code_file>", "</code_file>"
|
266
|
+
)
|
267
|
+
assert len(candidates) == 3, f"Expected 3 candidates, got {len(candidates)}"
|
268
|
+
assert candidates[0].startswith("<code_file>") and candidates[0].endswith(
|
269
|
+
"</code_file>"
|
270
|
+
), "First candidate is incorrect"
|
271
|
+
assert candidates[1].startswith("<code_file>") and candidates[1].endswith(
|
272
|
+
"</code_file>"
|
273
|
+
), "Second candidate is incorrect"
|
274
|
+
assert candidates[2].startswith("<code_file>") and not candidates[2].endswith(
|
275
|
+
"</code_file>"
|
276
|
+
), "Third candidate is incorrect"
|
277
|
+
print("find_all_candidates test passed.")
|
278
|
+
|
279
|
+
print("All tests passed successfully!")
|
@@ -0,0 +1,395 @@
|
|
1
|
+
from abc import ABC
|
2
|
+
from typing import Dict, List
|
3
|
+
|
4
|
+
from langroid.pydantic_v1 import BaseModel, Field
|
5
|
+
|
6
|
+
|
7
|
+
class FormatMetadata(BaseModel):
|
8
|
+
prefix: str = ""
|
9
|
+
suffix: str = ""
|
10
|
+
multiline: bool = False
|
11
|
+
order: int = 0 # New field for ordering
|
12
|
+
|
13
|
+
|
14
|
+
class FormattingModel(BaseModel, ABC):
|
15
|
+
@classmethod
|
16
|
+
def format_spec(cls) -> str:
|
17
|
+
fields = sorted(
|
18
|
+
cls.__fields__.items(),
|
19
|
+
key=lambda x: x[1]
|
20
|
+
.field_info.extra.get("format_metadata", FormatMetadata())
|
21
|
+
.order,
|
22
|
+
)
|
23
|
+
lines = []
|
24
|
+
for name, field in fields:
|
25
|
+
metadata: FormatMetadata = field.field_info.extra.get(
|
26
|
+
"format_metadata", FormatMetadata()
|
27
|
+
)
|
28
|
+
if metadata.multiline:
|
29
|
+
lines.append(f"{metadata.prefix}{{{name}}}{metadata.suffix}")
|
30
|
+
else:
|
31
|
+
lines.append(f"{metadata.prefix}{{{name}}}{metadata.suffix}")
|
32
|
+
return "\n".join(lines)
|
33
|
+
|
34
|
+
@classmethod
|
35
|
+
def parse_spec(cls) -> Dict[str, FormatMetadata]:
|
36
|
+
fields = sorted(
|
37
|
+
cls.__fields__.items(),
|
38
|
+
key=lambda x: x[1]
|
39
|
+
.field_info.extra.get("format_metadata", FormatMetadata())
|
40
|
+
.order,
|
41
|
+
)
|
42
|
+
return {
|
43
|
+
name: field.field_info.extra.get("format_metadata", FormatMetadata())
|
44
|
+
for name, field in fields
|
45
|
+
}
|
46
|
+
|
47
|
+
@classmethod
|
48
|
+
def start_token(cls) -> str:
|
49
|
+
return getattr(cls.Config, "start_token", "<format>")
|
50
|
+
|
51
|
+
@classmethod
|
52
|
+
def end_token(cls) -> str:
|
53
|
+
return getattr(cls.Config, "end_token", "</format>")
|
54
|
+
|
55
|
+
@classmethod
|
56
|
+
def format(cls, instance: "FormattingModel") -> str:
|
57
|
+
spec = cls.format_spec()
|
58
|
+
formatted = spec.format(**instance.dict())
|
59
|
+
return f"{cls.start_token()}\n{formatted}\n{cls.end_token()}"
|
60
|
+
|
61
|
+
@classmethod
|
62
|
+
def parse(cls, formatted_string: str) -> "FormattingModel":
|
63
|
+
lines = formatted_string.strip().split("\n")
|
64
|
+
if lines[0] != cls.start_token():
|
65
|
+
raise ValueError("Invalid start token")
|
66
|
+
|
67
|
+
content = "\n".join(lines[1:])
|
68
|
+
if content.endswith(cls.end_token()):
|
69
|
+
content = content[: -len(cls.end_token())]
|
70
|
+
|
71
|
+
parsed_data = {}
|
72
|
+
parse_spec = cls.parse_spec()
|
73
|
+
field_names = list(parse_spec.keys())
|
74
|
+
|
75
|
+
for i, (field, metadata) in enumerate(parse_spec.items()):
|
76
|
+
is_last_field = i == len(field_names) - 1
|
77
|
+
if metadata.multiline:
|
78
|
+
start = f"{metadata.prefix}"
|
79
|
+
end = f"{metadata.suffix}"
|
80
|
+
start_index = content.find(start)
|
81
|
+
if start_index == -1:
|
82
|
+
raise ValueError(f"Could not find start of {field}")
|
83
|
+
start_index += len(start)
|
84
|
+
if is_last_field:
|
85
|
+
end_index = content.rfind(
|
86
|
+
end
|
87
|
+
) # Use rfind to find the last occurrence
|
88
|
+
if end_index == -1:
|
89
|
+
end_index = len(content)
|
90
|
+
value = content[start_index:end_index] # Don't strip here
|
91
|
+
else:
|
92
|
+
end_index = content.find(end, start_index)
|
93
|
+
if end_index == -1:
|
94
|
+
raise ValueError(f"Could not find end of {field}")
|
95
|
+
value = content[start_index:end_index] # Don't strip here
|
96
|
+
else:
|
97
|
+
line_start = f"{metadata.prefix}"
|
98
|
+
line_end = metadata.suffix or "\n"
|
99
|
+
start_index = content.find(line_start)
|
100
|
+
if start_index == -1:
|
101
|
+
raise ValueError(f"Could not find {field}")
|
102
|
+
start_index += len(line_start)
|
103
|
+
if is_last_field:
|
104
|
+
end_index = content.rfind(
|
105
|
+
line_end
|
106
|
+
) # Use rfind to find the last occurrence
|
107
|
+
if end_index == -1:
|
108
|
+
end_index = len(content)
|
109
|
+
value = content[
|
110
|
+
start_index:end_index
|
111
|
+
].strip() # Strip for non-multiline fields
|
112
|
+
else:
|
113
|
+
end_index = content.find(line_end, start_index)
|
114
|
+
if end_index == -1:
|
115
|
+
raise ValueError(f"Could not find end of {field}")
|
116
|
+
value = content[
|
117
|
+
start_index:end_index
|
118
|
+
].strip() # Strip for non-multiline fields
|
119
|
+
|
120
|
+
parsed_data[field] = value
|
121
|
+
content = content[
|
122
|
+
end_index + len(end if metadata.multiline else line_end) :
|
123
|
+
]
|
124
|
+
|
125
|
+
return cls(**parsed_data)
|
126
|
+
|
127
|
+
@staticmethod
|
128
|
+
def find_all_candidates(string: str, begin_token: str, end_token: str) -> List[str]:
|
129
|
+
candidates = []
|
130
|
+
start = 0
|
131
|
+
while True:
|
132
|
+
start_index = string.find(begin_token, start)
|
133
|
+
if start_index == -1:
|
134
|
+
break
|
135
|
+
|
136
|
+
end_index = string.find(end_token, start_index + len(begin_token))
|
137
|
+
if end_index == -1:
|
138
|
+
# If no end token is found, assume it extends to the end of the string
|
139
|
+
candidates.append(string[start_index:])
|
140
|
+
break
|
141
|
+
|
142
|
+
# Check if there's a nested begin token before the end token
|
143
|
+
next_start = string.find(
|
144
|
+
begin_token, start_index + len(begin_token), end_index
|
145
|
+
)
|
146
|
+
if next_start != -1:
|
147
|
+
# If there's a nested begin token, continue searching from there
|
148
|
+
start = next_start
|
149
|
+
continue
|
150
|
+
|
151
|
+
candidates.append(string[start_index : end_index + len(end_token)])
|
152
|
+
start = end_index + len(end_token)
|
153
|
+
|
154
|
+
return candidates
|
155
|
+
|
156
|
+
|
157
|
+
class CodeFileModel(FormattingModel):
|
158
|
+
file_path: str = Field(
|
159
|
+
..., format_metadata=FormatMetadata(prefix="file_path: ", order=1)
|
160
|
+
)
|
161
|
+
language: str = Field(
|
162
|
+
..., format_metadata=FormatMetadata(prefix="language: ", order=2)
|
163
|
+
)
|
164
|
+
code: str = Field(
|
165
|
+
...,
|
166
|
+
format_metadata=FormatMetadata(
|
167
|
+
prefix="```\n", suffix="\n```", multiline=True, order=3
|
168
|
+
),
|
169
|
+
)
|
170
|
+
|
171
|
+
class Config:
|
172
|
+
start_token = "<code_file>"
|
173
|
+
end_token = "</code_file>"
|
174
|
+
|
175
|
+
|
176
|
+
if __name__ == "__main__":
|
177
|
+
# Test formatting
|
178
|
+
code_file = CodeFileModel(
|
179
|
+
file_path="src/main.py",
|
180
|
+
language="python",
|
181
|
+
code="def main():\n print('Hello, World!')",
|
182
|
+
)
|
183
|
+
formatted = CodeFileModel.format(code_file)
|
184
|
+
expected_format = """<code_file>
|
185
|
+
file_path: src/main.py
|
186
|
+
language: python
|
187
|
+
```
|
188
|
+
def main():
|
189
|
+
print('Hello, World!')
|
190
|
+
```
|
191
|
+
</code_file>"""
|
192
|
+
assert (
|
193
|
+
formatted == expected_format
|
194
|
+
), f"Formatting failed. Expected:\n{expected_format}\nGot:\n{formatted}"
|
195
|
+
print("Formatting test passed.")
|
196
|
+
|
197
|
+
# Test parsing
|
198
|
+
parsed = CodeFileModel.parse(formatted)
|
199
|
+
assert (
|
200
|
+
parsed == code_file
|
201
|
+
), f"Parsing failed. Expected:\n{code_file}\nGot:\n{parsed}"
|
202
|
+
print("Parsing test passed.")
|
203
|
+
|
204
|
+
# Test round-trip
|
205
|
+
round_trip = CodeFileModel.parse(CodeFileModel.format(code_file))
|
206
|
+
assert (
|
207
|
+
round_trip == code_file
|
208
|
+
), f"Round-trip failed. Expected:\n{code_file}\nGot:\n{round_trip}"
|
209
|
+
print("Round-trip test passed.")
|
210
|
+
|
211
|
+
# Test with different values
|
212
|
+
code_file2 = CodeFileModel(
|
213
|
+
file_path="src/app.js",
|
214
|
+
language="javascript",
|
215
|
+
code="function greet() {\n console.log('Hello, World!');\n}",
|
216
|
+
)
|
217
|
+
formatted2 = CodeFileModel.format(code_file2)
|
218
|
+
parsed2 = CodeFileModel.parse(formatted2)
|
219
|
+
assert (
|
220
|
+
parsed2 == code_file2
|
221
|
+
), f"Parsing failed for different values. Expected:\n{code_file2}\nGot:\n{parsed2}"
|
222
|
+
print("Different values test passed.")
|
223
|
+
|
224
|
+
# Test tolerant parsing
|
225
|
+
tolerant_input = """<code_file>
|
226
|
+
file_path: src/main.py
|
227
|
+
language: python
|
228
|
+
```
|
229
|
+
def main():
|
230
|
+
print('Hello, World!')
|
231
|
+
```
|
232
|
+
</code_file>"""
|
233
|
+
parsed_tolerant = CodeFileModel.parse(tolerant_input)
|
234
|
+
expected_tolerant = CodeFileModel(
|
235
|
+
file_path="src/main.py",
|
236
|
+
language="python",
|
237
|
+
code="def main():\n print('Hello, World!')",
|
238
|
+
)
|
239
|
+
assert (
|
240
|
+
parsed_tolerant == expected_tolerant
|
241
|
+
), f"Tolerant parsing failed. Expected:\n{expected_tolerant}\nGot:\n{parsed_tolerant}"
|
242
|
+
print("Tolerant parsing test passed.")
|
243
|
+
|
244
|
+
# Test tolerant parsing without end token and last field suffix
|
245
|
+
tolerant_input_no_end = """<code_file>
|
246
|
+
file_path: src/main.py
|
247
|
+
language: python
|
248
|
+
```
|
249
|
+
def main():
|
250
|
+
print('Hello, World!')"""
|
251
|
+
parsed_tolerant_no_end = CodeFileModel.parse(tolerant_input_no_end)
|
252
|
+
expected_tolerant_no_end = CodeFileModel(
|
253
|
+
file_path="src/main.py",
|
254
|
+
language="python",
|
255
|
+
code="def main():\n print('Hello, World!')",
|
256
|
+
)
|
257
|
+
assert (
|
258
|
+
parsed_tolerant_no_end == expected_tolerant_no_end
|
259
|
+
), f"Tolerant parsing without end token failed. Expected:\n{expected_tolerant_no_end}\nGot:\n{parsed_tolerant_no_end}"
|
260
|
+
print("Tolerant parsing without end token test passed.")
|
261
|
+
|
262
|
+
# Test find_all_candidates method
|
263
|
+
test_string = """
|
264
|
+
Some text before
|
265
|
+
<code_file>
|
266
|
+
file_path: src/main.py
|
267
|
+
language: python
|
268
|
+
```
|
269
|
+
def main():
|
270
|
+
print('Hello, World!')
|
271
|
+
```
|
272
|
+
</code_file>
|
273
|
+
Some text in between
|
274
|
+
<code_file>
|
275
|
+
file_path: src/helper.py
|
276
|
+
language: python
|
277
|
+
```
|
278
|
+
def helper():
|
279
|
+
return 'Helper function'
|
280
|
+
```
|
281
|
+
</code_file>
|
282
|
+
<code_file>
|
283
|
+
file_path: src/incomplete.py
|
284
|
+
language: python
|
285
|
+
```
|
286
|
+
def incomplete():
|
287
|
+
print('No end token')
|
288
|
+
Some text after
|
289
|
+
"""
|
290
|
+
|
291
|
+
candidates = FormattingModel.find_all_candidates(
|
292
|
+
test_string, "<code_file>", "</code_file>"
|
293
|
+
)
|
294
|
+
assert len(candidates) == 3, f"Expected 3 candidates, got {len(candidates)}"
|
295
|
+
assert candidates[0].startswith("<code_file>") and candidates[0].endswith(
|
296
|
+
"</code_file>"
|
297
|
+
), "First candidate is incorrect"
|
298
|
+
assert candidates[1].startswith("<code_file>") and candidates[1].endswith(
|
299
|
+
"</code_file>"
|
300
|
+
), "Second candidate is incorrect"
|
301
|
+
assert candidates[2].startswith("<code_file>") and not candidates[2].endswith(
|
302
|
+
"</code_file>"
|
303
|
+
), "Third candidate is incorrect"
|
304
|
+
print("find_all_candidates test passed.")
|
305
|
+
|
306
|
+
print("All tests passed successfully!")
|
307
|
+
|
308
|
+
# Test field order
|
309
|
+
code_file = CodeFileModel(
|
310
|
+
file_path="src/main.py",
|
311
|
+
language="python",
|
312
|
+
code="def main():\n print('Hello, World!')",
|
313
|
+
)
|
314
|
+
formatted = CodeFileModel.format(code_file)
|
315
|
+
expected_format = """<code_file>
|
316
|
+
file_path: src/main.py
|
317
|
+
language: python
|
318
|
+
```
|
319
|
+
def main():
|
320
|
+
print('Hello, World!')
|
321
|
+
```
|
322
|
+
</code_file>"""
|
323
|
+
assert (
|
324
|
+
formatted == expected_format
|
325
|
+
), f"Formatting with field order failed. Expected:\n{expected_format}\nGot:\n{formatted}"
|
326
|
+
print("Field order test passed.")
|
327
|
+
|
328
|
+
# Test parsing with different field order
|
329
|
+
class DifferentOrderCodeFileModel(FormattingModel):
|
330
|
+
language: str = Field(
|
331
|
+
..., format_metadata=FormatMetadata(prefix="language: ", order=1)
|
332
|
+
)
|
333
|
+
file_path: str = Field(
|
334
|
+
..., format_metadata=FormatMetadata(prefix="file_path: ", order=2)
|
335
|
+
)
|
336
|
+
code: str = Field(
|
337
|
+
...,
|
338
|
+
format_metadata=FormatMetadata(
|
339
|
+
prefix="```\n", suffix="\n```", multiline=True, order=3
|
340
|
+
),
|
341
|
+
)
|
342
|
+
|
343
|
+
class Config:
|
344
|
+
start_token = "<code_file>"
|
345
|
+
end_token = "</code_file>"
|
346
|
+
|
347
|
+
different_order_input = """<code_file>
|
348
|
+
language: python
|
349
|
+
file_path: src/main.py
|
350
|
+
```
|
351
|
+
def main():
|
352
|
+
print('Hello, World!')
|
353
|
+
```
|
354
|
+
</code_file>"""
|
355
|
+
parsed_different_order = DifferentOrderCodeFileModel.parse(different_order_input)
|
356
|
+
expected_different_order = DifferentOrderCodeFileModel(
|
357
|
+
language="python",
|
358
|
+
file_path="src/main.py",
|
359
|
+
code="def main():\n print('Hello, World!')",
|
360
|
+
)
|
361
|
+
assert (
|
362
|
+
parsed_different_order == expected_different_order
|
363
|
+
), f"Parsing with different field order failed. Expected:\n{expected_different_order}\nGot:\n{parsed_different_order}"
|
364
|
+
print("Different field order parsing test passed.")
|
365
|
+
|
366
|
+
# Test with code containing special characters
|
367
|
+
complex_code = CodeFileModel(
|
368
|
+
file_path="src/complex.py",
|
369
|
+
language="python",
|
370
|
+
code='''
|
371
|
+
def complex_function():
|
372
|
+
# This is a comment with "quotes" and 'apostrophes'
|
373
|
+
special_chars = "!@#$%^&*()_+{}[]|\\:;<>?,./"
|
374
|
+
multiline_string = """
|
375
|
+
This is a multiline string.
|
376
|
+
It can contain anything:
|
377
|
+
1. Numbers: 12345
|
378
|
+
2. Symbols: !@#$%^&*()
|
379
|
+
3. Quotes: "Hello" 'World'
|
380
|
+
4. Backticks: `code`
|
381
|
+
5. Even triple backticks: ```python
|
382
|
+
"""
|
383
|
+
print(f"Special chars: {special_chars}")
|
384
|
+
print(multiline_string)
|
385
|
+
''',
|
386
|
+
)
|
387
|
+
|
388
|
+
formatted_complex = CodeFileModel.format(complex_code)
|
389
|
+
parsed_complex = CodeFileModel.parse(formatted_complex)
|
390
|
+
assert (
|
391
|
+
parsed_complex == complex_code
|
392
|
+
), f"Complex code parsing failed. Expected:\n{complex_code}\nGot:\n{parsed_complex}"
|
393
|
+
print("Complex code test passed.")
|
394
|
+
|
395
|
+
print("All tests passed successfully!")
|