pyworkflow-engine 0.1.21__py3-none-any.whl → 0.1.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,299 @@
1
+ """
2
+ Base classes for the database migration framework.
3
+
4
+ Provides Migration dataclass, MigrationRegistry for tracking migrations,
5
+ and MigrationRunner abstract base class for backend-specific implementations.
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ from collections.abc import Callable
10
+ from dataclasses import dataclass
11
+ from datetime import UTC, datetime
12
+ from typing import Any
13
+
14
+
15
+ @dataclass
16
+ class Migration:
17
+ """
18
+ Represents a database schema migration.
19
+
20
+ Attributes:
21
+ version: Integer version number (must be unique and sequential)
22
+ description: Human-readable description of what the migration does
23
+ up_sql: SQL to apply the migration (can be None for Python-based migrations)
24
+ down_sql: SQL to rollback the migration (optional, for future use)
25
+ up_func: Optional Python function for complex migrations (receives connection)
26
+ """
27
+
28
+ version: int
29
+ description: str
30
+ up_sql: str | None = None
31
+ down_sql: str | None = None
32
+ up_func: Callable[[Any], Any] | None = None
33
+
34
+ def __post_init__(self) -> None:
35
+ if self.version < 1:
36
+ raise ValueError("Migration version must be >= 1")
37
+ if not self.up_sql and not self.up_func:
38
+ raise ValueError("Migration must have either up_sql or up_func")
39
+
40
+
41
+ @dataclass
42
+ class AppliedMigration:
43
+ """
44
+ Record of an applied migration.
45
+
46
+ Attributes:
47
+ version: Migration version number
48
+ applied_at: When the migration was applied
49
+ description: Description of the migration
50
+ """
51
+
52
+ version: int
53
+ applied_at: datetime
54
+ description: str
55
+
56
+
57
+ class MigrationRegistry:
58
+ """
59
+ Registry for managing and tracking migrations.
60
+
61
+ Maintains an ordered list of migrations and provides methods to
62
+ get pending migrations based on the current database version.
63
+ """
64
+
65
+ def __init__(self) -> None:
66
+ self._migrations: dict[int, Migration] = {}
67
+
68
+ def register(self, migration: Migration) -> None:
69
+ """
70
+ Register a migration.
71
+
72
+ Args:
73
+ migration: Migration to register
74
+
75
+ Raises:
76
+ ValueError: If a migration with the same version already exists
77
+ """
78
+ if migration.version in self._migrations:
79
+ raise ValueError(f"Migration version {migration.version} already registered")
80
+ self._migrations[migration.version] = migration
81
+
82
+ def get_all(self) -> list[Migration]:
83
+ """
84
+ Get all registered migrations, ordered by version.
85
+
86
+ Returns:
87
+ List of migrations sorted by version ascending
88
+ """
89
+ return [self._migrations[v] for v in sorted(self._migrations.keys())]
90
+
91
+ def get_pending(self, current_version: int) -> list[Migration]:
92
+ """
93
+ Get migrations that need to be applied.
94
+
95
+ Args:
96
+ current_version: Current schema version (0 if fresh database)
97
+
98
+ Returns:
99
+ List of migrations with version > current_version, sorted ascending
100
+ """
101
+ return [self._migrations[v] for v in sorted(self._migrations.keys()) if v > current_version]
102
+
103
+ def get_latest_version(self) -> int:
104
+ """
105
+ Get the latest migration version.
106
+
107
+ Returns:
108
+ Highest registered version, or 0 if no migrations
109
+ """
110
+ return max(self._migrations.keys()) if self._migrations else 0
111
+
112
+ def get(self, version: int) -> Migration | None:
113
+ """
114
+ Get a specific migration by version.
115
+
116
+ Args:
117
+ version: Migration version number
118
+
119
+ Returns:
120
+ Migration if found, None otherwise
121
+ """
122
+ return self._migrations.get(version)
123
+
124
+
125
+ # Global migration registry for SQL backends
126
+ _global_registry = MigrationRegistry()
127
+
128
+
129
+ def get_global_registry() -> MigrationRegistry:
130
+ """Get the global migration registry."""
131
+ return _global_registry
132
+
133
+
134
+ def register_migration(migration: Migration) -> None:
135
+ """Register a migration in the global registry."""
136
+ _global_registry.register(migration)
137
+
138
+
139
+ class MigrationRunner(ABC):
140
+ """
141
+ Abstract base class for running migrations on a storage backend.
142
+
143
+ Subclasses must implement the backend-specific methods for:
144
+ - Ensuring the schema_versions table exists
145
+ - Getting the current schema version
146
+ - Applying individual migrations
147
+ - Detecting existing schemas (for backward compatibility)
148
+ """
149
+
150
+ def __init__(self, registry: MigrationRegistry | None = None) -> None:
151
+ """
152
+ Initialize the migration runner.
153
+
154
+ Args:
155
+ registry: Migration registry to use (defaults to global registry)
156
+ """
157
+ self.registry = registry or get_global_registry()
158
+
159
+ @abstractmethod
160
+ async def ensure_schema_versions_table(self) -> None:
161
+ """
162
+ Create the schema_versions table if it doesn't exist.
163
+
164
+ The table should have:
165
+ - version: INTEGER PRIMARY KEY
166
+ - applied_at: TIMESTAMP NOT NULL
167
+ - description: TEXT
168
+ """
169
+ pass
170
+
171
+ @abstractmethod
172
+ async def get_current_version(self) -> int:
173
+ """
174
+ Get the current schema version from the database.
175
+
176
+ Returns:
177
+ Current version (highest applied), or 0 if no migrations applied
178
+ """
179
+ pass
180
+
181
+ @abstractmethod
182
+ async def apply_migration(self, migration: Migration) -> None:
183
+ """
184
+ Apply a single migration.
185
+
186
+ This should:
187
+ 1. Execute the migration SQL/function in a transaction
188
+ 2. Record the migration in schema_versions
189
+ 3. Rollback on failure
190
+
191
+ Args:
192
+ migration: Migration to apply
193
+
194
+ Raises:
195
+ Exception: If migration fails
196
+ """
197
+ pass
198
+
199
+ @abstractmethod
200
+ async def detect_existing_schema(self) -> bool:
201
+ """
202
+ Detect if the database has an existing schema (pre-versioning).
203
+
204
+ This is used for backward compatibility with databases created
205
+ before the migration framework was added. If tables exist but
206
+ no schema_versions table, we assume it's a V1 schema.
207
+
208
+ Returns:
209
+ True if existing schema detected, False if fresh database
210
+ """
211
+ pass
212
+
213
+ @abstractmethod
214
+ async def record_baseline_version(self, version: int, description: str) -> None:
215
+ """
216
+ Record a baseline version without running migrations.
217
+
218
+ Used when detecting an existing schema to mark it as a known version.
219
+
220
+ Args:
221
+ version: Version number to record
222
+ description: Description of the baseline
223
+ """
224
+ pass
225
+
226
+ async def run_migrations(self) -> list[AppliedMigration]:
227
+ """
228
+ Run all pending migrations.
229
+
230
+ This is the main entry point for the migration runner. It:
231
+ 1. Ensures the schema_versions table exists
232
+ 2. Detects existing schemas and records baseline if needed
233
+ 3. Applies all pending migrations in order
234
+
235
+ Returns:
236
+ List of applied migrations
237
+
238
+ Raises:
239
+ Exception: If any migration fails (partial migrations are rolled back)
240
+ """
241
+ # Ensure we have a schema_versions table
242
+ await self.ensure_schema_versions_table()
243
+
244
+ # Get current version
245
+ current_version = await self.get_current_version()
246
+
247
+ # If no migrations recorded, check for existing schema
248
+ if current_version == 0:
249
+ has_existing_schema = await self.detect_existing_schema()
250
+ if has_existing_schema:
251
+ # Database has tables but no version tracking
252
+ # Assume it's at V1 (original schema)
253
+ await self.record_baseline_version(1, "Baseline: original schema")
254
+ current_version = 1
255
+
256
+ # Get and apply pending migrations
257
+ pending = self.registry.get_pending(current_version)
258
+ applied: list[AppliedMigration] = []
259
+
260
+ for migration in pending:
261
+ await self.apply_migration(migration)
262
+ applied.append(
263
+ AppliedMigration(
264
+ version=migration.version,
265
+ applied_at=datetime.now(UTC),
266
+ description=migration.description,
267
+ )
268
+ )
269
+
270
+ return applied
271
+
272
+
273
+ # =============================================================================
274
+ # Migration Definitions
275
+ # =============================================================================
276
+
277
+ # Version 1: Baseline schema (represents the original schema before versioning)
278
+ # This is auto-detected for existing databases
279
+
280
+ _v1_migration = Migration(
281
+ version=1,
282
+ description="Baseline: original schema",
283
+ up_sql="SELECT 1", # No-op, baseline is detected not applied
284
+ )
285
+ register_migration(_v1_migration)
286
+
287
+
288
+ # Version 2: Add step_id column to events table for optimized queries
289
+ # The actual SQL is backend-specific, but we define the Python function here
290
+ # to handle the backfill logic
291
+
292
+ _v2_migration = Migration(
293
+ version=2,
294
+ description="Add step_id column to events table for optimized has_event() queries",
295
+ # SQL is None because each backend has different syntax
296
+ # The up_func will be set by each backend's runner
297
+ up_sql="SELECT 1", # Placeholder, actual migration is in backend runners
298
+ )
299
+ register_migration(_v2_migration)
@@ -17,6 +17,7 @@ import aiomysql
17
17
 
18
18
  from pyworkflow.engine.events import Event, EventType
19
19
  from pyworkflow.storage.base import StorageBackend
20
+ from pyworkflow.storage.migrations import Migration, MigrationRegistry, MigrationRunner
20
21
  from pyworkflow.storage.schemas import (
21
22
  Hook,
22
23
  HookStatus,
@@ -31,6 +32,124 @@ from pyworkflow.storage.schemas import (
31
32
  )
32
33
 
33
34
 
35
+ class MySQLMigrationRunner(MigrationRunner):
36
+ """MySQL-specific migration runner."""
37
+
38
+ def __init__(self, pool: aiomysql.Pool, registry: MigrationRegistry | None = None) -> None:
39
+ super().__init__(registry)
40
+ self._pool = pool
41
+
42
+ async def ensure_schema_versions_table(self) -> None:
43
+ """Create schema_versions table if it doesn't exist."""
44
+ async with self._pool.acquire() as conn, conn.cursor() as cur:
45
+ await cur.execute("""
46
+ CREATE TABLE IF NOT EXISTS schema_versions (
47
+ version INT PRIMARY KEY,
48
+ applied_at DATETIME(6) NOT NULL,
49
+ description TEXT
50
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
51
+ """)
52
+
53
+ async def get_current_version(self) -> int:
54
+ """Get the highest applied migration version."""
55
+ async with self._pool.acquire() as conn, conn.cursor() as cur:
56
+ await cur.execute("SELECT COALESCE(MAX(version), 0) as version FROM schema_versions")
57
+ row = await cur.fetchone()
58
+ return row[0] if row else 0
59
+
60
+ async def detect_existing_schema(self) -> bool:
61
+ """Check if the events table exists (pre-versioning database)."""
62
+ async with self._pool.acquire() as conn, conn.cursor() as cur:
63
+ await cur.execute("""
64
+ SELECT COUNT(*) FROM information_schema.tables
65
+ WHERE table_schema = DATABASE() AND table_name = 'events'
66
+ """)
67
+ row = await cur.fetchone()
68
+ return row[0] > 0 if row else False
69
+
70
+ async def record_baseline_version(self, version: int, description: str) -> None:
71
+ """Record a baseline version without running migrations."""
72
+ async with self._pool.acquire() as conn, conn.cursor() as cur:
73
+ await cur.execute(
74
+ """
75
+ INSERT IGNORE INTO schema_versions (version, applied_at, description)
76
+ VALUES (%s, %s, %s)
77
+ """,
78
+ (version, datetime.now(UTC), description),
79
+ )
80
+
81
+ async def apply_migration(self, migration: Migration) -> None:
82
+ """Apply a migration with MySQL-specific handling."""
83
+ async with self._pool.acquire() as conn:
84
+ await conn.begin()
85
+ try:
86
+ async with conn.cursor() as cur:
87
+ if migration.version == 2:
88
+ # V2: Add step_id column to events table
89
+ # First check if events table exists (fresh databases won't have it yet)
90
+ await cur.execute("""
91
+ SELECT COUNT(*) FROM information_schema.tables
92
+ WHERE table_schema = DATABASE() AND table_name = 'events'
93
+ """)
94
+ row = await cur.fetchone()
95
+ table_exists = row[0] > 0 if row else False
96
+
97
+ if table_exists:
98
+ # Check if column exists first
99
+ await cur.execute("""
100
+ SELECT COUNT(*) FROM information_schema.columns
101
+ WHERE table_schema = DATABASE()
102
+ AND table_name = 'events'
103
+ AND column_name = 'step_id'
104
+ """)
105
+ row = await cur.fetchone()
106
+ if row[0] == 0:
107
+ await cur.execute(
108
+ "ALTER TABLE events ADD COLUMN step_id VARCHAR(255)"
109
+ )
110
+
111
+ # Create index for optimized has_event() queries
112
+ # Check if index exists first
113
+ await cur.execute("""
114
+ SELECT COUNT(*) FROM information_schema.statistics
115
+ WHERE table_schema = DATABASE()
116
+ AND table_name = 'events'
117
+ AND index_name = 'idx_events_run_id_step_id_type'
118
+ """)
119
+ row = await cur.fetchone()
120
+ if row[0] == 0:
121
+ await cur.execute("""
122
+ CREATE INDEX idx_events_run_id_step_id_type
123
+ ON events(run_id, step_id, type)
124
+ """)
125
+
126
+ # Backfill step_id from JSON data
127
+ await cur.execute("""
128
+ UPDATE events
129
+ SET step_id = JSON_UNQUOTE(JSON_EXTRACT(data, '$.step_id'))
130
+ WHERE step_id IS NULL
131
+ AND JSON_EXTRACT(data, '$.step_id') IS NOT NULL
132
+ """)
133
+ # If table doesn't exist, schema will be created with step_id column
134
+ elif migration.up_func:
135
+ await migration.up_func(conn)
136
+ elif migration.up_sql and migration.up_sql != "SELECT 1":
137
+ await cur.execute(migration.up_sql)
138
+
139
+ # Record the migration
140
+ await cur.execute(
141
+ """
142
+ INSERT INTO schema_versions (version, applied_at, description)
143
+ VALUES (%s, %s, %s)
144
+ """,
145
+ (migration.version, datetime.now(UTC), migration.description),
146
+ )
147
+ await conn.commit()
148
+ except Exception:
149
+ await conn.rollback()
150
+ raise
151
+
152
+
34
153
  class MySQLStorageBackend(StorageBackend):
