iceaxe 0.7.1__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of iceaxe might be problematic. Click here for more details.

Files changed (75) hide show
  1. iceaxe/__init__.py +20 -0
  2. iceaxe/__tests__/__init__.py +0 -0
  3. iceaxe/__tests__/benchmarks/__init__.py +0 -0
  4. iceaxe/__tests__/benchmarks/test_bulk_insert.py +45 -0
  5. iceaxe/__tests__/benchmarks/test_select.py +114 -0
  6. iceaxe/__tests__/conf_models.py +133 -0
  7. iceaxe/__tests__/conftest.py +204 -0
  8. iceaxe/__tests__/docker_helpers.py +208 -0
  9. iceaxe/__tests__/helpers.py +268 -0
  10. iceaxe/__tests__/migrations/__init__.py +0 -0
  11. iceaxe/__tests__/migrations/conftest.py +36 -0
  12. iceaxe/__tests__/migrations/test_action_sorter.py +237 -0
  13. iceaxe/__tests__/migrations/test_generator.py +140 -0
  14. iceaxe/__tests__/migrations/test_generics.py +91 -0
  15. iceaxe/__tests__/mountaineer/__init__.py +0 -0
  16. iceaxe/__tests__/mountaineer/dependencies/__init__.py +0 -0
  17. iceaxe/__tests__/mountaineer/dependencies/test_core.py +76 -0
  18. iceaxe/__tests__/schemas/__init__.py +0 -0
  19. iceaxe/__tests__/schemas/test_actions.py +1264 -0
  20. iceaxe/__tests__/schemas/test_cli.py +25 -0
  21. iceaxe/__tests__/schemas/test_db_memory_serializer.py +1525 -0
  22. iceaxe/__tests__/schemas/test_db_serializer.py +398 -0
  23. iceaxe/__tests__/schemas/test_db_stubs.py +190 -0
  24. iceaxe/__tests__/test_alias.py +83 -0
  25. iceaxe/__tests__/test_base.py +52 -0
  26. iceaxe/__tests__/test_comparison.py +383 -0
  27. iceaxe/__tests__/test_field.py +11 -0
  28. iceaxe/__tests__/test_helpers.py +9 -0
  29. iceaxe/__tests__/test_modifications.py +151 -0
  30. iceaxe/__tests__/test_queries.py +605 -0
  31. iceaxe/__tests__/test_queries_str.py +173 -0
  32. iceaxe/__tests__/test_session.py +1511 -0
  33. iceaxe/__tests__/test_text_search.py +287 -0
  34. iceaxe/alias_values.py +67 -0
  35. iceaxe/base.py +350 -0
  36. iceaxe/comparison.py +560 -0
  37. iceaxe/field.py +250 -0
  38. iceaxe/functions.py +906 -0
  39. iceaxe/generics.py +140 -0
  40. iceaxe/io.py +107 -0
  41. iceaxe/logging.py +91 -0
  42. iceaxe/migrations/__init__.py +5 -0
  43. iceaxe/migrations/action_sorter.py +98 -0
  44. iceaxe/migrations/cli.py +228 -0
  45. iceaxe/migrations/client_io.py +62 -0
  46. iceaxe/migrations/generator.py +404 -0
  47. iceaxe/migrations/migration.py +86 -0
  48. iceaxe/migrations/migrator.py +101 -0
  49. iceaxe/modifications.py +176 -0
  50. iceaxe/mountaineer/__init__.py +10 -0
  51. iceaxe/mountaineer/cli.py +74 -0
  52. iceaxe/mountaineer/config.py +46 -0
  53. iceaxe/mountaineer/dependencies/__init__.py +6 -0
  54. iceaxe/mountaineer/dependencies/core.py +67 -0
  55. iceaxe/postgres.py +133 -0
  56. iceaxe/py.typed +0 -0
  57. iceaxe/queries.py +1455 -0
  58. iceaxe/queries_str.py +294 -0
  59. iceaxe/schemas/__init__.py +0 -0
  60. iceaxe/schemas/actions.py +864 -0
  61. iceaxe/schemas/cli.py +30 -0
  62. iceaxe/schemas/db_memory_serializer.py +705 -0
  63. iceaxe/schemas/db_serializer.py +346 -0
  64. iceaxe/schemas/db_stubs.py +525 -0
  65. iceaxe/session.py +860 -0
  66. iceaxe/session_optimized.c +12035 -0
  67. iceaxe/session_optimized.cpython-313-darwin.so +0 -0
  68. iceaxe/session_optimized.pyx +212 -0
  69. iceaxe/sql_types.py +148 -0
  70. iceaxe/typing.py +73 -0
  71. iceaxe-0.7.1.dist-info/METADATA +261 -0
  72. iceaxe-0.7.1.dist-info/RECORD +75 -0
  73. iceaxe-0.7.1.dist-info/WHEEL +6 -0
  74. iceaxe-0.7.1.dist-info/licenses/LICENSE +21 -0
  75. iceaxe-0.7.1.dist-info/top_level.txt +1 -0
