datamarket 0.6.0__py3-none-any.whl → 0.10.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datamarket might be problematic. Click here for more details.

Files changed (38) hide show
  1. datamarket/__init__.py +0 -1
  2. datamarket/exceptions/__init__.py +1 -0
  3. datamarket/exceptions/main.py +118 -0
  4. datamarket/interfaces/alchemy.py +1934 -25
  5. datamarket/interfaces/aws.py +81 -14
  6. datamarket/interfaces/azure.py +127 -0
  7. datamarket/interfaces/drive.py +60 -10
  8. datamarket/interfaces/ftp.py +37 -14
  9. datamarket/interfaces/llm.py +1220 -0
  10. datamarket/interfaces/nominatim.py +314 -42
  11. datamarket/interfaces/peerdb.py +272 -104
  12. datamarket/interfaces/proxy.py +354 -50
  13. datamarket/interfaces/tinybird.py +7 -15
  14. datamarket/params/nominatim.py +439 -0
  15. datamarket/utils/__init__.py +1 -1
  16. datamarket/utils/airflow.py +10 -7
  17. datamarket/utils/alchemy.py +2 -1
  18. datamarket/utils/logs.py +88 -0
  19. datamarket/utils/main.py +138 -10
  20. datamarket/utils/nominatim.py +201 -0
  21. datamarket/utils/playwright/__init__.py +0 -0
  22. datamarket/utils/playwright/async_api.py +274 -0
  23. datamarket/utils/playwright/sync_api.py +281 -0
  24. datamarket/utils/requests.py +655 -0
  25. datamarket/utils/selenium.py +6 -12
  26. datamarket/utils/strings/__init__.py +1 -0
  27. datamarket/utils/strings/normalization.py +217 -0
  28. datamarket/utils/strings/obfuscation.py +153 -0
  29. datamarket/utils/strings/standardization.py +40 -0
  30. datamarket/utils/typer.py +2 -1
  31. datamarket/utils/types.py +1 -0
  32. datamarket-0.10.3.dist-info/METADATA +172 -0
  33. datamarket-0.10.3.dist-info/RECORD +38 -0
  34. {datamarket-0.6.0.dist-info → datamarket-0.10.3.dist-info}/WHEEL +1 -2
  35. datamarket-0.6.0.dist-info/METADATA +0 -49
  36. datamarket-0.6.0.dist-info/RECORD +0 -24
  37. datamarket-0.6.0.dist-info/top_level.txt +0 -1
  38. {datamarket-0.6.0.dist-info → datamarket-0.10.3.dist-info/licenses}/LICENSE +0 -0
@@ -1,41 +1,702 @@
1
1
  ########################################################################################################################
2
2
  # IMPORTS
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  import logging
7
+ import time
8
+ from collections import defaultdict, deque
9
+ from collections.abc import MutableMapping
10
+ from dataclasses import dataclass, field
11
+ from enum import Enum, auto
12
+ from typing import (
13
+ TYPE_CHECKING,
14
+ Any,
15
+ Callable,
16
+ Deque,
17
+ Dict,
18
+ Iterable,
19
+ Iterator,
20
+ List,
21
+ Optional,
22
+ Sequence,
23
+ Set,
24
+ Tuple,
25
+ Type,
26
+ TypeVar,
27
+ Union,
28
+ )
5
29
  from urllib.parse import quote_plus
6
30
 
7
- from sqlalchemy import DDL, create_engine
31
+ import numpy as np
32
+ from sqlalchemy import (
33
+ DDL,
34
+ FrozenResult,
35
+ Result,
36
+ Select,
37
+ SQLColumnExpression,
38
+ Table,
39
+ and_,
40
+ bindparam,
41
+ create_engine,
42
+ func,
43
+ inspect,
44
+ or_,
45
+ select,
46
+ text,
47
+ )
48
+ from sqlalchemy.dialects.postgresql import insert
8
49
  from sqlalchemy.exc import IntegrityError
9
- from sqlalchemy.orm import sessionmaker
50
+ from sqlalchemy.ext.declarative import DeclarativeMeta
51
+ from sqlalchemy.orm import Query, Session, sessionmaker
52
+ from sqlalchemy.sql.elements import UnaryExpression
53
+ from sqlalchemy.sql.expression import ClauseElement
54
+ from sqlalchemy.sql.operators import desc_op
55
+
56
+ from datamarket.utils.logs import SystemColor, colorize
57
+
58
+ try:
59
+ import pandas as pd # type: ignore
60
+ except ImportError:
61
+ pd = None # type: ignore
62
+
63
+ if TYPE_CHECKING:
64
+ import pandas as pd # noqa: F401
65
+
10
66
 
11
67
  ########################################################################################################################
12
- # CLASSES
68
+ # TYPES / CONSTANTS
69
+
13
70
 
14
71
  logger = logging.getLogger(__name__)
15
72
 
