crawlee 1.0.3b6__py3-none-any.whl → 1.0.5b18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- crawlee/_service_locator.py +4 -4
- crawlee/_utils/recoverable_state.py +32 -8
- crawlee/_utils/recurring_task.py +15 -0
- crawlee/_utils/robots.py +17 -5
- crawlee/_utils/sitemap.py +1 -1
- crawlee/_utils/urls.py +9 -2
- crawlee/browsers/_browser_pool.py +4 -1
- crawlee/browsers/_playwright_browser_controller.py +1 -1
- crawlee/browsers/_playwright_browser_plugin.py +17 -3
- crawlee/browsers/_types.py +1 -1
- crawlee/configuration.py +3 -1
- crawlee/crawlers/_abstract_http/_abstract_http_crawler.py +3 -1
- crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py +33 -13
- crawlee/crawlers/_basic/_basic_crawler.py +23 -12
- crawlee/crawlers/_playwright/_playwright_crawler.py +11 -4
- crawlee/fingerprint_suite/_header_generator.py +2 -2
- crawlee/otel/crawler_instrumentor.py +3 -3
- crawlee/request_loaders/_sitemap_request_loader.py +5 -0
- crawlee/sessions/_session_pool.py +1 -1
- crawlee/statistics/_error_snapshotter.py +1 -1
- crawlee/statistics/_statistics.py +41 -31
- crawlee/storage_clients/__init__.py +4 -0
- crawlee/storage_clients/_file_system/_request_queue_client.py +24 -6
- crawlee/storage_clients/_redis/__init__.py +6 -0
- crawlee/storage_clients/_redis/_client_mixin.py +295 -0
- crawlee/storage_clients/_redis/_dataset_client.py +325 -0
- crawlee/storage_clients/_redis/_key_value_store_client.py +264 -0
- crawlee/storage_clients/_redis/_request_queue_client.py +586 -0
- crawlee/storage_clients/_redis/_storage_client.py +146 -0
- crawlee/storage_clients/_redis/_utils.py +23 -0
- crawlee/storage_clients/_redis/lua_scripts/atomic_bloom_add_requests.lua +36 -0
- crawlee/storage_clients/_redis/lua_scripts/atomic_fetch_request.lua +49 -0
- crawlee/storage_clients/_redis/lua_scripts/atomic_set_add_requests.lua +37 -0
- crawlee/storage_clients/_redis/lua_scripts/reclaim_stale_requests.lua +34 -0
- crawlee/storage_clients/_redis/py.typed +0 -0
- crawlee/storage_clients/_sql/_db_models.py +1 -2
- crawlee/storage_clients/_sql/_storage_client.py +9 -0
- crawlee/storages/_key_value_store.py +5 -2
- {crawlee-1.0.3b6.dist-info → crawlee-1.0.5b18.dist-info}/METADATA +9 -5
- {crawlee-1.0.3b6.dist-info → crawlee-1.0.5b18.dist-info}/RECORD +43 -31
- {crawlee-1.0.3b6.dist-info → crawlee-1.0.5b18.dist-info}/WHEEL +0 -0
- {crawlee-1.0.3b6.dist-info → crawlee-1.0.5b18.dist-info}/entry_points.txt +0 -0
- {crawlee-1.0.3b6.dist-info → crawlee-1.0.5b18.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
# Inspiration: https://github.com/apify/crawlee/blob/v3.9.2/packages/core/src/crawlers/statistics.ts
|
|
2
2
|
from __future__ import annotations
|
|
3
3
|
|
|
4
|
+
import asyncio
|
|
4
5
|
import math
|
|
5
6
|
import time
|
|
6
7
|
from datetime import datetime, timedelta, timezone
|
|
@@ -17,8 +18,11 @@ from crawlee.statistics import FinalStatistics, StatisticsState
|
|
|
17
18
|
from crawlee.statistics._error_tracker import ErrorTracker
|
|
18
19
|
|
|
19
20
|
if TYPE_CHECKING:
|
|
21
|
+
from collections.abc import Callable, Coroutine
|
|
20
22
|
from types import TracebackType
|
|
21
23
|
|
|
24
|
+
from crawlee.storages import KeyValueStore
|
|
25
|
+
|
|
22
26
|
TStatisticsState = TypeVar('TStatisticsState', bound=StatisticsState, default=StatisticsState)
|
|
23
27
|
TNewStatisticsState = TypeVar('TNewStatisticsState', bound=StatisticsState, default=StatisticsState)
|
|
24
28
|
logger = getLogger(__name__)
|
|
@@ -70,6 +74,7 @@ class Statistics(Generic[TStatisticsState]):
|
|
|
70
74
|
persistence_enabled: bool | Literal['explicit_only'] = False,
|
|
71
75
|
persist_state_kvs_name: str | None = None,
|
|
72
76
|
persist_state_key: str | None = None,
|
|
77
|
+
persist_state_kvs_factory: Callable[[], Coroutine[None, None, KeyValueStore]] | None = None,
|
|
73
78
|
log_message: str = 'Statistics',
|
|
74
79
|
periodic_message_logger: Logger | None = None,
|
|
75
80
|
log_interval: timedelta = timedelta(minutes=1),
|
|
@@ -80,8 +85,6 @@ class Statistics(Generic[TStatisticsState]):
|
|
|
80
85
|
self._id = Statistics.__next_id
|
|
81
86
|
Statistics.__next_id += 1
|
|
82
87
|
|
|
83
|
-
self._instance_start: datetime | None = None
|
|
84
|
-
|
|
85
88
|
self.error_tracker = ErrorTracker(
|
|
86
89
|
save_error_snapshots=save_error_snapshots,
|
|
87
90
|
snapshot_kvs_name=persist_state_kvs_name,
|
|
@@ -92,9 +95,10 @@ class Statistics(Generic[TStatisticsState]):
|
|
|
92
95
|
|
|
93
96
|
self._state = RecoverableState(
|
|
94
97
|
default_state=state_model(stats_id=self._id),
|
|
95
|
-
persist_state_key=persist_state_key or f'
|
|
98
|
+
persist_state_key=persist_state_key or f'__CRAWLER_STATISTICS_{self._id}',
|
|
96
99
|
persistence_enabled=persistence_enabled,
|
|
97
100
|
persist_state_kvs_name=persist_state_kvs_name,
|
|
101
|
+
persist_state_kvs_factory=persist_state_kvs_factory,
|
|
98
102
|
logger=logger,
|
|
99
103
|
)
|
|
100
104
|
|
|
@@ -106,12 +110,15 @@ class Statistics(Generic[TStatisticsState]):
|
|
|
106
110
|
# Flag to indicate the context state.
|
|
107
111
|
self._active = False
|
|
108
112
|
|
|
113
|
+
# Pre-existing runtime offset, that can be non-zero when restoring serialized state from KVS.
|
|
114
|
+
self._runtime_offset = timedelta(seconds=0)
|
|
115
|
+
|
|
109
116
|
def replace_state_model(self, state_model: type[TNewStatisticsState]) -> Statistics[TNewStatisticsState]:
|
|
110
117
|
"""Create near copy of the `Statistics` with replaced `state_model`."""
|
|
111
118
|
new_statistics: Statistics[TNewStatisticsState] = Statistics(
|
|
112
119
|
persistence_enabled=self._state._persistence_enabled, # noqa: SLF001
|
|
113
|
-
persist_state_kvs_name=self._state._persist_state_kvs_name, # noqa: SLF001
|
|
114
120
|
persist_state_key=self._state._persist_state_key, # noqa: SLF001
|
|
121
|
+
persist_state_kvs_factory=self._state._persist_state_kvs_factory, # noqa: SLF001
|
|
115
122
|
log_message=self._log_message,
|
|
116
123
|
periodic_message_logger=self._periodic_message_logger,
|
|
117
124
|
state_model=state_model,
|
|
@@ -125,6 +132,7 @@ class Statistics(Generic[TStatisticsState]):
|
|
|
125
132
|
persistence_enabled: bool = False,
|
|
126
133
|
persist_state_kvs_name: str | None = None,
|
|
127
134
|
persist_state_key: str | None = None,
|
|
135
|
+
persist_state_kvs_factory: Callable[[], Coroutine[None, None, KeyValueStore]] | None = None,
|
|
128
136
|
log_message: str = 'Statistics',
|
|
129
137
|
periodic_message_logger: Logger | None = None,
|
|
130
138
|
log_interval: timedelta = timedelta(minutes=1),
|
|
@@ -136,6 +144,7 @@ class Statistics(Generic[TStatisticsState]):
|
|
|
136
144
|
persistence_enabled=persistence_enabled,
|
|
137
145
|
persist_state_kvs_name=persist_state_kvs_name,
|
|
138
146
|
persist_state_key=persist_state_key,
|
|
147
|
+
persist_state_kvs_factory=persist_state_kvs_factory,
|
|
139
148
|
log_message=log_message,
|
|
140
149
|
periodic_message_logger=periodic_message_logger,
|
|
141
150
|
log_interval=log_interval,
|
|
@@ -158,14 +167,17 @@ class Statistics(Generic[TStatisticsState]):
|
|
|
158
167
|
if self._active:
|
|
159
168
|
raise RuntimeError(f'The {self.__class__.__name__} is already active.')
|
|
160
169
|
|
|
161
|
-
self._active = True
|
|
162
|
-
self._instance_start = datetime.now(timezone.utc)
|
|
163
|
-
|
|
164
170
|
await self._state.initialize()
|
|
165
|
-
self._after_initialize()
|
|
166
171
|
|
|
172
|
+
self._runtime_offset = self.state.crawler_runtime
|
|
173
|
+
|
|
174
|
+
# Start periodic logging and let it print initial state before activation.
|
|
167
175
|
self._periodic_logger.start()
|
|
176
|
+
await asyncio.sleep(0.01)
|
|
177
|
+
self._active = True
|
|
168
178
|
|
|
179
|
+
self.state.crawler_last_started_at = datetime.now(timezone.utc)
|
|
180
|
+
self.state.crawler_started_at = self.state.crawler_started_at or self.state.crawler_last_started_at
|
|
169
181
|
return self
|
|
170
182
|
|
|
171
183
|
async def __aexit__(
|
|
@@ -182,13 +194,18 @@ class Statistics(Generic[TStatisticsState]):
|
|
|
182
194
|
if not self._active:
|
|
183
195
|
raise RuntimeError(f'The {self.__class__.__name__} is not active.')
|
|
184
196
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
await self._state.teardown()
|
|
197
|
+
if not self.state.crawler_last_started_at:
|
|
198
|
+
raise RuntimeError('Statistics.state.crawler_last_started_at not set.')
|
|
188
199
|
|
|
200
|
+
# Stop logging and deactivate the statistics to prevent further changes to crawler_runtime
|
|
189
201
|
await self._periodic_logger.stop()
|
|
202
|
+
self.state.crawler_finished_at = datetime.now(timezone.utc)
|
|
203
|
+
self.state.crawler_runtime = (
|
|
204
|
+
self._runtime_offset + self.state.crawler_finished_at - self.state.crawler_last_started_at
|
|
205
|
+
)
|
|
190
206
|
|
|
191
207
|
self._active = False
|
|
208
|
+
await self._state.teardown()
|
|
192
209
|
|
|
193
210
|
@property
|
|
194
211
|
def state(self) -> TStatisticsState:
|
|
@@ -245,13 +262,21 @@ class Statistics(Generic[TStatisticsState]):
|
|
|
245
262
|
|
|
246
263
|
del self._requests_in_progress[request_id_or_key]
|
|
247
264
|
|
|
265
|
+
def _update_crawler_runtime(self) -> None:
|
|
266
|
+
current_run_duration = (
|
|
267
|
+
(datetime.now(timezone.utc) - self.state.crawler_last_started_at)
|
|
268
|
+
if self.state.crawler_last_started_at
|
|
269
|
+
else timedelta()
|
|
270
|
+
)
|
|
271
|
+
self.state.crawler_runtime = current_run_duration + self._runtime_offset
|
|
272
|
+
|
|
248
273
|
def calculate(self) -> FinalStatistics:
|
|
249
274
|
"""Calculate the current statistics."""
|
|
250
|
-
if self.
|
|
251
|
-
|
|
275
|
+
if self._active:
|
|
276
|
+
# Only update state when active. If not, just report the last known runtime.
|
|
277
|
+
self._update_crawler_runtime()
|
|
252
278
|
|
|
253
|
-
|
|
254
|
-
total_minutes = crawler_runtime.total_seconds() / 60
|
|
279
|
+
total_minutes = self.state.crawler_runtime.total_seconds() / 60
|
|
255
280
|
state = self._state.current_value
|
|
256
281
|
serialized_state = state.model_dump(by_alias=False)
|
|
257
282
|
|
|
@@ -262,7 +287,7 @@ class Statistics(Generic[TStatisticsState]):
|
|
|
262
287
|
requests_failed_per_minute=math.floor(state.requests_failed / total_minutes) if total_minutes else 0,
|
|
263
288
|
request_total_duration=state.request_total_finished_duration + state.request_total_failed_duration,
|
|
264
289
|
requests_total=state.requests_failed + state.requests_finished,
|
|
265
|
-
crawler_runtime=crawler_runtime,
|
|
290
|
+
crawler_runtime=state.crawler_runtime,
|
|
266
291
|
requests_finished=state.requests_finished,
|
|
267
292
|
requests_failed=state.requests_failed,
|
|
268
293
|
retry_histogram=serialized_state['request_retry_histogram'],
|
|
@@ -282,21 +307,6 @@ class Statistics(Generic[TStatisticsState]):
|
|
|
282
307
|
else:
|
|
283
308
|
self._periodic_message_logger.info(self._log_message, extra=stats.to_dict())
|
|
284
309
|
|
|
285
|
-
def _after_initialize(self) -> None:
|
|
286
|
-
state = self._state.current_value
|
|
287
|
-
|
|
288
|
-
if state.crawler_started_at is None:
|
|
289
|
-
state.crawler_started_at = datetime.now(timezone.utc)
|
|
290
|
-
|
|
291
|
-
if state.stats_persisted_at is not None and state.crawler_last_started_at:
|
|
292
|
-
self._instance_start = datetime.now(timezone.utc) - (
|
|
293
|
-
state.stats_persisted_at - state.crawler_last_started_at
|
|
294
|
-
)
|
|
295
|
-
elif state.crawler_last_started_at:
|
|
296
|
-
self._instance_start = state.crawler_last_started_at
|
|
297
|
-
|
|
298
|
-
state.crawler_last_started_at = self._instance_start
|
|
299
|
-
|
|
300
310
|
def _save_retry_count_for_request(self, record: RequestProcessingRecord) -> None:
|
|
301
311
|
retry_count = record.retry_count
|
|
302
312
|
state = self._state.current_value
|
|
@@ -13,9 +13,13 @@ _install_import_hook(__name__)
|
|
|
13
13
|
with _try_import(__name__, 'SqlStorageClient'):
|
|
14
14
|
from ._sql import SqlStorageClient
|
|
15
15
|
|
|
16
|
+
with _try_import(__name__, 'RedisStorageClient'):
|
|
17
|
+
from ._redis import RedisStorageClient
|
|
18
|
+
|
|
16
19
|
__all__ = [
|
|
17
20
|
'FileSystemStorageClient',
|
|
18
21
|
'MemoryStorageClient',
|
|
22
|
+
'RedisStorageClient',
|
|
19
23
|
'SqlStorageClient',
|
|
20
24
|
'StorageClient',
|
|
21
25
|
]
|
|
@@ -31,6 +31,7 @@ if TYPE_CHECKING:
|
|
|
31
31
|
from collections.abc import Sequence
|
|
32
32
|
|
|
33
33
|
from crawlee.configuration import Configuration
|
|
34
|
+
from crawlee.storages import KeyValueStore
|
|
34
35
|
|
|
35
36
|
logger = getLogger(__name__)
|
|
36
37
|
|
|
@@ -92,6 +93,7 @@ class FileSystemRequestQueueClient(RequestQueueClient):
|
|
|
92
93
|
metadata: RequestQueueMetadata,
|
|
93
94
|
path_to_rq: Path,
|
|
94
95
|
lock: asyncio.Lock,
|
|
96
|
+
recoverable_state: RecoverableState[RequestQueueState],
|
|
95
97
|
) -> None:
|
|
96
98
|
"""Initialize a new instance.
|
|
97
99
|
|
|
@@ -114,12 +116,7 @@ class FileSystemRequestQueueClient(RequestQueueClient):
|
|
|
114
116
|
self._is_empty_cache: bool | None = None
|
|
115
117
|
"""Cache for is_empty result: None means unknown, True/False is cached state."""
|
|
116
118
|
|
|
117
|
-
self._state =
|
|
118
|
-
default_state=RequestQueueState(),
|
|
119
|
-
persist_state_key=f'__RQ_STATE_{self._metadata.id}',
|
|
120
|
-
persistence_enabled=True,
|
|
121
|
-
logger=logger,
|
|
122
|
-
)
|
|
119
|
+
self._state = recoverable_state
|
|
123
120
|
"""Recoverable state to maintain request ordering, in-progress status, and handled status."""
|
|
124
121
|
|
|
125
122
|
@override
|
|
@@ -136,6 +133,22 @@ class FileSystemRequestQueueClient(RequestQueueClient):
|
|
|
136
133
|
"""The full path to the request queue metadata file."""
|
|
137
134
|
return self.path_to_rq / METADATA_FILENAME
|
|
138
135
|
|
|
136
|
+
@classmethod
|
|
137
|
+
async def _create_recoverable_state(cls, id: str, configuration: Configuration) -> RecoverableState:
|
|
138
|
+
async def kvs_factory() -> KeyValueStore:
|
|
139
|
+
from crawlee.storage_clients import FileSystemStorageClient # noqa: PLC0415 avoid circular import
|
|
140
|
+
from crawlee.storages import KeyValueStore # noqa: PLC0415 avoid circular import
|
|
141
|
+
|
|
142
|
+
return await KeyValueStore.open(storage_client=FileSystemStorageClient(), configuration=configuration)
|
|
143
|
+
|
|
144
|
+
return RecoverableState[RequestQueueState](
|
|
145
|
+
default_state=RequestQueueState(),
|
|
146
|
+
persist_state_key=f'__RQ_STATE_{id}',
|
|
147
|
+
persist_state_kvs_factory=kvs_factory,
|
|
148
|
+
persistence_enabled=True,
|
|
149
|
+
logger=logger,
|
|
150
|
+
)
|
|
151
|
+
|
|
139
152
|
@classmethod
|
|
140
153
|
async def open(
|
|
141
154
|
cls,
|
|
@@ -194,6 +207,9 @@ class FileSystemRequestQueueClient(RequestQueueClient):
|
|
|
194
207
|
metadata=metadata,
|
|
195
208
|
path_to_rq=rq_base_path / rq_dir,
|
|
196
209
|
lock=asyncio.Lock(),
|
|
210
|
+
recoverable_state=await cls._create_recoverable_state(
|
|
211
|
+
id=id, configuration=configuration
|
|
212
|
+
),
|
|
197
213
|
)
|
|
198
214
|
await client._state.initialize()
|
|
199
215
|
await client._discover_existing_requests()
|
|
@@ -230,6 +246,7 @@ class FileSystemRequestQueueClient(RequestQueueClient):
|
|
|
230
246
|
metadata=metadata,
|
|
231
247
|
path_to_rq=path_to_rq,
|
|
232
248
|
lock=asyncio.Lock(),
|
|
249
|
+
recoverable_state=await cls._create_recoverable_state(id=metadata.id, configuration=configuration),
|
|
233
250
|
)
|
|
234
251
|
|
|
235
252
|
await client._state.initialize()
|
|
@@ -254,6 +271,7 @@ class FileSystemRequestQueueClient(RequestQueueClient):
|
|
|
254
271
|
metadata=metadata,
|
|
255
272
|
path_to_rq=path_to_rq,
|
|
256
273
|
lock=asyncio.Lock(),
|
|
274
|
+
recoverable_state=await cls._create_recoverable_state(id=metadata.id, configuration=configuration),
|
|
257
275
|
)
|
|
258
276
|
await client._state.initialize()
|
|
259
277
|
await client._update_metadata()
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
from ._dataset_client import RedisDatasetClient
|
|
2
|
+
from ._key_value_store_client import RedisKeyValueStoreClient
|
|
3
|
+
from ._request_queue_client import RedisRequestQueueClient
|
|
4
|
+
from ._storage_client import RedisStorageClient
|
|
5
|
+
|
|
6
|
+
__all__ = ['RedisDatasetClient', 'RedisKeyValueStoreClient', 'RedisRequestQueueClient', 'RedisStorageClient']
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from logging import getLogger
|
|
7
|
+
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, overload
|
|
8
|
+
|
|
9
|
+
from crawlee._utils.crypto import crypto_random_object_id
|
|
10
|
+
|
|
11
|
+
from ._utils import await_redis_response, read_lua_script
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from collections.abc import AsyncIterator
|
|
15
|
+
|
|
16
|
+
from redis.asyncio import Redis
|
|
17
|
+
from redis.asyncio.client import Pipeline
|
|
18
|
+
from redis.commands.core import AsyncScript
|
|
19
|
+
from typing_extensions import NotRequired, Self
|
|
20
|
+
|
|
21
|
+
from crawlee.storage_clients.models import DatasetMetadata, KeyValueStoreMetadata, RequestQueueMetadata
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
logger = getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class MetadataUpdateParams(TypedDict, total=False):
|
|
28
|
+
"""Parameters for updating metadata."""
|
|
29
|
+
|
|
30
|
+
update_accessed_at: NotRequired[bool]
|
|
31
|
+
update_modified_at: NotRequired[bool]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class RedisClientMixin:
|
|
35
|
+
"""Mixin class for Redis clients.
|
|
36
|
+
|
|
37
|
+
This mixin provides common Redis operations and basic methods for Redis storage clients.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
_DEFAULT_NAME = 'default'
|
|
41
|
+
"""Default storage name in key prefix when none provided."""
|
|
42
|
+
|
|
43
|
+
_MAIN_KEY: ClassVar[str]
|
|
44
|
+
"""Main Redis key prefix for this storage type."""
|
|
45
|
+
|
|
46
|
+
_CLIENT_TYPE: ClassVar[str]
|
|
47
|
+
"""Human-readable client type for error messages."""
|
|
48
|
+
|
|
49
|
+
def __init__(self, storage_name: str, storage_id: str, redis: Redis) -> None:
|
|
50
|
+
self._storage_name = storage_name
|
|
51
|
+
self._storage_id = storage_id
|
|
52
|
+
self._redis = redis
|
|
53
|
+
|
|
54
|
+
self._scripts_loaded = False
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def redis(self) -> Redis:
|
|
58
|
+
"""Return the Redis client instance."""
|
|
59
|
+
return self._redis
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def metadata_key(self) -> str:
|
|
63
|
+
"""Return the Redis key for the metadata of this storage."""
|
|
64
|
+
return f'{self._MAIN_KEY}:{self._storage_name}:metadata'
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
async def _get_metadata_by_name(cls, name: str, redis: Redis, *, with_wait: bool = False) -> dict | None:
|
|
68
|
+
"""Retrieve metadata by storage name.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
name: The name of the storage.
|
|
72
|
+
redis: The Redis client instance.
|
|
73
|
+
with_wait: Whether to wait for the storage to be created if it doesn't exist.
|
|
74
|
+
"""
|
|
75
|
+
if with_wait:
|
|
76
|
+
# Wait for the creation signal (max 30 seconds)
|
|
77
|
+
await await_redis_response(redis.blpop([f'{cls._MAIN_KEY}:{name}:created_signal'], timeout=30))
|
|
78
|
+
# Signal consumed, push it back for other waiters
|
|
79
|
+
await await_redis_response(redis.lpush(f'{cls._MAIN_KEY}:{name}:created_signal', 1))
|
|
80
|
+
|
|
81
|
+
response = await await_redis_response(redis.json().get(f'{cls._MAIN_KEY}:{name}:metadata'))
|
|
82
|
+
data = response[0] if response is not None and isinstance(response, list) else response
|
|
83
|
+
if data is not None and not isinstance(data, dict):
|
|
84
|
+
raise TypeError('The metadata data was received in an incorrect format.')
|
|
85
|
+
return data
|
|
86
|
+
|
|
87
|
+
@classmethod
|
|
88
|
+
async def _get_metadata_name_by_id(cls, id: str, redis: Redis) -> str | None:
|
|
89
|
+
"""Retrieve storage name by ID from id_to_name index.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
id: The ID of the storage.
|
|
93
|
+
redis: The Redis client instance.
|
|
94
|
+
"""
|
|
95
|
+
name = await await_redis_response(redis.hget(f'{cls._MAIN_KEY}:id_to_name', id))
|
|
96
|
+
if isinstance(name, str) or name is None:
|
|
97
|
+
return name
|
|
98
|
+
if isinstance(name, bytes):
|
|
99
|
+
return name.decode('utf-8')
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
@classmethod
|
|
103
|
+
async def _open(
|
|
104
|
+
cls,
|
|
105
|
+
*,
|
|
106
|
+
id: str | None,
|
|
107
|
+
name: str | None,
|
|
108
|
+
alias: str | None,
|
|
109
|
+
metadata_model: type[DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata],
|
|
110
|
+
redis: Redis,
|
|
111
|
+
extra_metadata_fields: dict[str, Any],
|
|
112
|
+
instance_kwargs: dict[str, Any],
|
|
113
|
+
) -> Self:
|
|
114
|
+
"""Open or create a new Redis storage client.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
id: The ID of the storage. If not provided, a random ID will be generated.
|
|
118
|
+
name: The name of the storage for named (global scope) storages.
|
|
119
|
+
alias: The alias of the storage for unnamed (run scope) storages.
|
|
120
|
+
redis: Redis client instance.
|
|
121
|
+
metadata_model: Pydantic model for metadata validation.
|
|
122
|
+
extra_metadata_fields: Storage-specific metadata fields.
|
|
123
|
+
instance_kwargs: Additional arguments for the client constructor.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
An instance for the opened or created storage client.
|
|
127
|
+
"""
|
|
128
|
+
internal_name = name or alias or cls._DEFAULT_NAME
|
|
129
|
+
storage_id: str | None = None
|
|
130
|
+
# Determine if storage exists by ID or name
|
|
131
|
+
if id:
|
|
132
|
+
storage_name = await cls._get_metadata_name_by_id(id=id, redis=redis)
|
|
133
|
+
storage_id = id
|
|
134
|
+
if storage_name is None:
|
|
135
|
+
raise ValueError(f'{cls._CLIENT_TYPE} with ID "{id}" does not exist.')
|
|
136
|
+
else:
|
|
137
|
+
metadata_data = await cls._get_metadata_by_name(name=internal_name, redis=redis)
|
|
138
|
+
storage_name = internal_name if metadata_data is not None else None
|
|
139
|
+
storage_id = metadata_data['id'] if metadata_data is not None else None
|
|
140
|
+
# If both storage_name and storage_id are found, open existing storage
|
|
141
|
+
if storage_name and storage_id:
|
|
142
|
+
client = cls(storage_name=storage_name, storage_id=storage_id, redis=redis, **instance_kwargs)
|
|
143
|
+
async with client._get_pipeline() as pipe:
|
|
144
|
+
await client._update_metadata(pipe, update_accessed_at=True)
|
|
145
|
+
# Otherwise, create a new storage
|
|
146
|
+
else:
|
|
147
|
+
now = datetime.now(timezone.utc)
|
|
148
|
+
metadata = metadata_model(
|
|
149
|
+
id=crypto_random_object_id(),
|
|
150
|
+
name=name,
|
|
151
|
+
created_at=now,
|
|
152
|
+
accessed_at=now,
|
|
153
|
+
modified_at=now,
|
|
154
|
+
**extra_metadata_fields,
|
|
155
|
+
)
|
|
156
|
+
client = cls(storage_name=internal_name, storage_id=metadata.id, redis=redis, **instance_kwargs)
|
|
157
|
+
created = await client._create_metadata_and_storage(internal_name, metadata.model_dump())
|
|
158
|
+
# The client was probably not created due to a race condition. Let's try to open it using the name.
|
|
159
|
+
if not created:
|
|
160
|
+
metadata_data = await cls._get_metadata_by_name(name=internal_name, redis=redis, with_wait=True)
|
|
161
|
+
client = cls(storage_name=internal_name, storage_id=metadata.id, redis=redis, **instance_kwargs)
|
|
162
|
+
|
|
163
|
+
# Ensure Lua scripts are loaded
|
|
164
|
+
await client._ensure_scripts_loaded()
|
|
165
|
+
return client
|
|
166
|
+
|
|
167
|
+
async def _load_scripts(self) -> None:
|
|
168
|
+
"""Load Lua scripts in Redis."""
|
|
169
|
+
return
|
|
170
|
+
|
|
171
|
+
async def _ensure_scripts_loaded(self) -> None:
|
|
172
|
+
"""Ensure Lua scripts are loaded in Redis."""
|
|
173
|
+
if not self._scripts_loaded:
|
|
174
|
+
await self._load_scripts()
|
|
175
|
+
self._scripts_loaded = True
|
|
176
|
+
|
|
177
|
+
@asynccontextmanager
|
|
178
|
+
async def _get_pipeline(self, *, with_execute: bool = True) -> AsyncIterator[Pipeline]:
|
|
179
|
+
"""Create a new Redis pipeline."""
|
|
180
|
+
async with self._redis.pipeline() as pipe:
|
|
181
|
+
try:
|
|
182
|
+
pipe.multi() # type: ignore[no-untyped-call]
|
|
183
|
+
yield pipe
|
|
184
|
+
finally:
|
|
185
|
+
if with_execute:
|
|
186
|
+
await pipe.execute()
|
|
187
|
+
|
|
188
|
+
async def _create_storage(self, pipeline: Pipeline) -> None:
|
|
189
|
+
"""Create the actual storage structure in Redis."""
|
|
190
|
+
_ = pipeline # To avoid unused variable mypy error
|
|
191
|
+
|
|
192
|
+
async def _create_script(self, script_name: str) -> AsyncScript:
|
|
193
|
+
"""Load a Lua script from a file and return a Script object."""
|
|
194
|
+
script_content = await asyncio.to_thread(read_lua_script, script_name)
|
|
195
|
+
|
|
196
|
+
return self._redis.register_script(script_content)
|
|
197
|
+
|
|
198
|
+
async def _create_metadata_and_storage(self, storage_name: str, metadata: dict) -> bool:
|
|
199
|
+
index_id_to_name = f'{self._MAIN_KEY}:id_to_name'
|
|
200
|
+
index_name_to_id = f'{self._MAIN_KEY}:name_to_id'
|
|
201
|
+
metadata['created_at'] = metadata['created_at'].isoformat()
|
|
202
|
+
metadata['accessed_at'] = metadata['accessed_at'].isoformat()
|
|
203
|
+
metadata['modified_at'] = metadata['modified_at'].isoformat()
|
|
204
|
+
|
|
205
|
+
# Try to create name_to_id index entry, if it already exists, return False.
|
|
206
|
+
name_to_id = await await_redis_response(self._redis.hsetnx(index_name_to_id, storage_name, metadata['id']))
|
|
207
|
+
# If name already exists, return False. Probably an attempt at parallel creation.
|
|
208
|
+
if not name_to_id:
|
|
209
|
+
return False
|
|
210
|
+
|
|
211
|
+
# Create id_to_name index entry, metadata, and storage structure in a transaction.
|
|
212
|
+
async with self._get_pipeline() as pipe:
|
|
213
|
+
await await_redis_response(pipe.hsetnx(index_id_to_name, metadata['id'], storage_name))
|
|
214
|
+
await await_redis_response(pipe.json().set(self.metadata_key, '$', metadata))
|
|
215
|
+
await await_redis_response(pipe.lpush(f'{self._MAIN_KEY}:{storage_name}:created_signal', 1))
|
|
216
|
+
|
|
217
|
+
await self._create_storage(pipe)
|
|
218
|
+
|
|
219
|
+
return True
|
|
220
|
+
|
|
221
|
+
async def _drop(self, extra_keys: list[str]) -> None:
|
|
222
|
+
async with self._get_pipeline() as pipe:
|
|
223
|
+
await pipe.delete(self.metadata_key)
|
|
224
|
+
await pipe.delete(f'{self._MAIN_KEY}:id_to_name', self._storage_id)
|
|
225
|
+
await pipe.delete(f'{self._MAIN_KEY}:name_to_id', self._storage_name)
|
|
226
|
+
await pipe.delete(f'{self._MAIN_KEY}:{self._storage_name}:created_signal')
|
|
227
|
+
for key in extra_keys:
|
|
228
|
+
await pipe.delete(key)
|
|
229
|
+
|
|
230
|
+
async def _purge(self, extra_keys: list[str], metadata_kwargs: MetadataUpdateParams) -> None:
|
|
231
|
+
async with self._get_pipeline() as pipe:
|
|
232
|
+
for key in extra_keys:
|
|
233
|
+
await pipe.delete(key)
|
|
234
|
+
await self._update_metadata(pipe, **metadata_kwargs)
|
|
235
|
+
await self._create_storage(pipe)
|
|
236
|
+
|
|
237
|
+
@overload
|
|
238
|
+
async def _get_metadata(self, metadata_model: type[DatasetMetadata]) -> DatasetMetadata: ...
|
|
239
|
+
@overload
|
|
240
|
+
async def _get_metadata(self, metadata_model: type[KeyValueStoreMetadata]) -> KeyValueStoreMetadata: ...
|
|
241
|
+
@overload
|
|
242
|
+
async def _get_metadata(self, metadata_model: type[RequestQueueMetadata]) -> RequestQueueMetadata: ...
|
|
243
|
+
|
|
244
|
+
async def _get_metadata(
|
|
245
|
+
self, metadata_model: type[DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata]
|
|
246
|
+
) -> DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata:
|
|
247
|
+
"""Retrieve client metadata."""
|
|
248
|
+
metadata_dict = await self._get_metadata_by_name(name=self._storage_name, redis=self._redis)
|
|
249
|
+
if metadata_dict is None:
|
|
250
|
+
raise ValueError(f'{self._CLIENT_TYPE} with name "{self._storage_name}" does not exist.')
|
|
251
|
+
async with self._get_pipeline() as pipe:
|
|
252
|
+
await self._update_metadata(pipe, update_accessed_at=True)
|
|
253
|
+
|
|
254
|
+
return metadata_model.model_validate(metadata_dict)
|
|
255
|
+
|
|
256
|
+
async def _specific_update_metadata(self, pipeline: Pipeline, **kwargs: Any) -> None:
|
|
257
|
+
"""Pipeline operations storage-specific metadata updates.
|
|
258
|
+
|
|
259
|
+
Must be implemented by concrete classes.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
pipeline: The Redis pipeline to use for the update.
|
|
263
|
+
**kwargs: Storage-specific update parameters.
|
|
264
|
+
"""
|
|
265
|
+
_ = pipeline # To avoid unused variable mypy error
|
|
266
|
+
_ = kwargs
|
|
267
|
+
|
|
268
|
+
async def _update_metadata(
|
|
269
|
+
self,
|
|
270
|
+
pipeline: Pipeline,
|
|
271
|
+
*,
|
|
272
|
+
update_accessed_at: bool = False,
|
|
273
|
+
update_modified_at: bool = False,
|
|
274
|
+
**kwargs: Any,
|
|
275
|
+
) -> None:
|
|
276
|
+
"""Update storage metadata combining common and specific fields.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
pipeline: The Redis pipeline to use for the update.
|
|
280
|
+
update_accessed_at: Whether to update accessed_at timestamp.
|
|
281
|
+
update_modified_at: Whether to update modified_at timestamp.
|
|
282
|
+
**kwargs: Additional arguments for _specific_update_metadata.
|
|
283
|
+
"""
|
|
284
|
+
now = datetime.now(timezone.utc)
|
|
285
|
+
|
|
286
|
+
if update_accessed_at:
|
|
287
|
+
await await_redis_response(
|
|
288
|
+
pipeline.json().set(self.metadata_key, '$.accessed_at', now.isoformat(), nx=False, xx=True)
|
|
289
|
+
)
|
|
290
|
+
if update_modified_at:
|
|
291
|
+
await await_redis_response(
|
|
292
|
+
pipeline.json().set(self.metadata_key, '$.modified_at', now.isoformat(), nx=False, xx=True)
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
await self._specific_update_metadata(pipeline, **kwargs)
|