35
154
  """
36
155
  MySQL storage backend using aiomysql for async operations.
@@ -101,11 +220,16 @@ class MySQLStorageBackend(StorageBackend):
101
220
  self._initialized = False
102
221
 
103
222
  async def _initialize_schema(self) -> None:
104
- """Create database tables if they don't exist."""
223
+ """Create database tables if they don't exist and run migrations."""
105
224
  if not self._pool:
106
225
  await self.connect()
107
226
 
108
227
  pool = self._ensure_connected()
228
+
229
+ # Run migrations first (handles schema versioning)
230
+ runner = MySQLMigrationRunner(pool)
231
+ await runner.run_migrations()
232
+
109
233
  async with pool.acquire() as conn, conn.cursor() as cur:
110
234
  # Workflow runs table
111
235
  await cur.execute("""
@@ -140,7 +264,7 @@ class MySQLStorageBackend(StorageBackend):
140
264
  ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
141
265
  """)
142
266
 
143
- # Events table
267
+ # Events table (includes step_id column added in V2 migration)
144
268
  await cur.execute("""
145
269
  CREATE TABLE IF NOT EXISTS events (
146
270
  event_id VARCHAR(255) PRIMARY KEY,
@@ -149,8 +273,10 @@ class MySQLStorageBackend(StorageBackend):
149
273
  type VARCHAR(100) NOT NULL,
150
274
  timestamp DATETIME(6) NOT NULL,
151
275
  data LONGTEXT NOT NULL DEFAULT '{}',
276
+ step_id VARCHAR(255),
152
277
  INDEX idx_events_run_id_sequence (run_id, sequence),
153
- INDEX idx_events_type (type),
278
+ INDEX idx_events_run_id_type (run_id, type),
279
+ INDEX idx_events_run_id_step_id_type (run_id, step_id, type),
154
280
  FOREIGN KEY (run_id) REFERENCES workflow_runs(run_id) ON DELETE CASCADE
155
281
  ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
156
282
  """)
