crawlee 1.0.0rc1__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- crawlee/_autoscaling/snapshotter.py +1 -1
- crawlee/_request.py +2 -1
- crawlee/_service_locator.py +44 -24
- crawlee/_types.py +76 -17
- crawlee/_utils/raise_if_too_many_kwargs.py +12 -0
- crawlee/_utils/sitemap.py +3 -1
- crawlee/_utils/system.py +3 -3
- crawlee/browsers/_playwright_browser_controller.py +20 -14
- crawlee/configuration.py +1 -1
- crawlee/crawlers/_abstract_http/_abstract_http_crawler.py +3 -1
- crawlee/crawlers/_abstract_http/_abstract_http_parser.py +1 -1
- crawlee/crawlers/_abstract_http/_http_crawling_context.py +1 -1
- crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py +6 -2
- crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler_statistics.py +1 -1
- crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawling_context.py +2 -1
- crawlee/crawlers/_adaptive_playwright/_rendering_type_predictor.py +1 -1
- crawlee/crawlers/_basic/_basic_crawler.py +107 -27
- crawlee/crawlers/_basic/_logging_utils.py +5 -1
- crawlee/crawlers/_playwright/_playwright_crawler.py +6 -1
- crawlee/events/_types.py +6 -6
- crawlee/fingerprint_suite/_fingerprint_generator.py +3 -0
- crawlee/fingerprint_suite/_types.py +2 -2
- crawlee/project_template/{{cookiecutter.project_name}}/pyproject.toml +2 -2
- crawlee/project_template/{{cookiecutter.project_name}}/requirements.txt +3 -0
- crawlee/request_loaders/_request_list.py +1 -1
- crawlee/request_loaders/_request_loader.py +5 -1
- crawlee/request_loaders/_sitemap_request_loader.py +228 -48
- crawlee/sessions/_models.py +2 -2
- crawlee/statistics/_models.py +1 -1
- crawlee/storage_clients/__init__.py +12 -0
- crawlee/storage_clients/_base/_storage_client.py +13 -0
- crawlee/storage_clients/_file_system/_dataset_client.py +27 -25
- crawlee/storage_clients/_file_system/_key_value_store_client.py +27 -23
- crawlee/storage_clients/_file_system/_request_queue_client.py +84 -98
- crawlee/storage_clients/_file_system/_storage_client.py +16 -3
- crawlee/storage_clients/_file_system/_utils.py +0 -0
- crawlee/storage_clients/_memory/_dataset_client.py +14 -2
- crawlee/storage_clients/_memory/_key_value_store_client.py +14 -2
- crawlee/storage_clients/_memory/_request_queue_client.py +43 -12
- crawlee/storage_clients/_memory/_storage_client.py +6 -3
- crawlee/storage_clients/_sql/__init__.py +6 -0
- crawlee/storage_clients/_sql/_client_mixin.py +385 -0
- crawlee/storage_clients/_sql/_dataset_client.py +310 -0
- crawlee/storage_clients/_sql/_db_models.py +269 -0
- crawlee/storage_clients/_sql/_key_value_store_client.py +299 -0
- crawlee/storage_clients/_sql/_request_queue_client.py +706 -0
- crawlee/storage_clients/_sql/_storage_client.py +282 -0
- crawlee/storage_clients/_sql/py.typed +0 -0
- crawlee/storage_clients/models.py +10 -10
- crawlee/storages/_base.py +3 -1
- crawlee/storages/_dataset.py +9 -2
- crawlee/storages/_key_value_store.py +9 -2
- crawlee/storages/_request_queue.py +7 -2
- crawlee/storages/_storage_instance_manager.py +126 -72
- {crawlee-1.0.0rc1.dist-info → crawlee-1.0.1.dist-info}/METADATA +12 -5
- {crawlee-1.0.0rc1.dist-info → crawlee-1.0.1.dist-info}/RECORD +59 -49
- {crawlee-1.0.0rc1.dist-info → crawlee-1.0.1.dist-info}/WHEEL +0 -0
- {crawlee-1.0.0rc1.dist-info → crawlee-1.0.1.dist-info}/entry_points.txt +0 -0
- {crawlee-1.0.0rc1.dist-info → crawlee-1.0.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -10,6 +10,7 @@ from typing_extensions import override
|
|
|
10
10
|
|
|
11
11
|
from crawlee import Request
|
|
12
12
|
from crawlee._utils.crypto import crypto_random_object_id
|
|
13
|
+
from crawlee._utils.raise_if_too_many_kwargs import raise_if_too_many_kwargs
|
|
13
14
|
from crawlee.storage_clients._base import RequestQueueClient
|
|
14
15
|
from crawlee.storage_clients.models import AddRequestsResponse, ProcessedRequest, RequestQueueMetadata
|
|
15
16
|
|
|
@@ -63,6 +64,7 @@ class MemoryRequestQueueClient(RequestQueueClient):
|
|
|
63
64
|
*,
|
|
64
65
|
id: str | None,
|
|
65
66
|
name: str | None,
|
|
67
|
+
alias: str | None,
|
|
66
68
|
) -> MemoryRequestQueueClient:
|
|
67
69
|
"""Open or create a new memory request queue client.
|
|
68
70
|
|
|
@@ -70,14 +72,24 @@ class MemoryRequestQueueClient(RequestQueueClient):
|
|
|
70
72
|
memory queues don't check for existing queues with the same name or ID since all data exists only
|
|
71
73
|
in memory and is lost when the process terminates.
|
|
72
74
|
|
|
75
|
+
Alias does not have any effect on the memory storage client implementation, because unnamed storages
|
|
76
|
+
are supported by default, since data are not persisted.
|
|
77
|
+
|
|
73
78
|
Args:
|
|
74
79
|
id: The ID of the request queue. If not provided, a random ID will be generated.
|
|
75
|
-
name: The name of the request queue
|
|
80
|
+
name: The name of the request queue for named (global scope) storages.
|
|
81
|
+
alias: The alias of the request queue for unnamed (run scope) storages.
|
|
76
82
|
|
|
77
83
|
Returns:
|
|
78
84
|
An instance for the opened or created storage client.
|
|
85
|
+
|
|
86
|
+
Raises:
|
|
87
|
+
ValueError: If both name and alias are provided.
|
|
79
88
|
"""
|
|
80
|
-
#
|
|
89
|
+
# Validate input parameters.
|
|
90
|
+
raise_if_too_many_kwargs(id=id, name=name, alias=alias)
|
|
91
|
+
|
|
92
|
+
# Create a new queue
|
|
81
93
|
queue_id = id or crypto_random_object_id()
|
|
82
94
|
now = datetime.now(timezone.utc)
|
|
83
95
|
|
|
@@ -137,6 +149,7 @@ class MemoryRequestQueueClient(RequestQueueClient):
|
|
|
137
149
|
|
|
138
150
|
was_already_present = existing_request is not None
|
|
139
151
|
was_already_handled = was_already_present and existing_request and existing_request.handled_at is not None
|
|
152
|
+
is_in_progress = request.unique_key in self._in_progress_requests
|
|
140
153
|
|
|
141
154
|
# If the request is already in the queue and handled, don't add it again.
|
|
142
155
|
if was_already_handled:
|
|
@@ -149,21 +162,40 @@ class MemoryRequestQueueClient(RequestQueueClient):
|
|
|
149
162
|
)
|
|
150
163
|
continue
|
|
151
164
|
|
|
165
|
+
# If the request is already in progress, don't add it again.
|
|
166
|
+
if is_in_progress:
|
|
167
|
+
processed_requests.append(
|
|
168
|
+
ProcessedRequest(
|
|
169
|
+
unique_key=request.unique_key,
|
|
170
|
+
was_already_present=True,
|
|
171
|
+
was_already_handled=False,
|
|
172
|
+
)
|
|
173
|
+
)
|
|
174
|
+
continue
|
|
175
|
+
|
|
152
176
|
# If the request is already in the queue but not handled, update it.
|
|
153
177
|
if was_already_present and existing_request:
|
|
154
|
-
# Update the existing request with any new data and
|
|
155
|
-
# remove old request from pending queue if it's there.
|
|
156
|
-
with suppress(ValueError):
|
|
157
|
-
self._pending_requests.remove(existing_request)
|
|
158
|
-
|
|
159
178
|
# Update indexes.
|
|
160
179
|
self._requests_by_unique_key[request.unique_key] = request
|
|
161
180
|
|
|
162
|
-
#
|
|
181
|
+
# We only update `forefront` by updating its position by shifting it to the left.
|
|
163
182
|
if forefront:
|
|
183
|
+
# Update the existing request with any new data and
|
|
184
|
+
# remove old request from pending queue if it's there.
|
|
185
|
+
with suppress(ValueError):
|
|
186
|
+
self._pending_requests.remove(existing_request)
|
|
187
|
+
|
|
188
|
+
# Add updated request back to queue.
|
|
164
189
|
self._pending_requests.appendleft(request)
|
|
165
|
-
|
|
166
|
-
|
|
190
|
+
|
|
191
|
+
processed_requests.append(
|
|
192
|
+
ProcessedRequest(
|
|
193
|
+
unique_key=request.unique_key,
|
|
194
|
+
was_already_present=True,
|
|
195
|
+
was_already_handled=False,
|
|
196
|
+
)
|
|
197
|
+
)
|
|
198
|
+
|
|
167
199
|
# Add the new request to the queue.
|
|
168
200
|
else:
|
|
169
201
|
if forefront:
|
|
@@ -205,8 +237,7 @@ class MemoryRequestQueueClient(RequestQueueClient):
|
|
|
205
237
|
|
|
206
238
|
# Skip if already in progress (shouldn't happen, but safety check).
|
|
207
239
|
if request.unique_key in self._in_progress_requests:
|
|
208
|
-
|
|
209
|
-
break
|
|
240
|
+
continue
|
|
210
241
|
|
|
211
242
|
# Mark as in progress.
|
|
212
243
|
self._in_progress_requests[request.unique_key] = request
|
|
@@ -33,10 +33,11 @@ class MemoryStorageClient(StorageClient):
|
|
|
33
33
|
*,
|
|
34
34
|
id: str | None = None,
|
|
35
35
|
name: str | None = None,
|
|
36
|
+
alias: str | None = None,
|
|
36
37
|
configuration: Configuration | None = None,
|
|
37
38
|
) -> MemoryDatasetClient:
|
|
38
39
|
configuration = configuration or Configuration.get_global_configuration()
|
|
39
|
-
client = await MemoryDatasetClient.open(id=id, name=name)
|
|
40
|
+
client = await MemoryDatasetClient.open(id=id, name=name, alias=alias)
|
|
40
41
|
await self._purge_if_needed(client, configuration)
|
|
41
42
|
return client
|
|
42
43
|
|
|
@@ -46,10 +47,11 @@ class MemoryStorageClient(StorageClient):
|
|
|
46
47
|
*,
|
|
47
48
|
id: str | None = None,
|
|
48
49
|
name: str | None = None,
|
|
50
|
+
alias: str | None = None,
|
|
49
51
|
configuration: Configuration | None = None,
|
|
50
52
|
) -> MemoryKeyValueStoreClient:
|
|
51
53
|
configuration = configuration or Configuration.get_global_configuration()
|
|
52
|
-
client = await MemoryKeyValueStoreClient.open(id=id, name=name)
|
|
54
|
+
client = await MemoryKeyValueStoreClient.open(id=id, name=name, alias=alias)
|
|
53
55
|
await self._purge_if_needed(client, configuration)
|
|
54
56
|
return client
|
|
55
57
|
|
|
@@ -59,9 +61,10 @@ class MemoryStorageClient(StorageClient):
|
|
|
59
61
|
*,
|
|
60
62
|
id: str | None = None,
|
|
61
63
|
name: str | None = None,
|
|
64
|
+
alias: str | None = None,
|
|
62
65
|
configuration: Configuration | None = None,
|
|
63
66
|
) -> MemoryRequestQueueClient:
|
|
64
67
|
configuration = configuration or Configuration.get_global_configuration()
|
|
65
|
-
client = await MemoryRequestQueueClient.open(id=id, name=name)
|
|
68
|
+
client = await MemoryRequestQueueClient.open(id=id, name=name, alias=alias)
|
|
66
69
|
await self._purge_if_needed(client, configuration)
|
|
67
70
|
return client
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
from ._dataset_client import SqlDatasetClient
|
|
2
|
+
from ._key_value_store_client import SqlKeyValueStoreClient
|
|
3
|
+
from ._request_queue_client import SqlRequestQueueClient
|
|
4
|
+
from ._storage_client import SqlStorageClient
|
|
5
|
+
|
|
6
|
+
__all__ = ['SqlDatasetClient', 'SqlKeyValueStoreClient', 'SqlRequestQueueClient', 'SqlStorageClient']
|
|
@@ -0,0 +1,385 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from logging import getLogger
|
|
7
|
+
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast, overload
|
|
8
|
+
|
|
9
|
+
from sqlalchemy import delete, select, text, update
|
|
10
|
+
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
|
11
|
+
from sqlalchemy.dialects.sqlite import insert as lite_insert
|
|
12
|
+
from sqlalchemy.exc import SQLAlchemyError
|
|
13
|
+
|
|
14
|
+
from crawlee._utils.crypto import crypto_random_object_id
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from collections.abc import AsyncIterator
|
|
18
|
+
|
|
19
|
+
from sqlalchemy import Insert
|
|
20
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
21
|
+
from sqlalchemy.orm import DeclarativeBase
|
|
22
|
+
from typing_extensions import NotRequired, Self
|
|
23
|
+
|
|
24
|
+
from crawlee.storage_clients.models import DatasetMetadata, KeyValueStoreMetadata, RequestQueueMetadata
|
|
25
|
+
|
|
26
|
+
from ._db_models import (
|
|
27
|
+
DatasetItemDb,
|
|
28
|
+
DatasetMetadataDb,
|
|
29
|
+
KeyValueStoreMetadataDb,
|
|
30
|
+
KeyValueStoreRecordDb,
|
|
31
|
+
RequestDb,
|
|
32
|
+
RequestQueueMetadataDb,
|
|
33
|
+
)
|
|
34
|
+
from ._storage_client import SqlStorageClient
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
logger = getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class MetadataUpdateParams(TypedDict, total=False):
|
|
41
|
+
"""Parameters for updating metadata."""
|
|
42
|
+
|
|
43
|
+
update_accessed_at: NotRequired[bool]
|
|
44
|
+
update_modified_at: NotRequired[bool]
|
|
45
|
+
force: NotRequired[bool]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class SqlClientMixin(ABC):
|
|
49
|
+
"""Mixin class for SQL clients.
|
|
50
|
+
|
|
51
|
+
This mixin provides common SQL operations and basic methods for SQL storage clients.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
_DEFAULT_NAME: ClassVar[str]
|
|
55
|
+
"""Default name when none provided."""
|
|
56
|
+
|
|
57
|
+
_METADATA_TABLE: ClassVar[type[DatasetMetadataDb | KeyValueStoreMetadataDb | RequestQueueMetadataDb]]
|
|
58
|
+
"""SQLAlchemy model for metadata."""
|
|
59
|
+
|
|
60
|
+
_ITEM_TABLE: ClassVar[type[DatasetItemDb | KeyValueStoreRecordDb | RequestDb]]
|
|
61
|
+
"""SQLAlchemy model for items."""
|
|
62
|
+
|
|
63
|
+
_CLIENT_TYPE: ClassVar[str]
|
|
64
|
+
"""Human-readable client type for error messages."""
|
|
65
|
+
|
|
66
|
+
def __init__(self, *, id: str, storage_client: SqlStorageClient) -> None:
|
|
67
|
+
self._id = id
|
|
68
|
+
self._storage_client = storage_client
|
|
69
|
+
|
|
70
|
+
# Time tracking to reduce database writes during frequent operation
|
|
71
|
+
self._accessed_at_allow_update_after: datetime | None = None
|
|
72
|
+
self._modified_at_allow_update_after: datetime | None = None
|
|
73
|
+
self._accessed_modified_update_interval = storage_client.get_accessed_modified_update_interval()
|
|
74
|
+
|
|
75
|
+
@classmethod
|
|
76
|
+
async def _open(
|
|
77
|
+
cls,
|
|
78
|
+
*,
|
|
79
|
+
id: str | None,
|
|
80
|
+
name: str | None,
|
|
81
|
+
internal_name: str,
|
|
82
|
+
storage_client: SqlStorageClient,
|
|
83
|
+
metadata_model: type[DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata],
|
|
84
|
+
session: AsyncSession,
|
|
85
|
+
extra_metadata_fields: dict[str, Any],
|
|
86
|
+
) -> Self:
|
|
87
|
+
"""Open existing storage or create new one.
|
|
88
|
+
|
|
89
|
+
Internal method used by _safely_open.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
id: Storage ID to open (takes precedence over name).
|
|
93
|
+
name: The name of the storage.
|
|
94
|
+
internal_name: The database name for the storage based on name or alias.
|
|
95
|
+
storage_client: SQL storage client instance.
|
|
96
|
+
metadata_model: Pydantic model for metadata validation.
|
|
97
|
+
session: Active database session.
|
|
98
|
+
extra_metadata_fields: Storage-specific metadata fields.
|
|
99
|
+
"""
|
|
100
|
+
orm_metadata: DatasetMetadataDb | KeyValueStoreMetadataDb | RequestQueueMetadataDb | None = None
|
|
101
|
+
if id:
|
|
102
|
+
orm_metadata = await session.get(cls._METADATA_TABLE, id)
|
|
103
|
+
if not orm_metadata:
|
|
104
|
+
raise ValueError(f'{cls._CLIENT_TYPE} with ID "{id}" not found.')
|
|
105
|
+
else:
|
|
106
|
+
stmt = select(cls._METADATA_TABLE).where(cls._METADATA_TABLE.internal_name == internal_name)
|
|
107
|
+
result = await session.execute(stmt)
|
|
108
|
+
orm_metadata = result.scalar_one_or_none() # type: ignore[assignment]
|
|
109
|
+
|
|
110
|
+
if orm_metadata:
|
|
111
|
+
client = cls(id=orm_metadata.id, storage_client=storage_client)
|
|
112
|
+
await client._update_metadata(session, update_accessed_at=True)
|
|
113
|
+
else:
|
|
114
|
+
now = datetime.now(timezone.utc)
|
|
115
|
+
metadata = metadata_model(
|
|
116
|
+
id=crypto_random_object_id(),
|
|
117
|
+
name=name,
|
|
118
|
+
created_at=now,
|
|
119
|
+
accessed_at=now,
|
|
120
|
+
modified_at=now,
|
|
121
|
+
**extra_metadata_fields,
|
|
122
|
+
)
|
|
123
|
+
client = cls(id=metadata.id, storage_client=storage_client)
|
|
124
|
+
client._accessed_at_allow_update_after = now + client._accessed_modified_update_interval
|
|
125
|
+
client._modified_at_allow_update_after = now + client._accessed_modified_update_interval
|
|
126
|
+
session.add(cls._METADATA_TABLE(**metadata.model_dump(), internal_name=internal_name))
|
|
127
|
+
|
|
128
|
+
return client
|
|
129
|
+
|
|
130
|
+
@classmethod
|
|
131
|
+
async def _safely_open(
|
|
132
|
+
cls,
|
|
133
|
+
*,
|
|
134
|
+
id: str | None,
|
|
135
|
+
name: str | None,
|
|
136
|
+
alias: str | None = None,
|
|
137
|
+
storage_client: SqlStorageClient,
|
|
138
|
+
metadata_model: type[DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata],
|
|
139
|
+
extra_metadata_fields: dict[str, Any],
|
|
140
|
+
) -> Self:
|
|
141
|
+
"""Safely open storage with transaction handling.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
id: Storage ID to open (takes precedence over name).
|
|
145
|
+
name: The name of the storage for named (global scope) storages.
|
|
146
|
+
alias: The alias of the storage for unnamed (run scope) storages.
|
|
147
|
+
storage_client: SQL storage client instance.
|
|
148
|
+
client_class: Concrete client class to instantiate.
|
|
149
|
+
metadata_model: Pydantic model for metadata validation.
|
|
150
|
+
extra_metadata_fields: Storage-specific metadata fields.
|
|
151
|
+
"""
|
|
152
|
+
# Validate input parameters.
|
|
153
|
+
specified_params = sum(1 for param in [id, name, alias] if param is not None)
|
|
154
|
+
if specified_params > 1:
|
|
155
|
+
raise ValueError('Only one of "id", "name", or "alias" can be specified, not multiple.')
|
|
156
|
+
|
|
157
|
+
internal_name = name or alias or cls._DEFAULT_NAME
|
|
158
|
+
|
|
159
|
+
async with storage_client.create_session() as session:
|
|
160
|
+
try:
|
|
161
|
+
client = await cls._open(
|
|
162
|
+
id=id,
|
|
163
|
+
name=name,
|
|
164
|
+
internal_name=internal_name,
|
|
165
|
+
storage_client=storage_client,
|
|
166
|
+
metadata_model=metadata_model,
|
|
167
|
+
session=session,
|
|
168
|
+
extra_metadata_fields=extra_metadata_fields,
|
|
169
|
+
)
|
|
170
|
+
await session.commit()
|
|
171
|
+
except SQLAlchemyError:
|
|
172
|
+
await session.rollback()
|
|
173
|
+
|
|
174
|
+
stmt = select(cls._METADATA_TABLE).where(cls._METADATA_TABLE.internal_name == internal_name)
|
|
175
|
+
result = await session.execute(stmt)
|
|
176
|
+
orm_metadata: DatasetMetadataDb | KeyValueStoreMetadataDb | RequestQueueMetadataDb | None
|
|
177
|
+
orm_metadata = cast(
|
|
178
|
+
'DatasetMetadataDb | KeyValueStoreMetadataDb | RequestQueueMetadataDb | None',
|
|
179
|
+
result.scalar_one_or_none(),
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
if not orm_metadata:
|
|
183
|
+
raise ValueError(f'{cls._CLIENT_TYPE} with Name "{internal_name}" not found.') from None
|
|
184
|
+
|
|
185
|
+
client = cls(id=orm_metadata.id, storage_client=storage_client)
|
|
186
|
+
|
|
187
|
+
return client
|
|
188
|
+
|
|
189
|
+
@asynccontextmanager
|
|
190
|
+
async def get_session(self, *, with_simple_commit: bool = False) -> AsyncIterator[AsyncSession]:
|
|
191
|
+
"""Create a new SQLAlchemy session for this storage."""
|
|
192
|
+
async with self._storage_client.create_session() as session:
|
|
193
|
+
# For operations where a final commit is mandatory and does not require specific processing conditions
|
|
194
|
+
if with_simple_commit:
|
|
195
|
+
try:
|
|
196
|
+
yield session
|
|
197
|
+
await session.commit()
|
|
198
|
+
except SQLAlchemyError as e:
|
|
199
|
+
logger.warning(f'Error occurred during session transaction: {e}')
|
|
200
|
+
await session.rollback()
|
|
201
|
+
else:
|
|
202
|
+
yield session
|
|
203
|
+
|
|
204
|
+
def _build_insert_stmt_with_ignore(
|
|
205
|
+
self, table_model: type[DeclarativeBase], insert_values: dict[str, Any] | list[dict[str, Any]]
|
|
206
|
+
) -> Insert:
|
|
207
|
+
"""Build an insert statement with ignore for the SQL dialect.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
table_model: SQLAlchemy table model.
|
|
211
|
+
insert_values: Single dict or list of dicts to insert.
|
|
212
|
+
"""
|
|
213
|
+
if isinstance(insert_values, dict):
|
|
214
|
+
insert_values = [insert_values]
|
|
215
|
+
|
|
216
|
+
dialect = self._storage_client.get_dialect_name()
|
|
217
|
+
|
|
218
|
+
if dialect == 'postgresql':
|
|
219
|
+
return pg_insert(table_model).values(insert_values).on_conflict_do_nothing()
|
|
220
|
+
|
|
221
|
+
if dialect == 'sqlite':
|
|
222
|
+
return lite_insert(table_model).values(insert_values).on_conflict_do_nothing()
|
|
223
|
+
|
|
224
|
+
raise NotImplementedError(f'Insert with ignore not supported for dialect: {dialect}')
|
|
225
|
+
|
|
226
|
+
def _build_upsert_stmt(
|
|
227
|
+
self,
|
|
228
|
+
table_model: type[DeclarativeBase],
|
|
229
|
+
insert_values: dict[str, Any] | list[dict[str, Any]],
|
|
230
|
+
update_columns: list[str],
|
|
231
|
+
conflict_cols: list[str] | None = None,
|
|
232
|
+
) -> Insert:
|
|
233
|
+
"""Build an upsert statement for the SQL dialect.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
table_model: SQLAlchemy table model.
|
|
237
|
+
insert_values: Single dict or list of dicts to upsert.
|
|
238
|
+
update_columns: Column names to update on conflict.
|
|
239
|
+
conflict_cols: Column names that define uniqueness (for PostgreSQL/SQLite).
|
|
240
|
+
|
|
241
|
+
"""
|
|
242
|
+
if isinstance(insert_values, dict):
|
|
243
|
+
insert_values = [insert_values]
|
|
244
|
+
|
|
245
|
+
dialect = self._storage_client.get_dialect_name()
|
|
246
|
+
|
|
247
|
+
if dialect == 'postgresql':
|
|
248
|
+
pg_stmt = pg_insert(table_model).values(insert_values)
|
|
249
|
+
set_ = {col: getattr(pg_stmt.excluded, col) for col in update_columns}
|
|
250
|
+
return pg_stmt.on_conflict_do_update(index_elements=conflict_cols, set_=set_)
|
|
251
|
+
|
|
252
|
+
if dialect == 'sqlite':
|
|
253
|
+
lite_stmt = lite_insert(table_model).values(insert_values)
|
|
254
|
+
set_ = {col: getattr(lite_stmt.excluded, col) for col in update_columns}
|
|
255
|
+
return lite_stmt.on_conflict_do_update(index_elements=conflict_cols, set_=set_)
|
|
256
|
+
|
|
257
|
+
raise NotImplementedError(f'Upsert not supported for dialect: {dialect}')
|
|
258
|
+
|
|
259
|
+
async def _purge(self, metadata_kwargs: MetadataUpdateParams) -> None:
|
|
260
|
+
"""Drop all items in storage and update metadata.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
metadata_kwargs: Arguments to pass to _update_metadata.
|
|
264
|
+
"""
|
|
265
|
+
stmt = delete(self._ITEM_TABLE).where(self._ITEM_TABLE.storage_id == self._id)
|
|
266
|
+
async with self.get_session(with_simple_commit=True) as session:
|
|
267
|
+
await session.execute(stmt)
|
|
268
|
+
await self._update_metadata(session, **metadata_kwargs)
|
|
269
|
+
|
|
270
|
+
async def _drop(self) -> None:
|
|
271
|
+
"""Delete this storage and all its data.
|
|
272
|
+
|
|
273
|
+
This operation is irreversible. Uses CASCADE deletion to remove all related items.
|
|
274
|
+
"""
|
|
275
|
+
stmt = delete(self._METADATA_TABLE).where(self._METADATA_TABLE.id == self._id)
|
|
276
|
+
async with self.get_session(with_simple_commit=True) as session:
|
|
277
|
+
if self._storage_client.get_dialect_name() == 'sqlite':
|
|
278
|
+
# foreign_keys=ON is set at the connection level. Required for cascade deletion.
|
|
279
|
+
await session.execute(text('PRAGMA foreign_keys=ON'))
|
|
280
|
+
await session.execute(stmt)
|
|
281
|
+
|
|
282
|
+
@overload
|
|
283
|
+
async def _get_metadata(self, metadata_model: type[DatasetMetadata]) -> DatasetMetadata: ...
|
|
284
|
+
@overload
|
|
285
|
+
async def _get_metadata(self, metadata_model: type[KeyValueStoreMetadata]) -> KeyValueStoreMetadata: ...
|
|
286
|
+
@overload
|
|
287
|
+
async def _get_metadata(self, metadata_model: type[RequestQueueMetadata]) -> RequestQueueMetadata: ...
|
|
288
|
+
|
|
289
|
+
async def _get_metadata(
|
|
290
|
+
self, metadata_model: type[DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata]
|
|
291
|
+
) -> DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata:
|
|
292
|
+
"""Retrieve client metadata."""
|
|
293
|
+
async with self.get_session() as session:
|
|
294
|
+
orm_metadata = await session.get(self._METADATA_TABLE, self._id)
|
|
295
|
+
if not orm_metadata:
|
|
296
|
+
raise ValueError(f'{self._CLIENT_TYPE} with ID "{self._id}" not found.')
|
|
297
|
+
|
|
298
|
+
return metadata_model.model_validate(orm_metadata)
|
|
299
|
+
|
|
300
|
+
def _default_update_metadata(
|
|
301
|
+
self, *, update_accessed_at: bool = False, update_modified_at: bool = False, force: bool = False
|
|
302
|
+
) -> dict[str, Any]:
|
|
303
|
+
"""Prepare common metadata updates with rate limiting.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
update_accessed_at: Whether to update accessed_at timestamp.
|
|
307
|
+
update_modified_at: Whether to update modified_at timestamp.
|
|
308
|
+
force: Whether to force the update regardless of rate limiting.
|
|
309
|
+
"""
|
|
310
|
+
values_to_set: dict[str, Any] = {}
|
|
311
|
+
now = datetime.now(timezone.utc)
|
|
312
|
+
|
|
313
|
+
# If the record must be updated (for example, when updating counters), we update timestamps and shift the time.
|
|
314
|
+
if force:
|
|
315
|
+
if update_modified_at:
|
|
316
|
+
values_to_set['modified_at'] = now
|
|
317
|
+
self._modified_at_allow_update_after = now + self._accessed_modified_update_interval
|
|
318
|
+
if update_accessed_at:
|
|
319
|
+
values_to_set['accessed_at'] = now
|
|
320
|
+
self._accessed_at_allow_update_after = now + self._accessed_modified_update_interval
|
|
321
|
+
|
|
322
|
+
elif update_modified_at and (
|
|
323
|
+
self._modified_at_allow_update_after is None or now >= self._modified_at_allow_update_after
|
|
324
|
+
):
|
|
325
|
+
values_to_set['modified_at'] = now
|
|
326
|
+
self._modified_at_allow_update_after = now + self._accessed_modified_update_interval
|
|
327
|
+
# The record will be updated, we can update `accessed_at` and shift the time.
|
|
328
|
+
if update_accessed_at:
|
|
329
|
+
values_to_set['accessed_at'] = now
|
|
330
|
+
self._accessed_at_allow_update_after = now + self._accessed_modified_update_interval
|
|
331
|
+
|
|
332
|
+
elif update_accessed_at and (
|
|
333
|
+
self._accessed_at_allow_update_after is None or now >= self._accessed_at_allow_update_after
|
|
334
|
+
):
|
|
335
|
+
values_to_set['accessed_at'] = now
|
|
336
|
+
self._accessed_at_allow_update_after = now + self._accessed_modified_update_interval
|
|
337
|
+
|
|
338
|
+
return values_to_set
|
|
339
|
+
|
|
340
|
+
@abstractmethod
|
|
341
|
+
def _specific_update_metadata(self, **kwargs: Any) -> dict[str, Any]:
|
|
342
|
+
"""Prepare storage-specific metadata updates.
|
|
343
|
+
|
|
344
|
+
Must be implemented by concrete classes.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
**kwargs: Storage-specific update parameters.
|
|
348
|
+
"""
|
|
349
|
+
|
|
350
|
+
async def _update_metadata(
|
|
351
|
+
self,
|
|
352
|
+
session: AsyncSession,
|
|
353
|
+
*,
|
|
354
|
+
update_accessed_at: bool = False,
|
|
355
|
+
update_modified_at: bool = False,
|
|
356
|
+
force: bool = False,
|
|
357
|
+
**kwargs: Any,
|
|
358
|
+
) -> bool:
|
|
359
|
+
"""Update storage metadata combining common and specific fields.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
session: Active database session.
|
|
363
|
+
update_accessed_at: Whether to update accessed_at timestamp.
|
|
364
|
+
update_modified_at: Whether to update modified_at timestamp.
|
|
365
|
+
force: Whether to force the update timestamps regardless of rate limiting.
|
|
366
|
+
**kwargs: Additional arguments for _specific_update_metadata.
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
True if any updates were made, False otherwise
|
|
370
|
+
"""
|
|
371
|
+
values_to_set = self._default_update_metadata(
|
|
372
|
+
update_accessed_at=update_accessed_at, update_modified_at=update_modified_at, force=force
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
values_to_set.update(self._specific_update_metadata(**kwargs))
|
|
376
|
+
|
|
377
|
+
if values_to_set:
|
|
378
|
+
if (stmt := values_to_set.pop('custom_stmt', None)) is None:
|
|
379
|
+
stmt = update(self._METADATA_TABLE).where(self._METADATA_TABLE.id == self._id)
|
|
380
|
+
|
|
381
|
+
stmt = stmt.values(**values_to_set)
|
|
382
|
+
await session.execute(stmt)
|
|
383
|
+
return True
|
|
384
|
+
|
|
385
|
+
return False
|