iceaxe 0.8.3__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of iceaxe might be problematic. Click here for more details.
- iceaxe/__init__.py +20 -0
- iceaxe/__tests__/__init__.py +0 -0
- iceaxe/__tests__/benchmarks/__init__.py +0 -0
- iceaxe/__tests__/benchmarks/test_bulk_insert.py +45 -0
- iceaxe/__tests__/benchmarks/test_select.py +114 -0
- iceaxe/__tests__/conf_models.py +133 -0
- iceaxe/__tests__/conftest.py +204 -0
- iceaxe/__tests__/docker_helpers.py +208 -0
- iceaxe/__tests__/helpers.py +268 -0
- iceaxe/__tests__/migrations/__init__.py +0 -0
- iceaxe/__tests__/migrations/conftest.py +36 -0
- iceaxe/__tests__/migrations/test_action_sorter.py +237 -0
- iceaxe/__tests__/migrations/test_generator.py +140 -0
- iceaxe/__tests__/migrations/test_generics.py +91 -0
- iceaxe/__tests__/mountaineer/__init__.py +0 -0
- iceaxe/__tests__/mountaineer/dependencies/__init__.py +0 -0
- iceaxe/__tests__/mountaineer/dependencies/test_core.py +76 -0
- iceaxe/__tests__/schemas/__init__.py +0 -0
- iceaxe/__tests__/schemas/test_actions.py +1265 -0
- iceaxe/__tests__/schemas/test_cli.py +25 -0
- iceaxe/__tests__/schemas/test_db_memory_serializer.py +1571 -0
- iceaxe/__tests__/schemas/test_db_serializer.py +435 -0
- iceaxe/__tests__/schemas/test_db_stubs.py +190 -0
- iceaxe/__tests__/test_alias.py +83 -0
- iceaxe/__tests__/test_base.py +52 -0
- iceaxe/__tests__/test_comparison.py +383 -0
- iceaxe/__tests__/test_field.py +11 -0
- iceaxe/__tests__/test_helpers.py +9 -0
- iceaxe/__tests__/test_modifications.py +151 -0
- iceaxe/__tests__/test_queries.py +764 -0
- iceaxe/__tests__/test_queries_str.py +173 -0
- iceaxe/__tests__/test_session.py +1511 -0
- iceaxe/__tests__/test_text_search.py +287 -0
- iceaxe/alias_values.py +67 -0
- iceaxe/base.py +351 -0
- iceaxe/comparison.py +560 -0
- iceaxe/field.py +263 -0
- iceaxe/functions.py +1432 -0
- iceaxe/generics.py +140 -0
- iceaxe/io.py +107 -0
- iceaxe/logging.py +91 -0
- iceaxe/migrations/__init__.py +5 -0
- iceaxe/migrations/action_sorter.py +98 -0
- iceaxe/migrations/cli.py +228 -0
- iceaxe/migrations/client_io.py +62 -0
- iceaxe/migrations/generator.py +404 -0
- iceaxe/migrations/migration.py +86 -0
- iceaxe/migrations/migrator.py +101 -0
- iceaxe/modifications.py +176 -0
- iceaxe/mountaineer/__init__.py +10 -0
- iceaxe/mountaineer/cli.py +74 -0
- iceaxe/mountaineer/config.py +46 -0
- iceaxe/mountaineer/dependencies/__init__.py +6 -0
- iceaxe/mountaineer/dependencies/core.py +67 -0
- iceaxe/postgres.py +133 -0
- iceaxe/py.typed +0 -0
- iceaxe/queries.py +1459 -0
- iceaxe/queries_str.py +294 -0
- iceaxe/schemas/__init__.py +0 -0
- iceaxe/schemas/actions.py +864 -0
- iceaxe/schemas/cli.py +30 -0
- iceaxe/schemas/db_memory_serializer.py +711 -0
- iceaxe/schemas/db_serializer.py +347 -0
- iceaxe/schemas/db_stubs.py +529 -0
- iceaxe/session.py +860 -0
- iceaxe/session_optimized.c +12207 -0
- iceaxe/session_optimized.cpython-313-darwin.so +0 -0
- iceaxe/session_optimized.pyx +212 -0
- iceaxe/sql_types.py +149 -0
- iceaxe/typing.py +73 -0
- iceaxe-0.8.3.dist-info/METADATA +262 -0
- iceaxe-0.8.3.dist-info/RECORD +75 -0
- iceaxe-0.8.3.dist-info/WHEEL +6 -0
- iceaxe-0.8.3.dist-info/licenses/LICENSE +21 -0
- iceaxe-0.8.3.dist-info/top_level.txt +1 -0
iceaxe/generics.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
import types
|
|
2
|
+
from inspect import isclass
|
|
3
|
+
from typing import Any, Type, TypeGuard, TypeVar, Union, get_args, get_origin
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def mro_distance(obj_type: Type, target_type: Type) -> float:
|
|
7
|
+
"""
|
|
8
|
+
Calculate the MRO distance between obj_type and target_type.
|
|
9
|
+
Returns a large number if no match is found.
|
|
10
|
+
|
|
11
|
+
"""
|
|
12
|
+
if not isclass(obj_type):
|
|
13
|
+
obj_type = type(obj_type)
|
|
14
|
+
if not isclass(target_type):
|
|
15
|
+
target_type = type(target_type)
|
|
16
|
+
|
|
17
|
+
# Compare class types for exact match
|
|
18
|
+
if obj_type == target_type:
|
|
19
|
+
return 0
|
|
20
|
+
|
|
21
|
+
# Check if obj_type is a subclass of target_type using the MRO
|
|
22
|
+
try:
|
|
23
|
+
return obj_type.mro().index(target_type) # type: ignore
|
|
24
|
+
except ValueError:
|
|
25
|
+
return float("inf")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
T = TypeVar("T")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def is_type_compatible(obj_type: Type, target_type: T) -> TypeGuard[T]:
|
|
32
|
+
return _is_type_compatible(obj_type, target_type) < float("inf")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _is_type_compatible(obj_type: Type, target_type: Any) -> float:
|
|
36
|
+
"""
|
|
37
|
+
Relatively comprehensive type compatibility checker. This function is
|
|
38
|
+
used to check if a type has has a registered object that can
|
|
39
|
+
handle it.
|
|
40
|
+
|
|
41
|
+
Specifically returns the MRO distance where 0 indicates
|
|
42
|
+
an exact match, 1 indicates a direct ancestor, and so on. Returns a large number
|
|
43
|
+
if no compatibility is found.
|
|
44
|
+
|
|
45
|
+
"""
|
|
46
|
+
# Any type is compatible with any other type
|
|
47
|
+
if target_type is Any:
|
|
48
|
+
return 0
|
|
49
|
+
|
|
50
|
+
# If obj_type is a nested type, each of these types must be compatible
|
|
51
|
+
# with the corresponding type in target_type
|
|
52
|
+
if get_origin(obj_type) is Union or isinstance(obj_type, types.UnionType):
|
|
53
|
+
return max(_is_type_compatible(t, target_type) for t in get_args(obj_type))
|
|
54
|
+
|
|
55
|
+
# Handle OR types
|
|
56
|
+
if get_origin(target_type) is Union or isinstance(target_type, types.UnionType):
|
|
57
|
+
return min(_is_type_compatible(obj_type, t) for t in get_args(target_type))
|
|
58
|
+
|
|
59
|
+
# Handle Type[Values] like typehints where we want to typehint a class
|
|
60
|
+
if get_origin(target_type) == type: # noqa: E721
|
|
61
|
+
return _is_type_compatible(obj_type, get_args(target_type)[0])
|
|
62
|
+
|
|
63
|
+
# Handle dict[str, str] like typehints
|
|
64
|
+
# We assume that each arg in order must be matched with the target type
|
|
65
|
+
obj_origin = get_origin(obj_type)
|
|
66
|
+
target_origin = get_origin(target_type)
|
|
67
|
+
if obj_origin and target_origin:
|
|
68
|
+
if obj_origin == target_origin:
|
|
69
|
+
return max(
|
|
70
|
+
_is_type_compatible(t1, t2)
|
|
71
|
+
for t1, t2 in zip(get_args(obj_type), get_args(target_type))
|
|
72
|
+
)
|
|
73
|
+
else:
|
|
74
|
+
return float("inf")
|
|
75
|
+
|
|
76
|
+
# For lists, sets, and tuple objects make sure that each object matches
|
|
77
|
+
# the target type
|
|
78
|
+
if isinstance(obj_type, (list, set, tuple)):
|
|
79
|
+
if type(obj_type) != get_origin(target_type): # noqa: E721
|
|
80
|
+
return float("inf")
|
|
81
|
+
return max(
|
|
82
|
+
_is_type_compatible(obj, get_args(target_type)[0]) for obj in obj_type
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
if isinstance(target_type, type):
|
|
86
|
+
return mro_distance(obj_type, target_type)
|
|
87
|
+
|
|
88
|
+
# Default case
|
|
89
|
+
return float("inf")
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def remove_null_type(typehint: Type) -> Type:
|
|
93
|
+
if get_origin(typehint) is Union or isinstance(typehint, types.UnionType):
|
|
94
|
+
return Union[ # type: ignore
|
|
95
|
+
tuple( # type: ignore
|
|
96
|
+
[t for t in get_args(typehint) if t != type(None)] # noqa: E721
|
|
97
|
+
)
|
|
98
|
+
]
|
|
99
|
+
return typehint
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def has_null_type(typehint: Type) -> bool:
|
|
103
|
+
if get_origin(typehint) is Union or isinstance(typehint, types.UnionType):
|
|
104
|
+
return any(arg == type(None) for arg in get_args(typehint)) # noqa: E721
|
|
105
|
+
return typehint == type(None) # noqa: E721
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def get_typevar_mapping(cls):
|
|
109
|
+
"""
|
|
110
|
+
Get the raw typevar mappings {typvar: generic values} for each
|
|
111
|
+
typevar in the class hierarchy of `cls`.
|
|
112
|
+
|
|
113
|
+
Shared logic with Mountaineer. TODO: Move to a shared package.
|
|
114
|
+
|
|
115
|
+
"""
|
|
116
|
+
mapping: dict[Any, Any] = {}
|
|
117
|
+
|
|
118
|
+
# Traverse MRO in reverse order, except `object`
|
|
119
|
+
for base in reversed(cls.__mro__[:-1]):
|
|
120
|
+
# Skip non-generic classes
|
|
121
|
+
if not hasattr(base, "__orig_bases__"):
|
|
122
|
+
continue
|
|
123
|
+
|
|
124
|
+
for origin_base in base.__orig_bases__:
|
|
125
|
+
origin = get_origin(origin_base)
|
|
126
|
+
if origin:
|
|
127
|
+
base_params = getattr(origin, "__parameters__", [])
|
|
128
|
+
instantiated_params = get_args(origin_base)
|
|
129
|
+
|
|
130
|
+
# Update mapping with current base's mappings
|
|
131
|
+
base_mapping = dict(zip(base_params, instantiated_params))
|
|
132
|
+
for key, value in base_mapping.items():
|
|
133
|
+
# If value is another TypeVar, resolve it if possible
|
|
134
|
+
if isinstance(value, TypeVar) and value in mapping:
|
|
135
|
+
mapping[key] = mapping[value]
|
|
136
|
+
else:
|
|
137
|
+
mapping[key] = value
|
|
138
|
+
|
|
139
|
+
# Exclude TypeVars from the final mapping
|
|
140
|
+
return mapping
|
iceaxe/io.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import importlib.metadata
|
|
3
|
+
from functools import lru_cache, wraps
|
|
4
|
+
from json import loads as json_loads
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from re import search as re_search
|
|
7
|
+
from typing import Any, Callable, Coroutine, TypeVar
|
|
8
|
+
|
|
9
|
+
T = TypeVar("T")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def lru_cache_async(
|
|
13
|
+
maxsize: int | None = 100,
|
|
14
|
+
):
|
|
15
|
+
def decorator(
|
|
16
|
+
async_function: Callable[..., Coroutine[Any, Any, T]],
|
|
17
|
+
):
|
|
18
|
+
@lru_cache(maxsize=maxsize)
|
|
19
|
+
@wraps(async_function)
|
|
20
|
+
def internal(*args, **kwargs):
|
|
21
|
+
coroutine = async_function(*args, **kwargs)
|
|
22
|
+
# Unlike regular coroutine functions, futures can be awaited multiple times
|
|
23
|
+
# so our caller functions can await the same future on multiple cache hits
|
|
24
|
+
return asyncio.ensure_future(coroutine)
|
|
25
|
+
|
|
26
|
+
return internal
|
|
27
|
+
|
|
28
|
+
return decorator
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def resolve_package_path(package_name: str):
|
|
32
|
+
"""
|
|
33
|
+
Given a package distribution, returns the local file directory where the code
|
|
34
|
+
is located. This is resolved to the original reference if installed with `-e`
|
|
35
|
+
otherwise is a copy of the package.
|
|
36
|
+
|
|
37
|
+
NOTE: Copied from Mountaineer, refactor to a shared library.
|
|
38
|
+
|
|
39
|
+
"""
|
|
40
|
+
dist = importlib.metadata.distribution(package_name)
|
|
41
|
+
|
|
42
|
+
def normalize_package(package: str):
|
|
43
|
+
return package.replace("-", "_").lower()
|
|
44
|
+
|
|
45
|
+
# Recent versions of poetry install development packages (-e .) as direct URLs
|
|
46
|
+
# https://the-hitchhikers-guide-to-packaging.readthedocs.io/en/latest/introduction.html
|
|
47
|
+
# "Path configuration files have an extension of .pth, and each line must
|
|
48
|
+
# contain a single path that will be appended to sys.path."
|
|
49
|
+
package_name = normalize_package(dist.name)
|
|
50
|
+
symbolic_links = [
|
|
51
|
+
path
|
|
52
|
+
for path in (dist.files or [])
|
|
53
|
+
if path.name.lower() == f"{package_name}.pth"
|
|
54
|
+
]
|
|
55
|
+
dist_links = [
|
|
56
|
+
path
|
|
57
|
+
for path in (dist.files or [])
|
|
58
|
+
if path.name == "direct_url.json"
|
|
59
|
+
and re_search(package_name + r"-[0-9-.]+\.dist-info", path.parent.name.lower())
|
|
60
|
+
]
|
|
61
|
+
explicit_links = [
|
|
62
|
+
path
|
|
63
|
+
for path in (dist.files or [])
|
|
64
|
+
if path.parent.name.lower() == package_name
|
|
65
|
+
and (
|
|
66
|
+
# Sanity check that the parent is the high level project directory
|
|
67
|
+
# by looking for common base files
|
|
68
|
+
path.name == "__init__.py"
|
|
69
|
+
)
|
|
70
|
+
]
|
|
71
|
+
|
|
72
|
+
# The user installed code as an absolute package (ie. with pip install .) instead of
|
|
73
|
+
# as a reference. There's no need to sniff for the additional package path since
|
|
74
|
+
# we've already found it
|
|
75
|
+
if explicit_links:
|
|
76
|
+
# Right now we have a file pointer to __init__.py. Go up one level
|
|
77
|
+
# to the main package directory to return a directory
|
|
78
|
+
explicit_link = explicit_links[0]
|
|
79
|
+
return Path(str(dist.locate_file(explicit_link.parent)))
|
|
80
|
+
|
|
81
|
+
# Raw path will capture the path to the pyproject.toml file directory,
|
|
82
|
+
# not the actual package code directory
|
|
83
|
+
# Find the root, then resolve the package directory
|
|
84
|
+
raw_path: Path | None = None
|
|
85
|
+
|
|
86
|
+
if symbolic_links:
|
|
87
|
+
direct_url_path = symbolic_links[0]
|
|
88
|
+
raw_path = Path(str(dist.locate_file(direct_url_path.read_text().strip())))
|
|
89
|
+
elif dist_links:
|
|
90
|
+
dist_link = dist_links[0]
|
|
91
|
+
direct_metadata = json_loads(dist_link.read_text())
|
|
92
|
+
package_path = "/" + direct_metadata["url"].lstrip("file://").lstrip("/")
|
|
93
|
+
raw_path = Path(str(dist.locate_file(package_path)))
|
|
94
|
+
|
|
95
|
+
if not raw_path:
|
|
96
|
+
raise ValueError(
|
|
97
|
+
f"Could not find a valid path for package {dist.name}, found files: {dist.files}"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Sniff for the presence of the code directory
|
|
101
|
+
for path in raw_path.iterdir():
|
|
102
|
+
if path.is_dir() and normalize_package(path.name) == package_name:
|
|
103
|
+
return path
|
|
104
|
+
|
|
105
|
+
raise ValueError(
|
|
106
|
+
f"No matching package found in root path: {raw_path} {list(raw_path.iterdir())}"
|
|
107
|
+
)
|
iceaxe/logging.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
from json import dumps as json_dumps
|
|
4
|
+
from logging import Formatter, StreamHandler, getLogger
|
|
5
|
+
from os import environ
|
|
6
|
+
from time import monotonic_ns
|
|
7
|
+
|
|
8
|
+
from click import secho
|
|
9
|
+
from rich.console import Console
|
|
10
|
+
|
|
11
|
+
VERBOSITY_MAPPING = {
|
|
12
|
+
"INFO": logging.INFO,
|
|
13
|
+
"DEBUG": logging.DEBUG,
|
|
14
|
+
"WARNING": logging.WARNING,
|
|
15
|
+
"ERROR": logging.ERROR,
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class JsonFormatter(Formatter):
|
|
20
|
+
def format(self, record):
|
|
21
|
+
log_record = {
|
|
22
|
+
"level": record.levelname,
|
|
23
|
+
"name": record.name,
|
|
24
|
+
"timestamp": self.formatTime(record, self.datefmt),
|
|
25
|
+
"message": record.getMessage(),
|
|
26
|
+
}
|
|
27
|
+
if record.exc_info:
|
|
28
|
+
log_record["exception"] = self.formatException(record.exc_info)
|
|
29
|
+
return json_dumps(log_record)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ColorHandler(StreamHandler):
|
|
33
|
+
def emit(self, record):
|
|
34
|
+
try:
|
|
35
|
+
msg = self.format(record)
|
|
36
|
+
if record.levelno == logging.WARNING:
|
|
37
|
+
secho(msg, fg="yellow")
|
|
38
|
+
elif record.levelno >= logging.ERROR:
|
|
39
|
+
secho(msg, fg="red")
|
|
40
|
+
else:
|
|
41
|
+
secho(msg)
|
|
42
|
+
except Exception:
|
|
43
|
+
self.handleError(record)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def setup_logger(name, log_level=logging.DEBUG):
|
|
47
|
+
"""
|
|
48
|
+
Constructor for the main logger used by Mountaineer. Provided
|
|
49
|
+
convenient defaults for log level and formatting, alongside coloring
|
|
50
|
+
of stdout/stderr messages and JSON fields for structured parsing.
|
|
51
|
+
|
|
52
|
+
"""
|
|
53
|
+
logger = getLogger(name)
|
|
54
|
+
logger.setLevel(log_level)
|
|
55
|
+
|
|
56
|
+
# Create a handler that writes log records to the standard error
|
|
57
|
+
handler = ColorHandler()
|
|
58
|
+
handler.setLevel(log_level)
|
|
59
|
+
|
|
60
|
+
formatter = JsonFormatter()
|
|
61
|
+
handler.setFormatter(formatter)
|
|
62
|
+
|
|
63
|
+
logger.addHandler(handler)
|
|
64
|
+
|
|
65
|
+
return logger
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@contextmanager
|
|
69
|
+
def log_time_duration(message: str):
|
|
70
|
+
"""
|
|
71
|
+
Context manager to time a code block at runtime.
|
|
72
|
+
|
|
73
|
+
```python
|
|
74
|
+
with log_time_duration("Long computation"):
|
|
75
|
+
# Simulate work
|
|
76
|
+
sleep(10)
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
"""
|
|
80
|
+
start = monotonic_ns()
|
|
81
|
+
yield
|
|
82
|
+
LOGGER.debug(f"{message} : Took {(monotonic_ns() - start) / 1e9:.2f}s")
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
# Our global logger should only surface warnings and above by default
|
|
86
|
+
LOGGER = setup_logger(
|
|
87
|
+
__name__,
|
|
88
|
+
log_level=VERBOSITY_MAPPING[environ.get("ICEAXE_LOG_LEVEL", "WARNING")],
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
CONSOLE = Console()
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from iceaxe.schemas.db_stubs import DBObject
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ActionTopologicalSorter:
|
|
7
|
+
"""
|
|
8
|
+
Extends Python's native TopologicalSorter to group nodes by table_name. This provides
|
|
9
|
+
better semantic grouping within the migrations since most actions are oriented by their
|
|
10
|
+
parent table.
|
|
11
|
+
|
|
12
|
+
- Places cross-table dependencies and non-table actions (like types) according
|
|
13
|
+
to their default DAG order.
|
|
14
|
+
- Tables are processed in alphabetical order of their names.
|
|
15
|
+
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, graph: dict[DBObject, list[DBObject]]):
|
|
19
|
+
self.graph = graph
|
|
20
|
+
self.in_degree = defaultdict(int)
|
|
21
|
+
self.nodes = set(graph.keys())
|
|
22
|
+
|
|
23
|
+
for node, dependencies in list(graph.items()):
|
|
24
|
+
for dep in dependencies:
|
|
25
|
+
self.in_degree[node] += 1
|
|
26
|
+
if dep not in self.nodes:
|
|
27
|
+
self.nodes.add(dep)
|
|
28
|
+
self.graph[dep] = []
|
|
29
|
+
|
|
30
|
+
# Order based on the original yield / creation order of the nodes
|
|
31
|
+
self.node_to_ordering = {node: i for i, node in enumerate(self.graph.keys())}
|
|
32
|
+
|
|
33
|
+
def sort(self):
|
|
34
|
+
result = []
|
|
35
|
+
root_nodes_queued = sorted(
|
|
36
|
+
[node for node in self.nodes if self.in_degree[node] == 0],
|
|
37
|
+
key=self.node_key,
|
|
38
|
+
)
|
|
39
|
+
if not root_nodes_queued and self.nodes:
|
|
40
|
+
raise ValueError("Graph contains a cycle")
|
|
41
|
+
elif not root_nodes_queued:
|
|
42
|
+
return []
|
|
43
|
+
|
|
44
|
+
# Sort by the table name and then by the node representation
|
|
45
|
+
root_nodes_queued.sort(key=self.node_key)
|
|
46
|
+
|
|
47
|
+
# Always put the non-table actions first (things like global types)
|
|
48
|
+
non_table_nodes = [
|
|
49
|
+
node for node in root_nodes_queued if not hasattr(node, "table_name")
|
|
50
|
+
]
|
|
51
|
+
root_nodes_queued = [
|
|
52
|
+
node for node in root_nodes_queued if node not in non_table_nodes
|
|
53
|
+
]
|
|
54
|
+
root_nodes_queued = non_table_nodes + root_nodes_queued
|
|
55
|
+
|
|
56
|
+
queue = [root_nodes_queued.pop(0)]
|
|
57
|
+
processed = set()
|
|
58
|
+
|
|
59
|
+
while True:
|
|
60
|
+
if not queue:
|
|
61
|
+
# Pop another root node, if available
|
|
62
|
+
if root_nodes_queued:
|
|
63
|
+
queue.append(root_nodes_queued.pop(0))
|
|
64
|
+
continue
|
|
65
|
+
else:
|
|
66
|
+
# If no more root nodes are available and no more work to be done
|
|
67
|
+
break
|
|
68
|
+
|
|
69
|
+
current_node = queue.pop(0)
|
|
70
|
+
|
|
71
|
+
result.append(current_node)
|
|
72
|
+
processed.add(current_node)
|
|
73
|
+
|
|
74
|
+
# Newly unblocked nodes, since we've resolved their dependencies
|
|
75
|
+
# with the current processing
|
|
76
|
+
new_ready = []
|
|
77
|
+
for dependent, deps in self.graph.items():
|
|
78
|
+
if current_node in deps and dependent not in processed:
|
|
79
|
+
self.in_degree[dependent] -= 1
|
|
80
|
+
if self.in_degree[dependent] == 0:
|
|
81
|
+
new_ready.append(dependent)
|
|
82
|
+
|
|
83
|
+
# Add newly ready nodes to queue in sorted order
|
|
84
|
+
queue.extend(
|
|
85
|
+
sorted(new_ready, key=lambda node: self.node_to_ordering[node])
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
if len(result) != len(self.nodes):
|
|
89
|
+
raise ValueError("Graph contains a cycle")
|
|
90
|
+
|
|
91
|
+
return result
|
|
92
|
+
|
|
93
|
+
@staticmethod
|
|
94
|
+
def node_key(node: DBObject):
|
|
95
|
+
# Not all objects specify a table_name, but if they do we want to explicitly
|
|
96
|
+
# sort before the representation
|
|
97
|
+
table_name = getattr(node, "table_name", "")
|
|
98
|
+
return (table_name, node.representation())
|
iceaxe/migrations/cli.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
from inspect import isclass
|
|
2
|
+
from time import monotonic_ns
|
|
3
|
+
|
|
4
|
+
from iceaxe.base import DBModelMetaclass, TableBase
|
|
5
|
+
from iceaxe.io import resolve_package_path
|
|
6
|
+
from iceaxe.logging import CONSOLE
|
|
7
|
+
from iceaxe.schemas.db_serializer import DatabaseSerializer
|
|
8
|
+
from iceaxe.session import DBConnection
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
async def handle_generate(
|
|
12
|
+
package: str, db_connection: DBConnection, message: str | None = None
|
|
13
|
+
):
|
|
14
|
+
"""
|
|
15
|
+
Creates a new migration definition file, comparing the previous version
|
|
16
|
+
(if it exists) with the current schema.
|
|
17
|
+
|
|
18
|
+
:param package: The current python package name. This should match the name of the
|
|
19
|
+
project that's specified in pyproject.toml or setup.py.
|
|
20
|
+
|
|
21
|
+
:param message: An optional message to include in the migration file. Helps
|
|
22
|
+
with describing changes and searching for past migration logic over time.
|
|
23
|
+
|
|
24
|
+
```python {{sticky: True}}
|
|
25
|
+
from iceaxe.migrations.cli import handle_generate
|
|
26
|
+
from click import command, option
|
|
27
|
+
|
|
28
|
+
@command()
|
|
29
|
+
@option("--message", help="A message to include in the migration file.")
|
|
30
|
+
def generate_migration(message: str):
|
|
31
|
+
db_connection = DBConnection(...)
|
|
32
|
+
handle_generate("my_project", db_connection, message=message)
|
|
33
|
+
```
|
|
34
|
+
"""
|
|
35
|
+
# Any local imports must be done here to avoid circular imports because migrations.__init__
|
|
36
|
+
# imports this file.
|
|
37
|
+
from iceaxe.migrations.client_io import fetch_migrations
|
|
38
|
+
from iceaxe.migrations.generator import MigrationGenerator
|
|
39
|
+
from iceaxe.migrations.migrator import Migrator
|
|
40
|
+
|
|
41
|
+
CONSOLE.print("[bold blue]Generating migration to current schema")
|
|
42
|
+
|
|
43
|
+
CONSOLE.print(
|
|
44
|
+
"[grey58]Note that Iceaxe's migration support is well tested but still in beta."
|
|
45
|
+
)
|
|
46
|
+
CONSOLE.print(
|
|
47
|
+
"[grey58]File an issue @ https://github.com/piercefreeman/iceaxe/issues if you encounter any problems."
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# Locate the migrations directory that belongs to this project
|
|
51
|
+
package_path = resolve_package_path(package)
|
|
52
|
+
migrations_path = package_path / "migrations"
|
|
53
|
+
|
|
54
|
+
# Create the path if it doesn't exist
|
|
55
|
+
migrations_path.mkdir(exist_ok=True)
|
|
56
|
+
if not (migrations_path / "__init__.py").exists():
|
|
57
|
+
(migrations_path / "__init__.py").touch()
|
|
58
|
+
|
|
59
|
+
# Get all of the instances that have been registered
|
|
60
|
+
# in memory scope by the user.
|
|
61
|
+
models = [
|
|
62
|
+
cls
|
|
63
|
+
for cls in DBModelMetaclass.get_registry()
|
|
64
|
+
if isclass(cls) and issubclass(cls, TableBase)
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
db_serializer = DatabaseSerializer()
|
|
68
|
+
db_objects = []
|
|
69
|
+
async for values in db_serializer.get_objects(db_connection):
|
|
70
|
+
db_objects.append(values)
|
|
71
|
+
|
|
72
|
+
migration_generator = MigrationGenerator()
|
|
73
|
+
up_objects = list(migration_generator.serializer.delegate(models))
|
|
74
|
+
|
|
75
|
+
# Get the current revision from the database, this should represent the "down" revision
|
|
76
|
+
# for the new migration
|
|
77
|
+
migrator = Migrator(db_connection)
|
|
78
|
+
await migrator.init_db()
|
|
79
|
+
current_revision = await migrator.get_active_revision()
|
|
80
|
+
|
|
81
|
+
# Make sure there's not a duplicate revision that already have this down revision. If so that means
|
|
82
|
+
# that we will have two conflicting migration chains
|
|
83
|
+
migration_revisions = fetch_migrations(migrations_path)
|
|
84
|
+
conflict_migrations = [
|
|
85
|
+
migration
|
|
86
|
+
for migration in migration_revisions
|
|
87
|
+
if migration.down_revision == current_revision
|
|
88
|
+
]
|
|
89
|
+
if conflict_migrations:
|
|
90
|
+
up_revisions = {migration.up_revision for migration in conflict_migrations}
|
|
91
|
+
raise ValueError(
|
|
92
|
+
f"Found conflicting migrations with down revision {current_revision} (conflicts: {up_revisions}).\n"
|
|
93
|
+
"If you're trying to generate a new migration, make sure to apply the previous migration first - or delete the old one and recreate."
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
migration_code, revision = await migration_generator.new_migration(
|
|
97
|
+
db_objects,
|
|
98
|
+
up_objects,
|
|
99
|
+
down_revision=current_revision,
|
|
100
|
+
user_message=message,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Create the migration file. The change of a conflict with this timestamp is very low, but we make sure
|
|
104
|
+
# not to override any existing files anyway.
|
|
105
|
+
migration_file_path = migrations_path / f"rev_{revision}.py"
|
|
106
|
+
if migration_file_path.exists():
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"Migration file {migration_file_path} already exists. Wait a second and try again."
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
migration_file_path.write_text(migration_code)
|
|
112
|
+
|
|
113
|
+
CONSOLE.print(f"[bold green]New migration added: {migration_file_path.name}")
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
async def handle_apply(
|
|
117
|
+
package: str,
|
|
118
|
+
db_connection: DBConnection,
|
|
119
|
+
):
|
|
120
|
+
"""
|
|
121
|
+
Applies all migrations that have not been applied to the database.
|
|
122
|
+
|
|
123
|
+
:param package: The current python package name. This should match the name of the
|
|
124
|
+
project that's specified in pyproject.toml or setup.py.
|
|
125
|
+
|
|
126
|
+
"""
|
|
127
|
+
from iceaxe.migrations.client_io import fetch_migrations, sort_migrations
|
|
128
|
+
from iceaxe.migrations.migrator import Migrator
|
|
129
|
+
|
|
130
|
+
migrations_path = resolve_package_path(package) / "migrations"
|
|
131
|
+
if not migrations_path.exists():
|
|
132
|
+
raise ValueError(f"Migrations path {migrations_path} does not exist.")
|
|
133
|
+
|
|
134
|
+
# Load all the migration files into memory and locate the subclasses of MigrationRevisionBase
|
|
135
|
+
migration_revisions = fetch_migrations(migrations_path)
|
|
136
|
+
migration_revisions = sort_migrations(migration_revisions)
|
|
137
|
+
|
|
138
|
+
# Get the current revision from the database
|
|
139
|
+
migrator = Migrator(db_connection)
|
|
140
|
+
await migrator.init_db()
|
|
141
|
+
current_revision = await migrator.get_active_revision()
|
|
142
|
+
|
|
143
|
+
CONSOLE.print(f"Current revision: {current_revision}")
|
|
144
|
+
|
|
145
|
+
# Find the item in the sequence that has down_revision equal to the current_revision
|
|
146
|
+
# This indicates the next migration to apply
|
|
147
|
+
next_migration_index = None
|
|
148
|
+
for i, revision in enumerate(migration_revisions):
|
|
149
|
+
if revision.down_revision == current_revision:
|
|
150
|
+
next_migration_index = i
|
|
151
|
+
break
|
|
152
|
+
|
|
153
|
+
if next_migration_index is None:
|
|
154
|
+
raise ValueError(
|
|
155
|
+
f"Could not find a migration to apply after revision {current_revision}."
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Get the chain after this index, this should indicate the next migration to apply
|
|
159
|
+
migration_chain = migration_revisions[next_migration_index:]
|
|
160
|
+
CONSOLE.print(f"Applying {len(migration_chain)} migrations...")
|
|
161
|
+
|
|
162
|
+
for migration in migration_chain:
|
|
163
|
+
with CONSOLE.status(
|
|
164
|
+
f"[bold blue]Applying {migration.up_revision}...", spinner="dots"
|
|
165
|
+
):
|
|
166
|
+
start = monotonic_ns()
|
|
167
|
+
await migration._handle_up(db_connection)
|
|
168
|
+
|
|
169
|
+
CONSOLE.print(
|
|
170
|
+
f"[bold green]🚀 Applied {migration.up_revision} in {(monotonic_ns() - start) / 1e9:.2f}s"
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
async def handle_rollback(
|
|
175
|
+
package: str,
|
|
176
|
+
db_connection: DBConnection,
|
|
177
|
+
):
|
|
178
|
+
"""
|
|
179
|
+
Rolls back the last migration that was applied to the database.
|
|
180
|
+
|
|
181
|
+
:param package: The current python package name. This should match the name of the
|
|
182
|
+
project that's specified in pyproject.toml or setup.py.
|
|
183
|
+
|
|
184
|
+
"""
|
|
185
|
+
from iceaxe.migrations.client_io import fetch_migrations, sort_migrations
|
|
186
|
+
from iceaxe.migrations.migrator import Migrator
|
|
187
|
+
|
|
188
|
+
migrations_path = resolve_package_path(package) / "migrations"
|
|
189
|
+
if not migrations_path.exists():
|
|
190
|
+
raise ValueError(f"Migrations path {migrations_path} does not exist.")
|
|
191
|
+
|
|
192
|
+
# Load all the migration files into memory and locate the subclasses of MigrationRevisionBase
|
|
193
|
+
migration_revisions = fetch_migrations(migrations_path)
|
|
194
|
+
migration_revisions = sort_migrations(migration_revisions)
|
|
195
|
+
|
|
196
|
+
# Get the current revision from the database
|
|
197
|
+
migrator = Migrator(db_connection)
|
|
198
|
+
await migrator.init_db()
|
|
199
|
+
current_revision = await migrator.get_active_revision()
|
|
200
|
+
|
|
201
|
+
CONSOLE.print(f"Current revision: {current_revision}")
|
|
202
|
+
|
|
203
|
+
# Find the item in the sequence that has down_revision equal to the current_revision
|
|
204
|
+
# This indicates the next migration to apply
|
|
205
|
+
this_migration_index = None
|
|
206
|
+
for i, revision in enumerate(migration_revisions):
|
|
207
|
+
if revision.up_revision == current_revision:
|
|
208
|
+
this_migration_index = i
|
|
209
|
+
break
|
|
210
|
+
|
|
211
|
+
if this_migration_index is None:
|
|
212
|
+
raise ValueError(
|
|
213
|
+
f"Could not find a migration matching {current_revision} for rollback."
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# Get the chain after this index, this should indicate the next migration to apply
|
|
217
|
+
this_migration = migration_revisions[this_migration_index]
|
|
218
|
+
|
|
219
|
+
with CONSOLE.status(
|
|
220
|
+
f"[bold blue]Rolling back revision {this_migration.up_revision} to {this_migration.down_revision}...",
|
|
221
|
+
spinner="dots",
|
|
222
|
+
):
|
|
223
|
+
start = monotonic_ns()
|
|
224
|
+
await this_migration._handle_down(db_connection)
|
|
225
|
+
|
|
226
|
+
CONSOLE.print(
|
|
227
|
+
f"[bold green]🪃 Rolled back migration to {this_migration.down_revision} in {(monotonic_ns() - start) / 1e9:.2f}s"
|
|
228
|
+
)
|