@@ -449,6 +575,9 @@ class MySQLStorageBackend(StorageBackend):
449
575
  """Record an event to the append-only event log."""
450
576
  pool = self._ensure_connected()
451
577
 
578
+ # Extract step_id from event data for indexed column
579
+ step_id = event.data.get("step_id") if event.data else None
580
+
452
581
  async with pool.acquire() as conn:
453
582
  # Use transaction for atomic sequence assignment
454
583
  await conn.begin()
@@ -464,8 +593,8 @@ class MySQLStorageBackend(StorageBackend):
464
593
 
465
594
  await cur.execute(
466
595
  """
467
- INSERT INTO events (event_id, run_id, sequence, type, timestamp, data)
468
- VALUES (%s, %s, %s, %s, %s, %s)
596
+ INSERT INTO events (event_id, run_id, sequence, type, timestamp, data, step_id)
597
+ VALUES (%s, %s, %s, %s, %s, %s, %s)
469
598
  """,
470
599
  (
471
600
  event.event_id,
@@ -474,6 +603,7 @@ class MySQLStorageBackend(StorageBackend):
474
603
  event.type.value,
475
604
  event.timestamp,
476
605
  json.dumps(event.data),
606
+ step_id,
477
607
  ),
478
608
  )
479
609
  await conn.commit()
