crochet-migration 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,282 @@
1
+ """SQLite ledger — the authoritative record of applied migrations and data batches."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import sqlite3
6
+ from dataclasses import dataclass
7
+ from datetime import datetime, timezone
8
+ from pathlib import Path
9
+ from typing import Iterator
10
+
11
+ from crochet.errors import LedgerError, LedgerIntegrityError
12
+
13
+ _SCHEMA_VERSION = 1
14
+
15
+ _INIT_SQL = """\
16
+ CREATE TABLE IF NOT EXISTS ledger_meta (
17
+ key TEXT PRIMARY KEY,
18
+ value TEXT NOT NULL
19
+ );
20
+
21
+ CREATE TABLE IF NOT EXISTS applied_migrations (
22
+ revision_id TEXT PRIMARY KEY,
23
+ parent_id TEXT,
24
+ description TEXT NOT NULL DEFAULT '',
25
+ schema_hash TEXT NOT NULL,
26
+ applied_at TEXT NOT NULL,
27
+ rollback_safe INTEGER NOT NULL DEFAULT 1
28
+ );
29
+
30
+ CREATE TABLE IF NOT EXISTS dataset_batches (
31
+ batch_id TEXT PRIMARY KEY,
32
+ migration_id TEXT,
33
+ source_file TEXT,
34
+ file_checksum TEXT,
35
+ loader_version TEXT,
36
+ record_count INTEGER,
37
+ created_at TEXT NOT NULL,
38
+ FOREIGN KEY (migration_id) REFERENCES applied_migrations(revision_id)
39
+ );
40
+
41
+ CREATE TABLE IF NOT EXISTS schema_snapshots (
42
+ schema_hash TEXT PRIMARY KEY,
43
+ snapshot_json TEXT NOT NULL,
44
+ created_at TEXT NOT NULL
45
+ );
46
+ """
47
+
48
+
49
+ @dataclass
50
+ class AppliedMigration:
51
+ revision_id: str
52
+ parent_id: str | None
53
+ description: str
54
+ schema_hash: str
55
+ applied_at: str
56
+ rollback_safe: bool
57
+
58
+
59
+ @dataclass
60
+ class DatasetBatch:
61
+ batch_id: str
62
+ migration_id: str | None
63
+ source_file: str | None
64
+ file_checksum: str | None
65
+ loader_version: str | None
66
+ record_count: int | None
67
+ created_at: str
68
+
69
+
70
+ class Ledger:
71
+ """SQLite-backed ledger for migration and data-batch tracking."""
72
+
73
+ def __init__(self, db_path: Path) -> None:
74
+ self._db_path = db_path
75
+ db_path.parent.mkdir(parents=True, exist_ok=True)
76
+ self._conn = sqlite3.connect(str(db_path))
77
+ self._conn.execute("PRAGMA journal_mode=WAL")
78
+ self._conn.execute("PRAGMA foreign_keys=ON")
79
+ self._init_schema()
80
+
81
+ # ------------------------------------------------------------------
82
+ # Lifecycle
83
+ # ------------------------------------------------------------------
84
+
85
+ def _init_schema(self) -> None:
86
+ self._conn.executescript(_INIT_SQL)
87
+ cur = self._conn.execute(
88
+ "SELECT value FROM ledger_meta WHERE key = 'schema_version'"
89
+ )
90
+ row = cur.fetchone()
91
+ if row is None:
92
+ self._conn.execute(
93
+ "INSERT INTO ledger_meta (key, value) VALUES (?, ?)",
94
+ ("schema_version", str(_SCHEMA_VERSION)),
95
+ )
96
+ self._conn.commit()
97
+
98
+ def close(self) -> None:
99
+ self._conn.close()
100
+
101
+ def __enter__(self) -> "Ledger":
102
+ return self
103
+
104
+ def __exit__(self, *exc: object) -> None:
105
+ self.close()
106
+
107
+ # ------------------------------------------------------------------
108
+ # Migrations
109
+ # ------------------------------------------------------------------
110
+
111
+ def record_migration(
112
+ self,
113
+ revision_id: str,
114
+ parent_id: str | None,
115
+ description: str,
116
+ schema_hash: str,
117
+ rollback_safe: bool = True,
118
+ ) -> AppliedMigration:
119
+ now = datetime.now(timezone.utc).isoformat()
120
+ try:
121
+ self._conn.execute(
122
+ """INSERT INTO applied_migrations
123
+ (revision_id, parent_id, description, schema_hash, applied_at, rollback_safe)
124
+ VALUES (?, ?, ?, ?, ?, ?)""",
125
+ (revision_id, parent_id, description, schema_hash, now, int(rollback_safe)),
126
+ )
127
+ self._conn.commit()
128
+ except sqlite3.IntegrityError as exc:
129
+ raise LedgerError(
130
+ f"Migration '{revision_id}' is already recorded in the ledger."
131
+ ) from exc
132
+ return AppliedMigration(
133
+ revision_id=revision_id,
134
+ parent_id=parent_id,
135
+ description=description,
136
+ schema_hash=schema_hash,
137
+ applied_at=now,
138
+ rollback_safe=rollback_safe,
139
+ )
140
+
141
+ def remove_migration(self, revision_id: str) -> None:
142
+ self._conn.execute(
143
+ "DELETE FROM applied_migrations WHERE revision_id = ?", (revision_id,)
144
+ )
145
+ self._conn.commit()
146
+
147
+ def get_applied_migrations(self) -> list[AppliedMigration]:
148
+ cur = self._conn.execute(
149
+ "SELECT revision_id, parent_id, description, schema_hash, applied_at, rollback_safe "
150
+ "FROM applied_migrations ORDER BY applied_at"
151
+ )
152
+ return [
153
+ AppliedMigration(
154
+ revision_id=row[0],
155
+ parent_id=row[1],
156
+ description=row[2],
157
+ schema_hash=row[3],
158
+ applied_at=row[4],
159
+ rollback_safe=bool(row[5]),
160
+ )
161
+ for row in cur.fetchall()
162
+ ]
163
+
164
+ def get_head(self) -> AppliedMigration | None:
165
+ migrations = self.get_applied_migrations()
166
+ return migrations[-1] if migrations else None
167
+
168
+ def is_applied(self, revision_id: str) -> bool:
169
+ cur = self._conn.execute(
170
+ "SELECT 1 FROM applied_migrations WHERE revision_id = ?", (revision_id,)
171
+ )
172
+ return cur.fetchone() is not None
173
+
174
+ # ------------------------------------------------------------------
175
+ # Dataset batches
176
+ # ------------------------------------------------------------------
177
+
178
+ def record_batch(
179
+ self,
180
+ batch_id: str,
181
+ migration_id: str | None = None,
182
+ source_file: str | None = None,
183
+ file_checksum: str | None = None,
184
+ loader_version: str | None = None,
185
+ record_count: int | None = None,
186
+ ) -> DatasetBatch:
187
+ now = datetime.now(timezone.utc).isoformat()
188
+ try:
189
+ self._conn.execute(
190
+ """INSERT INTO dataset_batches
191
+ (batch_id, migration_id, source_file, file_checksum,
192
+ loader_version, record_count, created_at)
193
+ VALUES (?, ?, ?, ?, ?, ?, ?)""",
194
+ (batch_id, migration_id, source_file, file_checksum,
195
+ loader_version, record_count, now),
196
+ )
197
+ self._conn.commit()
198
+ except sqlite3.IntegrityError as exc:
199
+ raise LedgerError(
200
+ f"Batch '{batch_id}' is already recorded in the ledger."
201
+ ) from exc
202
+ return DatasetBatch(
203
+ batch_id=batch_id,
204
+ migration_id=migration_id,
205
+ source_file=source_file,
206
+ file_checksum=file_checksum,
207
+ loader_version=loader_version,
208
+ record_count=record_count,
209
+ created_at=now,
210
+ )
211
+
212
+ def get_batches(self, migration_id: str | None = None) -> list[DatasetBatch]:
213
+ if migration_id:
214
+ cur = self._conn.execute(
215
+ "SELECT batch_id, migration_id, source_file, file_checksum, "
216
+ "loader_version, record_count, created_at "
217
+ "FROM dataset_batches WHERE migration_id = ? ORDER BY created_at",
218
+ (migration_id,),
219
+ )
220
+ else:
221
+ cur = self._conn.execute(
222
+ "SELECT batch_id, migration_id, source_file, file_checksum, "
223
+ "loader_version, record_count, created_at "
224
+ "FROM dataset_batches ORDER BY created_at"
225
+ )
226
+ return [
227
+ DatasetBatch(
228
+ batch_id=row[0],
229
+ migration_id=row[1],
230
+ source_file=row[2],
231
+ file_checksum=row[3],
232
+ loader_version=row[4],
233
+ record_count=row[5],
234
+ created_at=row[6],
235
+ )
236
+ for row in cur.fetchall()
237
+ ]
238
+
239
+ def remove_batch(self, batch_id: str) -> None:
240
+ self._conn.execute(
241
+ "DELETE FROM dataset_batches WHERE batch_id = ?", (batch_id,)
242
+ )
243
+ self._conn.commit()
244
+
245
+ # ------------------------------------------------------------------
246
+ # Schema snapshots
247
+ # ------------------------------------------------------------------
248
+
249
+ def store_snapshot(self, schema_hash: str, snapshot_json: str) -> None:
250
+ now = datetime.now(timezone.utc).isoformat()
251
+ self._conn.execute(
252
+ """INSERT OR REPLACE INTO schema_snapshots
253
+ (schema_hash, snapshot_json, created_at) VALUES (?, ?, ?)""",
254
+ (schema_hash, snapshot_json, now),
255
+ )
256
+ self._conn.commit()
257
+
258
+ def get_snapshot(self, schema_hash: str) -> str | None:
259
+ cur = self._conn.execute(
260
+ "SELECT snapshot_json FROM schema_snapshots WHERE schema_hash = ?",
261
+ (schema_hash,),
262
+ )
263
+ row = cur.fetchone()
264
+ return row[0] if row else None
265
+
266
+ # ------------------------------------------------------------------
267
+ # Integrity
268
+ # ------------------------------------------------------------------
269
+
270
+ def verify_chain(self) -> list[str]:
271
+ """Verify the parent-chain integrity. Returns a list of issues."""
272
+ issues: list[str] = []
273
+ migrations = self.get_applied_migrations()
274
+ ids = {m.revision_id for m in migrations}
275
+
276
+ for m in migrations:
277
+ if m.parent_id is not None and m.parent_id not in ids:
278
+ issues.append(
279
+ f"Migration '{m.revision_id}' references unknown parent "
280
+ f"'{m.parent_id}'."
281
+ )
282
+ return issues
@@ -0,0 +1,6 @@
1
+ """Migration engine and operations."""
2
+
3
+ from crochet.migrations.engine import MigrationEngine
4
+ from crochet.migrations.operations import MigrationContext
5
+
6
+ __all__ = ["MigrationEngine", "MigrationContext"]
@@ -0,0 +1,279 @@
1
+ """Migration execution engine — ordering, upgrade, downgrade."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import importlib.util
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ from crochet.config import CrochetConfig
11
+ from crochet.errors import (
12
+ MigrationChainError,
13
+ MigrationError,
14
+ RollbackUnsafeError,
15
+ )
16
+ from crochet.ir.diff import SchemaDiff, diff_snapshots
17
+ from crochet.ir.hash import hash_snapshot
18
+ from crochet.ir.schema import SchemaSnapshot
19
+ from crochet.ledger.sqlite import Ledger
20
+ from crochet.migrations.operations import MigrationContext
21
+ from crochet.migrations.template import (
22
+ generate_revision_id,
23
+ render_migration,
24
+ write_migration_file,
25
+ )
26
+
27
+
28
+ class MigrationFile:
29
+ """Represents a loaded migration module."""
30
+
31
+ def __init__(self, module: Any, path: Path) -> None:
32
+ self.module = module
33
+ self.path = path
34
+ self.revision_id: str = module.revision_id
35
+ self.parent_id: str | None = module.parent_id
36
+ self.schema_hash: str = module.schema_hash
37
+ self.rollback_safe: bool = getattr(module, "rollback_safe", True)
38
+
39
+ def upgrade(self, ctx: MigrationContext) -> None:
40
+ self.module.upgrade(ctx)
41
+
42
+ def downgrade(self, ctx: MigrationContext) -> None:
43
+ self.module.downgrade(ctx)
44
+
45
+
46
+ class MigrationEngine:
47
+ """Orchestrates creation, ordering, and execution of migrations."""
48
+
49
+ def __init__(self, config: CrochetConfig, ledger: Ledger) -> None:
50
+ self._config = config
51
+ self._ledger = ledger
52
+
53
+ # ------------------------------------------------------------------
54
+ # Discovery
55
+ # ------------------------------------------------------------------
56
+
57
+ def discover_migrations(self) -> list[MigrationFile]:
58
+ """Load all migration files from the migrations directory, ordered."""
59
+ migrations_dir = self._config.migrations_dir
60
+ if not migrations_dir.exists():
61
+ return []
62
+
63
+ files: list[MigrationFile] = []
64
+ for py_file in sorted(migrations_dir.glob("*.py")):
65
+ if py_file.name.startswith("_"):
66
+ continue
67
+ module = self._load_migration_module(py_file)
68
+ if module is None:
69
+ continue
70
+ if not hasattr(module, "revision_id"):
71
+ continue
72
+ files.append(MigrationFile(module, py_file))
73
+
74
+ return self._sort_by_chain(files)
75
+
76
+ def _load_migration_module(self, path: Path) -> Any:
77
+ mod_name = f"crochet._migrations.{path.stem}"
78
+ spec = importlib.util.spec_from_file_location(mod_name, path)
79
+ if spec is None or spec.loader is None:
80
+ return None
81
+ module = importlib.util.module_from_spec(spec)
82
+ sys.modules[mod_name] = module
83
+ spec.loader.exec_module(module)
84
+ return module
85
+
86
+ def _sort_by_chain(self, migrations: list[MigrationFile]) -> list[MigrationFile]:
87
+ """Sort migrations by their parent chain (topological order)."""
88
+ by_id = {m.revision_id: m for m in migrations}
89
+ ordered: list[MigrationFile] = []
90
+ visited: set[str] = set()
91
+
92
+ # Find root(s)
93
+ roots = [m for m in migrations if m.parent_id is None]
94
+ if not roots and migrations:
95
+ # Fall back to filename sort
96
+ return sorted(migrations, key=lambda m: m.revision_id)
97
+
98
+ def walk(m: MigrationFile) -> None:
99
+ if m.revision_id in visited:
100
+ return
101
+ visited.add(m.revision_id)
102
+ if m.parent_id and m.parent_id in by_id:
103
+ walk(by_id[m.parent_id])
104
+ ordered.append(m)
105
+
106
+ for root in roots:
107
+ walk(root)
108
+
109
+ # Add any orphans not reached by the chain walk
110
+ for m in migrations:
111
+ if m.revision_id not in visited:
112
+ ordered.append(m)
113
+
114
+ return ordered
115
+
116
+ # ------------------------------------------------------------------
117
+ # Status
118
+ # ------------------------------------------------------------------
119
+
120
+ def pending_migrations(self) -> list[MigrationFile]:
121
+ """Return migrations that have not yet been applied."""
122
+ all_migrations = self.discover_migrations()
123
+ return [m for m in all_migrations if not self._ledger.is_applied(m.revision_id)]
124
+
125
+ def applied_migrations(self) -> list[MigrationFile]:
126
+ """Return migrations that have been applied (in order)."""
127
+ all_migrations = self.discover_migrations()
128
+ return [m for m in all_migrations if self._ledger.is_applied(m.revision_id)]
129
+
130
+ # ------------------------------------------------------------------
131
+ # Create
132
+ # ------------------------------------------------------------------
133
+
134
+ def create_migration(
135
+ self,
136
+ description: str,
137
+ current_snapshot: SchemaSnapshot | None = None,
138
+ rollback_safe: bool = True,
139
+ ) -> Path:
140
+ """Scaffold a new migration file.
141
+
142
+ If *current_snapshot* is provided, a diff against the previous
143
+ snapshot is computed and included as comments in the migration.
144
+ """
145
+ all_migrations = self.discover_migrations()
146
+ seq = len(all_migrations) + 1
147
+ parent_id: str | None = None
148
+ if all_migrations:
149
+ parent_id = all_migrations[-1].revision_id
150
+
151
+ revision_id = generate_revision_id(seq, description)
152
+
153
+ # Compute schema hash and diff
154
+ schema_hash = ""
155
+ diff_summary = ""
156
+ if current_snapshot is not None:
157
+ current_snapshot = hash_snapshot(current_snapshot)
158
+ schema_hash = current_snapshot.schema_hash
159
+
160
+ # Store the snapshot
161
+ self._ledger.store_snapshot(schema_hash, current_snapshot.to_json())
162
+
163
+ # Try to diff against the previous snapshot
164
+ if parent_id and all_migrations:
165
+ prev_hash = all_migrations[-1].schema_hash
166
+ prev_json = self._ledger.get_snapshot(prev_hash)
167
+ if prev_json:
168
+ prev_snapshot = SchemaSnapshot.from_json(prev_json)
169
+ diff = diff_snapshots(prev_snapshot, current_snapshot)
170
+ if diff.has_changes:
171
+ diff_summary = diff.summary()
172
+
173
+ content = render_migration(
174
+ revision_id=revision_id,
175
+ parent_id=parent_id,
176
+ description=description,
177
+ schema_hash=schema_hash,
178
+ rollback_safe=rollback_safe,
179
+ diff_summary=diff_summary,
180
+ )
181
+
182
+ return write_migration_file(
183
+ self._config.migrations_dir, revision_id, content
184
+ )
185
+
186
+ # ------------------------------------------------------------------
187
+ # Upgrade
188
+ # ------------------------------------------------------------------
189
+
190
+ def upgrade(
191
+ self,
192
+ target: str | None = None,
193
+ driver: Any | None = None,
194
+ dry_run: bool = False,
195
+ ) -> list[str]:
196
+ """Apply pending migrations up to *target* (or all).
197
+
198
+ Returns the list of applied revision IDs.
199
+ """
200
+ pending = self.pending_migrations()
201
+ if not pending:
202
+ return []
203
+
204
+ applied_ids: list[str] = []
205
+ for mf in pending:
206
+ if target and mf.revision_id == target:
207
+ self._apply_one(mf, driver, dry_run)
208
+ applied_ids.append(mf.revision_id)
209
+ break
210
+ self._apply_one(mf, driver, dry_run)
211
+ applied_ids.append(mf.revision_id)
212
+ if target and mf.revision_id == target:
213
+ break
214
+
215
+ return applied_ids
216
+
217
+ def _apply_one(
218
+ self, mf: MigrationFile, driver: Any | None, dry_run: bool
219
+ ) -> None:
220
+ ctx = MigrationContext(driver=driver, dry_run=dry_run)
221
+ try:
222
+ mf.upgrade(ctx)
223
+ except Exception as exc:
224
+ raise MigrationError(
225
+ f"Migration '{mf.revision_id}' failed during upgrade: {exc}"
226
+ ) from exc
227
+
228
+ if not dry_run:
229
+ self._ledger.record_migration(
230
+ revision_id=mf.revision_id,
231
+ parent_id=mf.parent_id,
232
+ description="",
233
+ schema_hash=mf.schema_hash,
234
+ rollback_safe=mf.rollback_safe,
235
+ )
236
+
237
+ # ------------------------------------------------------------------
238
+ # Downgrade
239
+ # ------------------------------------------------------------------
240
+
241
+ def downgrade(
242
+ self,
243
+ target: str | None = None,
244
+ driver: Any | None = None,
245
+ dry_run: bool = False,
246
+ ) -> list[str]:
247
+ """Revert applied migrations back to *target* (or one step).
248
+
249
+ Returns the list of reverted revision IDs.
250
+ """
251
+ applied = list(reversed(self.applied_migrations()))
252
+ if not applied:
253
+ return []
254
+
255
+ reverted_ids: list[str] = []
256
+ for mf in applied:
257
+ if target and mf.revision_id == target:
258
+ break
259
+ if not mf.rollback_safe:
260
+ raise RollbackUnsafeError(mf.revision_id)
261
+
262
+ ctx = MigrationContext(driver=driver, dry_run=dry_run)
263
+ try:
264
+ mf.downgrade(ctx)
265
+ except Exception as exc:
266
+ raise MigrationError(
267
+ f"Migration '{mf.revision_id}' failed during downgrade: {exc}"
268
+ ) from exc
269
+
270
+ if not dry_run:
271
+ self._ledger.remove_migration(mf.revision_id)
272
+
273
+ reverted_ids.append(mf.revision_id)
274
+
275
+ # If no target, only revert one step
276
+ if target is None:
277
+ break
278
+
279
+ return reverted_ids