framework-m-studio 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,368 @@
1
+ """DocType Test Generator.
2
+
3
+ This module generates comprehensive test files for DocTypes, including:
4
+ - CRUD tests (Create, Read, Update, Delete)
5
+ - Validation tests for required fields
6
+ - Integration test scaffolding
7
+
8
+ Usage:
9
+ from framework_m_studio.codegen.test_generator import generate_test
10
+
11
+ test_code = generate_test({
12
+ "name": "Todo",
13
+ "module": "myapp.doctypes.todo",
14
+ "fields": [{"name": "title", "type": "str", "required": True}],
15
+ })
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import re
21
+ from pathlib import Path
22
+ from typing import Any
23
+
24
+ from jinja2 import Environment, FileSystemLoader, select_autoescape
25
+
26
+ # Template directory
27
+ TEMPLATE_DIR = Path(__file__).parent / "templates"
28
+
29
+
30
+ def _get_jinja_env() -> Environment:
31
+ """Get configured Jinja2 environment."""
32
+ return Environment(
33
+ loader=FileSystemLoader(str(TEMPLATE_DIR)),
34
+ autoescape=select_autoescape(enabled_extensions=()),
35
+ trim_blocks=True,
36
+ lstrip_blocks=True,
37
+ keep_trailing_newline=True,
38
+ )
39
+
40
+
41
+ def _to_snake_case(name: str) -> str:
42
+ """Convert PascalCase to snake_case."""
43
+ s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
44
+ return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
45
+
46
+
47
+ def _get_test_value(field_type: str, field_name: str = "") -> str:
48
+ """Get a test value for a field type."""
49
+ # Handle common field types
50
+ base_type = field_type.split("[")[0].split("|")[0].strip()
51
+
52
+ type_values = {
53
+ "str": f'"{field_name or "test"}_value"',
54
+ "int": "42",
55
+ "float": "3.14",
56
+ "bool": "True",
57
+ "date": "date.today()",
58
+ "datetime": "datetime.now()",
59
+ "UUID": "uuid4()",
60
+ "uuid": "uuid4()",
61
+ "Decimal": 'Decimal("100.00")',
62
+ }
63
+
64
+ return type_values.get(base_type, '"test"')
65
+
66
+
67
+ def _get_update_value(field_type: str, field_name: str = "") -> str:
68
+ """Get an updated test value for a field type."""
69
+ base_type = field_type.split("[")[0].split("|")[0].strip()
70
+
71
+ type_values = {
72
+ "str": f'"{field_name or "test"}_updated"',
73
+ "int": "99",
74
+ "float": "9.99",
75
+ "bool": "False",
76
+ "date": "date.today()",
77
+ "datetime": "datetime.now()",
78
+ "UUID": "uuid4()",
79
+ "uuid": "uuid4()",
80
+ "Decimal": 'Decimal("999.99")',
81
+ }
82
+
83
+ return type_values.get(base_type, '"updated"')
84
+
85
+
86
+ def generate_test(schema: dict[str, Any]) -> str:
87
+ """Generate comprehensive test file for a DocType.
88
+
89
+ Args:
90
+ schema: Dictionary with DocType schema:
91
+ - name: str - Class name (PascalCase)
92
+ - module: str - Python module path
93
+ - fields: list[dict] - Field definitions
94
+
95
+ Returns:
96
+ Generated Python test file source code.
97
+
98
+ The generated tests include:
99
+ - test_create_{doctype}: Basic creation test
100
+ - test_{doctype}_defaults: Default value verification
101
+ - test_{doctype}_validation: Required field validation
102
+ - test_update_{doctype}: Field update test
103
+ - test_{doctype}_crud: Full CRUD cycle (if repository available)
104
+
105
+ Example:
106
+ >>> code = generate_test({
107
+ ... "name": "Todo",
108
+ ... "module": "myapp.doctypes.todo",
109
+ ... "fields": [{"name": "title", "type": "str", "required": True}],
110
+ ... })
111
+ """
112
+ name = schema["name"]
113
+ snake_name = _to_snake_case(name)
114
+ module = schema.get("module", f"doctypes.{snake_name}")
115
+ fields = schema.get("fields", [])
116
+
117
+ # Prepare field data with test values
118
+ required_fields = []
119
+ optional_fields = []
120
+
121
+ for field in fields:
122
+ field_data = {
123
+ "name": field["name"],
124
+ "type": field.get("type", "str"),
125
+ "required": field.get("required", True),
126
+ "default": field.get("default"),
127
+ "test_value": _get_test_value(field.get("type", "str"), field["name"]),
128
+ "update_value": _get_update_value(field.get("type", "str"), field["name"]),
129
+ }
130
+
131
+ if field_data["required"] and not field_data["default"]:
132
+ required_fields.append(field_data)
133
+ else:
134
+ optional_fields.append(field_data)
135
+
136
+ # Build test code
137
+ lines = [
138
+ f'"""Tests for {name} DocType.',
139
+ "",
140
+ "Auto-generated by Framework M Studio.",
141
+ '"""',
142
+ "",
143
+ "from __future__ import annotations",
144
+ "",
145
+ ]
146
+
147
+ # Add imports based on field types
148
+ imports_needed = set()
149
+ for field in fields:
150
+ field_type = field.get("type", "str")
151
+ if "date" in field_type.lower() and "datetime" not in field_type.lower():
152
+ imports_needed.add("from datetime import date")
153
+ if "datetime" in field_type.lower():
154
+ imports_needed.add("from datetime import datetime")
155
+ if "uuid" in field_type.lower():
156
+ imports_needed.add("from uuid import uuid4")
157
+ if "decimal" in field_type.lower():
158
+ imports_needed.add("from decimal import Decimal")
159
+
160
+ if imports_needed:
161
+ lines.extend(sorted(imports_needed))
162
+ lines.append("")
163
+
164
+ lines.extend(
165
+ [
166
+ "import pytest",
167
+ "",
168
+ f"from {module} import {name}",
169
+ "",
170
+ "",
171
+ f"class Test{name}:",
172
+ f' """Tests for {name} DocType."""',
173
+ "",
174
+ ]
175
+ )
176
+
177
+ # Test: Create with required fields
178
+ lines.extend(
179
+ [
180
+ f" def test_create_{snake_name}(self) -> None:",
181
+ f' """Test creating a {name} instance with required fields."""',
182
+ f" doc = {name}(",
183
+ ]
184
+ )
185
+
186
+ for field in required_fields:
187
+ lines.append(f" {field['name']}={field['test_value']},")
188
+
189
+ lines.extend(
190
+ [
191
+ " )",
192
+ " assert doc is not None",
193
+ ]
194
+ )
195
+
196
+ for field in required_fields:
197
+ lines.append(f" assert doc.{field['name']} == {field['test_value']}")
198
+
199
+ lines.append("")
200
+
201
+ # Test: Default values
202
+ if optional_fields:
203
+ lines.extend(
204
+ [
205
+ f" def test_{snake_name}_defaults(self) -> None:",
206
+ ' """Test default values for optional fields."""',
207
+ f" doc = {name}(",
208
+ ]
209
+ )
210
+
211
+ for field in required_fields:
212
+ lines.append(f" {field['name']}={field['test_value']},")
213
+
214
+ lines.extend(
215
+ [
216
+ " )",
217
+ ]
218
+ )
219
+
220
+ for field in optional_fields:
221
+ if field["default"]:
222
+ lines.append(
223
+ f" assert doc.{field['name']} == {field['default']}"
224
+ )
225
+ else:
226
+ lines.append(f" assert doc.{field['name']} is None")
227
+
228
+ lines.append("")
229
+
230
+ # Test: Validation (required fields)
231
+ if required_fields:
232
+ lines.extend(
233
+ [
234
+ f" def test_{snake_name}_validation(self) -> None:",
235
+ ' """Test validation error when required fields are missing."""',
236
+ " import pydantic",
237
+ "",
238
+ " with pytest.raises(pydantic.ValidationError):",
239
+ f" {name}() # Missing required fields",
240
+ "",
241
+ ]
242
+ )
243
+
244
+ # Test: Update fields
245
+ if required_fields:
246
+ lines.extend(
247
+ [
248
+ f" def test_update_{snake_name}(self) -> None:",
249
+ f' """Test updating {name} fields."""',
250
+ f" doc = {name}(",
251
+ ]
252
+ )
253
+
254
+ for field in required_fields:
255
+ lines.append(f" {field['name']}={field['test_value']},")
256
+
257
+ lines.extend(
258
+ [
259
+ " )",
260
+ "",
261
+ ]
262
+ )
263
+
264
+ # Update the first required field
265
+ if required_fields:
266
+ first_field = required_fields[0]
267
+ lines.extend(
268
+ [
269
+ f" doc.{first_field['name']} = {first_field['update_value']}",
270
+ f" assert doc.{first_field['name']} == {first_field['update_value']}",
271
+ ]
272
+ )
273
+
274
+ lines.append("")
275
+
276
+ # Test: CRUD cycle (integration test stub)
277
+ lines.extend(
278
+ [
279
+ " @pytest.mark.skip(reason='Integration test - requires database')",
280
+ f" async def test_{snake_name}_crud(self) -> None:",
281
+ f' """Test full CRUD cycle for {name}."""',
282
+ " # Create",
283
+ f" doc = {name}(",
284
+ ]
285
+ )
286
+
287
+ for field in required_fields:
288
+ lines.append(f" {field['name']}={field['test_value']},")
289
+
290
+ lines.extend(
291
+ [
292
+ " )",
293
+ " # saved_doc = await repository.save(doc)",
294
+ " # assert saved_doc.id is not None",
295
+ "",
296
+ " # Read",
297
+ " # fetched = await repository.get(saved_doc.id)",
298
+ " # assert fetched is not None",
299
+ ]
300
+ )
301
+
302
+ if required_fields:
303
+ first_field = required_fields[0]
304
+ lines.append(
305
+ f" # assert fetched.{first_field['name']} == {first_field['test_value']}"
306
+ )
307
+
308
+ lines.extend(
309
+ [
310
+ "",
311
+ " # Update",
312
+ ]
313
+ )
314
+
315
+ if required_fields:
316
+ first_field = required_fields[0]
317
+ lines.extend(
318
+ [
319
+ f" # fetched.{first_field['name']} = {first_field['update_value']}",
320
+ " # updated = await repository.save(fetched)",
321
+ f" # assert updated.{first_field['name']} == {first_field['update_value']}",
322
+ ]
323
+ )
324
+
325
+ lines.extend(
326
+ [
327
+ "",
328
+ " # Delete",
329
+ " # await repository.delete(saved_doc.id)",
330
+ " # deleted = await repository.get(saved_doc.id)",
331
+ " # assert deleted is None",
332
+ "",
333
+ ]
334
+ )
335
+
336
+ return "\n".join(lines)
337
+
338
+
339
+ def generate_test_file(
340
+ schema: dict[str, Any],
341
+ output_dir: Path | str,
342
+ ) -> Path:
343
+ """Generate and write a test file for a DocType.
344
+
345
+ Args:
346
+ schema: DocType schema dictionary
347
+ output_dir: Directory to write test file to
348
+
349
+ Returns:
350
+ Path to the created test file
351
+ """
352
+ output_dir = Path(output_dir)
353
+ output_dir.mkdir(parents=True, exist_ok=True)
354
+
355
+ name = schema["name"]
356
+ snake_name = _to_snake_case(name)
357
+
358
+ test_code = generate_test(schema)
359
+ test_path = output_dir / f"test_{snake_name}.py"
360
+ test_path.write_text(test_code, encoding="utf-8")
361
+
362
+ return test_path
363
+
364
+
365
+ __all__ = [
366
+ "generate_test",
367
+ "generate_test_file",
368
+ ]