@@ -545,6 +675,57 @@ class MySQLStorageBackend(StorageBackend):
545
675
 
546
676
  return self._row_to_event(row)
547
677
 
678
+ async def has_event(
679
+ self,
680
+ run_id: str,
681
+ event_type: str,
682
+ **filters: str,
683
+ ) -> bool:
684
+ """
685
+ Check if an event exists using optimized indexed queries.
686
+
687
+ When step_id is the only filter, uses a direct indexed query (O(1) lookup).
688
+ For other filters, falls back to loading events of the type and filtering in Python.
689
+
690
+ Args:
691
+ run_id: Workflow run identifier
692
+ event_type: Event type to check for
693
+ **filters: Additional filters for event data fields
694
+
695
+ Returns:
696
+ True if a matching event exists, False otherwise
697
+ """
698
+ pool = self._ensure_connected()
699
+
700
+ # Optimized path: if only filtering by step_id, use indexed column directly
701
+ if filters.keys() == {"step_id"}:
702
+ step_id = str(filters["step_id"])
703
+ async with pool.acquire() as conn, conn.cursor() as cur:
704
+ await cur.execute(
705
+ """
706
+ SELECT 1 FROM events
707
+ WHERE run_id = %s AND type = %s AND step_id = %s
708
+ LIMIT 1
709
+ """,
710
+ (run_id, event_type, step_id),
711
+ )
712
+ row = await cur.fetchone()
713
+ return row is not None
714
+
715
+ # Fallback: load events of type and filter in Python
716
+ events = await self.get_events(run_id, event_types=[event_type])
717
+
718
+ for event in events:
719
+ match = True
720
+ for key, value in filters.items():
721
+ if str(event.data.get(key)) != str(value):
722
+ match = False
723
+ break
724
+ if match:
725
+ return True
726
+
727
+ return False
728
+
548
729
  # Step Operations
549
730
 
550
731
  async def create_step(self, step: StepExecution) -> None: