macrostrat.database 4.2.0__tar.gz → 4.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (20) hide show
  1. {macrostrat_database-4.2.0 → macrostrat_database-4.3.0}/PKG-INFO +1 -1
  2. macrostrat_database-4.3.0/macrostrat/database/__init__.py +14 -0
  3. macrostrat_database-4.2.0/macrostrat/database/__init__.py → macrostrat_database-4.3.0/macrostrat/database/core.py +125 -8
  4. {macrostrat_database-4.2.0 → macrostrat_database-4.3.0}/macrostrat/database/mapper/__init__.py +3 -0
  5. {macrostrat_database-4.2.0 → macrostrat_database-4.3.0}/macrostrat/database/postgresql.py +1 -1
  6. {macrostrat_database-4.2.0 → macrostrat_database-4.3.0}/macrostrat/database/query.py +31 -12
  7. {macrostrat_database-4.2.0 → macrostrat_database-4.3.0}/macrostrat/database/transfer/utils.py +6 -12
  8. macrostrat_database-4.3.0/macrostrat/database/utils.py +348 -0
  9. {macrostrat_database-4.2.0 → macrostrat_database-4.3.0}/pyproject.toml +1 -1
  10. macrostrat_database-4.2.0/macrostrat/database/utils.py +0 -221
  11. {macrostrat_database-4.2.0 → macrostrat_database-4.3.0}/macrostrat/.DS_Store +0 -0
  12. {macrostrat_database-4.2.0 → macrostrat_database-4.3.0}/macrostrat/database/compat.py +0 -0
  13. {macrostrat_database-4.2.0 → macrostrat_database-4.3.0}/macrostrat/database/mapper/base.py +0 -0
  14. {macrostrat_database-4.2.0 → macrostrat_database-4.3.0}/macrostrat/database/mapper/cache.py +0 -0
  15. {macrostrat_database-4.2.0 → macrostrat_database-4.3.0}/macrostrat/database/mapper/utils.py +0 -0
  16. {macrostrat_database-4.2.0 → macrostrat_database-4.3.0}/macrostrat/database/transfer/__init__.py +0 -0
  17. {macrostrat_database-4.2.0 → macrostrat_database-4.3.0}/macrostrat/database/transfer/dump_database.py +0 -0
  18. {macrostrat_database-4.2.0 → macrostrat_database-4.3.0}/macrostrat/database/transfer/move_tables.py +0 -0
  19. {macrostrat_database-4.2.0 → macrostrat_database-4.3.0}/macrostrat/database/transfer/restore_database.py +0 -0
  20. {macrostrat_database-4.2.0 → macrostrat_database-4.3.0}/macrostrat/database/transfer/stream_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: macrostrat.database
3
- Version: 4.2.0
3
+ Version: 4.3.0
4
4
  Summary: A SQLAlchemy-based database toolkit.
5
5
  Author: Daven Quinn
6
6
  Author-email: Daven Quinn <dev@davenquinn.com>
@@ -0,0 +1,14 @@
1
+ from .core import Database
2
+ from .mapper import DatabaseMapper
3
+ from .postgresql import on_conflict, prefix_inserts # noqa
4
+ from .query import run_fixtures, run_query, run_sql, execute # noqa
5
+ from .utils import ( # noqa
6
+ create_database,
7
+ create_engine,
8
+ database_exists,
9
+ drop_database,
10
+ get_dataframe,
11
+ get_or_create,
12
+ reflect_table,
13
+ get_database_url,
14
+ )
@@ -13,16 +13,14 @@ from sqlalchemy.sql.expression import Insert
13
13
 
14
14
  from macrostrat.utils import get_logger
15
15
  from .mapper import DatabaseMapper
