cinchdb 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cinchdb/__init__.py +7 -0
- cinchdb/__main__.py +6 -0
- cinchdb/api/__init__.py +5 -0
- cinchdb/api/app.py +76 -0
- cinchdb/api/auth.py +290 -0
- cinchdb/api/main.py +137 -0
- cinchdb/api/routers/__init__.py +25 -0
- cinchdb/api/routers/auth.py +135 -0
- cinchdb/api/routers/branches.py +368 -0
- cinchdb/api/routers/codegen.py +164 -0
- cinchdb/api/routers/columns.py +290 -0
- cinchdb/api/routers/data.py +479 -0
- cinchdb/api/routers/databases.py +177 -0
- cinchdb/api/routers/projects.py +133 -0
- cinchdb/api/routers/query.py +156 -0
- cinchdb/api/routers/tables.py +349 -0
- cinchdb/api/routers/tenants.py +216 -0
- cinchdb/api/routers/views.py +219 -0
- cinchdb/cli/__init__.py +0 -0
- cinchdb/cli/commands/__init__.py +1 -0
- cinchdb/cli/commands/branch.py +479 -0
- cinchdb/cli/commands/codegen.py +176 -0
- cinchdb/cli/commands/column.py +308 -0
- cinchdb/cli/commands/database.py +212 -0
- cinchdb/cli/commands/query.py +136 -0
- cinchdb/cli/commands/remote.py +144 -0
- cinchdb/cli/commands/table.py +289 -0
- cinchdb/cli/commands/tenant.py +173 -0
- cinchdb/cli/commands/view.py +189 -0
- cinchdb/cli/handlers/__init__.py +5 -0
- cinchdb/cli/handlers/codegen_handler.py +189 -0
- cinchdb/cli/main.py +137 -0
- cinchdb/cli/utils.py +182 -0
- cinchdb/config.py +177 -0
- cinchdb/core/__init__.py +5 -0
- cinchdb/core/connection.py +175 -0
- cinchdb/core/database.py +537 -0
- cinchdb/core/maintenance.py +73 -0
- cinchdb/core/path_utils.py +153 -0
- cinchdb/managers/__init__.py +26 -0
- cinchdb/managers/branch.py +167 -0
- cinchdb/managers/change_applier.py +414 -0
- cinchdb/managers/change_comparator.py +194 -0
- cinchdb/managers/change_tracker.py +182 -0
- cinchdb/managers/codegen.py +523 -0
- cinchdb/managers/column.py +579 -0
- cinchdb/managers/data.py +455 -0
- cinchdb/managers/merge_manager.py +429 -0
- cinchdb/managers/query.py +214 -0
- cinchdb/managers/table.py +383 -0
- cinchdb/managers/tenant.py +258 -0
- cinchdb/managers/view.py +252 -0
- cinchdb/models/__init__.py +27 -0
- cinchdb/models/base.py +44 -0
- cinchdb/models/branch.py +26 -0
- cinchdb/models/change.py +47 -0
- cinchdb/models/database.py +20 -0
- cinchdb/models/project.py +20 -0
- cinchdb/models/table.py +86 -0
- cinchdb/models/tenant.py +19 -0
- cinchdb/models/view.py +15 -0
- cinchdb/utils/__init__.py +15 -0
- cinchdb/utils/sql_validator.py +137 -0
- cinchdb-0.1.0.dist-info/METADATA +195 -0
- cinchdb-0.1.0.dist-info/RECORD +68 -0
- cinchdb-0.1.0.dist-info/WHEEL +4 -0
- cinchdb-0.1.0.dist-info/entry_points.txt +3 -0
- cinchdb-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,429 @@
|
|
1
|
+
"""Branch merging functionality for CinchDB."""
|
2
|
+
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import List, Dict, Any
|
5
|
+
from cinchdb.models import Change, ChangeType
|
6
|
+
from cinchdb.managers.change_tracker import ChangeTracker
|
7
|
+
from cinchdb.managers.change_applier import ChangeApplier
|
8
|
+
from cinchdb.managers.change_comparator import ChangeComparator
|
9
|
+
from cinchdb.managers.branch import BranchManager
|
10
|
+
from cinchdb.core.path_utils import list_tenants
|
11
|
+
|
12
|
+
|
13
|
+
class MergeError(Exception):
|
14
|
+
"""Exception raised when merge operations fail."""
|
15
|
+
|
16
|
+
pass
|
17
|
+
|
18
|
+
|
19
|
+
class MergeManager:
|
20
|
+
"""Manages merging operations between branches."""
|
21
|
+
|
22
|
+
def __init__(self, project_root: Path, database_name: str):
|
23
|
+
"""Initialize the merge manager.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
project_root: Path to the project root
|
27
|
+
database_name: Name of the database
|
28
|
+
"""
|
29
|
+
self.project_root = Path(project_root)
|
30
|
+
self.database_name = database_name
|
31
|
+
self.comparator = ChangeComparator(project_root, database_name)
|
32
|
+
self.branch_manager = BranchManager(project_root, database_name)
|
33
|
+
|
34
|
+
def can_merge(self, source_branch: str, target_branch: str) -> Dict[str, Any]:
|
35
|
+
"""Check if source branch can be merged into target branch.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
source_branch: Name of the source branch
|
39
|
+
target_branch: Name of the target branch
|
40
|
+
|
41
|
+
Returns:
|
42
|
+
Dictionary with merge status and details
|
43
|
+
"""
|
44
|
+
# Check if branches exist
|
45
|
+
if not self.branch_manager.branch_exists(source_branch):
|
46
|
+
return {
|
47
|
+
"can_merge": False,
|
48
|
+
"reason": f"Source branch '{source_branch}' does not exist",
|
49
|
+
}
|
50
|
+
|
51
|
+
if not self.branch_manager.branch_exists(target_branch):
|
52
|
+
return {
|
53
|
+
"can_merge": False,
|
54
|
+
"reason": f"Target branch '{target_branch}' does not exist",
|
55
|
+
}
|
56
|
+
|
57
|
+
# Check for conflicts
|
58
|
+
conflicts = self.comparator.detect_conflicts(source_branch, target_branch)
|
59
|
+
if conflicts:
|
60
|
+
return {
|
61
|
+
"can_merge": False,
|
62
|
+
"reason": "Merge conflicts detected",
|
63
|
+
"conflicts": conflicts,
|
64
|
+
}
|
65
|
+
|
66
|
+
# Check if there are changes to merge
|
67
|
+
source_only, target_only = self.comparator.get_divergent_changes(
|
68
|
+
source_branch, target_branch
|
69
|
+
)
|
70
|
+
if not source_only:
|
71
|
+
return {
|
72
|
+
"can_merge": False,
|
73
|
+
"reason": "No changes to merge from source branch",
|
74
|
+
}
|
75
|
+
|
76
|
+
# Check merge type
|
77
|
+
is_fast_forward = self.comparator.can_fast_forward_merge(
|
78
|
+
source_branch, target_branch
|
79
|
+
)
|
80
|
+
|
81
|
+
return {
|
82
|
+
"can_merge": True,
|
83
|
+
"merge_type": "fast_forward" if is_fast_forward else "three_way",
|
84
|
+
"changes_to_merge": len(source_only),
|
85
|
+
"target_changes": len(target_only),
|
86
|
+
}
|
87
|
+
|
88
|
+
def _merge_branches_internal(
|
89
|
+
self,
|
90
|
+
source_branch: str,
|
91
|
+
target_branch: str,
|
92
|
+
force: bool = False,
|
93
|
+
dry_run: bool = False,
|
94
|
+
) -> Dict[str, Any]:
|
95
|
+
"""Internal merge method without main branch protection.
|
96
|
+
|
97
|
+
Args:
|
98
|
+
source_branch: Name of the source branch
|
99
|
+
target_branch: Name of the target branch
|
100
|
+
force: If True, attempt merge even with conflicts
|
101
|
+
dry_run: If True, return SQL statements without executing
|
102
|
+
|
103
|
+
Returns:
|
104
|
+
Dictionary with merge result details
|
105
|
+
|
106
|
+
Raises:
|
107
|
+
MergeError: If merge cannot be completed
|
108
|
+
"""
|
109
|
+
# Check if merge is possible
|
110
|
+
merge_check = self.can_merge(source_branch, target_branch)
|
111
|
+
if not merge_check["can_merge"]:
|
112
|
+
if not force:
|
113
|
+
raise MergeError(f"Cannot merge: {merge_check['reason']}")
|
114
|
+
elif "conflicts" in merge_check:
|
115
|
+
raise MergeError(
|
116
|
+
f"Cannot force merge due to conflicts: {', '.join(merge_check['conflicts'])}"
|
117
|
+
)
|
118
|
+
|
119
|
+
# Get changes to merge
|
120
|
+
source_only, _ = self.comparator.get_divergent_changes(
|
121
|
+
source_branch, target_branch
|
122
|
+
)
|
123
|
+
if not source_only:
|
124
|
+
return {
|
125
|
+
"success": True,
|
126
|
+
"message": "No changes to merge",
|
127
|
+
"changes_merged": 0,
|
128
|
+
}
|
129
|
+
|
130
|
+
# Order changes for safe merging
|
131
|
+
ordered_changes = self.comparator.get_merge_order(source_only)
|
132
|
+
|
133
|
+
if dry_run:
|
134
|
+
# Collect SQL statements that would be executed
|
135
|
+
sql_statements = self._collect_sql_statements(
|
136
|
+
ordered_changes, target_branch
|
137
|
+
)
|
138
|
+
return {
|
139
|
+
"success": True,
|
140
|
+
"dry_run": True,
|
141
|
+
"message": f"Dry run: would merge {len(ordered_changes)} changes from '{source_branch}' to '{target_branch}'",
|
142
|
+
"changes_to_merge": len(ordered_changes),
|
143
|
+
"merge_type": merge_check.get("merge_type", "unknown"),
|
144
|
+
"sql_statements": sql_statements,
|
145
|
+
}
|
146
|
+
|
147
|
+
try:
|
148
|
+
# Apply changes to target branch atomically
|
149
|
+
self._apply_changes_to_branch(ordered_changes, target_branch)
|
150
|
+
|
151
|
+
return {
|
152
|
+
"success": True,
|
153
|
+
"message": f"Successfully merged {len(ordered_changes)} changes from '{source_branch}' to '{target_branch}'",
|
154
|
+
"changes_merged": len(ordered_changes),
|
155
|
+
"merge_type": merge_check.get("merge_type", "unknown"),
|
156
|
+
}
|
157
|
+
|
158
|
+
except Exception as e:
|
159
|
+
raise MergeError(f"Merge failed during application: {str(e)}")
|
160
|
+
|
161
|
+
def merge_branches(
|
162
|
+
self,
|
163
|
+
source_branch: str,
|
164
|
+
target_branch: str,
|
165
|
+
force: bool = False,
|
166
|
+
dry_run: bool = False,
|
167
|
+
) -> Dict[str, Any]:
|
168
|
+
"""Merge source branch into target branch.
|
169
|
+
|
170
|
+
Args:
|
171
|
+
source_branch: Name of the source branch
|
172
|
+
target_branch: Name of the target branch
|
173
|
+
force: If True, attempt merge even with conflicts
|
174
|
+
dry_run: If True, return SQL statements without executing
|
175
|
+
|
176
|
+
Returns:
|
177
|
+
Dictionary with merge result details
|
178
|
+
|
179
|
+
Raises:
|
180
|
+
MergeError: If merge cannot be completed
|
181
|
+
"""
|
182
|
+
# Protect main branch from direct changes
|
183
|
+
if target_branch == "main":
|
184
|
+
raise MergeError(
|
185
|
+
"Cannot merge directly into main branch. Main branch is protected."
|
186
|
+
)
|
187
|
+
|
188
|
+
# Use internal method for actual merge
|
189
|
+
return self._merge_branches_internal(
|
190
|
+
source_branch, target_branch, force, dry_run
|
191
|
+
)
|
192
|
+
|
193
|
+
def _apply_changes_to_branch(
|
194
|
+
self, changes: List[Change], target_branch: str
|
195
|
+
) -> None:
|
196
|
+
"""Apply changes to target branch atomically.
|
197
|
+
|
198
|
+
Args:
|
199
|
+
changes: List of changes to apply
|
200
|
+
target_branch: Name of the target branch
|
201
|
+
|
202
|
+
Raises:
|
203
|
+
MergeError: If changes cannot be applied
|
204
|
+
"""
|
205
|
+
target_tracker = ChangeTracker(
|
206
|
+
self.project_root, self.database_name, target_branch
|
207
|
+
)
|
208
|
+
target_applier = ChangeApplier(
|
209
|
+
self.project_root, self.database_name, target_branch
|
210
|
+
)
|
211
|
+
|
212
|
+
# Get all tenants for the target branch
|
213
|
+
list_tenants(self.project_root, self.database_name, target_branch)
|
214
|
+
|
215
|
+
applied_changes = []
|
216
|
+
|
217
|
+
try:
|
218
|
+
for change in changes:
|
219
|
+
# Create a copy of the change for the target branch, marking as unapplied
|
220
|
+
# so it gets executed in the target branch's database
|
221
|
+
change_copy = change.model_copy()
|
222
|
+
change_copy.applied = False
|
223
|
+
|
224
|
+
# Add change to target branch history
|
225
|
+
target_tracker.add_change(change_copy)
|
226
|
+
applied_changes.append(change_copy.id)
|
227
|
+
|
228
|
+
# Apply change to all tenants in target branch
|
229
|
+
target_applier.apply_change(change_copy.id)
|
230
|
+
|
231
|
+
except Exception as e:
|
232
|
+
# Rollback: remove changes that were added but not fully applied
|
233
|
+
for change_id in applied_changes:
|
234
|
+
try:
|
235
|
+
# Note: This is a simplified rollback - in production you'd want
|
236
|
+
# more sophisticated transaction handling
|
237
|
+
target_tracker.remove_change(change_id) if hasattr(
|
238
|
+
target_tracker, "remove_change"
|
239
|
+
) else None
|
240
|
+
except Exception:
|
241
|
+
pass # Best effort rollback
|
242
|
+
|
243
|
+
raise MergeError(f"Failed to apply changes: {str(e)}")
|
244
|
+
|
245
|
+
def _collect_sql_statements(
|
246
|
+
self, changes: List[Change], target_branch: str
|
247
|
+
) -> List[Dict[str, Any]]:
|
248
|
+
"""Collect SQL statements that would be executed for the given changes.
|
249
|
+
|
250
|
+
Args:
|
251
|
+
changes: List of changes to collect SQL for
|
252
|
+
target_branch: Name of the target branch
|
253
|
+
|
254
|
+
Returns:
|
255
|
+
List of dictionaries with SQL statement information
|
256
|
+
"""
|
257
|
+
sql_statements = []
|
258
|
+
|
259
|
+
for change in changes:
|
260
|
+
# Determine what SQL would be executed
|
261
|
+
if change.details and "statements" in change.details:
|
262
|
+
# Multiple statements in a transaction
|
263
|
+
for step_name, sql in change.details["statements"]:
|
264
|
+
sql_statements.append(
|
265
|
+
{
|
266
|
+
"change_id": change.id,
|
267
|
+
"change_type": change.type.value
|
268
|
+
if hasattr(change.type, "value")
|
269
|
+
else change.type,
|
270
|
+
"entity_name": change.entity_name,
|
271
|
+
"step": step_name,
|
272
|
+
"sql": sql,
|
273
|
+
}
|
274
|
+
)
|
275
|
+
elif change.type == ChangeType.UPDATE_VIEW:
|
276
|
+
# View update requires DROP and CREATE
|
277
|
+
sql_statements.append(
|
278
|
+
{
|
279
|
+
"change_id": change.id,
|
280
|
+
"change_type": change.type.value
|
281
|
+
if hasattr(change.type, "value")
|
282
|
+
else change.type,
|
283
|
+
"entity_name": change.entity_name,
|
284
|
+
"step": "drop_existing",
|
285
|
+
"sql": f"DROP VIEW IF EXISTS {change.entity_name}",
|
286
|
+
}
|
287
|
+
)
|
288
|
+
sql_statements.append(
|
289
|
+
{
|
290
|
+
"change_id": change.id,
|
291
|
+
"change_type": change.type.value
|
292
|
+
if hasattr(change.type, "value")
|
293
|
+
else change.type,
|
294
|
+
"entity_name": change.entity_name,
|
295
|
+
"step": "create_view",
|
296
|
+
"sql": change.sql,
|
297
|
+
}
|
298
|
+
)
|
299
|
+
elif (
|
300
|
+
change.type == ChangeType.CREATE_TABLE
|
301
|
+
and change.details
|
302
|
+
and change.details.get("copy_sql")
|
303
|
+
):
|
304
|
+
# Table copy operation
|
305
|
+
sql_statements.append(
|
306
|
+
{
|
307
|
+
"change_id": change.id,
|
308
|
+
"change_type": change.type.value
|
309
|
+
if hasattr(change.type, "value")
|
310
|
+
else change.type,
|
311
|
+
"entity_name": change.entity_name,
|
312
|
+
"step": "create_table",
|
313
|
+
"sql": change.sql,
|
314
|
+
}
|
315
|
+
)
|
316
|
+
sql_statements.append(
|
317
|
+
{
|
318
|
+
"change_id": change.id,
|
319
|
+
"change_type": change.type.value
|
320
|
+
if hasattr(change.type, "value")
|
321
|
+
else change.type,
|
322
|
+
"entity_name": change.entity_name,
|
323
|
+
"step": "copy_data",
|
324
|
+
"sql": change.details["copy_sql"],
|
325
|
+
}
|
326
|
+
)
|
327
|
+
else:
|
328
|
+
# Regular single statement
|
329
|
+
sql_statements.append(
|
330
|
+
{
|
331
|
+
"change_id": change.id,
|
332
|
+
"change_type": change.type.value
|
333
|
+
if hasattr(change.type, "value")
|
334
|
+
else change.type,
|
335
|
+
"entity_name": change.entity_name,
|
336
|
+
"sql": change.sql,
|
337
|
+
}
|
338
|
+
)
|
339
|
+
|
340
|
+
return sql_statements
|
341
|
+
|
342
|
+
def merge_into_main(
|
343
|
+
self, source_branch: str, dry_run: bool = False
|
344
|
+
) -> Dict[str, Any]:
|
345
|
+
"""Merge a branch into main branch with additional validation.
|
346
|
+
|
347
|
+
This is the primary way to get changes into main branch.
|
348
|
+
|
349
|
+
Args:
|
350
|
+
source_branch: Name of the source branch to merge
|
351
|
+
dry_run: If True, return SQL statements without executing
|
352
|
+
|
353
|
+
Returns:
|
354
|
+
Dictionary with merge result details
|
355
|
+
|
356
|
+
Raises:
|
357
|
+
MergeError: If merge cannot be completed
|
358
|
+
"""
|
359
|
+
if source_branch == "main":
|
360
|
+
raise MergeError("Cannot merge main branch into itself")
|
361
|
+
|
362
|
+
# Additional validation for main branch merges
|
363
|
+
merge_check = self.can_merge(source_branch, "main")
|
364
|
+
if not merge_check["can_merge"]:
|
365
|
+
raise MergeError(f"Cannot merge into main: {merge_check['reason']}")
|
366
|
+
|
367
|
+
# Ensure source branch has all changes from main (is up to date)
|
368
|
+
main_only, source_only = self.comparator.get_divergent_changes(
|
369
|
+
"main", source_branch
|
370
|
+
)
|
371
|
+
if main_only:
|
372
|
+
raise MergeError(
|
373
|
+
f"Source branch '{source_branch}' is not up to date with main. "
|
374
|
+
f"Pull latest changes from main first."
|
375
|
+
)
|
376
|
+
|
377
|
+
# Perform the merge (bypass protection for official merge into main)
|
378
|
+
return self._merge_branches_internal(source_branch, "main", dry_run=dry_run)
|
379
|
+
|
380
|
+
def get_merge_preview(
|
381
|
+
self, source_branch: str, target_branch: str
|
382
|
+
) -> Dict[str, Any]:
|
383
|
+
"""Get a preview of what would happen during a merge.
|
384
|
+
|
385
|
+
Args:
|
386
|
+
source_branch: Name of the source branch
|
387
|
+
target_branch: Name of the target branch
|
388
|
+
|
389
|
+
Returns:
|
390
|
+
Dictionary with merge preview details
|
391
|
+
"""
|
392
|
+
merge_check = self.can_merge(source_branch, target_branch)
|
393
|
+
|
394
|
+
if not merge_check["can_merge"]:
|
395
|
+
return {
|
396
|
+
"can_merge": False,
|
397
|
+
"reason": merge_check["reason"],
|
398
|
+
"conflicts": merge_check.get("conflicts", []),
|
399
|
+
}
|
400
|
+
|
401
|
+
source_only, target_only = self.comparator.get_divergent_changes(
|
402
|
+
source_branch, target_branch
|
403
|
+
)
|
404
|
+
|
405
|
+
# Categorize changes by type
|
406
|
+
changes_by_type = {}
|
407
|
+
for change in source_only:
|
408
|
+
entity_type = change.entity_type
|
409
|
+
if entity_type not in changes_by_type:
|
410
|
+
changes_by_type[entity_type] = []
|
411
|
+
changes_by_type[entity_type].append(
|
412
|
+
{
|
413
|
+
"id": change.id,
|
414
|
+
"entity_name": change.entity_name,
|
415
|
+
"operation": change.type,
|
416
|
+
"timestamp": change.created_at.isoformat(),
|
417
|
+
}
|
418
|
+
)
|
419
|
+
|
420
|
+
return {
|
421
|
+
"can_merge": True,
|
422
|
+
"merge_type": merge_check.get("merge_type", "unknown"),
|
423
|
+
"changes_to_merge": len(source_only),
|
424
|
+
"target_has_changes": len(target_only) > 0,
|
425
|
+
"changes_by_type": changes_by_type,
|
426
|
+
"common_ancestor": self.comparator.find_common_ancestor(
|
427
|
+
source_branch, target_branch
|
428
|
+
),
|
429
|
+
}
|
@@ -0,0 +1,214 @@
|
|
1
|
+
"""Query execution manager for CinchDB - handles SQL queries with type-safe returns."""
|
2
|
+
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import List, Dict, Any, Optional, Type, TypeVar, Union
|
5
|
+
|
6
|
+
from pydantic import BaseModel, ValidationError
|
7
|
+
|
8
|
+
from cinchdb.core.connection import DatabaseConnection
|
9
|
+
from cinchdb.core.path_utils import get_tenant_db_path
|
10
|
+
from cinchdb.utils import validate_query_safe, SQLValidationError
|
11
|
+
|
12
|
+
T = TypeVar("T", bound=BaseModel)
|
13
|
+
|
14
|
+
|
15
|
+
class QueryManager:
|
16
|
+
"""Manages SQL query execution with support for typed returns."""
|
17
|
+
|
18
|
+
def __init__(
|
19
|
+
self, project_root: Path, database: str, branch: str, tenant: str = "main"
|
20
|
+
):
|
21
|
+
"""Initialize query manager.
|
22
|
+
|
23
|
+
Args:
|
24
|
+
project_root: Path to project root
|
25
|
+
database: Database name
|
26
|
+
branch: Branch name
|
27
|
+
tenant: Tenant name (default: main)
|
28
|
+
"""
|
29
|
+
self.project_root = Path(project_root)
|
30
|
+
self.database = database
|
31
|
+
self.branch = branch
|
32
|
+
self.tenant = tenant
|
33
|
+
self.db_path = get_tenant_db_path(project_root, database, branch, tenant)
|
34
|
+
|
35
|
+
def execute(
|
36
|
+
self, sql: str, params: Optional[Union[tuple, dict]] = None, skip_validation: bool = False
|
37
|
+
) -> List[Dict[str, Any]]:
|
38
|
+
"""Execute a SQL query and return results as dictionaries.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
sql: SQL query to execute
|
42
|
+
params: Optional query parameters (tuple for positional, dict for named)
|
43
|
+
skip_validation: Skip SQL validation (default: False)
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
List of dictionaries representing rows
|
47
|
+
|
48
|
+
Raises:
|
49
|
+
SQLValidationError: If query contains restricted operations
|
50
|
+
Exception: If query execution fails
|
51
|
+
"""
|
52
|
+
# Validate query unless explicitly skipped
|
53
|
+
if not skip_validation:
|
54
|
+
validate_query_safe(sql)
|
55
|
+
|
56
|
+
# Note: The original code had SELECT-only validation, but we're now more permissive
|
57
|
+
if not sql.strip().upper().startswith("SELECT"):
|
58
|
+
raise ValueError(
|
59
|
+
"execute() can only be used with SELECT queries. Use execute_non_query() for INSERT/UPDATE/DELETE operations."
|
60
|
+
)
|
61
|
+
|
62
|
+
with DatabaseConnection(self.db_path) as conn:
|
63
|
+
cursor = conn.execute(sql, params)
|
64
|
+
rows = cursor.fetchall()
|
65
|
+
return [dict(row) for row in rows]
|
66
|
+
|
67
|
+
def execute_typed(
|
68
|
+
self,
|
69
|
+
sql: str,
|
70
|
+
model: Type[T],
|
71
|
+
params: Optional[Union[tuple, dict]] = None,
|
72
|
+
strict: bool = True,
|
73
|
+
) -> List[T]:
|
74
|
+
"""Execute a SQL query and return results as typed model instances.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
sql: SQL query to execute
|
78
|
+
model: Pydantic model class to validate results against
|
79
|
+
params: Optional query parameters
|
80
|
+
strict: If True, raise on validation errors; if False, skip invalid rows
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
List of model instances
|
84
|
+
|
85
|
+
Raises:
|
86
|
+
ValueError: If query is not a SELECT query
|
87
|
+
ValidationError: If strict=True and validation fails
|
88
|
+
Exception: If query execution fails
|
89
|
+
"""
|
90
|
+
# Ensure this is a SELECT query
|
91
|
+
if not sql.strip().upper().startswith("SELECT"):
|
92
|
+
raise ValueError("execute_typed can only be used with SELECT queries")
|
93
|
+
|
94
|
+
# Execute query and get raw results
|
95
|
+
rows = self.execute(sql, params)
|
96
|
+
|
97
|
+
# Convert to typed results
|
98
|
+
typed_results = []
|
99
|
+
validation_errors = []
|
100
|
+
|
101
|
+
for i, row in enumerate(rows):
|
102
|
+
try:
|
103
|
+
instance = model(**row)
|
104
|
+
typed_results.append(instance)
|
105
|
+
except ValidationError as e:
|
106
|
+
if strict:
|
107
|
+
# Re-raise with more context
|
108
|
+
raise ValueError(
|
109
|
+
f"Row {i} failed validation for {model.__name__}: {str(e)}"
|
110
|
+
)
|
111
|
+
else:
|
112
|
+
validation_errors.append((i, str(e)))
|
113
|
+
|
114
|
+
# If we had validation errors in non-strict mode, we could log them
|
115
|
+
# For now, we'll just return the valid results
|
116
|
+
|
117
|
+
return typed_results
|
118
|
+
|
119
|
+
def execute_one(
|
120
|
+
self, sql: str, params: Optional[Union[tuple, dict]] = None
|
121
|
+
) -> Optional[Dict[str, Any]]:
|
122
|
+
"""Execute a SQL query and return at most one result as a dictionary.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
sql: SQL query to execute
|
126
|
+
params: Optional query parameters
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
Dictionary representing a single row, or None if no results
|
130
|
+
|
131
|
+
Raises:
|
132
|
+
Exception: If query execution fails
|
133
|
+
"""
|
134
|
+
results = self.execute(sql, params)
|
135
|
+
return results[0] if results else None
|
136
|
+
|
137
|
+
def execute_one_typed(
|
138
|
+
self,
|
139
|
+
sql: str,
|
140
|
+
model: Type[T],
|
141
|
+
params: Optional[Union[tuple, dict]] = None,
|
142
|
+
strict: bool = True,
|
143
|
+
) -> Optional[T]:
|
144
|
+
"""Execute a SQL query and return at most one result as a typed model instance.
|
145
|
+
|
146
|
+
Args:
|
147
|
+
sql: SQL query to execute
|
148
|
+
model: Pydantic model class to validate result against
|
149
|
+
params: Optional query parameters
|
150
|
+
strict: If True, raise on validation errors
|
151
|
+
|
152
|
+
Returns:
|
153
|
+
Model instance or None if no results
|
154
|
+
|
155
|
+
Raises:
|
156
|
+
ValueError: If query is not a SELECT query
|
157
|
+
ValidationError: If strict=True and validation fails
|
158
|
+
Exception: If query execution fails
|
159
|
+
"""
|
160
|
+
results = self.execute_typed(sql, model, params, strict)
|
161
|
+
return results[0] if results else None
|
162
|
+
|
163
|
+
def execute_non_query(
|
164
|
+
self, sql: str, params: Optional[Union[tuple, dict]] = None, skip_validation: bool = False
|
165
|
+
) -> int:
|
166
|
+
"""Execute a non-SELECT SQL query (INSERT, UPDATE, DELETE, etc.).
|
167
|
+
|
168
|
+
Args:
|
169
|
+
sql: SQL query to execute
|
170
|
+
params: Optional query parameters
|
171
|
+
skip_validation: Skip SQL validation (default: False)
|
172
|
+
|
173
|
+
Returns:
|
174
|
+
Number of rows affected
|
175
|
+
|
176
|
+
Raises:
|
177
|
+
SQLValidationError: If query contains restricted operations
|
178
|
+
Exception: If query execution fails
|
179
|
+
"""
|
180
|
+
# Validate query unless explicitly skipped
|
181
|
+
if not skip_validation:
|
182
|
+
validate_query_safe(sql)
|
183
|
+
|
184
|
+
with DatabaseConnection(self.db_path) as conn:
|
185
|
+
cursor = conn.execute(sql, params)
|
186
|
+
affected_rows = cursor.rowcount
|
187
|
+
conn.commit()
|
188
|
+
return affected_rows
|
189
|
+
|
190
|
+
def execute_many(self, sql: str, params_list: List[Union[tuple, dict]]) -> int:
|
191
|
+
"""Execute the same SQL query multiple times with different parameters.
|
192
|
+
|
193
|
+
Args:
|
194
|
+
sql: SQL query to execute
|
195
|
+
params_list: List of parameter sets
|
196
|
+
|
197
|
+
Returns:
|
198
|
+
Total number of rows affected
|
199
|
+
|
200
|
+
Raises:
|
201
|
+
Exception: If query execution fails
|
202
|
+
"""
|
203
|
+
total_affected = 0
|
204
|
+
|
205
|
+
with DatabaseConnection(self.db_path) as conn:
|
206
|
+
try:
|
207
|
+
for params in params_list:
|
208
|
+
cursor = conn.execute(sql, params)
|
209
|
+
total_affected += cursor.rowcount
|
210
|
+
conn.commit()
|
211
|
+
return total_affected
|
212
|
+
except Exception:
|
213
|
+
conn.rollback()
|
214
|
+
raise
|