fleet-python 0.2.13__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fleet-python might be problematic. Click here for more details.
- examples/diff_example.py +161 -0
- examples/dsl_example.py +50 -1
- examples/example_action_log.py +28 -0
- examples/example_mcp_anthropic.py +77 -0
- examples/example_mcp_openai.py +27 -0
- examples/example_task.py +199 -0
- examples/example_verifier.py +71 -0
- examples/query_builder_example.py +117 -0
- fleet/__init__.py +51 -40
- fleet/_async/base.py +14 -1
- fleet/_async/client.py +155 -19
- fleet/_async/env/client.py +4 -4
- fleet/_async/instance/__init__.py +1 -2
- fleet/_async/instance/client.py +3 -2
- fleet/_async/playwright.py +2 -2
- fleet/_async/resources/sqlite.py +654 -0
- fleet/_async/tasks.py +44 -0
- fleet/_async/verifiers/__init__.py +17 -0
- fleet/_async/verifiers/bundler.py +699 -0
- fleet/_async/verifiers/verifier.py +301 -0
- fleet/base.py +14 -1
- fleet/client.py +664 -12
- fleet/config.py +1 -1
- fleet/instance/__init__.py +1 -2
- fleet/instance/client.py +15 -5
- fleet/models.py +171 -4
- fleet/resources/browser.py +7 -8
- fleet/resources/mcp.py +60 -0
- fleet/resources/sqlite.py +654 -0
- fleet/tasks.py +44 -0
- fleet/types.py +18 -0
- fleet/verifiers/__init__.py +11 -5
- fleet/verifiers/bundler.py +699 -0
- fleet/verifiers/decorator.py +103 -0
- fleet/verifiers/verifier.py +301 -0
- {fleet_python-0.2.13.dist-info → fleet_python-0.2.16.dist-info}/METADATA +3 -42
- fleet_python-0.2.16.dist-info/RECORD +69 -0
- fleet_python-0.2.13.dist-info/RECORD +0 -52
- {fleet_python-0.2.13.dist-info → fleet_python-0.2.16.dist-info}/WHEEL +0 -0
- {fleet_python-0.2.13.dist-info → fleet_python-0.2.16.dist-info}/licenses/LICENSE +0 -0
- {fleet_python-0.2.13.dist-info → fleet_python-0.2.16.dist-info}/top_level.txt +0 -0
fleet/_async/resources/sqlite.py
CHANGED
|
@@ -2,6 +2,10 @@ from typing import Any, List, Optional
|
|
|
2
2
|
from ...instance.models import Resource as ResourceModel
|
|
3
3
|
from ...instance.models import DescribeResponse, QueryRequest, QueryResponse
|
|
4
4
|
from .base import Resource
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
import tempfile
|
|
7
|
+
import sqlite3
|
|
8
|
+
import os
|
|
5
9
|
|
|
6
10
|
from typing import TYPE_CHECKING
|
|
7
11
|
|
|
@@ -9,6 +13,625 @@ if TYPE_CHECKING:
|
|
|
9
13
|
from ..instance.base import AsyncWrapper
|
|
10
14
|
|
|
11
15
|
|
|
16
|
+
# Import types from verifiers module
|
|
17
|
+
from ...verifiers.db import IgnoreConfig, _get_row_identifier, _format_row_for_error, _values_equivalent
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AsyncDatabaseSnapshot:
|
|
21
|
+
"""Async database snapshot that fetches data through API and stores locally for diffing."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, resource: "AsyncSQLiteResource", name: str | None = None):
|
|
24
|
+
self.resource = resource
|
|
25
|
+
self.name = name or f"snapshot_{datetime.utcnow().isoformat()}"
|
|
26
|
+
self.created_at = datetime.utcnow()
|
|
27
|
+
self._data: dict[str, list[dict[str, Any]]] = {}
|
|
28
|
+
self._schemas: dict[str, list[str]] = {}
|
|
29
|
+
self._fetched = False
|
|
30
|
+
|
|
31
|
+
async def _ensure_fetched(self):
|
|
32
|
+
"""Fetch all data from remote database if not already fetched."""
|
|
33
|
+
if self._fetched:
|
|
34
|
+
return
|
|
35
|
+
|
|
36
|
+
# Get all tables
|
|
37
|
+
tables_response = await self.resource.query(
|
|
38
|
+
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
if not tables_response.rows:
|
|
42
|
+
self._fetched = True
|
|
43
|
+
return
|
|
44
|
+
|
|
45
|
+
table_names = [row[0] for row in tables_response.rows]
|
|
46
|
+
|
|
47
|
+
# Fetch data from each table
|
|
48
|
+
for table in table_names:
|
|
49
|
+
# Get table schema
|
|
50
|
+
schema_response = await self.resource.query(f"PRAGMA table_info({table})")
|
|
51
|
+
if schema_response.rows:
|
|
52
|
+
self._schemas[table] = [row[1] for row in schema_response.rows] # Column names
|
|
53
|
+
|
|
54
|
+
# Get all data
|
|
55
|
+
data_response = await self.resource.query(f"SELECT * FROM {table}")
|
|
56
|
+
if data_response.rows and data_response.columns:
|
|
57
|
+
self._data[table] = [
|
|
58
|
+
dict(zip(data_response.columns, row))
|
|
59
|
+
for row in data_response.rows
|
|
60
|
+
]
|
|
61
|
+
else:
|
|
62
|
+
self._data[table] = []
|
|
63
|
+
|
|
64
|
+
self._fetched = True
|
|
65
|
+
|
|
66
|
+
async def tables(self) -> list[str]:
|
|
67
|
+
"""Get list of all tables in the snapshot."""
|
|
68
|
+
await self._ensure_fetched()
|
|
69
|
+
return list(self._data.keys())
|
|
70
|
+
|
|
71
|
+
def table(self, table_name: str) -> "AsyncSnapshotQueryBuilder":
|
|
72
|
+
"""Create a query builder for snapshot data."""
|
|
73
|
+
return AsyncSnapshotQueryBuilder(self, table_name)
|
|
74
|
+
|
|
75
|
+
async def diff(
|
|
76
|
+
self,
|
|
77
|
+
other: "AsyncDatabaseSnapshot",
|
|
78
|
+
ignore_config: IgnoreConfig | None = None,
|
|
79
|
+
) -> "AsyncSnapshotDiff":
|
|
80
|
+
"""Compare this snapshot with another."""
|
|
81
|
+
await self._ensure_fetched()
|
|
82
|
+
await other._ensure_fetched()
|
|
83
|
+
return AsyncSnapshotDiff(self, other, ignore_config)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class AsyncSnapshotQueryBuilder:
|
|
87
|
+
"""Query builder that works on local snapshot data."""
|
|
88
|
+
|
|
89
|
+
def __init__(self, snapshot: AsyncDatabaseSnapshot, table: str):
|
|
90
|
+
self._snapshot = snapshot
|
|
91
|
+
self._table = table
|
|
92
|
+
self._select_cols: list[str] = ["*"]
|
|
93
|
+
self._conditions: list[tuple[str, str, Any]] = []
|
|
94
|
+
self._limit: int | None = None
|
|
95
|
+
self._order_by: str | None = None
|
|
96
|
+
self._order_desc: bool = False
|
|
97
|
+
|
|
98
|
+
async def _get_data(self) -> list[dict[str, Any]]:
|
|
99
|
+
"""Get table data from snapshot."""
|
|
100
|
+
await self._snapshot._ensure_fetched()
|
|
101
|
+
return self._snapshot._data.get(self._table, [])
|
|
102
|
+
|
|
103
|
+
def eq(self, column: str, value: Any) -> "AsyncSnapshotQueryBuilder":
|
|
104
|
+
qb = self._clone()
|
|
105
|
+
qb._conditions.append((column, "=", value))
|
|
106
|
+
return qb
|
|
107
|
+
|
|
108
|
+
def limit(self, n: int) -> "AsyncSnapshotQueryBuilder":
|
|
109
|
+
qb = self._clone()
|
|
110
|
+
qb._limit = n
|
|
111
|
+
return qb
|
|
112
|
+
|
|
113
|
+
def sort(self, column: str, desc: bool = False) -> "AsyncSnapshotQueryBuilder":
|
|
114
|
+
qb = self._clone()
|
|
115
|
+
qb._order_by = column
|
|
116
|
+
qb._order_desc = desc
|
|
117
|
+
return qb
|
|
118
|
+
|
|
119
|
+
async def first(self) -> dict[str, Any] | None:
|
|
120
|
+
rows = await self.all()
|
|
121
|
+
return rows[0] if rows else None
|
|
122
|
+
|
|
123
|
+
async def all(self) -> list[dict[str, Any]]:
|
|
124
|
+
data = await self._get_data()
|
|
125
|
+
|
|
126
|
+
# Apply filters
|
|
127
|
+
filtered = data
|
|
128
|
+
for col, op, val in self._conditions:
|
|
129
|
+
if op == "=":
|
|
130
|
+
filtered = [row for row in filtered if row.get(col) == val]
|
|
131
|
+
|
|
132
|
+
# Apply sorting
|
|
133
|
+
if self._order_by:
|
|
134
|
+
filtered = sorted(
|
|
135
|
+
filtered,
|
|
136
|
+
key=lambda r: r.get(self._order_by),
|
|
137
|
+
reverse=self._order_desc
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Apply limit
|
|
141
|
+
if self._limit is not None:
|
|
142
|
+
filtered = filtered[:self._limit]
|
|
143
|
+
|
|
144
|
+
# Apply column selection
|
|
145
|
+
if self._select_cols != ["*"]:
|
|
146
|
+
filtered = [
|
|
147
|
+
{col: row.get(col) for col in self._select_cols}
|
|
148
|
+
for row in filtered
|
|
149
|
+
]
|
|
150
|
+
|
|
151
|
+
return filtered
|
|
152
|
+
|
|
153
|
+
async def assert_exists(self):
|
|
154
|
+
row = await self.first()
|
|
155
|
+
if row is None:
|
|
156
|
+
error_msg = (
|
|
157
|
+
f"Expected at least one matching row, but found none.\n"
|
|
158
|
+
f"Table: {self._table}"
|
|
159
|
+
)
|
|
160
|
+
if self._conditions:
|
|
161
|
+
conditions_str = ", ".join(
|
|
162
|
+
[f"{col} {op} {val}" for col, op, val in self._conditions]
|
|
163
|
+
)
|
|
164
|
+
error_msg += f"\nConditions: {conditions_str}"
|
|
165
|
+
raise AssertionError(error_msg)
|
|
166
|
+
return self
|
|
167
|
+
|
|
168
|
+
def _clone(self) -> "AsyncSnapshotQueryBuilder":
|
|
169
|
+
qb = AsyncSnapshotQueryBuilder(self._snapshot, self._table)
|
|
170
|
+
qb._select_cols = list(self._select_cols)
|
|
171
|
+
qb._conditions = list(self._conditions)
|
|
172
|
+
qb._limit = self._limit
|
|
173
|
+
qb._order_by = self._order_by
|
|
174
|
+
qb._order_desc = self._order_desc
|
|
175
|
+
return qb
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class AsyncSnapshotDiff:
|
|
179
|
+
"""Compute & validate changes between two snapshots fetched via API."""
|
|
180
|
+
|
|
181
|
+
def __init__(
|
|
182
|
+
self,
|
|
183
|
+
before: AsyncDatabaseSnapshot,
|
|
184
|
+
after: AsyncDatabaseSnapshot,
|
|
185
|
+
ignore_config: IgnoreConfig | None = None,
|
|
186
|
+
):
|
|
187
|
+
self.before = before
|
|
188
|
+
self.after = after
|
|
189
|
+
self.ignore_config = ignore_config or IgnoreConfig()
|
|
190
|
+
self._cached: dict[str, Any] | None = None
|
|
191
|
+
|
|
192
|
+
async def _get_primary_key_columns(self, table: str) -> list[str]:
|
|
193
|
+
"""Get primary key columns for a table."""
|
|
194
|
+
# Try to get from schema
|
|
195
|
+
schema_response = await self.after.resource.query(f"PRAGMA table_info({table})")
|
|
196
|
+
if not schema_response.rows:
|
|
197
|
+
return ["id"] # Default fallback
|
|
198
|
+
|
|
199
|
+
pk_columns = []
|
|
200
|
+
for row in schema_response.rows:
|
|
201
|
+
# row format: (cid, name, type, notnull, dflt_value, pk)
|
|
202
|
+
if row[5] > 0: # pk > 0 means it's part of primary key
|
|
203
|
+
pk_columns.append((row[5], row[1])) # (pk_position, column_name)
|
|
204
|
+
|
|
205
|
+
if not pk_columns:
|
|
206
|
+
# Try common defaults
|
|
207
|
+
all_columns = [row[1] for row in schema_response.rows]
|
|
208
|
+
if "id" in all_columns:
|
|
209
|
+
return ["id"]
|
|
210
|
+
return ["rowid"]
|
|
211
|
+
|
|
212
|
+
# Sort by primary key position and return just the column names
|
|
213
|
+
pk_columns.sort(key=lambda x: x[0])
|
|
214
|
+
return [col[1] for col in pk_columns]
|
|
215
|
+
|
|
216
|
+
async def _collect(self):
|
|
217
|
+
"""Collect all differences between snapshots."""
|
|
218
|
+
if self._cached is not None:
|
|
219
|
+
return self._cached
|
|
220
|
+
|
|
221
|
+
all_tables = set(await self.before.tables()) | set(await self.after.tables())
|
|
222
|
+
diff: dict[str, dict[str, Any]] = {}
|
|
223
|
+
|
|
224
|
+
for tbl in all_tables:
|
|
225
|
+
if self.ignore_config.should_ignore_table(tbl):
|
|
226
|
+
continue
|
|
227
|
+
|
|
228
|
+
# Get primary key columns
|
|
229
|
+
pk_columns = await self._get_primary_key_columns(tbl)
|
|
230
|
+
|
|
231
|
+
# Get data from both snapshots
|
|
232
|
+
before_data = self.before._data.get(tbl, [])
|
|
233
|
+
after_data = self.after._data.get(tbl, [])
|
|
234
|
+
|
|
235
|
+
# Create indexes by primary key
|
|
236
|
+
def make_key(row: dict, pk_cols: list[str]) -> Any:
|
|
237
|
+
if len(pk_cols) == 1:
|
|
238
|
+
return row.get(pk_cols[0])
|
|
239
|
+
return tuple(row.get(col) for col in pk_cols)
|
|
240
|
+
|
|
241
|
+
before_index = {make_key(row, pk_columns): row for row in before_data}
|
|
242
|
+
after_index = {make_key(row, pk_columns): row for row in after_data}
|
|
243
|
+
|
|
244
|
+
before_keys = set(before_index.keys())
|
|
245
|
+
after_keys = set(after_index.keys())
|
|
246
|
+
|
|
247
|
+
# Find changes
|
|
248
|
+
result = {
|
|
249
|
+
"table_name": tbl,
|
|
250
|
+
"primary_key": pk_columns,
|
|
251
|
+
"added_rows": [],
|
|
252
|
+
"removed_rows": [],
|
|
253
|
+
"modified_rows": [],
|
|
254
|
+
"unchanged_count": 0,
|
|
255
|
+
"total_changes": 0,
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
# Added rows
|
|
259
|
+
for key in after_keys - before_keys:
|
|
260
|
+
result["added_rows"].append({
|
|
261
|
+
"row_id": key,
|
|
262
|
+
"data": after_index[key]
|
|
263
|
+
})
|
|
264
|
+
|
|
265
|
+
# Removed rows
|
|
266
|
+
for key in before_keys - after_keys:
|
|
267
|
+
result["removed_rows"].append({
|
|
268
|
+
"row_id": key,
|
|
269
|
+
"data": before_index[key]
|
|
270
|
+
})
|
|
271
|
+
|
|
272
|
+
# Modified rows
|
|
273
|
+
for key in before_keys & after_keys:
|
|
274
|
+
before_row = before_index[key]
|
|
275
|
+
after_row = after_index[key]
|
|
276
|
+
changes = {}
|
|
277
|
+
|
|
278
|
+
for field in set(before_row.keys()) | set(after_row.keys()):
|
|
279
|
+
if self.ignore_config.should_ignore_field(tbl, field):
|
|
280
|
+
continue
|
|
281
|
+
before_val = before_row.get(field)
|
|
282
|
+
after_val = after_row.get(field)
|
|
283
|
+
if not _values_equivalent(before_val, after_val):
|
|
284
|
+
changes[field] = {"before": before_val, "after": after_val}
|
|
285
|
+
|
|
286
|
+
if changes:
|
|
287
|
+
result["modified_rows"].append({
|
|
288
|
+
"row_id": key,
|
|
289
|
+
"changes": changes,
|
|
290
|
+
"data": after_row # Current state
|
|
291
|
+
})
|
|
292
|
+
else:
|
|
293
|
+
result["unchanged_count"] += 1
|
|
294
|
+
|
|
295
|
+
result["total_changes"] = (
|
|
296
|
+
len(result["added_rows"]) +
|
|
297
|
+
len(result["removed_rows"]) +
|
|
298
|
+
len(result["modified_rows"])
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
diff[tbl] = result
|
|
302
|
+
|
|
303
|
+
self._cached = diff
|
|
304
|
+
return diff
|
|
305
|
+
|
|
306
|
+
async def expect_only(self, allowed_changes: list[dict[str, Any]]):
|
|
307
|
+
"""Ensure only specified changes occurred."""
|
|
308
|
+
diff = await self._collect()
|
|
309
|
+
|
|
310
|
+
def _is_change_allowed(
|
|
311
|
+
table: str, row_id: Any, field: str | None, after_value: Any
|
|
312
|
+
) -> bool:
|
|
313
|
+
"""Check if a change is in the allowed list using semantic comparison."""
|
|
314
|
+
for allowed in allowed_changes:
|
|
315
|
+
allowed_pk = allowed.get("pk")
|
|
316
|
+
# Handle type conversion for primary key comparison
|
|
317
|
+
pk_match = (
|
|
318
|
+
str(allowed_pk) == str(row_id) if allowed_pk is not None else False
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
if (
|
|
322
|
+
allowed["table"] == table
|
|
323
|
+
and pk_match
|
|
324
|
+
and allowed.get("field") == field
|
|
325
|
+
and _values_equivalent(allowed.get("after"), after_value)
|
|
326
|
+
):
|
|
327
|
+
return True
|
|
328
|
+
return False
|
|
329
|
+
|
|
330
|
+
# Collect all unexpected changes
|
|
331
|
+
unexpected_changes = []
|
|
332
|
+
|
|
333
|
+
for tbl, report in diff.items():
|
|
334
|
+
for row in report.get("modified_rows", []):
|
|
335
|
+
for f, vals in row["changes"].items():
|
|
336
|
+
if self.ignore_config.should_ignore_field(tbl, f):
|
|
337
|
+
continue
|
|
338
|
+
if not _is_change_allowed(tbl, row["row_id"], f, vals["after"]):
|
|
339
|
+
unexpected_changes.append({
|
|
340
|
+
"type": "modification",
|
|
341
|
+
"table": tbl,
|
|
342
|
+
"row_id": row["row_id"],
|
|
343
|
+
"field": f,
|
|
344
|
+
"before": vals.get("before"),
|
|
345
|
+
"after": vals["after"],
|
|
346
|
+
"full_row": row,
|
|
347
|
+
})
|
|
348
|
+
|
|
349
|
+
for row in report.get("added_rows", []):
|
|
350
|
+
if not _is_change_allowed(tbl, row["row_id"], None, "__added__"):
|
|
351
|
+
unexpected_changes.append({
|
|
352
|
+
"type": "insertion",
|
|
353
|
+
"table": tbl,
|
|
354
|
+
"row_id": row["row_id"],
|
|
355
|
+
"field": None,
|
|
356
|
+
"after": "__added__",
|
|
357
|
+
"full_row": row,
|
|
358
|
+
})
|
|
359
|
+
|
|
360
|
+
for row in report.get("removed_rows", []):
|
|
361
|
+
if not _is_change_allowed(tbl, row["row_id"], None, "__removed__"):
|
|
362
|
+
unexpected_changes.append({
|
|
363
|
+
"type": "deletion",
|
|
364
|
+
"table": tbl,
|
|
365
|
+
"row_id": row["row_id"],
|
|
366
|
+
"field": None,
|
|
367
|
+
"after": "__removed__",
|
|
368
|
+
"full_row": row,
|
|
369
|
+
})
|
|
370
|
+
|
|
371
|
+
if unexpected_changes:
|
|
372
|
+
# Build comprehensive error message
|
|
373
|
+
error_lines = ["Unexpected database changes detected:"]
|
|
374
|
+
error_lines.append("")
|
|
375
|
+
|
|
376
|
+
for i, change in enumerate(unexpected_changes[:5], 1):
|
|
377
|
+
error_lines.append(f"{i}. {change['type'].upper()} in table '{change['table']}':")
|
|
378
|
+
error_lines.append(f" Row ID: {change['row_id']}")
|
|
379
|
+
|
|
380
|
+
if change["type"] == "modification":
|
|
381
|
+
error_lines.append(f" Field: {change['field']}")
|
|
382
|
+
error_lines.append(f" Before: {repr(change['before'])}")
|
|
383
|
+
error_lines.append(f" After: {repr(change['after'])}")
|
|
384
|
+
elif change["type"] == "insertion":
|
|
385
|
+
error_lines.append(" New row added")
|
|
386
|
+
elif change["type"] == "deletion":
|
|
387
|
+
error_lines.append(" Row deleted")
|
|
388
|
+
|
|
389
|
+
# Show some context from the row
|
|
390
|
+
if "full_row" in change and change["full_row"]:
|
|
391
|
+
row_data = change["full_row"]
|
|
392
|
+
if "data" in row_data:
|
|
393
|
+
formatted_row = _format_row_for_error(
|
|
394
|
+
row_data.get("data", {}), max_fields=5
|
|
395
|
+
)
|
|
396
|
+
error_lines.append(f" Row data: {formatted_row}")
|
|
397
|
+
|
|
398
|
+
error_lines.append("")
|
|
399
|
+
|
|
400
|
+
if len(unexpected_changes) > 5:
|
|
401
|
+
error_lines.append(f"... and {len(unexpected_changes) - 5} more unexpected changes")
|
|
402
|
+
error_lines.append("")
|
|
403
|
+
|
|
404
|
+
# Show what changes were allowed
|
|
405
|
+
error_lines.append("Allowed changes were:")
|
|
406
|
+
if allowed_changes:
|
|
407
|
+
for i, allowed in enumerate(allowed_changes[:3], 1):
|
|
408
|
+
error_lines.append(
|
|
409
|
+
f" {i}. Table: {allowed.get('table')}, "
|
|
410
|
+
f"ID: {allowed.get('pk')}, "
|
|
411
|
+
f"Field: {allowed.get('field')}, "
|
|
412
|
+
f"After: {repr(allowed.get('after'))}"
|
|
413
|
+
)
|
|
414
|
+
if len(allowed_changes) > 3:
|
|
415
|
+
error_lines.append(f" ... and {len(allowed_changes) - 3} more allowed changes")
|
|
416
|
+
else:
|
|
417
|
+
error_lines.append(" (No changes were allowed)")
|
|
418
|
+
|
|
419
|
+
raise AssertionError("\n".join(error_lines))
|
|
420
|
+
|
|
421
|
+
return self
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
class AsyncQueryBuilder:
|
|
425
|
+
"""Async query builder that translates DSL to SQL and executes through the API."""
|
|
426
|
+
|
|
427
|
+
def __init__(self, resource: "AsyncSQLiteResource", table: str):
|
|
428
|
+
self._resource = resource
|
|
429
|
+
self._table = table
|
|
430
|
+
self._select_cols: list[str] = ["*"]
|
|
431
|
+
self._conditions: list[tuple[str, str, Any]] = []
|
|
432
|
+
self._joins: list[tuple[str, dict[str, str]]] = []
|
|
433
|
+
self._limit: int | None = None
|
|
434
|
+
self._order_by: str | None = None
|
|
435
|
+
|
|
436
|
+
# Column projection / limiting / ordering
|
|
437
|
+
def select(self, *columns: str) -> "AsyncQueryBuilder":
|
|
438
|
+
qb = self._clone()
|
|
439
|
+
qb._select_cols = list(columns) if columns else ["*"]
|
|
440
|
+
return qb
|
|
441
|
+
|
|
442
|
+
def limit(self, n: int) -> "AsyncQueryBuilder":
|
|
443
|
+
qb = self._clone()
|
|
444
|
+
qb._limit = n
|
|
445
|
+
return qb
|
|
446
|
+
|
|
447
|
+
def sort(self, column: str, desc: bool = False) -> "AsyncQueryBuilder":
|
|
448
|
+
qb = self._clone()
|
|
449
|
+
qb._order_by = f"{column} {'DESC' if desc else 'ASC'}"
|
|
450
|
+
return qb
|
|
451
|
+
|
|
452
|
+
# WHERE helpers
|
|
453
|
+
def _add_condition(self, column: str, op: str, value: Any) -> "AsyncQueryBuilder":
|
|
454
|
+
qb = self._clone()
|
|
455
|
+
qb._conditions.append((column, op, value))
|
|
456
|
+
return qb
|
|
457
|
+
|
|
458
|
+
def eq(self, column: str, value: Any) -> "AsyncQueryBuilder":
|
|
459
|
+
return self._add_condition(column, "=", value)
|
|
460
|
+
|
|
461
|
+
def neq(self, column: str, value: Any) -> "AsyncQueryBuilder":
|
|
462
|
+
return self._add_condition(column, "!=", value)
|
|
463
|
+
|
|
464
|
+
def gt(self, column: str, value: Any) -> "AsyncQueryBuilder":
|
|
465
|
+
return self._add_condition(column, ">", value)
|
|
466
|
+
|
|
467
|
+
def gte(self, column: str, value: Any) -> "AsyncQueryBuilder":
|
|
468
|
+
return self._add_condition(column, ">=", value)
|
|
469
|
+
|
|
470
|
+
def lt(self, column: str, value: Any) -> "AsyncQueryBuilder":
|
|
471
|
+
return self._add_condition(column, "<", value)
|
|
472
|
+
|
|
473
|
+
def lte(self, column: str, value: Any) -> "AsyncQueryBuilder":
|
|
474
|
+
return self._add_condition(column, "<=", value)
|
|
475
|
+
|
|
476
|
+
def in_(self, column: str, values: list[Any]) -> "AsyncQueryBuilder":
|
|
477
|
+
qb = self._clone()
|
|
478
|
+
qb._conditions.append((column, "IN", tuple(values)))
|
|
479
|
+
return qb
|
|
480
|
+
|
|
481
|
+
def not_in(self, column: str, values: list[Any]) -> "AsyncQueryBuilder":
|
|
482
|
+
qb = self._clone()
|
|
483
|
+
qb._conditions.append((column, "NOT IN", tuple(values)))
|
|
484
|
+
return qb
|
|
485
|
+
|
|
486
|
+
def is_null(self, column: str) -> "AsyncQueryBuilder":
|
|
487
|
+
return self._add_condition(column, "IS", None)
|
|
488
|
+
|
|
489
|
+
def not_null(self, column: str) -> "AsyncQueryBuilder":
|
|
490
|
+
return self._add_condition(column, "IS NOT", None)
|
|
491
|
+
|
|
492
|
+
def ilike(self, column: str, pattern: str) -> "AsyncQueryBuilder":
|
|
493
|
+
qb = self._clone()
|
|
494
|
+
qb._conditions.append((column, "LIKE", pattern))
|
|
495
|
+
return qb
|
|
496
|
+
|
|
497
|
+
# JOIN
|
|
498
|
+
def join(self, other_table: str, on: dict[str, str]) -> "AsyncQueryBuilder":
|
|
499
|
+
qb = self._clone()
|
|
500
|
+
qb._joins.append((other_table, on))
|
|
501
|
+
return qb
|
|
502
|
+
|
|
503
|
+
# Compile to SQL
|
|
504
|
+
def _compile(self) -> tuple[str, list[Any]]:
|
|
505
|
+
cols = ", ".join(self._select_cols)
|
|
506
|
+
sql = [f"SELECT {cols} FROM {self._table}"]
|
|
507
|
+
params: list[Any] = []
|
|
508
|
+
|
|
509
|
+
# Joins
|
|
510
|
+
for tbl, onmap in self._joins:
|
|
511
|
+
join_clauses = [
|
|
512
|
+
f"{self._table}.{l} = {tbl}.{r}"
|
|
513
|
+
for l, r in onmap.items()
|
|
514
|
+
]
|
|
515
|
+
sql.append(f"JOIN {tbl} ON {' AND '.join(join_clauses)}")
|
|
516
|
+
|
|
517
|
+
# WHERE
|
|
518
|
+
if self._conditions:
|
|
519
|
+
placeholders = []
|
|
520
|
+
for col, op, val in self._conditions:
|
|
521
|
+
if op in ("IN", "NOT IN") and isinstance(val, tuple):
|
|
522
|
+
ph = ", ".join(["?" for _ in val])
|
|
523
|
+
placeholders.append(f"{col} {op} ({ph})")
|
|
524
|
+
params.extend(val)
|
|
525
|
+
elif op in ("IS", "IS NOT"):
|
|
526
|
+
placeholders.append(f"{col} {op} NULL")
|
|
527
|
+
else:
|
|
528
|
+
placeholders.append(f"{col} {op} ?")
|
|
529
|
+
params.append(val)
|
|
530
|
+
sql.append("WHERE " + " AND ".join(placeholders))
|
|
531
|
+
|
|
532
|
+
# ORDER / LIMIT
|
|
533
|
+
if self._order_by:
|
|
534
|
+
sql.append(f"ORDER BY {self._order_by}")
|
|
535
|
+
if self._limit is not None:
|
|
536
|
+
sql.append(f"LIMIT {self._limit}")
|
|
537
|
+
|
|
538
|
+
return " ".join(sql), params
|
|
539
|
+
|
|
540
|
+
# Execution methods
|
|
541
|
+
async def count(self) -> int:
|
|
542
|
+
qb = self.select("COUNT(*) AS __cnt__").limit(None)
|
|
543
|
+
sql, params = qb._compile()
|
|
544
|
+
response = await self._resource.query(sql, params)
|
|
545
|
+
if response.rows and len(response.rows) > 0:
|
|
546
|
+
# Convert row list to dict
|
|
547
|
+
row_dict = dict(zip(response.columns or [], response.rows[0]))
|
|
548
|
+
return row_dict.get("__cnt__", 0)
|
|
549
|
+
return 0
|
|
550
|
+
|
|
551
|
+
async def first(self) -> dict[str, Any] | None:
|
|
552
|
+
rows = await self.limit(1).all()
|
|
553
|
+
return rows[0] if rows else None
|
|
554
|
+
|
|
555
|
+
async def all(self) -> list[dict[str, Any]]:
|
|
556
|
+
sql, params = self._compile()
|
|
557
|
+
response = await self._resource.query(sql, params)
|
|
558
|
+
if not response.rows:
|
|
559
|
+
return []
|
|
560
|
+
# Convert List[List] to List[dict] using column names
|
|
561
|
+
return [
|
|
562
|
+
dict(zip(response.columns or [], row))
|
|
563
|
+
for row in response.rows
|
|
564
|
+
]
|
|
565
|
+
|
|
566
|
+
# Assertions
|
|
567
|
+
async def assert_exists(self):
|
|
568
|
+
row = await self.first()
|
|
569
|
+
if row is None:
|
|
570
|
+
sql, params = self._compile()
|
|
571
|
+
error_msg = (
|
|
572
|
+
f"Expected at least one matching row, but found none.\n"
|
|
573
|
+
f"Query: {sql}\n"
|
|
574
|
+
f"Parameters: {params}\n"
|
|
575
|
+
f"Table: {self._table}"
|
|
576
|
+
)
|
|
577
|
+
if self._conditions:
|
|
578
|
+
conditions_str = ", ".join(
|
|
579
|
+
[f"{col} {op} {val}" for col, op, val in self._conditions]
|
|
580
|
+
)
|
|
581
|
+
error_msg += f"\nConditions: {conditions_str}"
|
|
582
|
+
raise AssertionError(error_msg)
|
|
583
|
+
return self
|
|
584
|
+
|
|
585
|
+
async def assert_none(self):
|
|
586
|
+
row = await self.first()
|
|
587
|
+
if row is not None:
|
|
588
|
+
sql, params = self._compile()
|
|
589
|
+
error_msg = (
|
|
590
|
+
f"Expected no matching rows, but found at least one.\n"
|
|
591
|
+
f"Found row: {row}\n"
|
|
592
|
+
f"Query: {sql}\n"
|
|
593
|
+
f"Parameters: {params}\n"
|
|
594
|
+
f"Table: {self._table}"
|
|
595
|
+
)
|
|
596
|
+
raise AssertionError(error_msg)
|
|
597
|
+
return self
|
|
598
|
+
|
|
599
|
+
async def assert_eq(self, column: str, value: Any):
|
|
600
|
+
row = await self.first()
|
|
601
|
+
if row is None:
|
|
602
|
+
sql, params = self._compile()
|
|
603
|
+
error_msg = (
|
|
604
|
+
f"Row not found for equality assertion.\n"
|
|
605
|
+
f"Expected to find a row with {column}={repr(value)}\n"
|
|
606
|
+
f"Query: {sql}\n"
|
|
607
|
+
f"Parameters: {params}\n"
|
|
608
|
+
f"Table: {self._table}"
|
|
609
|
+
)
|
|
610
|
+
raise AssertionError(error_msg)
|
|
611
|
+
|
|
612
|
+
actual_value = row.get(column)
|
|
613
|
+
if actual_value != value:
|
|
614
|
+
error_msg = (
|
|
615
|
+
f"Field value assertion failed.\n"
|
|
616
|
+
f"Field: {column}\n"
|
|
617
|
+
f"Expected: {repr(value)}\n"
|
|
618
|
+
f"Actual: {repr(actual_value)}\n"
|
|
619
|
+
f"Full row data: {row}\n"
|
|
620
|
+
f"Table: {self._table}"
|
|
621
|
+
)
|
|
622
|
+
raise AssertionError(error_msg)
|
|
623
|
+
return self
|
|
624
|
+
|
|
625
|
+
def _clone(self) -> "AsyncQueryBuilder":
|
|
626
|
+
qb = AsyncQueryBuilder(self._resource, self._table)
|
|
627
|
+
qb._select_cols = list(self._select_cols)
|
|
628
|
+
qb._conditions = list(self._conditions)
|
|
629
|
+
qb._joins = list(self._joins)
|
|
630
|
+
qb._limit = self._limit
|
|
631
|
+
qb._order_by = self._order_by
|
|
632
|
+
return qb
|
|
633
|
+
|
|
634
|
+
|
|
12
635
|
class AsyncSQLiteResource(Resource):
|
|
13
636
|
def __init__(self, resource: ResourceModel, client: "AsyncWrapper"):
|
|
14
637
|
super().__init__(resource)
|
|
@@ -39,3 +662,34 @@ class AsyncSQLiteResource(Resource):
|
|
|
39
662
|
json=request.model_dump(),
|
|
40
663
|
)
|
|
41
664
|
return QueryResponse(**response.json())
|
|
665
|
+
|
|
666
|
+
def table(self, table_name: str) -> AsyncQueryBuilder:
|
|
667
|
+
"""Create a query builder for the specified table."""
|
|
668
|
+
return AsyncQueryBuilder(self, table_name)
|
|
669
|
+
|
|
670
|
+
async def snapshot(self, name: str | None = None) -> AsyncDatabaseSnapshot:
|
|
671
|
+
"""Create a snapshot of the current database state."""
|
|
672
|
+
snapshot = AsyncDatabaseSnapshot(self, name)
|
|
673
|
+
await snapshot._ensure_fetched()
|
|
674
|
+
return snapshot
|
|
675
|
+
|
|
676
|
+
async def diff(
|
|
677
|
+
self,
|
|
678
|
+
other: "AsyncSQLiteResource",
|
|
679
|
+
ignore_config: IgnoreConfig | None = None,
|
|
680
|
+
) -> AsyncSnapshotDiff:
|
|
681
|
+
"""Compare this database with another AsyncSQLiteResource.
|
|
682
|
+
|
|
683
|
+
Args:
|
|
684
|
+
other: Another AsyncSQLiteResource to compare against
|
|
685
|
+
ignore_config: Optional configuration for ignoring specific tables/fields
|
|
686
|
+
|
|
687
|
+
Returns:
|
|
688
|
+
AsyncSnapshotDiff: Object containing the differences between the two databases
|
|
689
|
+
"""
|
|
690
|
+
# Create snapshots of both databases
|
|
691
|
+
before_snapshot = await self.snapshot(name=f"before_{datetime.utcnow().isoformat()}")
|
|
692
|
+
after_snapshot = await other.snapshot(name=f"after_{datetime.utcnow().isoformat()}")
|
|
693
|
+
|
|
694
|
+
# Return the diff between the snapshots
|
|
695
|
+
return await before_snapshot.diff(after_snapshot, ignore_config)
|