macrostrat.database 4.0.0__tar.gz → 4.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (20) hide show
  1. macrostrat_database-4.0.1/PKG-INFO +23 -0
  2. macrostrat_database-4.0.1/macrostrat/.DS_Store +0 -0
  3. {macrostrat_database-4.0.0 → macrostrat_database-4.0.1}/macrostrat/database/__init__.py +2 -1
  4. macrostrat_database-4.0.1/macrostrat/database/compat.py +36 -0
  5. {macrostrat_database-4.0.0 → macrostrat_database-4.0.1}/macrostrat/database/postgresql.py +36 -2
  6. macrostrat_database-4.0.0/macrostrat/database/utils.py → macrostrat_database-4.0.1/macrostrat/database/query.py +176 -271
  7. macrostrat_database-4.0.1/macrostrat/database/utils.py +213 -0
  8. macrostrat_database-4.0.1/pyproject.toml +44 -0
  9. macrostrat_database-4.0.0/PKG-INFO +0 -22
  10. macrostrat_database-4.0.0/pyproject.toml +0 -26
  11. {macrostrat_database-4.0.0 → macrostrat_database-4.0.1}/macrostrat/database/mapper/__init__.py +0 -0
  12. {macrostrat_database-4.0.0 → macrostrat_database-4.0.1}/macrostrat/database/mapper/base.py +0 -0
  13. {macrostrat_database-4.0.0 → macrostrat_database-4.0.1}/macrostrat/database/mapper/cache.py +0 -0
  14. {macrostrat_database-4.0.0 → macrostrat_database-4.0.1}/macrostrat/database/mapper/utils.py +0 -0
  15. {macrostrat_database-4.0.0 → macrostrat_database-4.0.1}/macrostrat/database/transfer/__init__.py +0 -0
  16. {macrostrat_database-4.0.0 → macrostrat_database-4.0.1}/macrostrat/database/transfer/dump_database.py +0 -0
  17. {macrostrat_database-4.0.0 → macrostrat_database-4.0.1}/macrostrat/database/transfer/move_tables.py +0 -0
  18. {macrostrat_database-4.0.0 → macrostrat_database-4.0.1}/macrostrat/database/transfer/restore_database.py +0 -0
  19. {macrostrat_database-4.0.0 → macrostrat_database-4.0.1}/macrostrat/database/transfer/stream_utils.py +0 -0
  20. {macrostrat_database-4.0.0 → macrostrat_database-4.0.1}/macrostrat/database/transfer/utils.py +0 -0
@@ -0,0 +1,23 @@
1
+ Metadata-Version: 2.3
2
+ Name: macrostrat.database
3
+ Version: 4.0.1
4
+ Summary: A SQLAlchemy-based database toolkit.
5
+ Author: Daven Quinn
6
+ Author-email: Daven Quinn <dev@davenquinn.com>
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: Programming Language :: Python :: 3.10
9
+ Classifier: Programming Language :: Python :: 3.11
10
+ Classifier: Programming Language :: Python :: 3.12
11
+ Classifier: Programming Language :: Python :: 3.13
12
+ Classifier: Programming Language :: Python :: 3.14
13
+ Requires-Dist: geoalchemy2>=0.15.2,<0.16
14
+ Requires-Dist: sqlalchemy>=2.0.18,<3
15
+ Requires-Dist: sqlalchemy-utils>=0.41.1,<0.42
16
+ Requires-Dist: click>=8.1.3,<9
17
+ Requires-Dist: macrostrat-utils>=1.3.3,<2
18
+ Requires-Dist: sqlparse>=0.5.1,<0.6
19
+ Requires-Dist: aiofiles>=23.2.1,<24
20
+ Requires-Dist: rich>=13.7.1,<14
21
+ Requires-Dist: psycopg>=3.2.1,<4
22
+ Requires-Dist: psycopg2>=2.9.11,<3
23
+ Requires-Python: >=3.10, <4
@@ -12,10 +12,12 @@ from sqlalchemy.orm import Session, scoped_session, sessionmaker
12
12
  from sqlalchemy.sql.expression import Insert
13
13
 
14
14
  from macrostrat.utils import get_logger
15
+
15
16
  from .mapper import DatabaseMapper
16
17
  from .postgresql import on_conflict, prefix_inserts # noqa
