aiohomematic-test-support 2025.12.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aiohomematic-test-support might be problematic. Click here for more details.

@@ -0,0 +1,741 @@
1
+ """
2
+ Mock implementations for RPC clients with session playback.
3
+
4
+ This module provides mock RPC proxy implementations that replay pre-recorded
5
+ backend responses from session data files. This enables deterministic, fast
6
+ testing without live Homematic backend dependencies.
7
+
8
+ Key Classes
9
+ -----------
10
+ - **SessionPlayer**: Loads and plays back recorded RPC session data from ZIP archives.
11
+ - **get_mock**: Creates mock instances of data points and devices with configurable
12
+ method/property exclusions.
13
+ - **get_xml_rpc_proxy**: Returns mock XML-RPC proxy with session playback.
14
+ - **get_client_session**: Returns mock aiohttp ClientSession for JSON-RPC tests.
15
+
16
+ Session Playback
17
+ ----------------
18
+ Session data is stored in ZIP archives containing JSON files with recorded
19
+ RPC method calls and responses. The SessionPlayer replays these responses
20
+ when tests invoke RPC methods:
21
+
22
+ player = SessionPlayer(session_data_path="tests/data/ccu_full.zip")
23
+ proxy = get_xml_rpc_proxy(player=player, interface="BidCos-RF")
24
+
25
+ # Calls return pre-recorded responses
26
+ devices = await proxy.listDevices()
27
+
28
+ This approach provides:
29
+ - Fast test execution (no network I/O)
30
+ - Reproducible results (same responses every time)
31
+ - Offline testing (no backend required)
32
+
33
+ Public API of this module is defined by __all__.
34
+ """
35
+
36
+ from __future__ import annotations
37
+
38
+ import asyncio
39
+ from collections import defaultdict
40
+ from collections.abc import Callable
41
+ import contextlib
42
+ import json
43
+ import logging
44
+ import os
45
+ import sys
46
+ from typing import Any, cast
47
+ from unittest.mock import MagicMock, Mock
48
+ import zipfile
49
+
50
+ from aiohttp import ClientSession
51
+ import orjson
52
+
53
+ from aiohomematic.central import CentralUnit
54
+ from aiohomematic.client import BaseRpcProxy
55
+ from aiohomematic.client.json_rpc import _JsonKey, _JsonRpcMethod
56
+ from aiohomematic.client.rpc_proxy import _RpcMethod
57
+ from aiohomematic.const import UTF_8, DataOperationResult, Parameter, ParamsetKey, RPCType
58
+ from aiohomematic.store.persistent import _cleanup_params_for_session, _freeze_params, _unfreeze_params
59
+ from aiohomematic_test_support import const
60
+
61
+ _LOGGER = logging.getLogger(__name__)
62
+
63
+ # pylint: disable=protected-access
64
+
65
+
66
+ def _get_not_mockable_method_names(*, instance: Any, exclude_methods: set[str]) -> set[str]:
67
+ """Return all relevant method names for mocking."""
68
+ methods: set[str] = set(_get_properties(data_object=instance, decorator=property))
69
+
70
+ for method in dir(instance):
71
+ if method in exclude_methods:
72
+ methods.add(method)
73
+ return methods
74
+
75
+
76
+ def _get_properties(*, data_object: Any, decorator: Any) -> set[str]:
77
+ """Return the object attributes by decorator."""
78
+ cls = data_object.__class__
79
+
80
+ # Resolve function-based decorators to their underlying property class, if provided
81
+ resolved_decorator: Any = decorator
82
+ if not isinstance(decorator, type):
83
+ resolved_decorator = getattr(decorator, "__property_class__", decorator)
84
+
85
+ return {y for y in dir(cls) if isinstance(getattr(cls, y), resolved_decorator)}
86
+
87
+
88
+ def get_client_session( # noqa: C901
89
+ *,
90
+ player: SessionPlayer,
91
+ address_device_translation: set[str] | None = None,
92
+ ignore_devices_on_create: list[str] | None = None,
93
+ ) -> ClientSession:
94
+ """
95
+ Provide a ClientSession-like fixture that answers via SessionPlayer(JSON-RPC).
96
+
97
+ Any POST request will be answered by looking up the latest recorded
98
+ JSON-RPC response in the session player using the provided method and params.
99
+ """
100
+
101
+ class _MockResponse:
102
+ def __init__(self, *, json_data: dict[str, Any] | None) -> None:
103
+ # If no match is found, emulate backend error payload
104
+ self._json: dict[str, Any] = json_data or {
105
+ _JsonKey.RESULT: None,
106
+ _JsonKey.ERROR: {"name": "-1", "code": -1, "message": "Not found in session player"},
107
+ _JsonKey.ID: 0,
108
+ }
109
+ self.status = 200
110
+
111
+ async def json(self, *, encoding: str | None = None) -> dict[str, Any]: # mimic aiohttp API
112
+ return self._json
113
+
114
+ async def read(self) -> bytes:
115
+ return orjson.dumps(self._json)
116
+
117
+ class _MockClientSession:
118
+ def __init__(self) -> None:
119
+ """Initialize the mock client session."""
120
+ self._central: CentralUnit | None = None
121
+
122
+ async def close(self) -> None: # compatibility
123
+ return None
124
+
125
+ async def post(
126
+ self,
127
+ *,
128
+ url: str,
129
+ data: bytes | bytearray | str | None = None,
130
+ headers: Any = None,
131
+ timeout: Any = None, # noqa: ASYNC109
132
+ ssl: Any = None,
133
+ ) -> _MockResponse:
134
+ # Payload is produced by AioJsonRpcAioHttpClient via orjson.dumps
135
+ if isinstance(data, (bytes, bytearray)):
136
+ payload = orjson.loads(data)
137
+ elif isinstance(data, str):
138
+ payload = orjson.loads(data.encode(UTF_8))
139
+ else:
140
+ payload = {}
141
+
142
+ method = payload.get("method")
143
+ params = payload.get("params")
144
+
145
+ if self._central:
146
+ if method in (
147
+ _JsonRpcMethod.PROGRAM_EXECUTE,
148
+ _JsonRpcMethod.SYSVAR_SET_BOOL,
149
+ _JsonRpcMethod.SYSVAR_SET_FLOAT,
150
+ _JsonRpcMethod.SESSION_LOGOUT,
151
+ ):
152
+ return _MockResponse(json_data={_JsonKey.ID: 0, _JsonKey.RESULT: "200", _JsonKey.ERROR: None})
153
+ if method == _JsonRpcMethod.SYSVAR_GET_ALL:
154
+ return _MockResponse(
155
+ json_data={_JsonKey.ID: 0, _JsonKey.RESULT: const.SYSVAR_DATA_JSON, _JsonKey.ERROR: None}
156
+ )
157
+ if method == _JsonRpcMethod.PROGRAM_GET_ALL:
158
+ return _MockResponse(
159
+ json_data={_JsonKey.ID: 0, _JsonKey.RESULT: const.PROGRAM_DATA_JSON, _JsonKey.ERROR: None}
160
+ )
161
+ if method == _JsonRpcMethod.REGA_RUN_SCRIPT:
162
+ if "get_program_descriptions" in params[_JsonKey.SCRIPT]:
163
+ return _MockResponse(
164
+ json_data={
165
+ _JsonKey.ID: 0,
166
+ _JsonKey.RESULT: const.PROGRAM_DATA_JSON_DESCRIPTION,
167
+ _JsonKey.ERROR: None,
168
+ }
169
+ )
170
+
171
+ if "get_system_variable_descriptions" in params[_JsonKey.SCRIPT]:
172
+ return _MockResponse(
173
+ json_data={
174
+ _JsonKey.ID: 0,
175
+ _JsonKey.RESULT: const.SYSVAR_DATA_JSON_DESCRIPTION,
176
+ _JsonKey.ERROR: None,
177
+ }
178
+ )
179
+
180
+ if "get_backend_info" in params[_JsonKey.SCRIPT]:
181
+ return _MockResponse(
182
+ json_data={
183
+ _JsonKey.ID: 0,
184
+ _JsonKey.RESULT: const.BACKEND_INFO_JSON,
185
+ _JsonKey.ERROR: None,
186
+ }
187
+ )
188
+
189
+ if method == _JsonRpcMethod.INTERFACE_SET_VALUE:
190
+ await self._central.event_coordinator.data_point_event(
191
+ interface_id=params[_JsonKey.INTERFACE],
192
+ channel_address=params[_JsonKey.ADDRESS],
193
+ parameter=params[_JsonKey.VALUE_KEY],
194
+ value=params[_JsonKey.VALUE],
195
+ )
196
+ return _MockResponse(json_data={_JsonKey.ID: 0, _JsonKey.RESULT: "200", _JsonKey.ERROR: None})
197
+ if method == _JsonRpcMethod.INTERFACE_PUT_PARAMSET:
198
+ if params[_JsonKey.PARAMSET_KEY] == ParamsetKey.VALUES:
199
+ interface_id = params[_JsonKey.INTERFACE]
200
+ channel_address = params[_JsonKey.ADDRESS]
201
+ values = params[_JsonKey.SET]
202
+ for param, value in values.items():
203
+ await self._central.event_coordinator.data_point_event(
204
+ interface_id=interface_id,
205
+ channel_address=channel_address,
206
+ parameter=param,
207
+ value=value,
208
+ )
209
+ return _MockResponse(json_data={_JsonKey.RESULT: "200", _JsonKey.ERROR: None})
210
+
211
+ json_data = player.get_latest_response_by_params(
212
+ rpc_type=RPCType.JSON_RPC,
213
+ method=str(method) if method is not None else "",
214
+ params=params,
215
+ )
216
+ if method == _JsonRpcMethod.INTERFACE_LIST_DEVICES and (
217
+ ignore_devices_on_create is not None or address_device_translation is not None
218
+ ):
219
+ new_devices = []
220
+ for dd in json_data[_JsonKey.RESULT]:
221
+ if ignore_devices_on_create is not None and (
222
+ dd["address"] in ignore_devices_on_create or dd["parent"] in ignore_devices_on_create
223
+ ):
224
+ continue
225
+ if address_device_translation is not None:
226
+ if dd["address"] in address_device_translation or dd["parent"] in address_device_translation:
227
+ new_devices.append(dd)
228
+ else:
229
+ new_devices.append(dd)
230
+
231
+ json_data[_JsonKey.RESULT] = new_devices
232
+ return _MockResponse(json_data=json_data)
233
+
234
+ def set_central(self, *, central: CentralUnit) -> None:
235
+ """Set the central."""
236
+ self._central = central
237
+
238
+ return cast(ClientSession, _MockClientSession())
239
+
240
+
241
+ def get_xml_rpc_proxy( # noqa: C901
242
+ *,
243
+ player: SessionPlayer,
244
+ address_device_translation: set[str] | None = None,
245
+ ignore_devices_on_create: list[str] | None = None,
246
+ ) -> BaseRpcProxy:
247
+ """
248
+ Provide an BaseRpcProxy-like fixture that answers via SessionPlayer (XML-RPC).
249
+
250
+ Any method call like: await proxy.system.listMethods(...)
251
+ will be answered by looking up the latest recorded XML-RPC response
252
+ in the session player using the provided method and positional params.
253
+ """
254
+
255
+ class _Method:
256
+ def __init__(self, full_name: str, caller: Any) -> None:
257
+ self._name = full_name
258
+ self._caller = caller
259
+
260
+ async def __call__(self, *args: Any) -> Any:
261
+ # Forward to caller with collected method name and positional params
262
+ return await self._caller(self._name, *args)
263
+
264
+ def __getattr__(self, sub: str) -> _Method:
265
+ # Allow chaining like proxy.system.listMethods
266
+ return _Method(f"{self._name}.{sub}", self._caller)
267
+
268
+ class _AioXmlRpcProxyFromSession:
269
+ def __init__(self) -> None:
270
+ self._player = player
271
+ self._supported_methods: tuple[str, ...] = ()
272
+ self._central: CentralUnit | None = None
273
+
274
+ def __getattr__(self, name: str) -> Any:
275
+ # Start of method chain
276
+ return _Method(name, self._invoke)
277
+
278
+ @property
279
+ def supported_methods(self) -> tuple[str, ...]:
280
+ """Return the supported methods."""
281
+ return self._supported_methods
282
+
283
+ def clear_connection_issue(self) -> None:
284
+ """Clear connection issue (no-op for test mock)."""
285
+
286
+ async def clientServerInitialized(self, interface_id: str) -> None:
287
+ """Answer clientServerInitialized with pong."""
288
+ await self.ping(callerId=interface_id)
289
+
290
+ async def do_init(self) -> None:
291
+ """Init the xml rpc proxy."""
292
+ if supported_methods := await self.system.listMethods():
293
+ # ping is missing in VirtualDevices interface but can be used.
294
+ supported_methods.append(_RpcMethod.PING)
295
+ self._supported_methods = tuple(supported_methods)
296
+
297
+ async def getAllSystemVariables(self) -> dict[str, Any]:
298
+ """Return all system variables."""
299
+ return const.SYSVAR_DATA_XML
300
+
301
+ async def getParamset(self, channel_address: str, paramset: str) -> Any:
302
+ """Set a value."""
303
+ if self._central:
304
+ result = self._player.get_latest_response_by_params(
305
+ rpc_type=RPCType.XML_RPC,
306
+ method="getParamset",
307
+ params=(channel_address, paramset),
308
+ )
309
+ return result if result else {}
310
+
311
+ async def listDevices(self) -> list[Any]:
312
+ """Return a list of devices."""
313
+ devices = self._player.get_latest_response_by_params(
314
+ rpc_type=RPCType.XML_RPC,
315
+ method="listDevices",
316
+ params="()",
317
+ )
318
+
319
+ new_devices = []
320
+ if ignore_devices_on_create is None and address_device_translation is None:
321
+ return cast(list[Any], devices)
322
+
323
+ for dd in devices:
324
+ if ignore_devices_on_create is not None and (
325
+ dd["ADDRESS"] in ignore_devices_on_create or dd["PARENT"] in ignore_devices_on_create
326
+ ):
327
+ continue
328
+ if address_device_translation is not None:
329
+ if dd["ADDRESS"] in address_device_translation or dd["PARENT"] in address_device_translation:
330
+ new_devices.append(dd)
331
+ else:
332
+ new_devices.append(dd)
333
+
334
+ return new_devices
335
+
336
+ async def ping(self, callerId: str) -> None:
337
+ """Answer ping with pong."""
338
+ if self._central:
339
+ await self._central.event_coordinator.data_point_event(
340
+ interface_id=callerId,
341
+ channel_address="",
342
+ parameter=Parameter.PONG,
343
+ value=callerId,
344
+ )
345
+
346
+ async def putParamset(
347
+ self, channel_address: str, paramset_key: str, values: Any, rx_mode: Any | None = None
348
+ ) -> None:
349
+ """Set a paramset."""
350
+ if self._central and paramset_key == ParamsetKey.VALUES:
351
+ interface_id = self._central.client_coordinator.primary_client.interface_id # type: ignore[union-attr]
352
+ for param, value in values.items():
353
+ await self._central.event_coordinator.data_point_event(
354
+ interface_id=interface_id, channel_address=channel_address, parameter=param, value=value
355
+ )
356
+
357
+ async def setValue(self, channel_address: str, parameter: str, value: Any, rx_mode: Any | None = None) -> None:
358
+ """Set a value."""
359
+ if self._central:
360
+ await self._central.event_coordinator.data_point_event(
361
+ interface_id=self._central.client_coordinator.primary_client.interface_id, # type: ignore[union-attr]
362
+ channel_address=channel_address,
363
+ parameter=parameter,
364
+ value=value,
365
+ )
366
+
367
+ def set_central(self, *, central: CentralUnit) -> None:
368
+ """Set the central."""
369
+ self._central = central
370
+
371
+ async def stop(self) -> None: # compatibility with AioXmlRpcProxy.stop
372
+ return None
373
+
374
+ async def _invoke(self, method: str, *args: Any) -> Any:
375
+ params = tuple(args)
376
+ return self._player.get_latest_response_by_params(
377
+ rpc_type=RPCType.XML_RPC,
378
+ method=method,
379
+ params=params,
380
+ )
381
+
382
+ return cast(BaseRpcProxy, _AioXmlRpcProxyFromSession())
383
+
384
+
385
+ def _get_instance_attributes(instance: Any) -> set[str]:
386
+ """
387
+ Get all instance attribute names, supporting both __dict__ and __slots__.
388
+
389
+ For classes with __slots__, iterates through the class hierarchy to collect
390
+ all slot names. For classes with __dict__, returns the keys of __dict__.
391
+ Handles hybrid classes that have both __slots__ and __dict__.
392
+
393
+ Why this is needed:
394
+ Python classes can store instance attributes in two ways:
395
+ 1. __dict__: A dictionary attached to each instance (default behavior)
396
+ 2. __slots__: Pre-declared attribute names stored more efficiently
397
+
398
+ When copying attributes to a mock, we can't just use instance.__dict__
399
+ because __slots__-based classes don't have __dict__ (or have a limited one).
400
+ We must inspect the class hierarchy to find all declared slots.
401
+
402
+ Algorithm:
403
+ 1. Walk the Method Resolution Order (MRO) to find all classes in hierarchy
404
+ 2. For each class with __slots__, collect slot names (skip internal ones)
405
+ 3. Verify each slot actually has a value on this instance (getattr check)
406
+ 4. Also collect any __dict__ attributes if the instance has __dict__
407
+ 5. Return the union of all found attribute names
408
+ """
409
+ attrs: set[str] = set()
410
+
411
+ # Walk the class hierarchy via MRO (Method Resolution Order).
412
+ # __slots__ are inherited but each class defines its own slots separately,
413
+ # so we must check every class in the hierarchy.
414
+ for cls in type(instance).__mro__:
415
+ if hasattr(cls, "__slots__"):
416
+ for slot in cls.__slots__:
417
+ # Skip internal slots like __dict__ and __weakref__ which are
418
+ # automatically added by Python when a class uses __slots__
419
+ if not slot.startswith("__"):
420
+ try:
421
+ # Only include if the attribute actually exists on the instance.
422
+ # Slots can be declared but unset (raises AttributeError).
423
+ getattr(instance, slot)
424
+ attrs.add(slot)
425
+ except AttributeError:
426
+ # Slot is declared but not initialized on this instance
427
+ pass
428
+
429
+ # Also include __dict__ attributes if the instance has __dict__.
430
+ # Some classes have both __slots__ and __dict__ (e.g., if a parent class
431
+ # doesn't use __slots__, or if __dict__ is explicitly in __slots__).
432
+ if hasattr(instance, "__dict__"):
433
+ attrs.update(instance.__dict__.keys())
434
+
435
+ return attrs
436
+
437
+
438
+ def get_mock(
439
+ *, instance: Any, exclude_methods: set[str] | None = None, include_properties: set[str] | None = None, **kwargs: Any
440
+ ) -> Any:
441
+ """
442
+ Create a mock that wraps an instance with proper property delegation.
443
+
444
+ Supports both __dict__-based and __slots__-based classes. Properties are
445
+ delegated dynamically to the wrapped instance to ensure current values
446
+ are always returned.
447
+
448
+ Problem solved:
449
+ MagicMock(wraps=instance) only delegates method calls, not property access.
450
+ When you access mock.some_property, MagicMock returns the value that was
451
+ captured at mock creation time, not the current value on the wrapped instance.
452
+ This causes test failures when the wrapped instance's state changes after
453
+ the mock is created (e.g., client.available changes from False to True
454
+ after initialize_proxy() is called).
455
+
456
+ Solution:
457
+ Create a dynamic MagicMock subclass with property descriptors that delegate
458
+ to the wrapped instance on every access. This ensures properties always
459
+ return current values.
460
+
461
+ Algorithm:
462
+ 1. If already a Mock, just sync attributes from wrapped instance
463
+ 2. Identify all properties on the instance's class
464
+ 3. Create a MagicMock subclass with delegating property descriptors
465
+ 4. Create mock instance with spec and wraps
466
+ 5. Copy instance attributes (supports both __slots__ and __dict__)
467
+ 6. Copy non-mockable methods directly to mock
468
+ """
469
+ if exclude_methods is None:
470
+ exclude_methods = set()
471
+ if include_properties is None:
472
+ include_properties = set()
473
+
474
+ # Early return: if already a Mock, just refresh attributes from wrapped instance
475
+ if isinstance(instance, Mock):
476
+ if hasattr(instance, "_mock_wraps") and instance._mock_wraps is not None:
477
+ for attr in _get_instance_attributes(instance._mock_wraps):
478
+ with contextlib.suppress(AttributeError, TypeError):
479
+ setattr(instance, attr, getattr(instance._mock_wraps, attr))
480
+ return instance
481
+
482
+ # Step 1: Identify all @property decorated attributes on the class
483
+ # These need special handling because MagicMock doesn't delegate property access
484
+ property_names = _get_properties(data_object=instance, decorator=property)
485
+
486
+ # Step 2: Create a dynamic MagicMock subclass
487
+ # We add property descriptors to this subclass that delegate to _mock_wraps.
488
+ # This is the key technique: property descriptors on the class take precedence
489
+ # over MagicMock's attribute access, allowing us to intercept property reads.
490
+ class _DynamicMock(MagicMock):
491
+ pass
492
+
493
+ # Helper factory functions to create closures with correct name binding.
494
+ # Using a factory function ensures each property gets its own 'name' variable,
495
+ # avoiding the classic lambda closure bug where all properties would share
496
+ # the last loop value.
497
+ def _make_getter(name: str) -> Callable[[Any], Any]:
498
+ """Create a getter that delegates to the wrapped instance."""
499
+
500
+ def getter(self: Any) -> Any:
501
+ # Access _mock_wraps which holds the original instance
502
+ return getattr(self._mock_wraps, name)
503
+
504
+ return getter
505
+
506
+ def _make_setter(name: str) -> Callable[[Any, Any], None]:
507
+ """Create a setter that delegates to the wrapped instance."""
508
+
509
+ def setter(self: Any, value: Any) -> None:
510
+ setattr(self._mock_wraps, name, value)
511
+
512
+ return setter
513
+
514
+ # Step 3: Add property descriptors to the dynamic subclass
515
+ for prop_name in property_names:
516
+ # Skip properties that should be mocked or overridden via kwargs
517
+ if prop_name not in include_properties and prop_name not in kwargs:
518
+ # Check if the original property has a setter (is writable)
519
+ prop_descriptor = getattr(type(instance), prop_name, None)
520
+ if prop_descriptor is not None and getattr(prop_descriptor, "fset", None) is not None:
521
+ # Writable property: create descriptor with both getter and setter
522
+ setattr(
523
+ _DynamicMock,
524
+ prop_name,
525
+ property(_make_getter(prop_name), _make_setter(prop_name)),
526
+ )
527
+ else:
528
+ # Read-only property: create descriptor with getter only
529
+ setattr(
530
+ _DynamicMock,
531
+ prop_name,
532
+ property(_make_getter(prop_name)),
533
+ )
534
+
535
+ # Step 4: Create the mock instance
536
+ # spec=instance: ensures mock only allows access to attributes that exist on instance
537
+ # wraps=instance: delegates method calls to the real instance
538
+ mock = _DynamicMock(spec=instance, wraps=instance, **kwargs)
539
+
540
+ # Step 5: Copy instance attributes to mock
541
+ # This handles both __slots__ and __dict__ based classes via _get_instance_attributes()
542
+ for attr in _get_instance_attributes(instance):
543
+ with contextlib.suppress(AttributeError, TypeError):
544
+ setattr(mock, attr, getattr(instance, attr))
545
+
546
+ # Step 6: Copy non-mockable methods directly
547
+ # Some methods (like bound methods or special attributes) need to be copied
548
+ # directly rather than being mocked
549
+ for method_name in [
550
+ prop
551
+ for prop in _get_not_mockable_method_names(instance=instance, exclude_methods=exclude_methods)
552
+ if prop not in include_properties and prop not in kwargs and prop not in property_names
553
+ ]:
554
+ try:
555
+ setattr(mock, method_name, getattr(instance, method_name))
556
+ except (AttributeError, TypeError) as exc:
557
+ _LOGGER.debug("Could not copy method %s to mock: %s", method_name, exc)
558
+
559
+ return mock
560
+
561
+
562
+ async def get_session_player(*, file_name: str) -> SessionPlayer:
563
+ """Provide a SessionPlayer preloaded from the randomized full session JSON file."""
564
+ player = SessionPlayer(file_id=file_name)
565
+ if player.supports_file_id(file_id=file_name):
566
+ return player
567
+
568
+ for load_fn in const.ALL_SESSION_FILES:
569
+ file_path = os.path.join(os.path.dirname(__file__), "data", load_fn)
570
+ await player.load(file_path=file_path, file_id=load_fn)
571
+ return player
572
+
573
+
574
+ class SessionPlayer:
575
+ """Player for sessions."""
576
+
577
+ _store: dict[str, dict[str, dict[str, dict[str, dict[int, Any]]]]] = defaultdict(
578
+ lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(dict))))
579
+ )
580
+
581
+ def __init__(self, *, file_id: str) -> None:
582
+ """Initialize the session player."""
583
+ self._file_id = file_id
584
+
585
+ @classmethod
586
+ def clear_all(cls) -> None:
587
+ """Clear all cached session data from all file IDs."""
588
+ cls._store.clear()
589
+
590
+ @classmethod
591
+ def clear_file(cls, *, file_id: str) -> None:
592
+ """Clear cached session data for a specific file ID."""
593
+ cls._store.pop(file_id, None)
594
+
595
+ @classmethod
596
+ def get_loaded_file_ids(cls) -> list[str]:
597
+ """Return list of currently loaded file IDs."""
598
+ return list(cls._store.keys())
599
+
600
+ @classmethod
601
+ def get_memory_usage(cls) -> int:
602
+ """Return approximate memory usage of cached session data in bytes."""
603
+ return sys.getsizeof(cls._store)
604
+
605
+ @property
606
+ def _secondary_file_ids(self) -> list[str]:
607
+ """Return the secondary store for the given file_id."""
608
+ return [fid for fid in self._store if fid != self._file_id]
609
+
610
+ def get_latest_response_by_method(self, *, rpc_type: str, method: str) -> list[tuple[Any, Any]]:
611
+ """Return latest non-expired responses for a given (rpc_type, method)."""
612
+ if pri_result := self.get_latest_response_by_method_for_file_id(
613
+ file_id=self._file_id,
614
+ rpc_type=rpc_type,
615
+ method=method,
616
+ ):
617
+ return pri_result
618
+
619
+ for secondary_file_id in self._secondary_file_ids:
620
+ if sec_result := self.get_latest_response_by_method_for_file_id(
621
+ file_id=secondary_file_id,
622
+ rpc_type=rpc_type,
623
+ method=method,
624
+ ):
625
+ return sec_result
626
+ return pri_result
627
+
628
+ def get_latest_response_by_method_for_file_id(
629
+ self, *, file_id: str, rpc_type: str, method: str
630
+ ) -> list[tuple[Any, Any]]:
631
+ """Return latest non-expired responses for a given (rpc_type, method)."""
632
+ result: list[Any] = []
633
+ # Access store safely to avoid side effects from creating buckets.
634
+ if not (bucket_by_method := self._store[file_id].get(rpc_type)):
635
+ return result
636
+ if not (bucket_by_parameter := bucket_by_method.get(method)):
637
+ return result
638
+ # For each parameter, choose the response at the latest timestamp.
639
+ for frozen_params, bucket_by_ts in bucket_by_parameter.items():
640
+ if not bucket_by_ts:
641
+ continue
642
+ try:
643
+ latest_ts = max(bucket_by_ts.keys())
644
+ except ValueError:
645
+ continue
646
+ resp = bucket_by_ts[latest_ts]
647
+ params = _unfreeze_params(frozen_params=frozen_params)
648
+
649
+ result.append((params, resp))
650
+ return result
651
+
652
+ def get_latest_response_by_params(
653
+ self,
654
+ *,
655
+ rpc_type: str,
656
+ method: str,
657
+ params: Any,
658
+ ) -> Any:
659
+ """Return latest non-expired responses for a given (rpc_type, method, params)."""
660
+ if pri_result := self.get_latest_response_by_params_for_file_id(
661
+ file_id=self._file_id,
662
+ rpc_type=rpc_type,
663
+ method=method,
664
+ params=params,
665
+ ):
666
+ return pri_result
667
+
668
+ for secondary_file_id in self._secondary_file_ids:
669
+ if sec_result := self.get_latest_response_by_params_for_file_id(
670
+ file_id=secondary_file_id,
671
+ rpc_type=rpc_type,
672
+ method=method,
673
+ params=params,
674
+ ):
675
+ return sec_result
676
+ return pri_result
677
+
678
+ def get_latest_response_by_params_for_file_id(
679
+ self,
680
+ *,
681
+ file_id: str,
682
+ rpc_type: str,
683
+ method: str,
684
+ params: Any,
685
+ ) -> Any:
686
+ """Return latest non-expired responses for a given (rpc_type, method, params)."""
687
+ # Access store safely to avoid side effects from creating buckets.
688
+ if not (bucket_by_method := self._store[file_id].get(rpc_type)):
689
+ return None
690
+ if not (bucket_by_parameter := bucket_by_method.get(method)):
691
+ return None
692
+ frozen_params = _freeze_params(params=_cleanup_params_for_session(params=params))
693
+
694
+ # For each parameter, choose the response at the latest timestamp.
695
+ if (bucket_by_ts := bucket_by_parameter.get(frozen_params)) is None:
696
+ return None
697
+
698
+ try:
699
+ latest_ts = max(bucket_by_ts.keys())
700
+ return bucket_by_ts[latest_ts]
701
+ except ValueError:
702
+ return None
703
+
704
+ async def load(self, *, file_path: str, file_id: str) -> DataOperationResult:
705
+ """
706
+ Load data from disk into the dictionary.
707
+
708
+ Supports plain JSON files and ZIP archives containing a JSON file.
709
+ When a ZIP archive is provided, the first JSON member inside the archive
710
+ will be loaded.
711
+ """
712
+ if self.supports_file_id(file_id=file_id):
713
+ return DataOperationResult.NO_LOAD
714
+
715
+ if not os.path.exists(file_path):
716
+ return DataOperationResult.NO_LOAD
717
+
718
+ def _perform_load() -> DataOperationResult:
719
+ try:
720
+ if zipfile.is_zipfile(file_path):
721
+ with zipfile.ZipFile(file_path, mode="r") as zf:
722
+ # Prefer json files; pick the first .json entry if available
723
+ if not (json_members := [n for n in zf.namelist() if n.lower().endswith(".json")]):
724
+ return DataOperationResult.LOAD_FAIL
725
+ raw = zf.read(json_members[0]).decode(UTF_8)
726
+ data = json.loads(raw)
727
+ else:
728
+ with open(file=file_path, encoding=UTF_8) as file_pointer:
729
+ data = json.loads(file_pointer.read())
730
+
731
+ self._store[file_id] = data
732
+ except (json.JSONDecodeError, zipfile.BadZipFile, UnicodeDecodeError, OSError):
733
+ return DataOperationResult.LOAD_FAIL
734
+ return DataOperationResult.LOAD_SUCCESS
735
+
736
+ loop = asyncio.get_running_loop()
737
+ return await loop.run_in_executor(None, _perform_load)
738
+
739
+ def supports_file_id(self, *, file_id: str) -> bool:
740
+ """Return whether the session player supports the given file_id."""
741
+ return file_id in self._store