macrostrat.database 3.4.1__tar.gz → 3.5.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (16) hide show
  1. {macrostrat_database-3.4.1 → macrostrat_database-3.5.1}/PKG-INFO +3 -5
  2. {macrostrat_database-3.4.1 → macrostrat_database-3.5.1}/macrostrat/database/__init__.py +1 -1
  3. {macrostrat_database-3.4.1 → macrostrat_database-3.5.1}/macrostrat/database/transfer/dump_database.py +16 -4
  4. {macrostrat_database-3.4.1 → macrostrat_database-3.5.1}/macrostrat/database/transfer/move_tables.py +1 -2
  5. {macrostrat_database-3.4.1 → macrostrat_database-3.5.1}/macrostrat/database/transfer/restore_database.py +7 -10
  6. macrostrat_database-3.5.1/macrostrat/database/transfer/stream_utils.py +118 -0
  7. macrostrat_database-3.5.1/macrostrat/database/transfer/utils.py +97 -0
  8. {macrostrat_database-3.4.1 → macrostrat_database-3.5.1}/macrostrat/database/utils.py +21 -4
  9. {macrostrat_database-3.4.1 → macrostrat_database-3.5.1}/pyproject.toml +3 -3
  10. macrostrat_database-3.4.1/macrostrat/database/transfer/utils.py +0 -107
  11. {macrostrat_database-3.4.1 → macrostrat_database-3.5.1}/macrostrat/database/mapper/__init__.py +0 -0
  12. {macrostrat_database-3.4.1 → macrostrat_database-3.5.1}/macrostrat/database/mapper/base.py +0 -0
  13. {macrostrat_database-3.4.1 → macrostrat_database-3.5.1}/macrostrat/database/mapper/cache.py +0 -0
  14. {macrostrat_database-3.4.1 → macrostrat_database-3.5.1}/macrostrat/database/mapper/utils.py +0 -0
  15. {macrostrat_database-3.4.1 → macrostrat_database-3.5.1}/macrostrat/database/postgresql.py +0 -0
  16. {macrostrat_database-3.4.1 → macrostrat_database-3.5.1}/macrostrat/database/transfer/__init__.py +0 -0
@@ -1,13 +1,11 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: macrostrat.database
3
- Version: 3.4.1
3
+ Version: 3.5.1
4
4
  Summary: A SQLAlchemy-based database toolkit.
5
5
  Author: Daven Quinn
6
6
  Author-email: dev@davenquinn.com
7
- Requires-Python: >=3.8,<4.0
7
+ Requires-Python: >=3.10,<4.0
8
8
  Classifier: Programming Language :: Python :: 3
9
- Classifier: Programming Language :: Python :: 3.8
10
- Classifier: Programming Language :: Python :: 3.9
11
9
  Classifier: Programming Language :: Python :: 3.10
12
10
  Classifier: Programming Language :: Python :: 3.11
13
11
  Classifier: Programming Language :: Python :: 3.12
@@ -16,7 +14,7 @@ Requires-Dist: SQLAlchemy (>=2.0.18,<3.0.0)
16
14
  Requires-Dist: SQLAlchemy-Utils (>=0.41.1,<0.42.0)
17
15
  Requires-Dist: aiofiles (>=23.2.1,<24.0.0)
18
16
  Requires-Dist: click (>=8.1.3,<9.0.0)
19
- Requires-Dist: macrostrat.utils (>=1.0.0,<2.0.0)
17
+ Requires-Dist: macrostrat.utils (>=1.3.0,<2.0.0)
20
18
  Requires-Dist: psycopg2-binary (>=2.9.6,<3.0.0)
21
19
  Requires-Dist: rich (>=13.7.1,<14.0.0)
22
20
  Requires-Dist: sqlparse (>=0.5.1,<0.6.0)
@@ -128,7 +128,7 @@ class Database(object):
128
128
  Returns: Iterator of results from the query.
