kailash 0.4.2__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kailash/__init__.py +1 -1
- kailash/client/__init__.py +12 -0
- kailash/client/enhanced_client.py +306 -0
- kailash/core/actors/__init__.py +16 -0
- kailash/core/actors/connection_actor.py +566 -0
- kailash/core/actors/supervisor.py +364 -0
- kailash/edge/__init__.py +16 -0
- kailash/edge/compliance.py +834 -0
- kailash/edge/discovery.py +659 -0
- kailash/edge/location.py +582 -0
- kailash/gateway/__init__.py +33 -0
- kailash/gateway/api.py +289 -0
- kailash/gateway/enhanced_gateway.py +357 -0
- kailash/gateway/resource_resolver.py +217 -0
- kailash/gateway/security.py +227 -0
- kailash/middleware/auth/models.py +2 -2
- kailash/middleware/database/base_models.py +1 -7
- kailash/middleware/database/repositories.py +3 -1
- kailash/middleware/gateway/__init__.py +22 -0
- kailash/middleware/gateway/checkpoint_manager.py +398 -0
- kailash/middleware/gateway/deduplicator.py +382 -0
- kailash/middleware/gateway/durable_gateway.py +417 -0
- kailash/middleware/gateway/durable_request.py +498 -0
- kailash/middleware/gateway/event_store.py +459 -0
- kailash/nodes/admin/audit_log.py +364 -6
- kailash/nodes/admin/permission_check.py +817 -33
- kailash/nodes/admin/role_management.py +1242 -108
- kailash/nodes/admin/schema_manager.py +438 -0
- kailash/nodes/admin/user_management.py +1209 -681
- kailash/nodes/api/http.py +95 -71
- kailash/nodes/base.py +281 -164
- kailash/nodes/base_async.py +30 -31
- kailash/nodes/code/__init__.py +8 -1
- kailash/nodes/code/async_python.py +1035 -0
- kailash/nodes/code/python.py +1 -0
- kailash/nodes/data/async_sql.py +12 -25
- kailash/nodes/data/sql.py +20 -11
- kailash/nodes/data/workflow_connection_pool.py +643 -0
- kailash/nodes/rag/__init__.py +1 -4
- kailash/resources/__init__.py +40 -0
- kailash/resources/factory.py +533 -0
- kailash/resources/health.py +319 -0
- kailash/resources/reference.py +288 -0
- kailash/resources/registry.py +392 -0
- kailash/runtime/async_local.py +711 -302
- kailash/testing/__init__.py +34 -0
- kailash/testing/async_test_case.py +353 -0
- kailash/testing/async_utils.py +345 -0
- kailash/testing/fixtures.py +458 -0
- kailash/testing/mock_registry.py +495 -0
- kailash/utils/resource_manager.py +420 -0
- kailash/workflow/__init__.py +8 -0
- kailash/workflow/async_builder.py +621 -0
- kailash/workflow/async_patterns.py +766 -0
- kailash/workflow/builder.py +93 -10
- kailash/workflow/cyclic_runner.py +111 -41
- kailash/workflow/graph.py +7 -2
- kailash/workflow/resilience.py +11 -1
- {kailash-0.4.2.dist-info → kailash-0.6.0.dist-info}/METADATA +12 -7
- {kailash-0.4.2.dist-info → kailash-0.6.0.dist-info}/RECORD +64 -28
- {kailash-0.4.2.dist-info → kailash-0.6.0.dist-info}/WHEEL +0 -0
- {kailash-0.4.2.dist-info → kailash-0.6.0.dist-info}/entry_points.txt +0 -0
- {kailash-0.4.2.dist-info → kailash-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {kailash-0.4.2.dist-info → kailash-0.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,495 @@
|
|
1
|
+
"""Mock resource registry for testing."""
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import functools
|
5
|
+
import inspect
|
6
|
+
import logging
|
7
|
+
from dataclasses import dataclass, field
|
8
|
+
from datetime import datetime, timezone
|
9
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
10
|
+
from unittest.mock import AsyncMock, MagicMock, Mock, create_autospec
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
@dataclass
|
16
|
+
class CallRecord:
|
17
|
+
"""Record of a method call."""
|
18
|
+
|
19
|
+
method: str
|
20
|
+
args: tuple
|
21
|
+
kwargs: dict
|
22
|
+
timestamp: datetime
|
23
|
+
result: Any = None
|
24
|
+
exception: Optional[Exception] = None
|
25
|
+
duration: float = 0.0
|
26
|
+
|
27
|
+
|
28
|
+
class MockResource:
|
29
|
+
"""Base class for mock resources with call tracking."""
|
30
|
+
|
31
|
+
def __init__(self, spec=None):
|
32
|
+
self._call_records: List[CallRecord] = []
|
33
|
+
self._spec = spec
|
34
|
+
|
35
|
+
def _record_call(
|
36
|
+
self,
|
37
|
+
method_name: str,
|
38
|
+
args: tuple,
|
39
|
+
kwargs: dict,
|
40
|
+
result: Any = None,
|
41
|
+
exception: Exception = None,
|
42
|
+
duration: float = 0.0,
|
43
|
+
):
|
44
|
+
"""Record a method call."""
|
45
|
+
record = CallRecord(
|
46
|
+
method=method_name,
|
47
|
+
args=args,
|
48
|
+
kwargs=kwargs,
|
49
|
+
timestamp=datetime.now(timezone.utc),
|
50
|
+
result=result,
|
51
|
+
exception=exception,
|
52
|
+
duration=duration,
|
53
|
+
)
|
54
|
+
self._call_records.append(record)
|
55
|
+
|
56
|
+
def get_calls(self, method_name: str = None) -> List[CallRecord]:
|
57
|
+
"""Get call records, optionally filtered by method."""
|
58
|
+
if method_name:
|
59
|
+
return [r for r in self._call_records if r.method == method_name]
|
60
|
+
return self._call_records.copy()
|
61
|
+
|
62
|
+
|
63
|
+
class MockResourceRegistry:
|
64
|
+
"""Registry for mock resources in tests."""
|
65
|
+
|
66
|
+
def __init__(self):
|
67
|
+
self._mocks: Dict[str, Any] = {}
|
68
|
+
self._call_history: Dict[str, List[CallRecord]] = {}
|
69
|
+
self._expectations: Dict[str, List["Expectation"]] = {}
|
70
|
+
|
71
|
+
def register_mock(self, name: str, mock: Any):
|
72
|
+
"""Register a mock resource."""
|
73
|
+
self._mocks[name] = mock
|
74
|
+
self._call_history[name] = []
|
75
|
+
|
76
|
+
# Wrap methods to track calls if not already a Mock
|
77
|
+
if not isinstance(mock, (Mock, AsyncMock)):
|
78
|
+
self._wrap_mock_methods(name, mock)
|
79
|
+
|
80
|
+
async def create_mock(self, name: str, factory: Any, spec: Any = None) -> Any:
|
81
|
+
"""Create a mock resource from factory."""
|
82
|
+
# Determine what to mock
|
83
|
+
if spec is None and hasattr(factory, "create"):
|
84
|
+
# Try to get the spec from factory
|
85
|
+
try:
|
86
|
+
if asyncio.iscoroutinefunction(factory.create):
|
87
|
+
# Create a temporary instance to get its type
|
88
|
+
instance = await factory.create()
|
89
|
+
spec = type(instance)
|
90
|
+
# Clean up if possible
|
91
|
+
if hasattr(instance, "close"):
|
92
|
+
if asyncio.iscoroutinefunction(instance.close):
|
93
|
+
await instance.close()
|
94
|
+
else:
|
95
|
+
instance.close()
|
96
|
+
else:
|
97
|
+
instance = factory.create()
|
98
|
+
spec = type(instance)
|
99
|
+
if hasattr(instance, "close"):
|
100
|
+
instance.close()
|
101
|
+
except Exception as e:
|
102
|
+
logger.debug(f"Could not determine spec from factory: {e}")
|
103
|
+
|
104
|
+
# Create appropriate mock
|
105
|
+
if spec:
|
106
|
+
# Check if it's an async class
|
107
|
+
has_async = any(
|
108
|
+
asyncio.iscoroutinefunction(getattr(spec, attr, None))
|
109
|
+
for attr in dir(spec)
|
110
|
+
if not attr.startswith("_") and callable(getattr(spec, attr, None))
|
111
|
+
)
|
112
|
+
|
113
|
+
if has_async:
|
114
|
+
mock = create_autospec(spec, spec_set=True, instance=True)
|
115
|
+
# Make async methods return AsyncMock but preserve them for configuration
|
116
|
+
async_methods = []
|
117
|
+
for attr in dir(spec):
|
118
|
+
if not attr.startswith("_"):
|
119
|
+
method = getattr(spec, attr, None)
|
120
|
+
if asyncio.iscoroutinefunction(method):
|
121
|
+
async_methods.append(attr)
|
122
|
+
# Special handling for acquire method
|
123
|
+
if attr == "acquire":
|
124
|
+
# Create async context manager
|
125
|
+
async_cm = AsyncMock()
|
126
|
+
async_cm.__aenter__ = AsyncMock(return_value=mock)
|
127
|
+
async_cm.__aexit__ = AsyncMock(return_value=None)
|
128
|
+
acquire_mock = AsyncMock(return_value=async_cm)
|
129
|
+
setattr(mock, attr, acquire_mock)
|
130
|
+
else:
|
131
|
+
setattr(mock, attr, AsyncMock())
|
132
|
+
else:
|
133
|
+
mock = create_autospec(spec, spec_set=True, instance=True)
|
134
|
+
else:
|
135
|
+
# Default to AsyncMock for resources
|
136
|
+
mock = AsyncMock()
|
137
|
+
|
138
|
+
# Configure common resource methods (only for non-autospec mocks)
|
139
|
+
if not spec or not hasattr(mock, "_spec_class"):
|
140
|
+
self._configure_resource_mock(mock)
|
141
|
+
|
142
|
+
# Register it
|
143
|
+
self.register_mock(name, mock)
|
144
|
+
|
145
|
+
return mock
|
146
|
+
|
147
|
+
def create_mock_method(self, return_value=None, side_effect=None):
|
148
|
+
"""Create a mock method with tracking."""
|
149
|
+
if asyncio.iscoroutine(return_value) or (
|
150
|
+
side_effect and asyncio.iscoroutinefunction(side_effect)
|
151
|
+
):
|
152
|
+
mock = AsyncMock(return_value=return_value, side_effect=side_effect)
|
153
|
+
else:
|
154
|
+
mock = Mock(return_value=return_value, side_effect=side_effect)
|
155
|
+
return mock
|
156
|
+
|
157
|
+
def _configure_resource_mock(self, mock: Union[Mock, AsyncMock]):
|
158
|
+
"""Configure common resource patterns."""
|
159
|
+
# Database-like resources
|
160
|
+
if hasattr(mock, "acquire"):
|
161
|
+
# Check if this is already an AsyncMock
|
162
|
+
if isinstance(getattr(mock, "acquire", None), AsyncMock):
|
163
|
+
# Configure the existing AsyncMock
|
164
|
+
async_cm = AsyncMock()
|
165
|
+
async_cm.__aenter__ = AsyncMock(return_value=mock)
|
166
|
+
async_cm.__aexit__ = AsyncMock(return_value=None)
|
167
|
+
mock.acquire.return_value = async_cm
|
168
|
+
else:
|
169
|
+
# For autospec mocks, we can't override, but acquire should already be mocked
|
170
|
+
pass
|
171
|
+
|
172
|
+
if hasattr(mock, "execute"):
|
173
|
+
mock.execute = AsyncMock(return_value=None)
|
174
|
+
|
175
|
+
if hasattr(mock, "fetch"):
|
176
|
+
mock.fetch = AsyncMock(return_value=[])
|
177
|
+
|
178
|
+
if hasattr(mock, "fetchone"):
|
179
|
+
mock.fetchone = AsyncMock(return_value=None)
|
180
|
+
|
181
|
+
if hasattr(mock, "fetchval"):
|
182
|
+
mock.fetchval = AsyncMock(return_value=None)
|
183
|
+
|
184
|
+
# HTTP client-like resources
|
185
|
+
if hasattr(mock, "get"):
|
186
|
+
response_mock = AsyncMock()
|
187
|
+
response_mock.json = AsyncMock(return_value={})
|
188
|
+
response_mock.text = AsyncMock(return_value="")
|
189
|
+
response_mock.status = 200
|
190
|
+
response_mock.raise_for_status = Mock()
|
191
|
+
|
192
|
+
mock.get.return_value = response_mock
|
193
|
+
if hasattr(mock, "post"):
|
194
|
+
mock.post.return_value = response_mock
|
195
|
+
if hasattr(mock, "put"):
|
196
|
+
mock.put.return_value = response_mock
|
197
|
+
if hasattr(mock, "delete"):
|
198
|
+
mock.delete.return_value = response_mock
|
199
|
+
|
200
|
+
# Cache-like resources
|
201
|
+
if hasattr(mock, "get") and hasattr(mock, "set"):
|
202
|
+
mock.get = AsyncMock(return_value=None)
|
203
|
+
mock.set = AsyncMock()
|
204
|
+
mock.setex = AsyncMock()
|
205
|
+
mock.delete = AsyncMock()
|
206
|
+
mock.expire = AsyncMock()
|
207
|
+
|
208
|
+
# Add close/cleanup methods if not present (skip if spec_set)
|
209
|
+
try:
|
210
|
+
if not hasattr(mock, "close"):
|
211
|
+
mock.close = AsyncMock()
|
212
|
+
except AttributeError:
|
213
|
+
# Spec_set mock - can't add new attributes
|
214
|
+
pass
|
215
|
+
|
216
|
+
try:
|
217
|
+
if not hasattr(mock, "cleanup"):
|
218
|
+
mock.cleanup = AsyncMock()
|
219
|
+
except AttributeError:
|
220
|
+
# Spec_set mock - can't add new attributes
|
221
|
+
pass
|
222
|
+
|
223
|
+
def _wrap_mock_methods(self, name: str, mock: Any):
|
224
|
+
"""Wrap mock methods to track calls."""
|
225
|
+
# Only wrap MockResource instances
|
226
|
+
if not isinstance(mock, MockResource):
|
227
|
+
return
|
228
|
+
|
229
|
+
for attr_name in dir(mock):
|
230
|
+
if attr_name.startswith("_"):
|
231
|
+
continue
|
232
|
+
|
233
|
+
attr = getattr(mock, attr_name)
|
234
|
+
if callable(attr) and not isinstance(attr, (Mock, AsyncMock)):
|
235
|
+
wrapped = self._create_wrapper(name, attr_name, attr, mock)
|
236
|
+
setattr(mock, attr_name, wrapped)
|
237
|
+
|
238
|
+
def _create_wrapper(
|
239
|
+
self,
|
240
|
+
resource_name: str,
|
241
|
+
method_name: str,
|
242
|
+
method: Callable,
|
243
|
+
mock_resource: MockResource,
|
244
|
+
) -> Callable:
|
245
|
+
"""Create method wrapper that tracks calls."""
|
246
|
+
is_async = asyncio.iscoroutinefunction(method)
|
247
|
+
|
248
|
+
if is_async:
|
249
|
+
|
250
|
+
@functools.wraps(method)
|
251
|
+
async def async_wrapper(*args, **kwargs):
|
252
|
+
start_time = asyncio.get_event_loop().time()
|
253
|
+
try:
|
254
|
+
result = await method(*args, **kwargs)
|
255
|
+
duration = asyncio.get_event_loop().time() - start_time
|
256
|
+
|
257
|
+
# Record in both places
|
258
|
+
record = CallRecord(
|
259
|
+
method=method_name,
|
260
|
+
args=args,
|
261
|
+
kwargs=kwargs,
|
262
|
+
timestamp=datetime.now(timezone.utc),
|
263
|
+
result=result,
|
264
|
+
duration=duration,
|
265
|
+
)
|
266
|
+
self._call_history[resource_name].append(record)
|
267
|
+
mock_resource._record_call(
|
268
|
+
method_name, args, kwargs, result, duration=duration
|
269
|
+
)
|
270
|
+
|
271
|
+
return result
|
272
|
+
except Exception as e:
|
273
|
+
duration = asyncio.get_event_loop().time() - start_time
|
274
|
+
record = CallRecord(
|
275
|
+
method=method_name,
|
276
|
+
args=args,
|
277
|
+
kwargs=kwargs,
|
278
|
+
timestamp=datetime.now(timezone.utc),
|
279
|
+
exception=e,
|
280
|
+
duration=duration,
|
281
|
+
)
|
282
|
+
self._call_history[resource_name].append(record)
|
283
|
+
mock_resource._record_call(
|
284
|
+
method_name, args, kwargs, exception=e, duration=duration
|
285
|
+
)
|
286
|
+
raise
|
287
|
+
|
288
|
+
return async_wrapper
|
289
|
+
else:
|
290
|
+
|
291
|
+
@functools.wraps(method)
|
292
|
+
def sync_wrapper(*args, **kwargs):
|
293
|
+
import time
|
294
|
+
|
295
|
+
start_time = time.time()
|
296
|
+
try:
|
297
|
+
result = method(*args, **kwargs)
|
298
|
+
duration = time.time() - start_time
|
299
|
+
|
300
|
+
record = CallRecord(
|
301
|
+
method=method_name,
|
302
|
+
args=args,
|
303
|
+
kwargs=kwargs,
|
304
|
+
timestamp=datetime.now(timezone.utc),
|
305
|
+
result=result,
|
306
|
+
duration=duration,
|
307
|
+
)
|
308
|
+
self._call_history[resource_name].append(record)
|
309
|
+
mock_resource._record_call(
|
310
|
+
method_name, args, kwargs, result, duration=duration
|
311
|
+
)
|
312
|
+
|
313
|
+
return result
|
314
|
+
except Exception as e:
|
315
|
+
duration = time.time() - start_time
|
316
|
+
record = CallRecord(
|
317
|
+
method=method_name,
|
318
|
+
args=args,
|
319
|
+
kwargs=kwargs,
|
320
|
+
timestamp=datetime.now(timezone.utc),
|
321
|
+
exception=e,
|
322
|
+
duration=duration,
|
323
|
+
)
|
324
|
+
self._call_history[resource_name].append(record)
|
325
|
+
mock_resource._record_call(
|
326
|
+
method_name, args, kwargs, exception=e, duration=duration
|
327
|
+
)
|
328
|
+
raise
|
329
|
+
|
330
|
+
return sync_wrapper
|
331
|
+
|
332
|
+
def get_calls(
|
333
|
+
self, resource_name: str, method_name: str = None
|
334
|
+
) -> List[CallRecord]:
|
335
|
+
"""Get call history for a resource."""
|
336
|
+
calls = self._call_history.get(resource_name, [])
|
337
|
+
|
338
|
+
# Also check if it's a Mock object with call tracking
|
339
|
+
mock = self._mocks.get(resource_name)
|
340
|
+
if mock and isinstance(mock, (Mock, AsyncMock)):
|
341
|
+
# For unittest.mock objects, create CallRecords from call history
|
342
|
+
if method_name and hasattr(mock, method_name):
|
343
|
+
method_mock = getattr(mock, method_name)
|
344
|
+
if hasattr(method_mock, "call_args_list"):
|
345
|
+
for call in method_mock.call_args_list:
|
346
|
+
args, kwargs = call if call else ((), {})
|
347
|
+
record = CallRecord(
|
348
|
+
method=method_name,
|
349
|
+
args=args,
|
350
|
+
kwargs=kwargs,
|
351
|
+
timestamp=datetime.now(timezone.utc),
|
352
|
+
)
|
353
|
+
calls.append(record)
|
354
|
+
|
355
|
+
if method_name:
|
356
|
+
calls = [c for c in calls if c.method == method_name]
|
357
|
+
return calls
|
358
|
+
|
359
|
+
def assert_called(
|
360
|
+
self,
|
361
|
+
resource_name: str,
|
362
|
+
method_name: str,
|
363
|
+
times: Optional[int] = None,
|
364
|
+
with_args: Optional[tuple] = None,
|
365
|
+
with_kwargs: Optional[dict] = None,
|
366
|
+
):
|
367
|
+
"""Assert a method was called."""
|
368
|
+
mock = self._mocks.get(resource_name)
|
369
|
+
|
370
|
+
# Handle unittest.mock objects
|
371
|
+
if mock and isinstance(mock, (Mock, AsyncMock)):
|
372
|
+
method = getattr(mock, method_name, None)
|
373
|
+
if method is None:
|
374
|
+
raise AssertionError(f"{resource_name} has no method {method_name}")
|
375
|
+
|
376
|
+
if times is not None:
|
377
|
+
assert method.call_count == times, (
|
378
|
+
f"{resource_name}.{method_name} called {method.call_count} times, "
|
379
|
+
f"expected {times}"
|
380
|
+
)
|
381
|
+
else:
|
382
|
+
method.assert_called()
|
383
|
+
|
384
|
+
if with_args is not None or with_kwargs is not None:
|
385
|
+
method.assert_called_with(*(with_args or ()), **(with_kwargs or {}))
|
386
|
+
else:
|
387
|
+
# Use recorded calls
|
388
|
+
calls = self.get_calls(resource_name, method_name)
|
389
|
+
|
390
|
+
# Filter by args/kwargs if specified
|
391
|
+
if with_args is not None or with_kwargs is not None:
|
392
|
+
matching_calls = []
|
393
|
+
for call in calls:
|
394
|
+
args_match = with_args is None or call.args == with_args
|
395
|
+
kwargs_match = with_kwargs is None or all(
|
396
|
+
call.kwargs.get(k) == v for k, v in with_kwargs.items()
|
397
|
+
)
|
398
|
+
if args_match and kwargs_match:
|
399
|
+
matching_calls.append(call)
|
400
|
+
calls = matching_calls
|
401
|
+
|
402
|
+
# Check times
|
403
|
+
if times is not None:
|
404
|
+
assert len(calls) == times, (
|
405
|
+
f"{resource_name}.{method_name} called {len(calls)} times, "
|
406
|
+
f"expected {times}\n"
|
407
|
+
f"Calls: {[(c.args, c.kwargs) for c in calls]}"
|
408
|
+
)
|
409
|
+
else:
|
410
|
+
assert len(calls) > 0, f"{resource_name}.{method_name} was not called"
|
411
|
+
|
412
|
+
def assert_not_called(self, resource_name: str, method_name: str):
|
413
|
+
"""Assert a method was not called."""
|
414
|
+
mock = self._mocks.get(resource_name)
|
415
|
+
|
416
|
+
if mock and isinstance(mock, (Mock, AsyncMock)):
|
417
|
+
method = getattr(mock, method_name, None)
|
418
|
+
if method:
|
419
|
+
method.assert_not_called()
|
420
|
+
else:
|
421
|
+
calls = self.get_calls(resource_name, method_name)
|
422
|
+
assert (
|
423
|
+
len(calls) == 0
|
424
|
+
), f"{resource_name}.{method_name} was called {len(calls)} times"
|
425
|
+
|
426
|
+
def get_mock(self, name: str) -> Any:
|
427
|
+
"""Get a mock resource."""
|
428
|
+
return self._mocks.get(name)
|
429
|
+
|
430
|
+
def reset_history(self, resource_name: str = None):
|
431
|
+
"""Reset call history."""
|
432
|
+
if resource_name:
|
433
|
+
self._call_history[resource_name] = []
|
434
|
+
mock = self._mocks.get(resource_name)
|
435
|
+
if mock and isinstance(mock, (Mock, AsyncMock)):
|
436
|
+
mock.reset_mock()
|
437
|
+
else:
|
438
|
+
for name in self._call_history:
|
439
|
+
self._call_history[name] = []
|
440
|
+
for mock in self._mocks.values():
|
441
|
+
if isinstance(mock, (Mock, AsyncMock)):
|
442
|
+
mock.reset_mock()
|
443
|
+
|
444
|
+
def expect_call(
|
445
|
+
self,
|
446
|
+
resource_name: str,
|
447
|
+
method_name: str,
|
448
|
+
returns: Any = None,
|
449
|
+
raises: Exception = None,
|
450
|
+
) -> "Expectation":
|
451
|
+
"""Set up an expectation for a call."""
|
452
|
+
expectation = Expectation(resource_name, method_name, returns, raises)
|
453
|
+
|
454
|
+
if resource_name not in self._expectations:
|
455
|
+
self._expectations[resource_name] = []
|
456
|
+
self._expectations[resource_name].append(expectation)
|
457
|
+
|
458
|
+
# Configure mock if it exists
|
459
|
+
mock = self._mocks.get(resource_name)
|
460
|
+
if mock and hasattr(mock, method_name):
|
461
|
+
method = getattr(mock, method_name)
|
462
|
+
if raises:
|
463
|
+
method.side_effect = raises
|
464
|
+
else:
|
465
|
+
method.return_value = returns
|
466
|
+
|
467
|
+
return expectation
|
468
|
+
|
469
|
+
|
470
|
+
@dataclass
|
471
|
+
class Expectation:
|
472
|
+
"""Expectation for a method call."""
|
473
|
+
|
474
|
+
resource_name: str
|
475
|
+
method_name: str
|
476
|
+
returns: Any = None
|
477
|
+
raises: Optional[Exception] = None
|
478
|
+
times: Optional[int] = None
|
479
|
+
with_args: Optional[tuple] = None
|
480
|
+
with_kwargs: Optional[dict] = None
|
481
|
+
|
482
|
+
def matches(self, method_name: str, args: tuple, kwargs: dict) -> bool:
|
483
|
+
"""Check if call matches expectation."""
|
484
|
+
if method_name != self.method_name:
|
485
|
+
return False
|
486
|
+
|
487
|
+
if self.with_args is not None and args != self.with_args:
|
488
|
+
return False
|
489
|
+
|
490
|
+
if self.with_kwargs is not None:
|
491
|
+
for k, v in self.with_kwargs.items():
|
492
|
+
if kwargs.get(k) != v:
|
493
|
+
return False
|
494
|
+
|
495
|
+
return True
|