macrostrat.database 3.5.4__tar.gz → 4.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (20) hide show
  1. macrostrat_database-4.0.1/PKG-INFO +23 -0
  2. macrostrat_database-4.0.1/macrostrat/.DS_Store +0 -0
  3. {macrostrat_database-3.5.4 → macrostrat_database-4.0.1}/macrostrat/database/__init__.py +9 -11
  4. macrostrat_database-4.0.1/macrostrat/database/compat.py +36 -0
  5. {macrostrat_database-3.5.4 → macrostrat_database-4.0.1}/macrostrat/database/postgresql.py +41 -6
  6. macrostrat_database-3.5.4/macrostrat/database/utils.py → macrostrat_database-4.0.1/macrostrat/database/query.py +181 -248
  7. macrostrat_database-4.0.1/macrostrat/database/utils.py +213 -0
  8. macrostrat_database-4.0.1/pyproject.toml +44 -0
  9. macrostrat_database-3.5.4/PKG-INFO +0 -20
  10. macrostrat_database-3.5.4/pyproject.toml +0 -25
  11. {macrostrat_database-3.5.4 → macrostrat_database-4.0.1}/macrostrat/database/mapper/__init__.py +0 -0
  12. {macrostrat_database-3.5.4 → macrostrat_database-4.0.1}/macrostrat/database/mapper/base.py +0 -0
  13. {macrostrat_database-3.5.4 → macrostrat_database-4.0.1}/macrostrat/database/mapper/cache.py +0 -0
  14. {macrostrat_database-3.5.4 → macrostrat_database-4.0.1}/macrostrat/database/mapper/utils.py +0 -0
  15. {macrostrat_database-3.5.4 → macrostrat_database-4.0.1}/macrostrat/database/transfer/__init__.py +0 -0
  16. {macrostrat_database-3.5.4 → macrostrat_database-4.0.1}/macrostrat/database/transfer/dump_database.py +0 -0
  17. {macrostrat_database-3.5.4 → macrostrat_database-4.0.1}/macrostrat/database/transfer/move_tables.py +0 -0
  18. {macrostrat_database-3.5.4 → macrostrat_database-4.0.1}/macrostrat/database/transfer/restore_database.py +0 -0
  19. {macrostrat_database-3.5.4 → macrostrat_database-4.0.1}/macrostrat/database/transfer/stream_utils.py +0 -0
  20. {macrostrat_database-3.5.4 → macrostrat_database-4.0.1}/macrostrat/database/transfer/utils.py +0 -0
@@ -0,0 +1,23 @@
1
+ Metadata-Version: 2.3
2
+ Name: macrostrat.database
3
+ Version: 4.0.1
4
+ Summary: A SQLAlchemy-based database toolkit.
5
+ Author: Daven Quinn
6
+ Author-email: Daven Quinn <dev@davenquinn.com>
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: Programming Language :: Python :: 3.10
9
+ Classifier: Programming Language :: Python :: 3.11
10
+ Classifier: Programming Language :: Python :: 3.12
11
+ Classifier: Programming Language :: Python :: 3.13
12
+ Classifier: Programming Language :: Python :: 3.14
13
+ Requires-Dist: geoalchemy2>=0.15.2,<0.16
14
+ Requires-Dist: sqlalchemy>=2.0.18,<3
15
+ Requires-Dist: sqlalchemy-utils>=0.41.1,<0.42
16
+ Requires-Dist: click>=8.1.3,<9
17
+ Requires-Dist: macrostrat-utils>=1.3.3,<2
18
+ Requires-Dist: sqlparse>=0.5.1,<0.6
19
+ Requires-Dist: aiofiles>=23.2.1,<24
20
+ Requires-Dist: rich>=13.7.1,<14
21
+ Requires-Dist: psycopg>=3.2.1,<4
22
+ Requires-Dist: psycopg2>=2.9.11,<3
23
+ Requires-Python: >=3.10, <4
@@ -3,10 +3,10 @@ from contextlib import contextmanager
3
3
  from pathlib import Path
4
4
  from typing import Optional, Union
5
5
 
6
- from psycopg2.errors import InvalidSavepointSpecification
7
- from psycopg2.sql import Identifier
8
- from sqlalchemy import URL, Engine, MetaData, create_engine, inspect
9
- from sqlalchemy.exc import IntegrityError, InternalError
6
+ from psycopg.errors import InvalidSavepointSpecification
7
+ from psycopg.sql import Identifier
8
+ from sqlalchemy import URL, Engine, MetaData, inspect
9
+ from sqlalchemy.exc import IntegrityError, OperationalError
10
10
  from sqlalchemy.ext.compiler import compiles
11
11
  from sqlalchemy.orm import Session, scoped_session, sessionmaker
12
12
  from sqlalchemy.sql.expression import Insert
@@ -17,6 +17,7 @@ from .mapper import DatabaseMapper
17
17
  from .postgresql import on_conflict, prefix_inserts # noqa
