fleet-python 0.2.13__py3-none-any.whl → 0.2.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fleet-python might be problematic. Click here for more details.
- examples/diff_example.py +161 -0
- examples/dsl_example.py +50 -1
- examples/example_action_log.py +28 -0
- examples/example_mcp_anthropic.py +77 -0
- examples/example_mcp_openai.py +27 -0
- examples/example_task.py +199 -0
- examples/example_verifier.py +71 -0
- examples/query_builder_example.py +117 -0
- fleet/__init__.py +51 -40
- fleet/_async/base.py +14 -1
- fleet/_async/client.py +137 -19
- fleet/_async/env/client.py +4 -4
- fleet/_async/instance/__init__.py +1 -2
- fleet/_async/instance/client.py +3 -2
- fleet/_async/playwright.py +2 -2
- fleet/_async/resources/sqlite.py +654 -0
- fleet/_async/tasks.py +44 -0
- fleet/_async/verifiers/__init__.py +17 -0
- fleet/_async/verifiers/bundler.py +699 -0
- fleet/_async/verifiers/verifier.py +301 -0
- fleet/base.py +14 -1
- fleet/client.py +645 -12
- fleet/config.py +1 -1
- fleet/instance/__init__.py +1 -2
- fleet/instance/client.py +15 -5
- fleet/models.py +171 -4
- fleet/resources/browser.py +7 -8
- fleet/resources/mcp.py +60 -0
- fleet/resources/sqlite.py +654 -0
- fleet/tasks.py +44 -0
- fleet/types.py +18 -0
- fleet/verifiers/__init__.py +11 -5
- fleet/verifiers/bundler.py +699 -0
- fleet/verifiers/decorator.py +103 -0
- fleet/verifiers/verifier.py +301 -0
- {fleet_python-0.2.13.dist-info → fleet_python-0.2.15.dist-info}/METADATA +3 -42
- fleet_python-0.2.15.dist-info/RECORD +69 -0
- fleet_python-0.2.13.dist-info/RECORD +0 -52
- {fleet_python-0.2.13.dist-info → fleet_python-0.2.15.dist-info}/WHEEL +0 -0
- {fleet_python-0.2.13.dist-info → fleet_python-0.2.15.dist-info}/licenses/LICENSE +0 -0
- {fleet_python-0.2.13.dist-info → fleet_python-0.2.15.dist-info}/top_level.txt +0 -0
fleet/resources/sqlite.py
CHANGED
|
@@ -2,6 +2,10 @@ from typing import Any, List, Optional
|
|
|
2
2
|
from ..instance.models import Resource as ResourceModel
|
|
3
3
|
from ..instance.models import DescribeResponse, QueryRequest, QueryResponse
|
|
4
4
|
from .base import Resource
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
import tempfile
|
|
7
|
+
import sqlite3
|
|
8
|
+
import os
|
|
5
9
|
|
|
6
10
|
from typing import TYPE_CHECKING
|
|
7
11
|
|
|
@@ -9,6 +13,625 @@ if TYPE_CHECKING:
|
|
|
9
13
|
from ..instance.base import SyncWrapper
|
|
10
14
|
|
|
11
15
|
|
|
16
|
+
# Import types from verifiers module
|
|
17
|
+
from ..verifiers.db import IgnoreConfig, _get_row_identifier, _format_row_for_error, _values_equivalent
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SyncDatabaseSnapshot:
|
|
21
|
+
"""Async database snapshot that fetches data through API and stores locally for diffing."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, resource: "SQLiteResource", name: str | None = None):
|
|
24
|
+
self.resource = resource
|
|
25
|
+
self.name = name or f"snapshot_{datetime.utcnow().isoformat()}"
|
|
26
|
+
self.created_at = datetime.utcnow()
|
|
27
|
+
self._data: dict[str, list[dict[str, Any]]] = {}
|
|
28
|
+
self._schemas: dict[str, list[str]] = {}
|
|
29
|
+
self._fetched = False
|
|
30
|
+
|
|
31
|
+
def _ensure_fetched(self):
|
|
32
|
+
"""Fetch all data from remote database if not already fetched."""
|
|
33
|
+
if self._fetched:
|
|
34
|
+
return
|
|
35
|
+
|
|
36
|
+
# Get all tables
|
|
37
|
+
tables_response = self.resource.query(
|
|
38
|
+
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
if not tables_response.rows:
|
|
42
|
+
self._fetched = True
|
|
43
|
+
return
|
|
44
|
+
|
|
45
|
+
table_names = [row[0] for row in tables_response.rows]
|
|
46
|
+
|
|
47
|
+
# Fetch data from each table
|
|
48
|
+
for table in table_names:
|
|
49
|
+
# Get table schema
|
|
50
|
+
schema_response = self.resource.query(f"PRAGMA table_info({table})")
|
|
51
|
+
if schema_response.rows:
|
|
52
|
+
self._schemas[table] = [row[1] for row in schema_response.rows] # Column names
|
|
53
|
+
|
|
54
|
+
# Get all data
|
|
55
|
+
data_response = self.resource.query(f"SELECT * FROM {table}")
|
|
56
|
+
if data_response.rows and data_response.columns:
|
|
57
|
+
self._data[table] = [
|
|
58
|
+
dict(zip(data_response.columns, row))
|
|
59
|
+
for row in data_response.rows
|
|
60
|
+
]
|
|
61
|
+
else:
|
|
62
|
+
self._data[table] = []
|
|
63
|
+
|
|
64
|
+
self._fetched = True
|
|
65
|
+
|
|
66
|
+
def tables(self) -> list[str]:
|
|
67
|
+
"""Get list of all tables in the snapshot."""
|
|
68
|
+
self._ensure_fetched()
|
|
69
|
+
return list(self._data.keys())
|
|
70
|
+
|
|
71
|
+
def table(self, table_name: str) -> "SyncSnapshotQueryBuilder":
|
|
72
|
+
"""Create a query builder for snapshot data."""
|
|
73
|
+
return SyncSnapshotQueryBuilder(self, table_name)
|
|
74
|
+
|
|
75
|
+
def diff(
|
|
76
|
+
self,
|
|
77
|
+
other: "SyncDatabaseSnapshot",
|
|
78
|
+
ignore_config: IgnoreConfig | None = None,
|
|
79
|
+
) -> "SyncSnapshotDiff":
|
|
80
|
+
"""Compare this snapshot with another."""
|
|
81
|
+
self._ensure_fetched()
|
|
82
|
+
other._ensure_fetched()
|
|
83
|
+
return SyncSnapshotDiff(self, other, ignore_config)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class SyncSnapshotQueryBuilder:
|
|
87
|
+
"""Query builder that works on local snapshot data."""
|
|
88
|
+
|
|
89
|
+
def __init__(self, snapshot: SyncDatabaseSnapshot, table: str):
|
|
90
|
+
self._snapshot = snapshot
|
|
91
|
+
self._table = table
|
|
92
|
+
self._select_cols: list[str] = ["*"]
|
|
93
|
+
self._conditions: list[tuple[str, str, Any]] = []
|
|
94
|
+
self._limit: int | None = None
|
|
95
|
+
self._order_by: str | None = None
|
|
96
|
+
self._order_desc: bool = False
|
|
97
|
+
|
|
98
|
+
def _get_data(self) -> list[dict[str, Any]]:
|
|
99
|
+
"""Get table data from snapshot."""
|
|
100
|
+
self._snapshot._ensure_fetched()
|
|
101
|
+
return self._snapshot._data.get(self._table, [])
|
|
102
|
+
|
|
103
|
+
def eq(self, column: str, value: Any) -> "SyncSnapshotQueryBuilder":
|
|
104
|
+
qb = self._clone()
|
|
105
|
+
qb._conditions.append((column, "=", value))
|
|
106
|
+
return qb
|
|
107
|
+
|
|
108
|
+
def limit(self, n: int) -> "SyncSnapshotQueryBuilder":
|
|
109
|
+
qb = self._clone()
|
|
110
|
+
qb._limit = n
|
|
111
|
+
return qb
|
|
112
|
+
|
|
113
|
+
def sort(self, column: str, desc: bool = False) -> "SyncSnapshotQueryBuilder":
|
|
114
|
+
qb = self._clone()
|
|
115
|
+
qb._order_by = column
|
|
116
|
+
qb._order_desc = desc
|
|
117
|
+
return qb
|
|
118
|
+
|
|
119
|
+
def first(self) -> dict[str, Any] | None:
|
|
120
|
+
rows = self.all()
|
|
121
|
+
return rows[0] if rows else None
|
|
122
|
+
|
|
123
|
+
def all(self) -> list[dict[str, Any]]:
|
|
124
|
+
data = self._get_data()
|
|
125
|
+
|
|
126
|
+
# Apply filters
|
|
127
|
+
filtered = data
|
|
128
|
+
for col, op, val in self._conditions:
|
|
129
|
+
if op == "=":
|
|
130
|
+
filtered = [row for row in filtered if row.get(col) == val]
|
|
131
|
+
|
|
132
|
+
# Apply sorting
|
|
133
|
+
if self._order_by:
|
|
134
|
+
filtered = sorted(
|
|
135
|
+
filtered,
|
|
136
|
+
key=lambda r: r.get(self._order_by),
|
|
137
|
+
reverse=self._order_desc
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Apply limit
|
|
141
|
+
if self._limit is not None:
|
|
142
|
+
filtered = filtered[:self._limit]
|
|
143
|
+
|
|
144
|
+
# Apply column selection
|
|
145
|
+
if self._select_cols != ["*"]:
|
|
146
|
+
filtered = [
|
|
147
|
+
{col: row.get(col) for col in self._select_cols}
|
|
148
|
+
for row in filtered
|
|
149
|
+
]
|
|
150
|
+
|
|
151
|
+
return filtered
|
|
152
|
+
|
|
153
|
+
def assert_exists(self):
|
|
154
|
+
row = self.first()
|
|
155
|
+
if row is None:
|
|
156
|
+
error_msg = (
|
|
157
|
+
f"Expected at least one matching row, but found none.\n"
|
|
158
|
+
f"Table: {self._table}"
|
|
159
|
+
)
|
|
160
|
+
if self._conditions:
|
|
161
|
+
conditions_str = ", ".join(
|
|
162
|
+
[f"{col} {op} {val}" for col, op, val in self._conditions]
|
|
163
|
+
)
|
|
164
|
+
error_msg += f"\nConditions: {conditions_str}"
|
|
165
|
+
raise AssertionError(error_msg)
|
|
166
|
+
return self
|
|
167
|
+
|
|
168
|
+
def _clone(self) -> "SyncSnapshotQueryBuilder":
|
|
169
|
+
qb = SyncSnapshotQueryBuilder(self._snapshot, self._table)
|
|
170
|
+
qb._select_cols = list(self._select_cols)
|
|
171
|
+
qb._conditions = list(self._conditions)
|
|
172
|
+
qb._limit = self._limit
|
|
173
|
+
qb._order_by = self._order_by
|
|
174
|
+
qb._order_desc = self._order_desc
|
|
175
|
+
return qb
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class SyncSnapshotDiff:
|
|
179
|
+
"""Compute & validate changes between two snapshots fetched via API."""
|
|
180
|
+
|
|
181
|
+
def __init__(
|
|
182
|
+
self,
|
|
183
|
+
before: SyncDatabaseSnapshot,
|
|
184
|
+
after: SyncDatabaseSnapshot,
|
|
185
|
+
ignore_config: IgnoreConfig | None = None,
|
|
186
|
+
):
|
|
187
|
+
self.before = before
|
|
188
|
+
self.after = after
|
|
189
|
+
self.ignore_config = ignore_config or IgnoreConfig()
|
|
190
|
+
self._cached: dict[str, Any] | None = None
|
|
191
|
+
|
|
192
|
+
def _get_primary_key_columns(self, table: str) -> list[str]:
|
|
193
|
+
"""Get primary key columns for a table."""
|
|
194
|
+
# Try to get from schema
|
|
195
|
+
schema_response = self.after.resource.query(f"PRAGMA table_info({table})")
|
|
196
|
+
if not schema_response.rows:
|
|
197
|
+
return ["id"] # Default fallback
|
|
198
|
+
|
|
199
|
+
pk_columns = []
|
|
200
|
+
for row in schema_response.rows:
|
|
201
|
+
# row format: (cid, name, type, notnull, dflt_value, pk)
|
|
202
|
+
if row[5] > 0: # pk > 0 means it's part of primary key
|
|
203
|
+
pk_columns.append((row[5], row[1])) # (pk_position, column_name)
|
|
204
|
+
|
|
205
|
+
if not pk_columns:
|
|
206
|
+
# Try common defaults
|
|
207
|
+
all_columns = [row[1] for row in schema_response.rows]
|
|
208
|
+
if "id" in all_columns:
|
|
209
|
+
return ["id"]
|
|
210
|
+
return ["rowid"]
|
|
211
|
+
|
|
212
|
+
# Sort by primary key position and return just the column names
|
|
213
|
+
pk_columns.sort(key=lambda x: x[0])
|
|
214
|
+
return [col[1] for col in pk_columns]
|
|
215
|
+
|
|
216
|
+
def _collect(self):
|
|
217
|
+
"""Collect all differences between snapshots."""
|
|
218
|
+
if self._cached is not None:
|
|
219
|
+
return self._cached
|
|
220
|
+
|
|
221
|
+
all_tables = set(self.before.tables()) | set(self.after.tables())
|
|
222
|
+
diff: dict[str, dict[str, Any]] = {}
|
|
223
|
+
|
|
224
|
+
for tbl in all_tables:
|
|
225
|
+
if self.ignore_config.should_ignore_table(tbl):
|
|
226
|
+
continue
|
|
227
|
+
|
|
228
|
+
# Get primary key columns
|
|
229
|
+
pk_columns = self._get_primary_key_columns(tbl)
|
|
230
|
+
|
|
231
|
+
# Get data from both snapshots
|
|
232
|
+
before_data = self.before._data.get(tbl, [])
|
|
233
|
+
after_data = self.after._data.get(tbl, [])
|
|
234
|
+
|
|
235
|
+
# Create indexes by primary key
|
|
236
|
+
def make_key(row: dict, pk_cols: list[str]) -> Any:
|
|
237
|
+
if len(pk_cols) == 1:
|
|
238
|
+
return row.get(pk_cols[0])
|
|
239
|
+
return tuple(row.get(col) for col in pk_cols)
|
|
240
|
+
|
|
241
|
+
before_index = {make_key(row, pk_columns): row for row in before_data}
|
|
242
|
+
after_index = {make_key(row, pk_columns): row for row in after_data}
|
|
243
|
+
|
|
244
|
+
before_keys = set(before_index.keys())
|
|
245
|
+
after_keys = set(after_index.keys())
|
|
246
|
+
|
|
247
|
+
# Find changes
|
|
248
|
+
result = {
|
|
249
|
+
"table_name": tbl,
|
|
250
|
+
"primary_key": pk_columns,
|
|
251
|
+
"added_rows": [],
|
|
252
|
+
"removed_rows": [],
|
|
253
|
+
"modified_rows": [],
|
|
254
|
+
"unchanged_count": 0,
|
|
255
|
+
"total_changes": 0,
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
# Added rows
|
|
259
|
+
for key in after_keys - before_keys:
|
|
260
|
+
result["added_rows"].append({
|
|
261
|
+
"row_id": key,
|
|
262
|
+
"data": after_index[key]
|
|
263
|
+
})
|
|
264
|
+
|
|
265
|
+
# Removed rows
|
|
266
|
+
for key in before_keys - after_keys:
|
|
267
|
+
result["removed_rows"].append({
|
|
268
|
+
"row_id": key,
|
|
269
|
+
"data": before_index[key]
|
|
270
|
+
})
|
|
271
|
+
|
|
272
|
+
# Modified rows
|
|
273
|
+
for key in before_keys & after_keys:
|
|
274
|
+
before_row = before_index[key]
|
|
275
|
+
after_row = after_index[key]
|
|
276
|
+
changes = {}
|
|
277
|
+
|
|
278
|
+
for field in set(before_row.keys()) | set(after_row.keys()):
|
|
279
|
+
if self.ignore_config.should_ignore_field(tbl, field):
|
|
280
|
+
continue
|
|
281
|
+
before_val = before_row.get(field)
|
|
282
|
+
after_val = after_row.get(field)
|
|
283
|
+
if not _values_equivalent(before_val, after_val):
|
|
284
|
+
changes[field] = {"before": before_val, "after": after_val}
|
|
285
|
+
|
|
286
|
+
if changes:
|
|
287
|
+
result["modified_rows"].append({
|
|
288
|
+
"row_id": key,
|
|
289
|
+
"changes": changes,
|
|
290
|
+
"data": after_row # Current state
|
|
291
|
+
})
|
|
292
|
+
else:
|
|
293
|
+
result["unchanged_count"] += 1
|
|
294
|
+
|
|
295
|
+
result["total_changes"] = (
|
|
296
|
+
len(result["added_rows"]) +
|
|
297
|
+
len(result["removed_rows"]) +
|
|
298
|
+
len(result["modified_rows"])
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
diff[tbl] = result
|
|
302
|
+
|
|
303
|
+
self._cached = diff
|
|
304
|
+
return diff
|
|
305
|
+
|
|
306
|
+
def expect_only(self, allowed_changes: list[dict[str, Any]]):
|
|
307
|
+
"""Ensure only specified changes occurred."""
|
|
308
|
+
diff = self._collect()
|
|
309
|
+
|
|
310
|
+
def _is_change_allowed(
|
|
311
|
+
table: str, row_id: Any, field: str | None, after_value: Any
|
|
312
|
+
) -> bool:
|
|
313
|
+
"""Check if a change is in the allowed list using semantic comparison."""
|
|
314
|
+
for allowed in allowed_changes:
|
|
315
|
+
allowed_pk = allowed.get("pk")
|
|
316
|
+
# Handle type conversion for primary key comparison
|
|
317
|
+
pk_match = (
|
|
318
|
+
str(allowed_pk) == str(row_id) if allowed_pk is not None else False
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
if (
|
|
322
|
+
allowed["table"] == table
|
|
323
|
+
and pk_match
|
|
324
|
+
and allowed.get("field") == field
|
|
325
|
+
and _values_equivalent(allowed.get("after"), after_value)
|
|
326
|
+
):
|
|
327
|
+
return True
|
|
328
|
+
return False
|
|
329
|
+
|
|
330
|
+
# Collect all unexpected changes
|
|
331
|
+
unexpected_changes = []
|
|
332
|
+
|
|
333
|
+
for tbl, report in diff.items():
|
|
334
|
+
for row in report.get("modified_rows", []):
|
|
335
|
+
for f, vals in row["changes"].items():
|
|
336
|
+
if self.ignore_config.should_ignore_field(tbl, f):
|
|
337
|
+
continue
|
|
338
|
+
if not _is_change_allowed(tbl, row["row_id"], f, vals["after"]):
|
|
339
|
+
unexpected_changes.append({
|
|
340
|
+
"type": "modification",
|
|
341
|
+
"table": tbl,
|
|
342
|
+
"row_id": row["row_id"],
|
|
343
|
+
"field": f,
|
|
344
|
+
"before": vals.get("before"),
|
|
345
|
+
"after": vals["after"],
|
|
346
|
+
"full_row": row,
|
|
347
|
+
})
|
|
348
|
+
|
|
349
|
+
for row in report.get("added_rows", []):
|
|
350
|
+
if not _is_change_allowed(tbl, row["row_id"], None, "__added__"):
|
|
351
|
+
unexpected_changes.append({
|
|
352
|
+
"type": "insertion",
|
|
353
|
+
"table": tbl,
|
|
354
|
+
"row_id": row["row_id"],
|
|
355
|
+
"field": None,
|
|
356
|
+
"after": "__added__",
|
|
357
|
+
"full_row": row,
|
|
358
|
+
})
|
|
359
|
+
|
|
360
|
+
for row in report.get("removed_rows", []):
|
|
361
|
+
if not _is_change_allowed(tbl, row["row_id"], None, "__removed__"):
|
|
362
|
+
unexpected_changes.append({
|
|
363
|
+
"type": "deletion",
|
|
364
|
+
"table": tbl,
|
|
365
|
+
"row_id": row["row_id"],
|
|
366
|
+
"field": None,
|
|
367
|
+
"after": "__removed__",
|
|
368
|
+
"full_row": row,
|
|
369
|
+
})
|
|
370
|
+
|
|
371
|
+
if unexpected_changes:
|
|
372
|
+
# Build comprehensive error message
|
|
373
|
+
error_lines = ["Unexpected database changes detected:"]
|
|
374
|
+
error_lines.append("")
|
|
375
|
+
|
|
376
|
+
for i, change in enumerate(unexpected_changes[:5], 1):
|
|
377
|
+
error_lines.append(f"{i}. {change['type'].upper()} in table '{change['table']}':")
|
|
378
|
+
error_lines.append(f" Row ID: {change['row_id']}")
|
|
379
|
+
|
|
380
|
+
if change["type"] == "modification":
|
|
381
|
+
error_lines.append(f" Field: {change['field']}")
|
|
382
|
+
error_lines.append(f" Before: {repr(change['before'])}")
|
|
383
|
+
error_lines.append(f" After: {repr(change['after'])}")
|
|
384
|
+
elif change["type"] == "insertion":
|
|
385
|
+
error_lines.append(" New row added")
|
|
386
|
+
elif change["type"] == "deletion":
|
|
387
|
+
error_lines.append(" Row deleted")
|
|
388
|
+
|
|
389
|
+
# Show some context from the row
|
|
390
|
+
if "full_row" in change and change["full_row"]:
|
|
391
|
+
row_data = change["full_row"]
|
|
392
|
+
if "data" in row_data:
|
|
393
|
+
formatted_row = _format_row_for_error(
|
|
394
|
+
row_data.get("data", {}), max_fields=5
|
|
395
|
+
)
|
|
396
|
+
error_lines.append(f" Row data: {formatted_row}")
|
|
397
|
+
|
|
398
|
+
error_lines.append("")
|
|
399
|
+
|
|
400
|
+
if len(unexpected_changes) > 5:
|
|
401
|
+
error_lines.append(f"... and {len(unexpected_changes) - 5} more unexpected changes")
|
|
402
|
+
error_lines.append("")
|
|
403
|
+
|
|
404
|
+
# Show what changes were allowed
|
|
405
|
+
error_lines.append("Allowed changes were:")
|
|
406
|
+
if allowed_changes:
|
|
407
|
+
for i, allowed in enumerate(allowed_changes[:3], 1):
|
|
408
|
+
error_lines.append(
|
|
409
|
+
f" {i}. Table: {allowed.get('table')}, "
|
|
410
|
+
f"ID: {allowed.get('pk')}, "
|
|
411
|
+
f"Field: {allowed.get('field')}, "
|
|
412
|
+
f"After: {repr(allowed.get('after'))}"
|
|
413
|
+
)
|
|
414
|
+
if len(allowed_changes) > 3:
|
|
415
|
+
error_lines.append(f" ... and {len(allowed_changes) - 3} more allowed changes")
|
|
416
|
+
else:
|
|
417
|
+
error_lines.append(" (No changes were allowed)")
|
|
418
|
+
|
|
419
|
+
raise AssertionError("\n".join(error_lines))
|
|
420
|
+
|
|
421
|
+
return self
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
class SyncQueryBuilder:
|
|
425
|
+
"""Async query builder that translates DSL to SQL and executes through the API."""
|
|
426
|
+
|
|
427
|
+
def __init__(self, resource: "SQLiteResource", table: str):
|
|
428
|
+
self._resource = resource
|
|
429
|
+
self._table = table
|
|
430
|
+
self._select_cols: list[str] = ["*"]
|
|
431
|
+
self._conditions: list[tuple[str, str, Any]] = []
|
|
432
|
+
self._joins: list[tuple[str, dict[str, str]]] = []
|
|
433
|
+
self._limit: int | None = None
|
|
434
|
+
self._order_by: str | None = None
|
|
435
|
+
|
|
436
|
+
# Column projection / limiting / ordering
|
|
437
|
+
def select(self, *columns: str) -> "SyncQueryBuilder":
|
|
438
|
+
qb = self._clone()
|
|
439
|
+
qb._select_cols = list(columns) if columns else ["*"]
|
|
440
|
+
return qb
|
|
441
|
+
|
|
442
|
+
def limit(self, n: int) -> "SyncQueryBuilder":
|
|
443
|
+
qb = self._clone()
|
|
444
|
+
qb._limit = n
|
|
445
|
+
return qb
|
|
446
|
+
|
|
447
|
+
def sort(self, column: str, desc: bool = False) -> "SyncQueryBuilder":
|
|
448
|
+
qb = self._clone()
|
|
449
|
+
qb._order_by = f"{column} {'DESC' if desc else 'ASC'}"
|
|
450
|
+
return qb
|
|
451
|
+
|
|
452
|
+
# WHERE helpers
|
|
453
|
+
def _add_condition(self, column: str, op: str, value: Any) -> "SyncQueryBuilder":
|
|
454
|
+
qb = self._clone()
|
|
455
|
+
qb._conditions.append((column, op, value))
|
|
456
|
+
return qb
|
|
457
|
+
|
|
458
|
+
def eq(self, column: str, value: Any) -> "SyncQueryBuilder":
|
|
459
|
+
return self._add_condition(column, "=", value)
|
|
460
|
+
|
|
461
|
+
def neq(self, column: str, value: Any) -> "SyncQueryBuilder":
|
|
462
|
+
return self._add_condition(column, "!=", value)
|
|
463
|
+
|
|
464
|
+
def gt(self, column: str, value: Any) -> "SyncQueryBuilder":
|
|
465
|
+
return self._add_condition(column, ">", value)
|
|
466
|
+
|
|
467
|
+
def gte(self, column: str, value: Any) -> "SyncQueryBuilder":
|
|
468
|
+
return self._add_condition(column, ">=", value)
|
|
469
|
+
|
|
470
|
+
def lt(self, column: str, value: Any) -> "SyncQueryBuilder":
|
|
471
|
+
return self._add_condition(column, "<", value)
|
|
472
|
+
|
|
473
|
+
def lte(self, column: str, value: Any) -> "SyncQueryBuilder":
|
|
474
|
+
return self._add_condition(column, "<=", value)
|
|
475
|
+
|
|
476
|
+
def in_(self, column: str, values: list[Any]) -> "SyncQueryBuilder":
|
|
477
|
+
qb = self._clone()
|
|
478
|
+
qb._conditions.append((column, "IN", tuple(values)))
|
|
479
|
+
return qb
|
|
480
|
+
|
|
481
|
+
def not_in(self, column: str, values: list[Any]) -> "SyncQueryBuilder":
|
|
482
|
+
qb = self._clone()
|
|
483
|
+
qb._conditions.append((column, "NOT IN", tuple(values)))
|
|
484
|
+
return qb
|
|
485
|
+
|
|
486
|
+
def is_null(self, column: str) -> "SyncQueryBuilder":
|
|
487
|
+
return self._add_condition(column, "IS", None)
|
|
488
|
+
|
|
489
|
+
def not_null(self, column: str) -> "SyncQueryBuilder":
|
|
490
|
+
return self._add_condition(column, "IS NOT", None)
|
|
491
|
+
|
|
492
|
+
def ilike(self, column: str, pattern: str) -> "SyncQueryBuilder":
|
|
493
|
+
qb = self._clone()
|
|
494
|
+
qb._conditions.append((column, "LIKE", pattern))
|
|
495
|
+
return qb
|
|
496
|
+
|
|
497
|
+
# JOIN
|
|
498
|
+
def join(self, other_table: str, on: dict[str, str]) -> "SyncQueryBuilder":
|
|
499
|
+
qb = self._clone()
|
|
500
|
+
qb._joins.append((other_table, on))
|
|
501
|
+
return qb
|
|
502
|
+
|
|
503
|
+
# Compile to SQL
|
|
504
|
+
def _compile(self) -> tuple[str, list[Any]]:
|
|
505
|
+
cols = ", ".join(self._select_cols)
|
|
506
|
+
sql = [f"SELECT {cols} FROM {self._table}"]
|
|
507
|
+
params: list[Any] = []
|
|
508
|
+
|
|
509
|
+
# Joins
|
|
510
|
+
for tbl, onmap in self._joins:
|
|
511
|
+
join_clauses = [
|
|
512
|
+
f"{self._table}.{l} = {tbl}.{r}"
|
|
513
|
+
for l, r in onmap.items()
|
|
514
|
+
]
|
|
515
|
+
sql.append(f"JOIN {tbl} ON {' AND '.join(join_clauses)}")
|
|
516
|
+
|
|
517
|
+
# WHERE
|
|
518
|
+
if self._conditions:
|
|
519
|
+
placeholders = []
|
|
520
|
+
for col, op, val in self._conditions:
|
|
521
|
+
if op in ("IN", "NOT IN") and isinstance(val, tuple):
|
|
522
|
+
ph = ", ".join(["?" for _ in val])
|
|
523
|
+
placeholders.append(f"{col} {op} ({ph})")
|
|
524
|
+
params.extend(val)
|
|
525
|
+
elif op in ("IS", "IS NOT"):
|
|
526
|
+
placeholders.append(f"{col} {op} NULL")
|
|
527
|
+
else:
|
|
528
|
+
placeholders.append(f"{col} {op} ?")
|
|
529
|
+
params.append(val)
|
|
530
|
+
sql.append("WHERE " + " AND ".join(placeholders))
|
|
531
|
+
|
|
532
|
+
# ORDER / LIMIT
|
|
533
|
+
if self._order_by:
|
|
534
|
+
sql.append(f"ORDER BY {self._order_by}")
|
|
535
|
+
if self._limit is not None:
|
|
536
|
+
sql.append(f"LIMIT {self._limit}")
|
|
537
|
+
|
|
538
|
+
return " ".join(sql), params
|
|
539
|
+
|
|
540
|
+
# Execution methods
|
|
541
|
+
def count(self) -> int:
|
|
542
|
+
qb = self.select("COUNT(*) AS __cnt__").limit(None)
|
|
543
|
+
sql, params = qb._compile()
|
|
544
|
+
response = self._resource.query(sql, params)
|
|
545
|
+
if response.rows and len(response.rows) > 0:
|
|
546
|
+
# Convert row list to dict
|
|
547
|
+
row_dict = dict(zip(response.columns or [], response.rows[0]))
|
|
548
|
+
return row_dict.get("__cnt__", 0)
|
|
549
|
+
return 0
|
|
550
|
+
|
|
551
|
+
def first(self) -> dict[str, Any] | None:
|
|
552
|
+
rows = self.limit(1).all()
|
|
553
|
+
return rows[0] if rows else None
|
|
554
|
+
|
|
555
|
+
def all(self) -> list[dict[str, Any]]:
|
|
556
|
+
sql, params = self._compile()
|
|
557
|
+
response = self._resource.query(sql, params)
|
|
558
|
+
if not response.rows:
|
|
559
|
+
return []
|
|
560
|
+
# Convert List[List] to List[dict] using column names
|
|
561
|
+
return [
|
|
562
|
+
dict(zip(response.columns or [], row))
|
|
563
|
+
for row in response.rows
|
|
564
|
+
]
|
|
565
|
+
|
|
566
|
+
# Assertions
|
|
567
|
+
def assert_exists(self):
|
|
568
|
+
row = self.first()
|
|
569
|
+
if row is None:
|
|
570
|
+
sql, params = self._compile()
|
|
571
|
+
error_msg = (
|
|
572
|
+
f"Expected at least one matching row, but found none.\n"
|
|
573
|
+
f"Query: {sql}\n"
|
|
574
|
+
f"Parameters: {params}\n"
|
|
575
|
+
f"Table: {self._table}"
|
|
576
|
+
)
|
|
577
|
+
if self._conditions:
|
|
578
|
+
conditions_str = ", ".join(
|
|
579
|
+
[f"{col} {op} {val}" for col, op, val in self._conditions]
|
|
580
|
+
)
|
|
581
|
+
error_msg += f"\nConditions: {conditions_str}"
|
|
582
|
+
raise AssertionError(error_msg)
|
|
583
|
+
return self
|
|
584
|
+
|
|
585
|
+
def assert_none(self):
|
|
586
|
+
row = self.first()
|
|
587
|
+
if row is not None:
|
|
588
|
+
sql, params = self._compile()
|
|
589
|
+
error_msg = (
|
|
590
|
+
f"Expected no matching rows, but found at least one.\n"
|
|
591
|
+
f"Found row: {row}\n"
|
|
592
|
+
f"Query: {sql}\n"
|
|
593
|
+
f"Parameters: {params}\n"
|
|
594
|
+
f"Table: {self._table}"
|
|
595
|
+
)
|
|
596
|
+
raise AssertionError(error_msg)
|
|
597
|
+
return self
|
|
598
|
+
|
|
599
|
+
def assert_eq(self, column: str, value: Any):
|
|
600
|
+
row = self.first()
|
|
601
|
+
if row is None:
|
|
602
|
+
sql, params = self._compile()
|
|
603
|
+
error_msg = (
|
|
604
|
+
f"Row not found for equality assertion.\n"
|
|
605
|
+
f"Expected to find a row with {column}={repr(value)}\n"
|
|
606
|
+
f"Query: {sql}\n"
|
|
607
|
+
f"Parameters: {params}\n"
|
|
608
|
+
f"Table: {self._table}"
|
|
609
|
+
)
|
|
610
|
+
raise AssertionError(error_msg)
|
|
611
|
+
|
|
612
|
+
actual_value = row.get(column)
|
|
613
|
+
if actual_value != value:
|
|
614
|
+
error_msg = (
|
|
615
|
+
f"Field value assertion failed.\n"
|
|
616
|
+
f"Field: {column}\n"
|
|
617
|
+
f"Expected: {repr(value)}\n"
|
|
618
|
+
f"Actual: {repr(actual_value)}\n"
|
|
619
|
+
f"Full row data: {row}\n"
|
|
620
|
+
f"Table: {self._table}"
|
|
621
|
+
)
|
|
622
|
+
raise AssertionError(error_msg)
|
|
623
|
+
return self
|
|
624
|
+
|
|
625
|
+
def _clone(self) -> "SyncQueryBuilder":
|
|
626
|
+
qb = SyncQueryBuilder(self._resource, self._table)
|
|
627
|
+
qb._select_cols = list(self._select_cols)
|
|
628
|
+
qb._conditions = list(self._conditions)
|
|
629
|
+
qb._joins = list(self._joins)
|
|
630
|
+
qb._limit = self._limit
|
|
631
|
+
qb._order_by = self._order_by
|
|
632
|
+
return qb
|
|
633
|
+
|
|
634
|
+
|
|
12
635
|
class SQLiteResource(Resource):
|
|
13
636
|
def __init__(self, resource: ResourceModel, client: "SyncWrapper"):
|
|
14
637
|
super().__init__(resource)
|
|
@@ -39,3 +662,34 @@ class SQLiteResource(Resource):
|
|
|
39
662
|
json=request.model_dump(),
|
|
40
663
|
)
|
|
41
664
|
return QueryResponse(**response.json())
|
|
665
|
+
|
|
666
|
+
def table(self, table_name: str) -> SyncQueryBuilder:
|
|
667
|
+
"""Create a query builder for the specified table."""
|
|
668
|
+
return SyncQueryBuilder(self, table_name)
|
|
669
|
+
|
|
670
|
+
def snapshot(self, name: str | None = None) -> SyncDatabaseSnapshot:
|
|
671
|
+
"""Create a snapshot of the current database state."""
|
|
672
|
+
snapshot = SyncDatabaseSnapshot(self, name)
|
|
673
|
+
snapshot._ensure_fetched()
|
|
674
|
+
return snapshot
|
|
675
|
+
|
|
676
|
+
def diff(
|
|
677
|
+
self,
|
|
678
|
+
other: "SQLiteResource",
|
|
679
|
+
ignore_config: IgnoreConfig | None = None,
|
|
680
|
+
) -> SyncSnapshotDiff:
|
|
681
|
+
"""Compare this database with another SQLiteResource.
|
|
682
|
+
|
|
683
|
+
Args:
|
|
684
|
+
other: Another SQLiteResource to compare against
|
|
685
|
+
ignore_config: Optional configuration for ignoring specific tables/fields
|
|
686
|
+
|
|
687
|
+
Returns:
|
|
688
|
+
SyncSnapshotDiff: Object containing the differences between the two databases
|
|
689
|
+
"""
|
|
690
|
+
# Create snapshots of both databases
|
|
691
|
+
before_snapshot = self.snapshot(name=f"before_{datetime.utcnow().isoformat()}")
|
|
692
|
+
after_snapshot = other.snapshot(name=f"after_{datetime.utcnow().isoformat()}")
|
|
693
|
+
|
|
694
|
+
# Return the diff between the snapshots
|
|
695
|
+
return before_snapshot.diff(after_snapshot, ignore_config)
|