decoy 2.3.0__py3-none-any.whl → 2.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,229 @@
1
+ """Inspect spec object."""
2
+
3
+ import functools
4
+ import inspect
5
+ from typing import (
6
+ Callable,
7
+ NamedTuple,
8
+ Union,
9
+ cast,
10
+ get_args,
11
+ get_origin,
12
+ get_type_hints,
13
+ )
14
+
15
+ from .errors import (
16
+ createMockNameRequiredError,
17
+ createMockNotAsyncError,
18
+ createMockSpecInvalidError,
19
+ createSignatureMismatchError,
20
+ createThenDoActionNotCallableError,
21
+ )
22
+
23
+
24
+ class BoundArguments(NamedTuple):
25
+ """Arguments bound to a spec."""
26
+
27
+ args: tuple[object, ...]
28
+ kwargs: dict[str, object]
29
+
30
+
31
+ def is_async_callable(value: object, fallback: bool = False) -> bool:
32
+ """Get whether a spec object is an asynchronous callable."""
33
+ if value is None:
34
+ return fallback
35
+
36
+ source = _unwrap_callable(value)
37
+
38
+ # `iscoroutinefunction` does not work for `partial` on Python < 3.8
39
+ if isinstance(source, functools.partial):
40
+ source = source.func
41
+
42
+ return inspect.iscoroutinefunction(_unwrap_callable(value))
43
+
44
+
45
+ def ensure_spec(spec_cls: object, spec_func: object) -> object:
46
+ spec_cls = _unwrap_type_alias(spec_cls)
47
+
48
+ if spec_cls is not None and not inspect.isclass(spec_cls):
49
+ raise createMockSpecInvalidError("cls")
50
+
51
+ if spec_func is not None and not callable(spec_func):
52
+ raise createMockSpecInvalidError("func")
53
+
54
+ return spec_cls or spec_func
55
+
56
+
57
+ def ensure_spec_name(spec: object, fallback_name: str | None) -> str:
58
+ """Get the name of a source object."""
59
+ source_name = getattr(spec, "__name__", None) if spec is not None else None
60
+ name = source_name if isinstance(source_name, str) else fallback_name
61
+
62
+ if name is None:
63
+ raise createMockNameRequiredError()
64
+
65
+ return name
66
+
67
+
68
+ def ensure_callable(value: object, is_async: bool) -> Callable[..., object]:
69
+ if not callable(value):
70
+ raise createThenDoActionNotCallableError()
71
+
72
+ if is_async_callable(value) and not is_async:
73
+ raise createMockNotAsyncError()
74
+
75
+ return cast(Callable[..., object], value)
76
+
77
+
78
+ def get_spec_module_name(spec: object) -> str | None:
79
+ """Get the name of a source object."""
80
+ module_name = getattr(spec, "__module__", None) if spec is not None else None
81
+ return module_name if isinstance(module_name, str) else None
82
+
83
+
84
+ def get_spec_class_type(spec: object, fallback_type: type[object]) -> type[object]:
85
+ return spec if inspect.isclass(spec) else fallback_type
86
+
87
+
88
+ def is_magic_attribute(name: str) -> bool:
89
+ return name.startswith("__") and name.endswith("__")
90
+
91
+
92
+ def get_child_spec(spec: object, child_name: str) -> object:
93
+ if inspect.isclass(spec):
94
+ # inspect object for methods and properties,
95
+ # falling back to type annotations for attributes
96
+ child_hint = _get_type_hints(spec).get(child_name)
97
+ child_source = inspect.getattr_static(spec, child_name, child_hint)
98
+ unwrapped_child_source = inspect.unwrap(child_source)
99
+
100
+ if isinstance(child_source, staticmethod):
101
+ return unwrapped_child_source
102
+
103
+ if isinstance(unwrapped_child_source, property):
104
+ return _unwrap_type_alias(
105
+ _get_type_hints(unwrapped_child_source.fget).get("return")
106
+ )
107
+
108
+ if inspect.isfunction(unwrapped_child_source):
109
+ # consume `self` argument
110
+ return functools.partial(unwrapped_child_source, None)
111
+
112
+ return _unwrap_type_alias(unwrapped_child_source)
113
+
114
+ return None
115
+
116
+
117
+ def get_method_class(name: str, maybe_method: object) -> object:
118
+ if inspect.ismethod(maybe_method) and maybe_method.__name__ == name:
119
+ return maybe_method.__self__
120
+
121
+ return None
122
+
123
+
124
+ async def get_awaitable_value(value: object) -> object:
125
+ """Get an awaitable value."""
126
+ if inspect.isawaitable(value):
127
+ return await value
128
+
129
+ return value
130
+
131
+
132
+ def get_signature(value: object) -> inspect.Signature | None:
133
+ """Get the signature of an object, if it's callable."""
134
+ source = _unwrap_callable(value)
135
+
136
+ if source is None:
137
+ return None
138
+
139
+ try:
140
+ return inspect.signature(source, follow_wrapped=True)
141
+ except (ValueError, TypeError):
142
+ return None
143
+
144
+
145
+ def bind_args(
146
+ signature: inspect.Signature | None,
147
+ args: tuple[object, ...],
148
+ kwargs: dict[str, object],
149
+ ignore_extra_args: bool = False,
150
+ ) -> BoundArguments:
151
+ """Bind given args and kwargs to a signature, if possible."""
152
+ if signature is None:
153
+ return BoundArguments(args, kwargs)
154
+
155
+ try:
156
+ if ignore_extra_args:
157
+ bound_args = signature.bind_partial(*args, **kwargs)
158
+ else:
159
+ bound_args = signature.bind(*args, **kwargs)
160
+ except (TypeError, ValueError) as error:
161
+ raise createSignatureMismatchError(error) from None
162
+
163
+ return BoundArguments(bound_args.args, bound_args.kwargs)
164
+
165
+
166
+ def get_func_name(func: Callable[..., object]) -> str:
167
+ """Get the name of a function."""
168
+ if isinstance(func, functools.partial):
169
+ return func.func.__name__
170
+
171
+ return func.__name__
172
+
173
+
174
+ def _unwrap_callable(value: object) -> Callable[..., object] | None:
175
+ """Return an object's callable, checking if a class has a `__call__` method."""
176
+ if not callable(value):
177
+ return None
178
+
179
+ # check if spec source is a class with a __call__ method
180
+ if inspect.isclass(value):
181
+ call_method = inspect.getattr_static(value, "__call__", None)
182
+ if inspect.isfunction(call_method):
183
+ # consume the `self` argument of the method to ensure proper
184
+ # signature reporting by wrapping it in a partial
185
+ value = functools.partial(call_method, None)
186
+
187
+ return value
188
+
189
+
190
+ def _get_type_hints(value: object) -> dict[str, object]:
191
+ """Get type hints for an object, if possible.
192
+
193
+ The builtin `typing.get_type_hints` may fail at runtime,
194
+ e.g. if a type is subscriptable according to mypy but not
195
+ according to Python.
196
+ """
197
+ try:
198
+ return get_type_hints(value)
199
+ except Exception:
200
+ return {}
201
+
202
+
203
+ def _unwrap_type_alias(value: object) -> object:
204
+ """Return a value's actual type if it's a type alias.
205
+
206
+ If the resolved origin is a `Union`, remove any `None` values
207
+ to see if we can resolve to an individual specific type.
208
+
209
+ If we cannot resolve to a specific type, return the original value.
210
+ """
211
+ origin = _resolve_origin(value)
212
+ type_args = get_args(value)
213
+
214
+ if origin is not Union:
215
+ return origin
216
+
217
+ candidates = [a for a in type_args if a is not type(None)]
218
+
219
+ return candidates[0] if len(candidates) == 1 else origin
220
+
221
+
222
+ def _resolve_origin(source: object) -> object:
223
+ """Resolve the origin of an object.
224
+
225
+ This allows a type alias to be used as a class spec.
226
+ """
227
+ origin = get_origin(source)
228
+
229
+ return origin if origin is not None else source
@@ -0,0 +1,328 @@
1
+ import collections.abc
2
+ import functools
3
+ import re
4
+ import sys
5
+ from typing import Any, Callable, Generic, TypeVar, cast, overload
6
+
7
+ if sys.version_info >= (3, 13):
8
+ from typing import TypeIs
9
+ else:
10
+ from typing_extensions import TypeIs
11
+
12
+ from .errors import createNoMatcherValueCapturedError
13
+ from .inspect import get_func_name
14
+
15
+ ValueT = TypeVar("ValueT")
16
+ MatchT = TypeVar("MatchT")
17
+ MappingT = TypeVar("MappingT", bound=collections.abc.Mapping[Any, Any])
18
+ SequenceT = TypeVar("SequenceT", bound=collections.abc.Sequence[Any])
19
+ ErrorT = TypeVar("ErrorT", bound=BaseException)
20
+
21
+ TypedMatch = Callable[[object], TypeIs[MatchT]]
22
+ UntypedMatch = Callable[[object], bool]
23
+
24
+
25
+ class Matcher(Generic[ValueT]):
26
+ """Create an [argument matcher](./matchers.md).
27
+
28
+ Arguments:
29
+ match: A comparison function that returns a bool or `TypeIs` guard.
30
+ name: Optional name for the matcher; defaults to `match.__name__`
31
+ description: Optional extra description for the matcher's repr.
32
+
33
+ Example:
34
+ Use a function to create a custom matcher.
35
+
36
+ ```python
37
+ def is_even(target: object) -> TypeIs[int]:
38
+ return isinstance(target, int) and target % 2 == 0
39
+
40
+ is_even_matcher = Matcher(is_even)
41
+ ```
42
+
43
+ Matchers can also be constructed from built-in inspection functions, like `callable`.
44
+
45
+ ```python
46
+ callable_matcher = Matcher(callable)
47
+ ```
48
+ """
49
+
50
+ @overload
51
+ def __init__(
52
+ self: "Matcher[MatchT]",
53
+ match: TypedMatch[MatchT],
54
+ name: str | None = None,
55
+ description: str | None = None,
56
+ ) -> None: ...
57
+
58
+ @overload
59
+ def __init__(
60
+ self: "Matcher[Any]",
61
+ match: UntypedMatch,
62
+ name: str | None = None,
63
+ description: str | None = None,
64
+ ) -> None: ...
65
+
66
+ def __init__(
67
+ self,
68
+ match: TypedMatch[ValueT] | UntypedMatch,
69
+ name: str | None = None,
70
+ description: str | None = None,
71
+ ) -> None:
72
+ self._match = match
73
+ self._name = name or get_func_name(match)
74
+ self._description = description
75
+ self._values: list[ValueT] = []
76
+
77
+ def __eq__(self, target: object) -> bool:
78
+ if self._match(target):
79
+ self._values.append(cast(ValueT, target)) # type: ignore[redundant-cast]
80
+ return True
81
+
82
+ return False
83
+
84
+ def __repr__(self) -> str:
85
+ matcher_name = f"Matcher.{self._name}"
86
+ if self._description:
87
+ return f"<{matcher_name} {self._description.strip()}>"
88
+
89
+ return f"<{matcher_name}>"
90
+
91
+ @property
92
+ def arg(self) -> ValueT:
93
+ """Type-cast the matcher as the expected value.
94
+
95
+ Example:
96
+ If the mock expects a `str` argument, using `arg` prevents the type-checker from raising an error.
97
+
98
+ ```python
99
+ decoy
100
+ .when(mock)
101
+ .called_with(Matcher.matches("^(hello|hi)$").arg)
102
+ .then_return("world")
103
+ ```
104
+ """
105
+ return cast(ValueT, self)
106
+
107
+ @property
108
+ def value(self) -> ValueT:
109
+ """The latest matching compared value.
110
+
111
+ Raises:
112
+ NoMatcherValueCapturedError: the matcher has not been compared with any matching value.
113
+
114
+ Example:
115
+ You can use `value` to trigger a callback passed to your mock.
116
+
117
+ ```python
118
+ callback_matcher = Matcher(callable)
119
+ decoy.verify(mock).called_with(callback_matcher)
120
+ callback_matcher.value("value")
121
+ ```
122
+ """
123
+ if len(self._values) == 0:
124
+ raise createNoMatcherValueCapturedError(
125
+ f"{self} has not matched any values"
126
+ )
127
+
128
+ return self._values[-1]
129
+
130
+ @property
131
+ def values(self) -> list[ValueT]:
132
+ """All matching compared values."""
133
+ return self._values.copy()
134
+
135
+ @overload
136
+ @staticmethod
137
+ def any(
138
+ type: type[MatchT],
139
+ attrs: collections.abc.Mapping[str, object] | None = None,
140
+ ) -> "Matcher[MatchT]": ...
141
+
142
+ @overload
143
+ @staticmethod
144
+ def any(
145
+ type: None = None,
146
+ attrs: collections.abc.Mapping[str, object] | None = None,
147
+ ) -> "Matcher[Any]": ...
148
+
149
+ @staticmethod
150
+ def any(
151
+ type: type[MatchT] | None = None,
152
+ attrs: collections.abc.Mapping[str, object] | None = None,
153
+ ) -> "Matcher[MatchT] | Matcher[Any]":
154
+ """Match an argument, optionally by type and/or attributes.
155
+
156
+ If type and attributes are omitted, will match everything,
157
+ including `None`.
158
+
159
+ Arguments:
160
+ type: Type to match, if any.
161
+ attrs: Set of attributes to match, if any.
162
+ """
163
+ description = ""
164
+
165
+ if type:
166
+ description = type.__name__
167
+
168
+ if attrs:
169
+ description = f"{description} attrs={attrs!r}"
170
+
171
+ return Matcher(
172
+ match=functools.partial(any, type, attrs),
173
+ description=description,
174
+ )
175
+
176
+ @staticmethod
177
+ def is_not(value: object) -> "Matcher[Any]":
178
+ """Match any value that does not `==` the given value.
179
+
180
+ Arguments:
181
+ value: The value that the matcher rejects.
182
+ """
183
+ return Matcher(
184
+ lambda t: t != value,
185
+ name="is_not",
186
+ description=repr(value),
187
+ )
188
+
189
+ @overload
190
+ @staticmethod
191
+ def contains(values: MappingT) -> "Matcher[MappingT]": ...
192
+
193
+ @overload
194
+ @staticmethod
195
+ def contains(values: SequenceT, in_order: bool = False) -> "Matcher[SequenceT]": ...
196
+
197
+ @staticmethod
198
+ def contains(
199
+ values: MappingT | SequenceT,
200
+ in_order: bool = False,
201
+ ) -> "Matcher[MappingT] | Matcher[SequenceT]":
202
+ """Match a dict, list, or string with a partial value.
203
+
204
+ Arguments:
205
+ values: Partial value to match.
206
+ in_order: Match list values in order.
207
+ """
208
+ description = repr(values)
209
+
210
+ if in_order:
211
+ description = f"{description} in_order={in_order}"
212
+
213
+ return Matcher(
214
+ match=functools.partial(contains, values, in_order),
215
+ description=description,
216
+ )
217
+
218
+ @staticmethod
219
+ def matches(pattern: str) -> "Matcher[str]":
220
+ """Match a string by a pattern.
221
+
222
+ Arguments:
223
+ pattern: Regular expression pattern.
224
+ """
225
+ pattern_re = re.compile(pattern)
226
+
227
+ return Matcher(
228
+ match=functools.partial(matches, pattern_re),
229
+ description=repr(pattern),
230
+ )
231
+
232
+ @staticmethod
233
+ def error(type: type[ErrorT], message: str | None = None) -> "Matcher[ErrorT]":
234
+ """Match an exception object.
235
+
236
+ Arguments:
237
+ type: The type of exception to match.
238
+ message: An optional regular expression pattern to match.
239
+ """
240
+ message_re = re.compile(message or "")
241
+ description = type.__name__
242
+
243
+ if message:
244
+ description = f"{description} message={message!r}"
245
+
246
+ return Matcher(
247
+ match=functools.partial(error, type, message_re),
248
+ name="error",
249
+ description=description,
250
+ )
251
+
252
+
253
+ def any(
254
+ match_type: type[Any] | None,
255
+ attrs: collections.abc.Mapping[str, object] | None,
256
+ target: object,
257
+ ) -> bool:
258
+ return (match_type is None or isinstance(target, match_type)) and (
259
+ attrs is None or _has_attrs(attrs, target)
260
+ )
261
+
262
+
263
+ def _has_attrs(
264
+ attributes: collections.abc.Mapping[str, object],
265
+ target: object,
266
+ ) -> bool:
267
+ return all(
268
+ hasattr(target, attr_name) and getattr(target, attr_name) == attr_value
269
+ for attr_name, attr_value in attributes.items()
270
+ )
271
+
272
+
273
+ def contains(
274
+ values: collections.abc.Mapping[object, object] | collections.abc.Sequence[object],
275
+ in_order: bool,
276
+ target: object,
277
+ ) -> bool:
278
+ if isinstance(values, collections.abc.Mapping):
279
+ return _dict_containing(values, target)
280
+ if isinstance(values, str):
281
+ return isinstance(target, str) and values in target
282
+
283
+ return _list_containing(values, in_order, target)
284
+
285
+
286
+ def _dict_containing(
287
+ values: collections.abc.Mapping[object, object],
288
+ target: object,
289
+ ) -> bool:
290
+ try:
291
+ return all(
292
+ attr_name in target and target[attr_name] == attr_value # type: ignore[index,operator]
293
+ for attr_name, attr_value in values.items()
294
+ )
295
+ except TypeError:
296
+ return False
297
+
298
+
299
+ def _list_containing(
300
+ values: collections.abc.Sequence[object],
301
+ in_order: bool,
302
+ target: object,
303
+ ) -> bool:
304
+ target_index = 0
305
+
306
+ try:
307
+ for value in values:
308
+ if in_order:
309
+ target = target[target_index:] # type: ignore[index]
310
+
311
+ target_index = target.index(value) # type: ignore[attr-defined]
312
+
313
+ except (AttributeError, TypeError, ValueError):
314
+ return False
315
+
316
+ return True
317
+
318
+
319
+ def error(
320
+ type: type[ErrorT],
321
+ message_pattern: re.Pattern[str],
322
+ target: object,
323
+ ) -> bool:
324
+ return isinstance(target, type) and message_pattern.search(str(target)) is not None
325
+
326
+
327
+ def matches(pattern: re.Pattern[str], target: object) -> bool:
328
+ return isinstance(target, str) and pattern.search(target) is not None