129
129
  """
130
130
  params = self._setup_params(params, kwargs)
131
- return iter(run_sql(self.session, fn, params, **kwargs))
131
+ return run_sql(self.session, fn, params, **kwargs)
132
132
 
133
133
  def run_query(self, sql, params=None, **kwargs):
134
134
  """Run a single query on the database object, returning the result.
@@ -1,4 +1,5 @@
1
1
  import asyncio
2
+ import sys
2
3
  from pathlib import Path
3
4
  from typing import Optional
4
5
 
@@ -6,8 +7,8 @@ import aiofiles
6
7
  from sqlalchemy.engine import Engine
7
8
 
8
9
  from macrostrat.utils import get_logger
9
-
10
- from .utils import _create_command, print_stdout, print_stream_progress
10
+ from .stream_utils import print_stdout, print_stream_progress
11
+ from .utils import _create_command
11
12
 
12
13
  log = get_logger(__name__)
13
14
 
@@ -45,11 +46,22 @@ async def pg_dump(
45
46
  )
46
47
 
47
48
 
48
- async def pg_dump_to_file(dumpfile: Path, *args, **kwargs):
49
- proc = await pg_dump(*args, **kwargs)
49
+ async def pg_dump_to_file(engine: Engine, dumpfile: Path | None, **kwargs):
50
+ proc = await pg_dump(engine, **kwargs)
51
+ if dumpfile is None or dumpfile == sys.stdout:
52
+ # If we have no dumpfile, just print to stdout
53
+ await _monitor_stdout(proc)
54
+ return
50
55
  # Open dump file as an async stream
51
56
  async with aiofiles.open(dumpfile, mode="wb") as dest:
52
57
  await asyncio.gather(
53
58
  asyncio.create_task(print_stream_progress(proc.stdout, dest)),
54
59
  asyncio.create_task(print_stdout(proc.stderr)),
55
60
  )
61
+
62
+
63
+ async def _monitor_stdout(proc):
64
+ await asyncio.gather(
65
+ asyncio.create_task(print_stdout(proc.stdout)),
66
+ asyncio.create_task(print_stream_progress(proc.stderr, None)),
67
+ )
@@ -3,10 +3,9 @@ import asyncio
3
3
  from sqlalchemy.engine import Engine
4
4
 
5
5
  from macrostrat.utils import get_logger
6
-
7
6
  from .dump_database import pg_dump
8
7
  from .restore_database import pg_restore
9
- from .utils import print_stdout, print_stream_progress
8
+ from .stream_utils import print_stdout, print_stream_progress
10
9
 
11
10
  log = get_logger(__name__)
12
11
 
@@ -7,13 +7,8 @@ from rich.console import Console
7
7
  from sqlalchemy.engine import Engine
8
8
 
9
9
  from macrostrat.utils import get_logger
10
-
11
- from .utils import (
12
- _create_command,
13
- _create_database_if_not_exists,
14
- print_stdout,
15
- print_stream_progress,
16
- )
10
+ from .stream_utils import print_stdout, print_stream_progress
11
+ from .utils import _create_command, _create_database_if_not_exists
17
12
 
18
13
  console = Console()
19
14
 
@@ -56,11 +51,13 @@ async def pg_restore(
56
51
  )
57
52
 
58
53
 
59
- async def pg_restore_from_file(dumpfile: Path, *args, **kwargs):
60
- proc = await pg_restore(*args, **kwargs)
54
+ async def pg_restore_from_file(dumpfile: Path, engine: Engine, **kwargs):
55
+ proc = await pg_restore(engine, **kwargs)
61
56
  # Open dump file as an async stream
62
57
  async with aiofiles.open(dumpfile, mode="rb") as source:
63
58
  await asyncio.gather(
64
- asyncio.create_task(print_stream_progress(source, proc.stdin)),
59
+ asyncio.create_task(
60
+ print_stream_progress(source, proc.stdin, prefix="Restored")
61
+ ),
65
62
  asyncio.create_task(print_stdout(proc.stderr)),
66
63
  )