73
+ ModelType = TypeVar("ModelType", bound=DeclarativeMeta)
74
+
75
+
76
+ class OpAction(Enum):
77
+ INSERT = "insert"
78
+ UPSERT = "upsert"
79
+
80
+
81
+ @dataclass(frozen=True)
82
+ class BatchOp:
83
+ """
84
+ Immutable representation of a single SQL operation.
85
+ """
86
+
87
+ obj: ModelType
88
+ action: OpAction
89
+ index_elements: Optional[List[str]] = None
90
+
91
+ @classmethod
92
+ def insert(cls, obj: ModelType) -> "BatchOp":
93
+ return cls(obj=obj, action=OpAction.INSERT)
94
+
95
+ @classmethod
96
+ def upsert(cls, obj: ModelType, index_elements: List[str]) -> "BatchOp":
97
+ if not index_elements:
98
+ raise ValueError("Upsert requires 'index_elements'.")
99
+ return cls(obj=obj, action=OpAction.UPSERT, index_elements=index_elements)
100
+
101
+
102
+ @dataclass
103
+ class AtomicUnit:
104
+ """
105
+ A mutable container for BatchOps that must succeed or fail together.
106
+ """
107
+
108
+ ops: List[BatchOp] = field(default_factory=list)
109
+
110
+ def add(self, op: Union[BatchOp, List[BatchOp]]) -> "AtomicUnit":
111
+ """
112
+ Add an existing BatchOp (or list of BatchOps) to the unit.
113
+ """
114
+ if isinstance(op, list):
115
+ self.ops.extend(op)
116
+ else:
117
+ self.ops.append(op)
118
+ return self
119
+
120
+ def add_insert(self, obj: ModelType) -> "AtomicUnit":
121
+ """Helper to create and add an Insert op."""
122
+ return self.add(BatchOp.insert(obj))
123
+
124
+ def add_upsert(self, obj: ModelType, index_elements: List[str]) -> "AtomicUnit":
125
+ """Helper to create and add an Upsert op."""
126
+ return self.add(BatchOp.upsert(obj, index_elements))
127
+
128
+ def __len__(self) -> int:
129
+ return len(self.ops)
130
+
131
+
132
+ @dataclass
133
+ class _BatchStats:
134
+ """Helper to track pipeline statistics explicitly."""
135
+
136
+ size: int = 0 # Total operations
137
+ inserts: int = 0 # Pure Inserts (OpAction.INSERT)
138
+
139
+ # Upsert Breakdown
140
+ upserts: int = 0 # Total Upsert Intent (OpAction.UPSERT)
141
+ upsert_inserts: int = 0 # Outcome: Row created
142
+ upsert_updates: int = 0 # Outcome: Row modified
143
+
144
+ duplicates: int = 0 # Skipped ops (rowcount == 0)
145
+ failures: int = 0 # Failed/Skipped atomic units (exception thrown)
146
+ db_time: float = 0.0 # Time spent in DB flush
147
+
148
+ def add(self, other: "_BatchStats") -> None:
149
+ """Accumulate stats from another batch."""
150
+ self.size += other.size
151
+ self.inserts += other.inserts
152
+ self.upserts += other.upserts
153
+ self.upsert_inserts += other.upsert_inserts
154
+ self.upsert_updates += other.upsert_updates
155
+ self.duplicates += other.duplicates
156
+ self.failures += other.failures
157
+ self.db_time += other.db_time
158
+
159
+
160
+ # Input: ModelType from a RowIterator.
161
+ # Output: AtomicUnit (or list of them).
162
+ ProcessorFunc = Callable[[ModelType], Union[AtomicUnit, Sequence[AtomicUnit], None]]
163
+
164
+ # Below this size, batch processing is likely inefficient and considered suspicious.
165
+ MIN_BATCH_WARNING = 100
166
+
167
+ # Above this size, a batch is considered too large and will trigger a warning.
168
+ MAX_BATCH_WARNING = 5_000
169
+
170
+ # Safety limit: max operations allowed for a single input atomic item.
171
+ MAX_OPS_PER_ATOM = 50
172
+
173
+ MAX_FAILURE_BURST = 50
174
+
175
+ # Session misuse guardrails (long-lived session detector).
176
+ SESSION_WARN_THRESHOLD_SECONDS = 300.0
177
+ SESSION_WARN_REPEAT_SECONDS = 60.0
178
+
179
+
180
+ class CommitStrategy(Enum):
181
+ """Commit behavior control for windowed_query (legacy API)."""
182
+
183
+ COMMIT_ON_SUCCESS = auto()
184
+ FORCE_COMMIT = auto()
185
+
186
+
187
+ class MockContext:
188
+ """
189
+ Lightweight mock context passed to SQLAlchemy default/onupdate callables
190
+ that expect a context-like object.
191
+ """
192
+
193
+ def __init__(self, column: SQLColumnExpression) -> None:
194
+ self.current_parameters = {}
195
+ self.current_column = column
196
+ self.connection = None
197
+
198
+
199
+ ########################################################################################################################
200
+ # EXCEPTIONS
201
+
202
+
203
+ class MissingOptionalDependency(ImportError):
204
+ pass
205
+
206
+
207
+ def _require_pandas() -> Any:
208
+ if pd is None:
209
+ raise MissingOptionalDependency(
210
+ "This feature requires pandas. Add the extra dependency. Example: pip install datamarket[pandas]"
211
+ )
212
+ return pd
213
+
214
+
215
+ ########################################################################################################################
216
+ # ITERATORS
217
+
218
+
219
+ # TODO: enable iterating over joins of tables (multiple entities)
220
+ class RowIterator(Iterator[ModelType]):
221
+ """
222
+ Stateful iterator that performs Composite Keyset Pagination.
223
+
224
+ It dynamically builds complex SQL `WHERE` predicates (nested OR/AND logic)
225
+ to support efficient pagination across multiple columns with mixed sort directions
226
+ (e.g., [Date DESC, Category ASC, ID ASC]).
227
+
228
+ Includes Auto-Rerun logic: If new rows appear "in the past" (behind the cursor)
229
+ during iteration, it detects them at the end and performs a targeted rerun using
230
+ a "Chained Snapshot" strategy to ensure data consistency.
231
+
232
+ Manages the session lifecycle internally (Open -> Fetch -> Expunge -> Close).
233
+ """
234
+
235
+ def __init__(
236
+ self,
237
+ interface: "_BatchPipelineOps",
238
+ source: Union[Select[Any], Query],
239
+ chunk_size: int,
240
+ order_by: Optional[List[SQLColumnExpression[Any]]] = None,
241
+ limit: Optional[int] = None,
242
+ ):
243
+ self.interface = interface
244
+ self.chunk_size = chunk_size
245
+ self.limit = limit
246
+ self.yielded_count = 0
247
+
248
+ if not isinstance(source, (Select, Query)):
249
+ raise ValueError("RowIterator expects a SQLAlchemy Select or Query object.")
250
+
251
+ # Normalize Source
252
+ self.stmt = source.statement if isinstance(source, Query) and hasattr(source, "statement") else source
253
+
254
+ # Identify Primary Key(s) (Tie-Breaker)
255
+ desc = self.stmt.column_descriptions[0]
256
+ entity = desc["entity"]
257
+ primary_keys = list(inspect(entity).primary_key)
258
+
259
+ if not primary_keys:
260
+ raise ValueError(f"RowIterator: Model {entity} has no primary key.")
261
+
262
+ self.pk_cols = primary_keys # List of PK columns
263
+
264
+ # Construct composite sort key
265
+ self.sort_cols = []
266
+ user_sort_col_names = set()
267
+ if order_by:
268
+ self.sort_cols.extend(order_by)
269
+ user_sort_col_names = {self._get_col_name(c) for c in order_by}
270
+
271
+ # Ensure all PK columns are present in sort_cols (preserving user direction if present)
272
+ for pk_col in self.pk_cols:
273
+ pk_name = self._get_col_name(pk_col)
274
+ if pk_name not in user_sort_col_names:
275
+ self.sort_cols.append(pk_col)
276
+
277
+ # State Management
278
+ self.last_vals: Optional[tuple] = None
279
+ self._buffer: Deque[ModelType] = deque()
280
+ self._done = False
281
+
282
+ # Snapshot Windows (Chained Snapshots)
283
+ self.id_floor = None
284
+ # Use first PK for snapshot windowing
285
+ self.id_col = self.pk_cols[0]
286
+ with self.interface:
287
+ self.id_ceiling = self._get_max_id_at_start()
288
+ self._catch_up_mode = False
289
+
290
+ # Calculate total rows
291
+ self.total_rows = self._initial_count(source)
292
+
293
+ def _get_max_id_at_start(self) -> Optional[Any]:
294
+ """Captures the current max ID within the query's scope."""
295
+ try:
296
+ # We wrap the original statement in a subquery to respect filters like is_crawled=False.
297
+ sub = self.stmt.subquery()
298
+ stmt = select(func.max(sub.c[self.id_col.name]))
299
+ return self.interface.session.execute(stmt).scalar()
300
+ except Exception as e:
301
+ logger.warning(f"RowIterator: Could not capture snapshot ID: {e}")
302
+ return None
303
+
304
+ def _initial_count(self, source: Union[Select[Any], Query]) -> int:
305
+ """Performs initial count of the dataset."""
306
+ try:
307
+ with self.interface:
308
+ if isinstance(source, Query):
309
+ db_count = source.with_session(self.interface.session).count()
310
+ else:
311
+ subquery = self.stmt.subquery()
312
+ count_stmt = select(func.count()).select_from(subquery)
313
+ db_count = self.interface.session.execute(count_stmt).scalar() or 0
314
+
315
+ return min(db_count, self.limit) if self.limit else db_count
316
+ except Exception as e:
317
+ logger.warning(f"RowIterator: Failed to calculate total rows: {e}")
318
+ return 0
319
+
320
+ def __len__(self) -> int:
321
+ return self.total_rows
322
+
323
+ def __iter__(self) -> "RowIterator":
324
+ return self
325
+
326
+ def __next__(self) -> ModelType:
327
+ # Check global limit
328
+ if self.limit is not None and self.yielded_count >= self.limit:
329
+ raise StopIteration
330
+
331
+ if not self._buffer:
332
+ if self._done:
333
+ raise StopIteration
334
+
335
+ self._fetch_next_chunk()
336
+
337
+ if not self._buffer:
338
+ self._done = True
339
+ raise StopIteration
340
+
341
+ self.yielded_count += 1
342
+ return self._buffer.popleft()
343
+
344
+ def _get_col_name(self, col_expr: Any) -> str:
345
+ """
346
+ Helper to safely get column names from any SQL expression.
347
+ Handles: ORM Columns, Core Columns, Labels, Aliases, Unary (desc).
348
+ """
349
+ element = col_expr
350
+
351
+ # Unwrap Sort Modifiers (DESC, NULLS LAST)
352
+ while isinstance(element, UnaryExpression):
353
+ element = element.element
354
+
355
+ # Handle ORM Attributes (User.id)
356
+ if hasattr(element, "key"):
357
+ return element.key
358
+
359
+ # Handle Labels/Core Columns (table.c.id)
360
+ if hasattr(element, "name"):
361
+ return element.name
362
+
363
+ # Fallback for anonymous expressions
364
+ if hasattr(element, "compile"):
365
+ return str(element.compile(compile_kwargs={"literal_binds": True}))
366
+
367
+ raise ValueError(f"Could not determine name for sort column: {col_expr}")
368
+
369
+ def _fetch_next_chunk(self) -> None:
370
+ """Session-managed fetch logic with auto-rerun support."""
371
+ self.interface.start()
372
+ try:
373
+ # Use a loop to handle the "Rerun" trigger immediately
374
+ while not self._buffer:
375
+ fetch_size = self._calculate_fetch_size()
376
+ if fetch_size <= 0:
377
+ break
378
+
379
+ paged_stmt = self._build_query(fetch_size)
380
+ result = self.interface.session.execute(paged_stmt)
381
+ rows = result.all()
382
+
383
+ if rows:
384
+ self._process_rows(rows)
385
+ break # We have data!
386
+
387
+ # No rows found in current window. Check for new rows (shifts window).
388
+ if not self._check_for_new_rows():
389
+ break # Truly done.
390
+
391
+ # If _check_for_new_rows was True, the loop continues
392
+ # and runs _build_query again with the NEW window.
393
+
394
+ finally:
395
+ self.interface.stop(commit=False)
396
+
397
+ def _calculate_fetch_size(self) -> int:
398
+ if self.limit is None:
399
+ return self.chunk_size
400
+ remaining = self.limit - self.yielded_count
401
+ return min(self.chunk_size, remaining)
402
+
403
+ def _build_query(self, fetch_size: int) -> Select[Any]:
404
+ """Applies snapshot windows and keyset pagination filters."""
405
+ paged_stmt = self.stmt
406
+
407
+ if self.id_ceiling is not None:
408
+ paged_stmt = paged_stmt.where(self.id_col <= self.id_ceiling)
409
+
410
+ if self.id_floor is not None:
411
+ paged_stmt = paged_stmt.where(self.id_col > self.id_floor)
412
+
413
+ if self.last_vals is not None:
414
+ paged_stmt = paged_stmt.where(self._build_keyset_predicate(self.sort_cols, self.last_vals))
415
+
416
+ return paged_stmt.order_by(*self.sort_cols).limit(fetch_size)
417
+
418
+ def _process_rows(self, rows: List[Any]) -> None:
419
+ """Expunges objects and updates the keyset bookmark."""
420
+ for row in rows:
421
+ obj = row[0] if isinstance(row, tuple) or hasattr(row, "_mapping") else row
422
+ self.interface.session.expunge(obj)
423
+ self._buffer.append(obj)
424
+
425
+ # Update pagination bookmark from object
426
+ last_vals_list = []
427
+ for col in self.sort_cols:
428
+ col_name = self._get_col_name(col)
429
+ last_vals_list.append(getattr(obj, col_name))
430
+ self.last_vals = tuple(last_vals_list)
431
+
432
+ def _check_for_new_rows(self) -> bool:
433
+ """Checks for new data, shifts the snapshot window, and triggers rerun."""
434
+ new_ceiling = self._get_max_id_at_start()
435
+
436
+ if new_ceiling is None or (self.id_ceiling is not None and new_ceiling <= self.id_ceiling):
437
+ return False
438
+
439
+ try:
440
+ check_stmt = self.stmt.where(self.id_col > self.id_ceiling)
441
+ check_stmt = check_stmt.where(self.id_col <= new_ceiling)
442
+
443
+ sub = check_stmt.subquery()
444
+ count_query = select(func.count()).select_from(sub)
445
+
446
+ new_rows_count = self.interface.session.execute(count_query).scalar() or 0
447
+
448
+ if new_rows_count > 0:
449
+ logger.info(f"RowIterator: Found {new_rows_count} new rows. Shifting window.")
450
+
451
+ if self.limit is not None:
452
+ self.total_rows = min(self.total_rows + new_rows_count, self.limit)
453
+ else:
454
+ self.total_rows += new_rows_count
455
+
456
+ self.id_floor = self.id_ceiling
457
+ self.id_ceiling = new_ceiling
458
+ self.last_vals = None
459
+ self._catch_up_mode = True
460
+ return True
461
+ except Exception as e:
462
+ logger.warning(f"RowIterator: Failed to check for new rows: {e}")
463
+
464
+ return False
465
+
466
+ def _build_keyset_predicate(self, columns: List[Any], last_values: tuple) -> Any:
467
+ """
468
+ Constructs the recursive OR/AND SQL filter for mixed-direction keyset pagination.
469
+
470
+ Logic for columns (A, B, C) compared to values (va, vb, vc):
471
+ 1. (A > va)
472
+ 2. OR (A = va AND B > vb)
473
+ 3. OR (A = va AND B = vb AND C > vc)
474
+
475
+ *Swaps > for < if the column is Descending.
476
+ """
477
+
478
+ conditions = []
479
+
480
+ # We need to build the "Equality Chain" (A=va AND B=vb ...)
481
+ # that acts as the prefix for the next column's check.
482
+ equality_chain = []
483
+
484
+ for i, col_expr in enumerate(columns):
485
+ last_val = last_values[i]
486
+
487
+ # 1. INSPECT DIRECTION (ASC vs DESC)
488
+ is_desc = False
489
+ actual_col = col_expr
490
+
491
+ if isinstance(col_expr, UnaryExpression):
492
+ actual_col = col_expr.element
493
+ # Robust check using SQLAlchemy operator identity
494
+ if col_expr.modifier == desc_op:
495
+ is_desc = True
496
+
497
+ # 2. DETERMINE OPERATOR
498
+ # ASC: col > val | DESC: col < val
499
+ diff_check = actual_col < last_val if is_desc else actual_col > last_val
500
+
501
+ # 3. BUILD THE TERM
502
+ # (Previous Cols Equal) AND (Current Col is Better)
503
+ term = diff_check if not equality_chain else and_(*equality_chain, diff_check)
504
+
505
+ conditions.append(term)
506
+
507
+ # 4. EXTEND EQUALITY CHAIN FOR NEXT LOOP
508
+ # Add (Current Col == Last Val) to the chain
509
+ equality_chain.append(actual_col == last_val)
510
+
511
+ # Combine all terms with OR
512
+ return or_(*conditions)
513
+
514
+
515
+ class AtomIterator(Iterator[AtomicUnit]):
516
+ """
517
+ Iterator wrapper that strictly yields AtomicUnits.
518
+ Used for directly inserting into DB without querying any of its tables first.
519
+ Ensures type safety of the source stream.
520
+ """
521
+
522
+ def __init__(self, source: Iterator[Any], limit: Optional[int] = None):
523
+ self._source = source
524
+ self.limit = limit
525
+ self.yielded_count = 0
526
+
527
+ def __iter__(self) -> "AtomIterator":
528
+ return self
529
+
530
+ def __next__(self) -> AtomicUnit:
531
+ # Check global limit
532
+ if self.limit is not None and self.yielded_count >= self.limit:
533
+ raise StopIteration
534
+
535
+ item = next(self._source)
536
+ if not isinstance(item, AtomicUnit):
537
+ raise ValueError(f"AtomIterator expected AtomicUnit, got {type(item)}. Check your generator.")
538
+
539
+ self.yielded_count += 1
540
+ return item
541
+
542
+
543
+ ########################################################################################################################
544
+ # CLASSES
545
+
546
+
547
+ class _BaseAlchemyCore:
548
+ """
549
+ Core SQLAlchemy infrastructure:
550
+
551
+ - Engine and session factory.
552
+ - Manual session lifecycle (start/stop, context manager).
553
+ - Long-lived session detection (warnings when a session stays open too long).
554
+ - Basic DDL helpers (create/drop/reset tables).
555
+ - Utilities such as reset_column and integrity error logging.
556
+
557
+ This class is meant to be combined with mixins that add higher-level behavior.
558
+ """
559
+
560
+ def __init__(self, config: MutableMapping) -> None:
561
+ """
562
+ Initialize the core interface from a configuration mapping.
563
+
564
+ Expected config format:
565
+ {
566
+ "db": {
567
+ "engine": "postgresql+psycopg2",
568
+ "user": "...",
569
+ "password": "...",
570
+ "host": "...",
571
+ "port": 5432,
572
+ "database": "..."
573
+ }
574
+ }
575
+ """
576
+ self.session: Optional[Session] = None
577
+ self._session_started_at: Optional[float] = None
578
+ self._last_session_warning_at: Optional[float] = None
16
579
 
