langroid 0.16.5__py3-none-any.whl → 0.16.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. langroid/agent/md_tool_message_grammar.py +455 -0
  2. langroid/agent/tools/code_file_tool_parse.py +150 -0
  3. langroid/agent/tools/code_file_tool_pyparsing.py +194 -0
  4. langroid/agent/tools/code_file_tool_pyparsing2.py +199 -0
  5. langroid/agent/tools/formatted_model_custom.py +150 -0
  6. langroid/agent/tools/formatted_model_custom2.py +168 -0
  7. langroid/agent/tools/formatted_model_custom3.py +279 -0
  8. langroid/agent/tools/formatted_model_custom4.py +395 -0
  9. langroid/agent/tools/formatted_model_jinja.py +133 -0
  10. langroid/agent/tools/formatted_model_jinja.py-e +122 -0
  11. langroid/agent/tools/formatted_model_jinja2.py +145 -0
  12. langroid/agent/tools/formatted_model_jinja2.py-e +135 -0
  13. langroid/agent/tools/formatted_model_lark.py +0 -0
  14. langroid/agent/tools/formatted_model_lark2.py +168 -0
  15. langroid/agent/tools/formatted_model_parse.py +105 -0
  16. langroid/agent/tools/formatted_model_parse.py-e +98 -0
  17. langroid/agent/tools/formatted_model_parse2.py +113 -0
  18. langroid/agent/tools/formatted_model_parse2.py-e +109 -0
  19. langroid/agent/tools/formatted_model_parse3.py +114 -0
  20. langroid/agent/tools/formatted_model_parse3.py-e +110 -0
  21. langroid/agent/tools/formatted_model_parsimon.py +194 -0
  22. langroid/agent/tools/formatted_model_parsimon.py-e +186 -0
  23. langroid/agent/tools/formatted_model_pyparsing.py +169 -0
  24. langroid/agent/tools/formatted_model_pyparsing.py-e +149 -0
  25. langroid/agent/tools/formatted_model_pyparsing2.py +159 -0
  26. langroid/agent/tools/formatted_model_pyparsing2.py-e +143 -0
  27. langroid/agent/tools/formatted_model_pyparsing3.py +133 -0
  28. langroid/agent/tools/formatted_model_pyparsing3.py-e +121 -0
  29. langroid/agent/tools/formatted_model_pyparsing4.py +213 -0
  30. langroid/agent/tools/formatted_model_pyparsing4.py-e +176 -0
  31. langroid/agent/tools/formatted_model_pyparsing5.py +173 -0
  32. langroid/agent/tools/formatted_model_pyparsing5.py-e +142 -0
  33. langroid/agent/tools/formatted_model_regex.py +246 -0
  34. langroid/agent/tools/formatted_model_regex.py-e +248 -0
  35. langroid/agent/tools/formatted_model_regex2.py +250 -0
  36. langroid/agent/tools/formatted_model_regex2.py-e +253 -0
  37. langroid/agent/tools/formatted_model_tatsu.py +172 -0
  38. langroid/agent/tools/formatted_model_tatsu.py-e +160 -0
  39. langroid/agent/tools/formatted_model_template.py +217 -0
  40. langroid/agent/tools/formatted_model_template.py-e +200 -0
  41. langroid/agent/tools/formatted_model_xml.py +178 -0
  42. langroid/agent/tools/formatted_model_xml2.py +178 -0
  43. langroid/agent/tools/formatted_model_xml3.py +132 -0
  44. langroid/agent/tools/formatted_model_xml4.py +130 -0
  45. langroid/agent/tools/formatted_model_xml5.py +130 -0
  46. langroid/agent/tools/formatted_model_xml6.py +113 -0
  47. langroid/agent/tools/formatted_model_xml7.py +117 -0
  48. langroid/agent/tools/formatted_model_xml8.py +164 -0
  49. langroid/agent/tools/generic_tool.py +165 -0
  50. langroid/agent/tools/generic_tool_tatsu.py +275 -0
  51. langroid/agent/tools/grammar_based_model.py +132 -0
  52. langroid/agent/tools/grammar_based_model.py-e +128 -0
  53. langroid/agent/tools/grammar_based_model_lark.py +156 -0
  54. langroid/agent/tools/grammar_based_model_lark.py-e +153 -0
  55. langroid/agent/tools/grammar_based_model_parse.py +86 -0
  56. langroid/agent/tools/grammar_based_model_parse.py-e +80 -0
  57. langroid/agent/tools/grammar_based_model_parsimonious.py +129 -0
  58. langroid/agent/tools/grammar_based_model_parsimonious.py-e +120 -0
  59. langroid/agent/tools/grammar_based_model_pyparsing.py +105 -0
  60. langroid/agent/tools/grammar_based_model_pyparsing.py-e +103 -0
  61. langroid/agent/tools/grammar_based_model_regex.py +139 -0
  62. langroid/agent/tools/grammar_based_model_regex.py-e +130 -0
  63. langroid/agent/tools/grammar_based_model_regex2.py +124 -0
  64. langroid/agent/tools/grammar_based_model_regex2.py-e +116 -0
  65. langroid/agent/tools/grammar_based_model_tatsu.py +80 -0
  66. langroid/agent/tools/grammar_based_model_tatsu.py-e +77 -0
  67. langroid/agent/tools/lark_earley_example.py +135 -0
  68. langroid/agent/tools/lark_earley_example.py-e +117 -0
  69. langroid/agent/tools/lark_example.py +72 -0
  70. langroid/agent/tools/parse_example.py +76 -0
  71. langroid/agent/tools/parse_example2.py +87 -0
  72. langroid/agent/tools/parse_example3.py +42 -0
  73. langroid/agent/tools/parse_test.py +791 -0
  74. langroid/agent/xml_tool_message.py +106 -0
  75. langroid/language_models/openai_gpt.py +6 -1
  76. {langroid-0.16.5.dist-info → langroid-0.16.7.dist-info}/METADATA +1 -1
  77. {langroid-0.16.5.dist-info → langroid-0.16.7.dist-info}/RECORD +80 -6
  78. pyproject.toml +1 -1
  79. {langroid-0.16.5.dist-info → langroid-0.16.7.dist-info}/LICENSE +0 -0
  80. {langroid-0.16.5.dist-info → langroid-0.16.7.dist-info}/WHEEL +0 -0
