crawlee 1.0.3b6__py3-none-any.whl → 1.0.5b18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. crawlee/_service_locator.py +4 -4
  2. crawlee/_utils/recoverable_state.py +32 -8
  3. crawlee/_utils/recurring_task.py +15 -0
  4. crawlee/_utils/robots.py +17 -5
  5. crawlee/_utils/sitemap.py +1 -1
  6. crawlee/_utils/urls.py +9 -2
  7. crawlee/browsers/_browser_pool.py +4 -1
  8. crawlee/browsers/_playwright_browser_controller.py +1 -1
  9. crawlee/browsers/_playwright_browser_plugin.py +17 -3
  10. crawlee/browsers/_types.py +1 -1
  11. crawlee/configuration.py +3 -1
  12. crawlee/crawlers/_abstract_http/_abstract_http_crawler.py +3 -1
  13. crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py +33 -13
  14. crawlee/crawlers/_basic/_basic_crawler.py +23 -12
  15. crawlee/crawlers/_playwright/_playwright_crawler.py +11 -4
  16. crawlee/fingerprint_suite/_header_generator.py +2 -2
  17. crawlee/otel/crawler_instrumentor.py +3 -3
  18. crawlee/request_loaders/_sitemap_request_loader.py +5 -0
  19. crawlee/sessions/_session_pool.py +1 -1
  20. crawlee/statistics/_error_snapshotter.py +1 -1
  21. crawlee/statistics/_statistics.py +41 -31
  22. crawlee/storage_clients/__init__.py +4 -0
  23. crawlee/storage_clients/_file_system/_request_queue_client.py +24 -6
  24. crawlee/storage_clients/_redis/__init__.py +6 -0
  25. crawlee/storage_clients/_redis/_client_mixin.py +295 -0
  26. crawlee/storage_clients/_redis/_dataset_client.py +325 -0
  27. crawlee/storage_clients/_redis/_key_value_store_client.py +264 -0
  28. crawlee/storage_clients/_redis/_request_queue_client.py +586 -0
  29. crawlee/storage_clients/_redis/_storage_client.py +146 -0
  30. crawlee/storage_clients/_redis/_utils.py +23 -0
  31. crawlee/storage_clients/_redis/lua_scripts/atomic_bloom_add_requests.lua +36 -0
  32. crawlee/storage_clients/_redis/lua_scripts/atomic_fetch_request.lua +49 -0
  33. crawlee/storage_clients/_redis/lua_scripts/atomic_set_add_requests.lua +37 -0
  34. crawlee/storage_clients/_redis/lua_scripts/reclaim_stale_requests.lua +34 -0
  35. crawlee/storage_clients/_redis/py.typed +0 -0
  36. crawlee/storage_clients/_sql/_db_models.py +1 -2
  37. crawlee/storage_clients/_sql/_storage_client.py +9 -0
  38. crawlee/storages/_key_value_store.py +5 -2
  39. {crawlee-1.0.3b6.dist-info → crawlee-1.0.5b18.dist-info}/METADATA +9 -5
  40. {crawlee-1.0.3b6.dist-info → crawlee-1.0.5b18.dist-info}/RECORD +43 -31
  41. {crawlee-1.0.3b6.dist-info → crawlee-1.0.5b18.dist-info}/WHEEL +0 -0
  42. {crawlee-1.0.3b6.dist-info → crawlee-1.0.5b18.dist-info}/entry_points.txt +0 -0
  43. {crawlee-1.0.3b6.dist-info → crawlee-1.0.5b18.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,7 @@
1
1
  # Inspiration: https://github.com/apify/crawlee/blob/v3.9.2/packages/core/src/crawlers/statistics.ts
2
2
  from __future__ import annotations
3
3
 
4
+ import asyncio
4
5
  import math
5
6
  import time
6
7
  from datetime import datetime, timedelta, timezone
@@ -17,8 +18,11 @@ from crawlee.statistics import FinalStatistics, StatisticsState
17
18
  from crawlee.statistics._error_tracker import ErrorTracker
18
19
 
19
20
  if TYPE_CHECKING:
21
+ from collections.abc import Callable, Coroutine
20
22
  from types import TracebackType
21
23
 
24
+ from crawlee.storages import KeyValueStore
25
+
22
26
  TStatisticsState = TypeVar('TStatisticsState', bound=StatisticsState, default=StatisticsState)
23
27
  TNewStatisticsState = TypeVar('TNewStatisticsState', bound=StatisticsState, default=StatisticsState)
24
28
  logger = getLogger(__name__)
@@ -70,6 +74,7 @@ class Statistics(Generic[TStatisticsState]):
70
74
  persistence_enabled: bool | Literal['explicit_only'] = False,
71
75
  persist_state_kvs_name: str | None = None,
72
76
  persist_state_key: str | None = None,
77
+ persist_state_kvs_factory: Callable[[], Coroutine[None, None, KeyValueStore]] | None = None,
73
78
  log_message: str = 'Statistics',
74
79
  periodic_message_logger: Logger | None = None,
75
80
  log_interval: timedelta = timedelta(minutes=1),
@@ -80,8 +85,6 @@ class Statistics(Generic[TStatisticsState]):
80
85
  self._id = Statistics.__next_id
81
86
  Statistics.__next_id += 1
82
87
 
83
- self._instance_start: datetime | None = None
84
-
85
88
  self.error_tracker = ErrorTracker(
86
89
  save_error_snapshots=save_error_snapshots,
87
90
  snapshot_kvs_name=persist_state_kvs_name,
@@ -92,9 +95,10 @@ class Statistics(Generic[TStatisticsState]):
92
95
 
93
96
  self._state = RecoverableState(
94
97
  default_state=state_model(stats_id=self._id),
95
- persist_state_key=persist_state_key or f'SDK_CRAWLER_STATISTICS_{self._id}',
98
+ persist_state_key=persist_state_key or f'__CRAWLER_STATISTICS_{self._id}',
96
99
  persistence_enabled=persistence_enabled,
97
100
  persist_state_kvs_name=persist_state_kvs_name,
101
+ persist_state_kvs_factory=persist_state_kvs_factory,
98
102
  logger=logger,
99
103
  )
100
104
 
@@ -106,12 +110,15 @@ class Statistics(Generic[TStatisticsState]):
106
110
  # Flag to indicate the context state.
107
111
  self._active = False
108
112
 
113
+ # Pre-existing runtime offset, that can be non-zero when restoring serialized state from KVS.
114
+ self._runtime_offset = timedelta(seconds=0)
115
+
109
116
  def replace_state_model(self, state_model: type[TNewStatisticsState]) -> Statistics[TNewStatisticsState]:
110
117
  """Create near copy of the `Statistics` with replaced `state_model`."""
111
118
  new_statistics: Statistics[TNewStatisticsState] = Statistics(
112
119
  persistence_enabled=self._state._persistence_enabled, # noqa: SLF001
113
- persist_state_kvs_name=self._state._persist_state_kvs_name, # noqa: SLF001
114
120
  persist_state_key=self._state._persist_state_key, # noqa: SLF001
121
+ persist_state_kvs_factory=self._state._persist_state_kvs_factory, # noqa: SLF001
115
122
  log_message=self._log_message,
116
123
  periodic_message_logger=self._periodic_message_logger,
117
124
  state_model=state_model,
@@ -125,6 +132,7 @@ class Statistics(Generic[TStatisticsState]):
125
132
  persistence_enabled: bool = False,
126
133
  persist_state_kvs_name: str | None = None,
127
134
  persist_state_key: str | None = None,
135
+ persist_state_kvs_factory: Callable[[], Coroutine[None, None, KeyValueStore]] | None = None,
128
136
  log_message: str = 'Statistics',
129
137
  periodic_message_logger: Logger | None = None,
130
138
  log_interval: timedelta = timedelta(minutes=1),
@@ -136,6 +144,7 @@ class Statistics(Generic[TStatisticsState]):
136
144
  persistence_enabled=persistence_enabled,
137
145
  persist_state_kvs_name=persist_state_kvs_name,
138
146
  persist_state_key=persist_state_key,
147
+ persist_state_kvs_factory=persist_state_kvs_factory,
139
148
  log_message=log_message,
140
149
  periodic_message_logger=periodic_message_logger,
141
150
  log_interval=log_interval,
@@ -158,14 +167,17 @@ class Statistics(Generic[TStatisticsState]):
158
167
  if self._active:
159
168
  raise RuntimeError(f'The {self.__class__.__name__} is already active.')
160
169
 
161
- self._active = True
162
- self._instance_start = datetime.now(timezone.utc)
163
-
164
170
  await self._state.initialize()
165
- self._after_initialize()
166
171
 
172
+ self._runtime_offset = self.state.crawler_runtime
173
+
174
+ # Start periodic logging and let it print initial state before activation.
167
175
  self._periodic_logger.start()
176
+ await asyncio.sleep(0.01)
177
+ self._active = True
168
178
 
179
+ self.state.crawler_last_started_at = datetime.now(timezone.utc)
180
+ self.state.crawler_started_at = self.state.crawler_started_at or self.state.crawler_last_started_at
169
181
  return self
170
182
 
171
183
  async def __aexit__(
@@ -182,13 +194,18 @@ class Statistics(Generic[TStatisticsState]):
182
194
  if not self._active:
183
195
  raise RuntimeError(f'The {self.__class__.__name__} is not active.')
184
196
 
185
- self._state.current_value.crawler_finished_at = datetime.now(timezone.utc)
186
-
187
- await self._state.teardown()
197
+ if not self.state.crawler_last_started_at:
198
+ raise RuntimeError('Statistics.state.crawler_last_started_at not set.')
188
199
 
200
+ # Stop logging and deactivate the statistics to prevent further changes to crawler_runtime
189
201
  await self._periodic_logger.stop()
202
+ self.state.crawler_finished_at = datetime.now(timezone.utc)
203
+ self.state.crawler_runtime = (
204
+ self._runtime_offset + self.state.crawler_finished_at - self.state.crawler_last_started_at
205
+ )
190
206
 
191
207
  self._active = False
208
+ await self._state.teardown()
192
209
 
193
210
  @property
194
211
  def state(self) -> TStatisticsState:
@@ -245,13 +262,21 @@ class Statistics(Generic[TStatisticsState]):
245
262
 
246
263
  del self._requests_in_progress[request_id_or_key]
247
264
 
265
+ def _update_crawler_runtime(self) -> None:
266
+ current_run_duration = (
267
+ (datetime.now(timezone.utc) - self.state.crawler_last_started_at)
268
+ if self.state.crawler_last_started_at
269
+ else timedelta()
270
+ )
271
+ self.state.crawler_runtime = current_run_duration + self._runtime_offset
272
+
248
273
  def calculate(self) -> FinalStatistics:
249
274
  """Calculate the current statistics."""
250
- if self._instance_start is None:
251
- raise RuntimeError('The Statistics object is not initialized')
275
+ if self._active:
276
+ # Only update state when active. If not, just report the last known runtime.
277
+ self._update_crawler_runtime()
252
278
 
253
- crawler_runtime = datetime.now(timezone.utc) - self._instance_start
254
- total_minutes = crawler_runtime.total_seconds() / 60
279
+ total_minutes = self.state.crawler_runtime.total_seconds() / 60
255
280
  state = self._state.current_value
256
281
  serialized_state = state.model_dump(by_alias=False)
257
282
 
@@ -262,7 +287,7 @@ class Statistics(Generic[TStatisticsState]):
262
287
  requests_failed_per_minute=math.floor(state.requests_failed / total_minutes) if total_minutes else 0,
263
288
  request_total_duration=state.request_total_finished_duration + state.request_total_failed_duration,
264
289
  requests_total=state.requests_failed + state.requests_finished,
265
- crawler_runtime=crawler_runtime,
290
+ crawler_runtime=state.crawler_runtime,
266
291
  requests_finished=state.requests_finished,
267
292
  requests_failed=state.requests_failed,
268
293
  retry_histogram=serialized_state['request_retry_histogram'],
@@ -282,21 +307,6 @@ class Statistics(Generic[TStatisticsState]):
282
307
  else:
283
308
  self._periodic_message_logger.info(self._log_message, extra=stats.to_dict())
284
309
 
285
- def _after_initialize(self) -> None:
286
- state = self._state.current_value
287
-
288
- if state.crawler_started_at is None:
289
- state.crawler_started_at = datetime.now(timezone.utc)
290
-
291
- if state.stats_persisted_at is not None and state.crawler_last_started_at:
292
- self._instance_start = datetime.now(timezone.utc) - (
293
- state.stats_persisted_at - state.crawler_last_started_at
294
- )
295
- elif state.crawler_last_started_at:
296
- self._instance_start = state.crawler_last_started_at
297
-
298
- state.crawler_last_started_at = self._instance_start
299
-
300
310
  def _save_retry_count_for_request(self, record: RequestProcessingRecord) -> None:
301
311
  retry_count = record.retry_count
302
312
  state = self._state.current_value
@@ -13,9 +13,13 @@ _install_import_hook(__name__)
13
13
  with _try_import(__name__, 'SqlStorageClient'):
14
14
  from ._sql import SqlStorageClient
15
15
 
16
+ with _try_import(__name__, 'RedisStorageClient'):
17
+ from ._redis import RedisStorageClient
18
+
16
19
  __all__ = [
17
20
  'FileSystemStorageClient',
18
21
  'MemoryStorageClient',
22
+ 'RedisStorageClient',
19
23
  'SqlStorageClient',
20
24
  'StorageClient',
21
25
  ]
@@ -31,6 +31,7 @@ if TYPE_CHECKING:
31
31
  from collections.abc import Sequence
32
32
 
33
33
  from crawlee.configuration import Configuration
34
+ from crawlee.storages import KeyValueStore
34
35
 
35
36
  logger = getLogger(__name__)
36
37
 
@@ -92,6 +93,7 @@ class FileSystemRequestQueueClient(RequestQueueClient):
92
93
  metadata: RequestQueueMetadata,
93
94
  path_to_rq: Path,
94
95
  lock: asyncio.Lock,
96
+ recoverable_state: RecoverableState[RequestQueueState],
95
97
  ) -> None:
96
98
  """Initialize a new instance.
97
99
 
@@ -114,12 +116,7 @@ class FileSystemRequestQueueClient(RequestQueueClient):
114
116
  self._is_empty_cache: bool | None = None
115
117
  """Cache for is_empty result: None means unknown, True/False is cached state."""
116
118
 
117
- self._state = RecoverableState[RequestQueueState](
118
- default_state=RequestQueueState(),
119
- persist_state_key=f'__RQ_STATE_{self._metadata.id}',
120
- persistence_enabled=True,
121
- logger=logger,
122
- )
119
+ self._state = recoverable_state
123
120
  """Recoverable state to maintain request ordering, in-progress status, and handled status."""
124
121
 
125
122
  @override
@@ -136,6 +133,22 @@ class FileSystemRequestQueueClient(RequestQueueClient):
136
133
  """The full path to the request queue metadata file."""
137
134
  return self.path_to_rq / METADATA_FILENAME
138
135
 
136
+ @classmethod
137
+ async def _create_recoverable_state(cls, id: str, configuration: Configuration) -> RecoverableState:
138
+ async def kvs_factory() -> KeyValueStore:
139
+ from crawlee.storage_clients import FileSystemStorageClient # noqa: PLC0415 avoid circular import
140
+ from crawlee.storages import KeyValueStore # noqa: PLC0415 avoid circular import
141
+
142
+ return await KeyValueStore.open(storage_client=FileSystemStorageClient(), configuration=configuration)
143
+
144
+ return RecoverableState[RequestQueueState](
145
+ default_state=RequestQueueState(),
146
+ persist_state_key=f'__RQ_STATE_{id}',
147
+ persist_state_kvs_factory=kvs_factory,
148
+ persistence_enabled=True,
149
+ logger=logger,
150
+ )
151
+
139
152
  @classmethod
140
153
  async def open(
141
154
  cls,
@@ -194,6 +207,9 @@ class FileSystemRequestQueueClient(RequestQueueClient):
194
207
  metadata=metadata,
195
208
  path_to_rq=rq_base_path / rq_dir,
196
209
  lock=asyncio.Lock(),
210
+ recoverable_state=await cls._create_recoverable_state(
211
+ id=id, configuration=configuration
212
+ ),
197
213
  )
198
214
  await client._state.initialize()
199
215
  await client._discover_existing_requests()
@@ -230,6 +246,7 @@ class FileSystemRequestQueueClient(RequestQueueClient):
230
246
  metadata=metadata,
231
247
  path_to_rq=path_to_rq,
232
248
  lock=asyncio.Lock(),
249
+ recoverable_state=await cls._create_recoverable_state(id=metadata.id, configuration=configuration),
233
250
  )
234
251
 
235
252
  await client._state.initialize()
@@ -254,6 +271,7 @@ class FileSystemRequestQueueClient(RequestQueueClient):
254
271
  metadata=metadata,
255
272
  path_to_rq=path_to_rq,
256
273
  lock=asyncio.Lock(),
274
+ recoverable_state=await cls._create_recoverable_state(id=metadata.id, configuration=configuration),
257
275
  )
258
276
  await client._state.initialize()
259
277
  await client._update_metadata()
@@ -0,0 +1,6 @@
1
+ from ._dataset_client import RedisDatasetClient
2
+ from ._key_value_store_client import RedisKeyValueStoreClient
3
+ from ._request_queue_client import RedisRequestQueueClient
4
+ from ._storage_client import RedisStorageClient
5
+
6
+ __all__ = ['RedisDatasetClient', 'RedisKeyValueStoreClient', 'RedisRequestQueueClient', 'RedisStorageClient']
@@ -0,0 +1,295 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from contextlib import asynccontextmanager
5
+ from datetime import datetime, timezone
6
+ from logging import getLogger
7
+ from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, overload
8
+
9
+ from crawlee._utils.crypto import crypto_random_object_id
10
+
11
+ from ._utils import await_redis_response, read_lua_script
12
+
13
+ if TYPE_CHECKING:
14
+ from collections.abc import AsyncIterator
15
+
16
+ from redis.asyncio import Redis
17
+ from redis.asyncio.client import Pipeline
18
+ from redis.commands.core import AsyncScript
19
+ from typing_extensions import NotRequired, Self
20
+
21
+ from crawlee.storage_clients.models import DatasetMetadata, KeyValueStoreMetadata, RequestQueueMetadata
22
+
23
+
24
+ logger = getLogger(__name__)
25
+
26
+
27
+ class MetadataUpdateParams(TypedDict, total=False):
28
+ """Parameters for updating metadata."""
29
+
30
+ update_accessed_at: NotRequired[bool]
31
+ update_modified_at: NotRequired[bool]
32
+
33
+
34
+ class RedisClientMixin:
35
+ """Mixin class for Redis clients.
36
+
37
+ This mixin provides common Redis operations and basic methods for Redis storage clients.
38
+ """
39
+
40
+ _DEFAULT_NAME = 'default'
41
+ """Default storage name in key prefix when none provided."""
42
+
43
+ _MAIN_KEY: ClassVar[str]
44
+ """Main Redis key prefix for this storage type."""
45
+
46
+ _CLIENT_TYPE: ClassVar[str]
47
+ """Human-readable client type for error messages."""
48
+
49
+ def __init__(self, storage_name: str, storage_id: str, redis: Redis) -> None:
50
+ self._storage_name = storage_name
51
+ self._storage_id = storage_id
52
+ self._redis = redis
53
+
54
+ self._scripts_loaded = False
55
+
56
+ @property
57
+ def redis(self) -> Redis:
58
+ """Return the Redis client instance."""
59
+ return self._redis
60
+
61
+ @property
62
+ def metadata_key(self) -> str:
63
+ """Return the Redis key for the metadata of this storage."""
64
+ return f'{self._MAIN_KEY}:{self._storage_name}:metadata'
65
+
66
+ @classmethod
67
+ async def _get_metadata_by_name(cls, name: str, redis: Redis, *, with_wait: bool = False) -> dict | None:
68
+ """Retrieve metadata by storage name.
69
+
70
+ Args:
71
+ name: The name of the storage.
72
+ redis: The Redis client instance.
73
+ with_wait: Whether to wait for the storage to be created if it doesn't exist.
74
+ """
75
+ if with_wait:
76
+ # Wait for the creation signal (max 30 seconds)
77
+ await await_redis_response(redis.blpop([f'{cls._MAIN_KEY}:{name}:created_signal'], timeout=30))
78
+ # Signal consumed, push it back for other waiters
79
+ await await_redis_response(redis.lpush(f'{cls._MAIN_KEY}:{name}:created_signal', 1))
80
+
81
+ response = await await_redis_response(redis.json().get(f'{cls._MAIN_KEY}:{name}:metadata'))
82
+ data = response[0] if response is not None and isinstance(response, list) else response
83
+ if data is not None and not isinstance(data, dict):
84
+ raise TypeError('The metadata data was received in an incorrect format.')
85
+ return data
86
+
87
+ @classmethod
88
+ async def _get_metadata_name_by_id(cls, id: str, redis: Redis) -> str | None:
89
+ """Retrieve storage name by ID from id_to_name index.
90
+
91
+ Args:
92
+ id: The ID of the storage.
93
+ redis: The Redis client instance.
94
+ """
95
+ name = await await_redis_response(redis.hget(f'{cls._MAIN_KEY}:id_to_name', id))
96
+ if isinstance(name, str) or name is None:
97
+ return name
98
+ if isinstance(name, bytes):
99
+ return name.decode('utf-8')
100
+ return None
101
+
102
+ @classmethod
103
+ async def _open(
104
+ cls,
105
+ *,
106
+ id: str | None,
107
+ name: str | None,
108
+ alias: str | None,
109
+ metadata_model: type[DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata],
110
+ redis: Redis,
111
+ extra_metadata_fields: dict[str, Any],
112
+ instance_kwargs: dict[str, Any],
113
+ ) -> Self:
114
+ """Open or create a new Redis storage client.
115
+
116
+ Args:
117
+ id: The ID of the storage. If not provided, a random ID will be generated.
118
+ name: The name of the storage for named (global scope) storages.
119
+ alias: The alias of the storage for unnamed (run scope) storages.
120
+ redis: Redis client instance.
121
+ metadata_model: Pydantic model for metadata validation.
122
+ extra_metadata_fields: Storage-specific metadata fields.
123
+ instance_kwargs: Additional arguments for the client constructor.
124
+
125
+ Returns:
126
+ An instance for the opened or created storage client.
127
+ """
128
+ internal_name = name or alias or cls._DEFAULT_NAME
129
+ storage_id: str | None = None
130
+ # Determine if storage exists by ID or name
131
+ if id:
132
+ storage_name = await cls._get_metadata_name_by_id(id=id, redis=redis)
133
+ storage_id = id
134
+ if storage_name is None:
135
+ raise ValueError(f'{cls._CLIENT_TYPE} with ID "{id}" does not exist.')
136
+ else:
137
+ metadata_data = await cls._get_metadata_by_name(name=internal_name, redis=redis)
138
+ storage_name = internal_name if metadata_data is not None else None
139
+ storage_id = metadata_data['id'] if metadata_data is not None else None
140
+ # If both storage_name and storage_id are found, open existing storage
141
+ if storage_name and storage_id:
142
+ client = cls(storage_name=storage_name, storage_id=storage_id, redis=redis, **instance_kwargs)
143
+ async with client._get_pipeline() as pipe:
144
+ await client._update_metadata(pipe, update_accessed_at=True)
145
+ # Otherwise, create a new storage
146
+ else:
147
+ now = datetime.now(timezone.utc)
148
+ metadata = metadata_model(
149
+ id=crypto_random_object_id(),
150
+ name=name,
151
+ created_at=now,
152
+ accessed_at=now,
153
+ modified_at=now,
154
+ **extra_metadata_fields,
155
+ )
156
+ client = cls(storage_name=internal_name, storage_id=metadata.id, redis=redis, **instance_kwargs)
157
+ created = await client._create_metadata_and_storage(internal_name, metadata.model_dump())
158
+ # The client was probably not created due to a race condition. Let's try to open it using the name.
159
+ if not created:
160
+ metadata_data = await cls._get_metadata_by_name(name=internal_name, redis=redis, with_wait=True)
161
+ client = cls(storage_name=internal_name, storage_id=metadata.id, redis=redis, **instance_kwargs)
162
+
163
+ # Ensure Lua scripts are loaded
164
+ await client._ensure_scripts_loaded()
165
+ return client
166
+
167
+ async def _load_scripts(self) -> None:
168
+ """Load Lua scripts in Redis."""
169
+ return
170
+
171
+ async def _ensure_scripts_loaded(self) -> None:
172
+ """Ensure Lua scripts are loaded in Redis."""
173
+ if not self._scripts_loaded:
174
+ await self._load_scripts()
175
+ self._scripts_loaded = True
176
+
177
+ @asynccontextmanager
178
+ async def _get_pipeline(self, *, with_execute: bool = True) -> AsyncIterator[Pipeline]:
179
+ """Create a new Redis pipeline."""
180
+ async with self._redis.pipeline() as pipe:
181
+ try:
182
+ pipe.multi() # type: ignore[no-untyped-call]
183
+ yield pipe
184
+ finally:
185
+ if with_execute:
186
+ await pipe.execute()
187
+
188
+ async def _create_storage(self, pipeline: Pipeline) -> None:
189
+ """Create the actual storage structure in Redis."""
190
+ _ = pipeline # To avoid unused variable mypy error
191
+
192
+ async def _create_script(self, script_name: str) -> AsyncScript:
193
+ """Load a Lua script from a file and return a Script object."""
194
+ script_content = await asyncio.to_thread(read_lua_script, script_name)
195
+
196
+ return self._redis.register_script(script_content)
197
+
198
+ async def _create_metadata_and_storage(self, storage_name: str, metadata: dict) -> bool:
199
+ index_id_to_name = f'{self._MAIN_KEY}:id_to_name'
200
+ index_name_to_id = f'{self._MAIN_KEY}:name_to_id'
201
+ metadata['created_at'] = metadata['created_at'].isoformat()
202
+ metadata['accessed_at'] = metadata['accessed_at'].isoformat()
203
+ metadata['modified_at'] = metadata['modified_at'].isoformat()
204
+
205
+ # Try to create name_to_id index entry, if it already exists, return False.
206
+ name_to_id = await await_redis_response(self._redis.hsetnx(index_name_to_id, storage_name, metadata['id']))
207
+ # If name already exists, return False. Probably an attempt at parallel creation.
208
+ if not name_to_id:
209
+ return False
210
+
211
+ # Create id_to_name index entry, metadata, and storage structure in a transaction.
212
+ async with self._get_pipeline() as pipe:
213
+ await await_redis_response(pipe.hsetnx(index_id_to_name, metadata['id'], storage_name))
214
+ await await_redis_response(pipe.json().set(self.metadata_key, '$', metadata))
215
+ await await_redis_response(pipe.lpush(f'{self._MAIN_KEY}:{storage_name}:created_signal', 1))
216
+
217
+ await self._create_storage(pipe)
218
+
219
+ return True
220
+
221
+ async def _drop(self, extra_keys: list[str]) -> None:
222
+ async with self._get_pipeline() as pipe:
223
+ await pipe.delete(self.metadata_key)
224
+ await pipe.delete(f'{self._MAIN_KEY}:id_to_name', self._storage_id)
225
+ await pipe.delete(f'{self._MAIN_KEY}:name_to_id', self._storage_name)
226
+ await pipe.delete(f'{self._MAIN_KEY}:{self._storage_name}:created_signal')
227
+ for key in extra_keys:
228
+ await pipe.delete(key)
229
+
230
+ async def _purge(self, extra_keys: list[str], metadata_kwargs: MetadataUpdateParams) -> None:
231
+ async with self._get_pipeline() as pipe:
232
+ for key in extra_keys:
233
+ await pipe.delete(key)
234
+ await self._update_metadata(pipe, **metadata_kwargs)
235
+ await self._create_storage(pipe)
236
+
237
+ @overload
238
+ async def _get_metadata(self, metadata_model: type[DatasetMetadata]) -> DatasetMetadata: ...
239
+ @overload
240
+ async def _get_metadata(self, metadata_model: type[KeyValueStoreMetadata]) -> KeyValueStoreMetadata: ...
241
+ @overload
242
+ async def _get_metadata(self, metadata_model: type[RequestQueueMetadata]) -> RequestQueueMetadata: ...
243
+
244
+ async def _get_metadata(
245
+ self, metadata_model: type[DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata]
246
+ ) -> DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata:
247
+ """Retrieve client metadata."""
248
+ metadata_dict = await self._get_metadata_by_name(name=self._storage_name, redis=self._redis)
249
+ if metadata_dict is None:
250
+ raise ValueError(f'{self._CLIENT_TYPE} with name "{self._storage_name}" does not exist.')
251
+ async with self._get_pipeline() as pipe:
252
+ await self._update_metadata(pipe, update_accessed_at=True)
253
+
254
+ return metadata_model.model_validate(metadata_dict)
255
+
256
+ async def _specific_update_metadata(self, pipeline: Pipeline, **kwargs: Any) -> None:
257
+ """Pipeline operations storage-specific metadata updates.
258
+
259
+ Must be implemented by concrete classes.
260
+
261
+ Args:
262
+ pipeline: The Redis pipeline to use for the update.
263
+ **kwargs: Storage-specific update parameters.
264
+ """
265
+ _ = pipeline # To avoid unused variable mypy error
266
+ _ = kwargs
267
+
268
+ async def _update_metadata(
269
+ self,
270
+ pipeline: Pipeline,
271
+ *,
272
+ update_accessed_at: bool = False,
273
+ update_modified_at: bool = False,
274
+ **kwargs: Any,
275
+ ) -> None:
276
+ """Update storage metadata combining common and specific fields.
277
+
278
+ Args:
279
+ pipeline: The Redis pipeline to use for the update.
280
+ update_accessed_at: Whether to update accessed_at timestamp.
281
+ update_modified_at: Whether to update modified_at timestamp.
282
+ **kwargs: Additional arguments for _specific_update_metadata.
283
+ """
284
+ now = datetime.now(timezone.utc)
285
+
286
+ if update_accessed_at:
287
+ await await_redis_response(
288
+ pipeline.json().set(self.metadata_key, '$.accessed_at', now.isoformat(), nx=False, xx=True)
289
+ )
290
+ if update_modified_at:
291
+ await await_redis_response(
292
+ pipeline.json().set(self.metadata_key, '$.modified_at', now.isoformat(), nx=False, xx=True)
293
+ )
294
+
295
+ await self._specific_update_metadata(pipeline, **kwargs)