17
- class AlchemyInterface:
18
- def __init__(self, config):
19
580
  if "db" in config:
20
581
  self.config = config["db"]
21
-
22
582
  self.engine = create_engine(self.get_conn_str())
23
- self.session = sessionmaker(bind=self.engine)()
24
- self.cursor = self.session.connection().connection.cursor()
25
-
583
+ self.Session = sessionmaker(bind=self.engine)
26
584
  else:
27
585
  logger.warning("no db section in config")
28
586
 
29
- def get_conn_str(self):
587
+ def __enter__(self) -> "_BaseAlchemyCore":
588
+ """Enter the runtime context related to this object (starts session)."""
589
+ self.start()
590
+ return self
591
+
592
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
593
+ """
594
+ Exit the runtime context related to this object.
595
+
596
+ - Commits if no exception was raised inside the context.
597
+ - Rolls back otherwise.
598
+ - Always closes the session.
599
+ """
600
+ should_commit = exc_type is None
601
+ self.stop(commit=should_commit)
602
+
603
+ def start(self) -> None:
604
+ """
605
+ Start a new SQLAlchemy session manually.
606
+
607
+ This is intended for short-lived units of work. Long-lived sessions will
608
+ trigger warnings via _check_session_duration().
609
+ """
610
+ if not hasattr(self, "Session"):
611
+ raise AttributeError("Database configuration not initialized. Cannot create session.")
612
+ if self.session is not None:
613
+ raise RuntimeError("Session already active.")
614
+ self.session = self.Session()
615
+ now = time.monotonic()
616
+ self._session_started_at = now
617
+ self._last_session_warning_at = None
618
+ logger.debug("SQLAlchemy session started manually.")
619
+
620
+ def stop(self, commit: bool = True) -> None:
621
+ """
622
+ Stop the current SQLAlchemy session.
623
+
624
+ Args:
625
+ commit: If True, attempt to commit before closing. Otherwise rollback.
626
+ """
627
+ if self.session is None:
628
+ logger.warning("No active session to stop.")
629
+ return
630
+
631
+ try:
632
+ if commit:
633
+ logger.debug("Committing SQLAlchemy session before stopping.")
634
+ self.session.commit()
635
+ else:
636
+ logger.debug("Rolling back SQLAlchemy session before stopping.")
637
+ self.session.rollback()
638
+ except Exception as e:
639
+ logger.error(f"Exception during session commit/rollback on stop: {e}", exc_info=True)
640
+ try:
641
+ self.session.rollback()
642
+ except Exception as rb_exc:
643
+ logger.error(f"Exception during secondary rollback attempt on stop: {rb_exc}", exc_info=True)
644
+ raise
645
+ finally:
646
+ logger.debug("Closing SQLAlchemy session.")
647
+ self.session.close()
648
+ self.session = None
649
+ self._session_started_at = None
650
+ self._last_session_warning_at = None
651
+
652
+ def _check_session_duration(self) -> None:
653
+ """
654
+ Emit warnings if a manually managed session has been open for too long.
655
+
656
+ This is meant to detect misuse patterns such as:
657
+ with AlchemyInterface(...) as db:
658
+ # long-running loop / scraper / ETL here
659
+ """
660
+ if self.session is None or self._session_started_at is None:
661
+ return
662
+
663
+ now = time.monotonic()
664
+ elapsed = now - self._session_started_at
665
+
666
+ if elapsed < SESSION_WARN_THRESHOLD_SECONDS:
667
+ return
668
+
669
+ if (
670
+ self._last_session_warning_at is None
671
+ or (now - self._last_session_warning_at) >= SESSION_WARN_REPEAT_SECONDS
672
+ ):
673
+ logger.warning(
674
+ "SQLAlchemy session has been open for %.1f seconds. "
675
+ "This is likely a misuse of AlchemyInterface (long-lived session). "
676
+ "Prefer short-lived sessions or the batch pipeline API.",
677
+ elapsed,
678
+ )
679
+ self._last_session_warning_at = now
680
+
681
+ def get_conn_str(self) -> str:
682
+ """
683
+ Build the SQLAlchemy connection string from the loaded configuration.
684
+ """
30
685
  return (
31
- f'{self.config["engine"]}://'
32
- f'{self.config["user"]}:{quote_plus(self.config["password"])}'
33
- f'@{self.config["host"]}:{self.config["port"]}'
34
- f'/{self.config["database"]}'
686
+ f"{self.config['engine']}://"
687
+ f"{self.config['user']}:{quote_plus(self.config['password'])}"
688
+ f"@{self.config['host']}:{self.config['port']}"
689
+ f"/{self.config['database']}"
35
690
  )
36
691
 
37
692
  @staticmethod
38
- def get_schema_from_table(table):
693
+ def get_schema_from_table(table: Type[ModelType]) -> str:
694
+ """
695
+ Infer schema name from a SQLAlchemy model class.
696
+
697
+ - Defaults to 'public'.
698
+ - Warns if no explicit schema is provided.
699
+ """
39
700
  schema = "public"
40
701
 
41
702
  if isinstance(table.__table_args__, tuple):
@@ -51,7 +712,12 @@ class AlchemyInterface:
51
712
 
52
713
  return schema
53
714
 
54
- def create_tables(self, tables):
715
+ def create_tables(self, tables: List[Type[ModelType]]) -> None:
716
+ """
717
+ Create schemas and tables (or views) if they do not already exist.
718
+
719
+ For views, it calls a custom `create_view(conn)` on the model if needed.
720
+ """
55
721
  for table in tables:
56
722
  schema = self.get_schema_from_table(table)
57
723
 
@@ -74,7 +740,12 @@ class AlchemyInterface:
74
740
  else:
75
741
  logger.info(f"table {table.__tablename__} already exists")
76
742
 
77
- def drop_tables(self, tables):
743
+ def drop_tables(self, tables: List[Type[ModelType]]) -> None:
744
+ """
745
+ Drop the given tables or views if they exist.
746
+
747
+ Uses CASCADE to also drop dependent objects.
748
+ """
78
749
  for table in tables:
79
750
  schema = self.get_schema_from_table(table)
80
751
 
@@ -90,22 +761,1260 @@ class AlchemyInterface:
90
761
  conn.execute(DDL(f"DROP TABLE {schema}.{table.__tablename__} CASCADE"))
91
762
  conn.commit()
92
763
 
93
- def reset_db(self, tables, drop):
764
+ def reset_db(self, tables: List[Type[ModelType]], drop: bool = False) -> None:
765
+ """
766
+ Reset the database objects for a list of models.
767
+
768
+ Args:
769
+ tables: List of model classes.
770
+ drop: If True, drop tables/views before recreating them.
771
+ """
94
772
  if drop:
95
773
  self.drop_tables(tables)
96
774
 
97
775
  self.create_tables(tables)
98
776
 
