pyrmute 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyrmute/exceptions.py ADDED
@@ -0,0 +1,55 @@
1
+ """Exceptions."""
2
+
3
+ from typing import Self
4
+
5
+
6
+ class VersionedModelError(Exception):
7
+ """Base exception for all versioned model errors."""
8
+
9
+
10
+ class ModelNotFoundError(VersionedModelError):
11
+ """Raised when a model or version cannot be found in the registry."""
12
+
13
+ def __init__(self: Self, name: str, version: str | None = None) -> None:
14
+ """Initializes ModelNotFoundError."""
15
+ self.name = name
16
+ self.version = version
17
+ if version:
18
+ msg = f"Model '{name}' version '{version}' not found in registry"
19
+ else:
20
+ msg = f"Model '{name}' not found in registry"
21
+ super().__init__(msg)
22
+
23
+
24
+ class MigrationError(VersionedModelError):
25
+ """Raised when a migration fails or cannot be found."""
26
+
27
+ def __init__(
28
+ self: Self,
29
+ name: str,
30
+ from_version: str,
31
+ to_version: str,
32
+ reason: str | None = None,
33
+ ) -> None:
34
+ """Initializes MigrationError."""
35
+ self.name = name
36
+ self.from_version = from_version
37
+ self.to_version = to_version
38
+ self.reason = reason
39
+
40
+ msg = f"Migration failed for '{name}': {from_version} → {to_version}"
41
+ if reason:
42
+ msg += f"\nReason: {reason}"
43
+ super().__init__(msg)
44
+
45
+
46
+ class InvalidVersionError(VersionedModelError):
47
+ """Raised when a version string cannot be parsed."""
48
+
49
+ def __init__(self: Self, version_string: str, reason: str | None = None) -> None:
50
+ """Initializes InvalidVersionError."""
51
+ self.version_string = version_string
52
+ msg = f"Invalid version string: '{version_string}'"
53
+ if reason:
54
+ msg += f"\n{reason}"
55
+ super().__init__(msg)
@@ -0,0 +1,161 @@
1
+ """Migration testing utilities."""
2
+
3
+ from collections.abc import Iterator
4
+ from dataclasses import dataclass
5
+ from typing import Self
6
+
7
+ from .types import ModelData
8
+
9
+
10
+ @dataclass
11
+ class MigrationTestCase:
12
+ """Test case for migration validation.
13
+
14
+ Defines input data and expected output for testing a migration function. If target
15
+ is None, the test only verifies the migration doesn't crash.
16
+
17
+ Attributes:
18
+ source: Input data to migrate.
19
+ target: Expected output after migration. If None, only validates that migration
20
+ completes without errors.
21
+ description: Optional description of what this test case validates.
22
+
23
+ Example:
24
+ >>> test_case = MigrationTestCase(
25
+ ... source={"name": "Alice"},
26
+ ... target={"name": "Alice", "email": "alice@example.com"},
27
+ ... description="Adds default email field"
28
+ ... )
29
+ """
30
+
31
+ source: ModelData
32
+ target: ModelData | None = None
33
+ description: str = ""
34
+
35
+
36
+ @dataclass
37
+ class MigrationTestResult:
38
+ """Result of a single migration test case.
39
+
40
+ Contains the test case, actual output, pass/fail status, and any error message.
41
+
42
+ Attributes:
43
+ test_case: Original test case that was executed.
44
+ actual: Actual output produced by the migration.
45
+ passed: Whether the test passed (output matched expected or no errors).
46
+ error: Error message if test failed, None if passed.
47
+
48
+ Example:
49
+ >>> result = MigrationTestResult(
50
+ ... test_case=test_case,
51
+ ... actual={"name": "Alice", "email": "alice@example.com"},
52
+ ... passed=True,
53
+ ... error=None
54
+ ... )
55
+ """
56
+
57
+ test_case: MigrationTestCase
58
+ actual: ModelData
59
+ passed: bool
60
+ error: str | None = None
61
+
62
+ def __str__(self: Self) -> str:
63
+ """Format test result as human-readable string."""
64
+ if self.passed:
65
+ desc = (
66
+ f" - {self.test_case.description}" if self.test_case.description else ""
67
+ )
68
+ return f"✓ Test passed{desc}"
69
+
70
+ desc = f" - {self.test_case.description}" if self.test_case.description else ""
71
+ return f"""✗ Test failed{desc}
72
+ Source: {self.test_case.source}
73
+ Expected: {self.test_case.target}
74
+ Actual: {self.actual}
75
+ Error: {self.error}"""
76
+
77
+
78
+ class MigrationTestResults:
79
+ """Collection of migration test results.
80
+
81
+ Provides convenient methods for checking overall test status and accessing failed
82
+ tests.
83
+
84
+ Attributes:
85
+ results: List of individual test results.
86
+
87
+ Example:
88
+ >>> results = MigrationTestResults([result1, result2, result3])
89
+ >>> if results.all_passed:
90
+ ... print("All tests passed!")
91
+ >>> else:
92
+ ... print(f"{len(results.failures)} test(s) failed")
93
+ ... for failure in results.failures:
94
+ ... print(failure)
95
+ """
96
+
97
+ def __init__(self: Self, results: list[MigrationTestResult]) -> None:
98
+ """Initialize test results collection.
99
+
100
+ Args:
101
+ results: List of individual test results.
102
+ """
103
+ self.results = results
104
+
105
+ @property
106
+ def all_passed(self: Self) -> bool:
107
+ """Check if all tests passed.
108
+
109
+ Returns:
110
+ True if all tests passed, False if any failed.
111
+ """
112
+ return all(r.passed for r in self.results)
113
+
114
+ @property
115
+ def failures(self: Self) -> list[MigrationTestResult]:
116
+ """Get list of failed tests.
117
+
118
+ Returns:
119
+ List of test results that failed.
120
+ """
121
+ return [r for r in self.results if not r.passed]
122
+
123
+ def assert_all_passed(self: Self) -> None:
124
+ """Assert all tests passed, raising detailed error if any failed.
125
+
126
+ Raises:
127
+ AssertionError: If any tests failed, with details about failures.
128
+
129
+ Example:
130
+ >>> # Use in pytest
131
+ >>> def test_user_migration():
132
+ ... res = manager.test_migration("User", "1.0.0", "2.0.0", test_cases)
133
+ ... res.assert_all_passed()
134
+ """
135
+ if not self.all_passed:
136
+ messages = [str(f) for f in self.failures]
137
+ raise AssertionError(
138
+ f"\n{len(self.failures)} migration test(s) failed:\n"
139
+ + "\n\n".join(messages)
140
+ )
141
+
142
+ def __len__(self: Self) -> int:
143
+ """Get total number of test results."""
144
+ return len(self.results)
145
+
146
+ def __iter__(self: Self) -> Iterator[MigrationTestResult]:
147
+ """Iterate over test results."""
148
+ return iter(self.results)
149
+
150
+ def __str__(self) -> str:
151
+ """Format results summary as string."""
152
+ total_count = len(self.results)
153
+
154
+ if self.all_passed:
155
+ return f"✓ All {total_count} test(s) passed"
156
+
157
+ passed_count = total_count - len(self.failures)
158
+ return (
159
+ f"✗ {len(self.failures)} of {total_count} test(s) failed "
160
+ f"({passed_count} passed)"
161
+ )
pyrmute/model_diff.py ADDED
@@ -0,0 +1,272 @@
1
+ """Model diff class."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Self
5
+
6
+ from pydantic import BaseModel
7
+ from pydantic_core import PydanticUndefined
8
+
9
+
10
+ @dataclass
11
+ class ModelDiff:
12
+ """Contains the difference between two models."""
13
+
14
+ model_name: str
15
+ from_version: str
16
+ to_version: str
17
+ added_fields: list[str]
18
+ removed_fields: list[str]
19
+ modified_fields: dict[str, Any]
20
+ added_field_info: dict[str, Any]
21
+ unchanged_fields: list[str]
22
+
23
+ def to_markdown(self: Self, header_depth: int = 1) -> str:
24
+ """Generate a markdown representation of the diff.
25
+
26
+ Args:
27
+ header_depth: Base header level (1-6). All headers are relative to this.
28
+ For example, header_depth=2 makes the title "##" and subsections "###".
29
+
30
+ Returns:
31
+ Formatted markdown string showing the differences.
32
+
33
+ Example:
34
+ >>> diff.to_markdown(header_depth=1) # Default: # Title, ## Sections
35
+ >>> diff.to_markdown(header_depth=2) # ## Title, ### Sections
36
+ >>> diff.to_markdown(header_depth=3) # ### Title, #### Sections
37
+ """
38
+ header_depth = max(1, min(6, header_depth))
39
+
40
+ h1 = "#" * header_depth
41
+ h2 = "#" * (header_depth + 1)
42
+
43
+ lines = [
44
+ f"{h1} {self.model_name}: {self.from_version} → {self.to_version}",
45
+ "",
46
+ ]
47
+
48
+ lines.append(f"{h2} Added Fields")
49
+ lines.append("")
50
+ if self.added_fields:
51
+ for field_name in sorted(self.added_fields):
52
+ field_desc = self._format_field_description(field_name, "added")
53
+ lines.append(f"- {field_desc}")
54
+ else:
55
+ lines.append("None")
56
+ lines.append("")
57
+
58
+ lines.append(f"{h2} Removed Fields")
59
+ lines.append("")
60
+ if self.removed_fields:
61
+ for field_name in sorted(self.removed_fields):
62
+ field_desc = self._format_field_description(field_name, "removed")
63
+ lines.append(f"- {field_desc}")
64
+ else:
65
+ lines.append("None")
66
+ lines.append("")
67
+
68
+ lines.append(f"{h2} Modified Fields")
69
+ lines.append("")
70
+ if self.modified_fields:
71
+ for field_name in sorted(self.modified_fields.keys()):
72
+ changes = self.modified_fields[field_name]
73
+ field_desc = self._format_modified_field(field_name, changes)
74
+ lines.append(f"- {field_desc}")
75
+ else:
76
+ lines.append("None")
77
+ lines.append("")
78
+
79
+ breaking_changes = self._identify_breaking_changes()
80
+ if breaking_changes:
81
+ lines.append(f"{h2} Breaking Changes")
82
+ lines.append("")
83
+ lines.extend(f"-⚠️ {warning}" for warning in breaking_changes)
84
+ lines.append("")
85
+
86
+ return "\n".join(lines)
87
+
88
+ def _format_field_description(self: Self, field_name: str, context: str) -> str:
89
+ """Format a field for display."""
90
+ if context == "added" and field_name in self.added_field_info:
91
+ info = self.added_field_info[field_name]
92
+ type_str = self._format_type(info["type"])
93
+ req_str = "required" if info["required"] else "optional"
94
+ return f"`{field_name}: {type_str}` ({req_str})"
95
+
96
+ return f"`{field_name}`"
97
+
98
+ def _format_modified_field(
99
+ self: Self, field_name: str, changes: dict[str, Any]
100
+ ) -> str:
101
+ """Format a modified field with its changes."""
102
+ parts = [f"`{field_name}`"]
103
+
104
+ if "type_changed" in changes:
105
+ from_type = self._format_type(changes["type_changed"]["from"])
106
+ to_type = self._format_type(changes["type_changed"]["to"])
107
+ parts.append(f"type: `{from_type}` → `{to_type}`")
108
+
109
+ if "required_changed" in changes:
110
+ req_change = changes["required_changed"]
111
+ if req_change["from"] and not req_change["to"]:
112
+ parts.append("now optional")
113
+ elif not req_change["from"] and req_change["to"]:
114
+ parts.append("now required")
115
+
116
+ if "default_changed" in changes:
117
+ from_val = changes["default_changed"]["from"]
118
+ to_val = changes["default_changed"]["to"]
119
+ parts.append(f"default: `{from_val}` → `{to_val}`")
120
+
121
+ if "default_added" in changes:
122
+ parts.append(f"default added: `{changes['default_added']}`")
123
+
124
+ if "default_removed" in changes:
125
+ parts.append(f"default removed (was `{changes['default_removed']}`)")
126
+
127
+ return " - ".join(parts)
128
+
129
+ def _format_type(self: Self, type_annotation: Any) -> str:
130
+ """Format a type annotation for display."""
131
+ if hasattr(type_annotation, "__name__"):
132
+ return str(type_annotation.__name__)
133
+
134
+ type_str = str(type_annotation)
135
+ type_str = type_str.replace("typing.", "")
136
+ return type_str.replace("typing_extensions.", "")
137
+
138
+ def _identify_breaking_changes(self: Self) -> list[str]:
139
+ """Identify breaking changes that could cause issues."""
140
+ warnings = []
141
+
142
+ for field_name in self.added_fields:
143
+ if field_name in self.added_field_info:
144
+ info = self.added_field_info[field_name]
145
+ is_required = info["required"] and info["default"] is None
146
+
147
+ if is_required:
148
+ warnings.append(
149
+ f"New required field '{field_name}' will fail for existing "
150
+ "data without defaults"
151
+ )
152
+
153
+ if self.removed_fields:
154
+ fields_str = ", ".join(f"'{f}'" for f in sorted(self.removed_fields))
155
+ warnings.append(
156
+ f"Removed fields {fields_str} will be lost during migration"
157
+ )
158
+
159
+ for field_name, changes in self.modified_fields.items():
160
+ if "required_changed" in changes:
161
+ req_change = changes["required_changed"]
162
+ if not req_change["from"] and req_change["to"]:
163
+ warnings.append(
164
+ f"Field '{field_name}' changed from optional to required"
165
+ )
166
+
167
+ if "type_changed" in changes:
168
+ warnings.append(
169
+ f"Field '{field_name}' type changed - may cause validation errors"
170
+ )
171
+
172
+ return warnings
173
+
174
+ @classmethod
175
+ def from_models(
176
+ cls,
177
+ name: str,
178
+ from_model: type[BaseModel],
179
+ to_model: type[BaseModel],
180
+ from_version: str,
181
+ to_version: str,
182
+ ) -> Self:
183
+ """Create a ModelDiff by comparing two Pydantic models.
184
+
185
+ Args:
186
+ name: Name of the model.
187
+ from_model: Source model class.
188
+ to_model: Target model class.
189
+ from_version: Source version string.
190
+ to_version: Target version string.
191
+
192
+ Returns:
193
+ ModelDiff instance with computed differences.
194
+ """
195
+ from_fields = from_model.model_fields
196
+ to_fields = to_model.model_fields
197
+
198
+ from_keys = set(from_fields.keys())
199
+ to_keys = set(to_fields.keys())
200
+
201
+ added = list(to_keys - from_keys)
202
+ removed = list(from_keys - to_keys)
203
+ common = from_keys & to_keys
204
+
205
+ modified = {}
206
+ unchanged = []
207
+
208
+ for field_name in common:
209
+ from_field = from_fields[field_name]
210
+ to_field = to_fields[field_name]
211
+
212
+ changes: dict[str, Any] = {}
213
+
214
+ if from_field.annotation != to_field.annotation:
215
+ changes["type_changed"] = {
216
+ "from": from_field.annotation,
217
+ "to": to_field.annotation,
218
+ }
219
+
220
+ from_required = from_field.is_required()
221
+ to_required = to_field.is_required()
222
+ if from_required != to_required:
223
+ changes["required_changed"] = {
224
+ "from": from_required,
225
+ "to": to_required,
226
+ }
227
+
228
+ from_default = from_field.default
229
+ to_default = to_field.default
230
+
231
+ if from_default != to_default and not (
232
+ from_default is PydanticUndefined and to_default is PydanticUndefined
233
+ ):
234
+ if (
235
+ from_default is not PydanticUndefined
236
+ and to_default is not PydanticUndefined
237
+ ):
238
+ changes["default_changed"] = {
239
+ "from": from_default,
240
+ "to": to_default,
241
+ }
242
+ elif from_default is PydanticUndefined:
243
+ changes["default_added"] = to_default
244
+ else:
245
+ changes["default_removed"] = from_default
246
+
247
+ if changes:
248
+ modified[field_name] = changes
249
+ else:
250
+ unchanged.append(field_name)
251
+
252
+ added_field_info = {}
253
+ for field_name in added:
254
+ to_field = to_fields[field_name]
255
+ added_field_info[field_name] = {
256
+ "type": to_field.annotation,
257
+ "required": to_field.is_required(),
258
+ "default": to_field.default
259
+ if to_field.default is not PydanticUndefined
260
+ else None,
261
+ }
262
+
263
+ return cls(
264
+ model_name=name,
265
+ from_version=from_version,
266
+ to_version=to_version,
267
+ added_fields=added,
268
+ removed_fields=removed,
269
+ modified_fields=modified,
270
+ unchanged_fields=unchanged,
271
+ added_field_info=added_field_info,
272
+ )