iceaxe/generics.py ADDED
@@ -0,0 +1,140 @@
1
+ import types
2
+ from inspect import isclass
3
+ from typing import Any, Type, TypeGuard, TypeVar, Union, get_args, get_origin
4
+
5
+
6
+ def mro_distance(obj_type: Type, target_type: Type) -> float:
7
+ """
8
+ Calculate the MRO distance between obj_type and target_type.
9
+ Returns a large number if no match is found.
10
+
11
+ """
12
+ if not isclass(obj_type):
13
+ obj_type = type(obj_type)
14
+ if not isclass(target_type):
15
+ target_type = type(target_type)
16
+
17
+ # Compare class types for exact match
18
+ if obj_type == target_type:
19
+ return 0
20
+
21
+ # Check if obj_type is a subclass of target_type using the MRO
22
+ try:
23
+ return obj_type.mro().index(target_type) # type: ignore
24
+ except ValueError:
25
+ return float("inf")
26
+
27
+
28
+ T = TypeVar("T")
29
+
30
+
31
+ def is_type_compatible(obj_type: Type, target_type: T) -> TypeGuard[T]:
32
+ return _is_type_compatible(obj_type, target_type) < float("inf")
33
+
34
+
35
+ def _is_type_compatible(obj_type: Type, target_type: Any) -> float:
36
+ """
37
+ Relatively comprehensive type compatibility checker. This function is
38
+ used to check if a type has has a registered object that can
39
+ handle it.
40
+
41
+ Specifically returns the MRO distance where 0 indicates
42
+ an exact match, 1 indicates a direct ancestor, and so on. Returns a large number
43
+ if no compatibility is found.
44
+
45
+ """
46
+ # Any type is compatible with any other type
47
+ if target_type is Any:
48
+ return 0
49
+
50
+ # If obj_type is a nested type, each of these types must be compatible
51
+ # with the corresponding type in target_type
52
+ if get_origin(obj_type) is Union or isinstance(obj_type, types.UnionType):
53
+ return max(_is_type_compatible(t, target_type) for t in get_args(obj_type))
54
+
55
+ # Handle OR types
56
+ if get_origin(target_type) is Union or isinstance(target_type, types.UnionType):
57
+ return min(_is_type_compatible(obj_type, t) for t in get_args(target_type))
58
+
59
+ # Handle Type[Values] like typehints where we want to typehint a class
60
+ if get_origin(target_type) == type: # noqa: E721
61
+ return _is_type_compatible(obj_type, get_args(target_type)[0])
62
+
63
+ # Handle dict[str, str] like typehints
64
+ # We assume that each arg in order must be matched with the target type
65
+ obj_origin = get_origin(obj_type)
66
+ target_origin = get_origin(target_type)
67
+ if obj_origin and target_origin:
68
+ if obj_origin == target_origin:
69
+ return max(
70
+ _is_type_compatible(t1, t2)
71
+ for t1, t2 in zip(get_args(obj_type), get_args(target_type))
72
+ )
73
+ else:
74
+ return float("inf")
75
+
76
+ # For lists, sets, and tuple objects make sure that each object matches
77
+ # the target type
78
+ if isinstance(obj_type, (list, set, tuple)):
79
+ if type(obj_type) != get_origin(target_type): # noqa: E721
80
+ return float("inf")
81
+ return max(
82
+ _is_type_compatible(obj, get_args(target_type)[0]) for obj in obj_type
83
+ )
84
+
85
+ if isinstance(target_type, type):
86
+ return mro_distance(obj_type, target_type)
87
+
88
+ # Default case
89
+ return float("inf")
90
+
91
+
92
+ def remove_null_type(typehint: Type) -> Type:
93
+ if get_origin(typehint) is Union or isinstance(typehint, types.UnionType):
94
+ return Union[ # type: ignore
95
+ tuple( # type: ignore
96
+ [t for t in get_args(typehint) if t != type(None)] # noqa: E721
97
+ )
98
+ ]
99
+ return typehint
100
+
101
+
102
+ def has_null_type(typehint: Type) -> bool:
103
+ if get_origin(typehint) is Union or isinstance(typehint, types.UnionType):
104
+ return any(arg == type(None) for arg in get_args(typehint)) # noqa: E721
105
+ return typehint == type(None) # noqa: E721
106
+
107
+
108
+ def get_typevar_mapping(cls):
109
+ """
110
+ Get the raw typevar mappings {typvar: generic values} for each
111
+ typevar in the class hierarchy of `cls`.
112
+
113
+ Shared logic with Mountaineer. TODO: Move to a shared package.
114
+
115
+ """
116
+ mapping: dict[Any, Any] = {}
117
+
118
+ # Traverse MRO in reverse order, except `object`
119
+ for base in reversed(cls.__mro__[:-1]):
120
+ # Skip non-generic classes
121
+ if not hasattr(base, "__orig_bases__"):
122
+ continue
123
+
124
+ for origin_base in base.__orig_bases__:
125
+ origin = get_origin(origin_base)
126
+ if origin:
127
+ base_params = getattr(origin, "__parameters__", [])
128
+ instantiated_params = get_args(origin_base)
129
+
130
+ # Update mapping with current base's mappings
131
+ base_mapping = dict(zip(base_params, instantiated_params))
132
+ for key, value in base_mapping.items():
133
+ # If value is another TypeVar, resolve it if possible
134
+ if isinstance(value, TypeVar) and value in mapping:
135
+ mapping[key] = mapping[value]
136
+ else:
137
+ mapping[key] = value
138
+
139
+ # Exclude TypeVars from the final mapping
140
+ return mapping
iceaxe/io.py ADDED
@@ -0,0 +1,107 @@
1
+ import asyncio
2
+ import importlib.metadata
3
+ from functools import lru_cache, wraps
4
+ from json import loads as json_loads
5
+ from pathlib import Path
6
+ from re import search as re_search
7
+ from typing import Any, Callable, Coroutine, TypeVar
8
+
9
+ T = TypeVar("T")
10
+
11
+
12
+ def lru_cache_async(
13
+ maxsize: int | None = 100,
14
+ ):
15
+ def decorator(
16
+ async_function: Callable[..., Coroutine[Any, Any, T]],
17
+ ):
18
+ @lru_cache(maxsize=maxsize)
19
+ @wraps(async_function)
20
+ def internal(*args, **kwargs):
21
+ coroutine = async_function(*args, **kwargs)
22
+ # Unlike regular coroutine functions, futures can be awaited multiple times
23
+ # so our caller functions can await the same future on multiple cache hits
24
+ return asyncio.ensure_future(coroutine)
25
+
26
+ return internal
27
+
28
+ return decorator
29
+
30
+
31
+ def resolve_package_path(package_name: str):
32
+ """
33
+ Given a package distribution, returns the local file directory where the code
34
+ is located. This is resolved to the original reference if installed with `-e`
35
+ otherwise is a copy of the package.
36
+
37
+ NOTE: Copied from Mountaineer, refactor to a shared library.
38
+
39
+ """
40
+ dist = importlib.metadata.distribution(package_name)
41
+
42
+ def normalize_package(package: str):
43
+ return package.replace("-", "_").lower()
44
+
45
+ # Recent versions of poetry install development packages (-e .) as direct URLs
46
+ # https://the-hitchhikers-guide-to-packaging.readthedocs.io/en/latest/introduction.html
47
+ # "Path configuration files have an extension of .pth, and each line must
48
+ # contain a single path that will be appended to sys.path."
49
+ package_name = normalize_package(dist.name)
50
+ symbolic_links = [
51
+ path
52
+ for path in (dist.files or [])
53
+ if path.name.lower() == f"{package_name}.pth"
54
+ ]
55
+ dist_links = [
56
+ path
57
+ for path in (dist.files or [])
58
+ if path.name == "direct_url.json"
59
+ and re_search(package_name + r"-[0-9-.]+\.dist-info", path.parent.name.lower())
60
+ ]
61
+ explicit_links = [
62
+ path
63
+ for path in (dist.files or [])
64
+ if path.parent.name.lower() == package_name
65
+ and (
66
+ # Sanity check that the parent is the high level project directory
67
+ # by looking for common base files
68
+ path.name == "__init__.py"
69
+ )
70
+ ]
71
+
72
+ # The user installed code as an absolute package (ie. with pip install .) instead of
73
+ # as a reference. There's no need to sniff for the additional package path since
74
+ # we've already found it
75
+ if explicit_links:
76
+ # Right now we have a file pointer to __init__.py. Go up one level
77
+ # to the main package directory to return a directory
78
+ explicit_link = explicit_links[0]
79
+ return Path(str(dist.locate_file(explicit_link.parent)))
80
+
81
+ # Raw path will capture the path to the pyproject.toml file directory,
82
+ # not the actual package code directory
83
+ # Find the root, then resolve the package directory
84
+ raw_path: Path | None = None
85
+
86
+ if symbolic_links:
87
+ direct_url_path = symbolic_links[0]
88
+ raw_path = Path(str(dist.locate_file(direct_url_path.read_text().strip())))
89
+ elif dist_links:
90
+ dist_link = dist_links[0]
91
+ direct_metadata = json_loads(dist_link.read_text())
92
+ package_path = "/" + direct_metadata["url"].lstrip("file://").lstrip("/")
93
+ raw_path = Path(str(dist.locate_file(package_path)))
94
+
95
+ if not raw_path:
96
+ raise ValueError(
97
+ f"Could not find a valid path for package {dist.name}, found files: {dist.files}"
98
+ )
99
+
100
+ # Sniff for the presence of the code directory
101
+ for path in raw_path.iterdir():
102
+ if path.is_dir() and normalize_package(path.name) == package_name:
103
+ return path
104
+
105
+ raise ValueError(
106
+ f"No matching package found in root path: {raw_path} {list(raw_path.iterdir())}"
107
+ )
iceaxe/logging.py ADDED
@@ -0,0 +1,91 @@
1
+ import logging
2
+ from contextlib import contextmanager
3
+ from json import dumps as json_dumps
4
+ from logging import Formatter, StreamHandler, getLogger
5
+ from os import environ
6
+ from time import monotonic_ns
7
+
8
+ from click import secho
9
+ from rich.console import Console
10
+
11
+ VERBOSITY_MAPPING = {
12
+ "INFO": logging.INFO,
13
+ "DEBUG": logging.DEBUG,
14
+ "WARNING": logging.WARNING,
15
+ "ERROR": logging.ERROR,
16
+ }
17
+
18
+
19
+ class JsonFormatter(Formatter):
20
+ def format(self, record):
21
+ log_record = {
22
+ "level": record.levelname,
23
+ "name": record.name,
24
+ "timestamp": self.formatTime(record, self.datefmt),
25
+ "message": record.getMessage(),
26
+ }
27
+ if record.exc_info:
28
+ log_record["exception"] = self.formatException(record.exc_info)
29
+ return json_dumps(log_record)
30
+
31
+
32
+ class ColorHandler(StreamHandler):
33
+ def emit(self, record):
34
+ try:
35
+ msg = self.format(record)
36
+ if record.levelno == logging.WARNING:
37
+ secho(msg, fg="yellow")
38
+ elif record.levelno >= logging.ERROR:
39
+ secho(msg, fg="red")
40
+ else:
41
+ secho(msg)
42
+ except Exception:
43
+ self.handleError(record)
44
+
45
+
46
+ def setup_logger(name, log_level=logging.DEBUG):
47
+ """
48
+ Constructor for the main logger used by Mountaineer. Provided
49
+ convenient defaults for log level and formatting, alongside coloring
50
+ of stdout/stderr messages and JSON fields for structured parsing.
51
+
52
+ """
53
+ logger = getLogger(name)
54
+ logger.setLevel(log_level)
55
+
56
+ # Create a handler that writes log records to the standard error
57
+ handler = ColorHandler()
58
+ handler.setLevel(log_level)
59
+
60
+ formatter = JsonFormatter()
61
+ handler.setFormatter(formatter)
62
+
63
+ logger.addHandler(handler)
64
+
65
+ return logger
66
+
67
+
68
+ @contextmanager
69
+ def log_time_duration(message: str):
70
+ """
71
+ Context manager to time a code block at runtime.
72
+
73
+ ```python
74
+ with log_time_duration("Long computation"):
75
+ # Simulate work
76
+ sleep(10)
77
+ ```
78
+
79
+ """
80
+ start = monotonic_ns()
81
+ yield
82
+ LOGGER.debug(f"{message} : Took {(monotonic_ns() - start) / 1e9:.2f}s")
83
+
84
+
85
+ # Our global logger should only surface warnings and above by default
86
+ LOGGER = setup_logger(
87
+ __name__,
88
+ log_level=VERBOSITY_MAPPING[environ.get("ICEAXE_LOG_LEVEL", "WARNING")],
89
+ )
90
+
91
+ CONSOLE = Console()
@@ -0,0 +1,5 @@
1
+ from .cli import (
2
+ handle_apply as handle_apply,
3
+ handle_generate as handle_generate,
4
+ handle_rollback as handle_rollback,
5
+ )
@@ -0,0 +1,98 @@
1
+ from collections import defaultdict
2
+
3
+ from iceaxe.schemas.db_stubs import DBObject
4
+
5
+
6
+ class ActionTopologicalSorter:
7
+ """
8
+ Extends Python's native TopologicalSorter to group nodes by table_name. This provides
9
+ better semantic grouping within the migrations since most actions are oriented by their
10
+ parent table.
11
+
12
+ - Places cross-table dependencies and non-table actions (like types) according
13
+ to their default DAG order.
14
+ - Tables are processed in alphabetical order of their names.
15
+
16
+ """
17
+
18
+ def __init__(self, graph: dict[DBObject, list[DBObject]]):
19
+ self.graph = graph
20
+ self.in_degree = defaultdict(int)
21
+ self.nodes = set(graph.keys())
22
+
23
+ for node, dependencies in list(graph.items()):
24
+ for dep in dependencies:
25
+ self.in_degree[node] += 1
26
+ if dep not in self.nodes:
27
+ self.nodes.add(dep)
28
+ self.graph[dep] = []
29
+
30
+ # Order based on the original yield / creation order of the nodes
31
+ self.node_to_ordering = {node: i for i, node in enumerate(self.graph.keys())}
32
+
33
+ def sort(self):
34
+ result = []
35
+ root_nodes_queued = sorted(
36
+ [node for node in self.nodes if self.in_degree[node] == 0],
37
+ key=self.node_key,
38
+ )
39
+ if not root_nodes_queued and self.nodes:
40
+ raise ValueError("Graph contains a cycle")
41
+ elif not root_nodes_queued:
42
+ return []
43
+
44
+ # Sort by the table name and then by the node representation
45
+ root_nodes_queued.sort(key=self.node_key)
46
+
47
+ # Always put the non-table actions first (things like global types)
48
+ non_table_nodes = [
49
+ node for node in root_nodes_queued if not hasattr(node, "table_name")
50
+ ]
51
+ root_nodes_queued = [
52
+ node for node in root_nodes_queued if node not in non_table_nodes
53
+ ]
54
+ root_nodes_queued = non_table_nodes + root_nodes_queued
55
+
56
+ queue = [root_nodes_queued.pop(0)]
57
+ processed = set()
58
+
59
+ while True:
60
+ if not queue:
61
+ # Pop another root node, if available
62
+ if root_nodes_queued:
63
+ queue.append(root_nodes_queued.pop(0))
64
+ continue
65
+ else:
66
+ # If no more root nodes are available and no more work to be done
67
+ break
68
+
69
+ current_node = queue.pop(0)
70
+
71
+ result.append(current_node)
72
+ processed.add(current_node)
73
+
74
+ # Newly unblocked nodes, since we've resolved their dependencies
75
+ # with the current processing
76
+ new_ready = []
77
+ for dependent, deps in self.graph.items():
78
+ if current_node in deps and dependent not in processed:
79
+ self.in_degree[dependent] -= 1
80
+ if self.in_degree[dependent] == 0:
81
+ new_ready.append(dependent)
82
+
83
+ # Add newly ready nodes to queue in sorted order
84
+ queue.extend(
85
+ sorted(new_ready, key=lambda node: self.node_to_ordering[node])
86
+ )
87
+
88
+ if len(result) != len(self.nodes):
89
+ raise ValueError("Graph contains a cycle")
90
+
91
+ return result
92
+
93
+ @staticmethod
94
+ def node_key(node: DBObject):
95
+ # Not all objects specify a table_name, but if they do we want to explicitly
96
+ # sort before the representation
97
+ table_name = getattr(node, "table_name", "")
98
+ return (table_name, node.representation())
@@ -0,0 +1,228 @@
1
+ from inspect import isclass
2
+ from time import monotonic_ns
3
+
4
+ from iceaxe.base import DBModelMetaclass, TableBase
5
+ from iceaxe.io import resolve_package_path
6
+ from iceaxe.logging import CONSOLE
7
+ from iceaxe.schemas.db_serializer import DatabaseSerializer
8
+ from iceaxe.session import DBConnection
9
+
10
+
11
+ async def handle_generate(
12
+ package: str, db_connection: DBConnection, message: str | None = None
13
+ ):
14
+ """
15
+ Creates a new migration definition file, comparing the previous version
16
+ (if it exists) with the current schema.
17
+
18
+ :param package: The current python package name. This should match the name of the
19
+ project that's specified in pyproject.toml or setup.py.
20
+
21
+ :param message: An optional message to include in the migration file. Helps
22
+ with describing changes and searching for past migration logic over time.
23
+
24
+ ```python {{sticky: True}}
25
+ from iceaxe.migrations.cli import handle_generate
26
+ from click import command, option
27
+
28
+ @command()
29
+ @option("--message", help="A message to include in the migration file.")
30
+ def generate_migration(message: str):
31
+ db_connection = DBConnection(...)
32
+ handle_generate("my_project", db_connection, message=message)
33
+ ```
34
+ """
35
+ # Any local imports must be done here to avoid circular imports because migrations.__init__
36
+ # imports this file.
37
+ from iceaxe.migrations.client_io import fetch_migrations
38
+ from iceaxe.migrations.generator import MigrationGenerator
39
+ from iceaxe.migrations.migrator import Migrator
40
+
41
+ CONSOLE.print("[bold blue]Generating migration to current schema")
42
+
43
+ CONSOLE.print(
44
+ "[grey58]Note that Iceaxe's migration support is well tested but still in beta."
45
+ )
46
+ CONSOLE.print(
47
+ "[grey58]File an issue @ https://github.com/piercefreeman/iceaxe/issues if you encounter any problems."
48
+ )
49
+
50
+ # Locate the migrations directory that belongs to this project
51
+ package_path = resolve_package_path(package)
52
+ migrations_path = package_path / "migrations"
53
+
54
+ # Create the path if it doesn't exist
55
+ migrations_path.mkdir(exist_ok=True)
56
+ if not (migrations_path / "__init__.py").exists():
57
+ (migrations_path / "__init__.py").touch()
58
+
59
+ # Get all of the instances that have been registered
60
+ # in memory scope by the user.
61
+ models = [
62
+ cls
63
+ for cls in DBModelMetaclass.get_registry()
64
+ if isclass(cls) and issubclass(cls, TableBase)
65
+ ]
66
+
67
+ db_serializer = DatabaseSerializer()
68
+ db_objects = []
69
+ async for values in db_serializer.get_objects(db_connection):
70
+ db_objects.append(values)
71
+
72
+ migration_generator = MigrationGenerator()
73
+ up_objects = list(migration_generator.serializer.delegate(models))
74
+
75
+ # Get the current revision from the database, this should represent the "down" revision
76
+ # for the new migration
77
+ migrator = Migrator(db_connection)
78
+ await migrator.init_db()
79
+ current_revision = await migrator.get_active_revision()
80
+
81
+ # Make sure there's not a duplicate revision that already have this down revision. If so that means
82
+ # that we will have two conflicting migration chains
83
+ migration_revisions = fetch_migrations(migrations_path)
84
+ conflict_migrations = [
85
+ migration
86
+ for migration in migration_revisions
87
+ if migration.down_revision == current_revision
88
+ ]
89
+ if conflict_migrations:
90
+ up_revisions = {migration.up_revision for migration in conflict_migrations}
91
+ raise ValueError(
92
+ f"Found conflicting migrations with down revision {current_revision} (conflicts: {up_revisions}).\n"
93
+ "If you're trying to generate a new migration, make sure to apply the previous migration first - or delete the old one and recreate."
94
+ )
95
+
96
+ migration_code, revision = await migration_generator.new_migration(
97
+ db_objects,
98
+ up_objects,
99
+ down_revision=current_revision,
100
+ user_message=message,
101
+ )
102
+
103
+ # Create the migration file. The change of a conflict with this timestamp is very low, but we make sure
104
+ # not to override any existing files anyway.
105
+ migration_file_path = migrations_path / f"rev_{revision}.py"
106
+ if migration_file_path.exists():
107
+ raise ValueError(
108
+ f"Migration file {migration_file_path} already exists. Wait a second and try again."
109
+ )
110
+
111
+ migration_file_path.write_text(migration_code)
112
+
113
+ CONSOLE.print(f"[bold green]New migration added: {migration_file_path.name}")
114
+
115
+
116
+ async def handle_apply(
117
+ package: str,
118
+ db_connection: DBConnection,
119
+ ):
120
+ """
121
+ Applies all migrations that have not been applied to the database.
122
+
123
+ :param package: The current python package name. This should match the name of the
124
+ project that's specified in pyproject.toml or setup.py.
125
+
126
+ """
127
+ from iceaxe.migrations.client_io import fetch_migrations, sort_migrations
128
+ from iceaxe.migrations.migrator import Migrator
129
+
130
+ migrations_path = resolve_package_path(package) / "migrations"
131
+ if not migrations_path.exists():
132
+ raise ValueError(f"Migrations path {migrations_path} does not exist.")
133
+
134
+ # Load all the migration files into memory and locate the subclasses of MigrationRevisionBase
135
+ migration_revisions = fetch_migrations(migrations_path)
136
+ migration_revisions = sort_migrations(migration_revisions)
137
+
138
+ # Get the current revision from the database
139
+ migrator = Migrator(db_connection)
140
+ await migrator.init_db()
141
+ current_revision = await migrator.get_active_revision()
142
+
143
+ CONSOLE.print(f"Current revision: {current_revision}")
144
+
145
+ # Find the item in the sequence that has down_revision equal to the current_revision
146
+ # This indicates the next migration to apply
147
+ next_migration_index = None
148
+ for i, revision in enumerate(migration_revisions):
149
+ if revision.down_revision == current_revision:
150
+ next_migration_index = i
151
+ break
152
+
153
+ if next_migration_index is None:
154
+ raise ValueError(
155
+ f"Could not find a migration to apply after revision {current_revision}."
156
+ )
157
+
158
+ # Get the chain after this index, this should indicate the next migration to apply
159
+ migration_chain = migration_revisions[next_migration_index:]
160
+ CONSOLE.print(f"Applying {len(migration_chain)} migrations...")
161
+
162
+ for migration in migration_chain:
163
+ with CONSOLE.status(
164
+ f"[bold blue]Applying {migration.up_revision}...", spinner="dots"
165
+ ):
166
+ start = monotonic_ns()
167
+ await migration._handle_up(db_connection)
168
+
169
+ CONSOLE.print(
170
+ f"[bold green]🚀 Applied {migration.up_revision} in {(monotonic_ns() - start) / 1e9:.2f}s"
171
+ )
172
+
173
+
174
+ async def handle_rollback(
175
+ package: str,
176
+ db_connection: DBConnection,
177
+ ):
178
+ """
179
+ Rolls back the last migration that was applied to the database.
180
+
181
+ :param package: The current python package name. This should match the name of the
182
+ project that's specified in pyproject.toml or setup.py.
183
+
184
+ """
185
+ from iceaxe.migrations.client_io import fetch_migrations, sort_migrations
186
+ from iceaxe.migrations.migrator import Migrator
187
+
188
+ migrations_path = resolve_package_path(package) / "migrations"
189
+ if not migrations_path.exists():
190
+ raise ValueError(f"Migrations path {migrations_path} does not exist.")
191
+
192
+ # Load all the migration files into memory and locate the subclasses of MigrationRevisionBase
193
+ migration_revisions = fetch_migrations(migrations_path)
194
+ migration_revisions = sort_migrations(migration_revisions)
195
+
196
+ # Get the current revision from the database
197
+ migrator = Migrator(db_connection)
198
+ await migrator.init_db()
199
+ current_revision = await migrator.get_active_revision()
200
+
201
+ CONSOLE.print(f"Current revision: {current_revision}")
202
+
203
+ # Find the item in the sequence that has down_revision equal to the current_revision
204
+ # This indicates the next migration to apply
205
+ this_migration_index = None
206
+ for i, revision in enumerate(migration_revisions):
207
+ if revision.up_revision == current_revision:
208
+ this_migration_index = i
209
+ break
210
+
211
+ if this_migration_index is None:
212
+ raise ValueError(
213
+ f"Could not find a migration matching {current_revision} for rollback."
214
+ )
215
+
216
+ # Get the chain after this index, this should indicate the next migration to apply
217
+ this_migration = migration_revisions[this_migration_index]
218
+
219
+ with CONSOLE.status(
220
+ f"[bold blue]Rolling back revision {this_migration.up_revision} to {this_migration.down_revision}...",
221
+ spinner="dots",
222
+ ):
223
+ start = monotonic_ns()
224
+ await this_migration._handle_down(db_connection)
225
+
226
+ CONSOLE.print(
227
+ f"[bold green]🪃 Rolled back migration to {this_migration.down_revision} in {(monotonic_ns() - start) / 1e9:.2f}s"
228
+ )