99
- def insert_alchemy_obj(self, alchemy_obj, silent=False):
777
+ def reset_column(self, query_results: List[Result[Any]], column_name: str) -> None:
778
+ """
779
+ Reset a column to its default value for a list of query results.
780
+
781
+ The defaults may come from:
782
+ - server_default (uses SQL DEFAULT),
783
+ - column.default (Python callable or constant).
784
+
785
+ Args:
786
+ query_results: List of ORM instances to update.
787
+ column_name: Name of the column to reset.
788
+ """
789
+ if self.session is None:
790
+ raise RuntimeError("Session not active. Use 'with AlchemyInterface(...):' or call start()")
791
+
792
+ self._check_session_duration()
793
+
794
+ if not query_results:
795
+ logger.warning("No objects to reset column for.")
796
+ return
797
+
798
+ first_obj = query_results[0]
799
+ model_class = first_obj.__class__
800
+ table = model_class.__table__
801
+
802
+ if column_name not in table.columns:
803
+ logger.warning(f"Column {column_name} does not exist in table {table.name}.")
804
+ return
805
+
806
+ column = table.columns[column_name]
807
+
808
+ if column.server_default is not None:
809
+ # Use SQL DEFAULT so the server decides the final value.
810
+ default_value = text("DEFAULT")
811
+ elif column.default is not None:
812
+ default_value = column.default.arg
813
+ if callable(default_value):
814
+ # Some column defaults expect a context with metadata.
815
+ default_value = default_value(MockContext(column))
816
+ else:
817
+ raise ValueError(f"Column '{column_name}' doesn't have a default value defined.")
818
+
819
+ query_results.update({column_name: default_value}, synchronize_session=False)
820
+
821
+ @staticmethod
822
+ def _log_integrity_error(ex: IntegrityError, alchemy_obj: Any, action: str = "insert") -> None:
823
+ """
824
+ Log PostgreSQL IntegrityError in a compact, human-friendly way using SQLSTATE codes.
825
+
826
+ For code meanings, see:
827
+ https://www.postgresql.org/docs/current/errcodes-appendix.html
828
+ """
829
+
830
+ PG_ERROR_LABELS = {
831
+ "23000": "Integrity constraint violation",
832
+ "23001": "Restrict violation",
833
+ "23502": "NOT NULL violation",
834
+ "23503": "Foreign key violation",
835
+ "23505": "Unique violation",
836
+ "23514": "Check constraint violation",
837
+ "23P01": "Exclusion constraint violation",
838
+ }
839
+ code = getattr(ex.orig, "pgcode", None)
840
+ label = PG_ERROR_LABELS.get(code, "Integrity error (unspecified)")
841
+
842
+ if code == "23505":
843
+ logger.info(f"{label} trying to {action} {alchemy_obj}")
844
+ else:
845
+ logger.error(f"{label} trying to {action} {alchemy_obj}\nPostgreSQL message: {ex.orig}")
846
+
847
+ def log_row(
848
+ self, obj: Any, columns: Optional[List[str]] = None, prefix_msg: str = "Row", full: bool = False
849
+ ) -> None:
850
+ """
851
+ Logs attributes of an SQLAlchemy object in a standardized format.
852
+
853
+ Args:
854
+ obj: The SQLAlchemy model instance.
855
+ columns: List of attribute names. If None, uses currently set attributes (obj.__dict__).
856
+ prefix_msg: Descriptive tag for the log (e.g. "Crawled", "Parsed", "To be inserted"). Defaults to "Row".
857
+ full: If True, disables truncation of long values (default limit is 500 chars).
858
+ """
859
+ try:
860
+ # If no specific columns requested, grab only what is currently set on the object
861
+ if not columns:
862
+ # Filter out SQLAlchemy internals (keys starting with '_')
863
+ columns = [k for k in obj.__dict__ if not k.startswith("_")]
864
+ if not columns:
865
+ logger.info(f"{prefix_msg}: {obj}")
866
+ return
867
+
868
+ stats_parts = []
869
+ for col in columns:
870
+ # getattr is safe; fallback to 'N/A' if the key is missing
871
+ val = getattr(obj, col, "N/A")
872
+ val_str = str(val)
873
+
874
+ # Truncation logic (Default limit 500 chars)
875
+ if not full and len(val_str) > 500:
876
+ val_str = val_str[:500] + "...(truncated)"
877
+
878
+ stats_parts.append(f"{col}={val_str}")
879
+
880
+ stats_msg = ", ".join(stats_parts)
881
+ # Result: "Crawled: Category | url=http://... , zip_code=28001"
882
+ logger.info(f"{prefix_msg}: {obj.__class__.__name__} | {stats_msg}")
883
+
884
+ except Exception as e:
885
+ # Fallback to standard repr if anything breaks, but log the error
886
+ logger.error(f"Failed to generate detailed log for {obj}", exc_info=e)
887
+ logger.info(f"{prefix_msg}: {obj}")
888
+
889
+
890
+ class _LegacyAlchemyOps:
891
+ """
892
+ Mixin containing legacy CRUD helpers and range-query utilities.
893
+
894
+ These methods remain for backwards compatibility but should not be used in
895
+ new code. Prefer the batch pipeline and iter_query_safe for new flows.
896
+ """
897
+
898
+ def insert_alchemy_obj(self, alchemy_obj: ModelType, silent: bool = False) -> bool:
899
+ """
900
+ Legacy insert helper using per-object savepoints.
901
+
902
+ - Uses a nested transaction per insert (savepoint).
903
+ - Swallows IntegrityError and returns False if it occurs.
904
+
905
+ Prefer using lightweight INSERT via the batch pipeline.
906
+ """
907
+ logger.warning(
908
+ "DEPRECATED: insert_alchemy_obj is legacy API. "
909
+ "Prefer using the batch pipeline (process_batch) with lightweight inserts."
910
+ )
911
+
912
+ if self.session is None:
913
+ raise RuntimeError("Session not active. Use 'with AlchemyInterface(...):' or call start()")
914
+
915
+ self._check_session_duration()
916
+
100
917
  try:
918
+ # Use a savepoint (nested transaction)
919
+ with self.session.begin_nested():
920
+ if not silent:
921
+ logger.info(f"adding {alchemy_obj}...")
922
+ self.session.add(alchemy_obj)
923
+ except IntegrityError as ex:
924
+ # Rollback is handled automatically by begin_nested() context manager on error
101
925
  if not silent:
102
- logger.info(f"adding {alchemy_obj}...")
926
+ self._log_integrity_error(ex, alchemy_obj, action="insert")
927
+ # Do not re-raise, allow outer transaction/loop to continue
928
+ return False
929
+
930
+ return True
103
931
 
104
- self.session.add(alchemy_obj)
105
- self.session.commit()
932
+ def upsert_alchemy_obj(self, alchemy_obj: ModelType, index_elements: List[str], silent: bool = False) -> bool:
933
+ """
934
+ Legacy upsert helper using per-object savepoints and ON CONFLICT DO UPDATE.
935
+
936
+ - Builds insert/update dicts from ORM object attributes.
937
+ - Uses a savepoint per upsert.
938
+ - Swallows IntegrityError and returns False if it occurs.
939
+
940
+ Prefer using lightweight upsert via the batch pipeline.
941
+ """
942
+ logger.warning(
943
+ "DEPRECATED: upsert_alchemy_obj is legacy API. "
944
+ "Prefer using the batch pipeline (process_batch) with lightweight upserts."
945
+ )
946
+
947
+ if self.session is None:
948
+ raise RuntimeError("Session not active. Use 'with AlchemyInterface(...):' or call start()")
949
+
950
+ self._check_session_duration()
951
+
952
+ if not silent:
953
+ logger.info(f"upserting {alchemy_obj}")
954
+
955
+ table = alchemy_obj.__table__
956
+ primary_keys = list(col.name for col in table.primary_key.columns.values())
957
+
958
+ # Build the dictionary for the INSERT values
959
+ insert_values = {
960
+ col.name: getattr(alchemy_obj, col.name)
961
+ for col in table.columns
962
+ if getattr(alchemy_obj, col.name) is not None # Include all non-None values for insert
963
+ }
964
+
965
+ # Build the dictionary for the UPDATE set clause
966
+ # Start with values from the object, excluding primary keys
967
+ update_set_values = {
968
+ col.name: val
969
+ for col in table.columns
970
+ if col.name not in primary_keys and (val := getattr(alchemy_obj, col.name)) is not None
971
+ }
972
+
973
+ # Add columns with SQL-based onupdate values explicitly to the set clause
974
+ for column in table.columns:
975
+ actual_sql_expression = None
976
+ if column.onupdate is not None:
977
+ if hasattr(column.onupdate, "arg") and isinstance(column.onupdate.arg, ClauseElement):
978
+ # This handles wrappers like ColumnElementColumnDefault,
979
+ # where the actual SQL expression is in the .arg attribute.
980
+ actual_sql_expression = column.onupdate.arg
981
+ elif isinstance(column.onupdate, ClauseElement):
982
+ # This handles cases where onupdate might be a direct SQL expression.
983
+ actual_sql_expression = column.onupdate
984
+
985
+ if actual_sql_expression is not None:
986
+ update_set_values[column.name] = actual_sql_expression
987
+
988
+ statement = (
989
+ insert(table)
990
+ .values(insert_values)
991
+ .on_conflict_do_update(index_elements=index_elements, set_=update_set_values)
992
+ )
106
993
 
107
- except IntegrityError:
994
+ try:
995
+ # Use a savepoint (nested transaction)
996
+ with self.session.begin_nested():
997
+ self.session.execute(statement)
998
+ except IntegrityError as ex:
999
+ # Rollback is handled automatically by begin_nested() context manager on error
108
1000
  if not silent:
109
- logger.info(f"{alchemy_obj} already in db")
1001
+ self._log_integrity_error(ex, alchemy_obj, action="upsert")
1002
+ # Do not re-raise, allow outer transaction/loop to continue
1003
+ return False
1004
+
1005
+ return True
1006
+
1007
+ def windowed_query(
1008
+ self,
1009
+ stmt: Select[Any],
1010
+ order_by: List[SQLColumnExpression[Any]],
1011
+ windowsize: int,
1012
+ commit_strategy: Union[CommitStrategy, str] = CommitStrategy.COMMIT_ON_SUCCESS,
1013
+ ) -> Iterator[Result[Any]]:
1014
+ """
1015
+ Legacy windowed query helper (range query).
1016
+
1017
+ It executes the given SELECT statement in windows of size `windowsize`,
1018
+ each in its own short-lived session, and yields `Result` objects.
1019
+
1020
+ Prefer `get_row_iterator` for new range-query implementations.
1021
+ """
1022
+ logger.warning("DEPRECATED: windowed_query is legacy API. Prefer using get_row_iterator for range queries.")
1023
+
1024
+ # Parameter mapping
1025
+ if isinstance(commit_strategy, str):
1026
+ commit_strategy = CommitStrategy[commit_strategy.upper()]
1027
+
1028
+ # Find id column in stmt
1029
+ if not any(column.get("entity").id for column in stmt.column_descriptions):
1030
+ raise Exception("Column 'id' not found in any entity of the query.")
1031
+ id_column = stmt.column_descriptions[0]["entity"].id
1032
+
1033
+ last_id = 0
1034
+ while True:
1035
+ session_active = False
1036
+ commit_needed = False
1037
+ try:
1038
+ self.start()
1039
+ session_active = True
1040
+
1041
+ # Filter on row_number in the outer query
1042
+ current_query = stmt.where(id_column > last_id).order_by(order_by[0], *order_by[1:]).limit(windowsize)
1043
+ result = self.session.execute(current_query)
1044
+
1045
+ # Create a FrozenResult to allow peeking at the data without consuming
1046
+ frozen_result: FrozenResult = result.freeze()
1047
+ chunk = frozen_result().all()
1048
+
1049
+ if not chunk:
1050
+ break
1051
+
1052
+ # Update for next iteration
1053
+ last_id = chunk[-1].id
1054
+
1055
+ # Create a new Result object from the FrozenResult
1056
+ yield_result = frozen_result()
1057
+
1058
+ yield yield_result
1059
+ commit_needed = True
1060
+
1061
+ finally:
1062
+ if session_active and self.session:
1063
+ # Double check before stopping just in case. The user may never call insert/upsert in the loop,
1064
+ # so a final check needs to be done.
1065
+ self._check_session_duration()
1066
+
1067
+ if commit_strategy == CommitStrategy.FORCE_COMMIT:
1068
+ # For forced commit, always attempt to commit.
1069
+ # The self.stop() method already handles potential exceptions during commit/rollback.
1070
+ self.stop(commit=True)
1071
+ elif commit_strategy == CommitStrategy.COMMIT_ON_SUCCESS:
1072
+ # Commit only if no exception occurred before yielding the result.
1073
+ self.stop(commit=commit_needed)
1074
+ else:
1075
+ # Fallback or error for unknown strategy, though type hinting should prevent this.
1076
+ # For safety, default to rollback.
1077
+ logger.warning(f"Unknown commit strategy: {commit_strategy}. Defaulting to rollback.")
1078
+ self.stop(commit=False)
1079
+
1080
+
1081
+ class _BatchPipelineOps:
1082
+ """
1083
+ Mixin providing:
1084
+ - Safe, paginated iteration over large queries (iter_query).
1085
+ - Validation wrapper for atomic streams (iter_atoms).
1086
+ - Unified 'process_batch' method for Data Sink and ETL workflows.
1087
+ """
1088
+
1089
+ def get_row_iterator(
1090
+ self,
1091
+ source: Union[Select[Any], Query],
1092
+ chunk_size: int,
1093
+ order_by: Optional[List[SQLColumnExpression[Any]]] = None,
1094
+ limit: Optional[int] = None,
1095
+ ) -> RowIterator:
1096
+ """
1097
+ Creates a RowIterator for safe, paginated iteration over large datasets.
1098
+
1099
+ Features:
1100
+ - **Composite Keyset Pagination:** Prevents data gaps/duplicates even with changing data.
1101
+ - **Mixed Sorting:** Supports arbitrary combinations of ASC and DESC columns.
1102
+ - **Automatic Safety:** Automatically appends the Primary Key to `order_by` to ensure total ordering.
1103
+
1104
+ Args:
1105
+ source: SQLAlchemy Select or Query object.
1106
+ chunk_size: Number of rows to fetch per transaction.
1107
+ order_by: List of columns to sort by. Can use `sqlalchemy.desc()`.
1108
+ Default is [Primary Key ASC].
1109
+ limit: Global limit on number of rows to yield.
1110
+
1111
+ Returns:
1112
+ RowIterator: An iterator yielding detached ORM objects.
1113
+ """
1114
+ return RowIterator(self, source, chunk_size, order_by, limit)
1115
+
1116
+ def get_atom_iterator(
1117
+ self,
1118
+ source: Iterator[AtomicUnit],
1119
+ limit: Optional[int] = None,
1120
+ ) -> AtomIterator:
1121
+ """
1122
+ Wraps a generator to ensure it is treated as a stream of AtomicUnits.
1123
+
1124
+ Args:
1125
+ source: An iterator that yields AtomicUnit objects.
1126
+ limit: Maximum number of units to process.
1127
+
1128
+ Returns:
1129
+ AtomIterator: A wrapper iterator that validates strict typing at runtime.
1130
+ """
1131
+ return AtomIterator(source, limit)
1132
+
1133
+ def process_batch(
1134
+ self,
1135
+ source: Union[RowIterator, AtomIterator],
1136
+ processor_func: Optional[ProcessorFunc] = None,
1137
+ batch_size: int = 100,
1138
+ use_bulk_strategy: bool = True,
1139
+ ) -> None:
1140
+ """
1141
+ Unified Batch Processor.
1142
+
1143
+ Accepts specific iterator types to ensure pipeline safety:
1144
+ - RowIterator: Yields Model objects (ETL Mode). Requires 'processor_func'.
1145
+ - AtomIterator: Yields AtomicUnits (Data Sink Mode). 'processor_func' is optional.
1146
+
1147
+ Args:
1148
+ source: A RowIterator or AtomIterator.
1149
+ processor_func: Function to transform items into AtomicUnits.
1150
+ If None, assumes source yields AtomicUnits directly.
1151
+ batch_size: Target number of SQL operations per transaction.
1152
+ use_bulk_strategy: If True (default), attempts fast Bulk Upserts/Inserts first. Duplicates cannot be examined, only counted.
1153
+ If False, skips directly to Row-by-Row recovery mode (useful for debugging duplicates).
1154
+ """
1155
+ # Validation Checks
1156
+ self._validate_pipeline_config(batch_size)
1157
+
1158
+ total_items = len(source) if hasattr(source, "__len__") else None
1159
+
1160
+ if total_items is not None:
1161
+ logger.info(colorize(f"⚙️ Total items to process: {total_items}", SystemColor.PROCESS_BATCH_PROGRESS))
1162
+ else:
1163
+ logger.info(colorize("⚙️ Total items to process: Unknown (Stream Mode)", SystemColor.PROCESS_BATCH_PROGRESS))
1164
+
1165
+ current_batch: List[AtomicUnit] = []
1166
+ current_batch_ops_count = 0
1167
+ processed_count = 0 # Global counter
1168
+
1169
+ # Job Accumulators
1170
+ job_stats = _BatchStats()
1171
+ job_start_time = time.time()
1172
+
1173
+ for item in source:
1174
+ processed_count += 1
1175
+
1176
+ # Logs progress counter.
1177
+ if total_items:
1178
+ # Update in case of rerun
1179
+ total_items = len(source) if hasattr(source, "__len__") else None
1180
+ logger.info(
1181
+ colorize(f"⚙️ Processing item [{processed_count}/{total_items}]", SystemColor.PROCESS_BATCH_PROGRESS)
1182
+ )
1183
+ else:
1184
+ logger.info(colorize(f"⚙️ Processing item [{processed_count}]", SystemColor.PROCESS_BATCH_PROGRESS))
1185
+
1186
+ # Process Item (No DB Session active)
1187
+ try:
1188
+ result = processor_func(item) if processor_func else item
1189
+ except Exception as e:
1190
+ logger.error(f"Pipeline: Processor failed on item {item}. Flushing buffer before crash.")
1191
+
1192
+ # Emergency Flush: Save what we have before dying
1193
+ if current_batch:
1194
+ self._flush_and_log(current_batch, job_stats, job_start_time, use_bulk_strategy)
1195
+
1196
+ # Re-raise to stop the pipeline
1197
+ raise e
1198
+
1199
+ if not result:
1200
+ continue
1201
+
1202
+ # Normalize & Validate Units
1203
+ units = self._normalize_result_to_units(result)
1204
+
1205
+ # Add to Buffer
1206
+ for unit in units:
1207
+ current_batch.append(unit)
1208
+ current_batch_ops_count += len(unit)
1209
+
1210
+ # Check Buffer
1211
+ if current_batch_ops_count >= batch_size:
1212
+ self._flush_and_log(current_batch, job_stats, job_start_time, use_bulk_strategy)
1213
+
1214
+ current_batch = []
1215
+ current_batch_ops_count = 0
1216
+
1217
+ # Final flush
1218
+ if current_batch:
1219
+ self._flush_and_log(current_batch, job_stats, job_start_time, use_bulk_strategy)
1220
+
1221
+ def _validate_pipeline_config(self, batch_size: int) -> None:
1222
+ """Helper to enforce pipeline guardrails."""
1223
+ if self.session is not None:
1224
+ raise RuntimeError(
1225
+ "Pipeline methods should not be called while a session is already active. "
1226
+ "Do not run this inside a 'with AlchemyInterface(...)' block."
1227
+ )
1228
+
1229
+ if batch_size < MIN_BATCH_WARNING:
1230
+ logger.warning(
1231
+ f"PERFORMANCE WARNING: batch_size={batch_size} is low. "
1232
+ f"Consider using at least {MIN_BATCH_WARNING} items per batch."
1233
+ )
1234
+
1235
+ if batch_size > MAX_BATCH_WARNING:
1236
+ logger.warning(
1237
+ f"PERFORMANCE WARNING: batch_size={batch_size} is very large. "
1238
+ f"This creates long-running transactions. Consider lowering it."
1239
+ )
1240
+
1241
+ def _normalize_result_to_units(self, result: Any) -> List[AtomicUnit]:
1242
+ """Helper to validate and normalize processor results into a list of AtomicUnits."""
1243
+ # Handle both single item and list of items
1244
+ raw_units = result if isinstance(result, (list, tuple)) else [result]
1245
+ valid_units = []
1246
+
1247
+ for unit in raw_units:
1248
+ if not isinstance(unit, AtomicUnit):
1249
+ raise ValueError(f"Expected AtomicUnit, got {type(unit)}. Check your processor_func.")
1250
+
1251
+ unit_len = len(unit)
1252
+ if unit_len > MAX_OPS_PER_ATOM:
1253
+ logger.warning(
1254
+ f"Single AtomicUnit contains {unit_len} operations. Max allowed is {MAX_OPS_PER_ATOM}."
1255
+ f"Verify your code to make sure your atom contains the minimal number of operations."
1256
+ )
1257
+ valid_units.append(unit)
1258
+
1259
+ return valid_units
1260
+
1261
+ def _flush_and_log(
1262
+ self, batch: List[AtomicUnit], job_stats: _BatchStats, job_start_time: float, use_bulk_strategy: bool
1263
+ ) -> None:
1264
+ """
1265
+ Helper to flush the batch, update cumulative stats, and log the progress.
1266
+ """
1267
+ batch_ops_count = sum(len(u) for u in batch)
1268
+
1269
+ # Measure DB Time for this batch
1270
+ db_start = time.time()
1271
+
1272
+ flush_stats = self._flush_batch_optimistic(batch) if use_bulk_strategy else self._flush_batch_resilient(batch)
1273
+
1274
+ db_duration = time.time() - db_start
1275
+
1276
+ # Update metadata
1277
+ flush_stats.size = batch_ops_count
1278
+ flush_stats.db_time = db_duration
1279
+
1280
+ # Update Job Totals
1281
+ job_stats.add(flush_stats)
1282
+
1283
+ # Log combined status
1284
+ self._log_progress(flush_stats, job_stats, job_start_time)
1285
+
1286
+ def _log_progress(self, batch_stats: _BatchStats, job_stats: _BatchStats, job_start_time: float) -> None:
1287
+ """
1288
+ Logs the batch performance and the running total for the job.
1289
+ """
1290
+ current_time = time.time()
1291
+ total_elapsed = current_time - job_start_time
1292
+
1293
+ # Calculate Job Speed
1294
+ job_ops_sec = job_stats.size / total_elapsed if total_elapsed > 0 else 0.0
1295
+
1296
+ logger.info(
1297
+ colorize(
1298
+ f"📦 PIPELINE PROGRESS | "
1299
+ f"BATCH: {batch_stats.size} ops (DB: {batch_stats.db_time:.2f}s) | "
1300
+ f"TOTAL: {job_stats.size} ops (Time: {total_elapsed:.0f}s, DB: {job_stats.db_time:.0f}s) "
1301
+ f"STATS: {job_stats.inserts} Ins, "
1302
+ f"{job_stats.upserts} Ups ({job_stats.upsert_inserts} Ins, {job_stats.upsert_updates} Upd), "
1303
+ f"{job_stats.duplicates} Dup, {job_stats.failures} Fail | "
1304
+ f"SPEED: {job_ops_sec:.1f} ops/s",
1305
+ SystemColor.BATCH_PIPELINE_STATS,
1306
+ )
1307
+ )
1308
+
1309
+ def _flush_batch_optimistic(self, units: List[AtomicUnit]) -> _BatchStats:
1310
+ """
1311
+ Strategy: Optimistic Bulk Batching.
1312
+
1313
+ 1. Groups operations by Table and Action.
1314
+ 2. Normalizes data (applying defaults) to create uniform bulk payloads.
1315
+ 3. Executes one massive SQL statement per group.
1316
+ 4. Distinguishes Inserts vs Upserts using Postgres system columns (xmax).
1317
+
1318
+ Failure Strategy:
1319
+ If ANY group fails (Constraint, Deadlock, etc.), the entire transaction rolls back
1320
+ and we delegate to '_flush_batch_resilient' (Recovery Mode).
1321
+ """
1322
+ if self.session is not None:
1323
+ raise RuntimeError("Unexpected active session during batch flush.")
1324
+
1325
+ self.start()
1326
+ self._check_session_duration()
1327
+
1328
+ stats = _BatchStats()
1329
+
1330
+ try:
1331
+ # Sort & Group ("The Traffic Cop")
1332
+ buckets = self._group_ops_by_signature(units)
1333
+
1334
+ # Process each bucket in bulk
1335
+ for signature, ops in buckets.items():
1336
+ table, action, index_elements = signature
1337
+
1338
+ bucket_stats = self._bulk_process_bucket(table, action, index_elements, ops)
1339
+ stats.add(bucket_stats)
1340
+
1341
+ self._check_session_duration()
1342
+ self.stop(commit=True)
1343
+ return stats
1344
+
1345
+ except Exception as e:
1346
+ # If ANY bulk op fails, the whole batch is tainted.
1347
+ # Rollback and switch to safe, row-by-row recovery.
1348
+ logger.warning(f"Optimistic bulk commit failed. Switching to Recovery Mode. Error: {e}")
1349
+ self.stop(commit=False)
1350
+ return self._flush_batch_resilient(units)
1351
+
1352
+ def _group_ops_by_signature(self, units: List[AtomicUnit]) -> Dict[tuple, List[BatchOp]]:
1353
+ """
1354
+ Helper: Groups operations into buckets safe for bulk execution.
1355
+ Signature: (Table, OpAction, Tuple(index_elements))
1356
+ """
1357
+ buckets = defaultdict(list)
1358
+
1359
+ for unit in units:
1360
+ for op in unit.ops:
1361
+ # We need to group by index_elements too, because different conflict targets
1362
+ # require different SQL statements.
1363
+ idx_key = tuple(sorted(op.index_elements)) if op.index_elements else None
1364
+
1365
+ sig = (op.obj.__table__, op.action, idx_key)
1366
+ buckets[sig].append(op)
1367
+
1368
+ return buckets
1369
+
1370
+ def _bulk_process_bucket(
1371
+ self,
1372
+ table: Table,
1373
+ action: OpAction,
1374
+ index_elements: Optional[tuple],
1375
+ ops: List[BatchOp],
1376
+ ) -> _BatchStats:
1377
+ """
1378
+ Executes a single bulk operation for a homogeneous group of records.
1379
+ Refactored to reduce cyclomatic complexity.
1380
+ """
1381
+ # Prepare Data (Uniformity Pass)
1382
+ records = self._prepare_bulk_payload(table, ops)
1383
+ if not records:
1384
+ return _BatchStats()
1385
+
1386
+ # Build & Execute Statement
1387
+ stmt = self._build_bulk_stmt(table, action, index_elements, records)
1388
+ result = self.session.execute(stmt)
1389
+ rows = result.all()
1390
+
1391
+ # Calculate Stats
1392
+ return self._calculate_bulk_stats(action, rows, len(records))
1393
+
1394
+ def _prepare_bulk_payload(self, table: Table, ops: List[BatchOp]) -> List[Dict[str, Any]]:
1395
+ """
1396
+ Helper: Iterates through operations and resolves the exact value for every column.
1397
+ 1. Uses 'Sparse' strategy (Union of Keys) via `_get_active_bulk_columns`.
1398
+ 2. Enforces Uniformity: If a column is active in the batch, every row sends a value for it.
1399
+ We do not skip PKs individually; if the batch schema includes the PK, we send it.
1400
+ """
1401
+ active_column_names = self._get_active_bulk_columns(table, ops)
1402
+
1403
+ payload = []
1404
+ for op in ops:
1405
+ row_data = {}
1406
+
1407
+ # Iterate only over the "Union of Keys" found in this batch
1408
+ for col_name in active_column_names:
1409
+ col = table.columns[col_name]
1410
+ val = self._resolve_column_value(col, op.obj)
1411
+
1412
+ # Sentinel check: 'Ellipsis' means "Skip this column entirely" (e.g. Server OnUpdate)
1413
+ # Note: We do NOT skip None here. If the column is active but value is None,
1414
+ # we send None (which maps to NULL in SQL), preserving batch shape uniformity.
1415
+ if val is not Ellipsis:
1416
+ row_data[col_name] = val
1417
+
1418
+ payload.append(row_data)
1419
+
1420
+ return payload
1421
+
1422
+ def _get_active_bulk_columns(self, table: Table, ops: List[BatchOp]) -> Set[str]:
1423
+ """
1424
+ Helper: Scans the batch to find the 'Union of Keys'.
1425
+ A column is active if it is explicitly set (not None) on ANY object in the batch,
1426
+ or if it has a System OnUpdate.
1427
+ """
1428
+ active_names = set()
1429
+
1430
+ # Always include System OnUpdates
1431
+ for col in table.columns:
1432
+ if col.onupdate:
1433
+ active_names.add(col.name)
1434
+
1435
+ # Scan data for explicit values
1436
+ # We assume that if a column is None on all objects, it should be excluded
1437
+ # (to avoid triggering context-dependent defaults and doing unnecessary default simulations).
1438
+ for op in ops:
1439
+ for col in table.columns:
1440
+ # Optimization: Skip if already found
1441
+ if col.name in active_names:
1442
+ continue
1443
+
1444
+ if getattr(op.obj, col.name) is not None:
1445
+ active_names.add(col.name)
1446
+
1447
+ return active_names
1448
+
1449
+ def _build_bulk_stmt(
1450
+ self,
1451
+ table: Table,
1452
+ action: OpAction,
1453
+ index_elements: Optional[tuple],
1454
+ records: List[Dict[str, Any]],
1455
+ ) -> Any:
1456
+ """
1457
+ Helper: Constructs the SQLAlchemy Core statement (Insert or Upsert)
1458
+ with the correct ON CONFLICT and RETURNING clauses.
1459
+ """
1460
+ stmt = insert(table).values(records)
1461
+ pk_col = list(table.primary_key.columns)[0]
1462
+
1463
+ if action == OpAction.INSERT:
1464
+ # INSERT: Do Nothing on conflict, Return ID for count
1465
+ return stmt.on_conflict_do_nothing().returning(pk_col)
1466
+
1467
+ if action == OpAction.UPSERT:
1468
+ # UPSERT: Do Update on conflict, Return ID + xmax
1469
+ if not index_elements:
1470
+ raise ValueError(f"Upsert on {table.name} missing index_elements.")
1471
+
1472
+ update_set = self._build_upsert_set_clause(table, index_elements, records[0].keys())
1473
+
1474
+ # If there are no columns to update (e.g. only PK and Conflict Keys exist),
1475
+ # we must fallback to DO NOTHING to avoid a Postgres syntax error.
1476
+ if not update_set:
1477
+ return stmt.on_conflict_do_nothing().returning(pk_col)
1478
+
1479
+ return stmt.on_conflict_do_update(index_elements=index_elements, set_=update_set).returning(
1480
+ pk_col, text("xmax")
1481
+ )
1482
+
1483
+ raise ValueError(f"Unknown OpAction: {action}")
1484
+
1485
+ def _build_upsert_set_clause(
1486
+ self, table: Table, index_elements: tuple, record_keys: Iterable[str]
1487
+ ) -> Dict[str, Any]:
1488
+ """
1489
+ Helper: Builds the 'set_=' dictionary for ON CONFLICT DO UPDATE.
1490
+ Maps columns to EXCLUDED.column unless overridden by onupdate.
1491
+ """
1492
+ # Default: Update everything provided in the payload (except keys)
1493
+ primary_keys = {col.name for col in table.primary_key.columns}
1494
+ update_set = {
1495
+ key: getattr(insert(table).excluded, key)
1496
+ for key in record_keys
1497
+ if key not in index_elements and key not in primary_keys
1498
+ }
1499
+
1500
+ # Override: Python OnUpdates (System Overrides) must be forced
1501
+ for col in table.columns:
1502
+ if col.onupdate:
1503
+ expr = None
1504
+ if hasattr(col.onupdate, "arg") and isinstance(col.onupdate.arg, ClauseElement):
1505
+ expr = col.onupdate.arg
1506
+ elif isinstance(col.onupdate, ClauseElement):
1507
+ expr = col.onupdate
1508
+
1509
+ if expr is not None:
1510
+ update_set[col.name] = expr
1511
+
1512
+ return update_set
1513
+
1514
+ def _calculate_bulk_stats(self, action: OpAction, rows: List[Any], total_ops: int) -> _BatchStats:
1515
+ """
1516
+ Helper: Analyzes the RETURNING results to produce accurate counts.
1517
+ """
1518
+ stats = _BatchStats()
1519
+ returned_count = len(rows)
1520
+
1521
+ if action == OpAction.INSERT:
1522
+ stats.inserts = returned_count
1523
+ stats.duplicates = total_ops - returned_count
1524
+ return stats
1525
+
1526
+ if action == OpAction.UPSERT:
1527
+ stats.upserts = total_ops # Intent: We tried to upsert this many
1528
+
1529
+ # Handle the fallback case (Empty Update -> Do Nothing).
1530
+ # If we fell back, we only requested pk_col, so rows will have length 1 (no xmax).
1531
+ # If row exists, it's an "Insert" (conceptually created/ensured).
1532
+ if rows and len(rows[0]) < 2:
1533
+ stats.upsert_inserts = returned_count
1534
+ stats.duplicates = total_ops - returned_count
1535
+ return stats
1536
+
1537
+ # Analyze xmax to separate Inserts from Updates
1538
+ # xmax=0 implies insertion; xmax!=0 implies update.
1539
+ created = 0
1540
+ updated = 0
1541
+
1542
+ for row in rows:
1543
+ if row.xmax == 0:
1544
+ created += 1
1545
+ else:
1546
+ updated += 1
1547
+
1548
+ stats.upsert_inserts = created
1549
+ stats.upsert_updates = updated
1550
+ # In DO UPDATE, duplicates usually don't happen (unless filtered by WHERE)
1551
+ stats.duplicates = total_ops - returned_count
1552
+ return stats
1553
+
1554
+ return stats
1555
+
1556
+ def _flush_batch_resilient(self, units: List[AtomicUnit]) -> _BatchStats:
1557
+ """
1558
+ Recovery Mode with Safety Valve.
1559
+
1560
+ - Iterates through units one by one.
1561
+ - Uses Savepoints for isolation.
1562
+ - SAFETY VALVE: Forces a commit every MAX_FAILURE_BURST failures to flush "Dead Transaction IDs"
1563
+ from memory, preventing OOM/Lock exhaustion.
1564
+ """
1565
+ self.start()
1566
+ self._check_session_duration()
1567
+
1568
+ stats = _BatchStats()
1569
+
1570
+ try:
1571
+ for i, unit in enumerate(units):
1572
+ try:
1573
+ self._check_session_duration()
1574
+
1575
+ # Firewall: Each AtomicUnit gets its own isolated transaction
1576
+ with self.session.begin_nested():
1577
+ # Track stats for this specific unit
1578
+ unit_inserts = 0
1579
+ unit_upserts = 0
1580
+ unit_upsert_inserts = 0
1581
+ unit_upsert_updates = 0
1582
+ unit_duplicates = 0
1583
+
1584
+ for op in unit.ops:
1585
+ written, xmax = self._apply_operation(op)
1586
+
1587
+ if written > 0:
1588
+ if op.action == OpAction.INSERT:
1589
+ unit_inserts += 1
1590
+ elif op.action == OpAction.UPSERT:
1591
+ unit_upserts += 1
1592
+ if xmax == 0:
1593
+ unit_upsert_inserts += 1
1594
+ else:
1595
+ unit_upsert_updates += 1
1596
+ else:
1597
+ unit_duplicates += 1
1598
+
1599
+ # Only commit stats if the unit succeeds
1600
+ stats.inserts += unit_inserts
1601
+ stats.upserts += unit_upserts
1602
+ stats.upsert_inserts += unit_upsert_inserts
1603
+ stats.upsert_updates += unit_upsert_updates
1604
+ stats.duplicates += unit_duplicates
1605
+
1606
+ except Exception as e:
1607
+ # Granular Failure: Only THIS unit is lost.
1608
+ stats.failures += 1
1609
+ logger.error(
1610
+ f"Recovery: AtomicUnit failed (index {i} in batch). Discarding {len(unit)} ops. Error: {e}"
1611
+ )
1612
+
1613
+ # SAFETY VALVE: Check if we have accumulated too many dead transactions
1614
+ if stats.failures > 0 and stats.failures % MAX_FAILURE_BURST == 0:
1615
+ logger.warning(
1616
+ f"Safety Valve: {stats.failures} failures accumulated. "
1617
+ "Committing now to clear transaction memory and free up locks."
1618
+ )
1619
+ self.session.commit()
1620
+
1621
+ self._check_session_duration()
1622
+ # Commit whatever survived
1623
+ self.stop(commit=True)
1624
+ return stats
1625
+
1626
+ except Exception as e:
1627
+ # Critical Failure (e.g. DB Down).
1628
+ # Note: Since we might have done intermediate commits, some data might already be saved.
1629
+ # This rollback only affects the pending rows since the last safety commit.
1630
+ logger.error(f"Critical Failure in Recovery Mode: {e}", exc_info=True)
1631
+ self.stop(commit=False)
1632
+ raise e
1633
+
1634
+ def _apply_operation(self, op: BatchOp) -> Tuple[int, int]:
1635
+ """
1636
+ Apply a single BatchOp to the current session.
1637
+
1638
+ Returns:
1639
+ Tuple[int, int]: (rowcount, xmax).
1640
+ xmax is 0 for inserts/updates that don't return it.
1641
+ """
1642
+ self._check_session_duration()
1643
+
1644
+ if not isinstance(op, BatchOp):
1645
+ raise ValueError(
1646
+ f"Pipeline Error: Expected BatchOp, got {type(op)}. All operations must be wrapped in BatchOp."
1647
+ )
1648
+
1649
+ if op.action == OpAction.INSERT:
1650
+ # Inserts don't need xmax logic, just rowcount
1651
+ return self._insert_lightweight(op.obj), 0
1652
+ elif op.action == OpAction.UPSERT:
1653
+ if not op.index_elements:
1654
+ raise ValueError(f"Upsert BatchOp missing index_elements: {op}")
1655
+ return self._upsert_lightweight(op.obj, index_elements=op.index_elements)
1656
+ else:
1657
+ raise ValueError(f"Unknown OpAction: {op.action}")
1658
+
1659
+ def _insert_lightweight(self, obj: ModelType) -> int:
1660
+ """
1661
+ Lightweight INSERT using ON CONFLICT DO NOTHING.
1662
+
1663
+ - Builds a dict from non-None column values.
1664
+ - Skips duplicates silently (logs when a conflict happens).
1665
+
1666
+ Returns:
1667
+ int: rowcount (1 if inserted, 0 if duplicate).
1668
+ """
1669
+ self._check_session_duration()
1670
+
1671
+ table = obj.__table__
1672
+ data = {c.name: getattr(obj, c.name) for c in table.columns if getattr(obj, c.name) is not None}
1673
+
1674
+ stmt = insert(table).values(data).on_conflict_do_nothing()
1675
+ result = self.session.execute(stmt)
1676
+
1677
+ if result.rowcount == 0:
1678
+ logger.info(f"Duplicate skipped (Unique Violation): {obj}")
1679
+
1680
+ return result.rowcount
1681
+
1682
+ def _upsert_lightweight(self, obj: ModelType, index_elements: List[str]) -> Tuple[int, int]:
1683
+ """
1684
+ Lightweight UPSERT using ON CONFLICT DO UPDATE.
1685
+ Used primarily in Recovery Mode.
1686
+ """
1687
+ self._check_session_duration()
1688
+
1689
+ table = obj.__table__
1690
+
1691
+ # INSERT: Standard behavior (SQLAlchemy/DB handles defaults for missing cols)
1692
+ insert_values = {
1693
+ col.name: getattr(obj, col.name) for col in table.columns if getattr(obj, col.name) is not None
1694
+ }
1695
+
1696
+ # UPDATE: Calculate manual priority logic
1697
+ update_set_values = self._get_update_values(table, obj)
1698
+
1699
+ stmt = (
1700
+ insert(table)
1701
+ .values(insert_values)
1702
+ .on_conflict_do_update(index_elements=index_elements, set_=update_set_values)
1703
+ .returning(text("xmax"))
1704
+ )
1705
+
1706
+ result = self.session.execute(stmt)
1707
+
1708
+ # Capture the xmax if a row was returned (written)
1709
+ xmax = 0
1710
+ if result.rowcount > 0:
1711
+ row = result.fetchone()
1712
+ if row is not None:
1713
+ xmax = row.xmax
1714
+
1715
+ return result.rowcount, xmax
1716
+
1717
+ def _get_update_values(self, table, obj) -> Dict[str, Any]:
1718
+ """
1719
+ Helper: Builds the SET dictionary for the UPDATE clause.
1720
+
1721
+ Change: Now sparse. Only includes columns that are explicitly present on the object
1722
+ or have a System OnUpdate.
1723
+ """
1724
+ primary_keys = {col.name for col in table.primary_key.columns.values()}
1725
+ update_values = {}
1726
+
1727
+ for col in table.columns:
1728
+ if col.name in primary_keys:
1729
+ continue
1730
+
1731
+ # Skip columns that are None on this object (and have no OnUpdate override)
1732
+ # This prevents triggering context-dependent defaults during Recovery Mode.
1733
+ if getattr(obj, col.name) is None and not col.onupdate:
1734
+ continue
1735
+
1736
+ val = self._resolve_column_value(col, obj)
1737
+
1738
+ # If the resolver returns the special 'SKIP' sentinel, we exclude the col.
1739
+ if val is not Ellipsis:
1740
+ update_values[col.name] = val
1741
+
1742
+ return update_values
1743
+
1744
+ def _resolve_column_value(self, col, obj) -> Any: # noqa: C901
1745
+ """
1746
+ Helper: Determines the correct value for a single column based on priority.
1747
+ Used by both Bulk Processing and Lightweight Recovery.
1748
+
1749
+ Priority Hierarchy:
1750
+ 1. Python onupdate (System Override) -> Wins always.
1751
+ 2. Server onupdate (DB Trigger) -> Skips column so DB handles it.
1752
+ 3. Explicit Value -> Wins if provided.
1753
+ 4. Python Default -> Fallback for None.
1754
+ 5. Server Default -> Fallback for None.
1755
+ 6. NULL -> Last resort.
1756
+
1757
+ Returns `Ellipsis` (...) if the column should be excluded from the values dict.
1758
+ """
1759
+ # 1. Python OnUpdate (System Override)
1760
+ # Note: In Bulk Upsert, this might be overwritten by the stmt construction logic,
1761
+ # but we return it here for consistency in Insert/Row-by-Row.
1762
+ if col.onupdate is not None:
1763
+ expr = None
1764
+ if hasattr(col.onupdate, "arg") and isinstance(col.onupdate.arg, ClauseElement):
1765
+ expr = col.onupdate.arg
1766
+ elif isinstance(col.onupdate, ClauseElement):
1767
+ expr = col.onupdate
1768
+
1769
+ if expr is not None:
1770
+ return expr
1771
+
1772
+ # 2. Server OnUpdate (DB Trigger Override)
1773
+ if col.server_onupdate is not None:
1774
+ return Ellipsis # Sentinel to skip
1775
+
1776
+ # 3. Explicit Value
1777
+ val = getattr(obj, col.name)
1778
+ if val is not None:
1779
+ return val
1780
+
1781
+ # Fallback: Value is None
1782
+
1783
+ # 4. Python Default
1784
+ if col.default is not None:
1785
+ arg = col.default.arg
1786
+ if callable(arg):
1787
+ try:
1788
+ return arg()
1789
+ except TypeError as e:
1790
+ raise TypeError(
1791
+ "Calling the python default function failed. "
1792
+ "Most likely attempted to write NULL to a column with a python default that takes context as parameter."
1793
+ ) from e
1794
+ else:
1795
+ return arg
1796
+
1797
+ # 5. Server Default
1798
+ if col.server_default is not None:
1799
+ return col.server_default.arg
1800
+
1801
+ # 6. Explicit NULL
1802
+ return None
1803
+
1804
+
1805
+ class _BatchUtilities:
1806
+ """
1807
+ Mixin for High-Performance Batch Operations with Smart Default Handling.
1808
+
1809
+ Features:
1810
+ - Smart Defaults: Automatically handles Python static defaults and SQL Server defaults.
1811
+ - Recursive Bisecting: Isolates bad rows without stopping the whole batch.
1812
+
1813
+ Expected Interface:
1814
+ - self.session: sqlalchemy.orm.Session
1815
+ - self.start(): Method to start session
1816
+ - self.stop(commit=bool): Method to stop session
1817
+ - self.log_row(...): Optional helper for logging
1818
+ """
1819
+
1820
+ def insert_dataframe(
1821
+ self, df: pd.DataFrame, model: Type["ModelType"], batch_size: int = 5000, verbose: bool = False
1822
+ ) -> None:
1823
+ """
1824
+ High-Performance Smart Bulk Insert.
1825
+
1826
+ Logic Priority for Empty/Null Values:
1827
+ 1. Python Default (e.g. default="pending") -> Applied in-memory via Pandas.
1828
+ 2. Python Callable (e.g. default=uuid4) -> Executed per-row in-memory.
1829
+ 3. Server Default (e.g. server_default=text("now()")) -> Applied via SQL COALESCE.
1830
+ 4. NULL -> If none of the above exist.
1831
+
1832
+ CRITICAL LIMITATIONS:
1833
+ - NO EXPLICIT NULLS: If a column has a default (Python or Server), sending None/NaN
1834
+ will ALWAYS trigger that default. You cannot force a NULL value into such a column.
1835
+
1836
+ Args:
1837
+ df: Pandas DataFrame containing the data.
1838
+ model: SQLAlchemy Model class.
1839
+ batch_size: Rows per transaction chunk.
1840
+ verbose: If True, logs individual rows during recursive failure handling.
1841
+ """
1842
+ _require_pandas()
1843
+
1844
+ if self.session is not None:
1845
+ raise RuntimeError("insert_dataframe cannot be called when a session is already active.")
1846
+
1847
+ if df.empty:
1848
+ logger.warning("Pipeline: DataFrame is empty. Nothing to insert.")
1849
+ return
1850
+
1851
+ # Apply Python-side Defaults (In-Memory)
1852
+ df_processed = self._apply_python_defaults(df, model)
1853
+
1854
+ # Prepare Records (NaN -> None for SQL binding)
1855
+ records = df_processed.replace({np.nan: None}).to_dict("records")
1856
+ total_records = len(records)
1857
+
1858
+ logger.info(f"BULK INSERT START | Model: {model.__name__} | Records: {total_records}")
1859
+
1860
+ # Create Smart SQL Statement (COALESCE logic)
1861
+ smart_stmt = self._create_smart_insert_stmt(model, df_processed.columns)
1862
+
1863
+ self.start()
1864
+ job_stats = _BatchStats()
1865
+ job_start = time.time()
1866
+
1867
+ try:
1868
+ for i in range(0, total_records, batch_size):
1869
+ self._check_session_duration()
1870
+ chunk = records[i : i + batch_size]
1871
+
1872
+ # Recursive Engine using the pre-compiled Smart Statement
1873
+ self._insert_recursive_smart(chunk, smart_stmt, model, job_stats, verbose)
1874
+
1875
+ self.session.commit()
1876
+
1877
+ elapsed = time.time() - job_start
1878
+ ops_sec = job_stats.size / elapsed if elapsed > 0 else 0
1879
+ logger.info(
1880
+ f"BULK PROGRESS | Processed: {min(i + batch_size, total_records)}/{total_records} | "
1881
+ f"Written: {job_stats.inserts} | Dupes: {job_stats.duplicates} | "
1882
+ f"Skipped: {job_stats.failures} | Speed: {ops_sec:.0f} rows/s"
1883
+ )
1884
+
1885
+ except Exception:
1886
+ logger.error("Critical Failure in Bulk Insert", exc_info=True)
1887
+ self.stop(commit=False)
1888
+ raise
1889
+ finally:
1890
+ self.stop(commit=False)
1891
+ logger.info(f"BULK INSERT FINISHED | Total Time: {time.time() - job_start:.2f}s")
1892
+
1893
+ def _apply_python_defaults(self, df: pd.DataFrame, model: Type["ModelType"]) -> pd.DataFrame:
1894
+ """
1895
+ Fills NaNs with Python-side defaults.
1896
+ - Static values (int, str): Vectorized fill (Fast).
1897
+ - Simple functions (uuid4, datetime.now): Applied per-row (Slower).
1898
+ """
1899
+ df_copy = df.copy()
1900
+
1901
+ for col in model.__table__.columns:
1902
+ # Skip if column irrelevant, no default, or fully populated
1903
+ if col.name not in df_copy.columns or not df_copy[col.name].hasnans:
1904
+ continue
1905
+
1906
+ if col.default is None or not hasattr(col.default, "arg"):
1907
+ continue
1908
+
1909
+ default_arg = col.default.arg
1910
+
1911
+ # Case A: Static Value (Fast Vectorized Fill)
1912
+ if not callable(default_arg):
1913
+ df_copy[col.name] = df_copy[col.name].fillna(default_arg)
1914
+ continue
1915
+
1916
+ # Case B: Callable (Slow Row-by-Row Fill)
1917
+ # We assume it takes 0 arguments. If not raise error (default with context)
1918
+ mask = df_copy[col.name].isna()
1919
+ try:
1920
+ fill_values = [default_arg() for _ in range(mask.sum())]
1921
+ except TypeError as e:
1922
+ raise TypeError(
1923
+ "Calling the python default function failed. "
1924
+ "Most likely attempted to write NULL to a column with a python default that takes context as parameter."
1925
+ ) from e
1926
+ df_copy.loc[mask, col.name] = fill_values
1927
+
1928
+ return df_copy
1929
+
1930
+ def _create_smart_insert_stmt(self, model: Type["ModelType"], df_columns: Sequence[str]):
1931
+ """
1932
+ Helper: Builds an INSERT ... ON CONFLICT ... RETURNING statement.
1933
+ """
1934
+ table = model.__table__
1935
+ values_dict = {}
1936
+
1937
+ for col_name in df_columns:
1938
+ if col_name not in table.columns:
1939
+ continue
1940
+
1941
+ col = table.columns[col_name]
1942
+
1943
+ if col.server_default is not None and hasattr(col.server_default, "arg"):
1944
+ values_dict[col_name] = func.coalesce(bindparam(col_name, type_=col.type), col.server_default.arg)
1945
+ else:
1946
+ values_dict[col_name] = bindparam(col_name, type_=col.type)
1947
+
1948
+ # Add .returning() at the end so we get back the IDs of inserted rows
1949
+ # We dynamically grab the primary key column(s) to return.
1950
+ pk_cols = [c for c in table.primary_key.columns]
1951
+
1952
+ return insert(table).values(values_dict).on_conflict_do_nothing().returning(*pk_cols)
1953
+
1954
+ def _insert_recursive_smart(
1955
+ self, records: List[Dict[str, Any]], stmt, model: Type["ModelType"], stats: _BatchStats, verbose: bool
1956
+ ) -> None:
1957
+ """
1958
+ Recursive bisecting engine using the pre-compiled Smart Statement.
1959
+ Uses SAVEPOINTs (begin_nested) to isolate errors without committing.
1960
+ """
1961
+ if not records:
1962
+ return
1963
+
1964
+ # Base Case: Single Row
1965
+ if len(records) == 1:
1966
+ try:
1967
+ with self.session.begin_nested():
1968
+ result = self.session.execute(stmt, records)
1969
+ # For single row, simple boolean check works
1970
+ written = len(result.all()) # Will be 1 (written) or 0 (duplicate)
1971
+ duplicates = 1 - written
1972
+ stats.inserts += written
1973
+ stats.duplicates += duplicates
1974
+ except Exception as e:
1975
+ stats.failures += 1
1976
+ stats.size += 1
1977
+ if hasattr(self, "log_row"):
1978
+ self.log_row(model(**records[0]), prefix_msg="SKIPPING BAD ROW")
1979
+ logger.error(f"Bulk Insert Error on single row: {e}")
1980
+ return
1981
+
1982
+ # Recursive Step
1983
+ try:
1984
+ with self.session.begin_nested():
1985
+ result = self.session.execute(stmt, records)
1986
+
1987
+ # Count the actually returned rows (IDs)
1988
+ # This works because "ON CONFLICT DO NOTHING" returns NOTHING for duplicates.
1989
+ # result.all() fetches the list of returned PKs.
1990
+ written = len(result.all())
1991
+
1992
+ duplicates = len(records) - written
1993
+
1994
+ stats.inserts += written
1995
+ stats.duplicates += duplicates
1996
+ stats.size += len(records)
1997
+
1998
+ except Exception:
1999
+ # Failure -> Split and Retry
2000
+ mid = len(records) // 2
2001
+ self._insert_recursive_smart(records[:mid], stmt, model, stats, verbose)
2002
+ self._insert_recursive_smart(records[mid:], stmt, model, stats, verbose)
2003
+
2004
+
2005
+ class AlchemyInterface(_BaseAlchemyCore, _LegacyAlchemyOps, _BatchPipelineOps, _BatchUtilities):
2006
+ """
2007
+ Concrete interface combining:
2008
+
2009
+ - BaseAlchemyCore: engine/session management, DDL utilities, core helpers.
2010
+ - LegacyAlchemyOps: legacy insert/upsert/windowed query APIs (kept for compatibility).
2011
+ - BatchPipelineOps: modern batch processing and safe iteration utilities.
2012
+ - BatchUtilities: high-performance bulk operations (insert_dataframe) with smart defaults.
2013
+
2014
+ This is the class intended to be used by application code.
2015
+ """
2016
+
2017
+ pass
2018
+
110
2019
 
111
- self.session.rollback()
2020
+ __all__ = ["RowIterator", "AtomIterator", "AtomicUnit", "BatchOp", "OpAction", "AlchemyInterface"]