datamarket 0.6.0__py3-none-any.whl → 0.10.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datamarket might be problematic. Click here for more details.
- datamarket/__init__.py +0 -1
- datamarket/exceptions/__init__.py +1 -0
- datamarket/exceptions/main.py +118 -0
- datamarket/interfaces/alchemy.py +1934 -25
- datamarket/interfaces/aws.py +81 -14
- datamarket/interfaces/azure.py +127 -0
- datamarket/interfaces/drive.py +60 -10
- datamarket/interfaces/ftp.py +37 -14
- datamarket/interfaces/llm.py +1220 -0
- datamarket/interfaces/nominatim.py +314 -42
- datamarket/interfaces/peerdb.py +272 -104
- datamarket/interfaces/proxy.py +354 -50
- datamarket/interfaces/tinybird.py +7 -15
- datamarket/params/nominatim.py +439 -0
- datamarket/utils/__init__.py +1 -1
- datamarket/utils/airflow.py +10 -7
- datamarket/utils/alchemy.py +2 -1
- datamarket/utils/logs.py +88 -0
- datamarket/utils/main.py +138 -10
- datamarket/utils/nominatim.py +201 -0
- datamarket/utils/playwright/__init__.py +0 -0
- datamarket/utils/playwright/async_api.py +274 -0
- datamarket/utils/playwright/sync_api.py +281 -0
- datamarket/utils/requests.py +655 -0
- datamarket/utils/selenium.py +6 -12
- datamarket/utils/strings/__init__.py +1 -0
- datamarket/utils/strings/normalization.py +217 -0
- datamarket/utils/strings/obfuscation.py +153 -0
- datamarket/utils/strings/standardization.py +40 -0
- datamarket/utils/typer.py +2 -1
- datamarket/utils/types.py +1 -0
- datamarket-0.10.3.dist-info/METADATA +172 -0
- datamarket-0.10.3.dist-info/RECORD +38 -0
- {datamarket-0.6.0.dist-info → datamarket-0.10.3.dist-info}/WHEEL +1 -2
- datamarket-0.6.0.dist-info/METADATA +0 -49
- datamarket-0.6.0.dist-info/RECORD +0 -24
- datamarket-0.6.0.dist-info/top_level.txt +0 -1
- {datamarket-0.6.0.dist-info → datamarket-0.10.3.dist-info/licenses}/LICENSE +0 -0
datamarket/interfaces/alchemy.py
CHANGED
|
@@ -1,41 +1,702 @@
|
|
|
1
1
|
########################################################################################################################
|
|
2
2
|
# IMPORTS
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
import logging
|
|
7
|
+
import time
|
|
8
|
+
from collections import defaultdict, deque
|
|
9
|
+
from collections.abc import MutableMapping
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from enum import Enum, auto
|
|
12
|
+
from typing import (
|
|
13
|
+
TYPE_CHECKING,
|
|
14
|
+
Any,
|
|
15
|
+
Callable,
|
|
16
|
+
Deque,
|
|
17
|
+
Dict,
|
|
18
|
+
Iterable,
|
|
19
|
+
Iterator,
|
|
20
|
+
List,
|
|
21
|
+
Optional,
|
|
22
|
+
Sequence,
|
|
23
|
+
Set,
|
|
24
|
+
Tuple,
|
|
25
|
+
Type,
|
|
26
|
+
TypeVar,
|
|
27
|
+
Union,
|
|
28
|
+
)
|
|
5
29
|
from urllib.parse import quote_plus
|
|
6
30
|
|
|
7
|
-
|
|
31
|
+
import numpy as np
|
|
32
|
+
from sqlalchemy import (
|
|
33
|
+
DDL,
|
|
34
|
+
FrozenResult,
|
|
35
|
+
Result,
|
|
36
|
+
Select,
|
|
37
|
+
SQLColumnExpression,
|
|
38
|
+
Table,
|
|
39
|
+
and_,
|
|
40
|
+
bindparam,
|
|
41
|
+
create_engine,
|
|
42
|
+
func,
|
|
43
|
+
inspect,
|
|
44
|
+
or_,
|
|
45
|
+
select,
|
|
46
|
+
text,
|
|
47
|
+
)
|
|
48
|
+
from sqlalchemy.dialects.postgresql import insert
|
|
8
49
|
from sqlalchemy.exc import IntegrityError
|
|
9
|
-
from sqlalchemy.
|
|
50
|
+
from sqlalchemy.ext.declarative import DeclarativeMeta
|
|
51
|
+
from sqlalchemy.orm import Query, Session, sessionmaker
|
|
52
|
+
from sqlalchemy.sql.elements import UnaryExpression
|
|
53
|
+
from sqlalchemy.sql.expression import ClauseElement
|
|
54
|
+
from sqlalchemy.sql.operators import desc_op
|
|
55
|
+
|
|
56
|
+
from datamarket.utils.logs import SystemColor, colorize
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
import pandas as pd # type: ignore
|
|
60
|
+
except ImportError:
|
|
61
|
+
pd = None # type: ignore
|
|
62
|
+
|
|
63
|
+
if TYPE_CHECKING:
|
|
64
|
+
import pandas as pd # noqa: F401
|
|
65
|
+
|
|
10
66
|
|
|
11
67
|
########################################################################################################################
|
|
12
|
-
#
|
|
68
|
+
# TYPES / CONSTANTS
|
|
69
|
+
|
|
13
70
|
|
|
14
71
|
logger = logging.getLogger(__name__)
|
|
15
72
|
|
|
73
|
+
ModelType = TypeVar("ModelType", bound=DeclarativeMeta)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class OpAction(Enum):
|
|
77
|
+
INSERT = "insert"
|
|
78
|
+
UPSERT = "upsert"
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass(frozen=True)
|
|
82
|
+
class BatchOp:
|
|
83
|
+
"""
|
|
84
|
+
Immutable representation of a single SQL operation.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
obj: ModelType
|
|
88
|
+
action: OpAction
|
|
89
|
+
index_elements: Optional[List[str]] = None
|
|
90
|
+
|
|
91
|
+
@classmethod
|
|
92
|
+
def insert(cls, obj: ModelType) -> "BatchOp":
|
|
93
|
+
return cls(obj=obj, action=OpAction.INSERT)
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def upsert(cls, obj: ModelType, index_elements: List[str]) -> "BatchOp":
|
|
97
|
+
if not index_elements:
|
|
98
|
+
raise ValueError("Upsert requires 'index_elements'.")
|
|
99
|
+
return cls(obj=obj, action=OpAction.UPSERT, index_elements=index_elements)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@dataclass
|
|
103
|
+
class AtomicUnit:
|
|
104
|
+
"""
|
|
105
|
+
A mutable container for BatchOps that must succeed or fail together.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
ops: List[BatchOp] = field(default_factory=list)
|
|
109
|
+
|
|
110
|
+
def add(self, op: Union[BatchOp, List[BatchOp]]) -> "AtomicUnit":
|
|
111
|
+
"""
|
|
112
|
+
Add an existing BatchOp (or list of BatchOps) to the unit.
|
|
113
|
+
"""
|
|
114
|
+
if isinstance(op, list):
|
|
115
|
+
self.ops.extend(op)
|
|
116
|
+
else:
|
|
117
|
+
self.ops.append(op)
|
|
118
|
+
return self
|
|
119
|
+
|
|
120
|
+
def add_insert(self, obj: ModelType) -> "AtomicUnit":
|
|
121
|
+
"""Helper to create and add an Insert op."""
|
|
122
|
+
return self.add(BatchOp.insert(obj))
|
|
123
|
+
|
|
124
|
+
def add_upsert(self, obj: ModelType, index_elements: List[str]) -> "AtomicUnit":
|
|
125
|
+
"""Helper to create and add an Upsert op."""
|
|
126
|
+
return self.add(BatchOp.upsert(obj, index_elements))
|
|
127
|
+
|
|
128
|
+
def __len__(self) -> int:
|
|
129
|
+
return len(self.ops)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@dataclass
|
|
133
|
+
class _BatchStats:
|
|
134
|
+
"""Helper to track pipeline statistics explicitly."""
|
|
135
|
+
|
|
136
|
+
size: int = 0 # Total operations
|
|
137
|
+
inserts: int = 0 # Pure Inserts (OpAction.INSERT)
|
|
138
|
+
|
|
139
|
+
# Upsert Breakdown
|
|
140
|
+
upserts: int = 0 # Total Upsert Intent (OpAction.UPSERT)
|
|
141
|
+
upsert_inserts: int = 0 # Outcome: Row created
|
|
142
|
+
upsert_updates: int = 0 # Outcome: Row modified
|
|
143
|
+
|
|
144
|
+
duplicates: int = 0 # Skipped ops (rowcount == 0)
|
|
145
|
+
failures: int = 0 # Failed/Skipped atomic units (exception thrown)
|
|
146
|
+
db_time: float = 0.0 # Time spent in DB flush
|
|
147
|
+
|
|
148
|
+
def add(self, other: "_BatchStats") -> None:
|
|
149
|
+
"""Accumulate stats from another batch."""
|
|
150
|
+
self.size += other.size
|
|
151
|
+
self.inserts += other.inserts
|
|
152
|
+
self.upserts += other.upserts
|
|
153
|
+
self.upsert_inserts += other.upsert_inserts
|
|
154
|
+
self.upsert_updates += other.upsert_updates
|
|
155
|
+
self.duplicates += other.duplicates
|
|
156
|
+
self.failures += other.failures
|
|
157
|
+
self.db_time += other.db_time
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
# Input: ModelType from a RowIterator.
|
|
161
|
+
# Output: AtomicUnit (or list of them).
|
|
162
|
+
ProcessorFunc = Callable[[ModelType], Union[AtomicUnit, Sequence[AtomicUnit], None]]
|
|
163
|
+
|
|
164
|
+
# Below this size, batch processing is likely inefficient and considered suspicious.
|
|
165
|
+
MIN_BATCH_WARNING = 100
|
|
166
|
+
|
|
167
|
+
# Above this size, a batch is considered too large and will trigger a warning.
|
|
168
|
+
MAX_BATCH_WARNING = 5_000
|
|
169
|
+
|
|
170
|
+
# Safety limit: max operations allowed for a single input atomic item.
|
|
171
|
+
MAX_OPS_PER_ATOM = 50
|
|
172
|
+
|
|
173
|
+
MAX_FAILURE_BURST = 50
|
|
174
|
+
|
|
175
|
+
# Session misuse guardrails (long-lived session detector).
|
|
176
|
+
SESSION_WARN_THRESHOLD_SECONDS = 300.0
|
|
177
|
+
SESSION_WARN_REPEAT_SECONDS = 60.0
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class CommitStrategy(Enum):
|
|
181
|
+
"""Commit behavior control for windowed_query (legacy API)."""
|
|
182
|
+
|
|
183
|
+
COMMIT_ON_SUCCESS = auto()
|
|
184
|
+
FORCE_COMMIT = auto()
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class MockContext:
|
|
188
|
+
"""
|
|
189
|
+
Lightweight mock context passed to SQLAlchemy default/onupdate callables
|
|
190
|
+
that expect a context-like object.
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
def __init__(self, column: SQLColumnExpression) -> None:
|
|
194
|
+
self.current_parameters = {}
|
|
195
|
+
self.current_column = column
|
|
196
|
+
self.connection = None
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
########################################################################################################################
|
|
200
|
+
# EXCEPTIONS
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class MissingOptionalDependency(ImportError):
|
|
204
|
+
pass
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def _require_pandas() -> Any:
|
|
208
|
+
if pd is None:
|
|
209
|
+
raise MissingOptionalDependency(
|
|
210
|
+
"This feature requires pandas. Add the extra dependency. Example: pip install datamarket[pandas]"
|
|
211
|
+
)
|
|
212
|
+
return pd
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
########################################################################################################################
|
|
216
|
+
# ITERATORS
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
# TODO: enable iterating over joins of tables (multiple entities)
|
|
220
|
+
class RowIterator(Iterator[ModelType]):
|
|
221
|
+
"""
|
|
222
|
+
Stateful iterator that performs Composite Keyset Pagination.
|
|
223
|
+
|
|
224
|
+
It dynamically builds complex SQL `WHERE` predicates (nested OR/AND logic)
|
|
225
|
+
to support efficient pagination across multiple columns with mixed sort directions
|
|
226
|
+
(e.g., [Date DESC, Category ASC, ID ASC]).
|
|
227
|
+
|
|
228
|
+
Includes Auto-Rerun logic: If new rows appear "in the past" (behind the cursor)
|
|
229
|
+
during iteration, it detects them at the end and performs a targeted rerun using
|
|
230
|
+
a "Chained Snapshot" strategy to ensure data consistency.
|
|
231
|
+
|
|
232
|
+
Manages the session lifecycle internally (Open -> Fetch -> Expunge -> Close).
|
|
233
|
+
"""
|
|
234
|
+
|
|
235
|
+
def __init__(
|
|
236
|
+
self,
|
|
237
|
+
interface: "_BatchPipelineOps",
|
|
238
|
+
source: Union[Select[Any], Query],
|
|
239
|
+
chunk_size: int,
|
|
240
|
+
order_by: Optional[List[SQLColumnExpression[Any]]] = None,
|
|
241
|
+
limit: Optional[int] = None,
|
|
242
|
+
):
|
|
243
|
+
self.interface = interface
|
|
244
|
+
self.chunk_size = chunk_size
|
|
245
|
+
self.limit = limit
|
|
246
|
+
self.yielded_count = 0
|
|
247
|
+
|
|
248
|
+
if not isinstance(source, (Select, Query)):
|
|
249
|
+
raise ValueError("RowIterator expects a SQLAlchemy Select or Query object.")
|
|
250
|
+
|
|
251
|
+
# Normalize Source
|
|
252
|
+
self.stmt = source.statement if isinstance(source, Query) and hasattr(source, "statement") else source
|
|
253
|
+
|
|
254
|
+
# Identify Primary Key(s) (Tie-Breaker)
|
|
255
|
+
desc = self.stmt.column_descriptions[0]
|
|
256
|
+
entity = desc["entity"]
|
|
257
|
+
primary_keys = list(inspect(entity).primary_key)
|
|
258
|
+
|
|
259
|
+
if not primary_keys:
|
|
260
|
+
raise ValueError(f"RowIterator: Model {entity} has no primary key.")
|
|
261
|
+
|
|
262
|
+
self.pk_cols = primary_keys # List of PK columns
|
|
263
|
+
|
|
264
|
+
# Construct composite sort key
|
|
265
|
+
self.sort_cols = []
|
|
266
|
+
user_sort_col_names = set()
|
|
267
|
+
if order_by:
|
|
268
|
+
self.sort_cols.extend(order_by)
|
|
269
|
+
user_sort_col_names = {self._get_col_name(c) for c in order_by}
|
|
270
|
+
|
|
271
|
+
# Ensure all PK columns are present in sort_cols (preserving user direction if present)
|
|
272
|
+
for pk_col in self.pk_cols:
|
|
273
|
+
pk_name = self._get_col_name(pk_col)
|
|
274
|
+
if pk_name not in user_sort_col_names:
|
|
275
|
+
self.sort_cols.append(pk_col)
|
|
276
|
+
|
|
277
|
+
# State Management
|
|
278
|
+
self.last_vals: Optional[tuple] = None
|
|
279
|
+
self._buffer: Deque[ModelType] = deque()
|
|
280
|
+
self._done = False
|
|
281
|
+
|
|
282
|
+
# Snapshot Windows (Chained Snapshots)
|
|
283
|
+
self.id_floor = None
|
|
284
|
+
# Use first PK for snapshot windowing
|
|
285
|
+
self.id_col = self.pk_cols[0]
|
|
286
|
+
with self.interface:
|
|
287
|
+
self.id_ceiling = self._get_max_id_at_start()
|
|
288
|
+
self._catch_up_mode = False
|
|
289
|
+
|
|
290
|
+
# Calculate total rows
|
|
291
|
+
self.total_rows = self._initial_count(source)
|
|
292
|
+
|
|
293
|
+
def _get_max_id_at_start(self) -> Optional[Any]:
|
|
294
|
+
"""Captures the current max ID within the query's scope."""
|
|
295
|
+
try:
|
|
296
|
+
# We wrap the original statement in a subquery to respect filters like is_crawled=False.
|
|
297
|
+
sub = self.stmt.subquery()
|
|
298
|
+
stmt = select(func.max(sub.c[self.id_col.name]))
|
|
299
|
+
return self.interface.session.execute(stmt).scalar()
|
|
300
|
+
except Exception as e:
|
|
301
|
+
logger.warning(f"RowIterator: Could not capture snapshot ID: {e}")
|
|
302
|
+
return None
|
|
303
|
+
|
|
304
|
+
def _initial_count(self, source: Union[Select[Any], Query]) -> int:
|
|
305
|
+
"""Performs initial count of the dataset."""
|
|
306
|
+
try:
|
|
307
|
+
with self.interface:
|
|
308
|
+
if isinstance(source, Query):
|
|
309
|
+
db_count = source.with_session(self.interface.session).count()
|
|
310
|
+
else:
|
|
311
|
+
subquery = self.stmt.subquery()
|
|
312
|
+
count_stmt = select(func.count()).select_from(subquery)
|
|
313
|
+
db_count = self.interface.session.execute(count_stmt).scalar() or 0
|
|
314
|
+
|
|
315
|
+
return min(db_count, self.limit) if self.limit else db_count
|
|
316
|
+
except Exception as e:
|
|
317
|
+
logger.warning(f"RowIterator: Failed to calculate total rows: {e}")
|
|
318
|
+
return 0
|
|
319
|
+
|
|
320
|
+
def __len__(self) -> int:
|
|
321
|
+
return self.total_rows
|
|
322
|
+
|
|
323
|
+
def __iter__(self) -> "RowIterator":
|
|
324
|
+
return self
|
|
325
|
+
|
|
326
|
+
def __next__(self) -> ModelType:
|
|
327
|
+
# Check global limit
|
|
328
|
+
if self.limit is not None and self.yielded_count >= self.limit:
|
|
329
|
+
raise StopIteration
|
|
330
|
+
|
|
331
|
+
if not self._buffer:
|
|
332
|
+
if self._done:
|
|
333
|
+
raise StopIteration
|
|
334
|
+
|
|
335
|
+
self._fetch_next_chunk()
|
|
336
|
+
|
|
337
|
+
if not self._buffer:
|
|
338
|
+
self._done = True
|
|
339
|
+
raise StopIteration
|
|
340
|
+
|
|
341
|
+
self.yielded_count += 1
|
|
342
|
+
return self._buffer.popleft()
|
|
343
|
+
|
|
344
|
+
def _get_col_name(self, col_expr: Any) -> str:
|
|
345
|
+
"""
|
|
346
|
+
Helper to safely get column names from any SQL expression.
|
|
347
|
+
Handles: ORM Columns, Core Columns, Labels, Aliases, Unary (desc).
|
|
348
|
+
"""
|
|
349
|
+
element = col_expr
|
|
350
|
+
|
|
351
|
+
# Unwrap Sort Modifiers (DESC, NULLS LAST)
|
|
352
|
+
while isinstance(element, UnaryExpression):
|
|
353
|
+
element = element.element
|
|
354
|
+
|
|
355
|
+
# Handle ORM Attributes (User.id)
|
|
356
|
+
if hasattr(element, "key"):
|
|
357
|
+
return element.key
|
|
358
|
+
|
|
359
|
+
# Handle Labels/Core Columns (table.c.id)
|
|
360
|
+
if hasattr(element, "name"):
|
|
361
|
+
return element.name
|
|
362
|
+
|
|
363
|
+
# Fallback for anonymous expressions
|
|
364
|
+
if hasattr(element, "compile"):
|
|
365
|
+
return str(element.compile(compile_kwargs={"literal_binds": True}))
|
|
366
|
+
|
|
367
|
+
raise ValueError(f"Could not determine name for sort column: {col_expr}")
|
|
368
|
+
|
|
369
|
+
def _fetch_next_chunk(self) -> None:
|
|
370
|
+
"""Session-managed fetch logic with auto-rerun support."""
|
|
371
|
+
self.interface.start()
|
|
372
|
+
try:
|
|
373
|
+
# Use a loop to handle the "Rerun" trigger immediately
|
|
374
|
+
while not self._buffer:
|
|
375
|
+
fetch_size = self._calculate_fetch_size()
|
|
376
|
+
if fetch_size <= 0:
|
|
377
|
+
break
|
|
378
|
+
|
|
379
|
+
paged_stmt = self._build_query(fetch_size)
|
|
380
|
+
result = self.interface.session.execute(paged_stmt)
|
|
381
|
+
rows = result.all()
|
|
382
|
+
|
|
383
|
+
if rows:
|
|
384
|
+
self._process_rows(rows)
|
|
385
|
+
break # We have data!
|
|
386
|
+
|
|
387
|
+
# No rows found in current window. Check for new rows (shifts window).
|
|
388
|
+
if not self._check_for_new_rows():
|
|
389
|
+
break # Truly done.
|
|
390
|
+
|
|
391
|
+
# If _check_for_new_rows was True, the loop continues
|
|
392
|
+
# and runs _build_query again with the NEW window.
|
|
393
|
+
|
|
394
|
+
finally:
|
|
395
|
+
self.interface.stop(commit=False)
|
|
396
|
+
|
|
397
|
+
def _calculate_fetch_size(self) -> int:
|
|
398
|
+
if self.limit is None:
|
|
399
|
+
return self.chunk_size
|
|
400
|
+
remaining = self.limit - self.yielded_count
|
|
401
|
+
return min(self.chunk_size, remaining)
|
|
402
|
+
|
|
403
|
+
def _build_query(self, fetch_size: int) -> Select[Any]:
|
|
404
|
+
"""Applies snapshot windows and keyset pagination filters."""
|
|
405
|
+
paged_stmt = self.stmt
|
|
406
|
+
|
|
407
|
+
if self.id_ceiling is not None:
|
|
408
|
+
paged_stmt = paged_stmt.where(self.id_col <= self.id_ceiling)
|
|
409
|
+
|
|
410
|
+
if self.id_floor is not None:
|
|
411
|
+
paged_stmt = paged_stmt.where(self.id_col > self.id_floor)
|
|
412
|
+
|
|
413
|
+
if self.last_vals is not None:
|
|
414
|
+
paged_stmt = paged_stmt.where(self._build_keyset_predicate(self.sort_cols, self.last_vals))
|
|
415
|
+
|
|
416
|
+
return paged_stmt.order_by(*self.sort_cols).limit(fetch_size)
|
|
417
|
+
|
|
418
|
+
def _process_rows(self, rows: List[Any]) -> None:
|
|
419
|
+
"""Expunges objects and updates the keyset bookmark."""
|
|
420
|
+
for row in rows:
|
|
421
|
+
obj = row[0] if isinstance(row, tuple) or hasattr(row, "_mapping") else row
|
|
422
|
+
self.interface.session.expunge(obj)
|
|
423
|
+
self._buffer.append(obj)
|
|
424
|
+
|
|
425
|
+
# Update pagination bookmark from object
|
|
426
|
+
last_vals_list = []
|
|
427
|
+
for col in self.sort_cols:
|
|
428
|
+
col_name = self._get_col_name(col)
|
|
429
|
+
last_vals_list.append(getattr(obj, col_name))
|
|
430
|
+
self.last_vals = tuple(last_vals_list)
|
|
431
|
+
|
|
432
|
+
def _check_for_new_rows(self) -> bool:
|
|
433
|
+
"""Checks for new data, shifts the snapshot window, and triggers rerun."""
|
|
434
|
+
new_ceiling = self._get_max_id_at_start()
|
|
435
|
+
|
|
436
|
+
if new_ceiling is None or (self.id_ceiling is not None and new_ceiling <= self.id_ceiling):
|
|
437
|
+
return False
|
|
438
|
+
|
|
439
|
+
try:
|
|
440
|
+
check_stmt = self.stmt.where(self.id_col > self.id_ceiling)
|
|
441
|
+
check_stmt = check_stmt.where(self.id_col <= new_ceiling)
|
|
442
|
+
|
|
443
|
+
sub = check_stmt.subquery()
|
|
444
|
+
count_query = select(func.count()).select_from(sub)
|
|
445
|
+
|
|
446
|
+
new_rows_count = self.interface.session.execute(count_query).scalar() or 0
|
|
447
|
+
|
|
448
|
+
if new_rows_count > 0:
|
|
449
|
+
logger.info(f"RowIterator: Found {new_rows_count} new rows. Shifting window.")
|
|
450
|
+
|
|
451
|
+
if self.limit is not None:
|
|
452
|
+
self.total_rows = min(self.total_rows + new_rows_count, self.limit)
|
|
453
|
+
else:
|
|
454
|
+
self.total_rows += new_rows_count
|
|
455
|
+
|
|
456
|
+
self.id_floor = self.id_ceiling
|
|
457
|
+
self.id_ceiling = new_ceiling
|
|
458
|
+
self.last_vals = None
|
|
459
|
+
self._catch_up_mode = True
|
|
460
|
+
return True
|
|
461
|
+
except Exception as e:
|
|
462
|
+
logger.warning(f"RowIterator: Failed to check for new rows: {e}")
|
|
463
|
+
|
|
464
|
+
return False
|
|
465
|
+
|
|
466
|
+
def _build_keyset_predicate(self, columns: List[Any], last_values: tuple) -> Any:
|
|
467
|
+
"""
|
|
468
|
+
Constructs the recursive OR/AND SQL filter for mixed-direction keyset pagination.
|
|
469
|
+
|
|
470
|
+
Logic for columns (A, B, C) compared to values (va, vb, vc):
|
|
471
|
+
1. (A > va)
|
|
472
|
+
2. OR (A = va AND B > vb)
|
|
473
|
+
3. OR (A = va AND B = vb AND C > vc)
|
|
474
|
+
|
|
475
|
+
*Swaps > for < if the column is Descending.
|
|
476
|
+
"""
|
|
477
|
+
|
|
478
|
+
conditions = []
|
|
479
|
+
|
|
480
|
+
# We need to build the "Equality Chain" (A=va AND B=vb ...)
|
|
481
|
+
# that acts as the prefix for the next column's check.
|
|
482
|
+
equality_chain = []
|
|
483
|
+
|
|
484
|
+
for i, col_expr in enumerate(columns):
|
|
485
|
+
last_val = last_values[i]
|
|
486
|
+
|
|
487
|
+
# 1. INSPECT DIRECTION (ASC vs DESC)
|
|
488
|
+
is_desc = False
|
|
489
|
+
actual_col = col_expr
|
|
490
|
+
|
|
491
|
+
if isinstance(col_expr, UnaryExpression):
|
|
492
|
+
actual_col = col_expr.element
|
|
493
|
+
# Robust check using SQLAlchemy operator identity
|
|
494
|
+
if col_expr.modifier == desc_op:
|
|
495
|
+
is_desc = True
|
|
496
|
+
|
|
497
|
+
# 2. DETERMINE OPERATOR
|
|
498
|
+
# ASC: col > val | DESC: col < val
|
|
499
|
+
diff_check = actual_col < last_val if is_desc else actual_col > last_val
|
|
500
|
+
|
|
501
|
+
# 3. BUILD THE TERM
|
|
502
|
+
# (Previous Cols Equal) AND (Current Col is Better)
|
|
503
|
+
term = diff_check if not equality_chain else and_(*equality_chain, diff_check)
|
|
504
|
+
|
|
505
|
+
conditions.append(term)
|
|
506
|
+
|
|
507
|
+
# 4. EXTEND EQUALITY CHAIN FOR NEXT LOOP
|
|
508
|
+
# Add (Current Col == Last Val) to the chain
|
|
509
|
+
equality_chain.append(actual_col == last_val)
|
|
510
|
+
|
|
511
|
+
# Combine all terms with OR
|
|
512
|
+
return or_(*conditions)
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
class AtomIterator(Iterator[AtomicUnit]):
|
|
516
|
+
"""
|
|
517
|
+
Iterator wrapper that strictly yields AtomicUnits.
|
|
518
|
+
Used for directly inserting into DB without querying any of its tables first.
|
|
519
|
+
Ensures type safety of the source stream.
|
|
520
|
+
"""
|
|
521
|
+
|
|
522
|
+
def __init__(self, source: Iterator[Any], limit: Optional[int] = None):
|
|
523
|
+
self._source = source
|
|
524
|
+
self.limit = limit
|
|
525
|
+
self.yielded_count = 0
|
|
526
|
+
|
|
527
|
+
def __iter__(self) -> "AtomIterator":
|
|
528
|
+
return self
|
|
529
|
+
|
|
530
|
+
def __next__(self) -> AtomicUnit:
|
|
531
|
+
# Check global limit
|
|
532
|
+
if self.limit is not None and self.yielded_count >= self.limit:
|
|
533
|
+
raise StopIteration
|
|
534
|
+
|
|
535
|
+
item = next(self._source)
|
|
536
|
+
if not isinstance(item, AtomicUnit):
|
|
537
|
+
raise ValueError(f"AtomIterator expected AtomicUnit, got {type(item)}. Check your generator.")
|
|
538
|
+
|
|
539
|
+
self.yielded_count += 1
|
|
540
|
+
return item
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
########################################################################################################################
|
|
544
|
+
# CLASSES
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
class _BaseAlchemyCore:
|
|
548
|
+
"""
|
|
549
|
+
Core SQLAlchemy infrastructure:
|
|
550
|
+
|
|
551
|
+
- Engine and session factory.
|
|
552
|
+
- Manual session lifecycle (start/stop, context manager).
|
|
553
|
+
- Long-lived session detection (warnings when a session stays open too long).
|
|
554
|
+
- Basic DDL helpers (create/drop/reset tables).
|
|
555
|
+
- Utilities such as reset_column and integrity error logging.
|
|
556
|
+
|
|
557
|
+
This class is meant to be combined with mixins that add higher-level behavior.
|
|
558
|
+
"""
|
|
559
|
+
|
|
560
|
+
def __init__(self, config: MutableMapping) -> None:
|
|
561
|
+
"""
|
|
562
|
+
Initialize the core interface from a configuration mapping.
|
|
563
|
+
|
|
564
|
+
Expected config format:
|
|
565
|
+
{
|
|
566
|
+
"db": {
|
|
567
|
+
"engine": "postgresql+psycopg2",
|
|
568
|
+
"user": "...",
|
|
569
|
+
"password": "...",
|
|
570
|
+
"host": "...",
|
|
571
|
+
"port": 5432,
|
|
572
|
+
"database": "..."
|
|
573
|
+
}
|
|
574
|
+
}
|
|
575
|
+
"""
|
|
576
|
+
self.session: Optional[Session] = None
|
|
577
|
+
self._session_started_at: Optional[float] = None
|
|
578
|
+
self._last_session_warning_at: Optional[float] = None
|
|
16
579
|
|
|
17
|
-
class AlchemyInterface:
|
|
18
|
-
def __init__(self, config):
|
|
19
580
|
if "db" in config:
|
|
20
581
|
self.config = config["db"]
|
|
21
|
-
|
|
22
582
|
self.engine = create_engine(self.get_conn_str())
|
|
23
|
-
self.
|
|
24
|
-
self.cursor = self.session.connection().connection.cursor()
|
|
25
|
-
|
|
583
|
+
self.Session = sessionmaker(bind=self.engine)
|
|
26
584
|
else:
|
|
27
585
|
logger.warning("no db section in config")
|
|
28
586
|
|
|
29
|
-
def
|
|
587
|
+
def __enter__(self) -> "_BaseAlchemyCore":
|
|
588
|
+
"""Enter the runtime context related to this object (starts session)."""
|
|
589
|
+
self.start()
|
|
590
|
+
return self
|
|
591
|
+
|
|
592
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
593
|
+
"""
|
|
594
|
+
Exit the runtime context related to this object.
|
|
595
|
+
|
|
596
|
+
- Commits if no exception was raised inside the context.
|
|
597
|
+
- Rolls back otherwise.
|
|
598
|
+
- Always closes the session.
|
|
599
|
+
"""
|
|
600
|
+
should_commit = exc_type is None
|
|
601
|
+
self.stop(commit=should_commit)
|
|
602
|
+
|
|
603
|
+
def start(self) -> None:
|
|
604
|
+
"""
|
|
605
|
+
Start a new SQLAlchemy session manually.
|
|
606
|
+
|
|
607
|
+
This is intended for short-lived units of work. Long-lived sessions will
|
|
608
|
+
trigger warnings via _check_session_duration().
|
|
609
|
+
"""
|
|
610
|
+
if not hasattr(self, "Session"):
|
|
611
|
+
raise AttributeError("Database configuration not initialized. Cannot create session.")
|
|
612
|
+
if self.session is not None:
|
|
613
|
+
raise RuntimeError("Session already active.")
|
|
614
|
+
self.session = self.Session()
|
|
615
|
+
now = time.monotonic()
|
|
616
|
+
self._session_started_at = now
|
|
617
|
+
self._last_session_warning_at = None
|
|
618
|
+
logger.debug("SQLAlchemy session started manually.")
|
|
619
|
+
|
|
620
|
+
def stop(self, commit: bool = True) -> None:
|
|
621
|
+
"""
|
|
622
|
+
Stop the current SQLAlchemy session.
|
|
623
|
+
|
|
624
|
+
Args:
|
|
625
|
+
commit: If True, attempt to commit before closing. Otherwise rollback.
|
|
626
|
+
"""
|
|
627
|
+
if self.session is None:
|
|
628
|
+
logger.warning("No active session to stop.")
|
|
629
|
+
return
|
|
630
|
+
|
|
631
|
+
try:
|
|
632
|
+
if commit:
|
|
633
|
+
logger.debug("Committing SQLAlchemy session before stopping.")
|
|
634
|
+
self.session.commit()
|
|
635
|
+
else:
|
|
636
|
+
logger.debug("Rolling back SQLAlchemy session before stopping.")
|
|
637
|
+
self.session.rollback()
|
|
638
|
+
except Exception as e:
|
|
639
|
+
logger.error(f"Exception during session commit/rollback on stop: {e}", exc_info=True)
|
|
640
|
+
try:
|
|
641
|
+
self.session.rollback()
|
|
642
|
+
except Exception as rb_exc:
|
|
643
|
+
logger.error(f"Exception during secondary rollback attempt on stop: {rb_exc}", exc_info=True)
|
|
644
|
+
raise
|
|
645
|
+
finally:
|
|
646
|
+
logger.debug("Closing SQLAlchemy session.")
|
|
647
|
+
self.session.close()
|
|
648
|
+
self.session = None
|
|
649
|
+
self._session_started_at = None
|
|
650
|
+
self._last_session_warning_at = None
|
|
651
|
+
|
|
652
|
+
def _check_session_duration(self) -> None:
|
|
653
|
+
"""
|
|
654
|
+
Emit warnings if a manually managed session has been open for too long.
|
|
655
|
+
|
|
656
|
+
This is meant to detect misuse patterns such as:
|
|
657
|
+
with AlchemyInterface(...) as db:
|
|
658
|
+
# long-running loop / scraper / ETL here
|
|
659
|
+
"""
|
|
660
|
+
if self.session is None or self._session_started_at is None:
|
|
661
|
+
return
|
|
662
|
+
|
|
663
|
+
now = time.monotonic()
|
|
664
|
+
elapsed = now - self._session_started_at
|
|
665
|
+
|
|
666
|
+
if elapsed < SESSION_WARN_THRESHOLD_SECONDS:
|
|
667
|
+
return
|
|
668
|
+
|
|
669
|
+
if (
|
|
670
|
+
self._last_session_warning_at is None
|
|
671
|
+
or (now - self._last_session_warning_at) >= SESSION_WARN_REPEAT_SECONDS
|
|
672
|
+
):
|
|
673
|
+
logger.warning(
|
|
674
|
+
"SQLAlchemy session has been open for %.1f seconds. "
|
|
675
|
+
"This is likely a misuse of AlchemyInterface (long-lived session). "
|
|
676
|
+
"Prefer short-lived sessions or the batch pipeline API.",
|
|
677
|
+
elapsed,
|
|
678
|
+
)
|
|
679
|
+
self._last_session_warning_at = now
|
|
680
|
+
|
|
681
|
+
def get_conn_str(self) -> str:
|
|
682
|
+
"""
|
|
683
|
+
Build the SQLAlchemy connection string from the loaded configuration.
|
|
684
|
+
"""
|
|
30
685
|
return (
|
|
31
|
-
f
|
|
32
|
-
f
|
|
33
|
-
f
|
|
34
|
-
f
|
|
686
|
+
f"{self.config['engine']}://"
|
|
687
|
+
f"{self.config['user']}:{quote_plus(self.config['password'])}"
|
|
688
|
+
f"@{self.config['host']}:{self.config['port']}"
|
|
689
|
+
f"/{self.config['database']}"
|
|
35
690
|
)
|
|
36
691
|
|
|
37
692
|
@staticmethod
|
|
38
|
-
def get_schema_from_table(table):
|
|
693
|
+
def get_schema_from_table(table: Type[ModelType]) -> str:
|
|
694
|
+
"""
|
|
695
|
+
Infer schema name from a SQLAlchemy model class.
|
|
696
|
+
|
|
697
|
+
- Defaults to 'public'.
|
|
698
|
+
- Warns if no explicit schema is provided.
|
|
699
|
+
"""
|
|
39
700
|
schema = "public"
|
|
40
701
|
|
|
41
702
|
if isinstance(table.__table_args__, tuple):
|
|
@@ -51,7 +712,12 @@ class AlchemyInterface:
|
|
|
51
712
|
|
|
52
713
|
return schema
|
|
53
714
|
|
|
54
|
-
def create_tables(self, tables):
|
|
715
|
+
def create_tables(self, tables: List[Type[ModelType]]) -> None:
|
|
716
|
+
"""
|
|
717
|
+
Create schemas and tables (or views) if they do not already exist.
|
|
718
|
+
|
|
719
|
+
For views, it calls a custom `create_view(conn)` on the model if needed.
|
|
720
|
+
"""
|
|
55
721
|
for table in tables:
|
|
56
722
|
schema = self.get_schema_from_table(table)
|
|
57
723
|
|
|
@@ -74,7 +740,12 @@ class AlchemyInterface:
|
|
|
74
740
|
else:
|
|
75
741
|
logger.info(f"table {table.__tablename__} already exists")
|
|
76
742
|
|
|
77
|
-
def drop_tables(self, tables):
|
|
743
|
+
def drop_tables(self, tables: List[Type[ModelType]]) -> None:
|
|
744
|
+
"""
|
|
745
|
+
Drop the given tables or views if they exist.
|
|
746
|
+
|
|
747
|
+
Uses CASCADE to also drop dependent objects.
|
|
748
|
+
"""
|
|
78
749
|
for table in tables:
|
|
79
750
|
schema = self.get_schema_from_table(table)
|
|
80
751
|
|
|
@@ -90,22 +761,1260 @@ class AlchemyInterface:
|
|
|
90
761
|
conn.execute(DDL(f"DROP TABLE {schema}.{table.__tablename__} CASCADE"))
|
|
91
762
|
conn.commit()
|
|
92
763
|
|
|
93
|
-
def reset_db(self, tables, drop):
|
|
764
|
+
def reset_db(self, tables: List[Type[ModelType]], drop: bool = False) -> None:
|
|
765
|
+
"""
|
|
766
|
+
Reset the database objects for a list of models.
|
|
767
|
+
|
|
768
|
+
Args:
|
|
769
|
+
tables: List of model classes.
|
|
770
|
+
drop: If True, drop tables/views before recreating them.
|
|
771
|
+
"""
|
|
94
772
|
if drop:
|
|
95
773
|
self.drop_tables(tables)
|
|
96
774
|
|
|
97
775
|
self.create_tables(tables)
|
|
98
776
|
|
|
99
|
-
def
|
|
777
|
+
def reset_column(self, query_results: List[Result[Any]], column_name: str) -> None:
|
|
778
|
+
"""
|
|
779
|
+
Reset a column to its default value for a list of query results.
|
|
780
|
+
|
|
781
|
+
The defaults may come from:
|
|
782
|
+
- server_default (uses SQL DEFAULT),
|
|
783
|
+
- column.default (Python callable or constant).
|
|
784
|
+
|
|
785
|
+
Args:
|
|
786
|
+
query_results: List of ORM instances to update.
|
|
787
|
+
column_name: Name of the column to reset.
|
|
788
|
+
"""
|
|
789
|
+
if self.session is None:
|
|
790
|
+
raise RuntimeError("Session not active. Use 'with AlchemyInterface(...):' or call start()")
|
|
791
|
+
|
|
792
|
+
self._check_session_duration()
|
|
793
|
+
|
|
794
|
+
if not query_results:
|
|
795
|
+
logger.warning("No objects to reset column for.")
|
|
796
|
+
return
|
|
797
|
+
|
|
798
|
+
first_obj = query_results[0]
|
|
799
|
+
model_class = first_obj.__class__
|
|
800
|
+
table = model_class.__table__
|
|
801
|
+
|
|
802
|
+
if column_name not in table.columns:
|
|
803
|
+
logger.warning(f"Column {column_name} does not exist in table {table.name}.")
|
|
804
|
+
return
|
|
805
|
+
|
|
806
|
+
column = table.columns[column_name]
|
|
807
|
+
|
|
808
|
+
if column.server_default is not None:
|
|
809
|
+
# Use SQL DEFAULT so the server decides the final value.
|
|
810
|
+
default_value = text("DEFAULT")
|
|
811
|
+
elif column.default is not None:
|
|
812
|
+
default_value = column.default.arg
|
|
813
|
+
if callable(default_value):
|
|
814
|
+
# Some column defaults expect a context with metadata.
|
|
815
|
+
default_value = default_value(MockContext(column))
|
|
816
|
+
else:
|
|
817
|
+
raise ValueError(f"Column '{column_name}' doesn't have a default value defined.")
|
|
818
|
+
|
|
819
|
+
query_results.update({column_name: default_value}, synchronize_session=False)
|
|
820
|
+
|
|
821
|
+
@staticmethod
|
|
822
|
+
def _log_integrity_error(ex: IntegrityError, alchemy_obj: Any, action: str = "insert") -> None:
|
|
823
|
+
"""
|
|
824
|
+
Log PostgreSQL IntegrityError in a compact, human-friendly way using SQLSTATE codes.
|
|
825
|
+
|
|
826
|
+
For code meanings, see:
|
|
827
|
+
https://www.postgresql.org/docs/current/errcodes-appendix.html
|
|
828
|
+
"""
|
|
829
|
+
|
|
830
|
+
PG_ERROR_LABELS = {
|
|
831
|
+
"23000": "Integrity constraint violation",
|
|
832
|
+
"23001": "Restrict violation",
|
|
833
|
+
"23502": "NOT NULL violation",
|
|
834
|
+
"23503": "Foreign key violation",
|
|
835
|
+
"23505": "Unique violation",
|
|
836
|
+
"23514": "Check constraint violation",
|
|
837
|
+
"23P01": "Exclusion constraint violation",
|
|
838
|
+
}
|
|
839
|
+
code = getattr(ex.orig, "pgcode", None)
|
|
840
|
+
label = PG_ERROR_LABELS.get(code, "Integrity error (unspecified)")
|
|
841
|
+
|
|
842
|
+
if code == "23505":
|
|
843
|
+
logger.info(f"{label} trying to {action} {alchemy_obj}")
|
|
844
|
+
else:
|
|
845
|
+
logger.error(f"{label} trying to {action} {alchemy_obj}\nPostgreSQL message: {ex.orig}")
|
|
846
|
+
|
|
847
|
+
def log_row(
|
|
848
|
+
self, obj: Any, columns: Optional[List[str]] = None, prefix_msg: str = "Row", full: bool = False
|
|
849
|
+
) -> None:
|
|
850
|
+
"""
|
|
851
|
+
Logs attributes of an SQLAlchemy object in a standardized format.
|
|
852
|
+
|
|
853
|
+
Args:
|
|
854
|
+
obj: The SQLAlchemy model instance.
|
|
855
|
+
columns: List of attribute names. If None, uses currently set attributes (obj.__dict__).
|
|
856
|
+
prefix_msg: Descriptive tag for the log (e.g. "Crawled", "Parsed", "To be inserted"). Defaults to "Row".
|
|
857
|
+
full: If True, disables truncation of long values (default limit is 500 chars).
|
|
858
|
+
"""
|
|
859
|
+
try:
|
|
860
|
+
# If no specific columns requested, grab only what is currently set on the object
|
|
861
|
+
if not columns:
|
|
862
|
+
# Filter out SQLAlchemy internals (keys starting with '_')
|
|
863
|
+
columns = [k for k in obj.__dict__ if not k.startswith("_")]
|
|
864
|
+
if not columns:
|
|
865
|
+
logger.info(f"{prefix_msg}: {obj}")
|
|
866
|
+
return
|
|
867
|
+
|
|
868
|
+
stats_parts = []
|
|
869
|
+
for col in columns:
|
|
870
|
+
# getattr is safe; fallback to 'N/A' if the key is missing
|
|
871
|
+
val = getattr(obj, col, "N/A")
|
|
872
|
+
val_str = str(val)
|
|
873
|
+
|
|
874
|
+
# Truncation logic (Default limit 500 chars)
|
|
875
|
+
if not full and len(val_str) > 500:
|
|
876
|
+
val_str = val_str[:500] + "...(truncated)"
|
|
877
|
+
|
|
878
|
+
stats_parts.append(f"{col}={val_str}")
|
|
879
|
+
|
|
880
|
+
stats_msg = ", ".join(stats_parts)
|
|
881
|
+
# Result: "Crawled: Category | url=http://... , zip_code=28001"
|
|
882
|
+
logger.info(f"{prefix_msg}: {obj.__class__.__name__} | {stats_msg}")
|
|
883
|
+
|
|
884
|
+
except Exception as e:
|
|
885
|
+
# Fallback to standard repr if anything breaks, but log the error
|
|
886
|
+
logger.error(f"Failed to generate detailed log for {obj}", exc_info=e)
|
|
887
|
+
logger.info(f"{prefix_msg}: {obj}")
|
|
888
|
+
|
|
889
|
+
|
|
890
|
+
class _LegacyAlchemyOps:
|
|
891
|
+
"""
|
|
892
|
+
Mixin containing legacy CRUD helpers and range-query utilities.
|
|
893
|
+
|
|
894
|
+
These methods remain for backwards compatibility but should not be used in
|
|
895
|
+
new code. Prefer the batch pipeline and iter_query_safe for new flows.
|
|
896
|
+
"""
|
|
897
|
+
|
|
898
|
+
def insert_alchemy_obj(self, alchemy_obj: ModelType, silent: bool = False) -> bool:
|
|
899
|
+
"""
|
|
900
|
+
Legacy insert helper using per-object savepoints.
|
|
901
|
+
|
|
902
|
+
- Uses a nested transaction per insert (savepoint).
|
|
903
|
+
- Swallows IntegrityError and returns False if it occurs.
|
|
904
|
+
|
|
905
|
+
Prefer using lightweight INSERT via the batch pipeline.
|
|
906
|
+
"""
|
|
907
|
+
logger.warning(
|
|
908
|
+
"DEPRECATED: insert_alchemy_obj is legacy API. "
|
|
909
|
+
"Prefer using the batch pipeline (process_batch) with lightweight inserts."
|
|
910
|
+
)
|
|
911
|
+
|
|
912
|
+
if self.session is None:
|
|
913
|
+
raise RuntimeError("Session not active. Use 'with AlchemyInterface(...):' or call start()")
|
|
914
|
+
|
|
915
|
+
self._check_session_duration()
|
|
916
|
+
|
|
100
917
|
try:
|
|
918
|
+
# Use a savepoint (nested transaction)
|
|
919
|
+
with self.session.begin_nested():
|
|
920
|
+
if not silent:
|
|
921
|
+
logger.info(f"adding {alchemy_obj}...")
|
|
922
|
+
self.session.add(alchemy_obj)
|
|
923
|
+
except IntegrityError as ex:
|
|
924
|
+
# Rollback is handled automatically by begin_nested() context manager on error
|
|
101
925
|
if not silent:
|
|
102
|
-
|
|
926
|
+
self._log_integrity_error(ex, alchemy_obj, action="insert")
|
|
927
|
+
# Do not re-raise, allow outer transaction/loop to continue
|
|
928
|
+
return False
|
|
929
|
+
|
|
930
|
+
return True
|
|
103
931
|
|
|
104
|
-
|
|
105
|
-
|
|
932
|
+
def upsert_alchemy_obj(self, alchemy_obj: ModelType, index_elements: List[str], silent: bool = False) -> bool:
|
|
933
|
+
"""
|
|
934
|
+
Legacy upsert helper using per-object savepoints and ON CONFLICT DO UPDATE.
|
|
935
|
+
|
|
936
|
+
- Builds insert/update dicts from ORM object attributes.
|
|
937
|
+
- Uses a savepoint per upsert.
|
|
938
|
+
- Swallows IntegrityError and returns False if it occurs.
|
|
939
|
+
|
|
940
|
+
Prefer using lightweight upsert via the batch pipeline.
|
|
941
|
+
"""
|
|
942
|
+
logger.warning(
|
|
943
|
+
"DEPRECATED: upsert_alchemy_obj is legacy API. "
|
|
944
|
+
"Prefer using the batch pipeline (process_batch) with lightweight upserts."
|
|
945
|
+
)
|
|
946
|
+
|
|
947
|
+
if self.session is None:
|
|
948
|
+
raise RuntimeError("Session not active. Use 'with AlchemyInterface(...):' or call start()")
|
|
949
|
+
|
|
950
|
+
self._check_session_duration()
|
|
951
|
+
|
|
952
|
+
if not silent:
|
|
953
|
+
logger.info(f"upserting {alchemy_obj}")
|
|
954
|
+
|
|
955
|
+
table = alchemy_obj.__table__
|
|
956
|
+
primary_keys = list(col.name for col in table.primary_key.columns.values())
|
|
957
|
+
|
|
958
|
+
# Build the dictionary for the INSERT values
|
|
959
|
+
insert_values = {
|
|
960
|
+
col.name: getattr(alchemy_obj, col.name)
|
|
961
|
+
for col in table.columns
|
|
962
|
+
if getattr(alchemy_obj, col.name) is not None # Include all non-None values for insert
|
|
963
|
+
}
|
|
964
|
+
|
|
965
|
+
# Build the dictionary for the UPDATE set clause
|
|
966
|
+
# Start with values from the object, excluding primary keys
|
|
967
|
+
update_set_values = {
|
|
968
|
+
col.name: val
|
|
969
|
+
for col in table.columns
|
|
970
|
+
if col.name not in primary_keys and (val := getattr(alchemy_obj, col.name)) is not None
|
|
971
|
+
}
|
|
972
|
+
|
|
973
|
+
# Add columns with SQL-based onupdate values explicitly to the set clause
|
|
974
|
+
for column in table.columns:
|
|
975
|
+
actual_sql_expression = None
|
|
976
|
+
if column.onupdate is not None:
|
|
977
|
+
if hasattr(column.onupdate, "arg") and isinstance(column.onupdate.arg, ClauseElement):
|
|
978
|
+
# This handles wrappers like ColumnElementColumnDefault,
|
|
979
|
+
# where the actual SQL expression is in the .arg attribute.
|
|
980
|
+
actual_sql_expression = column.onupdate.arg
|
|
981
|
+
elif isinstance(column.onupdate, ClauseElement):
|
|
982
|
+
# This handles cases where onupdate might be a direct SQL expression.
|
|
983
|
+
actual_sql_expression = column.onupdate
|
|
984
|
+
|
|
985
|
+
if actual_sql_expression is not None:
|
|
986
|
+
update_set_values[column.name] = actual_sql_expression
|
|
987
|
+
|
|
988
|
+
statement = (
|
|
989
|
+
insert(table)
|
|
990
|
+
.values(insert_values)
|
|
991
|
+
.on_conflict_do_update(index_elements=index_elements, set_=update_set_values)
|
|
992
|
+
)
|
|
106
993
|
|
|
107
|
-
|
|
994
|
+
try:
|
|
995
|
+
# Use a savepoint (nested transaction)
|
|
996
|
+
with self.session.begin_nested():
|
|
997
|
+
self.session.execute(statement)
|
|
998
|
+
except IntegrityError as ex:
|
|
999
|
+
# Rollback is handled automatically by begin_nested() context manager on error
|
|
108
1000
|
if not silent:
|
|
109
|
-
|
|
1001
|
+
self._log_integrity_error(ex, alchemy_obj, action="upsert")
|
|
1002
|
+
# Do not re-raise, allow outer transaction/loop to continue
|
|
1003
|
+
return False
|
|
1004
|
+
|
|
1005
|
+
return True
|
|
1006
|
+
|
|
1007
|
+
def windowed_query(
|
|
1008
|
+
self,
|
|
1009
|
+
stmt: Select[Any],
|
|
1010
|
+
order_by: List[SQLColumnExpression[Any]],
|
|
1011
|
+
windowsize: int,
|
|
1012
|
+
commit_strategy: Union[CommitStrategy, str] = CommitStrategy.COMMIT_ON_SUCCESS,
|
|
1013
|
+
) -> Iterator[Result[Any]]:
|
|
1014
|
+
"""
|
|
1015
|
+
Legacy windowed query helper (range query).
|
|
1016
|
+
|
|
1017
|
+
It executes the given SELECT statement in windows of size `windowsize`,
|
|
1018
|
+
each in its own short-lived session, and yields `Result` objects.
|
|
1019
|
+
|
|
1020
|
+
Prefer `get_row_iterator` for new range-query implementations.
|
|
1021
|
+
"""
|
|
1022
|
+
logger.warning("DEPRECATED: windowed_query is legacy API. Prefer using get_row_iterator for range queries.")
|
|
1023
|
+
|
|
1024
|
+
# Parameter mapping
|
|
1025
|
+
if isinstance(commit_strategy, str):
|
|
1026
|
+
commit_strategy = CommitStrategy[commit_strategy.upper()]
|
|
1027
|
+
|
|
1028
|
+
# Find id column in stmt
|
|
1029
|
+
if not any(column.get("entity").id for column in stmt.column_descriptions):
|
|
1030
|
+
raise Exception("Column 'id' not found in any entity of the query.")
|
|
1031
|
+
id_column = stmt.column_descriptions[0]["entity"].id
|
|
1032
|
+
|
|
1033
|
+
last_id = 0
|
|
1034
|
+
while True:
|
|
1035
|
+
session_active = False
|
|
1036
|
+
commit_needed = False
|
|
1037
|
+
try:
|
|
1038
|
+
self.start()
|
|
1039
|
+
session_active = True
|
|
1040
|
+
|
|
1041
|
+
# Filter on row_number in the outer query
|
|
1042
|
+
current_query = stmt.where(id_column > last_id).order_by(order_by[0], *order_by[1:]).limit(windowsize)
|
|
1043
|
+
result = self.session.execute(current_query)
|
|
1044
|
+
|
|
1045
|
+
# Create a FrozenResult to allow peeking at the data without consuming
|
|
1046
|
+
frozen_result: FrozenResult = result.freeze()
|
|
1047
|
+
chunk = frozen_result().all()
|
|
1048
|
+
|
|
1049
|
+
if not chunk:
|
|
1050
|
+
break
|
|
1051
|
+
|
|
1052
|
+
# Update for next iteration
|
|
1053
|
+
last_id = chunk[-1].id
|
|
1054
|
+
|
|
1055
|
+
# Create a new Result object from the FrozenResult
|
|
1056
|
+
yield_result = frozen_result()
|
|
1057
|
+
|
|
1058
|
+
yield yield_result
|
|
1059
|
+
commit_needed = True
|
|
1060
|
+
|
|
1061
|
+
finally:
|
|
1062
|
+
if session_active and self.session:
|
|
1063
|
+
# Double check before stopping just in case. The user may never call insert/upsert in the loop,
|
|
1064
|
+
# so a final check needs to be done.
|
|
1065
|
+
self._check_session_duration()
|
|
1066
|
+
|
|
1067
|
+
if commit_strategy == CommitStrategy.FORCE_COMMIT:
|
|
1068
|
+
# For forced commit, always attempt to commit.
|
|
1069
|
+
# The self.stop() method already handles potential exceptions during commit/rollback.
|
|
1070
|
+
self.stop(commit=True)
|
|
1071
|
+
elif commit_strategy == CommitStrategy.COMMIT_ON_SUCCESS:
|
|
1072
|
+
# Commit only if no exception occurred before yielding the result.
|
|
1073
|
+
self.stop(commit=commit_needed)
|
|
1074
|
+
else:
|
|
1075
|
+
# Fallback or error for unknown strategy, though type hinting should prevent this.
|
|
1076
|
+
# For safety, default to rollback.
|
|
1077
|
+
logger.warning(f"Unknown commit strategy: {commit_strategy}. Defaulting to rollback.")
|
|
1078
|
+
self.stop(commit=False)
|
|
1079
|
+
|
|
1080
|
+
|
|
1081
|
+
class _BatchPipelineOps:
|
|
1082
|
+
"""
|
|
1083
|
+
Mixin providing:
|
|
1084
|
+
- Safe, paginated iteration over large queries (iter_query).
|
|
1085
|
+
- Validation wrapper for atomic streams (iter_atoms).
|
|
1086
|
+
- Unified 'process_batch' method for Data Sink and ETL workflows.
|
|
1087
|
+
"""
|
|
1088
|
+
|
|
1089
|
+
def get_row_iterator(
|
|
1090
|
+
self,
|
|
1091
|
+
source: Union[Select[Any], Query],
|
|
1092
|
+
chunk_size: int,
|
|
1093
|
+
order_by: Optional[List[SQLColumnExpression[Any]]] = None,
|
|
1094
|
+
limit: Optional[int] = None,
|
|
1095
|
+
) -> RowIterator:
|
|
1096
|
+
"""
|
|
1097
|
+
Creates a RowIterator for safe, paginated iteration over large datasets.
|
|
1098
|
+
|
|
1099
|
+
Features:
|
|
1100
|
+
- **Composite Keyset Pagination:** Prevents data gaps/duplicates even with changing data.
|
|
1101
|
+
- **Mixed Sorting:** Supports arbitrary combinations of ASC and DESC columns.
|
|
1102
|
+
- **Automatic Safety:** Automatically appends the Primary Key to `order_by` to ensure total ordering.
|
|
1103
|
+
|
|
1104
|
+
Args:
|
|
1105
|
+
source: SQLAlchemy Select or Query object.
|
|
1106
|
+
chunk_size: Number of rows to fetch per transaction.
|
|
1107
|
+
order_by: List of columns to sort by. Can use `sqlalchemy.desc()`.
|
|
1108
|
+
Default is [Primary Key ASC].
|
|
1109
|
+
limit: Global limit on number of rows to yield.
|
|
1110
|
+
|
|
1111
|
+
Returns:
|
|
1112
|
+
RowIterator: An iterator yielding detached ORM objects.
|
|
1113
|
+
"""
|
|
1114
|
+
return RowIterator(self, source, chunk_size, order_by, limit)
|
|
1115
|
+
|
|
1116
|
+
def get_atom_iterator(
|
|
1117
|
+
self,
|
|
1118
|
+
source: Iterator[AtomicUnit],
|
|
1119
|
+
limit: Optional[int] = None,
|
|
1120
|
+
) -> AtomIterator:
|
|
1121
|
+
"""
|
|
1122
|
+
Wraps a generator to ensure it is treated as a stream of AtomicUnits.
|
|
1123
|
+
|
|
1124
|
+
Args:
|
|
1125
|
+
source: An iterator that yields AtomicUnit objects.
|
|
1126
|
+
limit: Maximum number of units to process.
|
|
1127
|
+
|
|
1128
|
+
Returns:
|
|
1129
|
+
AtomIterator: A wrapper iterator that validates strict typing at runtime.
|
|
1130
|
+
"""
|
|
1131
|
+
return AtomIterator(source, limit)
|
|
1132
|
+
|
|
1133
|
+
def process_batch(
|
|
1134
|
+
self,
|
|
1135
|
+
source: Union[RowIterator, AtomIterator],
|
|
1136
|
+
processor_func: Optional[ProcessorFunc] = None,
|
|
1137
|
+
batch_size: int = 100,
|
|
1138
|
+
use_bulk_strategy: bool = True,
|
|
1139
|
+
) -> None:
|
|
1140
|
+
"""
|
|
1141
|
+
Unified Batch Processor.
|
|
1142
|
+
|
|
1143
|
+
Accepts specific iterator types to ensure pipeline safety:
|
|
1144
|
+
- RowIterator: Yields Model objects (ETL Mode). Requires 'processor_func'.
|
|
1145
|
+
- AtomIterator: Yields AtomicUnits (Data Sink Mode). 'processor_func' is optional.
|
|
1146
|
+
|
|
1147
|
+
Args:
|
|
1148
|
+
source: A RowIterator or AtomIterator.
|
|
1149
|
+
processor_func: Function to transform items into AtomicUnits.
|
|
1150
|
+
If None, assumes source yields AtomicUnits directly.
|
|
1151
|
+
batch_size: Target number of SQL operations per transaction.
|
|
1152
|
+
use_bulk_strategy: If True (default), attempts fast Bulk Upserts/Inserts first. Duplicates cannot be examined, only counted.
|
|
1153
|
+
If False, skips directly to Row-by-Row recovery mode (useful for debugging duplicates).
|
|
1154
|
+
"""
|
|
1155
|
+
# Validation Checks
|
|
1156
|
+
self._validate_pipeline_config(batch_size)
|
|
1157
|
+
|
|
1158
|
+
total_items = len(source) if hasattr(source, "__len__") else None
|
|
1159
|
+
|
|
1160
|
+
if total_items is not None:
|
|
1161
|
+
logger.info(colorize(f"⚙️ Total items to process: {total_items}", SystemColor.PROCESS_BATCH_PROGRESS))
|
|
1162
|
+
else:
|
|
1163
|
+
logger.info(colorize("⚙️ Total items to process: Unknown (Stream Mode)", SystemColor.PROCESS_BATCH_PROGRESS))
|
|
1164
|
+
|
|
1165
|
+
current_batch: List[AtomicUnit] = []
|
|
1166
|
+
current_batch_ops_count = 0
|
|
1167
|
+
processed_count = 0 # Global counter
|
|
1168
|
+
|
|
1169
|
+
# Job Accumulators
|
|
1170
|
+
job_stats = _BatchStats()
|
|
1171
|
+
job_start_time = time.time()
|
|
1172
|
+
|
|
1173
|
+
for item in source:
|
|
1174
|
+
processed_count += 1
|
|
1175
|
+
|
|
1176
|
+
# Logs progress counter.
|
|
1177
|
+
if total_items:
|
|
1178
|
+
# Update in case of rerun
|
|
1179
|
+
total_items = len(source) if hasattr(source, "__len__") else None
|
|
1180
|
+
logger.info(
|
|
1181
|
+
colorize(f"⚙️ Processing item [{processed_count}/{total_items}]", SystemColor.PROCESS_BATCH_PROGRESS)
|
|
1182
|
+
)
|
|
1183
|
+
else:
|
|
1184
|
+
logger.info(colorize(f"⚙️ Processing item [{processed_count}]", SystemColor.PROCESS_BATCH_PROGRESS))
|
|
1185
|
+
|
|
1186
|
+
# Process Item (No DB Session active)
|
|
1187
|
+
try:
|
|
1188
|
+
result = processor_func(item) if processor_func else item
|
|
1189
|
+
except Exception as e:
|
|
1190
|
+
logger.error(f"Pipeline: Processor failed on item {item}. Flushing buffer before crash.")
|
|
1191
|
+
|
|
1192
|
+
# Emergency Flush: Save what we have before dying
|
|
1193
|
+
if current_batch:
|
|
1194
|
+
self._flush_and_log(current_batch, job_stats, job_start_time, use_bulk_strategy)
|
|
1195
|
+
|
|
1196
|
+
# Re-raise to stop the pipeline
|
|
1197
|
+
raise e
|
|
1198
|
+
|
|
1199
|
+
if not result:
|
|
1200
|
+
continue
|
|
1201
|
+
|
|
1202
|
+
# Normalize & Validate Units
|
|
1203
|
+
units = self._normalize_result_to_units(result)
|
|
1204
|
+
|
|
1205
|
+
# Add to Buffer
|
|
1206
|
+
for unit in units:
|
|
1207
|
+
current_batch.append(unit)
|
|
1208
|
+
current_batch_ops_count += len(unit)
|
|
1209
|
+
|
|
1210
|
+
# Check Buffer
|
|
1211
|
+
if current_batch_ops_count >= batch_size:
|
|
1212
|
+
self._flush_and_log(current_batch, job_stats, job_start_time, use_bulk_strategy)
|
|
1213
|
+
|
|
1214
|
+
current_batch = []
|
|
1215
|
+
current_batch_ops_count = 0
|
|
1216
|
+
|
|
1217
|
+
# Final flush
|
|
1218
|
+
if current_batch:
|
|
1219
|
+
self._flush_and_log(current_batch, job_stats, job_start_time, use_bulk_strategy)
|
|
1220
|
+
|
|
1221
|
+
def _validate_pipeline_config(self, batch_size: int) -> None:
|
|
1222
|
+
"""Helper to enforce pipeline guardrails."""
|
|
1223
|
+
if self.session is not None:
|
|
1224
|
+
raise RuntimeError(
|
|
1225
|
+
"Pipeline methods should not be called while a session is already active. "
|
|
1226
|
+
"Do not run this inside a 'with AlchemyInterface(...)' block."
|
|
1227
|
+
)
|
|
1228
|
+
|
|
1229
|
+
if batch_size < MIN_BATCH_WARNING:
|
|
1230
|
+
logger.warning(
|
|
1231
|
+
f"PERFORMANCE WARNING: batch_size={batch_size} is low. "
|
|
1232
|
+
f"Consider using at least {MIN_BATCH_WARNING} items per batch."
|
|
1233
|
+
)
|
|
1234
|
+
|
|
1235
|
+
if batch_size > MAX_BATCH_WARNING:
|
|
1236
|
+
logger.warning(
|
|
1237
|
+
f"PERFORMANCE WARNING: batch_size={batch_size} is very large. "
|
|
1238
|
+
f"This creates long-running transactions. Consider lowering it."
|
|
1239
|
+
)
|
|
1240
|
+
|
|
1241
|
+
def _normalize_result_to_units(self, result: Any) -> List[AtomicUnit]:
|
|
1242
|
+
"""Helper to validate and normalize processor results into a list of AtomicUnits."""
|
|
1243
|
+
# Handle both single item and list of items
|
|
1244
|
+
raw_units = result if isinstance(result, (list, tuple)) else [result]
|
|
1245
|
+
valid_units = []
|
|
1246
|
+
|
|
1247
|
+
for unit in raw_units:
|
|
1248
|
+
if not isinstance(unit, AtomicUnit):
|
|
1249
|
+
raise ValueError(f"Expected AtomicUnit, got {type(unit)}. Check your processor_func.")
|
|
1250
|
+
|
|
1251
|
+
unit_len = len(unit)
|
|
1252
|
+
if unit_len > MAX_OPS_PER_ATOM:
|
|
1253
|
+
logger.warning(
|
|
1254
|
+
f"Single AtomicUnit contains {unit_len} operations. Max allowed is {MAX_OPS_PER_ATOM}."
|
|
1255
|
+
f"Verify your code to make sure your atom contains the minimal number of operations."
|
|
1256
|
+
)
|
|
1257
|
+
valid_units.append(unit)
|
|
1258
|
+
|
|
1259
|
+
return valid_units
|
|
1260
|
+
|
|
1261
|
+
def _flush_and_log(
|
|
1262
|
+
self, batch: List[AtomicUnit], job_stats: _BatchStats, job_start_time: float, use_bulk_strategy: bool
|
|
1263
|
+
) -> None:
|
|
1264
|
+
"""
|
|
1265
|
+
Helper to flush the batch, update cumulative stats, and log the progress.
|
|
1266
|
+
"""
|
|
1267
|
+
batch_ops_count = sum(len(u) for u in batch)
|
|
1268
|
+
|
|
1269
|
+
# Measure DB Time for this batch
|
|
1270
|
+
db_start = time.time()
|
|
1271
|
+
|
|
1272
|
+
flush_stats = self._flush_batch_optimistic(batch) if use_bulk_strategy else self._flush_batch_resilient(batch)
|
|
1273
|
+
|
|
1274
|
+
db_duration = time.time() - db_start
|
|
1275
|
+
|
|
1276
|
+
# Update metadata
|
|
1277
|
+
flush_stats.size = batch_ops_count
|
|
1278
|
+
flush_stats.db_time = db_duration
|
|
1279
|
+
|
|
1280
|
+
# Update Job Totals
|
|
1281
|
+
job_stats.add(flush_stats)
|
|
1282
|
+
|
|
1283
|
+
# Log combined status
|
|
1284
|
+
self._log_progress(flush_stats, job_stats, job_start_time)
|
|
1285
|
+
|
|
1286
|
+
def _log_progress(self, batch_stats: _BatchStats, job_stats: _BatchStats, job_start_time: float) -> None:
|
|
1287
|
+
"""
|
|
1288
|
+
Logs the batch performance and the running total for the job.
|
|
1289
|
+
"""
|
|
1290
|
+
current_time = time.time()
|
|
1291
|
+
total_elapsed = current_time - job_start_time
|
|
1292
|
+
|
|
1293
|
+
# Calculate Job Speed
|
|
1294
|
+
job_ops_sec = job_stats.size / total_elapsed if total_elapsed > 0 else 0.0
|
|
1295
|
+
|
|
1296
|
+
logger.info(
|
|
1297
|
+
colorize(
|
|
1298
|
+
f"📦 PIPELINE PROGRESS | "
|
|
1299
|
+
f"BATCH: {batch_stats.size} ops (DB: {batch_stats.db_time:.2f}s) | "
|
|
1300
|
+
f"TOTAL: {job_stats.size} ops (Time: {total_elapsed:.0f}s, DB: {job_stats.db_time:.0f}s) "
|
|
1301
|
+
f"STATS: {job_stats.inserts} Ins, "
|
|
1302
|
+
f"{job_stats.upserts} Ups ({job_stats.upsert_inserts} Ins, {job_stats.upsert_updates} Upd), "
|
|
1303
|
+
f"{job_stats.duplicates} Dup, {job_stats.failures} Fail | "
|
|
1304
|
+
f"SPEED: {job_ops_sec:.1f} ops/s",
|
|
1305
|
+
SystemColor.BATCH_PIPELINE_STATS,
|
|
1306
|
+
)
|
|
1307
|
+
)
|
|
1308
|
+
|
|
1309
|
+
def _flush_batch_optimistic(self, units: List[AtomicUnit]) -> _BatchStats:
|
|
1310
|
+
"""
|
|
1311
|
+
Strategy: Optimistic Bulk Batching.
|
|
1312
|
+
|
|
1313
|
+
1. Groups operations by Table and Action.
|
|
1314
|
+
2. Normalizes data (applying defaults) to create uniform bulk payloads.
|
|
1315
|
+
3. Executes one massive SQL statement per group.
|
|
1316
|
+
4. Distinguishes Inserts vs Upserts using Postgres system columns (xmax).
|
|
1317
|
+
|
|
1318
|
+
Failure Strategy:
|
|
1319
|
+
If ANY group fails (Constraint, Deadlock, etc.), the entire transaction rolls back
|
|
1320
|
+
and we delegate to '_flush_batch_resilient' (Recovery Mode).
|
|
1321
|
+
"""
|
|
1322
|
+
if self.session is not None:
|
|
1323
|
+
raise RuntimeError("Unexpected active session during batch flush.")
|
|
1324
|
+
|
|
1325
|
+
self.start()
|
|
1326
|
+
self._check_session_duration()
|
|
1327
|
+
|
|
1328
|
+
stats = _BatchStats()
|
|
1329
|
+
|
|
1330
|
+
try:
|
|
1331
|
+
# Sort & Group ("The Traffic Cop")
|
|
1332
|
+
buckets = self._group_ops_by_signature(units)
|
|
1333
|
+
|
|
1334
|
+
# Process each bucket in bulk
|
|
1335
|
+
for signature, ops in buckets.items():
|
|
1336
|
+
table, action, index_elements = signature
|
|
1337
|
+
|
|
1338
|
+
bucket_stats = self._bulk_process_bucket(table, action, index_elements, ops)
|
|
1339
|
+
stats.add(bucket_stats)
|
|
1340
|
+
|
|
1341
|
+
self._check_session_duration()
|
|
1342
|
+
self.stop(commit=True)
|
|
1343
|
+
return stats
|
|
1344
|
+
|
|
1345
|
+
except Exception as e:
|
|
1346
|
+
# If ANY bulk op fails, the whole batch is tainted.
|
|
1347
|
+
# Rollback and switch to safe, row-by-row recovery.
|
|
1348
|
+
logger.warning(f"Optimistic bulk commit failed. Switching to Recovery Mode. Error: {e}")
|
|
1349
|
+
self.stop(commit=False)
|
|
1350
|
+
return self._flush_batch_resilient(units)
|
|
1351
|
+
|
|
1352
|
+
def _group_ops_by_signature(self, units: List[AtomicUnit]) -> Dict[tuple, List[BatchOp]]:
|
|
1353
|
+
"""
|
|
1354
|
+
Helper: Groups operations into buckets safe for bulk execution.
|
|
1355
|
+
Signature: (Table, OpAction, Tuple(index_elements))
|
|
1356
|
+
"""
|
|
1357
|
+
buckets = defaultdict(list)
|
|
1358
|
+
|
|
1359
|
+
for unit in units:
|
|
1360
|
+
for op in unit.ops:
|
|
1361
|
+
# We need to group by index_elements too, because different conflict targets
|
|
1362
|
+
# require different SQL statements.
|
|
1363
|
+
idx_key = tuple(sorted(op.index_elements)) if op.index_elements else None
|
|
1364
|
+
|
|
1365
|
+
sig = (op.obj.__table__, op.action, idx_key)
|
|
1366
|
+
buckets[sig].append(op)
|
|
1367
|
+
|
|
1368
|
+
return buckets
|
|
1369
|
+
|
|
1370
|
+
def _bulk_process_bucket(
|
|
1371
|
+
self,
|
|
1372
|
+
table: Table,
|
|
1373
|
+
action: OpAction,
|
|
1374
|
+
index_elements: Optional[tuple],
|
|
1375
|
+
ops: List[BatchOp],
|
|
1376
|
+
) -> _BatchStats:
|
|
1377
|
+
"""
|
|
1378
|
+
Executes a single bulk operation for a homogeneous group of records.
|
|
1379
|
+
Refactored to reduce cyclomatic complexity.
|
|
1380
|
+
"""
|
|
1381
|
+
# Prepare Data (Uniformity Pass)
|
|
1382
|
+
records = self._prepare_bulk_payload(table, ops)
|
|
1383
|
+
if not records:
|
|
1384
|
+
return _BatchStats()
|
|
1385
|
+
|
|
1386
|
+
# Build & Execute Statement
|
|
1387
|
+
stmt = self._build_bulk_stmt(table, action, index_elements, records)
|
|
1388
|
+
result = self.session.execute(stmt)
|
|
1389
|
+
rows = result.all()
|
|
1390
|
+
|
|
1391
|
+
# Calculate Stats
|
|
1392
|
+
return self._calculate_bulk_stats(action, rows, len(records))
|
|
1393
|
+
|
|
1394
|
+
def _prepare_bulk_payload(self, table: Table, ops: List[BatchOp]) -> List[Dict[str, Any]]:
|
|
1395
|
+
"""
|
|
1396
|
+
Helper: Iterates through operations and resolves the exact value for every column.
|
|
1397
|
+
1. Uses 'Sparse' strategy (Union of Keys) via `_get_active_bulk_columns`.
|
|
1398
|
+
2. Enforces Uniformity: If a column is active in the batch, every row sends a value for it.
|
|
1399
|
+
We do not skip PKs individually; if the batch schema includes the PK, we send it.
|
|
1400
|
+
"""
|
|
1401
|
+
active_column_names = self._get_active_bulk_columns(table, ops)
|
|
1402
|
+
|
|
1403
|
+
payload = []
|
|
1404
|
+
for op in ops:
|
|
1405
|
+
row_data = {}
|
|
1406
|
+
|
|
1407
|
+
# Iterate only over the "Union of Keys" found in this batch
|
|
1408
|
+
for col_name in active_column_names:
|
|
1409
|
+
col = table.columns[col_name]
|
|
1410
|
+
val = self._resolve_column_value(col, op.obj)
|
|
1411
|
+
|
|
1412
|
+
# Sentinel check: 'Ellipsis' means "Skip this column entirely" (e.g. Server OnUpdate)
|
|
1413
|
+
# Note: We do NOT skip None here. If the column is active but value is None,
|
|
1414
|
+
# we send None (which maps to NULL in SQL), preserving batch shape uniformity.
|
|
1415
|
+
if val is not Ellipsis:
|
|
1416
|
+
row_data[col_name] = val
|
|
1417
|
+
|
|
1418
|
+
payload.append(row_data)
|
|
1419
|
+
|
|
1420
|
+
return payload
|
|
1421
|
+
|
|
1422
|
+
def _get_active_bulk_columns(self, table: Table, ops: List[BatchOp]) -> Set[str]:
|
|
1423
|
+
"""
|
|
1424
|
+
Helper: Scans the batch to find the 'Union of Keys'.
|
|
1425
|
+
A column is active if it is explicitly set (not None) on ANY object in the batch,
|
|
1426
|
+
or if it has a System OnUpdate.
|
|
1427
|
+
"""
|
|
1428
|
+
active_names = set()
|
|
1429
|
+
|
|
1430
|
+
# Always include System OnUpdates
|
|
1431
|
+
for col in table.columns:
|
|
1432
|
+
if col.onupdate:
|
|
1433
|
+
active_names.add(col.name)
|
|
1434
|
+
|
|
1435
|
+
# Scan data for explicit values
|
|
1436
|
+
# We assume that if a column is None on all objects, it should be excluded
|
|
1437
|
+
# (to avoid triggering context-dependent defaults and doing unnecessary default simulations).
|
|
1438
|
+
for op in ops:
|
|
1439
|
+
for col in table.columns:
|
|
1440
|
+
# Optimization: Skip if already found
|
|
1441
|
+
if col.name in active_names:
|
|
1442
|
+
continue
|
|
1443
|
+
|
|
1444
|
+
if getattr(op.obj, col.name) is not None:
|
|
1445
|
+
active_names.add(col.name)
|
|
1446
|
+
|
|
1447
|
+
return active_names
|
|
1448
|
+
|
|
1449
|
+
def _build_bulk_stmt(
|
|
1450
|
+
self,
|
|
1451
|
+
table: Table,
|
|
1452
|
+
action: OpAction,
|
|
1453
|
+
index_elements: Optional[tuple],
|
|
1454
|
+
records: List[Dict[str, Any]],
|
|
1455
|
+
) -> Any:
|
|
1456
|
+
"""
|
|
1457
|
+
Helper: Constructs the SQLAlchemy Core statement (Insert or Upsert)
|
|
1458
|
+
with the correct ON CONFLICT and RETURNING clauses.
|
|
1459
|
+
"""
|
|
1460
|
+
stmt = insert(table).values(records)
|
|
1461
|
+
pk_col = list(table.primary_key.columns)[0]
|
|
1462
|
+
|
|
1463
|
+
if action == OpAction.INSERT:
|
|
1464
|
+
# INSERT: Do Nothing on conflict, Return ID for count
|
|
1465
|
+
return stmt.on_conflict_do_nothing().returning(pk_col)
|
|
1466
|
+
|
|
1467
|
+
if action == OpAction.UPSERT:
|
|
1468
|
+
# UPSERT: Do Update on conflict, Return ID + xmax
|
|
1469
|
+
if not index_elements:
|
|
1470
|
+
raise ValueError(f"Upsert on {table.name} missing index_elements.")
|
|
1471
|
+
|
|
1472
|
+
update_set = self._build_upsert_set_clause(table, index_elements, records[0].keys())
|
|
1473
|
+
|
|
1474
|
+
# If there are no columns to update (e.g. only PK and Conflict Keys exist),
|
|
1475
|
+
# we must fallback to DO NOTHING to avoid a Postgres syntax error.
|
|
1476
|
+
if not update_set:
|
|
1477
|
+
return stmt.on_conflict_do_nothing().returning(pk_col)
|
|
1478
|
+
|
|
1479
|
+
return stmt.on_conflict_do_update(index_elements=index_elements, set_=update_set).returning(
|
|
1480
|
+
pk_col, text("xmax")
|
|
1481
|
+
)
|
|
1482
|
+
|
|
1483
|
+
raise ValueError(f"Unknown OpAction: {action}")
|
|
1484
|
+
|
|
1485
|
+
def _build_upsert_set_clause(
|
|
1486
|
+
self, table: Table, index_elements: tuple, record_keys: Iterable[str]
|
|
1487
|
+
) -> Dict[str, Any]:
|
|
1488
|
+
"""
|
|
1489
|
+
Helper: Builds the 'set_=' dictionary for ON CONFLICT DO UPDATE.
|
|
1490
|
+
Maps columns to EXCLUDED.column unless overridden by onupdate.
|
|
1491
|
+
"""
|
|
1492
|
+
# Default: Update everything provided in the payload (except keys)
|
|
1493
|
+
primary_keys = {col.name for col in table.primary_key.columns}
|
|
1494
|
+
update_set = {
|
|
1495
|
+
key: getattr(insert(table).excluded, key)
|
|
1496
|
+
for key in record_keys
|
|
1497
|
+
if key not in index_elements and key not in primary_keys
|
|
1498
|
+
}
|
|
1499
|
+
|
|
1500
|
+
# Override: Python OnUpdates (System Overrides) must be forced
|
|
1501
|
+
for col in table.columns:
|
|
1502
|
+
if col.onupdate:
|
|
1503
|
+
expr = None
|
|
1504
|
+
if hasattr(col.onupdate, "arg") and isinstance(col.onupdate.arg, ClauseElement):
|
|
1505
|
+
expr = col.onupdate.arg
|
|
1506
|
+
elif isinstance(col.onupdate, ClauseElement):
|
|
1507
|
+
expr = col.onupdate
|
|
1508
|
+
|
|
1509
|
+
if expr is not None:
|
|
1510
|
+
update_set[col.name] = expr
|
|
1511
|
+
|
|
1512
|
+
return update_set
|
|
1513
|
+
|
|
1514
|
+
def _calculate_bulk_stats(self, action: OpAction, rows: List[Any], total_ops: int) -> _BatchStats:
|
|
1515
|
+
"""
|
|
1516
|
+
Helper: Analyzes the RETURNING results to produce accurate counts.
|
|
1517
|
+
"""
|
|
1518
|
+
stats = _BatchStats()
|
|
1519
|
+
returned_count = len(rows)
|
|
1520
|
+
|
|
1521
|
+
if action == OpAction.INSERT:
|
|
1522
|
+
stats.inserts = returned_count
|
|
1523
|
+
stats.duplicates = total_ops - returned_count
|
|
1524
|
+
return stats
|
|
1525
|
+
|
|
1526
|
+
if action == OpAction.UPSERT:
|
|
1527
|
+
stats.upserts = total_ops # Intent: We tried to upsert this many
|
|
1528
|
+
|
|
1529
|
+
# Handle the fallback case (Empty Update -> Do Nothing).
|
|
1530
|
+
# If we fell back, we only requested pk_col, so rows will have length 1 (no xmax).
|
|
1531
|
+
# If row exists, it's an "Insert" (conceptually created/ensured).
|
|
1532
|
+
if rows and len(rows[0]) < 2:
|
|
1533
|
+
stats.upsert_inserts = returned_count
|
|
1534
|
+
stats.duplicates = total_ops - returned_count
|
|
1535
|
+
return stats
|
|
1536
|
+
|
|
1537
|
+
# Analyze xmax to separate Inserts from Updates
|
|
1538
|
+
# xmax=0 implies insertion; xmax!=0 implies update.
|
|
1539
|
+
created = 0
|
|
1540
|
+
updated = 0
|
|
1541
|
+
|
|
1542
|
+
for row in rows:
|
|
1543
|
+
if row.xmax == 0:
|
|
1544
|
+
created += 1
|
|
1545
|
+
else:
|
|
1546
|
+
updated += 1
|
|
1547
|
+
|
|
1548
|
+
stats.upsert_inserts = created
|
|
1549
|
+
stats.upsert_updates = updated
|
|
1550
|
+
# In DO UPDATE, duplicates usually don't happen (unless filtered by WHERE)
|
|
1551
|
+
stats.duplicates = total_ops - returned_count
|
|
1552
|
+
return stats
|
|
1553
|
+
|
|
1554
|
+
return stats
|
|
1555
|
+
|
|
1556
|
+
def _flush_batch_resilient(self, units: List[AtomicUnit]) -> _BatchStats:
|
|
1557
|
+
"""
|
|
1558
|
+
Recovery Mode with Safety Valve.
|
|
1559
|
+
|
|
1560
|
+
- Iterates through units one by one.
|
|
1561
|
+
- Uses Savepoints for isolation.
|
|
1562
|
+
- SAFETY VALVE: Forces a commit every MAX_FAILURE_BURST failures to flush "Dead Transaction IDs"
|
|
1563
|
+
from memory, preventing OOM/Lock exhaustion.
|
|
1564
|
+
"""
|
|
1565
|
+
self.start()
|
|
1566
|
+
self._check_session_duration()
|
|
1567
|
+
|
|
1568
|
+
stats = _BatchStats()
|
|
1569
|
+
|
|
1570
|
+
try:
|
|
1571
|
+
for i, unit in enumerate(units):
|
|
1572
|
+
try:
|
|
1573
|
+
self._check_session_duration()
|
|
1574
|
+
|
|
1575
|
+
# Firewall: Each AtomicUnit gets its own isolated transaction
|
|
1576
|
+
with self.session.begin_nested():
|
|
1577
|
+
# Track stats for this specific unit
|
|
1578
|
+
unit_inserts = 0
|
|
1579
|
+
unit_upserts = 0
|
|
1580
|
+
unit_upsert_inserts = 0
|
|
1581
|
+
unit_upsert_updates = 0
|
|
1582
|
+
unit_duplicates = 0
|
|
1583
|
+
|
|
1584
|
+
for op in unit.ops:
|
|
1585
|
+
written, xmax = self._apply_operation(op)
|
|
1586
|
+
|
|
1587
|
+
if written > 0:
|
|
1588
|
+
if op.action == OpAction.INSERT:
|
|
1589
|
+
unit_inserts += 1
|
|
1590
|
+
elif op.action == OpAction.UPSERT:
|
|
1591
|
+
unit_upserts += 1
|
|
1592
|
+
if xmax == 0:
|
|
1593
|
+
unit_upsert_inserts += 1
|
|
1594
|
+
else:
|
|
1595
|
+
unit_upsert_updates += 1
|
|
1596
|
+
else:
|
|
1597
|
+
unit_duplicates += 1
|
|
1598
|
+
|
|
1599
|
+
# Only commit stats if the unit succeeds
|
|
1600
|
+
stats.inserts += unit_inserts
|
|
1601
|
+
stats.upserts += unit_upserts
|
|
1602
|
+
stats.upsert_inserts += unit_upsert_inserts
|
|
1603
|
+
stats.upsert_updates += unit_upsert_updates
|
|
1604
|
+
stats.duplicates += unit_duplicates
|
|
1605
|
+
|
|
1606
|
+
except Exception as e:
|
|
1607
|
+
# Granular Failure: Only THIS unit is lost.
|
|
1608
|
+
stats.failures += 1
|
|
1609
|
+
logger.error(
|
|
1610
|
+
f"Recovery: AtomicUnit failed (index {i} in batch). Discarding {len(unit)} ops. Error: {e}"
|
|
1611
|
+
)
|
|
1612
|
+
|
|
1613
|
+
# SAFETY VALVE: Check if we have accumulated too many dead transactions
|
|
1614
|
+
if stats.failures > 0 and stats.failures % MAX_FAILURE_BURST == 0:
|
|
1615
|
+
logger.warning(
|
|
1616
|
+
f"Safety Valve: {stats.failures} failures accumulated. "
|
|
1617
|
+
"Committing now to clear transaction memory and free up locks."
|
|
1618
|
+
)
|
|
1619
|
+
self.session.commit()
|
|
1620
|
+
|
|
1621
|
+
self._check_session_duration()
|
|
1622
|
+
# Commit whatever survived
|
|
1623
|
+
self.stop(commit=True)
|
|
1624
|
+
return stats
|
|
1625
|
+
|
|
1626
|
+
except Exception as e:
|
|
1627
|
+
# Critical Failure (e.g. DB Down).
|
|
1628
|
+
# Note: Since we might have done intermediate commits, some data might already be saved.
|
|
1629
|
+
# This rollback only affects the pending rows since the last safety commit.
|
|
1630
|
+
logger.error(f"Critical Failure in Recovery Mode: {e}", exc_info=True)
|
|
1631
|
+
self.stop(commit=False)
|
|
1632
|
+
raise e
|
|
1633
|
+
|
|
1634
|
+
def _apply_operation(self, op: BatchOp) -> Tuple[int, int]:
|
|
1635
|
+
"""
|
|
1636
|
+
Apply a single BatchOp to the current session.
|
|
1637
|
+
|
|
1638
|
+
Returns:
|
|
1639
|
+
Tuple[int, int]: (rowcount, xmax).
|
|
1640
|
+
xmax is 0 for inserts/updates that don't return it.
|
|
1641
|
+
"""
|
|
1642
|
+
self._check_session_duration()
|
|
1643
|
+
|
|
1644
|
+
if not isinstance(op, BatchOp):
|
|
1645
|
+
raise ValueError(
|
|
1646
|
+
f"Pipeline Error: Expected BatchOp, got {type(op)}. All operations must be wrapped in BatchOp."
|
|
1647
|
+
)
|
|
1648
|
+
|
|
1649
|
+
if op.action == OpAction.INSERT:
|
|
1650
|
+
# Inserts don't need xmax logic, just rowcount
|
|
1651
|
+
return self._insert_lightweight(op.obj), 0
|
|
1652
|
+
elif op.action == OpAction.UPSERT:
|
|
1653
|
+
if not op.index_elements:
|
|
1654
|
+
raise ValueError(f"Upsert BatchOp missing index_elements: {op}")
|
|
1655
|
+
return self._upsert_lightweight(op.obj, index_elements=op.index_elements)
|
|
1656
|
+
else:
|
|
1657
|
+
raise ValueError(f"Unknown OpAction: {op.action}")
|
|
1658
|
+
|
|
1659
|
+
def _insert_lightweight(self, obj: ModelType) -> int:
|
|
1660
|
+
"""
|
|
1661
|
+
Lightweight INSERT using ON CONFLICT DO NOTHING.
|
|
1662
|
+
|
|
1663
|
+
- Builds a dict from non-None column values.
|
|
1664
|
+
- Skips duplicates silently (logs when a conflict happens).
|
|
1665
|
+
|
|
1666
|
+
Returns:
|
|
1667
|
+
int: rowcount (1 if inserted, 0 if duplicate).
|
|
1668
|
+
"""
|
|
1669
|
+
self._check_session_duration()
|
|
1670
|
+
|
|
1671
|
+
table = obj.__table__
|
|
1672
|
+
data = {c.name: getattr(obj, c.name) for c in table.columns if getattr(obj, c.name) is not None}
|
|
1673
|
+
|
|
1674
|
+
stmt = insert(table).values(data).on_conflict_do_nothing()
|
|
1675
|
+
result = self.session.execute(stmt)
|
|
1676
|
+
|
|
1677
|
+
if result.rowcount == 0:
|
|
1678
|
+
logger.info(f"Duplicate skipped (Unique Violation): {obj}")
|
|
1679
|
+
|
|
1680
|
+
return result.rowcount
|
|
1681
|
+
|
|
1682
|
+
def _upsert_lightweight(self, obj: ModelType, index_elements: List[str]) -> Tuple[int, int]:
|
|
1683
|
+
"""
|
|
1684
|
+
Lightweight UPSERT using ON CONFLICT DO UPDATE.
|
|
1685
|
+
Used primarily in Recovery Mode.
|
|
1686
|
+
"""
|
|
1687
|
+
self._check_session_duration()
|
|
1688
|
+
|
|
1689
|
+
table = obj.__table__
|
|
1690
|
+
|
|
1691
|
+
# INSERT: Standard behavior (SQLAlchemy/DB handles defaults for missing cols)
|
|
1692
|
+
insert_values = {
|
|
1693
|
+
col.name: getattr(obj, col.name) for col in table.columns if getattr(obj, col.name) is not None
|
|
1694
|
+
}
|
|
1695
|
+
|
|
1696
|
+
# UPDATE: Calculate manual priority logic
|
|
1697
|
+
update_set_values = self._get_update_values(table, obj)
|
|
1698
|
+
|
|
1699
|
+
stmt = (
|
|
1700
|
+
insert(table)
|
|
1701
|
+
.values(insert_values)
|
|
1702
|
+
.on_conflict_do_update(index_elements=index_elements, set_=update_set_values)
|
|
1703
|
+
.returning(text("xmax"))
|
|
1704
|
+
)
|
|
1705
|
+
|
|
1706
|
+
result = self.session.execute(stmt)
|
|
1707
|
+
|
|
1708
|
+
# Capture the xmax if a row was returned (written)
|
|
1709
|
+
xmax = 0
|
|
1710
|
+
if result.rowcount > 0:
|
|
1711
|
+
row = result.fetchone()
|
|
1712
|
+
if row is not None:
|
|
1713
|
+
xmax = row.xmax
|
|
1714
|
+
|
|
1715
|
+
return result.rowcount, xmax
|
|
1716
|
+
|
|
1717
|
+
def _get_update_values(self, table, obj) -> Dict[str, Any]:
|
|
1718
|
+
"""
|
|
1719
|
+
Helper: Builds the SET dictionary for the UPDATE clause.
|
|
1720
|
+
|
|
1721
|
+
Change: Now sparse. Only includes columns that are explicitly present on the object
|
|
1722
|
+
or have a System OnUpdate.
|
|
1723
|
+
"""
|
|
1724
|
+
primary_keys = {col.name for col in table.primary_key.columns.values()}
|
|
1725
|
+
update_values = {}
|
|
1726
|
+
|
|
1727
|
+
for col in table.columns:
|
|
1728
|
+
if col.name in primary_keys:
|
|
1729
|
+
continue
|
|
1730
|
+
|
|
1731
|
+
# Skip columns that are None on this object (and have no OnUpdate override)
|
|
1732
|
+
# This prevents triggering context-dependent defaults during Recovery Mode.
|
|
1733
|
+
if getattr(obj, col.name) is None and not col.onupdate:
|
|
1734
|
+
continue
|
|
1735
|
+
|
|
1736
|
+
val = self._resolve_column_value(col, obj)
|
|
1737
|
+
|
|
1738
|
+
# If the resolver returns the special 'SKIP' sentinel, we exclude the col.
|
|
1739
|
+
if val is not Ellipsis:
|
|
1740
|
+
update_values[col.name] = val
|
|
1741
|
+
|
|
1742
|
+
return update_values
|
|
1743
|
+
|
|
1744
|
+
def _resolve_column_value(self, col, obj) -> Any: # noqa: C901
|
|
1745
|
+
"""
|
|
1746
|
+
Helper: Determines the correct value for a single column based on priority.
|
|
1747
|
+
Used by both Bulk Processing and Lightweight Recovery.
|
|
1748
|
+
|
|
1749
|
+
Priority Hierarchy:
|
|
1750
|
+
1. Python onupdate (System Override) -> Wins always.
|
|
1751
|
+
2. Server onupdate (DB Trigger) -> Skips column so DB handles it.
|
|
1752
|
+
3. Explicit Value -> Wins if provided.
|
|
1753
|
+
4. Python Default -> Fallback for None.
|
|
1754
|
+
5. Server Default -> Fallback for None.
|
|
1755
|
+
6. NULL -> Last resort.
|
|
1756
|
+
|
|
1757
|
+
Returns `Ellipsis` (...) if the column should be excluded from the values dict.
|
|
1758
|
+
"""
|
|
1759
|
+
# 1. Python OnUpdate (System Override)
|
|
1760
|
+
# Note: In Bulk Upsert, this might be overwritten by the stmt construction logic,
|
|
1761
|
+
# but we return it here for consistency in Insert/Row-by-Row.
|
|
1762
|
+
if col.onupdate is not None:
|
|
1763
|
+
expr = None
|
|
1764
|
+
if hasattr(col.onupdate, "arg") and isinstance(col.onupdate.arg, ClauseElement):
|
|
1765
|
+
expr = col.onupdate.arg
|
|
1766
|
+
elif isinstance(col.onupdate, ClauseElement):
|
|
1767
|
+
expr = col.onupdate
|
|
1768
|
+
|
|
1769
|
+
if expr is not None:
|
|
1770
|
+
return expr
|
|
1771
|
+
|
|
1772
|
+
# 2. Server OnUpdate (DB Trigger Override)
|
|
1773
|
+
if col.server_onupdate is not None:
|
|
1774
|
+
return Ellipsis # Sentinel to skip
|
|
1775
|
+
|
|
1776
|
+
# 3. Explicit Value
|
|
1777
|
+
val = getattr(obj, col.name)
|
|
1778
|
+
if val is not None:
|
|
1779
|
+
return val
|
|
1780
|
+
|
|
1781
|
+
# Fallback: Value is None
|
|
1782
|
+
|
|
1783
|
+
# 4. Python Default
|
|
1784
|
+
if col.default is not None:
|
|
1785
|
+
arg = col.default.arg
|
|
1786
|
+
if callable(arg):
|
|
1787
|
+
try:
|
|
1788
|
+
return arg()
|
|
1789
|
+
except TypeError as e:
|
|
1790
|
+
raise TypeError(
|
|
1791
|
+
"Calling the python default function failed. "
|
|
1792
|
+
"Most likely attempted to write NULL to a column with a python default that takes context as parameter."
|
|
1793
|
+
) from e
|
|
1794
|
+
else:
|
|
1795
|
+
return arg
|
|
1796
|
+
|
|
1797
|
+
# 5. Server Default
|
|
1798
|
+
if col.server_default is not None:
|
|
1799
|
+
return col.server_default.arg
|
|
1800
|
+
|
|
1801
|
+
# 6. Explicit NULL
|
|
1802
|
+
return None
|
|
1803
|
+
|
|
1804
|
+
|
|
1805
|
+
class _BatchUtilities:
|
|
1806
|
+
"""
|
|
1807
|
+
Mixin for High-Performance Batch Operations with Smart Default Handling.
|
|
1808
|
+
|
|
1809
|
+
Features:
|
|
1810
|
+
- Smart Defaults: Automatically handles Python static defaults and SQL Server defaults.
|
|
1811
|
+
- Recursive Bisecting: Isolates bad rows without stopping the whole batch.
|
|
1812
|
+
|
|
1813
|
+
Expected Interface:
|
|
1814
|
+
- self.session: sqlalchemy.orm.Session
|
|
1815
|
+
- self.start(): Method to start session
|
|
1816
|
+
- self.stop(commit=bool): Method to stop session
|
|
1817
|
+
- self.log_row(...): Optional helper for logging
|
|
1818
|
+
"""
|
|
1819
|
+
|
|
1820
|
+
def insert_dataframe(
|
|
1821
|
+
self, df: pd.DataFrame, model: Type["ModelType"], batch_size: int = 5000, verbose: bool = False
|
|
1822
|
+
) -> None:
|
|
1823
|
+
"""
|
|
1824
|
+
High-Performance Smart Bulk Insert.
|
|
1825
|
+
|
|
1826
|
+
Logic Priority for Empty/Null Values:
|
|
1827
|
+
1. Python Default (e.g. default="pending") -> Applied in-memory via Pandas.
|
|
1828
|
+
2. Python Callable (e.g. default=uuid4) -> Executed per-row in-memory.
|
|
1829
|
+
3. Server Default (e.g. server_default=text("now()")) -> Applied via SQL COALESCE.
|
|
1830
|
+
4. NULL -> If none of the above exist.
|
|
1831
|
+
|
|
1832
|
+
CRITICAL LIMITATIONS:
|
|
1833
|
+
- NO EXPLICIT NULLS: If a column has a default (Python or Server), sending None/NaN
|
|
1834
|
+
will ALWAYS trigger that default. You cannot force a NULL value into such a column.
|
|
1835
|
+
|
|
1836
|
+
Args:
|
|
1837
|
+
df: Pandas DataFrame containing the data.
|
|
1838
|
+
model: SQLAlchemy Model class.
|
|
1839
|
+
batch_size: Rows per transaction chunk.
|
|
1840
|
+
verbose: If True, logs individual rows during recursive failure handling.
|
|
1841
|
+
"""
|
|
1842
|
+
_require_pandas()
|
|
1843
|
+
|
|
1844
|
+
if self.session is not None:
|
|
1845
|
+
raise RuntimeError("insert_dataframe cannot be called when a session is already active.")
|
|
1846
|
+
|
|
1847
|
+
if df.empty:
|
|
1848
|
+
logger.warning("Pipeline: DataFrame is empty. Nothing to insert.")
|
|
1849
|
+
return
|
|
1850
|
+
|
|
1851
|
+
# Apply Python-side Defaults (In-Memory)
|
|
1852
|
+
df_processed = self._apply_python_defaults(df, model)
|
|
1853
|
+
|
|
1854
|
+
# Prepare Records (NaN -> None for SQL binding)
|
|
1855
|
+
records = df_processed.replace({np.nan: None}).to_dict("records")
|
|
1856
|
+
total_records = len(records)
|
|
1857
|
+
|
|
1858
|
+
logger.info(f"BULK INSERT START | Model: {model.__name__} | Records: {total_records}")
|
|
1859
|
+
|
|
1860
|
+
# Create Smart SQL Statement (COALESCE logic)
|
|
1861
|
+
smart_stmt = self._create_smart_insert_stmt(model, df_processed.columns)
|
|
1862
|
+
|
|
1863
|
+
self.start()
|
|
1864
|
+
job_stats = _BatchStats()
|
|
1865
|
+
job_start = time.time()
|
|
1866
|
+
|
|
1867
|
+
try:
|
|
1868
|
+
for i in range(0, total_records, batch_size):
|
|
1869
|
+
self._check_session_duration()
|
|
1870
|
+
chunk = records[i : i + batch_size]
|
|
1871
|
+
|
|
1872
|
+
# Recursive Engine using the pre-compiled Smart Statement
|
|
1873
|
+
self._insert_recursive_smart(chunk, smart_stmt, model, job_stats, verbose)
|
|
1874
|
+
|
|
1875
|
+
self.session.commit()
|
|
1876
|
+
|
|
1877
|
+
elapsed = time.time() - job_start
|
|
1878
|
+
ops_sec = job_stats.size / elapsed if elapsed > 0 else 0
|
|
1879
|
+
logger.info(
|
|
1880
|
+
f"BULK PROGRESS | Processed: {min(i + batch_size, total_records)}/{total_records} | "
|
|
1881
|
+
f"Written: {job_stats.inserts} | Dupes: {job_stats.duplicates} | "
|
|
1882
|
+
f"Skipped: {job_stats.failures} | Speed: {ops_sec:.0f} rows/s"
|
|
1883
|
+
)
|
|
1884
|
+
|
|
1885
|
+
except Exception:
|
|
1886
|
+
logger.error("Critical Failure in Bulk Insert", exc_info=True)
|
|
1887
|
+
self.stop(commit=False)
|
|
1888
|
+
raise
|
|
1889
|
+
finally:
|
|
1890
|
+
self.stop(commit=False)
|
|
1891
|
+
logger.info(f"BULK INSERT FINISHED | Total Time: {time.time() - job_start:.2f}s")
|
|
1892
|
+
|
|
1893
|
+
def _apply_python_defaults(self, df: pd.DataFrame, model: Type["ModelType"]) -> pd.DataFrame:
|
|
1894
|
+
"""
|
|
1895
|
+
Fills NaNs with Python-side defaults.
|
|
1896
|
+
- Static values (int, str): Vectorized fill (Fast).
|
|
1897
|
+
- Simple functions (uuid4, datetime.now): Applied per-row (Slower).
|
|
1898
|
+
"""
|
|
1899
|
+
df_copy = df.copy()
|
|
1900
|
+
|
|
1901
|
+
for col in model.__table__.columns:
|
|
1902
|
+
# Skip if column irrelevant, no default, or fully populated
|
|
1903
|
+
if col.name not in df_copy.columns or not df_copy[col.name].hasnans:
|
|
1904
|
+
continue
|
|
1905
|
+
|
|
1906
|
+
if col.default is None or not hasattr(col.default, "arg"):
|
|
1907
|
+
continue
|
|
1908
|
+
|
|
1909
|
+
default_arg = col.default.arg
|
|
1910
|
+
|
|
1911
|
+
# Case A: Static Value (Fast Vectorized Fill)
|
|
1912
|
+
if not callable(default_arg):
|
|
1913
|
+
df_copy[col.name] = df_copy[col.name].fillna(default_arg)
|
|
1914
|
+
continue
|
|
1915
|
+
|
|
1916
|
+
# Case B: Callable (Slow Row-by-Row Fill)
|
|
1917
|
+
# We assume it takes 0 arguments. If not raise error (default with context)
|
|
1918
|
+
mask = df_copy[col.name].isna()
|
|
1919
|
+
try:
|
|
1920
|
+
fill_values = [default_arg() for _ in range(mask.sum())]
|
|
1921
|
+
except TypeError as e:
|
|
1922
|
+
raise TypeError(
|
|
1923
|
+
"Calling the python default function failed. "
|
|
1924
|
+
"Most likely attempted to write NULL to a column with a python default that takes context as parameter."
|
|
1925
|
+
) from e
|
|
1926
|
+
df_copy.loc[mask, col.name] = fill_values
|
|
1927
|
+
|
|
1928
|
+
return df_copy
|
|
1929
|
+
|
|
1930
|
+
def _create_smart_insert_stmt(self, model: Type["ModelType"], df_columns: Sequence[str]):
|
|
1931
|
+
"""
|
|
1932
|
+
Helper: Builds an INSERT ... ON CONFLICT ... RETURNING statement.
|
|
1933
|
+
"""
|
|
1934
|
+
table = model.__table__
|
|
1935
|
+
values_dict = {}
|
|
1936
|
+
|
|
1937
|
+
for col_name in df_columns:
|
|
1938
|
+
if col_name not in table.columns:
|
|
1939
|
+
continue
|
|
1940
|
+
|
|
1941
|
+
col = table.columns[col_name]
|
|
1942
|
+
|
|
1943
|
+
if col.server_default is not None and hasattr(col.server_default, "arg"):
|
|
1944
|
+
values_dict[col_name] = func.coalesce(bindparam(col_name, type_=col.type), col.server_default.arg)
|
|
1945
|
+
else:
|
|
1946
|
+
values_dict[col_name] = bindparam(col_name, type_=col.type)
|
|
1947
|
+
|
|
1948
|
+
# Add .returning() at the end so we get back the IDs of inserted rows
|
|
1949
|
+
# We dynamically grab the primary key column(s) to return.
|
|
1950
|
+
pk_cols = [c for c in table.primary_key.columns]
|
|
1951
|
+
|
|
1952
|
+
return insert(table).values(values_dict).on_conflict_do_nothing().returning(*pk_cols)
|
|
1953
|
+
|
|
1954
|
+
def _insert_recursive_smart(
|
|
1955
|
+
self, records: List[Dict[str, Any]], stmt, model: Type["ModelType"], stats: _BatchStats, verbose: bool
|
|
1956
|
+
) -> None:
|
|
1957
|
+
"""
|
|
1958
|
+
Recursive bisecting engine using the pre-compiled Smart Statement.
|
|
1959
|
+
Uses SAVEPOINTs (begin_nested) to isolate errors without committing.
|
|
1960
|
+
"""
|
|
1961
|
+
if not records:
|
|
1962
|
+
return
|
|
1963
|
+
|
|
1964
|
+
# Base Case: Single Row
|
|
1965
|
+
if len(records) == 1:
|
|
1966
|
+
try:
|
|
1967
|
+
with self.session.begin_nested():
|
|
1968
|
+
result = self.session.execute(stmt, records)
|
|
1969
|
+
# For single row, simple boolean check works
|
|
1970
|
+
written = len(result.all()) # Will be 1 (written) or 0 (duplicate)
|
|
1971
|
+
duplicates = 1 - written
|
|
1972
|
+
stats.inserts += written
|
|
1973
|
+
stats.duplicates += duplicates
|
|
1974
|
+
except Exception as e:
|
|
1975
|
+
stats.failures += 1
|
|
1976
|
+
stats.size += 1
|
|
1977
|
+
if hasattr(self, "log_row"):
|
|
1978
|
+
self.log_row(model(**records[0]), prefix_msg="SKIPPING BAD ROW")
|
|
1979
|
+
logger.error(f"Bulk Insert Error on single row: {e}")
|
|
1980
|
+
return
|
|
1981
|
+
|
|
1982
|
+
# Recursive Step
|
|
1983
|
+
try:
|
|
1984
|
+
with self.session.begin_nested():
|
|
1985
|
+
result = self.session.execute(stmt, records)
|
|
1986
|
+
|
|
1987
|
+
# Count the actually returned rows (IDs)
|
|
1988
|
+
# This works because "ON CONFLICT DO NOTHING" returns NOTHING for duplicates.
|
|
1989
|
+
# result.all() fetches the list of returned PKs.
|
|
1990
|
+
written = len(result.all())
|
|
1991
|
+
|
|
1992
|
+
duplicates = len(records) - written
|
|
1993
|
+
|
|
1994
|
+
stats.inserts += written
|
|
1995
|
+
stats.duplicates += duplicates
|
|
1996
|
+
stats.size += len(records)
|
|
1997
|
+
|
|
1998
|
+
except Exception:
|
|
1999
|
+
# Failure -> Split and Retry
|
|
2000
|
+
mid = len(records) // 2
|
|
2001
|
+
self._insert_recursive_smart(records[:mid], stmt, model, stats, verbose)
|
|
2002
|
+
self._insert_recursive_smart(records[mid:], stmt, model, stats, verbose)
|
|
2003
|
+
|
|
2004
|
+
|
|
2005
|
+
class AlchemyInterface(_BaseAlchemyCore, _LegacyAlchemyOps, _BatchPipelineOps, _BatchUtilities):
|
|
2006
|
+
"""
|
|
2007
|
+
Concrete interface combining:
|
|
2008
|
+
|
|
2009
|
+
- BaseAlchemyCore: engine/session management, DDL utilities, core helpers.
|
|
2010
|
+
- LegacyAlchemyOps: legacy insert/upsert/windowed query APIs (kept for compatibility).
|
|
2011
|
+
- BatchPipelineOps: modern batch processing and safe iteration utilities.
|
|
2012
|
+
- BatchUtilities: high-performance bulk operations (insert_dataframe) with smart defaults.
|
|
2013
|
+
|
|
2014
|
+
This is the class intended to be used by application code.
|
|
2015
|
+
"""
|
|
2016
|
+
|
|
2017
|
+
pass
|
|
2018
|
+
|
|
110
2019
|
|
|
111
|
-
|
|
2020
|
+
__all__ = ["RowIterator", "AtomIterator", "AtomicUnit", "BatchOp", "OpAction", "AlchemyInterface"]
|