macrostrat.database 1.0.2__tar.gz → 3.0.0b1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,19 @@
1
+ Metadata-Version: 2.1
2
+ Name: macrostrat.database
3
+ Version: 3.0.0b1
4
+ Summary: A SQLAlchemy-based database toolkit.
5
+ Author: Daven Quinn
6
+ Author-email: dev@davenquinn.com
7
+ Requires-Python: >=3.8,<4.0
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.8
10
+ Classifier: Programming Language :: Python :: 3.9
11
+ Classifier: Programming Language :: Python :: 3.10
12
+ Classifier: Programming Language :: Python :: 3.11
13
+ Requires-Dist: GeoAlchemy2 (>=0.14.0,<0.15.0)
14
+ Requires-Dist: SQLAlchemy (>=2.0.18,<3.0.0)
15
+ Requires-Dist: SQLAlchemy-Utils (>=0.41.1,<0.42.0)
16
+ Requires-Dist: click (>=8.1.3,<9.0.0)
17
+ Requires-Dist: macrostrat.utils (>=1.0.0,<2.0.0)
18
+ Requires-Dist: psycopg2-binary (>=2.9.6,<3.0.0)
19
+ Requires-Dist: sqlparse (>=0.4.4,<0.5.0)
@@ -1,21 +1,23 @@
1
+ import warnings
1
2
  from contextlib import contextmanager
2
- from pathlib import Path
3
3
  from typing import Optional
4
- from click import secho
5
4
 
6
5
  from sqlalchemy import create_engine, inspect, MetaData, text
7
- from sqlalchemy.orm import sessionmaker, scoped_session
8
- from sqlalchemy.schema import ForeignKey, Column
9
- from sqlalchemy.types import Integer
6
+ from sqlalchemy.orm import sessionmaker, scoped_session, Session
10
7
  from sqlalchemy.exc import IntegrityError
11
- from sqlalchemy.orm.exc import FlushError
12
- from macrostrat.utils import get_logger, relative_path
8
+ from macrostrat.utils import get_logger
13
9
  from sqlalchemy.ext.compiler import compiles
14
10
  from sqlalchemy.sql.expression import Insert
15
11
 
16
- from .utils import run_sql_file, run_query, get_or_create, run_sql_query_file
12
+
13
+ from .utils import (
14
+ run_sql,
15
+ get_or_create,
16
+ reflect_table,
17
+ get_dataframe,
18
+ )
17
19
  from .mapper import DatabaseMapper
18
- from .postgresql import on_conflict, prefix_inserts
20
+ from .postgresql import prefix_inserts, on_conflict # noqa
19
21
 
20
22
 
21
23
  metadata = MetaData()
@@ -25,9 +27,11 @@ log = get_logger(__name__)
25
27
 
26
28
  class Database(object):
27
29
  mapper: Optional[DatabaseMapper] = None
30
+ metadata: MetaData
31
+ session: Session
28
32
  __inspector__ = None
29
33
 
30
- def __init__(self, db_conn, app=None, echo_sql=False, **kwargs):
34
+ def __init__(self, db_conn, echo_sql=False, **kwargs):
31
35
  """
32
36
  We can pass a connection string, a **Flask** application object
33
37
  with the appropriate configuration, or nothing, in which
@@ -38,11 +42,8 @@ class Database(object):
38
42
  compiles(Insert, "postgresql")(prefix_inserts)
39
43
 
40
44
  log.info(f"Setting up database connection '{db_conn}'")
41
- self.engine = create_engine(
42
- db_conn, executemany_mode="batch", echo=echo_sql, **kwargs
43
- )
44
- metadata.create_all(bind=self.engine)
45
- self.meta = metadata
45
+ self.engine = create_engine(db_conn, echo=echo_sql, **kwargs)
46
+ self.metadata = kwargs.get("metadata", metadata)
46
47
 
47
48
  # Scoped session for database
48
49
  # https://docs.sqlalchemy.org/en/13/orm/contextual.html#unitofwork-contextual
@@ -51,10 +52,16 @@ class Database(object):
51
52
  self.session = scoped_session(self._session_factory)
52
53
  # Use the self.session_scope function to more explicitly manage sessions.
53
54
 
54
- def automap(self):
55
+ def create_tables(self):
56
+ """
57
+ Create all tables described by the database's metadata instance.
58
+ """
59
+ metadata.create_all(bind=self.engine)
60
+
61
+ def automap(self, **kwargs):
55
62
  log.info("Automapping the database")
56
63
  self.mapper = DatabaseMapper(self)
57
- self.mapper.automap_database()
64
+ self.mapper.reflect_database(**kwargs)
58
65
 
59
66
  @contextmanager
60
67
  def session_scope(self, commit=True):
@@ -85,26 +92,18 @@ class Database(object):
85
92
  session.rollback()
86
93
  log.debug(err)
87
94
 
88
- def exec_sql_text(self, statement, *args, **kwargs):
89
- """
90
- Executes a sql command, in string on the database
91
- Easy way to load data into a test database instance
92
- """
93
- connection = self.engine.connect()
94
- connection.execute(text(statement), *args, **kwargs)
95
-
96
- def exec_sql(self, fn, params=None):
95
+ def run_sql(self, fn, **kwargs):
97
96
  """Executes SQL files passed"""
98
- # TODO: refactor this to exec_sql_file
99
- secho(Path(fn).name, fg="cyan", bold=True)
100
- run_sql_file(self.session, str(fn), params)
97
+ return iter(run_sql(self.session, fn, **kwargs))
101
98
 
102
- def exec_sql_query(self, fn, params=None):
103
- return run_sql_query_file(self.session, fn, params)
99
+ def exec_sql(self, sql, **kwargs):
100
+ """Executes SQL files passed"""
101
+ warnings.warn("exec_sql is deprecated. Use run_sql instead", DeprecationWarning)
102
+ return self.run_sql(sql, **kwargs)
104
103
 
105
- def exec_query(self, *args):
104
+ def get_dataframe(self, *args):
106
105
  """Returns a Pandas DataFrame from a SQL query"""
107
- return run_query(self.engine, *args)
106
+ return get_dataframe(self.engine, *args)
108
107
 
109
108
  @property
110
109
  def inspector(self):
@@ -134,12 +133,27 @@ class Database(object):
134
133
  model = getattr(self.model, model)
135
134
  return get_or_create(self.session, model, **kwargs)
136
135
 
136
+ def reflect_table(self, *args, **kwargs):
137
+ """
138
+ One-off reflection of a database table or view. Note: for most purposes,
139
+ it will be better to use the database tables automapped at runtime using
140
+ `self.automap()`. Then, tables can be accessed using the
141
+ `self.table` object. However, this function can be useful for views (which
142
+ are not reflected automatically), or to customize type definitions for mapped
143
+ tables.
144
+
145
+ A set of `column_args` can be used to pass columns to override with the mapper, for
146
+ instance to set up foreign and primary key constraints.
147
+ https://docs.sqlalchemy.org/en/13/core/reflection.html#reflecting-views
148
+ """
149
+ return reflect_table(self.engine, *args, **kwargs)
150
+
137
151
  @property
138
152
  def table(self):
139
153
  """
