macrostrat.database 3.2.0__tar.gz → 3.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: macrostrat.database
3
- Version: 3.2.0
3
+ Version: 3.3.0
4
4
  Summary: A SQLAlchemy-based database toolkit.
5
5
  Author: Daven Quinn
6
6
  Author-email: dev@davenquinn.com
@@ -10,6 +10,7 @@ Classifier: Programming Language :: Python :: 3.8
10
10
  Classifier: Programming Language :: Python :: 3.9
11
11
  Classifier: Programming Language :: Python :: 3.10
12
12
  Classifier: Programming Language :: Python :: 3.11
13
+ Classifier: Programming Language :: Python :: 3.12
13
14
  Requires-Dist: GeoAlchemy2 (>=0.14.0,<0.15.0)
14
15
  Requires-Dist: SQLAlchemy (>=2.0.18,<3.0.0)
15
16
  Requires-Dist: SQLAlchemy-Utils (>=0.41.1,<0.42.0)
@@ -0,0 +1,338 @@
1
+ import warnings
2
+ from contextlib import contextmanager
3
+ from enum import Enum
4
+ from pathlib import Path
5
+ from typing import Optional, Union
6
+
7
+ from psycopg2.errors import InvalidSavepointSpecification
8
+ from psycopg2.sql import Identifier
9
+ from sqlalchemy import URL, MetaData, create_engine, inspect, text
10
+ from sqlalchemy.exc import IntegrityError, InternalError
11
+ from sqlalchemy.ext.compiler import compiles
12
+ from sqlalchemy.orm import Session, scoped_session, sessionmaker
13
+ from sqlalchemy.sql.expression import Insert
14
+
15
+ from macrostrat.utils import get_logger
16
+
17
+ from .mapper import DatabaseMapper
18
+ from .postgresql import on_conflict, prefix_inserts # noqa
19
+ from .utils import ( # noqa
20
+ create_database,
21
+ database_exists,
22
+ drop_database,
23
+ get_dataframe,
24
+ get_or_create,
25
+ reflect_table,
26
+ run_fixtures,
27
+ run_query,
28
+ run_sql,
29
+ )
30
+
31
+ metadata = MetaData()
32
+
33
+ log = get_logger(__name__)
34
+
35
+
36
+ class Database(object):
37
+ mapper: Optional[DatabaseMapper] = None
38
+ metadata: MetaData
39
+ session: Session
40
+ instance_params: dict
41
+
42
+ __inspector__ = None
43
+
44
+ def __init__(self, db_conn: Union[str, URL], *, echo_sql=False, **kwargs):
45
+ """
46
+ Wrapper for interacting with a database using SQLAlchemy.
47
+ Optimized for use with PostgreSQL, but usable with SQLite
48
+ as well.
49
+
50
+ Args:
51
+ db_conn (str): Connection string for the database.
52
+
53
+ Keyword Args:
54
+ echo_sql (bool): If True, will echo SQL commands to the
55
+ console. Default is False.
56
+ instance_params (dict): Parameters to
57
+ pass to queries and other database operations.
58
+ """
59
+
60
+ compiles(Insert, "postgresql")(prefix_inserts)
61
+
62
+ self.instance_params = kwargs.pop("instance_params", {})
63
+
64
+ log.info(f"Setting up database connection '{db_conn}'")
65
+ self.engine = create_engine(db_conn, echo=echo_sql, **kwargs)
66
+ self.metadata = kwargs.get("metadata", metadata)
67
+
68
+ # Scoped session for database
69
+ # https://docs.sqlalchemy.org/en/13/orm/contextual.html#unitofwork-contextual
70
+ # https://docs.sqlalchemy.org/en/13/orm/session_basics.html#session-faq-whentocreate
71
+ self._session_factory = sessionmaker(bind=self.engine)
72
+ self.session = scoped_session(self._session_factory)
73
+ # Use the self.session_scope function to more explicitly manage sessions.
74
+
75
+ def create_tables(self):
76
+ """
77
+ Create all tables described by the database's metadata instance.
78
+ """
79
+ metadata.create_all(bind=self.engine)
80
+
81
+ def automap(self, **kwargs):
82
+ log.info("Automapping the database")
83
+ self.mapper = DatabaseMapper(self)
84
+ self.mapper.reflect_database(**kwargs)
85
+
86
+ @contextmanager
87
+ def session_scope(self, commit=True):
88
+ """Provide a transactional scope around a series of operations."""
89
+ # self.__old_session = self.session
90
+ # session = self._session_factory()
91
+ session = self.session
92
+ try:
93
+ yield session
94
+ if commit:
95
+ session.commit()
96
+ except Exception as err:
97
+ session.rollback()
98
+ raise err
99
+ finally:
100
+ session.close()
101
+
102
+ def _flush_nested_objects(self, session):
103
+ """
104
+ Flush objects remaining in a session (generally these are objects loaded
105
+ during schema-based importing).
106
+ """
107
+ for object in session:
108
+ try:
109
+ session.flush(objects=[object])
110
+ log.debug(f"Successfully flushed instance {object}")
111
+ except IntegrityError as err:
112
+ session.rollback()
113
+ log.debug(err)
114
+
115
+ def run_sql(self, fn, params=None, **kwargs):
116
+ """Executes SQL files or query strings using the run_sql function.
117
+
118
+ Args:
119
+ fn (str|Path): SQL file or query string to execute.
120
+ params (dict): Parameters to pass to the query.
121
+
122
+ Keyword Args:
123
+ use_instance_params (bool): If True, will use the instance_params set on
124
+ the Database object. Default is True.
125
+
126
+ Returns: Iterator of results from the query.
127
+ """
128
+ params = self._setup_params(params, kwargs)
129
+ return iter(run_sql(self.session, fn, params, **kwargs))
130
+
131
+ def run_query(self, sql, params=None, **kwargs):
132
+ """Run a single query on the database object, returning the result.
133
+
134
+ Args:
135
+ sql (str): SQL file or query to execute.
136
+ params (dict): Parameters to pass to the query.
137
+
138
+ Keyword Args:
139
+ use_instance_params (bool): If True, will use the instance_params set on
140
+ the Database object. Default is True.
141
+ """
142
+ params = self._setup_params(params, kwargs)
143
+ return run_query(self.session, sql, params, **kwargs)
144
+
145
+ def run_fixtures(self, fixtures: Union[Path, list[Path]], params=None, **kwargs):
146
+ """Run a set of fixtures on the database object.
147
+
148
+ Args:
149
+ fixtures (Path|list[Path]): Path to a directory of fixtures or a list of paths to fixture files.
150
+ params (dict): Parameters to pass to the query.
151
+
152
+ Keyword Args:
153
+ use_instance_params (bool): If True, will use the instance_params set on
154
+ the Database object. Default is True.
155
+ """
156
+ params = self._setup_params(params, kwargs)
157
+ return run_fixtures(self.session, fixtures, params, **kwargs)
158
+
159
+ def _setup_params(self, params, kwargs):
160
+ use_instance_params = kwargs.pop("use_instance_params", True)
161
+ if params is None:
162
+ params = {}
163
+ if use_instance_params:
164
+ params.update(self.instance_params)
165
+ return params
166
+
167
+ def exec_sql(self, sql, params=None, **kwargs):
168
+ """Executes SQL files passed"""
169
+ warnings.warn(
170
+ "exec_sql is deprecated and will be removed in version 4.0. Use run_sql instead",
171
+ DeprecationWarning,
172
+ )
173
+ return self.run_sql(sql, params, **kwargs)
174
+
175
+ def get_dataframe(self, *args):
176
+ """Returns a Pandas DataFrame from a SQL query"""
177
+ return get_dataframe(self.engine, *args)
178
+
179
+ @property
180
+ def inspector(self):
181
+ if self.__inspector__ is None:
182
+ self.__inspector__ = inspect(self.engine)
183
+ return self.__inspector__
184
+
185
+ def entity_names(self, **kwargs):
186
+ """
187
+ Returns an iterator of names of *schema objects*
188
+ (both tables and views) from a the database.
189
+ """
190
+ yield from self.inspector.get_table_names(**kwargs)
191
+ yield from self.inspector.get_view_names(**kwargs)
192
+
193
+ def get(self, model, *args, **kwargs):
194
+ if isinstance(model, str):
195
+ model = getattr(self.model, model)
196
+ return self.session.query(model).get(*args, **kwargs)
197
+
198
+ def get_or_create(self, model, **kwargs):
199
+ """
200
+ Get an instance of a model, or create it if it doesn't
201
+ exist.
202
+ """
203
+ if isinstance(model, str):
204
+ model = getattr(self.model, model)
205
+ return get_or_create(self.session, model, **kwargs)
206
+
207
+ def reflect_table(self, *args, **kwargs):
208
+ """
209
+ One-off reflection of a database table or view. Note: for most purposes,
210
+ it will be better to use the database tables automapped at runtime using
211
+ `self.automap()`. Then, tables can be accessed using the
212
+ `self.table` object. However, this function can be useful for views (which
213
+ are not reflected automatically), or to customize type definitions for mapped
214
+ tables.
215
+
216
+ A set of `column_args` can be used to pass columns to override with the mapper, for
217
+ instance to set up foreign and primary key constraints.
218
+ https://docs.sqlalchemy.org/en/13/core/reflection.html#reflecting-views
219
+ """
220
+ warnings.warn(
221
+ "reflect_table is deprecated and will be removed in version 4.0. Shift away from table refection, or use reflect_table from the macrostrat.database.utils module.",
222
+ DeprecationWarning,
223
+ )
224
+
225
+ return reflect_table(self.engine, *args, **kwargs)
226
+
227
+ @property
228
+ def table(self):
229
+ """
230
+ Map of all tables in the database as SQLAlchemy table objects
231
+ """
232
+ if self.mapper is None or self.mapper._tables is None:
233
+ self.automap()
234
+ return self.mapper._tables
235
+
236
+ @property
237
+ def model(self):
238
+ """
239
+ Map of all tables in the database as SQLAlchemy models
240
+
241
+ https://docs.sqlalchemy.org/en/latest/orm/extensions/automap.html
242
+ """
243
+ if self.mapper is None or self.mapper._models is None:
244
+ self.automap()
245
+ return self.mapper._models
246
+
247
+ @property
248
+ def mapped_classes(self):
249
+ return self.model
250
+
251
+ @contextmanager
252
+ def transaction(self, *, rollback="on-error", connection=None, raise_errors=True):
253
+ """Create a database session that can be rolled back after use.
254
+ This is similar to the `session_scope` method but includes
255
+ more fine-grained control over transactions. The two methods may be integrated
256
+ in the future.
257
+
258
+ This is based on the Sparrow's implementation:
259
+ https://github.com/EarthCubeGeochron/Sparrow/blob/main/backend/conftest.py
260
+
261
+ It can be effectively used in a Pytest fixture like so:
262
+ ```
263
+ @fixture(scope="class")
264
+ def db(base_db):
265
+ with base_db.transaction(rollback=True):
266
+ yield base_db
267
+ """
268
+ if connection is None:
269
+ connection = self.engine.connect()
270
+ transaction = connection.begin()
271
+ session = Session(bind=connection)
272
+ prev_session = self.session
273
+ self.session = session
274
+
275
+ should_rollback = rollback == "always"
276
+
277
+ try:
278
+ yield self
279
+ except Exception as e:
280
+ should_rollback = rollback != "never"
281
+ if raise_errors:
282
+ raise e
283
+ finally:
284
+ if should_rollback:
285
+ transaction.rollback()
286
+ else:
287
+ transaction.commit()
288
+ session.close()
289
+ self.session = prev_session
290
+
291
+ savepoint_counter = 0
292
+
293
+ @contextmanager
294
+ def savepoint(self, name=None, rollback="on-error", connection=None):
295
+ """A PostgreSQL-specific savepoint context manager. This is similar to the
296
+ `transaction` context manager but uses savepoints directly for simpler operation.
297
+ Notably, it supports nested savepoints, a feature that is difficult in SQLAlchemy's `transaction`
298
+ model.
299
+
300
+ This function is not yet drop-in compatible with the `transaction` context manager, but that
301
+ is a future goal.
302
+ """
303
+ if name is None:
304
+ name = f"sp_{self.savepoint_counter}"
305
+ self.savepoint_counter += 1
306
+
307
+ _prev_session = self.session
308
+
309
+ if connection is None:
310
+ connection = self.session.connection()
311
+ params = {"name": Identifier(name)}
312
+ run_query(connection, "SAVEPOINT {name}", params)
313
+ should_rollback = rollback == "always"
314
+ self.session = Session(bind=connection)
315
+ try:
316
+ yield name
317
+ except Exception as e:
318
+ should_rollback = rollback != "never"
319
+ raise e
320
+ finally:
321
+ _clear_savepoint(connection, name, rollback=should_rollback)
322
+ self.session.close()
323
+ self.session = _prev_session
324
+
325
+
326
+ def _clear_savepoint(connection, name, rollback=True):
327
+ params = {"name": Identifier(name)}
328
+ try:
329
+ if rollback:
330
+ run_query(connection, "ROLLBACK TO SAVEPOINT {name}", params)
331
+ else:
332
+ run_query(connection, "RELEASE SAVEPOINT {name}", params)
333
+ except InternalError as err:
334
+ if isinstance(err.orig, InvalidSavepointSpecification):
335
+ log.warning(
336
+ f"Savepoint {name} does not exist; we may have already rolled back."
337
+ )
338
+ run_query(connection, "ROLLBACK")
@@ -4,8 +4,8 @@ from sqlalchemy.engine import Engine
4
4
 
