fleet-python 0.2.42__py3-none-any.whl → 0.2.44__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fleet-python might be problematic. Click here for more details.
- fleet/__init__.py +9 -7
- fleet/_async/__init__.py +28 -16
- fleet/_async/client.py +161 -69
- fleet/_async/env/client.py +9 -2
- fleet/_async/instance/client.py +1 -1
- fleet/_async/resources/sqlite.py +34 -34
- fleet/_async/tasks.py +42 -41
- fleet/_async/verifiers/verifier.py +3 -3
- fleet/client.py +164 -61
- fleet/env/client.py +9 -2
- fleet/instance/client.py +2 -4
- fleet/models.py +3 -1
- fleet/resources/sqlite.py +37 -43
- fleet/tasks.py +49 -65
- fleet/verifiers/__init__.py +1 -1
- fleet/verifiers/db.py +41 -36
- fleet/verifiers/parse.py +4 -1
- fleet/verifiers/sql_differ.py +8 -8
- fleet/verifiers/verifier.py +19 -7
- {fleet_python-0.2.42.dist-info → fleet_python-0.2.44.dist-info}/METADATA +1 -1
- {fleet_python-0.2.42.dist-info → fleet_python-0.2.44.dist-info}/RECORD +24 -24
- {fleet_python-0.2.42.dist-info → fleet_python-0.2.44.dist-info}/WHEEL +0 -0
- {fleet_python-0.2.42.dist-info → fleet_python-0.2.44.dist-info}/licenses/LICENSE +0 -0
- {fleet_python-0.2.42.dist-info → fleet_python-0.2.44.dist-info}/top_level.txt +0 -0
fleet/client.py
CHANGED
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
import base64
|
|
18
18
|
import cloudpickle
|
|
19
|
+
import concurrent.futures
|
|
19
20
|
import httpx
|
|
20
21
|
import json
|
|
21
22
|
import logging
|
|
@@ -129,7 +130,7 @@ class SyncEnv(EnvironmentBase):
|
|
|
129
130
|
return self.instance.verify(validator)
|
|
130
131
|
|
|
131
132
|
def verify_raw(
|
|
132
|
-
self, function_code: str, function_name: str
|
|
133
|
+
self, function_code: str, function_name: Optional[str] = None
|
|
133
134
|
) -> ExecuteFunctionResponse:
|
|
134
135
|
return self.instance.verify_raw(function_code, function_name)
|
|
135
136
|
|
|
@@ -206,24 +207,39 @@ class Fleet:
|
|
|
206
207
|
def make(
|
|
207
208
|
self,
|
|
208
209
|
env_key: str,
|
|
210
|
+
data_key: Optional[str] = None,
|
|
209
211
|
region: Optional[str] = None,
|
|
210
212
|
env_variables: Optional[Dict[str, Any]] = None,
|
|
211
213
|
) -> SyncEnv:
|
|
212
214
|
if ":" in env_key:
|
|
213
|
-
env_key_part,
|
|
215
|
+
env_key_part, env_version = env_key.split(":", 1)
|
|
214
216
|
if (
|
|
215
|
-
not
|
|
216
|
-
and len(
|
|
217
|
-
and
|
|
217
|
+
not env_version.startswith("v")
|
|
218
|
+
and len(env_version) != 0
|
|
219
|
+
and env_version[0].isdigit()
|
|
218
220
|
):
|
|
219
|
-
|
|
221
|
+
env_version = f"v{env_version}"
|
|
220
222
|
else:
|
|
221
223
|
env_key_part = env_key
|
|
222
|
-
|
|
224
|
+
env_version = None
|
|
225
|
+
|
|
226
|
+
if data_key is not None and ":" in data_key:
|
|
227
|
+
data_key_part, data_version = data_key.split(":", 1)
|
|
228
|
+
if (
|
|
229
|
+
not data_version.startswith("v")
|
|
230
|
+
and len(data_version) != 0
|
|
231
|
+
and data_version[0].isdigit()
|
|
232
|
+
):
|
|
233
|
+
data_version = f"v{data_version}"
|
|
234
|
+
else:
|
|
235
|
+
data_key_part = data_key
|
|
236
|
+
data_version = None
|
|
223
237
|
|
|
224
238
|
request = InstanceRequest(
|
|
225
239
|
env_key=env_key_part,
|
|
226
|
-
|
|
240
|
+
env_version=env_version,
|
|
241
|
+
data_key=data_key_part,
|
|
242
|
+
data_version=data_version,
|
|
227
243
|
region=region,
|
|
228
244
|
env_variables=env_variables,
|
|
229
245
|
created_from="sdk",
|
|
@@ -286,10 +302,19 @@ class Fleet:
|
|
|
286
302
|
|
|
287
303
|
return self.load_task_array_from_string(tasks_data)
|
|
288
304
|
|
|
289
|
-
def load_task_array_from_string(self, serialized_tasks:
|
|
305
|
+
def load_task_array_from_string(self, serialized_tasks: str) -> List[Task]:
|
|
290
306
|
tasks = []
|
|
291
307
|
|
|
292
|
-
|
|
308
|
+
parsed_data = json.loads(serialized_tasks)
|
|
309
|
+
if isinstance(parsed_data, list):
|
|
310
|
+
json_tasks = parsed_data
|
|
311
|
+
elif isinstance(parsed_data, dict) and "tasks" in parsed_data:
|
|
312
|
+
json_tasks = parsed_data["tasks"]
|
|
313
|
+
else:
|
|
314
|
+
raise ValueError(
|
|
315
|
+
"Invalid JSON structure: expected array or object with 'tasks' key"
|
|
316
|
+
)
|
|
317
|
+
|
|
293
318
|
for json_task in json_tasks:
|
|
294
319
|
parsed_task = self.load_task_from_json(json_task)
|
|
295
320
|
tasks.append(parsed_task)
|
|
@@ -300,25 +325,47 @@ class Fleet:
|
|
|
300
325
|
return self.load_task_from_json(task_json)
|
|
301
326
|
|
|
302
327
|
def load_task_from_json(self, task_json: Dict) -> Task:
|
|
328
|
+
verifier = None
|
|
329
|
+
verifier_code = task_json.get("verifier_func") or task_json.get("verifier_code")
|
|
330
|
+
|
|
331
|
+
# Try to find verifier_id in multiple locations
|
|
332
|
+
verifier_id = task_json.get("verifier_id")
|
|
333
|
+
if (
|
|
334
|
+
not verifier_id
|
|
335
|
+
and "metadata" in task_json
|
|
336
|
+
and isinstance(task_json["metadata"], dict)
|
|
337
|
+
):
|
|
338
|
+
verifier_metadata = task_json["metadata"].get("verifier", {})
|
|
339
|
+
if isinstance(verifier_metadata, dict):
|
|
340
|
+
verifier_id = verifier_metadata.get("verifier_id")
|
|
341
|
+
|
|
342
|
+
# If no verifier_id found, use the task key/id as fallback
|
|
343
|
+
if not verifier_id:
|
|
344
|
+
verifier_id = task_json.get("key", task_json.get("id"))
|
|
345
|
+
|
|
303
346
|
try:
|
|
304
|
-
if
|
|
347
|
+
if verifier_id and verifier_code:
|
|
305
348
|
verifier = self._create_verifier_from_data(
|
|
306
|
-
verifier_id=
|
|
307
|
-
verifier_key=task_json
|
|
308
|
-
verifier_code=
|
|
349
|
+
verifier_id=verifier_id,
|
|
350
|
+
verifier_key=task_json.get("key", task_json.get("id")),
|
|
351
|
+
verifier_code=verifier_code,
|
|
309
352
|
verifier_sha=task_json.get("verifier_sha", ""),
|
|
310
353
|
)
|
|
311
354
|
except Exception as e:
|
|
312
|
-
logger.warning(
|
|
355
|
+
logger.warning(
|
|
356
|
+
f"Failed to create verifier {task_json.get('key', task_json.get('id'))}: {e}"
|
|
357
|
+
)
|
|
313
358
|
|
|
314
359
|
task = Task(
|
|
315
|
-
key=task_json
|
|
360
|
+
key=task_json.get("key", task_json.get("id")),
|
|
316
361
|
prompt=task_json["prompt"],
|
|
317
|
-
env_id=task_json
|
|
318
|
-
|
|
362
|
+
env_id=task_json.get(
|
|
363
|
+
"env_id", task_json.get("env_key")
|
|
364
|
+
), # Use env_id or fallback to env_key
|
|
365
|
+
created_at=task_json.get("created_at"),
|
|
319
366
|
version=task_json.get("version"),
|
|
320
367
|
env_variables=task_json.get("env_variables", {}),
|
|
321
|
-
verifier_func=
|
|
368
|
+
verifier_func=verifier_code, # Set verifier code
|
|
322
369
|
verifier=verifier, # Use created verifier or None
|
|
323
370
|
metadata=task_json.get("metadata", {}), # Default empty metadata
|
|
324
371
|
)
|
|
@@ -353,48 +400,107 @@ class Fleet:
|
|
|
353
400
|
response = self.client.request("GET", "/v1/tasks", params=params)
|
|
354
401
|
task_list_response = TaskListResponse(**response.json())
|
|
355
402
|
|
|
356
|
-
#
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
# Create verifier function if verifier data is present
|
|
360
|
-
verifier = None
|
|
361
|
-
verifier_func = task_response.verifier_func
|
|
403
|
+
# Prepare verifier loading tasks
|
|
404
|
+
verifier_tasks = []
|
|
405
|
+
task_responses_with_indices = []
|
|
362
406
|
|
|
407
|
+
for idx, task_response in enumerate(task_list_response.tasks):
|
|
363
408
|
if task_response.verifier:
|
|
364
409
|
embedded_code = task_response.verifier.code or ""
|
|
365
410
|
is_embedded_error = embedded_code.strip().startswith(
|
|
366
411
|
"<error loading code:"
|
|
367
412
|
)
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
413
|
+
|
|
414
|
+
def create_verifier_with_fallback(tr, emb_code, is_error):
|
|
415
|
+
"""Create verifier with fallback logic."""
|
|
416
|
+
if not is_error:
|
|
417
|
+
# Try to create from embedded data
|
|
418
|
+
try:
|
|
419
|
+
return self._create_verifier_from_data(
|
|
420
|
+
verifier_id=tr.verifier.verifier_id,
|
|
421
|
+
verifier_key=tr.verifier.key,
|
|
422
|
+
verifier_code=emb_code,
|
|
423
|
+
verifier_sha=tr.verifier.sha256,
|
|
424
|
+
)
|
|
425
|
+
except Exception as e:
|
|
426
|
+
logger.warning(
|
|
427
|
+
f"Failed to create verifier {tr.verifier.key}: {e}"
|
|
428
|
+
)
|
|
429
|
+
return None
|
|
430
|
+
else:
|
|
431
|
+
# Fallback: try fetching by ID
|
|
432
|
+
try:
|
|
433
|
+
logger.warning(
|
|
434
|
+
f"Embedded verifier code missing for {tr.verifier.key} (NoSuchKey). "
|
|
435
|
+
f"Attempting to refetch by id {tr.verifier.verifier_id}"
|
|
436
|
+
)
|
|
437
|
+
return self._load_verifier(tr.verifier.verifier_id)
|
|
438
|
+
except Exception as e:
|
|
439
|
+
logger.warning(
|
|
440
|
+
f"Refetch by verifier id failed for {tr.verifier.key}: {e}. "
|
|
441
|
+
"Leaving verifier unset."
|
|
442
|
+
)
|
|
443
|
+
return None
|
|
444
|
+
|
|
445
|
+
# Add the task for parallel execution
|
|
446
|
+
verifier_tasks.append(
|
|
447
|
+
(
|
|
448
|
+
create_verifier_with_fallback,
|
|
449
|
+
task_response,
|
|
450
|
+
embedded_code,
|
|
451
|
+
is_embedded_error,
|
|
452
|
+
)
|
|
453
|
+
)
|
|
454
|
+
task_responses_with_indices.append((idx, task_response))
|
|
455
|
+
else:
|
|
456
|
+
# No verifier needed
|
|
457
|
+
verifier_tasks.append(None)
|
|
458
|
+
task_responses_with_indices.append((idx, task_response))
|
|
459
|
+
|
|
460
|
+
# Execute all verifier loading in parallel using ThreadPoolExecutor
|
|
461
|
+
verifier_results = []
|
|
462
|
+
if verifier_tasks:
|
|
463
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
|
464
|
+
futures = []
|
|
465
|
+
for task in verifier_tasks:
|
|
466
|
+
if task is not None:
|
|
467
|
+
func, tr, emb_code, is_error = task
|
|
468
|
+
future = executor.submit(func, tr, emb_code, is_error)
|
|
469
|
+
futures.append(future)
|
|
470
|
+
else:
|
|
471
|
+
futures.append(None)
|
|
472
|
+
|
|
473
|
+
# Collect results
|
|
474
|
+
for future in futures:
|
|
475
|
+
if future is None:
|
|
476
|
+
verifier_results.append(None)
|
|
477
|
+
else:
|
|
478
|
+
try:
|
|
479
|
+
result = future.result()
|
|
480
|
+
verifier_results.append(result)
|
|
481
|
+
except Exception as e:
|
|
482
|
+
logger.warning(f"Verifier loading failed: {e}")
|
|
483
|
+
verifier_results.append(None)
|
|
484
|
+
|
|
485
|
+
# Build tasks with results
|
|
486
|
+
tasks = []
|
|
487
|
+
for (idx, task_response), verifier_result in zip(
|
|
488
|
+
task_responses_with_indices, verifier_results
|
|
489
|
+
):
|
|
490
|
+
# Handle verifier result
|
|
491
|
+
verifier = None
|
|
492
|
+
verifier_func = task_response.verifier_func
|
|
493
|
+
|
|
494
|
+
if task_response.verifier:
|
|
495
|
+
# Process verifier result
|
|
496
|
+
if verifier_result is not None:
|
|
497
|
+
verifier = verifier_result
|
|
498
|
+
embedded_code = task_response.verifier.code or ""
|
|
499
|
+
is_embedded_error = embedded_code.strip().startswith(
|
|
500
|
+
"<error loading code:"
|
|
501
|
+
)
|
|
502
|
+
if not is_embedded_error:
|
|
503
|
+
verifier_func = embedded_code
|
|
398
504
|
|
|
399
505
|
task = Task(
|
|
400
506
|
key=task_response.key,
|
|
@@ -503,7 +609,7 @@ class Fleet:
|
|
|
503
609
|
self,
|
|
504
610
|
task_key: str,
|
|
505
611
|
prompt: Optional[str] = None,
|
|
506
|
-
verifier_code: Optional[str] = None
|
|
612
|
+
verifier_code: Optional[str] = None,
|
|
507
613
|
) -> TaskResponse:
|
|
508
614
|
"""Update an existing task.
|
|
509
615
|
|
|
@@ -515,10 +621,7 @@ class Fleet:
|
|
|
515
621
|
Returns:
|
|
516
622
|
TaskResponse containing the updated task details
|
|
517
623
|
"""
|
|
518
|
-
payload = TaskUpdateRequest(
|
|
519
|
-
prompt=prompt,
|
|
520
|
-
verifier_code=verifier_code
|
|
521
|
-
)
|
|
624
|
+
payload = TaskUpdateRequest(prompt=prompt, verifier_code=verifier_code)
|
|
522
625
|
response = self.client.request(
|
|
523
626
|
"PUT", f"/v1/tasks/{task_key}", json=payload.model_dump(exclude_none=True)
|
|
524
627
|
)
|
|
@@ -564,7 +667,7 @@ class Fleet:
|
|
|
564
667
|
AsyncVerifierFunction created from the verifier code
|
|
565
668
|
"""
|
|
566
669
|
# Fetch verifier from API
|
|
567
|
-
response = self.client.request("GET", f"/v1/
|
|
670
|
+
response = self.client.request("GET", f"/v1/verifiers/{verifier_id}")
|
|
568
671
|
verifier_data = response.json()
|
|
569
672
|
|
|
570
673
|
# Use the common method to create verifier
|
fleet/env/client.py
CHANGED
|
@@ -3,8 +3,15 @@ from ..models import Environment as EnvironmentModel, AccountResponse
|
|
|
3
3
|
from typing import List, Optional, Dict, Any
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
def make(
|
|
7
|
-
|
|
6
|
+
def make(
|
|
7
|
+
env_key: str,
|
|
8
|
+
data_key: Optional[str] = None,
|
|
9
|
+
region: Optional[str] = None,
|
|
10
|
+
env_variables: Optional[Dict[str, Any]] = None,
|
|
11
|
+
) -> SyncEnv:
|
|
12
|
+
return Fleet().make(
|
|
13
|
+
env_key, data_key=data_key, region=region, env_variables=env_variables
|
|
14
|
+
)
|
|
8
15
|
|
|
9
16
|
|
|
10
17
|
def make_for_task_async(task: Task) -> SyncEnv:
|
fleet/instance/client.py
CHANGED
|
@@ -63,9 +63,7 @@ class InstanceClient:
|
|
|
63
63
|
def load(self) -> None:
|
|
64
64
|
self._load_resources()
|
|
65
65
|
|
|
66
|
-
def reset(
|
|
67
|
-
self, reset_request: Optional[ResetRequest] = None
|
|
68
|
-
) -> ResetResponse:
|
|
66
|
+
def reset(self, reset_request: Optional[ResetRequest] = None) -> ResetResponse:
|
|
69
67
|
response = self.client.request(
|
|
70
68
|
"POST", "/reset", json=reset_request.model_dump() if reset_request else None
|
|
71
69
|
)
|
|
@@ -108,7 +106,7 @@ class InstanceClient:
|
|
|
108
106
|
return self.verify_raw(function_code, function_name)
|
|
109
107
|
|
|
110
108
|
def verify_raw(
|
|
111
|
-
self, function_code: str, function_name: str
|
|
109
|
+
self, function_code: str, function_name: Optional[str] = None
|
|
112
110
|
) -> ExecuteFunctionResponse:
|
|
113
111
|
try:
|
|
114
112
|
function_code = convert_verifier_string(function_code)
|
fleet/models.py
CHANGED
|
@@ -55,7 +55,9 @@ class Instance(BaseModel):
|
|
|
55
55
|
|
|
56
56
|
class InstanceRequest(BaseModel):
|
|
57
57
|
env_key: str = Field(..., title="Env Key")
|
|
58
|
-
|
|
58
|
+
env_version: Optional[str] = Field(None, title="Version")
|
|
59
|
+
data_key: Optional[str] = Field(None, title="Data Key")
|
|
60
|
+
data_version: Optional[str] = Field(None, title="Data Version")
|
|
59
61
|
region: Optional[str] = Field("us-west-1", title="Region")
|
|
60
62
|
seed: Optional[int] = Field(None, title="Seed")
|
|
61
63
|
timestamp: Optional[int] = Field(None, title="Timestamp")
|
fleet/resources/sqlite.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, List, Optional
|
|
1
|
+
from typing import Any, List, Optional, Dict, Tuple
|
|
2
2
|
from ..instance.models import Resource as ResourceModel
|
|
3
3
|
from ..instance.models import DescribeResponse, QueryRequest, QueryResponse
|
|
4
4
|
from .base import Resource
|
|
@@ -25,12 +25,12 @@ from fleet.verifiers.db import (
|
|
|
25
25
|
class SyncDatabaseSnapshot:
|
|
26
26
|
"""Async database snapshot that fetches data through API and stores locally for diffing."""
|
|
27
27
|
|
|
28
|
-
def __init__(self, resource: "SQLiteResource", name: str
|
|
28
|
+
def __init__(self, resource: "SQLiteResource", name: Optional[str] = None):
|
|
29
29
|
self.resource = resource
|
|
30
30
|
self.name = name or f"snapshot_{datetime.utcnow().isoformat()}"
|
|
31
31
|
self.created_at = datetime.utcnow()
|
|
32
|
-
self._data:
|
|
33
|
-
self._schemas:
|
|
32
|
+
self._data: Dict[str, List[Dict[str, Any]]] = {}
|
|
33
|
+
self._schemas: Dict[str, List[str]] = {}
|
|
34
34
|
self._fetched = False
|
|
35
35
|
|
|
36
36
|
def _ensure_fetched(self):
|
|
@@ -69,7 +69,7 @@ class SyncDatabaseSnapshot:
|
|
|
69
69
|
|
|
70
70
|
self._fetched = True
|
|
71
71
|
|
|
72
|
-
def tables(self) ->
|
|
72
|
+
def tables(self) -> List[str]:
|
|
73
73
|
"""Get list of all tables in the snapshot."""
|
|
74
74
|
self._ensure_fetched()
|
|
75
75
|
return list(self._data.keys())
|
|
@@ -81,7 +81,7 @@ class SyncDatabaseSnapshot:
|
|
|
81
81
|
def diff(
|
|
82
82
|
self,
|
|
83
83
|
other: "SyncDatabaseSnapshot",
|
|
84
|
-
ignore_config: IgnoreConfig
|
|
84
|
+
ignore_config: Optional[IgnoreConfig] = None,
|
|
85
85
|
) -> "SyncSnapshotDiff":
|
|
86
86
|
"""Compare this snapshot with another."""
|
|
87
87
|
self._ensure_fetched()
|
|
@@ -95,13 +95,13 @@ class SyncSnapshotQueryBuilder:
|
|
|
95
95
|
def __init__(self, snapshot: SyncDatabaseSnapshot, table: str):
|
|
96
96
|
self._snapshot = snapshot
|
|
97
97
|
self._table = table
|
|
98
|
-
self._select_cols:
|
|
99
|
-
self._conditions:
|
|
100
|
-
self._limit: int
|
|
101
|
-
self._order_by: str
|
|
98
|
+
self._select_cols: List[str] = ["*"]
|
|
99
|
+
self._conditions: List[Tuple[str, str, Any]] = []
|
|
100
|
+
self._limit: Optional[int] = None
|
|
101
|
+
self._order_by: Optional[str] = None
|
|
102
102
|
self._order_desc: bool = False
|
|
103
103
|
|
|
104
|
-
def _get_data(self) ->
|
|
104
|
+
def _get_data(self) -> List[Dict[str, Any]]:
|
|
105
105
|
"""Get table data from snapshot."""
|
|
106
106
|
self._snapshot._ensure_fetched()
|
|
107
107
|
return self._snapshot._data.get(self._table, [])
|
|
@@ -122,11 +122,11 @@ class SyncSnapshotQueryBuilder:
|
|
|
122
122
|
qb._order_desc = desc
|
|
123
123
|
return qb
|
|
124
124
|
|
|
125
|
-
def first(self) ->
|
|
125
|
+
def first(self) -> Optional[Dict[str, Any]]:
|
|
126
126
|
rows = self.all()
|
|
127
127
|
return rows[0] if rows else None
|
|
128
128
|
|
|
129
|
-
def all(self) ->
|
|
129
|
+
def all(self) -> List[Dict[str, Any]]:
|
|
130
130
|
data = self._get_data()
|
|
131
131
|
|
|
132
132
|
# Apply filters
|
|
@@ -185,14 +185,14 @@ class SyncSnapshotDiff:
|
|
|
185
185
|
self,
|
|
186
186
|
before: SyncDatabaseSnapshot,
|
|
187
187
|
after: SyncDatabaseSnapshot,
|
|
188
|
-
ignore_config: IgnoreConfig
|
|
188
|
+
ignore_config: Optional[IgnoreConfig] = None,
|
|
189
189
|
):
|
|
190
190
|
self.before = before
|
|
191
191
|
self.after = after
|
|
192
192
|
self.ignore_config = ignore_config or IgnoreConfig()
|
|
193
|
-
self._cached:
|
|
193
|
+
self._cached: Optional[Dict[str, Any]] = None
|
|
194
194
|
|
|
195
|
-
def _get_primary_key_columns(self, table: str) ->
|
|
195
|
+
def _get_primary_key_columns(self, table: str) -> List[str]:
|
|
196
196
|
"""Get primary key columns for a table."""
|
|
197
197
|
# Try to get from schema
|
|
198
198
|
schema_response = self.after.resource.query(f"PRAGMA table_info({table})")
|
|
@@ -222,7 +222,7 @@ class SyncSnapshotDiff:
|
|
|
222
222
|
return self._cached
|
|
223
223
|
|
|
224
224
|
all_tables = set(self.before.tables()) | set(self.after.tables())
|
|
225
|
-
diff:
|
|
225
|
+
diff: Dict[str, Dict[str, Any]] = {}
|
|
226
226
|
|
|
227
227
|
for tbl in all_tables:
|
|
228
228
|
if self.ignore_config.should_ignore_table(tbl):
|
|
@@ -236,7 +236,7 @@ class SyncSnapshotDiff:
|
|
|
236
236
|
after_data = self.after._data.get(tbl, [])
|
|
237
237
|
|
|
238
238
|
# Create indexes by primary key
|
|
239
|
-
def make_key(row: dict, pk_cols:
|
|
239
|
+
def make_key(row: dict, pk_cols: List[str]) -> Any:
|
|
240
240
|
if len(pk_cols) == 1:
|
|
241
241
|
return row.get(pk_cols[0])
|
|
242
242
|
return tuple(row.get(col) for col in pk_cols)
|
|
@@ -304,12 +304,12 @@ class SyncSnapshotDiff:
|
|
|
304
304
|
self._cached = diff
|
|
305
305
|
return diff
|
|
306
306
|
|
|
307
|
-
def expect_only(self, allowed_changes:
|
|
307
|
+
def expect_only(self, allowed_changes: List[Dict[str, Any]]):
|
|
308
308
|
"""Ensure only specified changes occurred."""
|
|
309
309
|
diff = self._collect()
|
|
310
310
|
|
|
311
311
|
def _is_change_allowed(
|
|
312
|
-
table: str, row_id: Any, field: str
|
|
312
|
+
table: str, row_id: Any, field: Optional[str], after_value: Any
|
|
313
313
|
) -> bool:
|
|
314
314
|
"""Check if a change is in the allowed list using semantic comparison."""
|
|
315
315
|
for allowed in allowed_changes:
|
|
@@ -440,11 +440,11 @@ class SyncQueryBuilder:
|
|
|
440
440
|
def __init__(self, resource: "SQLiteResource", table: str):
|
|
441
441
|
self._resource = resource
|
|
442
442
|
self._table = table
|
|
443
|
-
self._select_cols:
|
|
444
|
-
self._conditions:
|
|
445
|
-
self._joins:
|
|
446
|
-
self._limit: int
|
|
447
|
-
self._order_by: str
|
|
443
|
+
self._select_cols: List[str] = ["*"]
|
|
444
|
+
self._conditions: List[Tuple[str, str, Any]] = []
|
|
445
|
+
self._joins: List[Tuple[str, Dict[str, str]]] = []
|
|
446
|
+
self._limit: Optional[int] = None
|
|
447
|
+
self._order_by: Optional[str] = None
|
|
448
448
|
|
|
449
449
|
# Column projection / limiting / ordering
|
|
450
450
|
def select(self, *columns: str) -> "SyncQueryBuilder":
|
|
@@ -486,12 +486,12 @@ class SyncQueryBuilder:
|
|
|
486
486
|
def lte(self, column: str, value: Any) -> "SyncQueryBuilder":
|
|
487
487
|
return self._add_condition(column, "<=", value)
|
|
488
488
|
|
|
489
|
-
def in_(self, column: str, values:
|
|
489
|
+
def in_(self, column: str, values: List[Any]) -> "SyncQueryBuilder":
|
|
490
490
|
qb = self._clone()
|
|
491
491
|
qb._conditions.append((column, "IN", tuple(values)))
|
|
492
492
|
return qb
|
|
493
493
|
|
|
494
|
-
def not_in(self, column: str, values:
|
|
494
|
+
def not_in(self, column: str, values: List[Any]) -> "SyncQueryBuilder":
|
|
495
495
|
qb = self._clone()
|
|
496
496
|
qb._conditions.append((column, "NOT IN", tuple(values)))
|
|
497
497
|
return qb
|
|
@@ -508,16 +508,16 @@ class SyncQueryBuilder:
|
|
|
508
508
|
return qb
|
|
509
509
|
|
|
510
510
|
# JOIN
|
|
511
|
-
def join(self, other_table: str, on:
|
|
511
|
+
def join(self, other_table: str, on: Dict[str, str]) -> "SyncQueryBuilder":
|
|
512
512
|
qb = self._clone()
|
|
513
513
|
qb._joins.append((other_table, on))
|
|
514
514
|
return qb
|
|
515
515
|
|
|
516
516
|
# Compile to SQL
|
|
517
|
-
def _compile(self) ->
|
|
517
|
+
def _compile(self) -> Tuple[str, List[Any]]:
|
|
518
518
|
cols = ", ".join(self._select_cols)
|
|
519
519
|
sql = [f"SELECT {cols} FROM {self._table}"]
|
|
520
|
-
params:
|
|
520
|
+
params: List[Any] = []
|
|
521
521
|
|
|
522
522
|
# Joins
|
|
523
523
|
for tbl, onmap in self._joins:
|
|
@@ -558,11 +558,11 @@ class SyncQueryBuilder:
|
|
|
558
558
|
return row_dict.get("__cnt__", 0)
|
|
559
559
|
return 0
|
|
560
560
|
|
|
561
|
-
def first(self) ->
|
|
561
|
+
def first(self) -> Optional[Dict[str, Any]]:
|
|
562
562
|
rows = self.limit(1).all()
|
|
563
563
|
return rows[0] if rows else None
|
|
564
564
|
|
|
565
|
-
def all(self) ->
|
|
565
|
+
def all(self) -> List[Dict[str, Any]]:
|
|
566
566
|
sql, params = self._compile()
|
|
567
567
|
response = self._resource.query(sql, params)
|
|
568
568
|
if not response.rows:
|
|
@@ -651,9 +651,7 @@ class SQLiteResource(Resource):
|
|
|
651
651
|
)
|
|
652
652
|
return DescribeResponse(**response.json())
|
|
653
653
|
|
|
654
|
-
def query(
|
|
655
|
-
self, query: str, args: Optional[List[Any]] = None
|
|
656
|
-
) -> QueryResponse:
|
|
654
|
+
def query(self, query: str, args: Optional[List[Any]] = None) -> QueryResponse:
|
|
657
655
|
return self._query(query, args, read_only=True)
|
|
658
656
|
|
|
659
657
|
def exec(self, query: str, args: Optional[List[Any]] = None) -> QueryResponse:
|
|
@@ -674,7 +672,7 @@ class SQLiteResource(Resource):
|
|
|
674
672
|
"""Create a query builder for the specified table."""
|
|
675
673
|
return SyncQueryBuilder(self, table_name)
|
|
676
674
|
|
|
677
|
-
def snapshot(self, name: str
|
|
675
|
+
def snapshot(self, name: Optional[str] = None) -> SyncDatabaseSnapshot:
|
|
678
676
|
"""Create a snapshot of the current database state."""
|
|
679
677
|
snapshot = SyncDatabaseSnapshot(self, name)
|
|
680
678
|
snapshot._ensure_fetched()
|
|
@@ -683,7 +681,7 @@ class SQLiteResource(Resource):
|
|
|
683
681
|
def diff(
|
|
684
682
|
self,
|
|
685
683
|
other: "SQLiteResource",
|
|
686
|
-
ignore_config: IgnoreConfig
|
|
684
|
+
ignore_config: Optional[IgnoreConfig] = None,
|
|
687
685
|
) -> SyncSnapshotDiff:
|
|
688
686
|
"""Compare this database with another AsyncSQLiteResource.
|
|
689
687
|
|
|
@@ -695,12 +693,8 @@ class SQLiteResource(Resource):
|
|
|
695
693
|
AsyncSnapshotDiff: Object containing the differences between the two databases
|
|
696
694
|
"""
|
|
697
695
|
# Create snapshots of both databases
|
|
698
|
-
before_snapshot = self.snapshot(
|
|
699
|
-
|
|
700
|
-
)
|
|
701
|
-
after_snapshot = other.snapshot(
|
|
702
|
-
name=f"after_{datetime.utcnow().isoformat()}"
|
|
703
|
-
)
|
|
696
|
+
before_snapshot = self.snapshot(name=f"before_{datetime.utcnow().isoformat()}")
|
|
697
|
+
after_snapshot = other.snapshot(name=f"after_{datetime.utcnow().isoformat()}")
|
|
704
698
|
|
|
705
699
|
# Return the diff between the snapshots
|
|
706
700
|
return before_snapshot.diff(after_snapshot, ignore_config)
|