140
154
  Map of all tables in the database as SQLAlchemy table objects
141
155
  """
142
- if self.mapper._tables is None:
156
+ if self.mapper is None or self.mapper._tables is None:
143
157
  self.automap()
144
158
  return self.mapper._tables
145
159
 
@@ -150,7 +164,7 @@ class Database(object):
150
164
 
151
165
  https://docs.sqlalchemy.org/en/latest/orm/extensions/automap.html
152
166
  """
153
- if self.mapper._models is None:
167
+ if self.mapper is None or self.mapper._models is None:
154
168
  self.automap()
155
169
  return self.mapper._models
156
170
 
@@ -1,12 +1,13 @@
1
- from sqlalchemy.schema import Table
2
- from sqlalchemy import MetaData
1
+ from distutils.log import warn
2
+ from macrostrat.database.utils import reflect_table
3
3
  from sqlalchemy.ext.automap import generate_relationship
4
4
  from macrostrat.utils.logs import get_logger
5
- from .cache import DatabaseModelCache
5
+ from warnings import warn
6
6
 
7
7
  # Drag in geographic types for database reflection
8
8
  from geoalchemy2 import Geometry, Geography
9
9
 
10
+ from .cache import DatabaseModelCache
10
11
  from .utils import (
11
12
  ModelCollection,
12
13
  TableCollection,
@@ -18,12 +19,15 @@ from .utils import (
18
19
 
19
20
  log = get_logger(__name__)
20
21
 
22
+
21
23
  class AutomapError(Exception):
22
24
  pass
23
25
 
26
+
24
27
  model_builder = DatabaseModelCache()
25
28
  BaseModel = model_builder.automap_base()
26
29
 
30
+
27
31
  class DatabaseMapper:
28
32
  automap_base = BaseModel
29
33
  automap_error = None
@@ -38,10 +42,16 @@ class DatabaseMapper:
38
42
 
39
43
  # This stuff should be placed outside of core (one likely extension point).
40
44
  self.reflection_kwargs = dict(
41
- name_for_scalar_relationship=kwargs.get("name_for_scalar_relationship", name_for_scalar_relationship),
42
- name_for_collection_relationship=kwargs.get("name_for_collection_relationship", name_for_collection_relationship),
45
+ name_for_scalar_relationship=kwargs.get(
46
+ "name_for_scalar_relationship", name_for_scalar_relationship
47
+ ),
48
+ name_for_collection_relationship=kwargs.get(
49
+ "name_for_collection_relationship", name_for_collection_relationship
50
+ ),
43
51
  classname_for_table=kwargs.get("classname_for_table", _classname_for_table),
44
- generate_relationship=kwargs.get("generate_relationship", generate_relationship),
52
+ generate_relationship=kwargs.get(
53
+ "generate_relationship", generate_relationship
54
+ ),
45
55
  )
46
56
 
47
57
  self._models = ModelCollection(self.automap_base.classes)
@@ -66,45 +76,26 @@ class DatabaseMapper:
66
76
  def reflect_schema(self, schema, use_cache=True):
67
77
  if use_cache and self.automap_base.loaded_from_cache:
68
78
  log.info("Database models for %s have been loaded from cache", schema)
69
- self.automap_base.prepare(
70
- self.db.engine, schema=schema, **self.reflection_kwargs
71
- )
79
+ self.automap_base.prepare(schema=schema, **self.reflection_kwargs)
72
80
  return
73
81
  log.info(f"Reflecting schema {schema}")
74
82
  if schema == "public":
75
- self.automap_base.prepare(
76
- self.db.engine, reflect=True, schema=None, **self.reflection_kwargs
77
- )
78
- else:
79
- # Reflect tables in schemas we care about
80
- # Note: this will not reflect views because they don't have primary keys.
81
- self.automap_base.metadata.reflect(bind=self.db.engine, schema=schema, **self.reflection_kwargs)
83
+ schema = None
84
+ # Reflect tables in schemas we care about
85
+ # Note: this will not reflect views because they don't have primary keys.
86
+ self.automap_base.prepare(
87
+ autoload_with=self.db.engine, schema=schema, **self.reflection_kwargs
88
+ )
82
89
  self._models = ModelCollection(self.automap_base.classes)
83
90
  self._tables = TableCollection(self._models)
84
91
 
85
92
  def reflect_table(self, tablename, *column_args, **kwargs):
86
- """
87
- One-off reflection of a database table or view. Note: for most purposes,
88
- it will be better to use the database tables automapped at runtime in the
89
- `self.tables` object. However, this function can be useful for views (which
90
- are not reflected automatically), or to customize type definitions for mapped
91
- tables.
92
-
93
- A set of `column_args` can be used to pass columns to override with the mapper, for
94
- instance to set up foreign and primary key constraints.
95
- https://docs.sqlalchemy.org/en/13/core/reflection.html#reflecting-views
96
- """
97
- schema = kwargs.pop("schema", "public")
98
- meta = MetaData(schema=schema)
99
- tables = Table(
100
- tablename,
101
- meta,
102
- *column_args,
103
- autoload=True,
104
- autoload_with=self.db.engine,
105
- **kwargs,
93
+ # Warn that this method is deprecated
94
+ warn(
95
+ "DatabaseMapper.reflect_table is deprecated. Use Database.reflect_table instead",
96
+ DeprecationWarning,
106
97
  )
107
- return tables
98
+ return reflect_table(self.db.engine, tablename, *column_args, **kwargs)
108
99
 
109
100
  def reflect_view(self, tablename, *column_args, **kwargs):
110
101
  pass
@@ -7,16 +7,17 @@ from .base import ModelHelperMixins
7
7
 
8
8
  log = get_logger(__name__)
9
9
 
10
+
10
11
  class AutomapError(Exception):
11
12
  pass
12
13
 
14
+
13
15
  class DatabaseModelCache(object):
14
16
  cache_file = None
15
17
 
16
18
  def __init__(self, cache_file=None):
17
19
  self.cache_file = cache_file
18
20
 
19
-
20
21
  @property
21
22
  def _metadata_cache_filename(self):
22
23
  return self.cache_file
@@ -36,9 +37,7 @@ class DatabaseModelCache(object):
36
37
  log.info(f"Cached database models to {self.cache_file}")
37
38
  except IOError:
38
39
  # couldn't write the file for some reason
39
- log.info(
40
- f"Could not cache database models to {self.cache_file}"
41
- )
40
+ log.info(f"Could not cache database models to {self.cache_file}")
42
41
 
43
42
  def _load_database_map(self):
44
43
  # We have hard-coded the cache file for now
@@ -54,6 +53,8 @@ class DatabaseModelCache(object):
54
53
  log.info(
55
54
  f"Could not find database model cache ({self._metadata_cache_filename})"
56
55
  )
56
+ except Exception as exc:
57
+ log.error(f"Error loading database model cache: {exc}")
57
58
  return cached_metadata
58
59
 
59
60
  def automap_base(self):
@@ -66,4 +67,3 @@ class DatabaseModelCache(object):
66
67
  base.loaded_from_cache = True
67
68
  base.builder = self
68
69
  return base
69
-
@@ -6,6 +6,7 @@ def primary_key(instance):
6
6
  prop_list = [mapper.get_property_by_column(column) for column in mapper.primary_key]
7
7
  return {prop.key: getattr(instance, prop.key) for prop in prop_list}
8
8
 
9
+
9
10
  def classname_for_table(table):
10
11
  if table.schema is not None:
11
12
  return f"{table.schema}_{table.name}"
@@ -83,6 +84,14 @@ class ModelCollection(BaseCollection):
83
84
  def keys(self):
84
85
  return [k for k in self.__models.keys()]
85
86
 
87
+ # Support for dict-like access
88
+ def __getitem__(self, key):
89
+ return self.__models[key]
90
+
91
+ # Support 'in' operator
92
+ def __contains__(self, key):
93
+ return key in self.__models
94
+
86
95
 
87
96
  class TableCollection(BaseCollection):
88
97
  """
@@ -5,9 +5,12 @@ from sqlalchemy.exc import CompileError
5
5
  from sqlalchemy.ext.compiler import compiles
6
6
  from sqlalchemy.sql.expression import Insert
7
7
  from sqlalchemy.dialects import postgresql
8
+ import psycopg2
9
+
8
10
 
9
11
  _import_mode = ContextVar("import-mode", default="do-nothing")
10
12
 
13
+
11
14
  # https://stackoverflow.com/questions/33307250/postgresql-on-conflict-in-sqlalchemy/62305344#62305344
12
15
  @contextmanager
13
16
  def on_conflict(action="restrict"):
@@ -47,3 +50,17 @@ def prefix_inserts(insert, compiler, **kw):
47
50
  index_elements=insert.table.primary_key
48
51
  )
49
52
  return compiler.visit_insert(insert, **kw)
53
+
54
+
55
+ _psycopg2_setup_was_run = ContextVar("psycopg2-setup-was-run", default=False)
56
+
57
+
58
+ def _setup_psycopg2_wait_callback():
59
+ """Set up the wait callback for PostgreSQL connections. This allows for query cancellation with Ctrl-C."""
60
+ # TODO: we might want to do this only once on engine creation
61
+ # https://github.com/psycopg/psycopg2/issues/333
62
+ val = _psycopg2_setup_was_run.get()
63
+ if val:
64
+ return
65
+ psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select)
66
+ _psycopg2_setup_was_run.set(True)
@@ -0,0 +1,450 @@
1
+ from click import echo, secho
2
+ from sqlalchemy.exc import ProgrammingError, IntegrityError, InternalError
3
+ from sqlparse import split, format
4
+ from sqlalchemy.sql.elements import TextClause, ClauseElement
5
+ from sqlalchemy.orm import sessionmaker
6
+ from sqlalchemy.engine import Engine, Connection
7
+ from sqlalchemy.schema import Table
8
+ from sqlalchemy import MetaData, create_engine, text
9
+ from contextlib import contextmanager
10
+ from sqlalchemy_utils import create_database, database_exists, drop_database
11
+ from sqlalchemy.exc import InvalidRequestError
12
+ from macrostrat.utils import cmd, get_logger
13
+ from time import sleep
14
+ from typing import Union, IO
15
+ from pathlib import Path
16
+ from warnings import warn
17
+ from psycopg2.sql import SQL, Composable, Composed
18
+ from re import search
19
+ from macrostrat.utils import get_logger
20
+ from .postgresql import _setup_psycopg2_wait_callback
21
+
22
+ log = get_logger(__name__)
23
+
24
+
25
+ def db_session(engine):
26
+ factory = sessionmaker(bind=engine)
27
+ return factory()
28
+
29
+
30
+ def infer_is_sql_text(_string: str) -> bool:
31
+ """
32
+ Return True if the string is a valid SQL query,
33
+ false if it should be interpreted as a file path.
34
+ """
35
+ # If it's a byte string, decode it
36
+ if isinstance(_string, bytes):
37
+ _string = _string.decode("utf-8")
38
+
39
+ keywords = [
40
+ "SELECT",
41
+ "INSERT",
42
+ "UPDATE",
43
+ "CREATE",
44
+ "DROP",
45
+ "DELETE",
46
+ "ALTER",
47
+ "SET",
48
+ ]
49
+ lines = _string.split("\n")
50
+ if len(lines) > 1:
51
+ return True
52
+ _string = _string.lower()
53
+ for i in keywords:
54
+ if _string.strip().startswith(i.lower()):
55
+ return True
56
+ return False
57
+
58
+
59
+ def canonicalize_query(file_or_text: Union[str, Path, IO]) -> Union[str, Path]:
60
+ if isinstance(file_or_text, Path):
61
+ return file_or_text
62
+ # If it's a file-like object, read it
63
+ if hasattr(file_or_text, "read"):
64
+ return file_or_text.read()
65
+ # Otherwise, assume it's a string
66
+ if infer_is_sql_text(file_or_text):
67
+ return file_or_text
68
+ pth = Path(file_or_text)
69
+ if pth.exists() and pth.is_file():
70
+ return pth
71
+ return file_or_text
72
+
73
+
74
+ def get_dataframe(connectable, filename_or_query, **kwargs):
75
+ """
76
+ Run a query on a SQL database (represented by
77
+ a SQLAlchemy database object) and turn it into a
78
+ `Pandas` dataframe.
79
+ """
80
+ from pandas import read_sql
81
+
82
+ sql = get_sql_text(filename_or_query)
83
+
84
+ return read_sql(sql, connectable, **kwargs)
85
+
86
+
87
+ def pretty_print(sql, **kwargs):
88
+ for line in sql.split("\n"):
89
+ for i in ["SELECT", "INSERT", "UPDATE", "CREATE", "DROP", "DELETE", "ALTER"]:
90
+ if not line.startswith(i):
91
+ continue
92
+ start = line.split("(")[0].strip().rstrip(";").replace(" AS", "")
93
+ secho(start, **kwargs)
94
+ return
95
+
96
+
97
+ def get_sql_text(sql, interpret_as_file=None, echo_file_name=True):
98
+ if interpret_as_file:
99
+ sql = Path(sql).read_text()
100
+ elif interpret_as_file is None:
101
+ sql = canonicalize_query(sql)
102
+
103
+ if isinstance(sql, Path):
104
+ if echo_file_name:
105
+ secho(sql.name, fg="cyan", bold=True)
106
+ sql = sql.read_text()
107
+
108
+ return sql
109
+
110
+
111
+ def _get_queries(sql, interpret_as_file=None):
112
+ if isinstance(sql, (list, tuple)):
113
+ queries = []
114
+ for i in sql:
115
+ queries.extend(_get_queries(i, interpret_as_file=interpret_as_file))
116
+ return queries
117
+ if isinstance(sql, TextClause):
118
+ return [sql]
119
+ if isinstance(sql, SQL):
120
+ return [sql]
121
+
122
+ if sql in [None, ""]:
123
+ return
124
+ if interpret_as_file:
125
+ sql = Path(sql).read_text()
126
+ elif interpret_as_file is None:
127
+ sql = canonicalize_query(sql)
128
+
129
+ if isinstance(sql, Path):
130
+ sql = sql.read_text()
131
+
132
+ return split(sql)
133
+
134
+
135
+ def _is_prebind_param(param):
136
+ return isinstance(param, Composable)
137
+
138
+
139
+ def _split_params(params):
140
+ if params is None:
141
+ return None, None
142
+ new_params = []
143
+ new_bind_params = []
144
+ if isinstance(params, (list, tuple)):
145
+ for i in params:
146
+ if _is_prebind_param(i):
147
+ new_bind_params.append(i)
148
+ else:
149
+ new_params.append(i)
150
+ elif isinstance(params, dict):
151
+ new_params = {}
152
+ new_bind_params = {}
153
+ for k, v in params.items():
154
+ if _is_prebind_param(v):
155
+ new_bind_params[k] = v
156
+ else:
157
+ new_params[k] = v
158
+ if len(new_bind_params) == 0:
159
+ new_bind_params = None
160
+ return new_params, new_bind_params
161
+
162
+
163
+ def _get_cursor(connectable):
164
+ if isinstance(connectable, Engine):
165
+ conn = connectable.connect()
166
+
167
+ # Find a connection or cursor object for the connectable
168
+ conn = connectable
169
+ if hasattr(conn, "raw_connection"):
170
+ conn = conn.raw_connection()
171
+ while hasattr(conn, "driver_connection") or hasattr(conn, "connection"):
172
+ if hasattr(conn, "driver_connection"):
173
+ conn = conn.driver_connection
174
+ else:
175
+ conn = conn.connection
176
+ if callable(conn):
177
+ conn = conn()
178
+ if hasattr(conn, "cursor"):
179
+ conn = conn.cursor()
180
+
181
+ return conn
182
+
183
+
184
+ def _get_connection(connectable) -> Connection:
185
+ if isinstance(connectable, Engine):
186
+ return connectable.connect()
187
+ if isinstance(connectable, Connection):
188
+ return connectable
189
+ if not hasattr(connectable, "connection"):
190
+ return connectable
191
+ conn = connectable.connection
192
+ if callable(conn):
193
+ return conn()
194
+ return conn
195
+
196
+
197
+ def _render_query(query: Union[SQL, Composed], connectable: Union[Engine, Connection]):
198
+ """Render a query to a SQL string."""
199
+ if not isinstance(query, (Composed, SQL)):
200
+ return query
201
+ # Find a connection or cursor object for the connectable
202
+ conn = _get_cursor(connectable)
203
+ return query.as_string(conn)
204
+
205
+
206
+ def infer_has_server_binds(sql):
207
+ return "%s" in sql or search(r"%\(\w+\)s", sql)
208
+
209
+
210
+ def _run_sql(connectable, sql, **kwargs):
211
+ """
212
+ Internal function for running a query on a SQLAlchemy connectable,
213
+ which always returns an iterator. The wrapper function adds the option
214
+ to return a list of results.
215
+ """
216
+ if isinstance(connectable, Engine):
217
+ with connectable.connect() as conn:
218
+ yield from _run_sql(conn, sql, **kwargs)
219
+ return
220
+
221
+ _setup_psycopg2_wait_callback()
222
+
223
+ params = kwargs.pop("params", None)
224
+ stop_on_error = kwargs.pop("stop_on_error", False)
225
+ raise_errors = kwargs.pop("raise_errors", False)
226
+ has_server_binds = kwargs.pop("has_server_binds", None)
227
+
228
+ if stop_on_error:
229
+ raise_errors = True
230
+ warn(DeprecationWarning("stop_on_error is deprecated, use raise_errors"))
231
+
232
+ interpret_as_file = kwargs.pop("interpret_as_file", None)
233
+
234
+ queries = _get_queries(sql, interpret_as_file=interpret_as_file)
235
+
236
+ if queries is None:
237
+ return
238
+
239
+ # check if parameters is a list of the same length as the number of queries
240
+ if not isinstance(params, list) or not len(params) == len(queries):
241
+ params = [params] * len(queries)
242
+
243
+ for query, params in zip(queries, params):
244
+ trans = None
245
+ try:
246
+ trans = connectable.begin()
247
+ except InvalidRequestError:
248
+ trans = None
249
+ try:
250
+ params, pre_bind_params = _split_params(params)
251
+
252
+ if pre_bind_params is not None:
253
+ if not isinstance(query, SQL):
254
+ query = SQL(query)
255
+ # Pre-bind the parameters using PsycoPG2
256
+ query = query.format(**pre_bind_params)
257
+
258
+ if isinstance(query, (SQL, Composed)):
259
+ query = _render_query(query, connectable)
260
+
261
+ sql_text = str(query)
262
+ if isinstance(query, str):
263
+ sql_text = format(query, strip_comments=True).strip()
264
+ if sql_text == "":
265
+ continue
266
+ # Check for server-bound parameters in sql native style. If there are none, use
267
+ # the SQLAlchemy text() function, otherwise use the raw query string
268
+ if has_server_binds is None:
269
+ has_server_binds = infer_has_server_binds(sql_text)
270
+
271
+ log.debug("Executing SQL: \n %s", query)
272
+ if has_server_binds:
273
+ conn = _get_connection(connectable)
274
+ res = conn.exec_driver_sql(query, params)
275
+ else:
276
+ if not isinstance(query, TextClause):
277
+ query = text(query)
278
+ res = connectable.execute(query, params)
279
+ yield res
280
+ if trans is not None:
281
+ trans.commit()
282
+ elif hasattr(connectable, "commit"):
283
+ connectable.commit()
284
+ pretty_print(sql_text, dim=True)
285
+ except (ProgrammingError, IntegrityError, InternalError) as err:
286
+ _err = str(err.orig).strip()
287
+ dim = "already exists" in _err
288
+ if trans is not None:
289
+ trans.rollback()
290
+ elif hasattr(connectable, "rollback"):
291
+ connectable.rollback()
292
+ pretty_print(sql_text, fg=None if dim else "red", dim=True)
293
+ if dim:
294
+ _err = " " + _err
295
+ secho(_err, fg="red", dim=dim)
296
+ log.error(err)
297
+ if raise_errors:
298
+ raise err
299
+
300
+
301
+ def run_sql_file(connectable, filename, **kwargs):
302
+ return run_sql(connectable, filename, interpret_as_file=True, **kwargs)
303
+
304
+
305
+ def run_sql(*args, **kwargs):
306
+ """
307
+ Run a query on a SQLAlchemy connectable.
308
+
309
+ Parameters
310
+ ----------
311
+ connectable : Union[Engine, Connection]
312
+ A SQLAlchemy engine or connection object.
313
+ sql : Union[str, Path, IO, SQL, Composed]
314
+ A SQL query, or a file containing a SQL query.
315
+ params : Union[dict, list, tuple]
316
+ Parameters to bind to the query. If a list or tuple, the parameters
317
+ will be bound to the query in order. If a dict, the parameters will
318
+ be bound to the query by name.
319
+ stop_on_error : bool
320
+ If True, stop running queries if an error is encountered.
321
+ raise_errors : bool
322
+ If True, raise errors encountered while running queries.
323
+ has_server_binds : bool
324
+ Interpret the query to have server-side bind parameters (requiring execution
325
+ with the backend driver). By default, this is inferred from the query string,
326
+ but inference is not always reliable.
327
+ interpret_as_file : bool
328
+ If True, force interpreting the query as a file path.
329
+ yield_results : bool
330
+ If True, yield the results of the query as they are executed, rather than
331
+ returning a list after completion.
332
+ """
333
+ res = _run_sql(*args, **kwargs)
334
+ if kwargs.pop("yield_results", False):
335
+ return res
336
+ return list(res)
337
+
338
+
339
+ def execute(connectable, sql, params=None, stop_on_error=False):
340
+ sql = format(sql, strip_comments=True).strip()
341
+ if sql == "":
342
+ return
343
+ try:
344
+ connectable.begin()
345
+ res = connectable.execute(text(sql), params=params)
346
+ if hasattr(connectable, "commit"):
347
+ connectable.commit()
348
+ pretty_print(sql, dim=True)
349
+ return res
350
+ except (ProgrammingError, IntegrityError) as err:
351
+ err = str(err.orig).strip()
352
+ dim = "already exists" in err
353
+ if hasattr(connectable, "rollback"):
354
+ connectable.rollback()
355
+ pretty_print(sql, fg=None if dim else "red", dim=True)
356
+ if dim:
357
+ err = " " + err
358
+ secho(err, fg="red", dim=dim)
359
+ if stop_on_error:
360
+ return
361
+ finally:
362
+ if hasattr(connectable, "close"):
363
+ connectable.close()
364
+
365
+
366
+ def get_or_create(session, model, defaults=None, **kwargs):
367
+ """
368
+ Get an instance of a model, or create it if it doesn't
369
+ exist.
370
+
371
+ https://stackoverflow.com/questions/2546207
372
+ """
373
+ instance = session.query(model).filter_by(**kwargs).first()
374
+ if instance:
375
+ instance._created = False
376
+ return instance
377
+ else:
378
+ params = dict(
379
+ (k, v) for k, v in kwargs.items() if not isinstance(v, ClauseElement)
380
+ )
381
+ params.update(defaults or {})
382
+ instance = model(**params)
383
+ session.add(instance)
384
+ instance._created = True
385
+ return instance
386
+
387
+
388
+ def get_db_model(db, model_name: str):
389
+ return getattr(db.model, model_name)
390
+
391
+
392
+ @contextmanager
393
+ def temp_database(conn_string, drop=True, ensure_empty=False):
394
+ """Create a temporary database and tear it down after tests."""
395
+ if ensure_empty:
396
+ drop_database(conn_string)
397
+ if not database_exists(conn_string):
398
+ create_database(conn_string)
399
+ try:
400
+ yield create_engine(conn_string)
401
+ finally:
402
+ if drop:
403
+ drop_database(conn_string)
404
+
405
+
406
+ def connection_args(engine):
407
+ """Get PostgreSQL connection arguments for an engine"""
408
+ _psql_flags = {"-U": "username", "-h": "host", "-p": "port", "-P": "password"}
409
+
410
+ if isinstance(engine, str):
411
+ # We passed a connection url!
412
+ engine = create_engine(engine)
413
+ flags = ""
414
+ for flag, _attr in _psql_flags.items():
415
+ val = getattr(engine.url, _attr)
416
+ if val is not None:
417
+ flags += f" {flag} {val}"
418
+ return flags, engine.url.database
419
+
420
+
421
+ def db_isready(engine_or_url):
422
+ args, _ = connection_args(engine_or_url)
423
+ c = cmd("pg_isready", args, capture_output=True)
424
+ return c.returncode == 0
425
+
426
+
427
+ def wait_for_database(engine_or_url, quiet=False):
428
+ msg = "Waiting for database..."
429
+ while not db_isready(engine_or_url):
430
+ if not quiet:
431
+ echo(msg, err=True)
432
+ log.info(msg)
433
+ sleep(1)
434
+
435
+
436
+ def reflect_table(engine, tablename, *column_args, **kwargs):
437
+ """
438
+ One-off reflection of a database table or view. Note: for most purposes,
439
+ it will be better to use the database tables automapped at runtime in the
440
+ `self.tables` object. However, this function can be useful for views (which
441
+ are not reflected automatically), or to customize type definitions for mapped
442
+ tables.
443
+
444
+ A set of `column_args` can be used to pass columns to override with the mapper, for
445
+ instance to set up foreign and primary key constraints.
446
+ https://docs.sqlalchemy.org/en/13/core/reflection.html#reflecting-views
447
+ """
448
+ schema = kwargs.pop("schema", "public")
449
+ meta = MetaData(schema=schema)
450
+ return Table(tablename, meta, *column_args, autoload_with=engine, **kwargs)
@@ -0,0 +1,23 @@
1
+ [tool.poetry]
2
+ authors = ["Daven Quinn <dev@davenquinn.com>"]
3
+ description = "A SQLAlchemy-based database toolkit."
4
+ name = "macrostrat.database"
5
+ packages = [{ include = "macrostrat" }]
6
+ version = "3.0.0-beta1"
7
+
8
+ [tool.poetry.dependencies]
9
+ GeoAlchemy2 = "^0.14.0"
10
+ SQLAlchemy = "^2.0.18"
11
+ SQLAlchemy-Utils = "^0.41.1"
12
+ click = "^8.1.3"
13
+ "macrostrat.utils" = "^1.0.0"
14
+ psycopg2-binary = "^2.9.6"
15
+ python = "^3.8"
16
+ sqlparse = "^0.4.4"
17
+
18
+ [tool.poetry.dev-dependencies]
19
+ "macrostrat.utils" = { path = "../utils", develop = true }
20
+
21
+ [build-system]
22
+ build-backend = "poetry.core.masonry.api"
23
+ requires = ["poetry-core>=1.0.0"]
@@ -1,20 +0,0 @@
1
- Metadata-Version: 2.1
2
- Name: macrostrat.database
3
- Version: 1.0.2
4
- Summary: A small library based on SQLAlchemy to assist with common database tasks.
5
- Author: Daven Quinn
6
- Author-email: dev@davenquinn.com
7
- Requires-Python: >=3.8,<4.0
8
- Classifier: Programming Language :: Python :: 3
9
- Classifier: Programming Language :: Python :: 3.10
10
- Classifier: Programming Language :: Python :: 3.8
11
- Classifier: Programming Language :: Python :: 3.9
12
- Requires-Dist: GeoAlchemy2 (>=0.9.4,<0.10.0)
13
- Requires-Dist: SQLAlchemy (>=1.4.26,<2.0.0)
14
- Requires-Dist: SQLAlchemy-Utils (>=0.37.0,<0.38.0)
15
- Requires-Dist: click (>=8.1.3,<9.0.0)
16
- Requires-Dist: macrostrat.utils (>=1.0.0,<2.0.0)
17
- Requires-Dist: migra (>=3.0.1621480950,<4.0.0)
18
- Requires-Dist: psycopg2-binary (>=2.9.1,<3.0.0)
19
- Requires-Dist: schemainspect (>=3.0.1616029793,<4.0.0)
20
- Requires-Dist: sqlparse (>=0.4.0,<0.5.0)
@@ -1,164 +0,0 @@
1
- from click import echo, secho
2
- from sqlalchemy.exc import ProgrammingError, IntegrityError
3
- from sqlparse import split, format
4
- from sqlalchemy.sql import ClauseElement
5
- from sqlalchemy import create_engine, text
6
- from sqlalchemy.orm import sessionmaker
7
- from contextlib import contextmanager
8
- from sqlalchemy_utils import create_database, database_exists, drop_database
9
- from macrostrat.utils import cmd, get_logger
10
- from time import sleep
11
-
12
- log = get_logger(__name__)
13
-
14
-
15
- def db_session(engine):
16
- factory = sessionmaker(bind=engine)
17
- return factory()
18
-
19
-
20
- def run_query(db, filename_or_query, **kwargs):
21
- """
22
- Run a query on a SQL database (represented by
23
- a SQLAlchemy database object) and turn it into a
24
- `Pandas` dataframe.
25
- """
26
- from pandas import read_sql
27
-
28
- if "SELECT" in str(filename_or_query):
29
- # We are working with a query string instead of
30
- # an SQL file.
31
- sql = filename_or_query
32
- else:
33
- with open(filename_or_query) as f:
34
- sql = f.read()
35
-
36
- return read_sql(sql, db, **kwargs)
37
-
38
-
39
- def pretty_print(sql, **kwargs):
40
- for line in sql.split("\n"):
41
- for i in ["SELECT", "INSERT", "UPDATE", "CREATE", "DROP", "DELETE", "ALTER"]:
42
- if not line.startswith(i):
43
- continue
44
- start = line.split("(")[0].strip().rstrip(";").replace(" AS", "")
45
- secho(start, **kwargs)
46
- return
47
-
48
-
49
- def run_sql(session, sql, params=None, stop_on_error=False):
50
- queries = split(sql)
51
- for q in queries:
52
- sql = format(q, strip_comments=True).strip()
53
- if sql == "":
54
- continue
55
- try:
56
- session.execute(text(sql), params=params)
57
- if hasattr(session, "commit"):
58
- session.commit()
59
- pretty_print(sql, dim=True)
60
- except (ProgrammingError, IntegrityError) as err:
61
- err = str(err.orig).strip()
62
- dim = "already exists" in err
63
- if hasattr(session, "rollback"):
64
- session.rollback()
65
- pretty_print(sql, fg=None if dim else "red", dim=True)
66
- if dim:
67
- err = " " + err
68
- secho(err, fg="red", dim=dim)
69
- if stop_on_error:
70
- return
71
-
72
-
73
- def _exec_raw_sql(engine, sql):
74
- """Execute SQL unsafely on an sqlalchemy Engine"""
75
- try:
76
- engine.execute(text(sql))
77
- pretty_print(sql, dim=True)
78
- except (ProgrammingError, IntegrityError) as err:
79
- err = str(err.orig).strip()
80
- dim = "already exists" in err
81
- pretty_print(sql, fg=None if dim else "red", dim=True)
82
- if dim:
83
- err = " " + err
84
- secho(err, fg="red", dim=dim)
85
-
86
-
87
- def run_sql_file(session, sql_file, params=None):
88
- sql = open(sql_file).read()
89
- run_sql(session, sql, params=params)
90
-
91
-
92
- def run_sql_query_file(session, sql_file, params=None):
93
- sql = open(sql_file).read()
94
- return session.execute(sql, params)
95
-
96
-
97
- def get_or_create(session, model, defaults=None, **kwargs):
98
- """
99
- Get an instance of a model, or create it if it doesn't
100
- exist.
101
-
102
- https://stackoverflow.com/questions/2546207
103
- """
104
- instance = session.query(model).filter_by(**kwargs).first()
105
- if instance:
106
- instance._created = False
107
- return instance
108
- else:
109
- params = dict(
110
- (k, v) for k, v in kwargs.items() if not isinstance(v, ClauseElement)
111
- )
112
- params.update(defaults or {})
113
- instance = model(**params)
114
- session.add(instance)
115
- instance._created = True
116
- return instance
117
-
118
-
119
- def get_db_model(db, model_name: str):
120
- return getattr(db.model, model_name)
121
-
122
-
123
- @contextmanager
124
- def temp_database(conn_string, drop=True, ensure_empty=False):
125
- """Create a temporary database and tear it down after tests."""
126
- if ensure_empty:
127
- drop_database(conn_string)
128
- if not database_exists(conn_string):
129
- create_database(conn_string)
130
- try:
131
- yield create_engine(conn_string)
132
- finally:
133
- if drop:
134
- drop_database(conn_string)
135
-
136
-
137
- def connection_args(engine):
138
- """Get PostgreSQL connection arguments for a engine"""
139
- _psql_flags = {"-U": "username", "-h": "host", "-p": "port", "-P": "password"}
140
-
141
- if isinstance(engine, str):
142
- # We passed a connection url!
143
- engine = create_engine(engine)
144
- flags = ""
145
- for flag, _attr in _psql_flags.items():
146
- val = getattr(engine.url, _attr)
147
- if val is not None:
148
- flags += f" {flag} {val}"
149
- return flags, engine.url.database
150
-
151
-
152
- def db_isready(engine_or_url):
153
- args, _ = connection_args(engine_or_url)
154
- c = cmd("pg_isready", args, capture_output=True)
155
- return c.returncode == 0
156
-
157
-
158
- def wait_for_database(engine_or_url, quiet=False):
159
- msg = "Waiting for database..."
160
- while not db_isready(engine_or_url):
161
- if not quiet:
162
- echo(msg, err=True)
163
- log.info(msg)
164
- sleep(1)
@@ -1,27 +0,0 @@
1
- [tool.poetry]
2
- authors = ["Daven Quinn <dev@davenquinn.com>"]
3
- description = "A small library based on SQLAlchemy to assist with common database tasks."
4
- name = "macrostrat.database"
5
- packages = [
6
- {include = "macrostrat"},
7
- ]
8
- version = "1.0.2"
9
-
10
- [tool.poetry.dependencies]
11
- GeoAlchemy2 = "^0.9.4"
12
- SQLAlchemy = "^1.4.26"
13
- SQLAlchemy-Utils = "^0.37.0"
14
- click = "^8.1.3"
15
- "macrostrat.utils" = "^1.0.0"
16
- migra = "^3.0.1621480950"
17
- psycopg2-binary = "^2.9.1"
18
- python = "^3.8"
19
- schemainspect = "^3.0.1616029793"
20
- sqlparse = "^0.4.0"
21
-
22
- [tool.poetry.dev-dependencies]
23
- "macrostrat.utils" = {path = "../utils", develop = true}
24
-
25
- [build-system]
26
- build-backend = "poetry.core.masonry.api"
27
- requires = ["poetry-core>=1.0.0"]
@@ -1,38 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- from setuptools import setup
3
-
4
- packages = \
5
- ['macrostrat', 'macrostrat.database', 'macrostrat.database.mapper']
6
-
7
- package_data = \
8
- {'': ['*']}
9
-
10
- install_requires = \
11
- ['GeoAlchemy2>=0.9.4,<0.10.0',
12
- 'SQLAlchemy-Utils>=0.37.0,<0.38.0',
13
- 'SQLAlchemy>=1.4.26,<2.0.0',
14
- 'click>=8.1.3,<9.0.0',
15
- 'macrostrat.utils>=1.0.0,<2.0.0',
16
- 'migra>=3.0.1621480950,<4.0.0',
17
- 'psycopg2-binary>=2.9.1,<3.0.0',
18
- 'schemainspect>=3.0.1616029793,<4.0.0',
19
- 'sqlparse>=0.4.0,<0.5.0']
20
-
21
- setup_kwargs = {
22
- 'name': 'macrostrat.database',
23
- 'version': '1.0.2',
24
- 'description': 'A small library based on SQLAlchemy to assist with common database tasks.',
25
- 'long_description': None,
26
- 'author': 'Daven Quinn',
27
- 'author_email': 'dev@davenquinn.com',
28
- 'maintainer': None,
29
- 'maintainer_email': None,
30
- 'url': None,
31
- 'packages': packages,
32
- 'package_data': package_data,
33
- 'install_requires': install_requires,
34
- 'python_requires': '>=3.8,<4.0',
35
- }
36
-
37
-
38
- setup(**setup_kwargs)