18
18
  from .utils import ( # noqa
19
19
  create_database,
20
+ create_engine,
20
21
  database_exists,
21
22
  drop_database,
22
23
  get_dataframe,
@@ -60,12 +61,8 @@ class Database(object):
60
61
 
61
62
  self.instance_params = kwargs.pop("instance_params", {})
62
63
 
63
- if isinstance(db_conn, Engine):
64
- log.info(f"Set up database connection with engine {db_conn.url}")
65
- self.engine = db_conn
66
- else:
67
- log.info(f"Setting up database connection with URL '{db_conn}'")
68
- self.engine = create_engine(db_conn, echo=echo_sql, **kwargs)
64
+ self.engine = create_engine(db_conn, echo=echo_sql, **kwargs)
65
+
69
66
  self.metadata = kwargs.get("metadata", metadata)
70
67
 
71
68
  # Scoped session for database
@@ -334,6 +331,7 @@ class Database(object):
334
331
 
335
332
  if connection is None:
336
333
  connection = self.session.connection()
334
+
337
335
  params = {"name": Identifier(name)}
338
336
  run_query(connection, "SAVEPOINT {name}", params)
339
337
  should_rollback = rollback == "always"
@@ -356,7 +354,7 @@ def _clear_savepoint(connection, name, rollback=True):
356
354
  run_query(connection, "ROLLBACK TO SAVEPOINT {name}", params)
357
355
  else:
358
356
  run_query(connection, "RELEASE SAVEPOINT {name}", params)
359
- except InternalError as err:
357
+ except OperationalError as err:
360
358
  if isinstance(err.orig, InvalidSavepointSpecification):
361
359
  log.warning(
362
360
  f"Savepoint {name} does not exist; we may have already rolled back."
@@ -0,0 +1,36 @@
1
+ from warnings import warn
2
+
3
+ import psycopg.sql as psql3
4
+ import psycopg2.sql as psql2
5
+
6
+
7
+ def update_legacy_identifier(identifier):
8
+ """
9
+ For backwards compatibility with current code, we need to map psycopg2 identifiers to their equivalents in psycopg3,
10
+ while printing a warning that the mapping is deprecated.
11
+ :param identifier:
12
+ :return: psycopg3 equivalent of identifier
13
+ """
14
+ new_identifier = _map_psycopg2_identifier_to_psycopg3_identifier_internal(
15
+ identifier
16
+ )
17
+ if new_identifier is not identifier:
18
+ warn(
19
+ "psycopg2 identifiers are deprecated. Please use psycopg3 identifiers instead.",
20
+ DeprecationWarning,
21
+ )
22
+ return new_identifier
23
+
24
+
25
+ def _map_psycopg2_identifier_to_psycopg3_identifier_internal(identifier):
26
+ if isinstance(identifier, psql2.Identifier):
27
+ return psql3.Identifier(*identifier._wrapped)
28
+ if isinstance(identifier, psql2.SQL):
29
+ return psql3.SQL(identifier._wrapped)
30
+ if isinstance(identifier, psql2.Literal):
31
+ return psql3.Literal(identifier._wrapped)
32
+ if isinstance(identifier, psql2.Placeholder):
33
+ return psql3.Placeholder(identifier._obj)
34
+ if isinstance(identifier, psql2.Composed):
35
+ return psql3.Composed(identifier._obj)
36
+ return identifier
@@ -2,18 +2,26 @@ from __future__ import annotations
2
2
 
3
3
  from contextlib import contextmanager
4
4
  from contextvars import ContextVar
5
- from typing import TYPE_CHECKING
5
+ from enum import Enum
6
+ from typing import Any, Sequence, TYPE_CHECKING
6
7
 
7
- import psycopg2
8
8
  from sqlalchemy.dialects import postgresql
9
9
  from sqlalchemy.exc import CompileError
10
10
  from sqlalchemy.ext.compiler import compiles
11
- from sqlalchemy.sql.expression import Insert, text
11
+ from sqlalchemy.sql.dml import Insert
12
+ from sqlalchemy.sql.expression import text
12
13
 
13
14
  if TYPE_CHECKING:
14
15
  from ..database import Database
15
16
 
16
- _insert_mode = ContextVar("insert-mode", default="do-nothing")
17
+
18
+ class OnConflictAction(str, Enum):
19
+ DO_NOTHING = "do-nothing"
20
+ DO_UPDATE = "do-update"
21
+ RESTRICT = "restrict"
22
+
23
+
24
+ _insert_mode = ContextVar("insert-mode", default="restrict")
17
25
 
18
26
 
19
27
  # https://stackoverflow.com/questions/33307250/postgresql-on-conflict-in-sqlalchemy/62305344#62305344
@@ -26,9 +34,10 @@ def on_conflict(action="restrict"):
26
34
  _insert_mode.reset(token)
27
35
 
28
36
 
29
- # @compiles(Insert, "postgresql")
37
+ @compiles(Insert, "postgresql")
30
38
  def prefix_inserts(insert, compiler, **kw):
31
39
  """Conditionally adapt insert statements to use on-conflict resolution (a PostgreSQL feature)"""
40
+
32
41
  if insert._post_values_clause is not None:
33
42
  return compiler.visit_insert(insert, **kw)
34
43
 
@@ -60,10 +69,36 @@ def prefix_inserts(insert, compiler, **kw):
60
69
  return compiler.visit_insert(insert, **kw)
61
70
 
62
71
 
72
+ def upsert(
73
+ table,
74
+ values: dict[str, Any],
75
+ index_elements: Sequence[str],
76
+ *,
77
+ on_conflict: OnConflictAction = "do-update",
78
+ ):
79
+ stmt = postgresql.insert(table).values(values)
80
+ if on_conflict == "restrict":
81
+ return stmt
82
+
83
+ if on_conflict == "do-nothing":
84
+ return stmt.on_conflict_do_nothing(index_elements=list(index_elements))
85
+
86
+ update_values = {
87
+ column.name: getattr(stmt.excluded, column.name)
88
+ for column in table.columns
89
+ if column.name in values and column.name not in index_elements
90
+ }
91
+
92
+ return stmt.on_conflict_do_update(
93
+ index_elements=list(index_elements),
94
+ set_=update_values,
95
+ )
96
+
97
+
63
98
  def table_exists(db: Database, table_name: str, schema: str = "public") -> bool:
64
99
  """Check if a table exists in a PostgreSQL database."""
65
100
  sql = """SELECT EXISTS (
66
- SELECT FROM information_schema.tables
101
+ SELECT FROM information_schema.tables
67
102
  WHERE table_schema = :schema
68
103
  AND table_name = :table_name
69
104
  );"""
@@ -1,20 +1,18 @@
1
1
  import os
2
- from contextlib import contextmanager
2
+ from dataclasses import dataclass
3
3
  from enum import Enum
4
4
  from pathlib import Path
5
5
  from re import search
6
6
  from sys import stderr
7
- from time import sleep
8
- from typing import IO, Union
7
+ from typing import Callable, Any, IO, Union
9
8
  from warnings import warn
10
9
 
11
- import psycopg2.errors
12
- from click import echo, secho
13
- from psycopg2.extensions import set_wait_callback
14
- from psycopg2.extras import wait_select
15
- from psycopg2.sql import SQL, Composable, Composed
10
+ import psycopg2.sql as psql2
11
+ from click import secho
12
+ from psycopg.errors import QueryCanceled
13
+ from psycopg.sql import SQL, Composable, Composed
16
14
  from rich.console import Console
17
- from sqlalchemy import MetaData, create_engine, text
15
+ from sqlalchemy import text
18
16
  from sqlalchemy.engine import Connection, Engine
19
17
  from sqlalchemy.exc import (
20
18
  IntegrityError,
@@ -23,23 +21,17 @@ from sqlalchemy.exc import (
23
21
  OperationalError,
24
22
  ProgrammingError,
25
23
  )
26
- from sqlalchemy.orm import sessionmaker
27
- from sqlalchemy.schema import Table
28
- from sqlalchemy.sql.elements import ClauseElement, TextClause
29
- from sqlalchemy_utils import create_database as _create_database
30
- from sqlalchemy_utils import database_exists, drop_database
24
+ from sqlalchemy.sql.elements import TextClause
31
25
  from sqlparse import format, split
32
26
 
33
- from macrostrat.utils import cmd, get_logger
27
+ from macrostrat.database.compat import (
28
+ update_legacy_identifier,
29
+ )
30
+ from macrostrat.utils import get_logger
34
31
 
35
32
  log = get_logger(__name__)
36
33
 
37
34
 
38
- def db_session(engine):
39
- factory = sessionmaker(bind=engine)
40
- return factory()
41
-
42
-
43
35
  def infer_is_sql_text(_string: str) -> bool:
44
36
  """
45
37
  Return True if the string is a valid SQL query,
@@ -74,19 +66,6 @@ def canonicalize_query(file_or_text: Union[str, Path, IO]) -> Union[str, Path]:
74
66
  return file_or_text
75
67
 
76
68
 
77
- def get_dataframe(connectable, filename_or_query, **kwargs):
78
- """
79
- Run a query on a SQL database (represented by
80
- a SQLAlchemy database object) and turn it into a
81
- `Pandas` dataframe.
82
- """
83
- from pandas import read_sql
84
-
85
- sql = get_sql_text(filename_or_query)
86
-
87
- return read_sql(sql, connectable, **kwargs)
88
-
89
-
90
69
  def pretty_print(sql, **kwargs):
91
70
  """Print and optionally summarize an SQL query"""
92
71
  summarize = kwargs.pop("summarize", True)
@@ -117,6 +96,7 @@ def summarize_statement(sql):
117
96
  if not line.startswith(i):
118
97
  continue
119
98
  return line.split("(")[0].strip().rstrip(";").replace(" AS", "")
99
+ return sql.strip().split("\n")[0].strip().rstrip(";")
120
100
 
121
101
 
122
102
  def get_sql_text(sql, interpret_as_file=None, echo_file_name=True):
@@ -139,9 +119,7 @@ def _get_queries(sql, interpret_as_file=None):
139
119
  for i in sql:
140
120
  queries.extend(_get_queries(i, interpret_as_file=interpret_as_file))
141
121
  return queries
142
- if isinstance(sql, TextClause):
143
- return [sql]
144
- if isinstance(sql, SQL):
122
+ if isinstance(sql, (SQL, psql2.SQL, TextClause)):
145
123
  return [sql]
146
124
 
147
125
  if sql in [None, ""]:
@@ -158,7 +136,7 @@ def _get_queries(sql, interpret_as_file=None):
158
136
 
159
137
 
160
138
  def _is_prebind_param(param):
161
- return isinstance(param, Composable)
139
+ return isinstance(param, (Composable, psql2.Composable))
162
140
 
163
141
 
164
142
  def _split_params(params):
@@ -169,7 +147,7 @@ def _split_params(params):
169
147
  if isinstance(params, (list, tuple)):
170
148
  for i in params:
171
149
  if _is_prebind_param(i):
172
- new_bind_params.append(i)
150
+ new_bind_params.append(update_legacy_identifier(i))
173
151
  else:
174
152
  new_params.append(i)
175
153
  elif isinstance(params, dict):
@@ -177,7 +155,7 @@ def _split_params(params):
177
155
  new_bind_params = {}
178
156
  for k, v in params.items():
179
157
  if _is_prebind_param(v):
180
- new_bind_params[k] = v
158
+ new_bind_params[k] = update_legacy_identifier(v)
181
159
  else:
182
160
  new_params[k] = v
183
161
  if len(new_bind_params) == 0:
@@ -196,6 +174,8 @@ def _get_cursor(connectable):
196
174
  while hasattr(conn, "driver_connection") or hasattr(conn, "connection"):
197
175
  if hasattr(conn, "driver_connection"):
198
176
  conn = conn.driver_connection
177
+ elif conn.connection == conn:
178
+ break
199
179
  else:
200
180
  conn = conn.connection
201
181
  if callable(conn):
@@ -228,8 +208,11 @@ def _render_query(query: Union[SQL, Composed], connectable: Union[Engine, Connec
228
208
  return query.as_string(conn)
229
209
 
230
210
 
231
- def infer_has_server_binds(sql):
232
- return "%s" in sql or search(r"%\(\w+\)s", sql)
211
+ def infer_has_server_binds(sql) -> bool:
212
+ if "%s" in sql:
213
+ return True
214
+ res = search(r"%\(\w+\)s", sql)
215
+ return res is not None
233
216
 
234
217
 
235
218
  _default_statement_filter = lambda sql_text, params: True
@@ -254,7 +237,40 @@ def _normalize_output_args(kwargs):
254
237
  return output_mode, output_file
255
238
 
256
239
 
257
- def _run_sql(connectable, sql, params=None, **kwargs):
240
+ @dataclass
241
+ class StatementResult:
242
+ query: Any
243
+ params: Any = None
244
+ skip: bool = False
245
+ label: str | None = None
246
+
247
+ @classmethod
248
+ def skipped(cls, query=None, params=None) -> "StatementResult":
249
+ return cls(query=query, params=params, skip=True)
250
+
251
+
252
+ @dataclass
253
+ class StatementContext:
254
+ index: int
255
+ query: Any
256
+ params: Any
257
+ sql_text: str
258
+
259
+
260
+ TransformFn = Callable[[StatementContext], list[StatementResult] | None]
261
+ Connectable = Union[Engine, Connection]
262
+
263
+
264
+ def _statement_filter_to_transform(statement_filter) -> TransformFn:
265
+ def transform(ctx: StatementContext) -> list[StatementResult] | None:
266
+ if not statement_filter(ctx.sql_text, ctx.params):
267
+ return [StatementResult.skipped(query=ctx.query, params=ctx.params)]
268
+ return None
269
+
270
+ return transform
271
+
272
+
273
+ def _run_sql(connectable, sql, params=None, *, print_skipped=True, **kwargs):
258
274
  """
259
275
  Internal function for running a query on a SQLAlchemy connectable,
260
276
  which always returns an iterator. The wrapper function adds the option
@@ -267,17 +283,30 @@ def _run_sql(connectable, sql, params=None, **kwargs):
267
283
 
268
284
  stop_on_error = kwargs.pop("stop_on_error", False)
269
285
  raise_errors = kwargs.pop("raise_errors", False)
270
- has_server_binds = kwargs.pop("has_server_binds", None)
271
286
  ensure_single_query = kwargs.pop("ensure_single_query", False)
272
- statement_filter = kwargs.pop("statement_filter", _default_statement_filter)
273
287
  output_mode, output_file = _normalize_output_args(kwargs)
288
+ has_server_binds = kwargs.pop("has_server_binds", None)
289
+
290
+ statement_filter = kwargs.pop("statement_filter", None)
291
+ transform_statement: TransformFn | None = kwargs.pop("transform_statement", None)
274
292
 
275
293
  if stop_on_error:
276
294
  raise_errors = True
277
295
  warn(DeprecationWarning("stop_on_error is deprecated, use raise_errors"))
278
296
 
279
- interpret_as_file = kwargs.pop("interpret_as_file", None)
297
+ if statement_filter is not None:
298
+ warn(
299
+ DeprecationWarning(
300
+ "statement_filter is deprecated, use transform_statement"
301
+ )
302
+ )
303
+ if transform_statement is not None:
304
+ raise ValueError(
305
+ "Cannot specify both statement_filter and transform_statement"
306
+ )
307
+ transform_statement = _statement_filter_to_transform(statement_filter)
280
308
 
309
+ interpret_as_file = kwargs.pop("interpret_as_file", None)
281
310
  queries = _get_queries(sql, interpret_as_file=interpret_as_file)
282
311
 
283
312
  if queries is None:
@@ -286,81 +315,116 @@ def _run_sql(connectable, sql, params=None, **kwargs):
286
315
  if ensure_single_query and len(queries) > 1:
287
316
  raise ValueError("Multiple queries passed when only one was expected")
288
317
 
289
- # check if parameters is a list of the same length as the number of queries
290
318
  if not isinstance(params, list) or not len(params) == len(queries):
291
319
  params = [params] * len(queries)
292
320
 
293
- for query, _params in zip(queries, params):
294
- params, pre_bind_params = _split_params(_params)
321
+ for index, (query, _params) in enumerate(zip(queries, params)):
322
+ _query, sql_text = _render_query_text(connectable, query, _params)
323
+ if sql_text == "":
324
+ continue
295
325
 
296
- if pre_bind_params is not None:
297
- if not isinstance(query, SQL):
298
- query = SQL(query)
299
- # Pre-bind the parameters using PsycoPG2
300
- query = query.format(**pre_bind_params)
326
+ ctx = StatementContext(
327
+ index=index, query=query, params=_params, sql_text=sql_text
328
+ )
301
329
 
302
- if isinstance(query, (SQL, Composed)):
303
- query = _render_query(query, connectable)
330
+ results = transform_statement(ctx) if transform_statement is not None else None
304
331
 
305
- sql_text = str(query)
332
+ if results is None:
333
+ results = [StatementResult(query=query, params=_params)]
306
334
 
307
- if isinstance(query, str):
308
- sql_text = format(query, strip_comments=True).strip()
309
- if sql_text == "":
310
- continue
311
- # Check for server-bound parameters in sql native style. If there are none, use
312
- # the SQLAlchemy text() function, otherwise use the raw query string
313
- if has_server_binds is None:
314
- has_server_binds = infer_has_server_binds(sql_text)
315
-
316
- should_run = statement_filter(sql_text, params)
317
-
318
- # Shorten summary text for printing
319
- if output_mode != OutputMode.ALL:
320
- sql_text = summarize_statement(sql_text)
321
-
322
- if not should_run:
323
- secho(
324
- sql_text,
325
- dim=True,
326
- strikethrough=True,
327
- file=output_file,
335
+ for result in results:
336
+ yield from _execute_one(
337
+ connectable,
338
+ result,
339
+ output_file,
340
+ raise_errors=raise_errors,
341
+ output_mode=output_mode,
342
+ print_skipped=print_skipped,
343
+ has_server_binds=has_server_binds,
328
344
  )
329
- continue
330
345
 
331
- # This only does something for postgresql, but it's harmless to run it for other engines
332
- set_wait_callback(wait_select)
333
-
334
- try:
335
- trans = connectable.begin()
336
- except InvalidRequestError:
337
- trans = None
338
- try:
339
- log.debug("Executing SQL: \n %s", query)
340
- if has_server_binds:
341
- conn = _get_connection(connectable)
342
- res = conn.exec_driver_sql(query, params)
343
- else:
344
- if not isinstance(query, TextClause):
345
- query = text(query)
346
- res = connectable.execute(query, params)
347
- yield res
348
- if trans is not None:
349
- trans.commit()
350
- elif hasattr(connectable, "commit"):
351
- connectable.commit()
352
- secho(sql_text, dim=True, file=output_file)
353
- except Exception as err:
354
- if trans is not None:
355
- trans.rollback()
356
- elif hasattr(connectable, "rollback"):
357
- connectable.rollback()
358
- if raise_errors or _should_raise_query_error(err):
359
- raise err
360
-
361
- _print_error(sql_text, err, file=output_file)
362
- finally:
363
- set_wait_callback(None)
346
+
347
+ def _render_query_text(connectable, query, params):
348
+ params, pre_bind_params = _split_params(params)
349
+ if isinstance(query, (psql2.SQL, psql2.Composed)):
350
+ query = update_legacy_identifier(query)
351
+
352
+ if pre_bind_params is not None:
353
+ if not isinstance(query, SQL):
354
+ query = SQL(query)
355
+ # Pre-bind the parameters using psycopg
356
+ query = query.format(**pre_bind_params)
357
+
358
+ if isinstance(query, (SQL, Composed)):
359
+ query = _render_query(query, connectable)
360
+
361
+ sql_text = str(query)
362
+ if isinstance(query, str):
363
+ sql_text = format(query, strip_comments=True).strip()
364
+
365
+ return query, sql_text
366
+
367
+
368
+ def _execute_one(
369
+ connectable,
370
+ result: StatementResult,
371
+ output_file: IO,
372
+ *,
373
+ raise_errors: bool = True,
374
+ output_mode: OutputMode = OutputMode.SUMMARY,
375
+ has_server_binds: bool | None = None,
376
+ print_skipped: bool = True,
377
+ ):
378
+ params = result.params
379
+
380
+ query, sql_text = _render_query_text(connectable, result.query, params)
381
+ if has_server_binds is None:
382
+ has_server_binds = infer_has_server_binds(sql_text)
383
+
384
+ if result.label is not None:
385
+ display_text = result.label
386
+ elif output_mode != OutputMode.ALL:
387
+ display_text = summarize_statement(str(query))
388
+ else:
389
+ display_text = str(query)
390
+
391
+ if result.skip:
392
+ if print_skipped:
393
+ secho(display_text, dim=True, strikethrough=True, file=output_file)
394
+ return
395
+
396
+ try:
397
+ trans = connectable.begin()
398
+ except InvalidRequestError:
399
+ trans = None
400
+
401
+ try:
402
+ log.debug("Executing SQL: \n %s", query)
403
+ if has_server_binds:
404
+ conn = _get_connection(connectable)
405
+ res = conn.exec_driver_sql(query, params)
406
+ else:
407
+ if not isinstance(query, TextClause):
408
+ query = text(query)
409
+ res = connectable.execute(query, params)
410
+
411
+ yield res
412
+
413
+ if trans is not None:
414
+ trans.commit()
415
+ elif hasattr(connectable, "commit"):
416
+ connectable.commit()
417
+
418
+ secho(display_text, dim=True, file=output_file)
419
+
420
+ except Exception as err:
421
+ if trans is not None:
422
+ trans.rollback()
423
+ elif hasattr(connectable, "rollback"):
424
+ connectable.rollback()
425
+ if raise_errors or _should_raise_query_error(err):
426
+ raise err
427
+ _print_error(display_text, err, file=output_file)
364
428
 
365
429
 
366
430
  def _should_raise_query_error(err):
@@ -379,7 +443,7 @@ def _should_raise_query_error(err):
379
443
  # database backends.
380
444
  # Ideally we could handle operational errors more gracefully
381
445
  if (
382
- isinstance(orig_err, psycopg2.errors.QueryCanceled)
446
+ isinstance(orig_err, QueryCanceled)
383
447
  or getattr(orig_err, "pgcode", None) == "57014"
384
448
  ):
385
449
  return True
@@ -506,144 +570,13 @@ def run_sql(*args, **kwargs):
506
570
  statement_filter : Callable
507
571
  A function that takes a SQL statement and parameters and returns True if the statement
508
572
  should be run, and False if it should be skipped.
573
+ transform_statement: TransformFn | None
574
+ A function that takes a StatementContext and returns a list of StatementResult
575
+ objects, which can modify the query, parameters, and whether the statement
576
+ should be skipped or not. This allows for more complex logic than a simple
577
+ statement filter.
509
578
  """
510
579
  res = _run_sql(*args, **kwargs)
511
580
  if kwargs.pop("yield_results", False):
512
581
  return res
513
582
  return list(res)
514
-
515
-
516
- def execute(connectable, sql, params=None, stop_on_error=False, **kwargs):
517
- output_file = kwargs.pop("output_file", None)
518
- output_mode = kwargs.pop("output_mode", None)
519
- sql = format(sql, strip_comments=True).strip()
520
- if sql == "":
521
- return
522
- try:
523
- connectable.begin()
524
- res = connectable.execute(text(sql), params=params)
525
- if hasattr(connectable, "commit"):
526
- connectable.commit()
527
- pretty_print(sql, dim=True, file=output_file, mode=output_mode)
528
- return res
529
- except (ProgrammingError, IntegrityError) as err:
530
- if hasattr(connectable, "rollback"):
531
- connectable.rollback()
532
- _print_error(sql, dim=True, file=output_file, mode=output_mode)
533
- if stop_on_error:
534
- return
535
- finally:
536
- if hasattr(connectable, "close"):
537
- connectable.close()
538
-
539
-
540
- def get_or_create(session, model, defaults=None, **kwargs):
541
- """
542
- Get an instance of a model, or create it if it doesn't
543
- exist.
544
-
545
- https://stackoverflow.com/questions/2546207
546
- """
547
- instance = session.query(model).filter_by(**kwargs).first()
548
- if instance:
549
- instance._created = False
550
- return instance
551
- else:
552
- params = dict(
553
- (k, v) for k, v in kwargs.items() if not isinstance(v, ClauseElement)
554
- )
555
- params.update(defaults or {})
556
- instance = model(**params)
557
- session.add(instance)
558
- instance._created = True
559
- return instance
560
-
561
-
562
- def get_db_model(db, model_name: str):
563
- return getattr(db.model, model_name)
564
-
565
-
566
- @contextmanager
567
- def temp_database(conn_string, drop=True, ensure_empty=False):
568
- """Create a temporary database and tear it down after tests."""
569
- create_database(conn_string, exists_ok=True, replace=ensure_empty)
570
- try:
571
- yield create_engine(conn_string)
572
- finally:
573
- if drop:
574
- drop_database(conn_string)
575
-
576
-
577
- def create_database(url, **kwargs):
578
- """Create a database if it doesn't exist.
579
-
580
- Parameters
581
- ----------
582
- url : str
583
- A SQLAlchemy database URL.
584
- exists_ok : bool
585
- If True, don't raise an error if the database already exists.
586
- replace : bool
587
- If True, drop the database if it exists and create a new one.
588
- kwargs : dict
589
- Additional keyword arguments to pass to `sqlalchemy_utils.create_database`.
590
- """
591
- db_exists = database_exists(url)
592
-
593
- should_replace = kwargs.pop("replace", False)
594
- exists_ok = kwargs.pop("exists_ok", False)
595
-
596
- if should_replace and db_exists:
597
- drop_database(url)
598
- db_exists = False
599
-
600
- if exists_ok and db_exists:
601
- return
602
- _create_database(url, **kwargs)
603
-
604
-
605
- def connection_args(engine):
606
- """Get PostgreSQL connection arguments for an engine"""
607
- _psql_flags = {"-U": "username", "-h": "host", "-p": "port", "-P": "password"}
608
-
609
- if isinstance(engine, str):
610
- # We passed a connection url!
611
- engine = create_engine(engine)
612
- flags = ""
613
- for flag, _attr in _psql_flags.items():
614
- val = getattr(engine.url, _attr)
615
- if val is not None:
616
- flags += f" {flag} {val}"
617
- return flags, engine.url.database
618
-
619
-
620
- def db_isready(engine_or_url):
621
- args, _ = connection_args(engine_or_url)
622
- c = cmd("pg_isready", args, capture_output=True)
623
- return c.returncode == 0
624
-
625
-
626
- def wait_for_database(engine_or_url, quiet=False):
627
- msg = "Waiting for database..."
628
- while not db_isready(engine_or_url):
629
- if not quiet:
630
- echo(msg, err=True)
631
- log.info(msg)
632
- sleep(1)
633
-
634
-
635
- def reflect_table(engine, tablename, *column_args, **kwargs):
636
- """
637
- One-off reflection of a database table or view. Note: for most purposes,
638
- it will be better to use the database tables automapped at runtime in the
639
- `self.tables` object. However, this function can be useful for views (which
640
- are not reflected automatically), or to customize type definitions for mapped
641
- tables.
642
-
643
- A set of `column_args` can be used to pass columns to override with the mapper, for
644
- instance to set up foreign and primary key constraints.
645
- https://docs.sqlalchemy.org/en/13/core/reflection.html#reflecting-views
646
- """
647
- schema = kwargs.pop("schema", "public")
648
- meta = MetaData(schema=schema)
649
- return Table(tablename, meta, *column_args, autoload_with=engine, **kwargs)
@@ -0,0 +1,213 @@
1
+ from contextlib import contextmanager
2
+ from time import sleep
3
+
4
+ from click import echo
5
+ from sqlalchemy import MetaData
6
+ from sqlalchemy import create_engine as base_create_engine
7
+ from sqlalchemy import text
8
+ from sqlalchemy.engine import Engine
9
+ from sqlalchemy.engine.url import make_url
10
+ from sqlalchemy.exc import (
11
+ IntegrityError,
12
+ OperationalError,
13
+ ProgrammingError,
14
+ )
15
+ from sqlalchemy.orm import sessionmaker
16
+ from sqlalchemy.schema import Table
17
+ from sqlalchemy.sql.elements import ClauseElement
18
+ from sqlalchemy_utils import create_database as _create_database
19
+ from sqlalchemy_utils import database_exists, drop_database
20
+ from sqlparse import format
21
+
22
+ from macrostrat.utils import cmd, get_logger
23
+ from .query import get_sql_text
24
+
25
+ log = get_logger(__name__)
26
+
27
+ # Ensure that old import structure still works
28
+ from .query import run_sql, run_query, run_sql_file, run_fixtures # noqa: F401
29
+
30
+
31
+ def get_dataframe(connectable, filename_or_query, **kwargs):
32
+ """
33
+ Run a query on a SQL database (represented by
34
+ a SQLAlchemy database object) and turn it into a
35
+ `Pandas` dataframe.
36
+ """
37
+ from pandas import read_sql
38
+
39
+ sql = get_sql_text(filename_or_query)
40
+
41
+ return read_sql(sql, connectable, **kwargs)
42
+
43
+
44
+ def db_session(engine):
45
+ factory = sessionmaker(bind=engine)
46
+ return factory()
47
+
48
+
49
+ def execute(connectable, sql, params=None, stop_on_error=False, **kwargs):
50
+ output_file = kwargs.pop("output_file", None)
51
+ output_mode = kwargs.pop("output_mode", None)
52
+ sql = format(sql, strip_comments=True).strip()
53
+ if sql == "":
54
+ return
55
+ try:
56
+ connectable.begin()
57
+ res = connectable.execute(text(sql), params=params)
58
+ if hasattr(connectable, "commit"):
59
+ connectable.commit()
60
+ pretty_print(sql, dim=True, file=output_file, mode=output_mode)
61
+ return res
62
+ except (ProgrammingError, IntegrityError) as err:
63
+ if hasattr(connectable, "rollback"):
64
+ connectable.rollback()
65
+ _print_error(sql, dim=True, file=output_file, mode=output_mode)
66
+ if stop_on_error:
67
+ return
68
+ finally:
69
+ if hasattr(connectable, "close"):
70
+ connectable.close()
71
+
72
+
73
+ def get_or_create(session, model, defaults=None, **kwargs):
74
+ """
75
+ Get an instance of a model, or create it if it doesn't
76
+ exist.
77
+
78
+ https://stackoverflow.com/questions/2546207
79
+ """
80
+ instance = session.query(model).filter_by(**kwargs).first()
81
+ if instance:
82
+ instance._created = False
83
+ return instance
84
+ else:
85
+ params = dict(
86
+ (k, v) for k, v in kwargs.items() if not isinstance(v, ClauseElement)
87
+ )
88
+ params.update(defaults or {})
89
+ instance = model(**params)
90
+ session.add(instance)
91
+ instance._created = True
92
+ return instance
93
+
94
+
95
+ def get_db_model(db, model_name: str):
96
+ return getattr(db.model, model_name)
97
+
98
+
99
+ @contextmanager
100
+ def temp_database(conn_string, drop=True, ensure_empty=False):
101
+ """Create a temporary database and tear it down after tests."""
102
+ create_database(conn_string, exists_ok=True, replace=ensure_empty)
103
+ try:
104
+ yield create_engine(conn_string)
105
+ finally:
106
+ if drop:
107
+ drop_database(conn_string)
108
+
109
+
110
+ def create_database(url, **kwargs):
111
+ """Create a database if it doesn't exist.
112
+
113
+ Parameters
114
+ ----------
115
+ url : str
116
+ A SQLAlchemy database URL.
117
+ exists_ok : bool
118
+ If True, don't raise an error if the database already exists.
119
+ replace : bool
120
+ If True, drop the database if it exists and create a new one.
121
+ kwargs : dict
122
+ Additional keyword arguments to pass to `sqlalchemy_utils.create_database`.
123
+ """
124
+ db_exists = database_exists(url)
125
+
126
+ should_replace = kwargs.pop("replace", False)
127
+ exists_ok = kwargs.pop("exists_ok", False)
128
+
129
+ if should_replace and db_exists:
130
+ drop_database(url)
131
+ db_exists = False
132
+
133
+ if exists_ok and db_exists:
134
+ return
135
+ _create_database(url, **kwargs)
136
+
137
+
138
+ def create_engine(db_conn, **kwargs):
139
+ if isinstance(db_conn, Engine):
140
+ log.info(f"Set up database connection with engine {db_conn.url}")
141
+ if db_conn.driver == "psycopg2":
142
+ log.warning(
143
+ "The psycopg2 driver is deprecated. Please use psycopg3 instead."
144
+ )
145
+ return db_conn
146
+ else:
147
+ log.info(f"Setting up database connection with URL '{db_conn}'")
148
+ url = db_conn
149
+ if isinstance(url, str):
150
+ url = make_url(url)
151
+ # Set the driver to psycopg if not already set
152
+ if url.drivername != "postgresql+psycopg":
153
+ url = url.set(drivername="postgresql+psycopg")
154
+
155
+ return base_create_engine(url, **kwargs)
156
+
157
+
158
+ def connection_args(engine, with_password=False):
159
+ """Get PostgreSQL connection arguments for an engine"""
160
+ _psql_flags = {"-U": "username", "-h": "host", "-p": "port", "-P": "password"}
161
+
162
+ if isinstance(engine, str):
163
+ # We passed a connection url!
164
+ engine = create_engine(engine)
165
+ flags = ""
166
+ for flag, _attr in _psql_flags.items():
167
+ val = getattr(engine.url, _attr)
168
+ if flag == "-P" and not with_password:
169
+ continue
170
+ if val is not None:
171
+ flags += f" {flag} {val}"
172
+ return flags, engine.url.database
173
+
174
+
175
+ def db_isready(engine_or_url, use_shell_command=False):
176
+ if use_shell_command:
177
+ args, _ = connection_args(engine_or_url, with_password=True)
178
+ c = cmd("pg_isready", args, capture_output=True)
179
+ return c.returncode == 0
180
+ # Use a more typical sqlalchemy connection approach
181
+ engine = create_engine(engine_or_url)
182
+ try:
183
+ with engine.connect() as conn:
184
+ conn.execute(text("SELECT 1"))
185
+ return True
186
+ except OperationalError:
187
+ return False
188
+
189
+
190
+ def wait_for_database(engine_or_url, *, quiet=False, use_shell_command=False):
191
+ msg = "Waiting for database..."
192
+ while not db_isready(engine_or_url, use_shell_command=use_shell_command):
193
+ if not quiet:
194
+ echo(msg, err=True)
195
+ log.info(msg)
196
+ sleep(1)
197
+
198
+
199
+ def reflect_table(engine, tablename, *column_args, **kwargs):
200
+ """
201
+ One-off reflection of a database table or view. Note: for most purposes,
202
+ it will be better to use the database tables automapped at runtime in the
203
+ `self.tables` object. However, this function can be useful for views (which
204
+ are not reflected automatically), or to customize type definitions for mapped
205
+ tables.
206
+
207
+ A set of `column_args` can be used to pass columns to override with the mapper, for
208
+ instance to set up foreign and primary key constraints.
209
+ https://docs.sqlalchemy.org/en/13/core/reflection.html#reflecting-views
210
+ """
211
+ schema = kwargs.pop("schema", "public")
212
+ meta = MetaData(schema=schema)
213
+ return Table(tablename, meta, *column_args, autoload_with=engine, **kwargs)
@@ -0,0 +1,44 @@
1
+ [project]
2
+ name = "macrostrat.database"
3
+ version = "4.0.1"
4
+ description = "A SQLAlchemy-based database toolkit."
5
+ authors = [{ name = "Daven Quinn", email = "dev@davenquinn.com" }]
6
+ requires-python = ">=3.10,<4"
7
+ classifiers = [
8
+ "Programming Language :: Python :: 3",
9
+ "Programming Language :: Python :: 3.10",
10
+ "Programming Language :: Python :: 3.11",
11
+ "Programming Language :: Python :: 3.12",
12
+ "Programming Language :: Python :: 3.13",
13
+ "Programming Language :: Python :: 3.14",
14
+ ]
15
+ dependencies = [
16
+ "GeoAlchemy2>=0.15.2,<0.16",
17
+ "SQLAlchemy>=2.0.18,<3",
18
+ "SQLAlchemy-Utils>=0.41.1,<0.42",
19
+ "click>=8.1.3,<9",
20
+ "macrostrat.utils>=1.3.3,<2",
21
+ "sqlparse>=0.5.1,<0.6",
22
+ "aiofiles>=23.2.1,<24",
23
+ "rich>=13.7.1,<14",
24
+ "psycopg>=3.2.1,<4",
25
+ "psycopg2>=2.9.11,<3",
26
+ ]
27
+
28
+ [dependency-groups]
29
+ dev = ["macrostrat.utils"]
30
+
31
+ [tool.uv]
32
+ default-groups = "all"
33
+
34
+ [tool.uv.sources]
35
+ "macrostrat.utils" = { path = "../utils", editable = true }
36
+
37
+ [tool.uv.build-backend]
38
+ module-name = ["macrostrat"]
39
+ module-root = ""
40
+ namespace = true
41
+
42
+ [build-system]
43
+ requires = ["uv_build>=0.9.21,<0.12.0"]
44
+ build-backend = "uv_build"
@@ -1,20 +0,0 @@
1
- Metadata-Version: 2.1
2
- Name: macrostrat.database
3
- Version: 3.5.4
4
- Summary: A SQLAlchemy-based database toolkit.
5
- Author: Daven Quinn
6
- Author-email: dev@davenquinn.com
7
- Requires-Python: >=3.10,<4.0
8
- Classifier: Programming Language :: Python :: 3
9
- Classifier: Programming Language :: Python :: 3.10
10
- Classifier: Programming Language :: Python :: 3.11
11
- Classifier: Programming Language :: Python :: 3.12
12
- Requires-Dist: GeoAlchemy2 (>=0.15.2,<0.16.0)
13
- Requires-Dist: SQLAlchemy (>=2.0.18,<3.0.0)
14
- Requires-Dist: SQLAlchemy-Utils (>=0.41.1,<0.42.0)
15
- Requires-Dist: aiofiles (>=23.2.1,<24.0.0)
16
- Requires-Dist: click (>=8.1.3,<9.0.0)
17
- Requires-Dist: macrostrat.utils (>=1.3.0,<2.0.0)
18
- Requires-Dist: psycopg2-binary (>=2.9.6,<3.0.0)
19
- Requires-Dist: rich (>=13.7.1,<14.0.0)
20
- Requires-Dist: sqlparse (>=0.5.1,<0.6.0)
@@ -1,25 +0,0 @@
1
- [tool.poetry]
2
- authors = ["Daven Quinn <dev@davenquinn.com>"]
3
- description = "A SQLAlchemy-based database toolkit."
4
- name = "macrostrat.database"
5
- packages = [{ include = "macrostrat" }]
6
- version = "3.5.4"
7
-
8
- [tool.poetry.dependencies]
9
- GeoAlchemy2 = "^0.15.2"
10
- SQLAlchemy = "^2.0.18"
11
- SQLAlchemy-Utils = "^0.41.1"
12
- click = "^8.1.3"
13
- "macrostrat.utils" = "^1.3.0"
14
- psycopg2-binary = "^2.9.6"
15
- python = "^3.10"
16
- sqlparse = "^0.5.1"
17
- aiofiles = "^23.2.1"
18
- rich = "^13.7.1"
19
-
20
- [tool.poetry.dev-dependencies]
21
- "macrostrat.utils" = { path = "../utils", develop = true }
22
-
23
- [build-system]
24
- build-backend = "poetry.core.masonry.api"
25
- requires = ["poetry-core>=1.0.0"]