17
18
  from .utils import ( # noqa
18
19
  create_database,
20
+ create_engine,
19
21
  database_exists,
20
22
  drop_database,
21
23
  get_dataframe,
@@ -24,7 +26,6 @@ from .utils import ( # noqa
24
26
  run_fixtures,
25
27
  run_query,
26
28
  run_sql,
27
- create_engine,
28
29
  )
29
30
 
30
31
  metadata = MetaData()
@@ -0,0 +1,36 @@
1
+ from warnings import warn
2
+
3
+ import psycopg.sql as psql3
4
+ import psycopg2.sql as psql2
5
+
6
+
7
+ def update_legacy_identifier(identifier):
8
+ """
9
+ For backwards compatibility with current code, we need to map psycopg2 identifiers to their equivalents in psycopg3,
10
+ while printing a warning that the mapping is deprecated.
11
+ :param identifier:
12
+ :return: psycopg3 equivalent of identifier
13
+ """
14
+ new_identifier = _map_psycopg2_identifier_to_psycopg3_identifier_internal(
15
+ identifier
16
+ )
17
+ if new_identifier is not identifier:
18
+ warn(
19
+ "psycopg2 identifiers are deprecated. Please use psycopg3 identifiers instead.",
20
+ DeprecationWarning,
21
+ )
22
+ return new_identifier
23
+
24
+
25
+ def _map_psycopg2_identifier_to_psycopg3_identifier_internal(identifier):
26
+ if isinstance(identifier, psql2.Identifier):
27
+ return psql3.Identifier(*identifier._wrapped)
28
+ if isinstance(identifier, psql2.SQL):
29
+ return psql3.SQL(identifier._wrapped)
30
+ if isinstance(identifier, psql2.Literal):
31
+ return psql3.Literal(identifier._wrapped)
32
+ if isinstance(identifier, psql2.Placeholder):
33
+ return psql3.Placeholder(identifier._obj)
34
+ if isinstance(identifier, psql2.Composed):
35
+ return psql3.Composed(identifier._obj)
36
+ return identifier
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  from contextlib import contextmanager
4
4
  from contextvars import ContextVar
5
- from typing import TYPE_CHECKING
5
+ from enum import Enum
6
+ from typing import Any, Sequence, TYPE_CHECKING
6
7
 
7
8
  from sqlalchemy.dialects import postgresql
8
9
  from sqlalchemy.exc import CompileError
@@ -13,7 +14,14 @@ from sqlalchemy.sql.expression import text
13
14
  if TYPE_CHECKING:
14
15
  from ..database import Database
15
16
 
16
- _insert_mode = ContextVar("insert-mode", default="do-nothing")
17
+
18
+ class OnConflictAction(str, Enum):
19
+ DO_NOTHING = "do-nothing"
20
+ DO_UPDATE = "do-update"
21
+ RESTRICT = "restrict"
22
+
23
+
24
+ _insert_mode = ContextVar("insert-mode", default="restrict")
17
25
 
18
26
 
19
27
  # https://stackoverflow.com/questions/33307250/postgresql-on-conflict-in-sqlalchemy/62305344#62305344
@@ -61,6 +69,32 @@ def prefix_inserts(insert, compiler, **kw):
61
69
  return compiler.visit_insert(insert, **kw)
62
70
 
63
71
 
72
+ def upsert(
73
+ table,
74
+ values: dict[str, Any],
75
+ index_elements: Sequence[str],
76
+ *,
77
+ on_conflict: OnConflictAction = "do-update",
78
+ ):
79
+ stmt = postgresql.insert(table).values(values)
80
+ if on_conflict == "restrict":
81
+ return stmt
82
+
83
+ if on_conflict == "do-nothing":
84
+ return stmt.on_conflict_do_nothing(index_elements=list(index_elements))
85
+
86
+ update_values = {
87
+ column.name: getattr(stmt.excluded, column.name)
88
+ for column in table.columns
89
+ if column.name in values and column.name not in index_elements
90
+ }
91
+
92
+ return stmt.on_conflict_do_update(
93
+ index_elements=list(index_elements),
94
+ set_=update_values,
95
+ )
96
+
97
+
64
98
  def table_exists(db: Database, table_name: str, schema: str = "public") -> bool:
65
99
  """Check if a table exists in a PostgreSQL database."""
66
100
  sql = """SELECT EXISTS (
@@ -1,20 +1,19 @@
1
1
  import os
2
- from contextlib import contextmanager
2
+ from dataclasses import dataclass
3
3
  from enum import Enum
4
4
  from pathlib import Path
5
5
  from re import search
6
6
  from sys import stderr
7
- from typing import IO, Union
7
+ from typing import Callable, Any, IO, Union
8
8
  from warnings import warn
9
9
 
10
- from click import echo, secho
10
+ import psycopg2.sql as psql2
11
+ from click import secho
11
12
  from psycopg.errors import QueryCanceled
12
13
  from psycopg.sql import SQL, Composable, Composed
13
14
  from rich.console import Console
14
- from sqlalchemy import MetaData, text
15
- from sqlalchemy import create_engine as base_create_engine
15
+ from sqlalchemy import text
16
16
  from sqlalchemy.engine import Connection, Engine
17
- from sqlalchemy.engine.url import make_url
18
17
  from sqlalchemy.exc import (
19
18
  IntegrityError,
20
19
  InternalError,
@@ -22,24 +21,17 @@ from sqlalchemy.exc import (
22
21
  OperationalError,
23
22
  ProgrammingError,
24
23
  )
25
- from sqlalchemy.orm import sessionmaker
26
- from sqlalchemy.schema import Table
27
- from sqlalchemy.sql.elements import ClauseElement, TextClause
28
- from sqlalchemy_utils import create_database as _create_database
29
- from sqlalchemy_utils import database_exists, drop_database
24
+ from sqlalchemy.sql.elements import TextClause
30
25
  from sqlparse import format, split
31
- from time import sleep
32
26
 
33
- from macrostrat.utils import cmd, get_logger
27
+ from macrostrat.database.compat import (
28
+ update_legacy_identifier,
29
+ )
30
+ from macrostrat.utils import get_logger
34
31
 
35
32
  log = get_logger(__name__)
36
33
 
37
34
 
38
- def db_session(engine):
39
- factory = sessionmaker(bind=engine)
40
- return factory()
41
-
42
-
43
35
  def infer_is_sql_text(_string: str) -> bool:
44
36
  """
45
37
  Return True if the string is a valid SQL query,
@@ -74,19 +66,6 @@ def canonicalize_query(file_or_text: Union[str, Path, IO]) -> Union[str, Path]:
74
66
  return file_or_text
75
67
 
76
68
 
77
- def get_dataframe(connectable, filename_or_query, **kwargs):
78
- """
79
- Run a query on a SQL database (represented by
80
- a SQLAlchemy database object) and turn it into a
81
- `Pandas` dataframe.
82
- """
83
- from pandas import read_sql
84
-
85
- sql = get_sql_text(filename_or_query)
86
-
87
- return read_sql(sql, connectable, **kwargs)
88
-
89
-
90
69
  def pretty_print(sql, **kwargs):
91
70
  """Print and optionally summarize an SQL query"""
92
71
  summarize = kwargs.pop("summarize", True)
@@ -117,6 +96,7 @@ def summarize_statement(sql):
117
96
  if not line.startswith(i):
118
97
  continue
119
98
  return line.split("(")[0].strip().rstrip(";").replace(" AS", "")
99
+ return sql.strip().split("\n")[0].strip().rstrip(";")
120
100
 
121
101
 
122
102
  def get_sql_text(sql, interpret_as_file=None, echo_file_name=True):
@@ -139,9 +119,7 @@ def _get_queries(sql, interpret_as_file=None):
139
119
  for i in sql:
140
120
  queries.extend(_get_queries(i, interpret_as_file=interpret_as_file))
141
121
  return queries
142
- if isinstance(sql, TextClause):
143
- return [sql]
144
- if isinstance(sql, SQL):
122
+ if isinstance(sql, (SQL, psql2.SQL, TextClause)):
145
123
  return [sql]
146
124
 
147
125
  if sql in [None, ""]:
@@ -158,7 +136,7 @@ def _get_queries(sql, interpret_as_file=None):
158
136
 
159
137
 
160
138
  def _is_prebind_param(param):
161
- return isinstance(param, Composable)
139
+ return isinstance(param, (Composable, psql2.Composable))
162
140
 
163
141
 
164
142
  def _split_params(params):
@@ -169,7 +147,7 @@ def _split_params(params):
169
147
  if isinstance(params, (list, tuple)):
170
148
  for i in params:
171
149
  if _is_prebind_param(i):
172
- new_bind_params.append(i)
150
+ new_bind_params.append(update_legacy_identifier(i))
173
151
  else:
174
152
  new_params.append(i)
175
153
  elif isinstance(params, dict):
@@ -177,7 +155,7 @@ def _split_params(params):
177
155
  new_bind_params = {}
178
156
  for k, v in params.items():
179
157
  if _is_prebind_param(v):
180
- new_bind_params[k] = v
158
+ new_bind_params[k] = update_legacy_identifier(v)
181
159
  else:
182
160
  new_params[k] = v
183
161
  if len(new_bind_params) == 0:
@@ -230,8 +208,11 @@ def _render_query(query: Union[SQL, Composed], connectable: Union[Engine, Connec
230
208
  return query.as_string(conn)
231
209
 
232
210
 
233
- def infer_has_server_binds(sql):
234
- return "%s" in sql or search(r"%\(\w+\)s", sql)
211
+ def infer_has_server_binds(sql) -> bool:
212
+ if "%s" in sql:
213
+ return True
214
+ res = search(r"%\(\w+\)s", sql)
215
+ return res is not None
235
216
 
236
217
 
237
218
  _default_statement_filter = lambda sql_text, params: True
@@ -256,7 +237,40 @@ def _normalize_output_args(kwargs):
256
237
  return output_mode, output_file
257
238
 
258
239
 
259
- def _run_sql(connectable, sql, params=None, **kwargs):
240
+ @dataclass
241
+ class StatementResult:
242
+ query: Any
243
+ params: Any = None
244
+ skip: bool = False
245
+ label: str | None = None
246
+
247
+ @classmethod
248
+ def skipped(cls, query=None, params=None) -> "StatementResult":
249
+ return cls(query=query, params=params, skip=True)
250
+
251
+
252
+ @dataclass
253
+ class StatementContext:
254
+ index: int
255
+ query: Any
256
+ params: Any
257
+ sql_text: str
258
+
259
+
260
+ TransformFn = Callable[[StatementContext], list[StatementResult] | None]
261
+ Connectable = Union[Engine, Connection]
262
+
263
+
264
+ def _statement_filter_to_transform(statement_filter) -> TransformFn:
265
+ def transform(ctx: StatementContext) -> list[StatementResult] | None:
266
+ if not statement_filter(ctx.sql_text, ctx.params):
267
+ return [StatementResult.skipped(query=ctx.query, params=ctx.params)]
268
+ return None
269
+
270
+ return transform
271
+
272
+
273
+ def _run_sql(connectable, sql, params=None, *, print_skipped=True, **kwargs):
260
274
  """
261
275
  Internal function for running a query on a SQLAlchemy connectable,
262
276
  which always returns an iterator. The wrapper function adds the option
@@ -269,17 +283,30 @@ def _run_sql(connectable, sql, params=None, **kwargs):
269
283
 
270
284
  stop_on_error = kwargs.pop("stop_on_error", False)
271
285
  raise_errors = kwargs.pop("raise_errors", False)
272
- has_server_binds = kwargs.pop("has_server_binds", None)
273
286
  ensure_single_query = kwargs.pop("ensure_single_query", False)
274
- statement_filter = kwargs.pop("statement_filter", _default_statement_filter)
275
287
  output_mode, output_file = _normalize_output_args(kwargs)
288
+ has_server_binds = kwargs.pop("has_server_binds", None)
289
+
290
+ statement_filter = kwargs.pop("statement_filter", None)
291
+ transform_statement: TransformFn | None = kwargs.pop("transform_statement", None)
276
292
 
277
293
  if stop_on_error:
278
294
  raise_errors = True
279
295
  warn(DeprecationWarning("stop_on_error is deprecated, use raise_errors"))
280
296
 
281
- interpret_as_file = kwargs.pop("interpret_as_file", None)
297
+ if statement_filter is not None:
298
+ warn(
299
+ DeprecationWarning(
300
+ "statement_filter is deprecated, use transform_statement"
301
+ )
302
+ )
303
+ if transform_statement is not None:
304
+ raise ValueError(
305
+ "Cannot specify both statement_filter and transform_statement"
306
+ )
307
+ transform_statement = _statement_filter_to_transform(statement_filter)
282
308
 
309
+ interpret_as_file = kwargs.pop("interpret_as_file", None)
283
310
  queries = _get_queries(sql, interpret_as_file=interpret_as_file)
284
311
 
285
312
  if queries is None:
@@ -288,78 +315,116 @@ def _run_sql(connectable, sql, params=None, **kwargs):
288
315
  if ensure_single_query and len(queries) > 1:
289
316
  raise ValueError("Multiple queries passed when only one was expected")
290
317
 
291
- # check if parameters is a list of the same length as the number of queries
292
318
  if not isinstance(params, list) or not len(params) == len(queries):
293
319
  params = [params] * len(queries)
294
320
 
295
- for query, _params in zip(queries, params):
296
- params, pre_bind_params = _split_params(_params)
321
+ for index, (query, _params) in enumerate(zip(queries, params)):
322
+ _query, sql_text = _render_query_text(connectable, query, _params)
323
+ if sql_text == "":
324
+ continue
297
325
 
298
- if pre_bind_params is not None:
299
- if not isinstance(query, SQL):
300
- query = SQL(query)
301
- # Pre-bind the parameters using psycopg
302
- query = query.format(**pre_bind_params)
326
+ ctx = StatementContext(
327
+ index=index, query=query, params=_params, sql_text=sql_text
328
+ )
303
329
 
304
- if isinstance(query, (SQL, Composed)):
305
- query = _render_query(query, connectable)
330
+ results = transform_statement(ctx) if transform_statement is not None else None
306
331
 
307
- sql_text = str(query)
332
+ if results is None:
333
+ results = [StatementResult(query=query, params=_params)]
308
334
 
309
- if isinstance(query, str):
310
- sql_text = format(query, strip_comments=True).strip()
311
- if sql_text == "":
312
- continue
313
- # Check for server-bound parameters in sql native style. If there are none, use
314
- # the SQLAlchemy text() function, otherwise use the raw query string
315
- if has_server_binds is None:
316
- has_server_binds = infer_has_server_binds(sql_text)
317
-
318
- should_run = statement_filter(sql_text, params)
319
-
320
- # Shorten summary text for printing
321
- if output_mode != OutputMode.ALL:
322
- sql_text = summarize_statement(sql_text)
323
-
324
- if not should_run:
325
- secho(
326
- sql_text,
327
- dim=True,
328
- strikethrough=True,
329
- file=output_file,
335
+ for result in results:
336
+ yield from _execute_one(
337
+ connectable,
338
+ result,
339
+ output_file,
340
+ raise_errors=raise_errors,
341
+ output_mode=output_mode,
342
+ print_skipped=print_skipped,
343
+ has_server_binds=has_server_binds,
330
344
  )
331
- continue
332
345
 
333
- try:
334
- trans = connectable.begin()
335
- except InvalidRequestError:
336
- trans = None
337
- try:
338
- log.debug("Executing SQL: \n %s", query)
339
- if has_server_binds:
340
- conn = _get_connection(connectable)
341
- res = conn.exec_driver_sql(query, params)
342
- else:
343
- if not isinstance(query, TextClause):
344
- query = text(query)
345
- res = connectable.execute(query, params)
346
- yield res
347
- if trans is not None:
348
- trans.commit()
349
- elif hasattr(connectable, "commit"):
350
- connectable.commit()
351
- secho(sql_text, dim=True, file=output_file)
352
- except Exception as err:
353
- if trans is not None:
354
- trans.rollback()
355
- elif hasattr(connectable, "rollback"):
356
- connectable.rollback()
357
- if raise_errors or _should_raise_query_error(err):
358
- raise err
359
-
360
- _print_error(sql_text, err, file=output_file)
361
- finally:
362
- pass
346
+
347
+ def _render_query_text(connectable, query, params):
348
+ params, pre_bind_params = _split_params(params)
349
+ if isinstance(query, (psql2.SQL, psql2.Composed)):
350
+ query = update_legacy_identifier(query)
351
+
352
+ if pre_bind_params is not None:
353
+ if not isinstance(query, SQL):
354
+ query = SQL(query)
355
+ # Pre-bind the parameters using psycopg
356
+ query = query.format(**pre_bind_params)
357
+
358
+ if isinstance(query, (SQL, Composed)):
359
+ query = _render_query(query, connectable)
360
+
361
+ sql_text = str(query)
362
+ if isinstance(query, str):
363
+ sql_text = format(query, strip_comments=True).strip()
364
+
365
+ return query, sql_text
366
+
367
+
368
+ def _execute_one(
369
+ connectable,
370
+ result: StatementResult,
371
+ output_file: IO,
372
+ *,
373
+ raise_errors: bool = True,
374
+ output_mode: OutputMode = OutputMode.SUMMARY,
375
+ has_server_binds: bool | None = None,
376
+ print_skipped: bool = True,
377
+ ):
378
+ params = result.params
379
+
380
+ query, sql_text = _render_query_text(connectable, result.query, params)
381
+ if has_server_binds is None:
382
+ has_server_binds = infer_has_server_binds(sql_text)
383
+
384
+ if result.label is not None:
385
+ display_text = result.label
386
+ elif output_mode != OutputMode.ALL:
387
+ display_text = summarize_statement(str(query))
388
+ else:
389
+ display_text = str(query)
390
+
391
+ if result.skip:
392
+ if print_skipped:
393
+ secho(display_text, dim=True, strikethrough=True, file=output_file)
394
+ return
395
+
396
+ try:
397
+ trans = connectable.begin()
398
+ except InvalidRequestError:
399
+ trans = None
400
+
401
+ try:
402
+ log.debug("Executing SQL: \n %s", query)
403
+ if has_server_binds:
404
+ conn = _get_connection(connectable)
405
+ res = conn.exec_driver_sql(query, params)
406
+ else:
407
+ if not isinstance(query, TextClause):
408
+ query = text(query)
409
+ res = connectable.execute(query, params)
410
+
411
+ yield res
412
+
413
+ if trans is not None:
414
+ trans.commit()
415
+ elif hasattr(connectable, "commit"):
416
+ connectable.commit()
417
+
418
+ secho(display_text, dim=True, file=output_file)
419
+
420
+ except Exception as err:
421
+ if trans is not None:
422
+ trans.rollback()
423
+ elif hasattr(connectable, "rollback"):
424
+ connectable.rollback()
425
+ if raise_errors or _should_raise_query_error(err):
426
+ raise err
427
+ _print_error(display_text, err, file=output_file)
363
428
 
364
429
 
365
430
  def _should_raise_query_error(err):
@@ -505,173 +570,13 @@ def run_sql(*args, **kwargs):
505
570
  statement_filter : Callable
506
571
  A function that takes a SQL statement and parameters and returns True if the statement
507
572
  should be run, and False if it should be skipped.
573
+ transform_statement: TransformFn | None
574
+ A function that takes a StatementContext and returns a list of StatementResult
575
+ objects, which can modify the query, parameters, and whether the statement
576
+ should be skipped or not. This allows for more complex logic than a simple
577
+ statement filter.
508
578
  """
509
579
  res = _run_sql(*args, **kwargs)
510
580
  if kwargs.pop("yield_results", False):
511
581
  return res
512
582
  return list(res)
513
-
514
-
515
- def execute(connectable, sql, params=None, stop_on_error=False, **kwargs):
516
- output_file = kwargs.pop("output_file", None)
517
- output_mode = kwargs.pop("output_mode", None)
518
- sql = format(sql, strip_comments=True).strip()
519
- if sql == "":
520
- return
521
- try:
522
- connectable.begin()
523
- res = connectable.execute(text(sql), params=params)
524
- if hasattr(connectable, "commit"):
525
- connectable.commit()
526
- pretty_print(sql, dim=True, file=output_file, mode=output_mode)
527
- return res
528
- except (ProgrammingError, IntegrityError) as err:
529
- if hasattr(connectable, "rollback"):
530
- connectable.rollback()
531
- _print_error(sql, dim=True, file=output_file, mode=output_mode)
532
- if stop_on_error:
533
- return
534
- finally:
535
- if hasattr(connectable, "close"):
536
- connectable.close()
537
-
538
-
539
- def get_or_create(session, model, defaults=None, **kwargs):
540
- """
541
- Get an instance of a model, or create it if it doesn't
542
- exist.
543
-
544
- https://stackoverflow.com/questions/2546207
545
- """
546
- instance = session.query(model).filter_by(**kwargs).first()
547
- if instance:
548
- instance._created = False
549
- return instance
550
- else:
551
- params = dict(
552
- (k, v) for k, v in kwargs.items() if not isinstance(v, ClauseElement)
553
- )
554
- params.update(defaults or {})
555
- instance = model(**params)
556
- session.add(instance)
557
- instance._created = True
558
- return instance
559
-
560
-
561
- def get_db_model(db, model_name: str):
562
- return getattr(db.model, model_name)
563
-
564
-
565
- @contextmanager
566
- def temp_database(conn_string, drop=True, ensure_empty=False):
567
- """Create a temporary database and tear it down after tests."""
568
- create_database(conn_string, exists_ok=True, replace=ensure_empty)
569
- try:
570
- yield create_engine(conn_string)
571
- finally:
572
- if drop:
573
- drop_database(conn_string)
574
-
575
-
576
- def create_database(url, **kwargs):
577
- """Create a database if it doesn't exist.
578
-
579
- Parameters
580
- ----------
581
- url : str
582
- A SQLAlchemy database URL.
583
- exists_ok : bool
584
- If True, don't raise an error if the database already exists.
585
- replace : bool
586
- If True, drop the database if it exists and create a new one.
587
- kwargs : dict
588
- Additional keyword arguments to pass to `sqlalchemy_utils.create_database`.
589
- """
590
- db_exists = database_exists(url)
591
-
592
- should_replace = kwargs.pop("replace", False)
593
- exists_ok = kwargs.pop("exists_ok", False)
594
-
595
- if should_replace and db_exists:
596
- drop_database(url)
597
- db_exists = False
598
-
599
- if exists_ok and db_exists:
600
- return
601
- _create_database(url, **kwargs)
602
-
603
-
604
- def create_engine(db_conn, **kwargs):
605
- if isinstance(db_conn, Engine):
606
- log.info(f"Set up database connection with engine {db_conn.url}")
607
- if db_conn.driver == "psycopg2":
608
- log.warning(
609
- "The psycopg2 driver is deprecated. Please use psycopg3 instead."
610
- )
611
- return db_conn
612
- else:
613
- log.info(f"Setting up database connection with URL '{db_conn}'")
614
- url = db_conn
615
- if isinstance(url, str):
616
- url = make_url(url)
617
- # Set the driver to psycopg if not already set
618
- if url.drivername != "postgresql+psycopg":
619
- url = url.set(drivername="postgresql+psycopg")
620
-
621
- return base_create_engine(url, **kwargs)
622
-
623
-
624
- def connection_args(engine):
625
- """Get PostgreSQL connection arguments for an engine"""
626
- _psql_flags = {"-U": "username", "-h": "host", "-p": "port", "-P": "password"}
627
-
628
- if isinstance(engine, str):
629
- # We passed a connection url!
630
- engine = create_engine(engine)
631
- flags = ""
632
- for flag, _attr in _psql_flags.items():
633
- val = getattr(engine.url, _attr)
634
- if val is not None:
635
- flags += f" {flag} {val}"
636
- return flags, engine.url.database
637
-
638
-
639
- def db_isready(engine_or_url, use_shell_command=False):
640
- if use_shell_command:
641
- args, _ = connection_args(engine_or_url)
642
- c = cmd("pg_isready", args, capture_output=True)
643
- return c.returncode == 0
644
- # Use a more typical sqlalchemy connection approach
645
- engine = create_engine(engine_or_url)
646
- try:
647
- with engine.connect() as conn:
648
- conn.execute(text("SELECT 1"))
649
- return True
650
- except OperationalError:
651
- return False
652
-
653
-
654
- def wait_for_database(engine_or_url, *, quiet=False, use_shell_command=False):
655
- msg = "Waiting for database..."
656
- while not db_isready(engine_or_url, use_shell_command=use_shell_command):
657
- if not quiet:
658
- echo(msg, err=True)
659
- log.info(msg)
660
- sleep(1)
661
-
662
-
663
- def reflect_table(engine, tablename, *column_args, **kwargs):
664
- """
665
- One-off reflection of a database table or view. Note: for most purposes,
666
- it will be better to use the database tables automapped at runtime in the
667
- `self.tables` object. However, this function can be useful for views (which
668
- are not reflected automatically), or to customize type definitions for mapped
669
- tables.
670
-
671
- A set of `column_args` can be used to pass columns to override with the mapper, for
672
- instance to set up foreign and primary key constraints.
673
- https://docs.sqlalchemy.org/en/13/core/reflection.html#reflecting-views
674
- """
675
- schema = kwargs.pop("schema", "public")
676
- meta = MetaData(schema=schema)
677
- return Table(tablename, meta, *column_args, autoload_with=engine, **kwargs)
@@ -0,0 +1,213 @@
1
+ from contextlib import contextmanager
2
+ from time import sleep
3
+
4
+ from click import echo
5
+ from sqlalchemy import MetaData
6
+ from sqlalchemy import create_engine as base_create_engine
7
+ from sqlalchemy import text
8
+ from sqlalchemy.engine import Engine
9
+ from sqlalchemy.engine.url import make_url
10
+ from sqlalchemy.exc import (
11
+ IntegrityError,
12
+ OperationalError,
13
+ ProgrammingError,
14
+ )
15
+ from sqlalchemy.orm import sessionmaker
16
+ from sqlalchemy.schema import Table
17
+ from sqlalchemy.sql.elements import ClauseElement
18
+ from sqlalchemy_utils import create_database as _create_database
19
+ from sqlalchemy_utils import database_exists, drop_database
20
+ from sqlparse import format
21
+
22
+ from macrostrat.utils import cmd, get_logger
23
+ from .query import get_sql_text
24
+
25
+ log = get_logger(__name__)
26
+
27
+ # Ensure that old import structure still works
28
+ from .query import run_sql, run_query, run_sql_file, run_fixtures # noqa: F401
29
+
30
+
31
+ def get_dataframe(connectable, filename_or_query, **kwargs):
32
+ """
33
+ Run a query on a SQL database (represented by
34
+ a SQLAlchemy database object) and turn it into a
35
+ `Pandas` dataframe.
36
+ """
37
+ from pandas import read_sql
38
+
39
+ sql = get_sql_text(filename_or_query)
40
+
41
+ return read_sql(sql, connectable, **kwargs)
42
+
43
+
44
+ def db_session(engine):
45
+ factory = sessionmaker(bind=engine)
46
+ return factory()
47
+
48
+
49
+ def execute(connectable, sql, params=None, stop_on_error=False, **kwargs):
50
+ output_file = kwargs.pop("output_file", None)
51
+ output_mode = kwargs.pop("output_mode", None)
52
+ sql = format(sql, strip_comments=True).strip()
53
+ if sql == "":
54
+ return
55
+ try:
56
+ connectable.begin()
57
+ res = connectable.execute(text(sql), params=params)
58
+ if hasattr(connectable, "commit"):
59
+ connectable.commit()
60
+ pretty_print(sql, dim=True, file=output_file, mode=output_mode)
61
+ return res
62
+ except (ProgrammingError, IntegrityError) as err:
63
+ if hasattr(connectable, "rollback"):
64
+ connectable.rollback()
65
+ _print_error(sql, dim=True, file=output_file, mode=output_mode)
66
+ if stop_on_error:
67
+ return
68
+ finally:
69
+ if hasattr(connectable, "close"):
70
+ connectable.close()
71
+
72
+
73
+ def get_or_create(session, model, defaults=None, **kwargs):
74
+ """
75
+ Get an instance of a model, or create it if it doesn't
76
+ exist.
77
+
78
+ https://stackoverflow.com/questions/2546207
79
+ """
80
+ instance = session.query(model).filter_by(**kwargs).first()
81
+ if instance:
82
+ instance._created = False
83
+ return instance
84
+ else:
85
+ params = dict(
86
+ (k, v) for k, v in kwargs.items() if not isinstance(v, ClauseElement)
87
+ )
88
+ params.update(defaults or {})
89
+ instance = model(**params)
90
+ session.add(instance)
91
+ instance._created = True
92
+ return instance
93
+
94
+
95
+ def get_db_model(db, model_name: str):
96
+ return getattr(db.model, model_name)
97
+
98
+
99
+ @contextmanager
100
+ def temp_database(conn_string, drop=True, ensure_empty=False):
101
+ """Create a temporary database and tear it down after tests."""
102
+ create_database(conn_string, exists_ok=True, replace=ensure_empty)
103
+ try:
104
+ yield create_engine(conn_string)
105
+ finally:
106
+ if drop:
107
+ drop_database(conn_string)
108
+
109
+
110
+ def create_database(url, **kwargs):
111
+ """Create a database if it doesn't exist.
112
+
113
+ Parameters
114
+ ----------
115
+ url : str
116
+ A SQLAlchemy database URL.
117
+ exists_ok : bool
118
+ If True, don't raise an error if the database already exists.
119
+ replace : bool
120
+ If True, drop the database if it exists and create a new one.
121
+ kwargs : dict
122
+ Additional keyword arguments to pass to `sqlalchemy_utils.create_database`.
123
+ """
124
+ db_exists = database_exists(url)
125
+
126
+ should_replace = kwargs.pop("replace", False)
127
+ exists_ok = kwargs.pop("exists_ok", False)
128
+
129
+ if should_replace and db_exists:
130
+ drop_database(url)
131
+ db_exists = False
132
+
133
+ if exists_ok and db_exists:
134
+ return
135
+ _create_database(url, **kwargs)
136
+
137
+
138
+ def create_engine(db_conn, **kwargs):
139
+ if isinstance(db_conn, Engine):
140
+ log.info(f"Set up database connection with engine {db_conn.url}")
141
+ if db_conn.driver == "psycopg2":
142
+ log.warning(
143
+ "The psycopg2 driver is deprecated. Please use psycopg3 instead."
144
+ )
145
+ return db_conn
146
+ else:
147
+ log.info(f"Setting up database connection with URL '{db_conn}'")
148
+ url = db_conn
149
+ if isinstance(url, str):
150
+ url = make_url(url)
151
+ # Set the driver to psycopg if not already set
152
+ if url.drivername != "postgresql+psycopg":
153
+ url = url.set(drivername="postgresql+psycopg")
154
+
155
+ return base_create_engine(url, **kwargs)
156
+
157
+
158
+ def connection_args(engine, with_password=False):
159
+ """Get PostgreSQL connection arguments for an engine"""
160
+ _psql_flags = {"-U": "username", "-h": "host", "-p": "port", "-P": "password"}
161
+
162
+ if isinstance(engine, str):
163
+ # We passed a connection url!
164
+ engine = create_engine(engine)
165
+ flags = ""
166
+ for flag, _attr in _psql_flags.items():
167
+ val = getattr(engine.url, _attr)
168
+ if flag == "-P" and not with_password:
169
+ continue
170
+ if val is not None:
171
+ flags += f" {flag} {val}"
172
+ return flags, engine.url.database
173
+
174
+
175
+ def db_isready(engine_or_url, use_shell_command=False):
176
+ if use_shell_command:
177
+ args, _ = connection_args(engine_or_url, with_password=True)
178
+ c = cmd("pg_isready", args, capture_output=True)
179
+ return c.returncode == 0
180
+ # Use a more typical sqlalchemy connection approach
181
+ engine = create_engine(engine_or_url)
182
+ try:
183
+ with engine.connect() as conn:
184
+ conn.execute(text("SELECT 1"))
185
+ return True
186
+ except OperationalError:
187
+ return False
188
+
189
+
190
+ def wait_for_database(engine_or_url, *, quiet=False, use_shell_command=False):
191
+ msg = "Waiting for database..."
192
+ while not db_isready(engine_or_url, use_shell_command=use_shell_command):
193
+ if not quiet:
194
+ echo(msg, err=True)
195
+ log.info(msg)
196
+ sleep(1)
197
+
198
+
199
+ def reflect_table(engine, tablename, *column_args, **kwargs):
200
+ """
201
+ One-off reflection of a database table or view. Note: for most purposes,
202
+ it will be better to use the database tables automapped at runtime in the
203
+ `self.tables` object. However, this function can be useful for views (which
204
+ are not reflected automatically), or to customize type definitions for mapped
205
+ tables.
206
+
207
+ A set of `column_args` can be used to pass columns to override with the mapper, for
208
+ instance to set up foreign and primary key constraints.
209
+ https://docs.sqlalchemy.org/en/13/core/reflection.html#reflecting-views
210
+ """
211
+ schema = kwargs.pop("schema", "public")
212
+ meta = MetaData(schema=schema)
213
+ return Table(tablename, meta, *column_args, autoload_with=engine, **kwargs)
@@ -0,0 +1,44 @@
1
+ [project]
2
+ name = "macrostrat.database"
3
+ version = "4.0.1"
4
+ description = "A SQLAlchemy-based database toolkit."
5
+ authors = [{ name = "Daven Quinn", email = "dev@davenquinn.com" }]
6
+ requires-python = ">=3.10,<4"
7
+ classifiers = [
8
+ "Programming Language :: Python :: 3",
9
+ "Programming Language :: Python :: 3.10",
10
+ "Programming Language :: Python :: 3.11",
11
+ "Programming Language :: Python :: 3.12",
12
+ "Programming Language :: Python :: 3.13",
13
+ "Programming Language :: Python :: 3.14",
14
+ ]
15
+ dependencies = [
16
+ "GeoAlchemy2>=0.15.2,<0.16",
17
+ "SQLAlchemy>=2.0.18,<3",
18
+ "SQLAlchemy-Utils>=0.41.1,<0.42",
19
+ "click>=8.1.3,<9",
20
+ "macrostrat.utils>=1.3.3,<2",
21
+ "sqlparse>=0.5.1,<0.6",
22
+ "aiofiles>=23.2.1,<24",
23
+ "rich>=13.7.1,<14",
24
+ "psycopg>=3.2.1,<4",
25
+ "psycopg2>=2.9.11,<3",
26
+ ]
27
+
28
+ [dependency-groups]
29
+ dev = ["macrostrat.utils"]
30
+
31
+ [tool.uv]
32
+ default-groups = "all"
33
+
34
+ [tool.uv.sources]
35
+ "macrostrat.utils" = { path = "../utils", editable = true }
36
+
37
+ [tool.uv.build-backend]
38
+ module-name = ["macrostrat"]
39
+ module-root = ""
40
+ namespace = true
41
+
42
+ [build-system]
43
+ requires = ["uv_build>=0.9.21,<0.12.0"]
44
+ build-backend = "uv_build"
@@ -1,22 +0,0 @@
1
- Metadata-Version: 2.3
2
- Name: macrostrat.database
3
- Version: 4.0.0
4
- Summary: A SQLAlchemy-based database toolkit.
5
- Author: Daven Quinn
6
- Author-email: dev@davenquinn.com
7
- Requires-Python: >=3.10,<4.0
8
- Classifier: Programming Language :: Python :: 3
9
- Classifier: Programming Language :: Python :: 3.10
10
- Classifier: Programming Language :: Python :: 3.11
11
- Classifier: Programming Language :: Python :: 3.12
12
- Classifier: Programming Language :: Python :: 3.13
13
- Requires-Dist: GeoAlchemy2 (>=0.15.2,<0.16.0)
14
- Requires-Dist: SQLAlchemy (>=2.0.18,<3.0.0)
15
- Requires-Dist: SQLAlchemy-Utils (>=0.41.1,<0.42.0)
16
- Requires-Dist: aiofiles (>=23.2.1,<24.0.0)
17
- Requires-Dist: click (>=8.1.3,<9.0.0)
18
- Requires-Dist: macrostrat.utils (>=1.3.0,<2.0.0)
19
- Requires-Dist: psycopg (>=3.2.1,<4.0.0)
20
- Requires-Dist: psycopg2 (>=2.9.11,<3.0.0)
21
- Requires-Dist: rich (>=13.7.1,<14.0.0)
22
- Requires-Dist: sqlparse (>=0.5.1,<0.6.0)
@@ -1,26 +0,0 @@
1
- [tool.poetry]
2
- authors = ["Daven Quinn <dev@davenquinn.com>"]
3
- description = "A SQLAlchemy-based database toolkit."
4
- name = "macrostrat.database"
5
- packages = [{ include = "macrostrat" }]
6
- version = "4.0.0"
7
-
8
- [tool.poetry.dependencies]
9
- GeoAlchemy2 = "^0.15.2"
10
- SQLAlchemy = "^2.0.18"
11
- SQLAlchemy-Utils = "^0.41.1"
12
- click = "^8.1.3"
13
- "macrostrat.utils" = "^1.3.0"
14
- python = "^3.10"
15
- sqlparse = "^0.5.1"
16
- aiofiles = "^23.2.1"
17
- rich = "^13.7.1"
18
- psycopg = "^3.2.1"
19
- psycopg2 = "^2.9.11"
20
-
21
- [tool.poetry.group.dev.dependencies]
22
- "macrostrat.utils" = { path = "../utils", develop = true }
23
-
24
- [build-system]
25
- build-backend = "poetry.core.masonry.api"
26
- requires = ["poetry-core>=1.0.0"]