@@ -0,0 +1,194 @@
1
+ """"
2
+ Non-JSON Tool for LLM to specify contents of a code file.
3
+
4
+ Why Non-JSON? Because there are numerous issues with even the best LLMs trying
5
+ to return code within JSON strings (e.g. unescaped newlines, quotes, etc.),
6
+ and the problem is even worse with weak LLMs. Json repair methods exist, but
7
+ can't deal with all possible cases.
8
+
9
+ E.g. see this study from Aider: https://aider.chat/2024/08/14/code-in-json.html
10
+
11
+ Note: We express the formatting rules with a template since it has several benefits:
12
+ - all of the formatting rules are in one place,
13
+ - we get a parser for free, and don't have to write parsing code,
14
+ - we get a formatting example generator for free, and don't have to write
15
+ example generation code.
16
+ - consistency between the parser and the example generator is guaranteed.
17
+ """
18
+
19
+ from typing import Any, Callable, Dict, List, Tuple, Type
20
+
21
+ from pyparsing import (
22
+ LineEnd,
23
+ Literal,
24
+ Optional,
25
+ ParserElement,
26
+ SkipTo,
27
+ White,
28
+ Word,
29
+ alphanums,
30
+ lineEnd,
31
+ printables,
32
+ )
33
+
34
+ from langroid.agent.tool_message import ToolMessage
35
+ from langroid.utils.constants import TOOL, TOOL_END
36
+
37
+ CODE_FENCE_START = "`" * 3
38
+ CODE_FENCE_END = "`" * 3
39
+
40
+
41
+ class CodeFileTool(ToolMessage):
42
+ """
43
+ Used by LLM to specify contents of a code file.
44
+ """
45
+
46
+ request: str = "code_file_tool"
47
+ purpose: str = """
48
+ To specify the contents of a code file.
49
+ """
50
+ file_path: str
51
+ contents: str
52
+ language: str
53
+
54
+ @classmethod
55
+ def create_parser(cls):
56
+ TOOL_START = Literal(TOOL + ":")
57
+ CODE_FENCE = Literal("```")
58
+
59
+ file_path = SkipTo(lineEnd)("file_path")
60
+ language = Word(alphanums)("language")
61
+ contents = SkipTo(CODE_FENCE)("contents")
62
+
63
+ parser = (
64
+ TOOL_START
65
+ + Optional(Word(printables), default=cls.default_value("request"))(
66
+ "request"
67
+ )
68
+ + lineEnd
69
+ + file_path
70
+ + lineEnd
71
+ + CODE_FENCE
72
+ + Optional(White()) # Allow space after ```
73
+ + language
74
+ + lineEnd
75
+ + contents
76
+ + CODE_FENCE
77
+ + lineEnd # Add this line to expect a newline after the closing fence
78
+ + Literal(TOOL_END)
79
+ )
80
+ return parser
81
+
82
+ @classmethod
83
+ def parse(cls, string) -> Dict[str, Any]:
84
+ parser = cls.create_parser()
85
+ try:
86
+ result = parser.parseString(string, parseAll=True)
87
+ return {
88
+ "request": result["request"],
89
+ "file_path": result["file_path"].strip(),
90
+ "language": result["language"],
91
+ "contents": result["contents"].strip(),
92
+ }
93
+ except Exception as e:
94
+ print(f"Parsing failed: {e}")
95
+ return {}
96
+
97
+ @classmethod
98
+ def format(cls, instance) -> str:
99
+ parser = cls.create_parser()
100
+
101
+ def format_element(element):
102
+ if isinstance(element, Literal):
103
+ return element.match
104
+ elif element.resultsName:
105
+ if element.resultsName == "request":
106
+ return instance.request
107
+ elif element.resultsName == "file_path":
108
+ return instance.file_path
109
+ elif element.resultsName == "language":
110
+ return instance.language
111
+ elif element.resultsName == "contents":
112
+ return f"{instance.contents}\n" # Add newline after contents
113
+ elif isinstance(element, LineEnd):
114
+ return "\n"
115
+ return ""
116
+
117
+ def traverse_parser(parser_element):
118
+ if isinstance(parser_element, ParserElement):
119
+ if isinstance(parser_element, SkipTo):
120
+ return format_element(parser_element)
121
+ elif hasattr(parser_element, "exprs"):
122
+ return "".join(
123
+ traverse_parser(expr) for expr in parser_element.exprs
124
+ )
125
+ else:
126
+ return format_element(parser_element)
127
+ return str(parser_element)
128
+
129
+ formatted_string = traverse_parser(parser)
130
+
131
+ return formatted_string.strip()
132
+
133
+ @classmethod
134
+ def create(cls, get_directory: Callable[[], str]) -> Type["CodeFileTool"]:
135
+ """
136
+ Create a subclass of CodeFileTool with a static method get_directory,
137
+ which returns the current directory path, so that all file paths are
138
+ interpreted as relative to the current directory.
139
+ """
140
+
141
+ class SubCodeFileTool(cls):
142
+ get_directory: Callable[[], str] = staticmethod(get_directory)
143
+
144
+ return SubCodeFileTool
145
+
146
+ @classmethod
147
+ def examples(cls) -> List[ToolMessage | Tuple[str, ToolMessage]]:
148
+ return [
149
+ cls(
150
+ file_path="src/lib.rs",
151
+ language="rust",
152
+ contents="""
153
+ // function to add two numbers
154
+ pub fn add(a: i32, b: i32) -> i32 {
155
+ a + b
156
+ }
157
+ """,
158
+ )
159
+ ]
160
+
161
+ @classmethod
162
+ def find_candidates(cls, input_str: str) -> List[str]:
163
+ """
164
+ Find all possible (top-level) candidates for
165
+ CodeFileTool in the input string.
166
+ """
167
+ # Use parse.findall to find all instances of the CodeFileTool pattern
168
+ parser = compile(cls.get_template())
169
+ matches = list(parser.findall(input_str))
170
+ candidates = [match.fixed for match in matches]
171
+ return candidates
172
+
173
+ @classmethod
174
+ def from_string(cls, input_string: str) -> "CodeFileTool":
175
+ """Parse a string into a CodeFileTool object, using the TEMPLATE."""
176
+ parsed_data = cls.parse(input_string)
177
+ if parsed_data:
178
+ return cls(**parsed_data)
179
+ raise ValueError("Invalid input string format")
180
+
181
+ @classmethod
182
+ def to_string(cls, instance) -> str:
183
+ """Convert a CodeFileTool object to a string, using the TEMPLATE."""
184
+ return cls.format(instance)
185
+
186
+ def __str__(self):
187
+ return self.to_string()
188
+
189
+ def __repr__(self):
190
+ return f"""CodeFileTool(
191
+ file_path='{self.file_path}',
192
+ language='{self.language}',
193
+ contents='{self.contents}')
194
+ """
@@ -0,0 +1,199 @@
1
+ """"
2
+ Non-JSON Tool for LLM to specify contents of a code file.
3
+
4
+ Why Non-JSON? Because there are numerous issues with even the best LLMs trying
5
+ to return code within JSON strings (e.g. unescaped newlines, quotes, etc.),
6
+ and the problem is even worse with weak LLMs. Json repair methods exist, but
7
+ can't deal with all possible cases.
8
+
9
+ E.g. see this study from Aider: https://aider.chat/2024/08/14/code-in-json.html
10
+
11
+ Note: We express the formatting rules with a template since it has several benefits:
12
+ - all of the formatting rules are in one place,
13
+ - we get a parser for free, and don't have to write parsing code,
14
+ - we get a formatting example generator for free, and don't have to write
15
+ example generation code.
16
+ - consistency between the parser and the example generator is guaranteed.
17
+ """
18
+
19
+ from typing import Callable, List, Tuple, Type
20
+
21
+ from pyparsing import (
22
+ Literal,
23
+ Optional,
24
+ SkipTo,
25
+ White,
26
+ Word,
27
+ alphanums,
28
+ lineEnd,
29
+ printables,
30
+ )
31
+
32
+ from langroid.agent.tool_message import ToolMessage
33
+ from langroid.agent.tools.generic_tool import GenericTool
34
+ from langroid.utils.constants import TOOL, TOOL_END
35
+
36
+ CODE_FENCE_START = "`" * 3
37
+ CODE_FENCE_END = "`" * 3
38
+
39
+
40
+ class CodeFileTool(GenericTool):
41
+ """
42
+ Used by LLM to specify contents of a code file.
43
+ """
44
+
45
+ request: str = "code_file_tool"
46
+ purpose: str = """
47
+ To specify the <contents> of a code file at <file_path>,
48
+ containing code in a specific <language>.
49
+ """
50
+ file_path: str
51
+ contents: str
52
+ language: str
53
+
54
+ @classmethod
55
+ def define_grammar(cls):
56
+ TOOL_START = Literal(TOOL + ":")
57
+ CODE_FENCE = Literal("```")
58
+
59
+ file_path = SkipTo(lineEnd)("file_path")
60
+ language = Word(alphanums)("language")
61
+ contents = SkipTo(lineEnd + CODE_FENCE)("contents")
62
+
63
+ grammar = (
64
+ TOOL_START
65
+ + Optional(White())
66
+ + Optional(Word(printables), default=cls.default_value("request"))(
67
+ "request"
68
+ )
69
+ + lineEnd
70
+ + Optional(White())
71
+ + file_path
72
+ + lineEnd
73
+ + CODE_FENCE
74
+ + Optional(White())
75
+ + language
76
+ + lineEnd
77
+ + contents
78
+ + lineEnd
79
+ + CODE_FENCE
80
+ + lineEnd
81
+ + Optional(White())
82
+ + Literal(TOOL_END)
83
+ + Optional(White())
84
+ )
85
+ return grammar
86
+
87
+ @classmethod
88
+ def create(cls, get_directory: Callable[[], str]) -> Type["CodeFileTool"]:
89
+ """
90
+ Create a subclass of CodeFileTool with a static method get_directory,
91
+ which returns the current directory path, so that all file paths are
92
+ interpreted as relative to the current directory.
93
+ """
94
+
95
+ class SubCodeFileTool(cls):
96
+ get_directory: Callable[[], str] = staticmethod(get_directory)
97
+
98
+ return SubCodeFileTool
99
+
100
+ @classmethod
101
+ def examples(cls) -> List[ToolMessage | Tuple[str, ToolMessage]]:
102
+ return [
103
+ cls(
104
+ file_path="src/lib.rs",
105
+ language="rust",
106
+ contents="""
107
+ // function to add two numbers
108
+ pub fn add(a: i32, b: i32) -> i32 {
109
+ a + b
110
+ }
111
+ """,
112
+ )
113
+ ]
114
+
115
+ def __repr__(self):
116
+ return f"""CodeFileTool(
117
+ file_path='{self.file_path}',
118
+ language='{self.language}',
119
+ contents='{self.contents}')
120
+ """
121
+
122
+
123
+ if __name__ == "__main__":
124
+ # Informal test to print instructions for CodeFileTool
125
+ print("Testing CodeFileTool instructions:")
126
+ print("-" * 50)
127
+
128
+ instructions = CodeFileTool.instructions()
129
+ print(instructions)
130
+
131
+ print("-" * 50)
132
+ print("End of instructions test")
133
+
134
+ # You can add more informal tests here if needed
135
+ # For example, testing the parse method:
136
+ print("\nTesting CodeFileTool parse method:")
137
+ print("-" * 50)
138
+
139
+ test_input = """TOOL: code_file_tool
140
+ src/main.py
141
+ ```python
142
+ def hello_world():
143
+ print("Hello, World!")
144
+
145
+ if __name__ == "__main__":
146
+ hello_world()
147
+ ```
148
+ TOOL_END"""
149
+
150
+ parsed_result = CodeFileTool.parse(test_input)
151
+ print("Parsed result:")
152
+ for key, value in parsed_result.items():
153
+ print(f"{key}: {value}")
154
+
155
+ print("-" * 50)
156
+ print("End of parse test")
157
+
158
+ # Test format method
159
+ print("\nTesting CodeFileTool format method:")
160
+ print("-" * 50)
161
+ test_instance = CodeFileTool(
162
+ request="code_file_tool",
163
+ file_path="tests/test_file.py",
164
+ language="python",
165
+ contents="""
166
+ def test_function():
167
+ assert 1 + 1 == 2
168
+
169
+ if __name__ == "__main__":
170
+ test_function()
171
+ """,
172
+ )
173
+ formatted_output = CodeFileTool.format(test_instance)
174
+ print("Formatted output:")
175
+ print(formatted_output)
176
+ print("-" * 50)
177
+ print("End of format test")
178
+
179
+ # Additional test: Round-trip (parse -> format -> parse)
180
+ print("\nTesting CodeFileTool round-trip (parse -> format -> parse):")
181
+ print("-" * 50)
182
+ initial_parse = CodeFileTool.parse(test_input)
183
+ initial_instance = CodeFileTool(**initial_parse)
184
+ formatted_output = CodeFileTool.format(initial_instance)
185
+ final_parse = CodeFileTool.parse(formatted_output)
186
+
187
+ print("Initial parse:")
188
+ print(initial_parse)
189
+ print("\nFormatted output:")
190
+ print(formatted_output)
191
+ print("\nFinal parse:")
192
+ print(final_parse)
193
+
194
+ if initial_parse == final_parse:
195
+ print("\nRound-trip test passed: Initial and final parses match.")
196
+ else:
197
+ print("\nRound-trip test failed: Initial and final parses do not match.")
198
+ print("-" * 50)
199
+ print("End of round-trip test")
@@ -0,0 +1,150 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Tuple
3
+
4
+ from langroid.pydantic_v1 import BaseModel
5
+
6
+
7
+ class FormattingModel(BaseModel, ABC):
8
+ @classmethod
9
+ @abstractmethod
10
+ def format_spec(cls) -> str:
11
+ pass
12
+
13
+ @classmethod
14
+ @abstractmethod
15
+ def parse_spec(cls) -> List[Tuple[str, str, str]]:
16
+ pass
17
+
18
+ @classmethod
19
+ @abstractmethod
20
+ def start_token(cls) -> str:
21
+ pass
22
+
23
+ @classmethod
24
+ @abstractmethod
25
+ def end_token(cls) -> str:
26
+ pass
27
+
28
+ @classmethod
29
+ def format(cls, instance: "FormattingModel") -> str:
30
+ spec = cls.format_spec()
31
+ formatted = spec.format(**instance.dict())
32
+ return f"{cls.start_token()}\n{formatted}\n{cls.end_token()}"
33
+
34
+ @classmethod
35
+ def parse(cls, formatted_string: str) -> "FormattingModel":
36
+ lines = formatted_string.strip().split("\n")
37
+ if lines[0] != cls.start_token() or lines[-1] != cls.end_token():
38
+ raise ValueError("Invalid start or end token")
39
+ content = "\n".join(lines[1:-1])
40
+
41
+ parsed_data = {}
42
+ for field, start, end in cls.parse_spec():
43
+ start_index = content.find(start)
44
+ if start_index == -1:
45
+ raise ValueError(f"Could not find start of {field}")
46
+ end_index = content.find(end, start_index + len(start))
47
+ if end_index == -1:
48
+ raise ValueError(f"Could not find end of {field}")
49
+ value = content[start_index + len(start) : end_index].strip()
50
+ parsed_data[field] = value
51
+
52
+ return cls(**parsed_data)
53
+
54
+
55
+ class CodeFileModel(FormattingModel):
56
+ file_path: str
57
+ language: str
58
+ code: str
59
+
60
+ @classmethod
61
+ def format_spec(cls):
62
+ return "file_path: {file_path}\nlanguage: {language}\n```\n{code}\n```"
63
+
64
+ @classmethod
65
+ def parse_spec(cls):
66
+ return [
67
+ ("file_path", "file_path:", "\n"),
68
+ ("language", "language:", "\n"),
69
+ ("code", "```\n", "\n```"),
70
+ ]
71
+
72
+ @classmethod
73
+ def start_token(cls):
74
+ return "<code_file>"
75
+
76
+ @classmethod
77
+ def end_token(cls):
78
+ return "</code_file>"
79
+
80
+
81
+ # Test cases
82
+ if __name__ == "__main__":
83
+ # Test formatting
84
+ code_file = CodeFileModel(
85
+ file_path="src/main.py",
86
+ language="python",
87
+ code="def main():\n print('Hello, World!')",
88
+ )
89
+ formatted = CodeFileModel.format(code_file)
90
+ expected_format = """<code_file>
91
+ file_path: src/main.py
92
+ language: python
93
+ ```
94
+ def main():
95
+ print('Hello, World!')
96
+ ```
97
+ </code_file>"""
98
+ assert (
99
+ formatted == expected_format
100
+ ), f"Formatting failed. Expected:\n{expected_format}\nGot:\n{formatted}"
101
+ print("Formatting test passed.")
102
+
103
+ # Test parsing
104
+ parsed = CodeFileModel.parse(formatted)
105
+ assert (
106
+ parsed == code_file
107
+ ), f"Parsing failed. Expected:\n{code_file}\nGot:\n{parsed}"
108
+ print("Parsing test passed.")
109
+
110
+ # Test round-trip
111
+ round_trip = CodeFileModel.parse(CodeFileModel.format(code_file))
112
+ assert (
113
+ round_trip == code_file
114
+ ), f"Round-trip failed. Expected:\n{code_file}\nGot:\n{round_trip}"
115
+ print("Round-trip test passed.")
116
+
117
+ # Test with different values
118
+ code_file2 = CodeFileModel(
119
+ file_path="src/app.js",
120
+ language="javascript",
121
+ code="function greet() {\n console.log('Hello, World!');\n}",
122
+ )
123
+ formatted2 = CodeFileModel.format(code_file2)
124
+ parsed2 = CodeFileModel.parse(formatted2)
125
+ assert (
126
+ parsed2 == code_file2
127
+ ), f"Parsing failed for different values. Expected:\n{code_file2}\nGot:\n{parsed2}"
128
+ print("Different values test passed.")
129
+
130
+ # Test tolerant parsing
131
+ tolerant_input = """<code_file>
132
+ file_path: src/main.py
133
+ language: python
134
+ ```
135
+ def main():
136
+ print('Hello, World!')
137
+ ```
138
+ </code_file>"""
139
+ parsed_tolerant = CodeFileModel.parse(tolerant_input)
140
+ expected_tolerant = CodeFileModel(
141
+ file_path="src/main.py",
142
+ language="python",
143
+ code="def main():\n print('Hello, World!')",
144
+ )
145
+ assert (
146
+ parsed_tolerant == expected_tolerant
147
+ ), f"Tolerant parsing failed. Expected:\n{expected_tolerant}\nGot:\n{parsed_tolerant}"
148
+ print("Tolerant parsing test passed.")
149
+
150
+ print("All tests passed successfully!")
@@ -0,0 +1,168 @@
1
+ from abc import ABC
2
+ from typing import Dict
3
+
4
+ from langroid.pydantic_v1 import BaseModel, Field
5
+
6
+
7
+ class FormatMetadata(BaseModel):
8
+ prefix: str = ""
9
+ suffix: str = ""
10
+ multiline: bool = False
11
+
12
+
13
+ class FormattingModel(BaseModel, ABC):
14
+ @classmethod
15
+ def format_spec(cls) -> str:
16
+ lines = []
17
+ for name, field in cls.__fields__.items():
18
+ metadata: FormatMetadata = field.field_info.extra.get(
19
+ "format_metadata", FormatMetadata()
20
+ )
21
+ if metadata.multiline:
22
+ lines.append(f"{metadata.prefix}{{{name}}}{metadata.suffix}")
23
+ else:
24
+ lines.append(f"{metadata.prefix}{{{name}}}{metadata.suffix}")
25
+ return "\n".join(lines)
26
+
27
+ @classmethod
28
+ def parse_spec(cls) -> Dict[str, FormatMetadata]:
29
+ return {
30
+ name: field.field_info.extra.get("format_metadata", FormatMetadata())
31
+ for name, field in cls.__fields__.items()
32
+ }
33
+
34
+ @classmethod
35
+ def start_token(cls) -> str:
36
+ return getattr(cls.Config, "start_token", "<format>")
37
+
38
+ @classmethod
39
+ def end_token(cls) -> str:
40
+ return getattr(cls.Config, "end_token", "</format>")
41
+
42
+ @classmethod
43
+ def format(cls, instance: "FormattingModel") -> str:
44
+ spec = cls.format_spec()
45
+ formatted = spec.format(**instance.dict())
46
+ return f"{cls.start_token()}\n{formatted}\n{cls.end_token()}"
47
+
48
+ @classmethod
49
+ def parse(cls, formatted_string: str) -> "FormattingModel":
50
+ lines = formatted_string.strip().split("\n")
51
+ if lines[0] != cls.start_token() or lines[-1] != cls.end_token():
52
+ raise ValueError("Invalid start or end token")
53
+ content = "\n".join(lines[1:-1])
54
+
55
+ parsed_data = {}
56
+ parse_spec = cls.parse_spec()
57
+
58
+ for field, metadata in parse_spec.items():
59
+ if metadata.multiline:
60
+ start = f"{metadata.prefix}"
61
+ end = f"{metadata.suffix}"
62
+ start_index = content.find(start)
63
+ if start_index == -1:
64
+ raise ValueError(f"Could not find start of {field}")
65
+ end_index = content.find(end, start_index + len(start))
66
+ if end_index == -1:
67
+ raise ValueError(f"Could not find end of {field}")
68
+ value = content[start_index + len(start) : end_index].strip()
69
+ else:
70
+ line_start = f"{metadata.prefix}"
71
+ line_end = metadata.suffix or "\n"
72
+ start_index = content.find(line_start)
73
+ if start_index == -1:
74
+ raise ValueError(f"Could not find {field}")
75
+ end_index = content.find(line_end, start_index + len(line_start))
76
+ if end_index == -1:
77
+ end_index = len(content)
78
+ value = content[start_index + len(line_start) : end_index].strip()
79
+
80
+ parsed_data[field] = value
81
+
82
+ return cls(**parsed_data)
83
+
84
+
85
+ class CodeFileModel(FormattingModel):
86
+ file_path: str = Field(..., format_metadata=FormatMetadata(prefix="file_path: "))
87
+ language: str = Field(..., format_metadata=FormatMetadata(prefix="language: "))
88
+ code: str = Field(
89
+ ...,
90
+ format_metadata=FormatMetadata(prefix="```\n", suffix="\n```", multiline=True),
91
+ )
92
+
93
+ class Config:
94
+ start_token = "<code_file>"
95
+ end_token = "</code_file>"
96
+
97
+
98
+ # Test cases
99
+ #
100
+ if __name__ == "__main__":
101
+ # Test formatting
102
+ code_file = CodeFileModel(
103
+ file_path="src/main.py",
104
+ language="python",
105
+ code="def main():\n print('Hello, World!')",
106
+ )
107
+ formatted = CodeFileModel.format(code_file)
108
+ expected_format = """<code_file>
109
+ file_path: src/main.py
110
+ language: python
111
+ ```
112
+ def main():
113
+ print('Hello, World!')
114
+ ```
115
+ </code_file>"""
116
+ assert (
117
+ formatted == expected_format
118
+ ), f"Formatting failed. Expected:\n{expected_format}\nGot:\n{formatted}"
119
+ print("Formatting test passed.")
120
+
121
+ # Test parsing
122
+ parsed = CodeFileModel.parse(formatted)
123
+ assert (
124
+ parsed == code_file
125
+ ), f"Parsing failed. Expected:\n{code_file}\nGot:\n{parsed}"
126
+ print("Parsing test passed.")
127
+
128
+ # Test round-trip
129
+ round_trip = CodeFileModel.parse(CodeFileModel.format(code_file))
130
+ assert (
131
+ round_trip == code_file
132
+ ), f"Round-trip failed. Expected:\n{code_file}\nGot:\n{round_trip}"
133
+ print("Round-trip test passed.")
134
+
135
+ # Test with different values
136
+ code_file2 = CodeFileModel(
137
+ file_path="src/app.js",
138
+ language="javascript",
139
+ code="function greet() {\n console.log('Hello, World!');\n}",
140
+ )
141
+ formatted2 = CodeFileModel.format(code_file2)
142
+ parsed2 = CodeFileModel.parse(formatted2)
143
+ assert (
144
+ parsed2 == code_file2
145
+ ), f"Parsing failed for different values. Expected:\n{code_file2}\nGot:\n{parsed2}"
146
+ print("Different values test passed.")
147
+
148
+ # Test tolerant parsing
149
+ tolerant_input = """<code_file>
150
+ file_path: src/main.py
151
+ language: python
152
+ ```
153
+ def main():
154
+ print('Hello, World!')
155
+ ```
156
+ </code_file>"""
157
+ parsed_tolerant = CodeFileModel.parse(tolerant_input)
158
+ expected_tolerant = CodeFileModel(
159
+ file_path="src/main.py",
160
+ language="python",
161
+ code="def main():\n print('Hello, World!')",
162
+ )
163
+ assert (
164
+ parsed_tolerant == expected_tolerant
165
+ ), f"Tolerant parsing failed. Expected:\n{expected_tolerant}\nGot:\n{parsed_tolerant}"
166
+ print("Tolerant parsing test passed.")
167
+
168
+ print("All tests passed successfully!")