macrostrat.database 3.5.1__tar.gz → 3.5.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (15) hide show
  1. {macrostrat_database-3.5.1 → macrostrat_database-3.5.3}/PKG-INFO +1 -1
  2. {macrostrat_database-3.5.1 → macrostrat_database-3.5.3}/macrostrat/database/__init__.py +25 -1
  3. {macrostrat_database-3.5.1 → macrostrat_database-3.5.3}/macrostrat/database/transfer/dump_database.py +1 -0
  4. {macrostrat_database-3.5.1 → macrostrat_database-3.5.3}/macrostrat/database/transfer/move_tables.py +1 -0
  5. {macrostrat_database-3.5.1 → macrostrat_database-3.5.3}/macrostrat/database/transfer/restore_database.py +1 -0
  6. {macrostrat_database-3.5.1 → macrostrat_database-3.5.3}/macrostrat/database/transfer/stream_utils.py +1 -0
  7. {macrostrat_database-3.5.1 → macrostrat_database-3.5.3}/macrostrat/database/transfer/utils.py +1 -1
  8. {macrostrat_database-3.5.1 → macrostrat_database-3.5.3}/macrostrat/database/utils.py +83 -48
  9. {macrostrat_database-3.5.1 → macrostrat_database-3.5.3}/pyproject.toml +1 -1
  10. {macrostrat_database-3.5.1 → macrostrat_database-3.5.3}/macrostrat/database/mapper/__init__.py +0 -0
  11. {macrostrat_database-3.5.1 → macrostrat_database-3.5.3}/macrostrat/database/mapper/base.py +0 -0
  12. {macrostrat_database-3.5.1 → macrostrat_database-3.5.3}/macrostrat/database/mapper/cache.py +0 -0
  13. {macrostrat_database-3.5.1 → macrostrat_database-3.5.3}/macrostrat/database/mapper/utils.py +0 -0
  14. {macrostrat_database-3.5.1 → macrostrat_database-3.5.3}/macrostrat/database/postgresql.py +0 -0
  15. {macrostrat_database-3.5.1 → macrostrat_database-3.5.3}/macrostrat/database/transfer/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: macrostrat.database
3
- Version: 3.5.1
3
+ Version: 3.5.3
4
4
  Summary: A SQLAlchemy-based database toolkit.
5
5
  Author: Daven Quinn
6
6
  Author-email: dev@davenquinn.com
@@ -5,13 +5,14 @@ from typing import Optional, Union
5
5
 
6
6
  from psycopg2.errors import InvalidSavepointSpecification
7
7
  from psycopg2.sql import Identifier
8
- from sqlalchemy import URL, MetaData, create_engine, inspect, Engine
8
+ from sqlalchemy import URL, Engine, MetaData, create_engine, inspect
9
9
  from sqlalchemy.exc import IntegrityError, InternalError
10
10
  from sqlalchemy.ext.compiler import compiles
11
11
  from sqlalchemy.orm import Session, scoped_session, sessionmaker
12
12
  from sqlalchemy.sql.expression import Insert
13
13
 
14
14
  from macrostrat.utils import get_logger
15
+
15
16
  from .mapper import DatabaseMapper
16
17
  from .postgresql import on_conflict, prefix_inserts # noqa