@@ -0,0 +1,118 @@
1
+ import asyncio
2
+ import sys
3
+ import zlib
4
+
5
+ from aiofiles.threadpool import AsyncBufferedIOBase
6
+
7
+ from macrostrat.utils import get_logger
8
+ from .utils import console
9
+
10
+ log = get_logger(__name__)
11
+
12
+
13
+ async def print_stream_progress(
14
+ input: asyncio.StreamReader | asyncio.subprocess.Process,
15
+ out_stream: asyncio.StreamWriter | None | AsyncBufferedIOBase = None,
16
+ *,
17
+ verbose: bool = False,
18
+ chunk_size: int = 1024,
19
+ prefix: str = None,
20
+ ):
21
+ """This should be unified with print_stream_progress, but there seem to be
22
+ slight API differences between aiofiles and asyncio.StreamWriter APIs.?"""
23
+ in_stream = input
24
+ if isinstance(in_stream, asyncio.subprocess.Process):
25
+ in_stream = input.stdout
26
+
27
+ megabytes_written = 0
28
+ i = 0
29
+
30
+ # Iterate over the stream by chunks
31
+ try:
32
+ while True:
33
+ chunk = await in_stream.read(chunk_size)
34
+ if not chunk:
35
+ log.info("End of stream")
36
+ break
37
+ if verbose:
38
+ log.info(chunk)
39
+ megabytes_written += len(chunk) / 1_000_000
40
+ if isinstance(out_stream, AsyncBufferedIOBase):
41
+ await out_stream.write(chunk)
42
+ await out_stream.flush()
43
+ elif out_stream is not None:
44
+ out_stream.write(chunk)
45
+ await out_stream.drain()
46
+ i += 1
47
+ if i == 100:
48
+ i = 0
49
+ _print_progress(megabytes_written, end="\r", prefix=prefix)
50
+ except asyncio.CancelledError:
51
+ pass
52
+ finally:
53
+ _print_progress(megabytes_written, prefix=prefix)
54
+
55
+ if isinstance(out_stream, AsyncBufferedIOBase):
56
+ out_stream.close()
57
+ elif out_stream is not None:
58
+ out_stream.close()
59
+ await out_stream.wait_closed()
60
+
61
+
62
+ def _print_progress(megabytes: float, **kwargs):
63
+ prefix = kwargs.pop("prefix", None)
64
+ if prefix is None:
65
+ prefix = "Dumped"
66
+ progress = f"{prefix} {megabytes:.1f} MB"
67
+ kwargs["file"] = sys.stderr
68
+ print(progress, **kwargs)
69
+
70
+
71
+ async def print_stdout(stream: asyncio.StreamReader):
72
+ async for line in stream:
73
+ log.info(line)
74
+ console.print(line.decode("utf-8"), style="dim")
75
+
76
+
77
+ class DecodingStreamReader(asyncio.StreamReader):
78
+ """A StreamReader that decompresses gzip files (if compressed)"""
79
+
80
+ # https://ejosh.co/de/2022/08/stream-a-massive-gzipped-json-file-in-python/
81
+
82
+ def __init__(self, stream, encoding="utf-8", errors="strict"):
83
+ super().__init__()
84
+ self.stream = stream
85
+ self._is_gzipped = None
86
+ self.d = zlib.decompressobj(zlib.MAX_WBITS | 16)
87
+
88
+ def decompress(self, input: bytes) -> bytes:
89
+ decompressed = self.d.decompress(input)
90
+ data = b""
91
+ while self.d.unused_data != b"":
92
+ buf = self.d.unused_data
93
+ self.d = zlib.decompressobj(zlib.MAX_WBITS | 16)
94
+ data = self.d.decompress(buf)
95
+ return decompressed + data
96
+
97
+ def transform_data(self, data):
98
+ if self._is_gzipped is None:
99
+ self._is_gzipped = data[:2] == b"\x1f\x8b"
100
+ log.info("is_gzipped: %s", self._is_gzipped)
101
+ if self._is_gzipped:
102
+ # Decompress the data
103
+ data = self.decompress(data)
104
+ return data
105
+
106
+ async def read(self, n=-1):
107
+ data = await self.stream.read(n)
108
+ return self.transform_data(data)
109
+
110
+ async def readline(self):
111
+ res = b""
112
+ while res == b"":
113
+ # Read next line
114
+ line = await self.stream.readline()
115
+ if not line:
116
+ break
117
+ res += self.transform_data(line)
118
+ return res
@@ -0,0 +1,97 @@
1
+ from urllib.parse import quote
2
+
3
+ from rich.console import Console
4
+ from sqlalchemy.engine import Engine
5
+ from sqlalchemy.engine.url import URL
6
+ from sqlalchemy_utils import create_database, database_exists, drop_database
7
+
8
+ from macrostrat.utils import get_logger, ApplicationError
9
+
10
+ console = Console()
11
+
12
+ log = get_logger(__name__)
13
+
14
+
15
+ def _docker_local_run_args(postgres_container: str = "postgres:15"):
16
+ return [
17
+ "docker",
18
+ "run",
19
+ "-i",
20
+ "--network",
21
+ "host",
22
+ "--attach",
23
+ "stdin",
24
+ "--attach",
25
+ "stdout",
26
+ "--attach",
27
+ "stderr",
28
+ "--log-driver",
29
+ "none",
30
+ "--rm",
31
+ postgres_container,
32
+ ]
33
+
34
+
35
+ def _create_database_if_not_exists(
36
+ _url: URL, *, create=False, allow_exists=True, overwrite=False
37
+ ):
38
+ database = _url.database
39
+ if overwrite:
40
+ create = True
41
+ db_exists = database_exists(_url)
42
+ if db_exists:
43
+ msg = f"Database [bold underline]{database}[/] already exists"
44
+ if overwrite:
45
+ console.print(f"{msg}, overwriting")
46
+ drop_database(_url)
47
+ db_exists = False
48
+ elif not allow_exists:
49
+ raise ApplicationError(msg, details="Use `--overwrite` to overwrite")
50
+ else:
51
+ console.print(msg)
52
+
53
+ if create and not db_exists:
54
+ console.print(f"Creating database [bold cyan]{database}[/]")
55
+ create_database(_url)
56
+
57
+ if not db_exists and not create:
58
+ raise ApplicationError(
59
+ f"Database [bold cyan]{database}[/] does not exist. ",
60
+ "Use `--create` to create it.",
61
+ )
62
+
63
+
64
+ def _create_command(
65
+ engine: Engine,
66
+ *command,
67
+ args=None | list[str],
68
+ prefix=None | list[str],
69
+ container="postgres:16",
70
+ ):
71
+ if args is None:
72
+ args = []
73
+
74
+ command_prefix = prefix or _docker_local_run_args(container)
75
+ _cmd = [*command_prefix, *command, str(engine.url), *args]
76
+
77
+ log.info(" ".join(_cmd))
78
+
79
+ # Replace asterisks with the real password (if any). This is kind of backwards
80
+ # but it works.
81
+ if "***" in str(engine.url) and engine.url.password is not None:
82
+ _cmd = [
83
+ *command_prefix,
84
+ *command,
85
+ raw_database_url(engine.url),
86
+ *args,
87
+ ]
88
+
89
+ return _cmd
90
+
91
+
92
+ def raw_database_url(url: URL):
93
+ """Replace the password asterisks with the actual password, in order to pass to other commands."""
94
+ _url = str(url)
95
+ if "***" not in _url or url.password is None:
96
+ return _url
97
+ return _url.replace("***", quote(url.password, safe=""))
@@ -5,11 +5,11 @@ from time import sleep
5
5
  from typing import IO, Union
