fleet-python 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fleet-python might be problematic. Click here for more details.
- fleet/__init__.py +3 -3
- fleet/_async/client.py +40 -8
- fleet/_async/models.py +321 -0
- fleet/_async/resources/sqlite.py +1 -1
- fleet/_async/tasks.py +1 -1
- fleet/_async/verifiers/verifier.py +4 -6
- fleet/client.py +64 -542
- fleet/env/client.py +4 -4
- fleet/instance/client.py +4 -15
- fleet/models.py +235 -190
- fleet/resources/browser.py +8 -7
- fleet/resources/sqlite.py +42 -459
- fleet/tasks.py +2 -2
- fleet/verifiers/verifier.py +10 -13
- {fleet_python-0.2.21.dist-info → fleet_python-0.2.23.dist-info}/METADATA +1 -1
- {fleet_python-0.2.21.dist-info → fleet_python-0.2.23.dist-info}/RECORD +20 -19
- scripts/fix_sync_imports.py +6 -0
- {fleet_python-0.2.21.dist-info → fleet_python-0.2.23.dist-info}/WHEEL +0 -0
- {fleet_python-0.2.21.dist-info → fleet_python-0.2.23.dist-info}/licenses/LICENSE +0 -0
- {fleet_python-0.2.21.dist-info → fleet_python-0.2.23.dist-info}/top_level.txt +0 -0
fleet/resources/sqlite.py
CHANGED
|
@@ -14,11 +14,11 @@ if TYPE_CHECKING:
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
# Import types from verifiers module
|
|
17
|
-
from
|
|
17
|
+
from fleet.verifiers.db import IgnoreConfig, _get_row_identifier, _format_row_for_error, _values_equivalent
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class SyncDatabaseSnapshot:
|
|
21
|
-
"""
|
|
21
|
+
"""Async database snapshot that fetches data through API and stores locally for diffing."""
|
|
22
22
|
|
|
23
23
|
def __init__(self, resource: "SQLiteResource", name: str | None = None):
|
|
24
24
|
self.resource = resource
|
|
@@ -26,12 +26,11 @@ class SyncDatabaseSnapshot:
|
|
|
26
26
|
self.created_at = datetime.utcnow()
|
|
27
27
|
self._data: dict[str, list[dict[str, Any]]] = {}
|
|
28
28
|
self._schemas: dict[str, list[str]] = {}
|
|
29
|
-
self.
|
|
30
|
-
self._fetched_tables: set[str] = set()
|
|
29
|
+
self._fetched = False
|
|
31
30
|
|
|
32
|
-
def
|
|
33
|
-
"""Fetch
|
|
34
|
-
if self.
|
|
31
|
+
def _ensure_fetched(self):
|
|
32
|
+
"""Fetch all data from remote database if not already fetched."""
|
|
33
|
+
if self._fetched:
|
|
35
34
|
return
|
|
36
35
|
|
|
37
36
|
# Get all tables
|
|
@@ -40,37 +39,34 @@ class SyncDatabaseSnapshot:
|
|
|
40
39
|
)
|
|
41
40
|
|
|
42
41
|
if not tables_response.rows:
|
|
43
|
-
self.
|
|
42
|
+
self._fetched = True
|
|
44
43
|
return
|
|
45
44
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
self._data[table] = []
|
|
67
|
-
|
|
68
|
-
self._fetched_tables.add(table)
|
|
45
|
+
table_names = [row[0] for row in tables_response.rows]
|
|
46
|
+
|
|
47
|
+
# Fetch data from each table
|
|
48
|
+
for table in table_names:
|
|
49
|
+
# Get table schema
|
|
50
|
+
schema_response = self.resource.query(f"PRAGMA table_info({table})")
|
|
51
|
+
if schema_response.rows:
|
|
52
|
+
self._schemas[table] = [row[1] for row in schema_response.rows] # Column names
|
|
53
|
+
|
|
54
|
+
# Get all data
|
|
55
|
+
data_response = self.resource.query(f"SELECT * FROM {table}")
|
|
56
|
+
if data_response.rows and data_response.columns:
|
|
57
|
+
self._data[table] = [
|
|
58
|
+
dict(zip(data_response.columns, row))
|
|
59
|
+
for row in data_response.rows
|
|
60
|
+
]
|
|
61
|
+
else:
|
|
62
|
+
self._data[table] = []
|
|
63
|
+
|
|
64
|
+
self._fetched = True
|
|
69
65
|
|
|
70
66
|
def tables(self) -> list[str]:
|
|
71
67
|
"""Get list of all tables in the snapshot."""
|
|
72
|
-
self.
|
|
73
|
-
return list(self.
|
|
68
|
+
self._ensure_fetched()
|
|
69
|
+
return list(self._data.keys())
|
|
74
70
|
|
|
75
71
|
def table(self, table_name: str) -> "SyncSnapshotQueryBuilder":
|
|
76
72
|
"""Create a query builder for snapshot data."""
|
|
@@ -82,12 +78,13 @@ class SyncDatabaseSnapshot:
|
|
|
82
78
|
ignore_config: IgnoreConfig | None = None,
|
|
83
79
|
) -> "SyncSnapshotDiff":
|
|
84
80
|
"""Compare this snapshot with another."""
|
|
85
|
-
|
|
81
|
+
self._ensure_fetched()
|
|
82
|
+
other._ensure_fetched()
|
|
86
83
|
return SyncSnapshotDiff(self, other, ignore_config)
|
|
87
84
|
|
|
88
85
|
|
|
89
86
|
class SyncSnapshotQueryBuilder:
|
|
90
|
-
"""Query builder that works on snapshot data
|
|
87
|
+
"""Query builder that works on local snapshot data."""
|
|
91
88
|
|
|
92
89
|
def __init__(self, snapshot: SyncDatabaseSnapshot, table: str):
|
|
93
90
|
self._snapshot = snapshot
|
|
@@ -97,63 +94,10 @@ class SyncSnapshotQueryBuilder:
|
|
|
97
94
|
self._limit: int | None = None
|
|
98
95
|
self._order_by: str | None = None
|
|
99
96
|
self._order_desc: bool = False
|
|
100
|
-
self._use_targeted_query = True # Try to use targeted queries when possible
|
|
101
|
-
|
|
102
|
-
def _can_use_targeted_query(self) -> bool:
|
|
103
|
-
"""Check if we can use a targeted query instead of loading all data."""
|
|
104
|
-
# We can use targeted query if:
|
|
105
|
-
# 1. We have simple equality conditions
|
|
106
|
-
# 2. No complex operations like joins
|
|
107
|
-
# 3. The query is selective (has conditions)
|
|
108
|
-
if not self._conditions:
|
|
109
|
-
return False
|
|
110
|
-
for col, op, val in self._conditions:
|
|
111
|
-
if op not in ["=", "IS", "IS NOT"]:
|
|
112
|
-
return False
|
|
113
|
-
return True
|
|
114
|
-
|
|
115
|
-
def _execute_targeted_query(self) -> list[dict[str, Any]]:
|
|
116
|
-
"""Execute a targeted query directly instead of loading all data."""
|
|
117
|
-
# Build WHERE clause
|
|
118
|
-
where_parts = []
|
|
119
|
-
for col, op, val in self._conditions:
|
|
120
|
-
if op == "=" and val is None:
|
|
121
|
-
where_parts.append(f"{col} IS NULL")
|
|
122
|
-
elif op == "IS":
|
|
123
|
-
where_parts.append(f"{col} IS NULL")
|
|
124
|
-
elif op == "IS NOT":
|
|
125
|
-
where_parts.append(f"{col} IS NOT NULL")
|
|
126
|
-
elif op == "=":
|
|
127
|
-
if isinstance(val, str):
|
|
128
|
-
escaped_val = val.replace("'", "''")
|
|
129
|
-
where_parts.append(f"{col} = '{escaped_val}'")
|
|
130
|
-
else:
|
|
131
|
-
where_parts.append(f"{col} = '{val}'")
|
|
132
|
-
|
|
133
|
-
where_clause = " AND ".join(where_parts)
|
|
134
|
-
|
|
135
|
-
# Build full query
|
|
136
|
-
cols = ", ".join(self._select_cols)
|
|
137
|
-
query = f"SELECT {cols} FROM {self._table} WHERE {where_clause}"
|
|
138
|
-
|
|
139
|
-
if self._order_by:
|
|
140
|
-
query += f" ORDER BY {self._order_by}"
|
|
141
|
-
if self._limit is not None:
|
|
142
|
-
query += f" LIMIT {self._limit}"
|
|
143
|
-
|
|
144
|
-
# Execute query
|
|
145
|
-
response = self._snapshot.resource.query(query)
|
|
146
|
-
if response.rows and response.columns:
|
|
147
|
-
return [dict(zip(response.columns, row)) for row in response.rows]
|
|
148
|
-
return []
|
|
149
97
|
|
|
150
98
|
def _get_data(self) -> list[dict[str, Any]]:
|
|
151
|
-
"""Get table data
|
|
152
|
-
|
|
153
|
-
return self._execute_targeted_query()
|
|
154
|
-
|
|
155
|
-
# Fall back to loading all data
|
|
156
|
-
self._snapshot._ensure_table_data(self._table)
|
|
99
|
+
"""Get table data from snapshot."""
|
|
100
|
+
self._snapshot._ensure_fetched()
|
|
157
101
|
return self._snapshot._data.get(self._table, [])
|
|
158
102
|
|
|
159
103
|
def eq(self, column: str, value: Any) -> "SyncSnapshotQueryBuilder":
|
|
@@ -177,11 +121,6 @@ class SyncSnapshotQueryBuilder:
|
|
|
177
121
|
return rows[0] if rows else None
|
|
178
122
|
|
|
179
123
|
def all(self) -> list[dict[str, Any]]:
|
|
180
|
-
# If we can use targeted query, _get_data already applies filters
|
|
181
|
-
if self._use_targeted_query and self._can_use_targeted_query():
|
|
182
|
-
return self._get_data()
|
|
183
|
-
|
|
184
|
-
# Otherwise, get all data and apply filters manually
|
|
185
124
|
data = self._get_data()
|
|
186
125
|
|
|
187
126
|
# Apply filters
|
|
@@ -249,7 +188,6 @@ class SyncSnapshotDiff:
|
|
|
249
188
|
self.after = after
|
|
250
189
|
self.ignore_config = ignore_config or IgnoreConfig()
|
|
251
190
|
self._cached: dict[str, Any] | None = None
|
|
252
|
-
self._targeted_mode = False # Flag to use targeted queries
|
|
253
191
|
|
|
254
192
|
def _get_primary_key_columns(self, table: str) -> list[str]:
|
|
255
193
|
"""Get primary key columns for a table."""
|
|
@@ -290,10 +228,6 @@ class SyncSnapshotDiff:
|
|
|
290
228
|
# Get primary key columns
|
|
291
229
|
pk_columns = self._get_primary_key_columns(tbl)
|
|
292
230
|
|
|
293
|
-
# Ensure data is fetched for this table
|
|
294
|
-
self.before._ensure_table_data(tbl)
|
|
295
|
-
self.after._ensure_table_data(tbl)
|
|
296
|
-
|
|
297
231
|
# Get data from both snapshots
|
|
298
232
|
before_data = self.before._data.get(tbl, [])
|
|
299
233
|
after_data = self.after._data.get(tbl, [])
|
|
@@ -369,348 +303,10 @@ class SyncSnapshotDiff:
|
|
|
369
303
|
self._cached = diff
|
|
370
304
|
return diff
|
|
371
305
|
|
|
372
|
-
def
|
|
373
|
-
"""
|
|
374
|
-
# We can use targeted queries if all allowed changes specify table and pk
|
|
375
|
-
for change in allowed_changes:
|
|
376
|
-
if "table" not in change or "pk" not in change:
|
|
377
|
-
return False
|
|
378
|
-
return True
|
|
379
|
-
|
|
380
|
-
def _expect_only_targeted(self, allowed_changes: list[dict[str, Any]]):
|
|
381
|
-
"""Optimized version that only queries specific rows mentioned in allowed_changes."""
|
|
382
|
-
import concurrent.futures
|
|
383
|
-
from threading import Lock
|
|
384
|
-
|
|
385
|
-
# Group allowed changes by table
|
|
386
|
-
changes_by_table: dict[str, list[dict[str, Any]]] = {}
|
|
387
|
-
for change in allowed_changes:
|
|
388
|
-
table = change["table"]
|
|
389
|
-
if table not in changes_by_table:
|
|
390
|
-
changes_by_table[table] = []
|
|
391
|
-
changes_by_table[table].append(change)
|
|
392
|
-
|
|
393
|
-
errors = []
|
|
394
|
-
errors_lock = Lock()
|
|
395
|
-
|
|
396
|
-
# Function to check a single row
|
|
397
|
-
def check_row(table: str, pk: Any, table_changes: list[dict[str, Any]], pk_columns: list[str]):
|
|
398
|
-
try:
|
|
399
|
-
# Build WHERE clause for this PK
|
|
400
|
-
where_sql = self._build_pk_where_clause(pk_columns, pk)
|
|
401
|
-
|
|
402
|
-
# Query before snapshot
|
|
403
|
-
before_query = f"SELECT * FROM {table} WHERE {where_sql}"
|
|
404
|
-
before_response = self.before.resource.query(before_query)
|
|
405
|
-
before_row = dict(zip(before_response.columns, before_response.rows[0])) if before_response.rows else None
|
|
406
|
-
|
|
407
|
-
# Query after snapshot
|
|
408
|
-
after_response = self.after.resource.query(before_query)
|
|
409
|
-
after_row = dict(zip(after_response.columns, after_response.rows[0])) if after_response.rows else None
|
|
410
|
-
|
|
411
|
-
# Check changes for this row
|
|
412
|
-
if before_row and after_row:
|
|
413
|
-
# Modified row - check fields
|
|
414
|
-
for field in set(before_row.keys()) | set(after_row.keys()):
|
|
415
|
-
if self.ignore_config.should_ignore_field(table, field):
|
|
416
|
-
continue
|
|
417
|
-
before_val = before_row.get(field)
|
|
418
|
-
after_val = after_row.get(field)
|
|
419
|
-
if not _values_equivalent(before_val, after_val):
|
|
420
|
-
# Check if this change is allowed
|
|
421
|
-
if not self._is_field_change_allowed(table_changes, pk, field, after_val):
|
|
422
|
-
error_msg = (
|
|
423
|
-
f"Unexpected change in table '{table}', "
|
|
424
|
-
f"row {pk}, field '{field}': "
|
|
425
|
-
f"{repr(before_val)} -> {repr(after_val)}"
|
|
426
|
-
)
|
|
427
|
-
with errors_lock:
|
|
428
|
-
errors.append(AssertionError(error_msg))
|
|
429
|
-
return # Stop checking this row
|
|
430
|
-
elif not before_row and after_row:
|
|
431
|
-
# Added row
|
|
432
|
-
if not self._is_row_change_allowed(table_changes, pk, "__added__"):
|
|
433
|
-
error_msg = f"Unexpected row added in table '{table}': {pk}"
|
|
434
|
-
with errors_lock:
|
|
435
|
-
errors.append(AssertionError(error_msg))
|
|
436
|
-
elif before_row and not after_row:
|
|
437
|
-
# Removed row
|
|
438
|
-
if not self._is_row_change_allowed(table_changes, pk, "__removed__"):
|
|
439
|
-
error_msg = f"Unexpected row removed from table '{table}': {pk}"
|
|
440
|
-
with errors_lock:
|
|
441
|
-
errors.append(AssertionError(error_msg))
|
|
442
|
-
except Exception as e:
|
|
443
|
-
with errors_lock:
|
|
444
|
-
errors.append(e)
|
|
445
|
-
|
|
446
|
-
# Prepare all row checks
|
|
447
|
-
row_checks = []
|
|
448
|
-
for table, table_changes in changes_by_table.items():
|
|
449
|
-
if self.ignore_config.should_ignore_table(table):
|
|
450
|
-
continue
|
|
451
|
-
|
|
452
|
-
# Get primary key columns once per table
|
|
453
|
-
pk_columns = self._get_primary_key_columns(table)
|
|
454
|
-
|
|
455
|
-
# Extract unique PKs to check
|
|
456
|
-
pks_to_check = {change["pk"] for change in table_changes}
|
|
457
|
-
|
|
458
|
-
for pk in pks_to_check:
|
|
459
|
-
row_checks.append((table, pk, table_changes, pk_columns))
|
|
460
|
-
|
|
461
|
-
# Execute row checks in parallel
|
|
462
|
-
if row_checks:
|
|
463
|
-
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
|
464
|
-
futures = [
|
|
465
|
-
executor.submit(check_row, table, pk, table_changes, pk_columns)
|
|
466
|
-
for table, pk, table_changes, pk_columns in row_checks
|
|
467
|
-
]
|
|
468
|
-
concurrent.futures.wait(futures)
|
|
469
|
-
|
|
470
|
-
# Check for errors from row checks
|
|
471
|
-
if errors:
|
|
472
|
-
raise errors[0]
|
|
473
|
-
|
|
474
|
-
# Now check tables not mentioned in allowed_changes to ensure no changes
|
|
475
|
-
all_tables = set(self.before.tables()) | set(self.after.tables())
|
|
476
|
-
tables_to_verify = []
|
|
477
|
-
|
|
478
|
-
for table in all_tables:
|
|
479
|
-
if table not in changes_by_table and not self.ignore_config.should_ignore_table(table):
|
|
480
|
-
tables_to_verify.append(table)
|
|
481
|
-
|
|
482
|
-
# Function to verify no changes in a table
|
|
483
|
-
def verify_no_changes(table: str):
|
|
484
|
-
try:
|
|
485
|
-
# For tables with no allowed changes, just check row counts
|
|
486
|
-
before_count_response = self.before.resource.query(f"SELECT COUNT(*) FROM {table}")
|
|
487
|
-
before_count = before_count_response.rows[0][0] if before_count_response.rows else 0
|
|
488
|
-
|
|
489
|
-
after_count_response = self.after.resource.query(f"SELECT COUNT(*) FROM {table}")
|
|
490
|
-
after_count = after_count_response.rows[0][0] if after_count_response.rows else 0
|
|
491
|
-
|
|
492
|
-
if before_count != after_count:
|
|
493
|
-
error_msg = (
|
|
494
|
-
f"Unexpected change in table '{table}': "
|
|
495
|
-
f"row count changed from {before_count} to {after_count}"
|
|
496
|
-
)
|
|
497
|
-
with errors_lock:
|
|
498
|
-
errors.append(AssertionError(error_msg))
|
|
499
|
-
except Exception as e:
|
|
500
|
-
with errors_lock:
|
|
501
|
-
errors.append(e)
|
|
502
|
-
|
|
503
|
-
# Execute table verification in parallel
|
|
504
|
-
if tables_to_verify:
|
|
505
|
-
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
|
506
|
-
futures = [
|
|
507
|
-
executor.submit(verify_no_changes, table)
|
|
508
|
-
for table in tables_to_verify
|
|
509
|
-
]
|
|
510
|
-
concurrent.futures.wait(futures)
|
|
511
|
-
|
|
512
|
-
# Final error check
|
|
513
|
-
if errors:
|
|
514
|
-
raise errors[0]
|
|
515
|
-
|
|
516
|
-
return self
|
|
517
|
-
|
|
518
|
-
def _build_pk_where_clause(self, pk_columns: list[str], pk_value: Any) -> str:
|
|
519
|
-
"""Build WHERE clause for primary key lookup."""
|
|
520
|
-
# Escape single quotes in values to prevent SQL injection
|
|
521
|
-
def escape_value(val: Any) -> str:
|
|
522
|
-
if val is None:
|
|
523
|
-
return "NULL"
|
|
524
|
-
elif isinstance(val, str):
|
|
525
|
-
escaped = str(val).replace("'", "''")
|
|
526
|
-
return f"'{escaped}'"
|
|
527
|
-
else:
|
|
528
|
-
return f"'{val}'"
|
|
529
|
-
|
|
530
|
-
if len(pk_columns) == 1:
|
|
531
|
-
return f"{pk_columns[0]} = {escape_value(pk_value)}"
|
|
532
|
-
else:
|
|
533
|
-
# Composite key
|
|
534
|
-
if isinstance(pk_value, tuple):
|
|
535
|
-
conditions = [f"{col} = {escape_value(val)}" for col, val in zip(pk_columns, pk_value)]
|
|
536
|
-
return " AND ".join(conditions)
|
|
537
|
-
else:
|
|
538
|
-
# Shouldn't happen if data is consistent
|
|
539
|
-
return f"{pk_columns[0]} = {escape_value(pk_value)}"
|
|
540
|
-
|
|
541
|
-
def _is_field_change_allowed(self, table_changes: list[dict[str, Any]], pk: Any, field: str, after_val: Any) -> bool:
|
|
542
|
-
"""Check if a specific field change is allowed."""
|
|
543
|
-
for change in table_changes:
|
|
544
|
-
if (str(change.get("pk")) == str(pk) and
|
|
545
|
-
change.get("field") == field and
|
|
546
|
-
_values_equivalent(change.get("after"), after_val)):
|
|
547
|
-
return True
|
|
548
|
-
return False
|
|
549
|
-
|
|
550
|
-
def _is_row_change_allowed(self, table_changes: list[dict[str, Any]], pk: Any, change_type: str) -> bool:
|
|
551
|
-
"""Check if a row addition/deletion is allowed."""
|
|
552
|
-
for change in table_changes:
|
|
553
|
-
if str(change.get("pk")) == str(pk) and change.get("after") == change_type:
|
|
554
|
-
return True
|
|
555
|
-
return False
|
|
556
|
-
|
|
557
|
-
def _expect_no_changes(self):
|
|
558
|
-
"""Efficiently verify that no changes occurred between snapshots using row counts."""
|
|
559
|
-
try:
|
|
560
|
-
import concurrent.futures
|
|
561
|
-
from threading import Lock
|
|
562
|
-
|
|
563
|
-
# Get all tables from both snapshots
|
|
564
|
-
before_tables = set(self.before.tables())
|
|
565
|
-
after_tables = set(self.after.tables())
|
|
566
|
-
|
|
567
|
-
# Check for added/removed tables (excluding ignored ones)
|
|
568
|
-
added_tables = after_tables - before_tables
|
|
569
|
-
removed_tables = before_tables - after_tables
|
|
570
|
-
|
|
571
|
-
for table in added_tables:
|
|
572
|
-
if not self.ignore_config.should_ignore_table(table):
|
|
573
|
-
raise AssertionError(f"Unexpected table added: {table}")
|
|
574
|
-
|
|
575
|
-
for table in removed_tables:
|
|
576
|
-
if not self.ignore_config.should_ignore_table(table):
|
|
577
|
-
raise AssertionError(f"Unexpected table removed: {table}")
|
|
578
|
-
|
|
579
|
-
# Prepare tables to check
|
|
580
|
-
tables_to_check = []
|
|
581
|
-
all_tables = before_tables | after_tables
|
|
582
|
-
for table in all_tables:
|
|
583
|
-
if not self.ignore_config.should_ignore_table(table):
|
|
584
|
-
tables_to_check.append(table)
|
|
585
|
-
|
|
586
|
-
# If no tables to check, we're done
|
|
587
|
-
if not tables_to_check:
|
|
588
|
-
return self
|
|
589
|
-
|
|
590
|
-
# Use ThreadPoolExecutor to parallelize count queries
|
|
591
|
-
# We use threads instead of processes since the queries are I/O bound
|
|
592
|
-
errors = []
|
|
593
|
-
errors_lock = Lock()
|
|
594
|
-
tables_needing_verification = []
|
|
595
|
-
verification_lock = Lock()
|
|
596
|
-
|
|
597
|
-
def check_table_counts(table: str):
|
|
598
|
-
"""Check row counts for a single table."""
|
|
599
|
-
try:
|
|
600
|
-
# Get row counts from both snapshots
|
|
601
|
-
before_count = 0
|
|
602
|
-
after_count = 0
|
|
603
|
-
|
|
604
|
-
if table in before_tables:
|
|
605
|
-
before_count_response = self.before.resource.query(f"SELECT COUNT(*) FROM {table}")
|
|
606
|
-
before_count = before_count_response.rows[0][0] if before_count_response.rows else 0
|
|
607
|
-
|
|
608
|
-
if table in after_tables:
|
|
609
|
-
after_count_response = self.after.resource.query(f"SELECT COUNT(*) FROM {table}")
|
|
610
|
-
after_count = after_count_response.rows[0][0] if after_count_response.rows else 0
|
|
611
|
-
|
|
612
|
-
if before_count != after_count:
|
|
613
|
-
error_msg = (
|
|
614
|
-
f"Unexpected change in table '{table}': "
|
|
615
|
-
f"row count changed from {before_count} to {after_count}"
|
|
616
|
-
)
|
|
617
|
-
with errors_lock:
|
|
618
|
-
errors.append(AssertionError(error_msg))
|
|
619
|
-
elif before_count > 0 and before_count <= 1000:
|
|
620
|
-
# Mark for detailed verification
|
|
621
|
-
with verification_lock:
|
|
622
|
-
tables_needing_verification.append(table)
|
|
623
|
-
|
|
624
|
-
except Exception as e:
|
|
625
|
-
with errors_lock:
|
|
626
|
-
errors.append(e)
|
|
627
|
-
|
|
628
|
-
# Execute count checks in parallel
|
|
629
|
-
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
|
630
|
-
futures = [executor.submit(check_table_counts, table) for table in tables_to_check]
|
|
631
|
-
concurrent.futures.wait(futures)
|
|
632
|
-
|
|
633
|
-
# Check if any errors occurred during count checking
|
|
634
|
-
if errors:
|
|
635
|
-
# Raise the first error
|
|
636
|
-
raise errors[0]
|
|
637
|
-
|
|
638
|
-
# Now verify small tables for data changes (also in parallel)
|
|
639
|
-
if tables_needing_verification:
|
|
640
|
-
verification_errors = []
|
|
641
|
-
|
|
642
|
-
def verify_table(table: str):
|
|
643
|
-
"""Verify a single table's data hasn't changed."""
|
|
644
|
-
try:
|
|
645
|
-
self._verify_table_unchanged(table)
|
|
646
|
-
except AssertionError as e:
|
|
647
|
-
with errors_lock:
|
|
648
|
-
verification_errors.append(e)
|
|
649
|
-
|
|
650
|
-
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
|
|
651
|
-
futures = [executor.submit(verify_table, table) for table in tables_needing_verification]
|
|
652
|
-
concurrent.futures.wait(futures)
|
|
653
|
-
|
|
654
|
-
# Check if any errors occurred during verification
|
|
655
|
-
if verification_errors:
|
|
656
|
-
raise verification_errors[0]
|
|
657
|
-
|
|
658
|
-
return self
|
|
659
|
-
|
|
660
|
-
except AssertionError:
|
|
661
|
-
# Re-raise assertion errors (these are expected failures)
|
|
662
|
-
raise
|
|
663
|
-
except Exception as e:
|
|
664
|
-
# If the optimized check fails for other reasons, fall back to full diff
|
|
665
|
-
print(f"Warning: Optimized no-changes check failed: {e}")
|
|
666
|
-
print("Falling back to full diff...")
|
|
667
|
-
return self._expect_only_fallback([])
|
|
668
|
-
|
|
669
|
-
def _verify_table_unchanged(self, table: str):
|
|
670
|
-
"""Verify that a table's data hasn't changed (for small tables)."""
|
|
671
|
-
# Get primary key columns
|
|
672
|
-
pk_columns = self._get_primary_key_columns(table)
|
|
673
|
-
|
|
674
|
-
# Get sorted data from both snapshots
|
|
675
|
-
order_by = ", ".join(pk_columns) if pk_columns else "rowid"
|
|
676
|
-
|
|
677
|
-
before_response = self.before.resource.query(f"SELECT * FROM {table} ORDER BY {order_by}")
|
|
678
|
-
after_response = self.after.resource.query(f"SELECT * FROM {table} ORDER BY {order_by}")
|
|
679
|
-
|
|
680
|
-
# Quick check: if column counts differ, there's a schema change
|
|
681
|
-
if before_response.columns != after_response.columns:
|
|
682
|
-
raise AssertionError(f"Schema changed in table '{table}'")
|
|
683
|
-
|
|
684
|
-
# Compare row by row
|
|
685
|
-
if len(before_response.rows) != len(after_response.rows):
|
|
686
|
-
raise AssertionError(
|
|
687
|
-
f"Row count mismatch in table '{table}': "
|
|
688
|
-
f"{len(before_response.rows)} vs {len(after_response.rows)}"
|
|
689
|
-
)
|
|
690
|
-
|
|
691
|
-
for i, (before_row, after_row) in enumerate(zip(before_response.rows, after_response.rows)):
|
|
692
|
-
before_dict = dict(zip(before_response.columns, before_row))
|
|
693
|
-
after_dict = dict(zip(after_response.columns, after_row))
|
|
694
|
-
|
|
695
|
-
# Compare fields, ignoring those in ignore config
|
|
696
|
-
for field in before_response.columns:
|
|
697
|
-
if self.ignore_config.should_ignore_field(table, field):
|
|
698
|
-
continue
|
|
699
|
-
|
|
700
|
-
if not _values_equivalent(before_dict.get(field), after_dict.get(field)):
|
|
701
|
-
pk_val = before_dict.get(pk_columns[0]) if pk_columns else i
|
|
702
|
-
raise AssertionError(
|
|
703
|
-
f"Unexpected change in table '{table}', row {pk_val}, "
|
|
704
|
-
f"field '{field}': {repr(before_dict.get(field))} -> {repr(after_dict.get(field))}"
|
|
705
|
-
)
|
|
706
|
-
|
|
707
|
-
def _expect_only_fallback(self, allowed_changes: list[dict[str, Any]]):
|
|
708
|
-
"""Fallback to full diff collection when optimized methods fail."""
|
|
306
|
+
def expect_only(self, allowed_changes: list[dict[str, Any]]):
|
|
307
|
+
"""Ensure only specified changes occurred."""
|
|
709
308
|
diff = self._collect()
|
|
710
|
-
return self._validate_diff_against_allowed_changes(diff, allowed_changes)
|
|
711
309
|
|
|
712
|
-
def _validate_diff_against_allowed_changes(self, diff: dict[str, Any], allowed_changes: list[dict[str, Any]]):
|
|
713
|
-
"""Validate a collected diff against allowed changes."""
|
|
714
310
|
def _is_change_allowed(
|
|
715
311
|
table: str, row_id: Any, field: str | None, after_value: Any
|
|
716
312
|
) -> bool:
|
|
@@ -823,20 +419,6 @@ class SyncSnapshotDiff:
|
|
|
823
419
|
raise AssertionError("\n".join(error_lines))
|
|
824
420
|
|
|
825
421
|
return self
|
|
826
|
-
|
|
827
|
-
def expect_only(self, allowed_changes: list[dict[str, Any]]):
|
|
828
|
-
"""Ensure only specified changes occurred."""
|
|
829
|
-
# Special case: empty allowed_changes means no changes should have occurred
|
|
830
|
-
if not allowed_changes:
|
|
831
|
-
return self._expect_no_changes()
|
|
832
|
-
|
|
833
|
-
# For expect_only, we can optimize by only checking the specific rows mentioned
|
|
834
|
-
if self._can_use_targeted_queries(allowed_changes):
|
|
835
|
-
return self._expect_only_targeted(allowed_changes)
|
|
836
|
-
|
|
837
|
-
# Fall back to full diff for complex cases
|
|
838
|
-
diff = self._collect()
|
|
839
|
-
return self._validate_diff_against_allowed_changes(diff, allowed_changes)
|
|
840
422
|
|
|
841
423
|
|
|
842
424
|
class SyncQueryBuilder:
|
|
@@ -1087,22 +669,23 @@ class SQLiteResource(Resource):
|
|
|
1087
669
|
|
|
1088
670
|
def snapshot(self, name: str | None = None) -> SyncDatabaseSnapshot:
|
|
1089
671
|
"""Create a snapshot of the current database state."""
|
|
1090
|
-
|
|
1091
|
-
|
|
672
|
+
snapshot = SyncDatabaseSnapshot(self, name)
|
|
673
|
+
snapshot._ensure_fetched()
|
|
674
|
+
return snapshot
|
|
1092
675
|
|
|
1093
676
|
def diff(
|
|
1094
677
|
self,
|
|
1095
678
|
other: "SQLiteResource",
|
|
1096
679
|
ignore_config: IgnoreConfig | None = None,
|
|
1097
680
|
) -> SyncSnapshotDiff:
|
|
1098
|
-
"""Compare this database with another
|
|
681
|
+
"""Compare this database with another AsyncSQLiteResource.
|
|
1099
682
|
|
|
1100
683
|
Args:
|
|
1101
|
-
other: Another
|
|
684
|
+
other: Another AsyncSQLiteResource to compare against
|
|
1102
685
|
ignore_config: Optional configuration for ignoring specific tables/fields
|
|
1103
686
|
|
|
1104
687
|
Returns:
|
|
1105
|
-
|
|
688
|
+
AsyncSnapshotDiff: Object containing the differences between the two databases
|
|
1106
689
|
"""
|
|
1107
690
|
# Create snapshots of both databases
|
|
1108
691
|
before_snapshot = self.snapshot(name=f"before_{datetime.utcnow().isoformat()}")
|
fleet/tasks.py
CHANGED
|
@@ -10,7 +10,7 @@ from uuid import UUID
|
|
|
10
10
|
from pydantic import BaseModel, Field, validator
|
|
11
11
|
|
|
12
12
|
# Import the shared VerifierFunction type that works for both async and sync
|
|
13
|
-
from
|
|
13
|
+
from fleet.types import VerifierFunction
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class Task(BaseModel):
|
|
@@ -41,4 +41,4 @@ class Task(BaseModel):
|
|
|
41
41
|
datetime: lambda v: v.isoformat(),
|
|
42
42
|
}
|
|
43
43
|
# Allow arbitrary types for the verifier field
|
|
44
|
-
arbitrary_types_allowed = True
|
|
44
|
+
arbitrary_types_allowed = True
|