5
5
  from macrostrat.utils import get_logger
6
6
 
7
- from .dump_database import _pg_dump
8
- from .restore_database import _pg_restore
7
+ from .dump_database import pg_dump
8
+ from .restore_database import pg_restore
9
9
  from .utils import print_stdout, print_stream_progress
10
10
 
11
11
  log = get_logger(__name__)
@@ -45,8 +45,8 @@ async def move_tables(
45
45
  log.debug(f"Dump args: {dump_args}")
46
46
  log.debug(f"Restore args: {restore_args}")
47
47
 
48
- source = await _pg_dump(from_database, **kwargs, args=dump_args)
49
- dest = await _pg_restore(to_database, **kwargs, args=restore_args)
48
+ source = await pg_dump(from_database, **kwargs, args=dump_args)
49
+ dest = await pg_restore(to_database, **kwargs, args=restore_args)
50
50
 
51
51
  await asyncio.gather(
52
52
  asyncio.create_task(print_stream_progress(source.stdout, dest.stdin)),
@@ -9,6 +9,7 @@ from click import echo, secho
9
9
  from psycopg2.extensions import set_wait_callback
10
10
  from psycopg2.extras import wait_select
11
11
  from psycopg2.sql import SQL, Composable, Composed
12
+ from rich.console import Console
12
13
  from sqlalchemy import MetaData, create_engine, text
13
14
  from sqlalchemy.engine import Connection, Engine
14
15
  from sqlalchemy.exc import (
@@ -346,6 +347,45 @@ def run_query(connectable, query, params=None, **kwargs):
346
347
  )
347
348
 
348
349
 
350
+ def get_sql_files(
351
+ fixtures: Union[Path, list[Path]], recursive=False, order_by_name=True
352
+ ):
353
+ files = []
354
+ if isinstance(fixtures, Path):
355
+ fixtures = [fixtures]
356
+ for fixture in fixtures:
357
+ files.extend(_get_sql_files(fixture, recursive))
358
+ if order_by_name:
359
+ files = sorted(files)
360
+ return files
361
+
362
+
363
+ def _get_sql_files(fixture: Path, recursive=False):
364
+ if not fixture.exists():
365
+ raise FileNotFoundError(f"Fixture {fixture} does not exist.")
366
+ if fixture.is_file() and fixture.suffix == ".sql":
367
+ return [fixture]
368
+ _fn = "rglob" if recursive else "glob"
369
+ files = getattr(fixture, _fn)("*.sql")
370
+ return [r for r in files if r.is_file()]
371
+
372
+
373
+ def run_fixtures(connectable, fixtures: Union[Path, list[Path]], params=None, **kwargs):
374
+ """
375
+ Run a set of SQL fixture files on a database. Fixtures can be passed as a list of file paths or a directory.
376
+ Fixtures are ordered by name by default, but this can be disabled.
377
+ """
378
+ recursive = kwargs.pop("recursive", False)
379
+ order_by_name = kwargs.pop("order_by_name", True)
380
+ console = kwargs.pop("console", Console(stderr=True))
381
+ files = get_sql_files(fixtures, recursive=recursive, order_by_name=order_by_name)
382
+
383
+ for fixture in files:
384
+ console.print(f"[cyan bold]{fixture}[/]")
385
+ run_sql_file(connectable, fixture, params, **kwargs)
386
+ console.print()
387
+
388
+
349
389
  def run_sql(*args, **kwargs):
350
390
  """
351
391
  Run a query on a SQLAlchemy connectable.
@@ -3,7 +3,7 @@ authors = ["Daven Quinn <dev@davenquinn.com>"]
3
3
  description = "A SQLAlchemy-based database toolkit."
4
4
  name = "macrostrat.database"
5
5
  packages = [{ include = "macrostrat" }]
6
- version = "3.2.0"
6
+ version = "3.3.0"
7
7
 
8
8
  [tool.poetry.dependencies]
9
9
  GeoAlchemy2 = "^0.14.0"
@@ -1,179 +0,0 @@
1
- import warnings
2
- from contextlib import contextmanager
3
- from typing import Optional
4
-
5
- from sqlalchemy import MetaData, create_engine, inspect, text
6
- from sqlalchemy.exc import IntegrityError
7
- from sqlalchemy.ext.compiler import compiles
8
- from sqlalchemy.orm import Session, scoped_session, sessionmaker
9
- from sqlalchemy.sql.expression import Insert
10
-
11
- from macrostrat.utils import get_logger
12
-
13
- from .mapper import DatabaseMapper
14
- from .postgresql import on_conflict, prefix_inserts # noqa
15
- from .utils import ( # noqa
16
- create_database,
17
- database_exists,
18
- drop_database,
19
- get_dataframe,
20
- get_or_create,
21
- reflect_table,
22
- run_query,
23
- run_sql,
24
- )
25
-
26
- metadata = MetaData()
27
-
28
- log = get_logger(__name__)
29
-
30
-
31
- class Database(object):
32
- mapper: Optional[DatabaseMapper] = None
33
- metadata: MetaData
34
- session: Session
35
- __inspector__ = None
36
-
37
- def __init__(self, db_conn, echo_sql=False, **kwargs):
38
- """
39
- We can pass a connection string, a **Flask** application object
40
- with the appropriate configuration, or nothing, in which
41
- case we will try to infer the correct database from
42
- the SPARROW_BACKEND_CONFIG file, if available.
43
- """
44
-
45
- compiles(Insert, "postgresql")(prefix_inserts)
46
-
47
- log.info(f"Setting up database connection '{db_conn}'")
48
- self.engine = create_engine(db_conn, echo=echo_sql, **kwargs)
49
- self.metadata = kwargs.get("metadata", metadata)
50
-
51
- # Scoped session for database
52
- # https://docs.sqlalchemy.org/en/13/orm/contextual.html#unitofwork-contextual
53
- # https://docs.sqlalchemy.org/en/13/orm/session_basics.html#session-faq-whentocreate
54
- self._session_factory = sessionmaker(bind=self.engine)
55
- self.session = scoped_session(self._session_factory)
56
- # Use the self.session_scope function to more explicitly manage sessions.
57
-
58
- def create_tables(self):
59
- """
60
- Create all tables described by the database's metadata instance.
61
- """
62
- metadata.create_all(bind=self.engine)
63
-
64
- def automap(self, **kwargs):
65
- log.info("Automapping the database")
66
- self.mapper = DatabaseMapper(self)
67
- self.mapper.reflect_database(**kwargs)
68
-
69
- @contextmanager
70
- def session_scope(self, commit=True):
71
- """Provide a transactional scope around a series of operations."""
72
- # self.__old_session = self.session
73
- # session = self._session_factory()
74
- session = self.session
75
- try:
76
- yield session
77
- if commit:
78
- session.commit()
79
- except Exception as err:
80
- session.rollback()
81
- raise err
82
- finally:
83
- session.close()
84
-
85
- def _flush_nested_objects(self, session):
86
- """
87
- Flush objects remaining in a session (generally these are objects loaded
88
- during schema-based importing).
89
- """
90
- for object in session:
91
- try:
92
- session.flush(objects=[object])
93
- log.debug(f"Successfully flushed instance {object}")
94
- except IntegrityError as err:
95
- session.rollback()
96
- log.debug(err)
97
-
98
- def run_sql(self, fn, params=None, **kwargs):
99
- """Executes SQL files passed"""
100
- return iter(run_sql(self.session, fn, params, **kwargs))
101
-
102
- def run_query(self, sql, params=None, **kwargs):
103
- return run_query(self.session, sql, params, **kwargs)
104
-
105
- def exec_sql(self, sql, params=None, **kwargs):
106
- """Executes SQL files passed"""
107
- warnings.warn("exec_sql is deprecated. Use run_sql instead", DeprecationWarning)
108
- return self.run_sql(sql, params, **kwargs)
109
-
110
- def get_dataframe(self, *args):
111
- """Returns a Pandas DataFrame from a SQL query"""
112
- return get_dataframe(self.engine, *args)
113
-
114
- @property
115
- def inspector(self):
116
- if self.__inspector__ is None:
117
- self.__inspector__ = inspect(self.engine)
118
- return self.__inspector__
119
-
120
- def entity_names(self, **kwargs):
121
- """
122
- Returns an iterator of names of *schema objects*
123
- (both tables and views) from a the database.
124
- """
125
- yield from self.inspector.get_table_names(**kwargs)
126
- yield from self.inspector.get_view_names(**kwargs)
127
-
128
- def get(self, model, *args, **kwargs):
129
- if isinstance(model, str):
130
- model = getattr(self.model, model)
131
- return self.session.query(model).get(*args, **kwargs)
132
-
133
- def get_or_create(self, model, **kwargs):
134
- """
135
- Get an instance of a model, or create it if it doesn't
136
- exist.
137
- """
138
- if isinstance(model, str):
139
- model = getattr(self.model, model)
140
- return get_or_create(self.session, model, **kwargs)
141
-
142
- def reflect_table(self, *args, **kwargs):
143
- """
144
- One-off reflection of a database table or view. Note: for most purposes,
145
- it will be better to use the database tables automapped at runtime using
146
- `self.automap()`. Then, tables can be accessed using the
147
- `self.table` object. However, this function can be useful for views (which
148
- are not reflected automatically), or to customize type definitions for mapped
149
- tables.
150
-
151
- A set of `column_args` can be used to pass columns to override with the mapper, for
152
- instance to set up foreign and primary key constraints.
153
- https://docs.sqlalchemy.org/en/13/core/reflection.html#reflecting-views
154
- """
155
- return reflect_table(self.engine, *args, **kwargs)
156
-
157
- @property
158
- def table(self):
159
- """
160
- Map of all tables in the database as SQLAlchemy table objects
161
- """
162
- if self.mapper is None or self.mapper._tables is None:
163
- self.automap()
164
- return self.mapper._tables
165
-
166
- @property
167
- def model(self):
168
- """
169
- Map of all tables in the database as SQLAlchemy models
170
-
171
- https://docs.sqlalchemy.org/en/latest/orm/extensions/automap.html
172
- """
173
- if self.mapper is None or self.mapper._models is None:
174
- self.automap()
175
- return self.mapper._models
176
-
177
- @property
178
- def mapped_classes(self):
179
- return self.model