6
6
  from warnings import warn
7
7
 
8
+ import psycopg2.errors
8
9
  from click import echo, secho
9
10
  from psycopg2.extensions import set_wait_callback
10
11
  from psycopg2.extras import wait_select
11
12
  from psycopg2.sql import SQL, Composable, Composed
12
- import psycopg2.errors
13
13
  from rich.console import Console
14
14
  from sqlalchemy import MetaData, create_engine, text
15
15
  from sqlalchemy.engine import Connection, Engine
@@ -18,7 +18,7 @@ from sqlalchemy.exc import (
18
18
  InternalError,
19
19
  InvalidRequestError,
20
20
  ProgrammingError,
21
- OperationalError
21
+ OperationalError,
22
22
  )
23
23
  from sqlalchemy.orm import sessionmaker
24
24
  from sqlalchemy.schema import Table
@@ -232,6 +232,9 @@ def infer_has_server_binds(sql):
232
232
  return "%s" in sql or search(r"%\(\w+\)s", sql)
233
233
 
234
234
 
235
+ _default_statement_filter = lambda sql_text, params: True
236
+
237
+
235
238
  def _run_sql(connectable, sql, params=None, **kwargs):
236
239
  """
237
240
  Internal function for running a query on a SQLAlchemy connectable,
@@ -247,6 +250,7 @@ def _run_sql(connectable, sql, params=None, **kwargs):
247
250
  raise_errors = kwargs.pop("raise_errors", False)
248
251
  has_server_binds = kwargs.pop("has_server_binds", None)
249
252
  ensure_single_query = kwargs.pop("ensure_single_query", False)
253
+ statement_filter = kwargs.pop("statement_filter", _default_statement_filter)
250
254
 
251
255
  if stop_on_error:
252
256
  raise_errors = True
@@ -288,6 +292,11 @@ def _run_sql(connectable, sql, params=None, **kwargs):
288
292
  if has_server_binds is None:
289
293
  has_server_binds = infer_has_server_binds(sql_text)
290
294
 
295
+ should_run = statement_filter(sql_text, params)
296
+ if not should_run:
297
+ pretty_print(sql_text, dim=True, strikethrough=True)
298
+ continue
299
+
291
300
  # This only does something for postgresql, but it's harmless to run it for other engines
292
301
  set_wait_callback(wait_select)
293
302
 
@@ -325,7 +334,9 @@ def _run_sql(connectable, sql, params=None, **kwargs):
325
334
 
326
335
  def _should_raise_query_error(err):
327
336
  """Determine if an error should be raised for a query or not."""
328
- if not isinstance(err, (ProgrammingError, IntegrityError, InternalError, OperationalError)):
337
+ if not isinstance(
338
+ err, (ProgrammingError, IntegrityError, InternalError, OperationalError)
339
+ ):
329
340
  return True
330
341
 
331
342
  orig_err = getattr(err, "orig", None)
@@ -336,7 +347,10 @@ def _should_raise_query_error(err):
336
347
  # We might want to change this behavior in the future, or support more graceful handling of errors from other
337
348
  # database backends.
338
349
  # Ideally we could handle operational errors more gracefully
339
- if isinstance(orig_err, psycopg2.errors.QueryCanceled) or getattr(orig_err, "pgcode", None) == "57014":
350
+ if (
351
+ isinstance(orig_err, psycopg2.errors.QueryCanceled)
352
+ or getattr(orig_err, "pgcode", None) == "57014"
353
+ ):
340
354
  return True
341
355
 
342
356
  return False
@@ -444,6 +458,9 @@ def run_sql(*args, **kwargs):
444
458
  returning a list after completion.
445
459
  ensure_single_query : bool
446
460
  If True, raise an error if multiple queries are passed when only one is expected.
461
+ statement_filter : Callable
462
+ A function that takes a SQL statement and parameters and returns True if the statement
463
+ should be run, and False if it should be skipped.
447
464
  """