17
18
  from .utils import ( # noqa
@@ -184,6 +185,29 @@ class Database(object):
184
185
  self.__inspector__ = inspect(self.engine)
185
186
  return self.__inspector__
186
187
 
188
+ def refresh_schema(self, *, automap=None):
189
+ """
190
+ Refresh the current database connection
191
+
192
+ - closes the session and flushes
193
+ - removes the inspector
194
+
195
+ If automap is True, will automap the database after refreshing.
196
+ If automap is False, will not automap the database after refreshing.
197
+ If automap is None, it will re-map the database if it was previously mapped.
198
+ """
199
+ # Close the session
200
+ self.session.flush()
201
+ self.session.close()
202
+ # Remove the inspector
203
+ self.__inspector__ = None
204
+
205
+ if automap is None:
206
+ automap = self.mapper is not None
207
+
208
+ if automap:
209
+ self.automap()
210
+
187
211
  def entity_names(self, **kwargs):
188
212
  """
189
213
  Returns an iterator of names of *schema objects*
@@ -7,6 +7,7 @@ import aiofiles
7
7
  from sqlalchemy.engine import Engine
8
8
 
9
9
  from macrostrat.utils import get_logger
10
+
10
11
  from .stream_utils import print_stdout, print_stream_progress
11
12
  from .utils import _create_command
12
13
 
@@ -3,6 +3,7 @@ import asyncio
3
3
  from sqlalchemy.engine import Engine
4
4
 
5
5
  from macrostrat.utils import get_logger
6
+
6
7
  from .dump_database import pg_dump
7
8
  from .restore_database import pg_restore
8
9
  from .stream_utils import print_stdout, print_stream_progress
@@ -7,6 +7,7 @@ from rich.console import Console
7
7
  from sqlalchemy.engine import Engine
8
8
 
9
9
  from macrostrat.utils import get_logger
10
+
10
11
  from .stream_utils import print_stdout, print_stream_progress
11
12
  from .utils import _create_command, _create_database_if_not_exists
12
13
 
@@ -5,6 +5,7 @@ import zlib
5
5
  from aiofiles.threadpool import AsyncBufferedIOBase
6
6
 
7
7
  from macrostrat.utils import get_logger
8
+
8
9
  from .utils import console
9
10
 
10
11
  log = get_logger(__name__)
@@ -5,7 +5,7 @@ from sqlalchemy.engine import Engine
5
5
  from sqlalchemy.engine.url import URL
6
6
  from sqlalchemy_utils import create_database, database_exists, drop_database
7
7
 
8
- from macrostrat.utils import get_logger, ApplicationError
8
+ from macrostrat.utils import ApplicationError, get_logger
9
9
 
10
10
  console = Console()
11
11
 
@@ -1,6 +1,9 @@
1
+ import os
1
2
  from contextlib import contextmanager
3
+ from enum import Enum
2
4
  from pathlib import Path
3
5
  from re import search
6
+ from sys import stderr
4
7
  from time import sleep
5
8
  from typing import IO, Union
6
9
  from warnings import warn
@@ -17,8 +20,8 @@ from sqlalchemy.exc import (
17
20
  IntegrityError,
18
21
  InternalError,
19
22
  InvalidRequestError,
20
- ProgrammingError,
21
23
  OperationalError,
24
+ ProgrammingError,
22
25
  )
23
26
  from sqlalchemy.orm import sessionmaker
24
27
  from sqlalchemy.schema import Table
@@ -46,23 +49,11 @@ def infer_is_sql_text(_string: str) -> bool:
46
49
  if isinstance(_string, bytes):
47
50
  _string = _string.decode("utf-8")
48
51
 
49
- keywords = [
50
- "SELECT",
51
- "INSERT",
52
- "UPDATE",
53
- "CREATE",
54
- "DROP",
55
- "DELETE",
56
- "ALTER",
57
- "SET",
58
- "GRANT",
59
- "WITH",
60
- ]
61
52
  lines = _string.split("\n")
62
53
  if len(lines) > 1:
63
54
  return True
64
55
  _string = _string.lower()
65
- for i in keywords:
56
+ for i in _sql_keywords:
66
57
  if _string.strip().startswith(i.lower() + " "):
67
58
  return True
68
59
  return False
@@ -97,26 +88,35 @@ def get_dataframe(connectable, filename_or_query, **kwargs):
97
88
 
98
89
 
99
90
  def pretty_print(sql, **kwargs):
91
+ """Print and optionally summarize an SQL query"""
92
+ summarize = kwargs.pop("summarize", True)
93
+ if summarize:
94
+ sql = summarize_statement(sql)
95
+ secho(sql, **kwargs)
96
+
97
+
98
+ _sql_keywords = [
99
+ "SELECT",
100
+ "INSERT",
101
+ "UPDATE",
102
+ "CREATE",
103
+ "DROP",
104
+ "DELETE",
105
+ "ALTER",
106
+ "SET",
107
+ "GRANT",
108
+ "WITH",
109
+ "NOTIFY",
110
+ "COPY",
111
+ ]
112
+
113
+
114
+ def summarize_statement(sql):
100
115
  for line in sql.split("\n"):
101
- for i in [
102
- "SELECT",
103
- "INSERT",
104
- "UPDATE",
105
- "CREATE",
106
- "DROP",
107
- "DELETE",
108
- "ALTER",
109
- "SET",
110
- "GRANT",
111
- "WITH",
112
- "NOTIFY",
113
- "COPY",
114
- ]:
116
+ for i in _sql_keywords:
115
117
  if not line.startswith(i):
116
118
  continue
117
- start = line.split("(")[0].strip().rstrip(";").replace(" AS", "")
118
- secho(start, **kwargs)
119
- return
119
+ return line.split("(")[0].strip().rstrip(";").replace(" AS", "")
120
120
 
121
121
 
122
122
  def get_sql_text(sql, interpret_as_file=None, echo_file_name=True):
@@ -235,6 +235,25 @@ def infer_has_server_binds(sql):
235
235
  _default_statement_filter = lambda sql_text, params: True
236
236
 
237
237
 
238
+ class OutputMode(Enum):
239
+ NONE = "none"
240
+ ERRORS = "errors"
241
+ SUMMARY = "summary"
242
+ ALL = "all"
243
+
244
+
245
+ def _normalize_output_args(kwargs):
246
+ output_mode = kwargs.pop("output_mode", OutputMode.SUMMARY)
247
+ output_file = kwargs.pop("output_file", stderr)
248
+
249
+ if not isinstance(output_mode, OutputMode):
250
+ output_mode = OutputMode(output_mode)
251
+
252
+ if output_mode == OutputMode.NONE:
253
+ output_file = open(os.devnull, "w")
254
+ return output_mode, output_file
255
+
256
+
238
257
  def _run_sql(connectable, sql, params=None, **kwargs):
239
258
  """
240
259
  Internal function for running a query on a SQLAlchemy connectable,
@@ -251,6 +270,7 @@ def _run_sql(connectable, sql, params=None, **kwargs):
251
270
  has_server_binds = kwargs.pop("has_server_binds", None)
252
271
  ensure_single_query = kwargs.pop("ensure_single_query", False)
253
272
  statement_filter = kwargs.pop("statement_filter", _default_statement_filter)
273
+ output_mode, output_file = _normalize_output_args(kwargs)
254
274
 
255
275
  if stop_on_error:
256
276
  raise_errors = True
@@ -283,6 +303,7 @@ def _run_sql(connectable, sql, params=None, **kwargs):
283
303
  query = _render_query(query, connectable)
284
304
 
285
305
  sql_text = str(query)
306
+
286
307
  if isinstance(query, str):
287
308
  sql_text = format(query, strip_comments=True).strip()
288
309
  if sql_text == "":
@@ -293,8 +314,18 @@ def _run_sql(connectable, sql, params=None, **kwargs):
293
314
  has_server_binds = infer_has_server_binds(sql_text)
294
315
 
295
316
  should_run = statement_filter(sql_text, params)
317
+
318
+ # Shorten summary text for printing
319
+ if output_mode != OutputMode.ALL:
320
+ sql_text = summarize_statement(sql_text)
321
+
296
322
  if not should_run:
297
- pretty_print(sql_text, dim=True, strikethrough=True)
323
+ secho(
324
+ sql_text,
325
+ dim=True,
326
+ strikethrough=True,
327
+ file=output_file,
328
+ )
298
329
  continue
299
330
 
300
331
  # This only does something for postgresql, but it's harmless to run it for other engines
@@ -318,7 +349,7 @@ def _run_sql(connectable, sql, params=None, **kwargs):
318
349
  trans.commit()
319
350
  elif hasattr(connectable, "commit"):
320
351
  connectable.commit()
321
- pretty_print(sql_text, dim=True)
352
+ secho(sql_text, dim=True, file=output_file)
322
353
  except Exception as err:
323
354
  if trans is not None:
324
355
  trans.rollback()
@@ -327,7 +358,7 @@ def _run_sql(connectable, sql, params=None, **kwargs):
327
358
  if raise_errors or _should_raise_query_error(err):
328
359
  raise err
329
360
 
330
- _print_error(sql_text, err)
361
+ _print_error(sql_text, err, file=output_file)
331
362
  finally:
332
363
  set_wait_callback(None)
333
364
 
@@ -356,17 +387,18 @@ def _should_raise_query_error(err):
356
387
  return False
357
388
 
358
389
 
359
- def _print_error(sql_text, err):
390
+ def _print_error(sql_text, err, **kwargs):
360
391
  if orig := getattr(err, "orig", None):
361
392
  _err = str(orig)
362
393
  else:
363
394
  _err = str(err)
364
395
  _err = _err.strip()
365
- dim = "already exists" in _err
366
- pretty_print(sql_text, fg=None if dim else "red", dim=True)
396
+ # Decide whether error should be dimmed
397
+ dim = kwargs.pop("dim", "already exists" in _err)
398
+ secho(sql_text, fg=None if dim else "red", dim=True, **kwargs)
367
399
  if dim:
368
400
  _err = " " + _err
369
- secho(_err, fg="red", dim=dim)
401
+ secho(_err, fg="red", dim=dim, **kwargs)
370
402
  log.error(err)
371
403
 
372
404
 
@@ -420,11 +452,17 @@ def run_fixtures(connectable, fixtures: Union[Path, list[Path]], params=None, **
420
452
  """
421
453
  recursive = kwargs.pop("recursive", False)
422
454
  order_by_name = kwargs.pop("order_by_name", True)
423
- console = kwargs.pop("console", Console(stderr=True))
455
+ output_mode, output_file = _normalize_output_args(kwargs)
456
+
457
+ console = kwargs.pop("console", Console(stderr=True, file=output_file))
424
458
  files = get_sql_files(fixtures, recursive=recursive, order_by_name=order_by_name)
425
459
 
460
+ prefix = os.path.commonpath(files)
461
+
462
+ console.print(f"Running fixtures in [cyan bold]{prefix}[/]")
426
463
  for fixture in files:
427
- console.print(f"[cyan bold]{fixture}[/]")
464
+ fn = fixture.relative_to(prefix)
465
+ console.print(f"[cyan bold]{fn}[/]")
428
466
  run_sql_file(connectable, fixture, params, **kwargs)
429
467
  console.print()
430
468
 
@@ -468,7 +506,9 @@ def run_sql(*args, **kwargs):
468
506
  return list(res)
469
507
 
470
508
 
471
- def execute(connectable, sql, params=None, stop_on_error=False):
509
+ def execute(connectable, sql, params=None, stop_on_error=False, **kwargs):
510
+ output_file = kwargs.pop("output_file", None)
511
+ output_mode = kwargs.pop("output_mode", None)
472
512
  sql = format(sql, strip_comments=True).strip()
473
513
  if sql == "":
474
514
  return
@@ -477,17 +517,12 @@ def execute(connectable, sql, params=None, stop_on_error=False):
477
517
  res = connectable.execute(text(sql), params=params)
478
518
  if hasattr(connectable, "commit"):
479
519
  connectable.commit()
480
- pretty_print(sql, dim=True)
520
+ pretty_print(sql, dim=True, file=output_file, mode=output_mode)
481
521
  return res
482
522
  except (ProgrammingError, IntegrityError) as err:
483
- err = str(err.orig).strip()
484
- dim = "already exists" in err
485
523
  if hasattr(connectable, "rollback"):
486
524
  connectable.rollback()
487
- pretty_print(sql, fg=None if dim else "red", dim=True)
488
- if dim:
489
- err = " " + err
490
- secho(err, fg="red", dim=dim)
525
+ _print_error(sql, dim=True, file=output_file, mode=output_mode)
491
526
  if stop_on_error:
492
527
  return
493
528
  finally:
@@ -3,7 +3,7 @@ authors = ["Daven Quinn <dev@davenquinn.com>"]
3
3
  description = "A SQLAlchemy-based database toolkit."
4
4
  name = "macrostrat.database"
5
5
  packages = [{ include = "macrostrat" }]
6
- version = "3.5.1"
6
+ version = "3.5.3"
7
7
 
8
8
  [tool.poetry.dependencies]
9
9
  GeoAlchemy2 = "^0.15.2"