grai-build 0.3.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,346 @@
1
+ """
2
+ Migration executor for applying migrations to Neo4j.
3
+
4
+ This module handles executing migration Cypher scripts against Neo4j
5
+ and tracking migration state.
6
+ """
7
+
8
+ import time
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+ from typing import List, Optional
12
+
13
+ import yaml
14
+ from neo4j import Driver
15
+
16
+ from grai.core.migrations.models import Migration, MigrationHistory, MigrationStatus
17
+
18
+
19
+ class MigrationExecutor:
20
+ """
21
+ Executes migrations against Neo4j and tracks state.
22
+
23
+ This class applies migration Cypher scripts to the database and
24
+ maintains migration history using __GraiMigration nodes.
25
+ """
26
+
27
+ def __init__(self, driver: Driver, project_root: Path):
28
+ """
29
+ Initialize the migration executor.
30
+
31
+ Args:
32
+ driver: Neo4j driver instance.
33
+ project_root: Path to the project root directory.
34
+ """
35
+ self.driver = driver
36
+ self.project_root = project_root
37
+ self.migrations_dir = project_root / "migrations"
38
+
39
+ def get_pending_migrations(self) -> List[Migration]:
40
+ """
41
+ Get list of pending migrations that haven't been applied.
42
+
43
+ Returns:
44
+ List of Migration objects for pending migrations.
45
+ """
46
+ # Get all migration files
47
+ all_migrations = self._load_all_migrations()
48
+
49
+ # Get applied migrations from Neo4j
50
+ applied_versions = self._get_applied_versions()
51
+
52
+ # Filter to pending only
53
+ pending = [m for m in all_migrations if m.version not in applied_versions]
54
+
55
+ return pending
56
+
57
+ def get_migration_history(self) -> List[MigrationHistory]:
58
+ """
59
+ Get the full migration history from Neo4j.
60
+
61
+ Returns:
62
+ List of MigrationHistory objects, sorted by application time.
63
+ """
64
+ with self.driver.session() as session:
65
+ result = session.run(
66
+ """
67
+ MATCH (m:__GraiMigration)
68
+ RETURN m.version as version,
69
+ m.description as description,
70
+ m.applied_at as applied_at,
71
+ m.status as status,
72
+ m.checksum as checksum,
73
+ m.execution_time_ms as execution_time_ms,
74
+ m.error_message as error_message
75
+ ORDER BY m.applied_at ASC
76
+ """
77
+ )
78
+
79
+ history = []
80
+ for record in result:
81
+ history.append(
82
+ MigrationHistory(
83
+ version=record["version"],
84
+ description=record["description"],
85
+ applied_at=record["applied_at"],
86
+ status=MigrationStatus(record["status"]),
87
+ checksum=record["checksum"],
88
+ execution_time_ms=record.get("execution_time_ms"),
89
+ error_message=record.get("error_message"),
90
+ )
91
+ )
92
+
93
+ return history
94
+
95
+ def apply_migration(self, migration: Migration, dry_run: bool = False) -> MigrationHistory:
96
+ """
97
+ Apply a single migration to Neo4j.
98
+
99
+ Args:
100
+ migration: Migration to apply.
101
+ dry_run: If True, don't actually execute, just validate.
102
+
103
+ Returns:
104
+ MigrationHistory record of the execution.
105
+
106
+ Raises:
107
+ Exception: If migration fails.
108
+ """
109
+ start_time = time.time()
110
+ error_message = None
111
+ status = MigrationStatus.APPLIED
112
+
113
+ try:
114
+ if not dry_run:
115
+ with self.driver.session() as session:
116
+ # Execute each Cypher statement
117
+ for cypher in migration.up_cypher:
118
+ session.run(cypher)
119
+
120
+ # Record migration in history
121
+ execution_time_ms = int((time.time() - start_time) * 1000)
122
+ session.run(
123
+ """
124
+ CREATE (m:__GraiMigration {
125
+ version: $version,
126
+ description: $description,
127
+ applied_at: datetime(),
128
+ status: $status,
129
+ checksum: $checksum,
130
+ execution_time_ms: $execution_time_ms
131
+ })
132
+ """,
133
+ version=migration.version,
134
+ description=migration.description,
135
+ status=status.value,
136
+ checksum=migration.checksum,
137
+ execution_time_ms=execution_time_ms,
138
+ )
139
+ else:
140
+ # Dry run - just validate Cypher
141
+ execution_time_ms = 0
142
+
143
+ except Exception as e:
144
+ status = MigrationStatus.FAILED
145
+ error_message = str(e)
146
+ execution_time_ms = int((time.time() - start_time) * 1000)
147
+
148
+ # Record failure
149
+ if not dry_run:
150
+ with self.driver.session() as session:
151
+ session.run(
152
+ """
153
+ CREATE (m:__GraiMigration {
154
+ version: $version,
155
+ description: $description,
156
+ applied_at: datetime(),
157
+ status: $status,
158
+ checksum: $checksum,
159
+ execution_time_ms: $execution_time_ms,
160
+ error_message: $error_message
161
+ })
162
+ """,
163
+ version=migration.version,
164
+ description=migration.description,
165
+ status=status.value,
166
+ checksum=migration.checksum,
167
+ execution_time_ms=execution_time_ms,
168
+ error_message=error_message,
169
+ )
170
+
171
+ raise
172
+
173
+ return MigrationHistory(
174
+ version=migration.version,
175
+ description=migration.description,
176
+ applied_at=datetime.now(),
177
+ status=status,
178
+ checksum=migration.checksum or "",
179
+ execution_time_ms=execution_time_ms,
180
+ error_message=error_message,
181
+ )
182
+
183
+ def apply_all_pending(self, dry_run: bool = False) -> List[MigrationHistory]:
184
+ """
185
+ Apply all pending migrations in order.
186
+
187
+ Args:
188
+ dry_run: If True, don't actually execute, just validate.
189
+
190
+ Returns:
191
+ List of MigrationHistory records for applied migrations.
192
+
193
+ Raises:
194
+ Exception: If any migration fails (stops execution).
195
+ """
196
+ pending = self.get_pending_migrations()
197
+ results = []
198
+
199
+ for migration in pending:
200
+ result = self.apply_migration(migration, dry_run=dry_run)
201
+ results.append(result)
202
+
203
+ return results
204
+
205
+ def rollback_migration(self, version: Optional[str] = None) -> MigrationHistory:
206
+ """
207
+ Rollback a migration using its down script.
208
+
209
+ Args:
210
+ version: Specific version to rollback. If None, rolls back last migration.
211
+
212
+ Returns:
213
+ MigrationHistory record of the rollback.
214
+
215
+ Raises:
216
+ Exception: If rollback fails or migration not found.
217
+ """
218
+ if version is None:
219
+ # Get last applied migration
220
+ history = self.get_migration_history()
221
+ if not history:
222
+ raise ValueError("No migrations to rollback")
223
+ version = history[-1].version
224
+
225
+ # Load the migration
226
+ migration = self._load_migration(version)
227
+ if not migration:
228
+ raise ValueError(f"Migration {version} not found")
229
+
230
+ start_time = time.time()
231
+ error_message = None
232
+ status = MigrationStatus.ROLLED_BACK
233
+
234
+ try:
235
+ with self.driver.session() as session:
236
+ # Execute down script
237
+ for cypher in migration.down_cypher:
238
+ session.run(cypher)
239
+
240
+ # Update migration status
241
+ execution_time_ms = int((time.time() - start_time) * 1000)
242
+ session.run(
243
+ """
244
+ MATCH (m:__GraiMigration {version: $version})
245
+ SET m.status = $status,
246
+ m.rolled_back_at = datetime(),
247
+ m.rollback_time_ms = $execution_time_ms
248
+ """,
249
+ version=version,
250
+ status=status.value,
251
+ execution_time_ms=execution_time_ms,
252
+ )
253
+
254
+ except Exception as e:
255
+ status = MigrationStatus.FAILED
256
+ error_message = str(e)
257
+ raise
258
+
259
+ return MigrationHistory(
260
+ version=migration.version,
261
+ description=migration.description,
262
+ applied_at=datetime.now(),
263
+ status=status,
264
+ checksum=migration.checksum or "",
265
+ execution_time_ms=int((time.time() - start_time) * 1000),
266
+ error_message=error_message,
267
+ )
268
+
269
+ def _load_all_migrations(self) -> List[Migration]:
270
+ """Load all migration files from the migrations directory."""
271
+ migrations = []
272
+
273
+ for migration_file in sorted(self.migrations_dir.glob("*.yml")):
274
+ migration = self._load_migration_from_file(migration_file)
275
+ if migration:
276
+ migrations.append(migration)
277
+
278
+ return migrations
279
+
280
+ def _load_migration(self, version: str) -> Optional[Migration]:
281
+ """Load a specific migration by version."""
282
+ for migration_file in self.migrations_dir.glob(f"{version}_*.yml"):
283
+ return self._load_migration_from_file(migration_file)
284
+ return None
285
+
286
+ def _load_migration_from_file(self, filepath: Path) -> Optional[Migration]:
287
+ """Load migration from YAML file."""
288
+ try:
289
+ with open(filepath) as f:
290
+ data = yaml.safe_load(f)
291
+
292
+ # For now, just load basic info - full reconstruction coming later
293
+ return Migration(
294
+ version=data["version"],
295
+ description=data["description"],
296
+ author=data.get("author", "unknown"),
297
+ timestamp=datetime.fromisoformat(data["timestamp"]),
298
+ changes=data.get("changes", {"entities": [], "relations": []}),
299
+ up_cypher=data.get("up", []),
300
+ down_cypher=data.get("down", []),
301
+ checksum=data.get("checksum"),
302
+ )
303
+ except Exception:
304
+ return None
305
+
306
+ def _get_applied_versions(self) -> set:
307
+ """Get set of applied migration versions from Neo4j."""
308
+ with self.driver.session() as session:
309
+ result = session.run(
310
+ """
311
+ MATCH (m:__GraiMigration)
312
+ WHERE m.status IN ['applied', 'failed']
313
+ RETURN m.version as version
314
+ """
315
+ )
316
+ return {record["version"] for record in result}
317
+
318
+ def verify_migrations(self) -> bool:
319
+ """
320
+ Verify that migration files match applied migrations in database.
321
+
322
+ Returns:
323
+ True if all migrations are consistent, False otherwise.
324
+ """
325
+ all_migrations = self._load_all_migrations()
326
+ applied_history = self.get_migration_history()
327
+
328
+ # Check that all applied migrations have matching files
329
+ applied_versions = {h.version for h in applied_history}
330
+ file_versions = {m.version for m in all_migrations}
331
+
332
+ missing_files = applied_versions - file_versions
333
+ if missing_files:
334
+ print(f"Warning: Applied migrations missing files: {missing_files}")
335
+ return False
336
+
337
+ # Check checksums match
338
+ for history_entry in applied_history:
339
+ migration = next(
340
+ (m for m in all_migrations if m.version == history_entry.version), None
341
+ )
342
+ if migration and migration.checksum != history_entry.checksum:
343
+ print(f"Warning: Checksum mismatch for migration {migration.version}")
344
+ return False
345
+
346
+ return True