448
465
  res = _run_sql(*args, **kwargs)
449
466
  if kwargs.pop("yield_results", False):
@@ -3,16 +3,16 @@ authors = ["Daven Quinn <dev@davenquinn.com>"]
3
3
  description = "A SQLAlchemy-based database toolkit."
4
4
  name = "macrostrat.database"
5
5
  packages = [{ include = "macrostrat" }]
6
- version = "3.4.1"
6
+ version = "3.5.1"
7
7
 
8
8
  [tool.poetry.dependencies]
9
9
  GeoAlchemy2 = "^0.15.2"
10
10
  SQLAlchemy = "^2.0.18"
11
11
  SQLAlchemy-Utils = "^0.41.1"
12
12
  click = "^8.1.3"
13
- "macrostrat.utils" = "^1.0.0"
13
+ "macrostrat.utils" = "^1.3.0"
14
14
  psycopg2-binary = "^2.9.6"
15
- python = "^3.8"
15
+ python = "^3.10"
16
16
  sqlparse = "^0.5.1"
17
17
  aiofiles = "^23.2.1"
18
18
  rich = "^13.7.1"
@@ -1,107 +0,0 @@
1
- import asyncio
2
- from urllib.parse import quote
3
-
4
- from aiofiles.threadpool.binary import AsyncBufferedIOBase
5
- from rich.console import Console
6
- from sqlalchemy.engine import Engine
7
- from sqlalchemy.engine.url import URL
8
- from sqlalchemy_utils import create_database, database_exists
9
-
10
- from macrostrat.utils import get_logger
11
-
12
- console = Console()
13
-
14
- log = get_logger(__name__)
15
-
16
-
17
- def _docker_local_run_args(postgres_container: str = "postgres:15"):
18
- return [
19
- "docker",
20
- "run",
21
- "-i",
22
- "--rm",
23
- "--network",
24
- "host",
25
- postgres_container,
26
- ]
27
-
28
-
29
- def _create_database_if_not_exists(_url: URL, create=False):
30
- database = _url.database
31
- db_exists = database_exists(_url)
32
- if db_exists:
33
- console.print(f"Database [bold cyan]{database}[/] already exists")
34
-
35
- if create and not db_exists:
36
- console.print(f"Creating database [bold cyan]{database}[/]")
37
- create_database(_url)
38
-
39
- if not db_exists and not create:
40
- raise ValueError(
41
- f"Database [bold cyan]{database}[/] does not exist. "
42
- "Use `--create` to create it."
43
- )
44
-
45
-
46
- def _create_command(
47
- engine: Engine,
48
- *command,
49
- args=[],
50
- prefix=None | list[str],
51
- container="postgres:16",
52
- ):
53
- command_prefix = prefix or _docker_local_run_args(container)
54
- _cmd = [*command_prefix, *command, str(engine.url), *args]
55
-
56
- log.info(" ".join(_cmd))
57
-
58
- # Replace asterisks with the real password (if any). This is kind of backwards
59
- # but it works.
60
- if "***" in str(engine.url) and engine.url.password is not None:
61
- _cmd = [
62
- *command_prefix,
63
- *command,
64
- raw_database_url(engine.url),
65
- *args,
66
- ]
67
-
68
- return _cmd
69
-
70
-
71
- async def print_stream_progress(
72
- in_stream: asyncio.StreamReader,
73
- out_stream: asyncio.StreamWriter | AsyncBufferedIOBase,
74
- ):
75
- """This should be unified with print_stream_progress, but there seem to be
76
- slight API differences between aiofiles and asyncio.StreamWriter APIs.?"""
77
- megabytes_written = 0
78
- i = 0
79
- async for line in in_stream:
80
- megabytes_written += len(line) / 1_000_000
81
- if isinstance(out_stream, AsyncBufferedIOBase):
82
- await out_stream.write(line)
83
- await out_stream.flush()
84
- else:
85
- out_stream.write(line)
86
- await out_stream.drain()
87
- i += 1
88
- if i == 1000:
89
- i = 0
90
- _print_progress(megabytes_written, end="\r")
91
-
92
- out_stream.close()
93
- _print_progress(megabytes_written)
94
-
95
-
96
- def _print_progress(megabytes: float, **kwargs):
97
- progress = f"Dumped {megabytes:.1f} MB"
98
- print(progress, **kwargs)
99
-
100
-
101
- async def print_stdout(stream: asyncio.StreamReader):
102
- async for line in stream:
103
- console.print(line.decode("utf-8"), style="dim")
104
-
105
-
106
- def raw_database_url(url: URL):
107
- return str(url).replace("***", quote(url.password, safe=""))