tmock 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tmock/__init__.py ADDED
@@ -0,0 +1,19 @@
1
+ from tmock.interceptor import CallArguments
2
+ from tmock.matchers.any import any
3
+ from tmock.mock_generator import tmock
4
+ from tmock.reset import reset, reset_behaviors, reset_interactions
5
+ from tmock.stubbing_dsl import given
6
+ from tmock.tpatch import tpatch
7
+ from tmock.verification_dsl import verify
8
+
9
+ __all__ = [
10
+ any.__name__,
11
+ CallArguments.__name__,
12
+ given.__name__,
13
+ reset.__name__,
14
+ reset_behaviors.__name__,
15
+ reset_interactions.__name__,
16
+ tmock.__name__,
17
+ tpatch.__name__,
18
+ verify.__name__,
19
+ ]
tmock/call_record.py ADDED
@@ -0,0 +1,83 @@
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import Any
4
+
5
+ from tmock.matchers.base import Matcher
6
+
7
+
8
+ @dataclass(frozen=True)
9
+ class RecordedArgument:
10
+ name: str
11
+ value: Any
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class CallRecord(ABC):
16
+ name: str
17
+ arguments: tuple[RecordedArgument, ...]
18
+
19
+ @abstractmethod
20
+ def format_call(self) -> str:
21
+ """Format this call for display in error messages."""
22
+ ...
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class MethodCallRecord(CallRecord):
27
+ def format_call(self) -> str:
28
+ args_str = ", ".join(f"{arg.name}={_format_value(arg.value)}" for arg in self.arguments)
29
+ return f"{self.name}({args_str})"
30
+
31
+
32
+ @dataclass(frozen=True)
33
+ class GetterCallRecord(CallRecord):
34
+ def format_call(self) -> str:
35
+ return f"get {self.name}"
36
+
37
+
38
+ @dataclass(frozen=True)
39
+ class SetterCallRecord(CallRecord):
40
+ def format_call(self) -> str:
41
+ value = self.arguments[0].value if self.arguments else "?"
42
+ return f"set {self.name} = {_format_value(value)}"
43
+
44
+
45
+ def _format_value(v: Any) -> str:
46
+ if isinstance(v, Matcher):
47
+ return v.describe()
48
+ return repr(v)
49
+
50
+
51
+ def pattern_matches_call(pattern: CallRecord, actual: CallRecord) -> bool:
52
+ """Check if a pattern (which may contain Matchers) matches an actual call."""
53
+ if pattern.name != actual.name or len(pattern.arguments) != len(actual.arguments):
54
+ return False
55
+ for pattern_arg, actual_arg in zip(pattern.arguments, actual.arguments):
56
+ if pattern_arg.name != actual_arg.name:
57
+ return False
58
+ if isinstance(pattern_arg.value, Matcher):
59
+ if not pattern_arg.value.matches(actual_arg.value):
60
+ return False
61
+ elif not _safe_equals(pattern_arg.value, actual_arg.value):
62
+ return False
63
+ return True
64
+
65
+
66
+ def _safe_equals(a: Any, b: Any) -> bool:
67
+ """Compare two values safely, avoiding recursion with TMock objects."""
68
+ # If both are the same object, they are equal
69
+ if a is b:
70
+ return True
71
+
72
+ # If both are TMocks, use identity to avoid triggering their __eq__ interceptors
73
+ # which leads to infinite recursion during pattern matching.
74
+ from tmock.mock_generator import is_tmock
75
+
76
+ if is_tmock(a) and is_tmock(b):
77
+ return a is b
78
+
79
+ # Standard comparison for everything else
80
+ try:
81
+ return bool(a == b)
82
+ except Exception:
83
+ return False
tmock/class_schema.py ADDED
@@ -0,0 +1,336 @@
1
+ import dataclasses
2
+ from dataclasses import dataclass, field
3
+ from enum import Enum, auto
4
+ from inspect import Parameter, Signature, iscoroutinefunction, signature
5
+ from typing import Any, Callable, ClassVar, Type, get_origin, get_type_hints
6
+
7
+
8
+ class FieldSource(Enum):
9
+ """Indicates how a field was discovered."""
10
+
11
+ PROPERTY = auto()
12
+ ANNOTATION = auto()
13
+ DATACLASS = auto()
14
+ PYDANTIC = auto()
15
+ EXTRA = auto()
16
+
17
+
18
+ @dataclass
19
+ class FieldSchema:
20
+ """Unified schema for any mockable field."""
21
+
22
+ name: str
23
+ getter_signature: Signature
24
+ setter_signature: Signature | None
25
+ source: FieldSource
26
+
27
+
28
+ @dataclass
29
+ class ClassSchema:
30
+ """Holds introspected metadata about a class's members."""
31
+
32
+ method_signatures: dict[str, Signature] = field(default_factory=dict)
33
+ fields: dict[str, FieldSchema] = field(default_factory=dict)
34
+ class_or_static: set[str] = field(default_factory=set)
35
+ async_methods: set[str] = field(default_factory=set)
36
+
37
+
38
+ class FieldDiscovery:
39
+ """Discovers mockable fields from various class types."""
40
+
41
+ def __init__(self, cls: Type[Any]):
42
+ self._cls = cls
43
+
44
+ def discover_all(self) -> dict[str, FieldSchema]:
45
+ """Discover all fields, with earlier sources taking precedence."""
46
+ result: dict[str, FieldSchema] = {}
47
+ self._merge(result, self._discover_pydantic_fields())
48
+ self._merge(result, self._discover_dataclass_fields())
49
+ self._merge(result, self._discover_properties())
50
+ self._merge(result, self._discover_annotations())
51
+ return result
52
+
53
+ def _merge(self, target: dict[str, FieldSchema], discovered: dict[str, FieldSchema]) -> None:
54
+ """Merge discovered fields, skipping those already present."""
55
+ for name, field_schema in discovered.items():
56
+ if name not in target:
57
+ target[name] = field_schema
58
+
59
+ def _discover_pydantic_fields(self) -> dict[str, FieldSchema]:
60
+ """Discover fields from Pydantic models (v2)."""
61
+ if not hasattr(self._cls, "__pydantic_complete__"):
62
+ return {}
63
+
64
+ result: dict[str, FieldSchema] = {}
65
+ model_fields = getattr(self._cls, "model_fields", {})
66
+ model_config = getattr(self._cls, "model_config", {})
67
+ frozen = model_config.get("frozen", False) if isinstance(model_config, dict) else False
68
+
69
+ for name, field_info in model_fields.items():
70
+ if name.startswith("_"):
71
+ continue
72
+ annotation = getattr(field_info, "annotation", Any)
73
+ result[name] = self._create_schema(
74
+ name=name,
75
+ annotation=annotation,
76
+ has_setter=not frozen,
77
+ source=FieldSource.PYDANTIC,
78
+ )
79
+
80
+ return result
81
+
82
+ def _discover_dataclass_fields(self) -> dict[str, FieldSchema]:
83
+ """Discover fields from dataclasses."""
84
+ if not dataclasses.is_dataclass(self._cls):
85
+ return {}
86
+
87
+ result: dict[str, FieldSchema] = {}
88
+ params = getattr(self._cls, "__dataclass_params__", None)
89
+ frozen = params.frozen if params else False
90
+
91
+ for fld in dataclasses.fields(self._cls):
92
+ if fld.name.startswith("_"):
93
+ continue
94
+ result[fld.name] = self._create_schema(
95
+ name=fld.name,
96
+ annotation=fld.type,
97
+ has_setter=not frozen,
98
+ source=FieldSource.DATACLASS,
99
+ )
100
+
101
+ return result
102
+
103
+ def _discover_properties(self) -> dict[str, FieldSchema]:
104
+ """Discover @property descriptors."""
105
+ result: dict[str, FieldSchema] = {}
106
+
107
+ for name in dir(self._cls):
108
+ if name.startswith("_"):
109
+ continue
110
+
111
+ raw_attr = _get_raw_attribute(self._cls, name)
112
+ if isinstance(raw_attr, property):
113
+ getter_sig = self._extract_property_getter_signature(raw_attr.fget)
114
+ setter_sig = self._extract_property_setter_signature(raw_attr.fset) if raw_attr.fset else None
115
+ result[name] = FieldSchema(
116
+ name=name,
117
+ getter_signature=getter_sig,
118
+ setter_signature=setter_sig,
119
+ source=FieldSource.PROPERTY,
120
+ )
121
+
122
+ return result
123
+
124
+ def _discover_annotations(self) -> dict[str, FieldSchema]:
125
+ """Discover class-level type annotations (instance variables)."""
126
+ result: dict[str, FieldSchema] = {}
127
+
128
+ try:
129
+ hints = get_type_hints(self._cls)
130
+ except Exception:
131
+ hints = getattr(self._cls, "__annotations__", {})
132
+
133
+ for name, annotation in hints.items():
134
+ if name.startswith("_"):
135
+ continue
136
+ if get_origin(annotation) is ClassVar:
137
+ continue
138
+ result[name] = self._create_schema(
139
+ name=name,
140
+ annotation=annotation,
141
+ has_setter=True,
142
+ source=FieldSource.ANNOTATION,
143
+ )
144
+
145
+ return result
146
+
147
+ def _create_schema(
148
+ self,
149
+ name: str,
150
+ annotation: Any,
151
+ has_setter: bool,
152
+ source: FieldSource,
153
+ ) -> FieldSchema:
154
+ """Create a FieldSchema with synthetic getter/setter signatures."""
155
+ getter_sig = Signature(return_annotation=annotation)
156
+
157
+ if has_setter:
158
+ value_param = Parameter("value", Parameter.POSITIONAL_OR_KEYWORD, annotation=annotation)
159
+ setter_sig = Signature(parameters=[value_param], return_annotation=type(None))
160
+ else:
161
+ setter_sig = None
162
+
163
+ return FieldSchema(
164
+ name=name,
165
+ getter_signature=getter_sig,
166
+ setter_signature=setter_sig,
167
+ source=source,
168
+ )
169
+
170
+ def _extract_property_getter_signature(self, getter: Any) -> Signature:
171
+ """Creates a signature for a property getter (no params, just return type)."""
172
+ if getter is None:
173
+ return Signature(return_annotation=Signature.empty)
174
+
175
+ try:
176
+ hints = get_type_hints(getter)
177
+ return_type = hints.get("return", Signature.empty)
178
+ except Exception:
179
+ return_type = Signature.empty
180
+
181
+ return Signature(return_annotation=return_type)
182
+
183
+ def _extract_property_setter_signature(self, setter: Any) -> Signature:
184
+ """Creates a signature for a property setter (one 'value' param, returns None)."""
185
+ value_type: Any = Signature.empty
186
+ try:
187
+ hints = get_type_hints(setter)
188
+ # Get the first non-return hint (the value parameter, regardless of its name)
189
+ for param_name, param_type in hints.items():
190
+ if param_name != "return":
191
+ value_type = param_type
192
+ break
193
+ except Exception:
194
+ pass
195
+
196
+ value_param = Parameter("value", Parameter.POSITIONAL_OR_KEYWORD, annotation=value_type)
197
+ return Signature(parameters=[value_param], return_annotation=type(None))
198
+
199
+
200
+ ALLOWED_MAGIC_METHODS = {
201
+ "__call__",
202
+ "__enter__",
203
+ "__exit__",
204
+ "__aenter__",
205
+ "__aexit__",
206
+ "__getitem__",
207
+ "__setitem__",
208
+ "__delitem__",
209
+ "__iter__",
210
+ "__next__",
211
+ "__aiter__",
212
+ "__anext__",
213
+ "__len__",
214
+ "__contains__",
215
+ "__bool__",
216
+ "__hash__",
217
+ "__eq__",
218
+ "__ne__",
219
+ "__lt__",
220
+ "__le__",
221
+ "__gt__",
222
+ "__ge__",
223
+ "__str__",
224
+ "__repr__",
225
+ "__format__",
226
+ }
227
+
228
+
229
+ def introspect_class(cls: Type[Any], extra_fields: list[str] | None = None) -> ClassSchema:
230
+ """Analyzes a class and extracts metadata about its members."""
231
+ schema = ClassSchema()
232
+
233
+ # Discover fields
234
+ discovery = FieldDiscovery(cls)
235
+ schema.fields = discovery.discover_all()
236
+ _apply_extra_fields_if_not_discovered(extra_fields, schema)
237
+
238
+ # Discover methods and class/static members
239
+ for name in dir(cls):
240
+ is_magic_allowed = name in ALLOWED_MAGIC_METHODS
241
+ if (name.startswith("_") and not is_magic_allowed) or name in schema.fields:
242
+ continue
243
+
244
+ raw_attr = _get_raw_attribute(cls, name)
245
+ if raw_attr is None:
246
+ continue
247
+
248
+ # Skip magic methods that are just the default object implementation
249
+ if is_magic_allowed and _default_impl_is_inherited_from_object(cls, name):
250
+ continue
251
+
252
+ if isinstance(raw_attr, (classmethod, staticmethod)):
253
+ schema.class_or_static.add(name)
254
+ elif callable(raw_attr) and not isinstance(raw_attr, property):
255
+ schema.method_signatures[name] = _extract_instance_method_signature(raw_attr)
256
+ if iscoroutinefunction(raw_attr):
257
+ schema.async_methods.add(name)
258
+
259
+ return schema
260
+
261
+
262
+ def _default_impl_is_inherited_from_object(cls: Type[Any], name: str) -> bool:
263
+ """Returns True if the attribute is resolved from the 'object' class directly."""
264
+ for base in cls.__mro__:
265
+ if name in base.__dict__:
266
+ return base is object
267
+ return False
268
+
269
+
270
+ def _apply_extra_fields_if_not_discovered(extra_fields, schema):
271
+ if not extra_fields:
272
+ return
273
+ for name in extra_fields:
274
+ if name not in schema.fields:
275
+ schema.fields[name] = _create_extra_field_schema(name)
276
+
277
+
278
+ def _create_extra_field_schema(name: str) -> FieldSchema:
279
+ """Create a FieldSchema for an extra field with no type info."""
280
+ getter_sig = Signature(return_annotation=Any)
281
+ value_param = Parameter("value", Parameter.POSITIONAL_OR_KEYWORD, annotation=Any)
282
+ setter_sig = Signature(parameters=[value_param], return_annotation=type(None))
283
+ return FieldSchema(
284
+ name=name,
285
+ getter_signature=getter_sig,
286
+ setter_signature=setter_sig,
287
+ source=FieldSource.EXTRA,
288
+ )
289
+
290
+
291
+ def _get_raw_attribute(cls: Type[Any], name: str) -> Any:
292
+ """Retrieves the raw attribute from the class hierarchy, bypassing descriptors."""
293
+ for klass in cls.__mro__:
294
+ if name in klass.__dict__:
295
+ return klass.__dict__[name]
296
+ return None
297
+
298
+
299
+ def _extract_instance_method_signature(method: Any) -> Signature:
300
+ """Extracts signature from an instance method, excluding 'self' parameter."""
301
+ sig = signature(method)
302
+ sig = resolve_forward_refs(method, sig)
303
+ params = list(sig.parameters.values())
304
+ if params:
305
+ return sig.replace(parameters=params[1:])
306
+ return sig
307
+
308
+
309
+ def resolve_forward_refs(func: Callable[..., Any], sig: Signature) -> Signature:
310
+ """Resolve string forward references in a signature using get_type_hints.
311
+
312
+ This allows typeguard to properly validate types even when annotations
313
+ are written as strings (forward references like -> "ClassName").
314
+
315
+ Args:
316
+ func: The function/method to resolve hints for.
317
+ sig: The original signature from inspect.signature().
318
+
319
+ Returns:
320
+ A new Signature with forward references resolved to actual types.
321
+ If resolution fails, returns the original signature unchanged.
322
+ """
323
+ try:
324
+ hints = get_type_hints(func)
325
+ except Exception:
326
+ return sig
327
+
328
+ new_params = []
329
+ for param in sig.parameters.values():
330
+ if param.name in hints:
331
+ new_params.append(param.replace(annotation=hints[param.name]))
332
+ else:
333
+ new_params.append(param)
334
+
335
+ return_annotation = hints.get("return", sig.return_annotation)
336
+ return sig.replace(parameters=new_params, return_annotation=return_annotation)
tmock/exceptions.py ADDED
@@ -0,0 +1,24 @@
1
+ class TMockError(Exception):
2
+ """Base class for all tmock exceptions."""
3
+
4
+ pass
5
+
6
+
7
+ class TMockStubbingError(TMockError):
8
+ pass
9
+
10
+
11
+ class TMockVerificationError(TMockError, AssertionError):
12
+ pass
13
+
14
+
15
+ class TMockUnexpectedCallError(TMockError):
16
+ pass
17
+
18
+
19
+ class TMockPatchingError(TMockError):
20
+ pass
21
+
22
+
23
+ class TMockResetError(TMockError):
24
+ pass
tmock/field_ref.py ADDED
@@ -0,0 +1,15 @@
1
+ from dataclasses import dataclass
2
+ from typing import TYPE_CHECKING, Any, Optional
3
+
4
+ if TYPE_CHECKING:
5
+ from tmock.interceptor import Interceptor
6
+
7
+
8
+ @dataclass
9
+ class FieldRef:
10
+ """Reference to a field on a mock, returned during DSL mode."""
11
+
12
+ mock: Any
13
+ name: str
14
+ getter_interceptor: "Interceptor"
15
+ setter_interceptor: Optional["Interceptor"]