16
- from .postgresql import on_conflict, prefix_inserts # noqa
17
- from .query import run_fixtures, run_query, run_sql, execute # noqa
18
- from .utils import ( # noqa
19
- create_database,
16
+ from .postgresql import prefix_inserts
17
+ from .query import run_fixtures, run_query, run_sql
18
+ from .utils import (
20
19
  create_engine,
21
- database_exists,
22
- drop_database,
23
20
  get_dataframe,
24
21
  get_or_create,
25
22
  reflect_table,
23
+ DatabaseInput,
26
24
  )
27
25
 
28
26
  metadata = MetaData()
@@ -30,6 +28,28 @@ metadata = MetaData()
30
28
  log = get_logger(__name__)
31
29
 
32
30
 
31
+ def _parse_table_name(name, schema=None):
32
+ """Parse a table name into (schema, table_name).
33
+
34
+ Accepts "table", "schema.table", or ("schema", "table").
35
+ An explicit schema kwarg takes precedence. Defaults to "public".
36
+ """
37
+ if isinstance(name, tuple):
38
+ parsed_schema, table_name = name[0], name[1]
39
+ elif "." in str(name):
40
+ parsed_schema, table_name = str(name).split(".", 1)
41
+ else:
42
+ parsed_schema, table_name = None, str(name)
43
+ return (schema or parsed_schema or "public"), table_name
44
+
45
+
46
+ def _model_key(schema, table_name):
47
+ """Return the ModelCollection key used by automap for a given table."""
48
+ if schema == "public":
49
+ return table_name
50
+ return f"{schema}_{table_name}"
51
+
52
+
33
53
  class Database(object):
34
54
  mapper: Optional[DatabaseMapper] = None
35
55
  metadata: MetaData
@@ -38,7 +58,7 @@ class Database(object):
38
58
 
39
59
  __inspector__ = None
40
60
 
41
- def __init__(self, db_conn: Union[str, URL, Engine], *, echo_sql=False, **kwargs):
61
+ def __init__(self, db_conn: DatabaseInput, *, echo_sql=False, **kwargs):
42
62
  """
43
63
  Wrapper for interacting with a database using SQLAlchemy.
44
64
  Optimized for use with PostgreSQL, but usable with SQLite
@@ -58,7 +78,10 @@ class Database(object):
58
78
 
59
79
  self.instance_params = kwargs.pop("instance_params", {})
60
80
 
61
- self.engine = create_engine(db_conn, echo=echo_sql, **kwargs)
81
+ if echo_sql:
82
+ kwargs["echo"] = True
83
+
84
+ self.engine = create_engine(db_conn, **kwargs)
62
85
 
63
86
  self.metadata = kwargs.get("metadata", metadata)
64
87
 
@@ -68,6 +91,7 @@ class Database(object):
68
91
  self._session_factory = sessionmaker(bind=self.engine)
69
92
  self.session = scoped_session(self._session_factory)
70
93
  # Use the self.session_scope function to more explicitly manage sessions.
94
+ self._table_cache: dict = {}
71
95
 
72
96
  def create_tables(self):
73
97
  """
@@ -80,6 +104,10 @@ class Database(object):
80
104
  self.mapper = DatabaseMapper(self)
81
105
  self.mapper.reflect_database(**kwargs)
82
106
 
107
+ def get_server_version(self):
108
+ with self.engine.connect():
109
+ return self.engine.dialect.server_version_info
110
+
83
111
  @contextmanager
84
112
  def session_scope(self, commit=True):
85
113
  """Provide a transactional scope around a series of operations."""
@@ -343,6 +371,95 @@ class Database(object):
343
371
  self.session.close()
344
372
  self.session = _prev_session
345
373
 
374
+ def get_table(self, name, *, schema=None):
375
+ """Return a reflected SQLAlchemy Table object, with per-instance caching.
376
+
377
+ After the first call the result is cached; subsequent calls for the
378
+ same table are instant. If automap has already been run the mapper's
379
+ existing Table is reused, avoiding a second round-trip.
380
+
381
+ Args:
382
+ name: Table name as ``"table"``, ``"schema.table"``, or
383
+ ``("schema", "table")``.
384
+ schema: Explicit schema override (default ``"public"``).
385
+ """
386
+ schema_, table_name = _parse_table_name(name, schema)
387
+ cache_key = (schema_, table_name)
388
+ if cache_key in self._table_cache:
389
+ return self._table_cache[cache_key]
390
+
391
+ # Reuse the already-reflected Table from automap when available
392
+ if self.mapper is not None:
393
+ model_key = _model_key(schema_, table_name)
394
+ if model_key in self.mapper._models:
395
+ tbl = self.mapper._models[model_key].__table__
396
+ self._table_cache[cache_key] = tbl
397
+ return tbl
398
+
399
+ # Per-table reflection; "public" → None matches how automap stores it
400
+ reflect_schema = None if schema_ == "public" else schema_
401
+ tbl = reflect_table(self.engine, table_name, schema=reflect_schema)
402
+ self._table_cache[cache_key] = tbl
403
+ return tbl
404
+
405
+ def get_model(self, name, *, schema=None, automap=True):
406
+ """Return the ORM model class for a table.
407
+
408
+ If the target schema has not yet been reflected and ``automap=True``
409
+ (the default), it is reflected lazily before the lookup. Set
410
+ ``automap=False`` to raise ``LookupError`` instead, which is useful
411
+ when you want strict control over when reflection happens.
412
+
413
+ Args:
414
+ name: Table name as ``"table"``, ``"schema.table"``, or
415
+ ``("schema", "table")``.
416
+ schema: Explicit schema override (default ``"public"``).
417
+ automap: Lazily reflect the schema if not yet mapped.
418
+
419
+ Raises:
420
+ LookupError: When the model is not found.
421
+ """
422
+ schema_, table_name = _parse_table_name(name, schema)
423
+ model_key = _model_key(schema_, table_name)
424
+
425
+ if self.mapper is not None and model_key in self.mapper._models:
426
+ return self.mapper._models[model_key]
427
+
428
+ if not automap:
429
+ raise LookupError(
430
+ f"No ORM model found for {schema_}.{table_name}. "
431
+ "Call db.automap() first, or use get_table() for a Table object."
432
+ )
433
+
434
+ # Lazy automap: reflect only the needed schema
435
+ if self.mapper is None:
436
+ self.automap(schemas=[schema_])
437
+ elif schema_ not in self.mapper._reflected_schemas:
438
+ self.mapper.reflect_schema(schema_)
439
+
440
+ if model_key in self.mapper._models:
441
+ return self.mapper._models[model_key]
442
+
443
+ raise LookupError(
444
+ f"No ORM model found for {schema_}.{table_name} after reflecting "
445
+ f"schema '{schema_}'. Verify the table exists, or use get_table()."
446
+ )
447
+
448
+ def __getitem__(self, name):
449
+ """Subscript shorthand for get_table()."""
450
+ return self.get_table(name)
451
+
452
+ # Destroy engine on cleanup
453
+ def cleanup(self):
454
+ try:
455
+ self.session.close()
456
+ except OperationalError:
457
+ pass
458
+ self.engine.dispose()
459
+
460
+ def __del__(self):
461
+ self.cleanup()
462
+
346
463
 
347
464
  def _clear_savepoint(connection, name, rollback=True):
348
465
  params = {"name": Identifier(name)}
@@ -33,6 +33,7 @@ class DatabaseMapper:
33
33
  automap_error = None
34
34
  _models = None
35
35
  _tables = None
36
+ _reflected_schemas: set
36
37
 
37
38
  def __init__(self, db, **kwargs):
38
39
  # https://docs.sqlalchemy.org/en/13/orm/extensions/automap.html#sqlalchemy.ext.automap.AutomapBase.prepare
@@ -56,6 +57,7 @@ class DatabaseMapper:
56
57
 
57
58
  self._models = ModelCollection(self.automap_base.classes)
58
59
  self._tables = TableCollection(self._models)
60
+ self._reflected_schemas = set()
59
61
 
60
62
  def reflect_database(self, schemas=["public"], use_cache=True):
61
63
  # This stuff should be placed outside of core (one likely extension point).
@@ -74,6 +76,7 @@ class DatabaseMapper:
74
76
  self.automap_base.builder._cache_database_map(self.automap_base.metadata)
75
77
 
76
78
  def reflect_schema(self, schema, use_cache=True):
79
+ self._reflected_schemas.add(schema or "public")
77
80
  if use_cache and self.automap_base.loaded_from_cache:
78
81
  log.info("Database models for %s have been loaded from cache", schema)
79
82
  self.automap_base.prepare(schema=schema, **self.reflection_kwargs)
@@ -12,7 +12,7 @@ from sqlalchemy.sql.dml import Insert
12
12
  from sqlalchemy.sql.expression import text
13
13
 
14
14
  if TYPE_CHECKING:
15
- from ..database import Database
15
+ from .core import Database
16
16
 
17
17
 
18
18
  class OnConflictAction(str, Enum):
@@ -10,10 +10,6 @@ from warnings import warn
10
10
 
11
11
  import psycopg2.sql as psql2
12
12
  from click import secho
13
- from macrostrat.database.compat import (
14
- update_legacy_identifier,
15
- )
16
- from macrostrat.utils import get_logger
17
13
  from psycopg.errors import QueryCanceled
18
14
  from psycopg.sql import SQL, Composable, Composed
19
15
  from rich.console import Console
@@ -29,6 +25,11 @@ from sqlalchemy.exc import (
29
25
  from sqlalchemy.sql.elements import TextClause
30
26
  from sqlparse import format, split
31
27
 
28
+ from macrostrat.database.compat import (
29
+ update_legacy_identifier,
30
+ )
31
+ from macrostrat.utils import get_logger
32
+
32
33
  log = get_logger(__name__)
33
34
 
34
35
 
@@ -270,7 +271,15 @@ def _statement_filter_to_transform(statement_filter) -> TransformFn:
270
271
  return transform
271
272
 
272
273
 
273
- def _run_sql(connectable, sql, params=None, *, print_skipped=True, **kwargs):
274
+ def _run_sql(
275
+ connectable,
276
+ sql,
277
+ params=None,
278
+ *,
279
+ print_skipped=True,
280
+ use_transaction=True,
281
+ **kwargs,
282
+ ):
274
283
  """
275
284
  Internal function for running a query on a SQLAlchemy connectable,
276
285
  which always returns an iterator. The wrapper function adds the option
@@ -341,6 +350,7 @@ def _run_sql(connectable, sql, params=None, *, print_skipped=True, **kwargs):
341
350
  output_mode=output_mode,
342
351
  print_skipped=print_skipped,
343
352
  has_server_binds=has_server_binds,
353
+ use_transaction=use_transaction,
344
354
  )
345
355
 
346
356
 
@@ -389,6 +399,7 @@ def _execute_one(
389
399
  output_mode: OutputMode = OutputMode.SUMMARY,
390
400
  has_server_binds: bool | None = None,
391
401
  print_skipped: bool = True,
402
+ use_transaction: bool = True,
392
403
  ):
393
404
  params = result.params
394
405
 
@@ -398,20 +409,24 @@ def _execute_one(
398
409
 
399
410
  if result.label is not None:
400
411
  display_text = result.label
412
+ elif output_mode == OutputMode.NONE:
413
+ display_text = None
401
414
  elif output_mode != OutputMode.ALL:
402
415
  display_text = summarize_statement(str(query))
403
416
  else:
404
417
  display_text = str(query)
405
418
 
406
419
  if result.skip:
407
- if print_skipped:
420
+ if print_skipped and display_text is not None:
408
421
  secho(display_text, dim=True, strikethrough=True, file=output_file)
409
422
  return
410
423
 
411
- try:
412
- trans = connectable.begin()
413
- except InvalidRequestError:
414
- trans = None
424
+ trans = None
425
+ if use_transaction:
426
+ try:
427
+ trans = connectable.begin()
428
+ except InvalidRequestError:
429
+ pass
415
430
 
416
431
  try:
417
432
  log.debug("Executing SQL: \n %s", query)
@@ -430,7 +445,8 @@ def _execute_one(
430
445
  elif hasattr(connectable, "commit"):
431
446
  connectable.commit()
432
447
 
433
- secho(display_text, dim=True, file=output_file)
448
+ if display_text is not None:
449
+ secho(display_text, dim=True, file=output_file)
434
450
 
435
451
  except Exception as err:
436
452
  if trans is not None:
@@ -439,7 +455,8 @@ def _execute_one(
439
455
  connectable.rollback()
440
456
  if raise_errors or _should_raise_query_error(err):
441
457
  raise err
442
- _print_error(display_text, err, file=output_file)
458
+ if display_text is not None:
459
+ _print_error(display_text, err, file=output_file)
443
460
 
444
461
 
445
462
  def _should_raise_query_error(err):
@@ -590,6 +607,8 @@ def run_sql(*args, **kwargs):
590
607
  objects, which can modify the query, parameters, and whether the statement
591
608
  should be skipped or not. This allows for more complex logic than a simple
592
609
  statement filter.
610
+ use_transaction: bool
611
+ Whether to run the query in a transaction block
593
612
  """
594
613
  res = _run_sql(*args, **kwargs)
595
614
  if kwargs.pop("yield_results", False):
@@ -72,20 +72,14 @@ def _create_command(
72
72
  args = []
73
73
 
74
74
  command_prefix = prefix or _docker_local_run_args(container)
75
- _cmd = [*command_prefix, *command, str(engine.url), *args]
76
75
 
77
- log.info(" ".join(_cmd))
78
-
79
- # Replace asterisks with the real password (if any). This is kind of backwards
80
- # but it works.
81
- if "***" in str(engine.url) and engine.url.password is not None:
82
- _cmd = [
83
- *command_prefix,
84
- *command,
85
- raw_database_url(engine.url),
86
- *args,
87
- ]
76
+ # Strip the SQLAlchemy dialect suffix (e.g. +psycopg) — pg_dump/pg_restore
77
+ # only understand plain postgresql:// URLs. Also expand any *** password
78
+ # masking so the subprocess receives the real credentials.
79
+ pg_url = raw_database_url(engine.url.set(drivername="postgresql"))
88
80
 
81
+ _cmd = [*command_prefix, *command, pg_url, *args]
82
+ log.info(" ".join(_cmd))
89
83
  return _cmd
90
84
 
91
85
 
@@ -0,0 +1,348 @@
1
+ from contextlib import contextmanager
2
+ from time import sleep
3
+ from typing import Union
4
+ from uuid import uuid4
5
+ from warnings import warn
6
+
7
+ from click import echo
8
+ from psycopg.errors import AdminShutdown
9
+ from psycopg.sql import Identifier
10
+ from sqlalchemy import MetaData
11
+ from sqlalchemy import create_engine as base_create_engine
12
+ from sqlalchemy import text
13
+ from sqlalchemy.engine import Engine
14
+ from sqlalchemy.engine.url import make_url, URL
15
+ from sqlalchemy.exc import (
16
+ OperationalError,
17
+ )
18
+ from sqlalchemy.orm import sessionmaker
19
+ from sqlalchemy.schema import Table
20
+ from sqlalchemy.sql.elements import ClauseElement
21
+ from sqlalchemy_utils import (
22
+ create_database as _create_database,
23
+ database_exists,
24
+ drop_database as _drop_database,
25
+ )
26
+
27
+ from macrostrat.utils import cmd, get_logger
28
+ from .query import get_sql_text, execute # noqa
29
+
30
+ log = get_logger(__name__)
31
+
32
+ # Ensure that old import structure still works
33
+ from .query import run_sql, run_query, run_sql_file, run_fixtures # noqa: F401
34
+
35
+ DatabaseInput = Union["Database", Engine, str, URL]
36
+
37
+
38
+ def get_dataframe(connectable, filename_or_query, **kwargs):
39
+ """
40
+ Run a query on a SQL database (represented by
41
+ a SQLAlchemy database object) and turn it into a
42
+ `Pandas` dataframe.
43
+ """
44
+ from pandas import read_sql
45
+
46
+ sql = get_sql_text(filename_or_query)
47
+
48
+ return read_sql(sql, connectable, **kwargs)
49
+
50
+
51
+ def db_session(engine):
52
+ factory = sessionmaker(bind=engine)
53
+ return factory()
54
+
55
+
56
+ def get_or_create(session, model, defaults=None, **kwargs):
57
+ """
58
+ Get an instance of a model, or create it if it doesn't
59
+ exist.
60
+
61
+ https://stackoverflow.com/questions/2546207
62
+ """
63
+ instance = session.query(model).filter_by(**kwargs).first()
64
+ if instance:
65
+ instance._created = False
66
+ return instance
67
+ else:
68
+ params = dict(
69
+ (k, v) for k, v in kwargs.items() if not isinstance(v, ClauseElement)
70
+ )
71
+ params.update(defaults or {})
72
+ instance = model(**params)
73
+ session.add(instance)
74
+ instance._created = True
75
+ return instance
76
+
77
+
78
+ def get_db_model(db, model_name: str):
79
+ return getattr(db.model, model_name)
80
+
81
+
82
+ @contextmanager
83
+ def temp_database(*args, **kwargs):
84
+ warn(
85
+ "temp_database is deprecated, use temporary_database instead",
86
+ DeprecationWarning,
87
+ )
88
+ with temporary_database(*args, **kwargs) as engine:
89
+ yield engine
90
+
91
+
92
+ @contextmanager
93
+ def temporary_database(
94
+ _input: DatabaseInput,
95
+ *,
96
+ drop=True,
97
+ ensure_empty=False,
98
+ exists_ok=True,
99
+ template=None,
100
+ force_drop=False,
101
+ ):
102
+ """Create a temporary database and tear it down after tests."""
103
+ url = get_database_url(_input)
104
+ create_database(url, exists_ok=exists_ok, replace=ensure_empty, template=template)
105
+ engine = create_engine(url)
106
+ try:
107
+ yield engine
108
+ engine.dispose()
109
+ finally:
110
+ if drop:
111
+ drop_database(engine, force=force_drop)
112
+
113
+
114
+ def drop_database(_input: DatabaseInput, force=None, allow_missing=False):
115
+ """Drop a database.
116
+
117
+ Parameters
118
+ ----------
119
+ engine : Database, Engine or str
120
+ A SQLAlchemy engine or database URL.
121
+ force: bool
122
+ If true, use the `force` parameter
123
+ """
124
+ url = get_database_url(_input)
125
+ if not database_exists(url):
126
+ if not allow_missing:
127
+ raise ValueError(f"Database {url} does not exist")
128
+ elif "postgres" in url.drivername and force is not False:
129
+ # Check if we can force-drop and do so if we can
130
+ _force_drop_postgresql_database(url)
131
+ else:
132
+ # Drop the database without force
133
+ _drop_database(url)
134
+
135
+
136
+ def _force_drop_postgresql_database(url):
137
+ # Check if we can force-drop and do so if we can
138
+ database_name = url.database
139
+ user_url = url._replace(database=None)
140
+ user_engine = create_engine(
141
+ user_url, execution_options={"isolation_level": "AUTOCOMMIT"}
142
+ )
143
+ # Get postgresql version from engine
144
+ major_version = 0
145
+ with allow_shutdown(user_engine) as conn:
146
+ conn.autocommit = True
147
+ pg_version = user_engine.dialect.server_version_info
148
+ major_version = pg_version[0]
149
+ can_use_modern_force = major_version >= 13
150
+ sql = "DROP DATABASE {database_name}"
151
+ params = dict(database_name=Identifier(database_name))
152
+ if can_use_modern_force:
153
+ sql += " WITH (FORCE)"
154
+ else:
155
+ close_all_connections(user_engine, database=database_name)
156
+ run_sql(conn, sql, params=params, raise_errors=True, use_transaction=False)
157
+ user_engine.dispose()
158
+
159
+
160
+ @contextmanager
161
+ def allow_shutdown(engine):
162
+ with engine.connect() as conn:
163
+ conn.autocommit = True
164
+ try:
165
+ yield conn
166
+ except AdminShutdown:
167
+ pass
168
+ except OperationalError as exc:
169
+ if isinstance(exc.orig, AdminShutdown):
170
+ pass
171
+ else:
172
+ raise exc
173
+
174
+
175
+ def close_all_connections(engine: Engine, database: str = None):
176
+ """Close all connections to the database."""
177
+ if database is None:
178
+ database = engine.url.database
179
+ sql = "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = :database"
180
+ params = dict(database=database)
181
+ with allow_shutdown(engine) as conn:
182
+ run_sql(conn, sql, params=params, raise_errors=True, use_transaction=False)
183
+
184
+
185
+ def get_database_url(_input: DatabaseInput) -> URL:
186
+ from .core import Database
187
+
188
+ if isinstance(_input, Database):
189
+ return _input.engine.url
190
+ elif isinstance(_input, Engine):
191
+ return _input.url
192
+ elif isinstance(_input, str) or isinstance(_input, URL):
193
+ return make_url(_input)
194
+ else:
195
+ raise ValueError(f"Invalid input type: {_input}")
196
+
197
+
198
+ @contextmanager
199
+ def template_database(
200
+ _input: DatabaseInput,
201
+ *,
202
+ name: str = None,
203
+ force_drop=True,
204
+ close_source_connections=False,
205
+ ):
206
+ """Create a temporary template database using an existing database as a template."""
207
+
208
+ url = get_database_url(_input)
209
+ if close_source_connections:
210
+ engine = create_engine(_input)
211
+ close_all_connections(engine)
212
+ engine.dispose()
213
+
214
+ db_name = url.database
215
+ template_db_name = name
216
+ if name is None:
217
+ uid = str(uuid4())[:8]
218
+ template_db_name = db_name + "_template_" + uid
219
+ # Close connection to the database so we can create a new one based on the template
220
+ new_db_url = url.set(database=template_db_name)
221
+ with temporary_database(
222
+ new_db_url, drop=True, exists_ok=False, template=db_name, force_drop=force_drop
223
+ ) as engine:
224
+ yield engine
225
+
226
+
227
+ def create_database(_input: DatabaseInput, **kwargs):
228
+ """Create a database if it doesn't exist.
229
+
230
+ Parameters
231
+ ----------
232
+ url : str
233
+ A SQLAlchemy database URL.
234
+ exists_ok : bool
235
+ If True, don't raise an error if the database already exists.
236
+ replace : bool
237
+ If True, drop the database if it exists and create a new one.
238
+ kwargs : dict
239
+ Additional keyword arguments to pass to `sqlalchemy_utils.create_database`.
240
+ """
241
+ url = get_database_url(_input)
242
+ db_exists = database_exists(url)
243
+
244
+ should_replace = kwargs.pop("replace", False)
245
+ exists_ok = kwargs.pop("exists_ok", False)
246
+
247
+ if should_replace and db_exists:
248
+ drop_database(url)
249
+ db_exists = False
250
+
251
+ if exists_ok and db_exists:
252
+ return
253
+ _create_database(url, **kwargs)
254
+
255
+
256
+ def create_engine(_input: DatabaseInput, **kwargs):
257
+ from .core import Database
258
+
259
+ recreate = False
260
+ # If we specify engine options, we should recreate the engine
261
+ if len(kwargs) > 0:
262
+ recreate = True
263
+ db_conn = _input
264
+ if isinstance(_input, Database):
265
+ db_conn = _input.engine
266
+ elif isinstance(_input, str):
267
+ db_conn = make_url(_input)
268
+ elif isinstance(_input, URL):
269
+ db_conn = _input
270
+
271
+ if isinstance(db_conn, Engine):
272
+ if recreate:
273
+ # Reuse the existing engine
274
+ log.info(f"Set up database connection with engine {db_conn.url}")
275
+ if db_conn.driver == "psycopg2":
276
+ log.warning(
277
+ "The psycopg2 driver is deprecated. Please use psycopg3 instead."
278
+ )
279
+ return db_conn
280
+ else:
281
+ db_conn = db_conn.url
282
+
283
+ if not isinstance(db_conn, URL):
284
+ raise ValueError(f"Invalid input type: {_input}")
285
+ url = db_conn
286
+
287
+ log.info(f"Setting up database connection with URL '{url}'")
288
+ # Set the driver to psycopg if not already set
289
+ if "postgres" in url.drivername:
290
+ url = url.set(drivername="postgresql+psycopg")
291
+
292
+ return base_create_engine(url, **kwargs)
293
+
294
+
295
+ def connection_args(_input: DatabaseInput, with_password=False):
296
+ """Get PostgreSQL connection arguments for an engine"""
297
+ _psql_flags = {"-U": "username", "-h": "host", "-p": "port", "-P": "password"}
298
+ url = get_database_url(_input)
299
+
300
+ flags = ""
301
+ for flag, _attr in _psql_flags.items():
302
+ val = getattr(url, _attr)
303
+ if flag == "-P" and not with_password:
304
+ continue
305
+ if val is not None:
306
+ flags += f" {flag} {val}"
307
+ return flags, url.database
308
+
309
+
310
+ def db_isready(_input: DatabaseInput, use_shell_command=False):
311
+ if use_shell_command:
312
+ args, _ = connection_args(_input, with_password=True)
313
+ c = cmd("pg_isready", args, capture_output=True)
314
+ return c.returncode == 0
315
+ # Use a more typical sqlalchemy connection approach
316
+ engine = create_engine(_input)
317
+ try:
318
+ with engine.connect() as conn:
319
+ conn.execute(text("SELECT 1"))
320
+ return True
321
+ except OperationalError:
322
+ return False
323
+
324
+
325
+ def wait_for_database(_input: DatabaseInput, *, quiet=False, use_shell_command=False):
326
+ msg = "Waiting for database..."
327
+ while not db_isready(_input, use_shell_command=use_shell_command):
328
+ if not quiet:
329
+ echo(msg, err=True)
330
+ log.info(msg)
331
+ sleep(1)
332
+
333
+
334
+ def reflect_table(engine, tablename, *column_args, **kwargs):
335
+ """
336
+ One-off reflection of a database table or view. Note: for most purposes,
337
+ it will be better to use the database tables automapped at runtime in the
338
+ `self.tables` object. However, this function can be useful for views (which
339
+ are not reflected automatically), or to customize type definitions for mapped
340
+ tables.
341
+
342
+ A set of `column_args` can be used to pass columns to override with the mapper, for
343
+ instance to set up foreign and primary key constraints.
344
+ https://docs.sqlalchemy.org/en/13/core/reflection.html#reflecting-views
345
+ """
346
+ schema = kwargs.pop("schema", "public")
347
+ meta = MetaData(schema=schema)
348
+ return Table(tablename, meta, *column_args, autoload_with=engine, **kwargs)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "macrostrat.database"
3
- version = "4.2.0"
3
+ version = "4.3.0"
4
4
  description = "A SQLAlchemy-based database toolkit."
5
5
  authors = [{ name = "Daven Quinn", email = "dev@davenquinn.com" }]
6
6
  requires-python = ">=3.10,<4"
@@ -1,221 +0,0 @@
1
- import warnings
2
- from contextlib import contextmanager
3
- from time import sleep
4
- from uuid import uuid4
5
-
6
- from click import echo
7
- from sqlalchemy import MetaData
8
- from sqlalchemy import create_engine as base_create_engine
9
- from sqlalchemy import text
10
- from sqlalchemy.engine import Engine
11
- from sqlalchemy.engine.url import make_url
12
- from sqlalchemy.exc import (
13
- OperationalError,
14
- )
15
- from sqlalchemy.orm import sessionmaker
16
- from sqlalchemy.schema import Table
17
- from sqlalchemy.sql.elements import ClauseElement
18
- from sqlalchemy_utils import create_database as _create_database
19
- from sqlalchemy_utils import database_exists, drop_database
20
-
21
- from macrostrat.utils import cmd, get_logger
22
- from .query import get_sql_text, execute # noqa
23
-
24
- log = get_logger(__name__)
25
-
26
- # Ensure that old import structure still works
27
- from .query import run_sql, run_query, run_sql_file, run_fixtures # noqa: F401
28
-
29
-
30
- def get_dataframe(connectable, filename_or_query, **kwargs):
31
- """
32
- Run a query on a SQL database (represented by
33
- a SQLAlchemy database object) and turn it into a
34
- `Pandas` dataframe.
35
- """
36
- from pandas import read_sql
37
-
38
- sql = get_sql_text(filename_or_query)
39
-
40
- return read_sql(sql, connectable, **kwargs)
41
-
42
-
43
- def db_session(engine):
44
- factory = sessionmaker(bind=engine)
45
- return factory()
46
-
47
-
48
- def get_or_create(session, model, defaults=None, **kwargs):
49
- """
50
- Get an instance of a model, or create it if it doesn't
51
- exist.
52
-
53
- https://stackoverflow.com/questions/2546207
54
- """
55
- instance = session.query(model).filter_by(**kwargs).first()
56
- if instance:
57
- instance._created = False
58
- return instance
59
- else:
60
- params = dict(
61
- (k, v) for k, v in kwargs.items() if not isinstance(v, ClauseElement)
62
- )
63
- params.update(defaults or {})
64
- instance = model(**params)
65
- session.add(instance)
66
- instance._created = True
67
- return instance
68
-
69
-
70
- def get_db_model(db, model_name: str):
71
- return getattr(db.model, model_name)
72
-
73
-
74
- @contextmanager
75
- def temporary_database(
76
- conn_string, *, drop=True, ensure_empty=False, exists_ok=True, template=None
77
- ):
78
- """Create a temporary database and tear it down after tests."""
79
- create_database(
80
- conn_string, exists_ok=exists_ok, replace=ensure_empty, template=template
81
- )
82
- try:
83
- engine = create_engine(conn_string)
84
- yield engine
85
- engine.dispose()
86
- finally:
87
- if drop:
88
- drop_database(conn_string)
89
-
90
-
91
- @contextmanager
92
- def temp_database(*args, **kwargs):
93
- warnings.warn(
94
- "temp_database is deprecated, use temporary_database instead",
95
- DeprecationWarning,
96
- )
97
- with temporary_database(*args, **kwargs) as engine:
98
- yield engine
99
-
100
-
101
- @contextmanager
102
- def template_database(engine: Engine, *, name: str = None):
103
- """Create a temporary template database using an existing database as a template."""
104
- db_name = engine.url.database
105
- template_db_name = name
106
- if name is None:
107
- uid = str(uuid4())[:8]
108
- template_db_name = db_name + "_template_" + uid
109
- # Close connection to the database so we can create a new one based on the template
110
- new_db_url = engine.url.set(database=template_db_name)
111
- engine.dispose()
112
- with temporary_database(
113
- new_db_url, drop=True, exists_ok=False, template=db_name
114
- ) as engine:
115
- yield engine
116
-
117
-
118
- def create_database(url, **kwargs):
119
- """Create a database if it doesn't exist.
120
-
121
- Parameters
122
- ----------
123
- url : str
124
- A SQLAlchemy database URL.
125
- exists_ok : bool
126
- If True, don't raise an error if the database already exists.
127
- replace : bool
128
- If True, drop the database if it exists and create a new one.
129
- kwargs : dict
130
- Additional keyword arguments to pass to `sqlalchemy_utils.create_database`.
131
- """
132
- db_exists = database_exists(url)
133
-
134
- should_replace = kwargs.pop("replace", False)
135
- exists_ok = kwargs.pop("exists_ok", False)
136
-
137
- if should_replace and db_exists:
138
- drop_database(url)
139
- db_exists = False
140
-
141
- if exists_ok and db_exists:
142
- return
143
- _create_database(url, **kwargs)
144
-
145
-
146
- def create_engine(db_conn, **kwargs):
147
- if isinstance(db_conn, Engine):
148
- log.info(f"Set up database connection with engine {db_conn.url}")
149
- if db_conn.driver == "psycopg2":
150
- log.warning(
151
- "The psycopg2 driver is deprecated. Please use psycopg3 instead."
152
- )
153
- return db_conn
154
- else:
155
- log.info(f"Setting up database connection with URL '{db_conn}'")
156
- url = db_conn
157
- if isinstance(url, str):
158
- url = make_url(url)
159
- # Set the driver to psycopg if not already set
160
- if "postgres" in url.drivername:
161
- url = url.set(drivername="postgresql+psycopg")
162
-
163
- return base_create_engine(url, **kwargs)
164
-
165
-
166
- def connection_args(engine, with_password=False):
167
- """Get PostgreSQL connection arguments for an engine"""
168
- _psql_flags = {"-U": "username", "-h": "host", "-p": "port", "-P": "password"}
169
-
170
- if isinstance(engine, str):
171
- # We passed a connection url!
172
- engine = create_engine(engine)
173
- flags = ""
174
- for flag, _attr in _psql_flags.items():
175
- val = getattr(engine.url, _attr)
176
- if flag == "-P" and not with_password:
177
- continue
178
- if val is not None:
179
- flags += f" {flag} {val}"
180
- return flags, engine.url.database
181
-
182
-
183
- def db_isready(engine_or_url, use_shell_command=False):
184
- if use_shell_command:
185
- args, _ = connection_args(engine_or_url, with_password=True)
186
- c = cmd("pg_isready", args, capture_output=True)
187
- return c.returncode == 0
188
- # Use a more typical sqlalchemy connection approach
189
- engine = create_engine(engine_or_url)
190
- try:
191
- with engine.connect() as conn:
192
- conn.execute(text("SELECT 1"))
193
- return True
194
- except OperationalError:
195
- return False
196
-
197
-
198
- def wait_for_database(engine_or_url, *, quiet=False, use_shell_command=False):
199
- msg = "Waiting for database..."
200
- while not db_isready(engine_or_url, use_shell_command=use_shell_command):
201
- if not quiet:
202
- echo(msg, err=True)
203
- log.info(msg)
204
- sleep(1)
205
-
206
-
207
- def reflect_table(engine, tablename, *column_args, **kwargs):
208
- """
209
- One-off reflection of a database table or view. Note: for most purposes,
210
- it will be better to use the database tables automapped at runtime in the
211
- `self.tables` object. However, this function can be useful for views (which
212
- are not reflected automatically), or to customize type definitions for mapped
213
- tables.
214
-
215
- A set of `column_args` can be used to pass columns to override with the mapper, for
216
- instance to set up foreign and primary key constraints.
217
- https://docs.sqlalchemy.org/en/13/core/reflection.html#reflecting-views
218
- """
219
- schema = kwargs.pop("schema", "public")
220
- meta = MetaData(schema=schema)
221
- return Table(tablename, meta, *column_args, autoload_with=engine, **kwargs)