sqlspec 0.12.1__py3-none-any.whl → 0.12.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/adapters/aiosqlite/driver.py +16 -11
- sqlspec/adapters/bigquery/driver.py +113 -21
- sqlspec/adapters/duckdb/driver.py +18 -13
- sqlspec/adapters/psycopg/config.py +20 -3
- sqlspec/adapters/psycopg/driver.py +82 -1
- sqlspec/adapters/sqlite/driver.py +50 -10
- sqlspec/driver/mixins/_storage.py +83 -36
- sqlspec/loader.py +8 -30
- sqlspec/statement/builder/base.py +3 -1
- sqlspec/statement/builder/ddl.py +14 -1
- sqlspec/statement/pipelines/analyzers/_analyzer.py +1 -5
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +56 -2
- sqlspec/statement/sql.py +40 -6
- sqlspec/storage/backends/fsspec.py +29 -27
- sqlspec/storage/backends/obstore.py +55 -34
- sqlspec/storage/protocol.py +28 -25
- {sqlspec-0.12.1.dist-info → sqlspec-0.12.2.dist-info}/METADATA +1 -1
- {sqlspec-0.12.1.dist-info → sqlspec-0.12.2.dist-info}/RECORD +21 -21
- {sqlspec-0.12.1.dist-info → sqlspec-0.12.2.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.12.2.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.12.2.dist-info}/licenses/NOTICE +0 -0
|
@@ -203,8 +203,7 @@ class AiosqliteDriver(
|
|
|
203
203
|
return result
|
|
204
204
|
|
|
205
205
|
async def _bulk_load_file(self, file_path: Path, table_name: str, format: str, mode: str, **options: Any) -> int:
|
|
206
|
-
"""Database-specific bulk load implementation."""
|
|
207
|
-
# TODO: convert this to use the storage backend. it has async support
|
|
206
|
+
"""Database-specific bulk load implementation using storage backend."""
|
|
208
207
|
if format != "csv":
|
|
209
208
|
msg = f"aiosqlite driver only supports CSV for bulk loading, not {format}."
|
|
210
209
|
raise NotImplementedError(msg)
|
|
@@ -215,15 +214,21 @@ class AiosqliteDriver(
|
|
|
215
214
|
if mode == "replace":
|
|
216
215
|
await cursor.execute(f"DELETE FROM {table_name}")
|
|
217
216
|
|
|
218
|
-
#
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
217
|
+
# Use async storage backend to read the file
|
|
218
|
+
file_path_str = str(file_path)
|
|
219
|
+
backend = self._get_storage_backend(file_path_str)
|
|
220
|
+
content = await backend.read_text_async(file_path_str, encoding="utf-8")
|
|
221
|
+
# Parse CSV content
|
|
222
|
+
import io
|
|
223
|
+
|
|
224
|
+
csv_file = io.StringIO(content)
|
|
225
|
+
reader = csv.reader(csv_file, **options)
|
|
226
|
+
header = next(reader) # Skip header
|
|
227
|
+
placeholders = ", ".join("?" for _ in header)
|
|
228
|
+
sql = f"INSERT INTO {table_name} VALUES ({placeholders})"
|
|
229
|
+
data_iter = list(reader)
|
|
230
|
+
await cursor.executemany(sql, data_iter)
|
|
231
|
+
rowcount = cursor.rowcount
|
|
227
232
|
await conn.commit()
|
|
228
233
|
return rowcount
|
|
229
234
|
finally:
|
|
@@ -1,6 +1,8 @@
|
|
|
1
|
+
import contextlib
|
|
1
2
|
import datetime
|
|
2
3
|
import io
|
|
3
4
|
import logging
|
|
5
|
+
import uuid
|
|
4
6
|
from collections.abc import Iterator
|
|
5
7
|
from decimal import Decimal
|
|
6
8
|
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union, cast
|
|
@@ -8,10 +10,12 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union, cast
|
|
|
8
10
|
from google.cloud.bigquery import (
|
|
9
11
|
ArrayQueryParameter,
|
|
10
12
|
Client,
|
|
13
|
+
ExtractJobConfig,
|
|
11
14
|
LoadJobConfig,
|
|
12
15
|
QueryJob,
|
|
13
16
|
QueryJobConfig,
|
|
14
17
|
ScalarQueryParameter,
|
|
18
|
+
SourceFormat,
|
|
15
19
|
WriteDisposition,
|
|
16
20
|
)
|
|
17
21
|
from google.cloud.bigquery.table import Row as BigQueryRow
|
|
@@ -32,6 +36,8 @@ from sqlspec.typing import DictRow, ModelDTOT, RowT
|
|
|
32
36
|
from sqlspec.utils.serializers import to_json
|
|
33
37
|
|
|
34
38
|
if TYPE_CHECKING:
|
|
39
|
+
from pathlib import Path
|
|
40
|
+
|
|
35
41
|
from sqlglot.dialects.dialect import DialectType
|
|
36
42
|
|
|
37
43
|
|
|
@@ -258,23 +264,17 @@ class BigQueryDriver(
|
|
|
258
264
|
param_value,
|
|
259
265
|
type(param_value),
|
|
260
266
|
)
|
|
261
|
-
# Let BigQuery generate the job ID to avoid collisions
|
|
262
|
-
# This is the recommended approach for production code and works better with emulators
|
|
263
|
-
logger.warning("About to send to BigQuery - SQL: %r", sql_str)
|
|
264
|
-
logger.warning("Query parameters in job config: %r", final_job_config.query_parameters)
|
|
265
267
|
query_job = conn.query(sql_str, job_config=final_job_config)
|
|
266
268
|
|
|
267
269
|
# Get the auto-generated job ID for callbacks
|
|
268
270
|
if self.on_job_start and query_job.job_id:
|
|
269
|
-
|
|
271
|
+
with contextlib.suppress(Exception):
|
|
272
|
+
# Callback errors should not interfere with job execution
|
|
270
273
|
self.on_job_start(query_job.job_id)
|
|
271
|
-
except Exception as e:
|
|
272
|
-
logger.warning("Job start callback failed: %s", str(e), extra={"adapter": "bigquery"})
|
|
273
274
|
if self.on_job_complete and query_job.job_id:
|
|
274
|
-
|
|
275
|
+
with contextlib.suppress(Exception):
|
|
276
|
+
# Callback errors should not interfere with job execution
|
|
275
277
|
self.on_job_complete(query_job.job_id, query_job)
|
|
276
|
-
except Exception as e:
|
|
277
|
-
logger.warning("Job complete callback failed: %s", str(e), extra={"adapter": "bigquery"})
|
|
278
278
|
|
|
279
279
|
return query_job
|
|
280
280
|
|
|
@@ -529,28 +529,120 @@ class BigQueryDriver(
|
|
|
529
529
|
# BigQuery Native Export Support
|
|
530
530
|
# ============================================================================
|
|
531
531
|
|
|
532
|
-
def _export_native(self, query: str, destination_uri: str, format: str, **options: Any) -> int:
|
|
533
|
-
"""BigQuery native export implementation.
|
|
532
|
+
def _export_native(self, query: str, destination_uri: "Union[str, Path]", format: str, **options: Any) -> int:
|
|
533
|
+
"""BigQuery native export implementation with automatic GCS staging.
|
|
534
534
|
|
|
535
|
-
For
|
|
536
|
-
|
|
535
|
+
For GCS URIs, uses direct export. For other locations, automatically stages
|
|
536
|
+
through a temporary GCS location and transfers to the final destination.
|
|
537
537
|
|
|
538
538
|
Args:
|
|
539
539
|
query: SQL query to execute
|
|
540
|
-
destination_uri: Destination URI (local file path
|
|
540
|
+
destination_uri: Destination URI (local file path, gs:// URI, or Path object)
|
|
541
541
|
format: Export format (parquet, csv, json, avro)
|
|
542
|
-
**options: Additional export options
|
|
542
|
+
**options: Additional export options including 'gcs_staging_bucket'
|
|
543
543
|
|
|
544
544
|
Returns:
|
|
545
545
|
Number of rows exported
|
|
546
546
|
|
|
547
547
|
Raises:
|
|
548
|
-
NotImplementedError:
|
|
548
|
+
NotImplementedError: If no staging bucket is configured for non-GCS destinations
|
|
549
549
|
"""
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
550
|
+
destination_str = str(destination_uri)
|
|
551
|
+
|
|
552
|
+
# If it's already a GCS URI, use direct export
|
|
553
|
+
if destination_str.startswith("gs://"):
|
|
554
|
+
return self._export_to_gcs_native(query, destination_str, format, **options)
|
|
555
|
+
|
|
556
|
+
# For non-GCS destinations, check if staging is configured
|
|
557
|
+
staging_bucket = options.get("gcs_staging_bucket") or getattr(self.config, "gcs_staging_bucket", None)
|
|
558
|
+
if not staging_bucket:
|
|
559
|
+
# Fall back to fetch + write for non-GCS destinations without staging
|
|
560
|
+
msg = "BigQuery native export requires GCS staging bucket for non-GCS destinations"
|
|
561
|
+
raise NotImplementedError(msg)
|
|
562
|
+
|
|
563
|
+
# Generate temporary GCS path
|
|
564
|
+
from datetime import timezone
|
|
565
|
+
|
|
566
|
+
timestamp = datetime.datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
|
567
|
+
temp_filename = f"bigquery_export_{timestamp}_{uuid.uuid4().hex[:8]}.{format}"
|
|
568
|
+
temp_gcs_uri = f"gs://{staging_bucket}/temp_exports/{temp_filename}"
|
|
569
|
+
|
|
570
|
+
try:
|
|
571
|
+
# Export to temporary GCS location
|
|
572
|
+
rows_exported = self._export_to_gcs_native(query, temp_gcs_uri, format, **options)
|
|
573
|
+
|
|
574
|
+
# Transfer from GCS to final destination using storage backend
|
|
575
|
+
backend, path = self._resolve_backend_and_path(destination_str)
|
|
576
|
+
gcs_backend = self._get_storage_backend(temp_gcs_uri)
|
|
577
|
+
|
|
578
|
+
# Download from GCS and upload to final destination
|
|
579
|
+
data = gcs_backend.read_bytes(temp_gcs_uri)
|
|
580
|
+
backend.write_bytes(path, data)
|
|
581
|
+
|
|
582
|
+
return rows_exported
|
|
583
|
+
finally:
|
|
584
|
+
# Clean up temporary file
|
|
585
|
+
try:
|
|
586
|
+
gcs_backend = self._get_storage_backend(temp_gcs_uri)
|
|
587
|
+
gcs_backend.delete(temp_gcs_uri)
|
|
588
|
+
except Exception as e:
|
|
589
|
+
logger.warning("Failed to clean up temporary GCS file %s: %s", temp_gcs_uri, e)
|
|
590
|
+
|
|
591
|
+
def _export_to_gcs_native(self, query: str, gcs_uri: str, format: str, **options: Any) -> int:
|
|
592
|
+
"""Direct BigQuery export to GCS.
|
|
593
|
+
|
|
594
|
+
Args:
|
|
595
|
+
query: SQL query to execute
|
|
596
|
+
gcs_uri: GCS destination URI (must start with gs://)
|
|
597
|
+
format: Export format (parquet, csv, json, avro)
|
|
598
|
+
**options: Additional export options
|
|
599
|
+
|
|
600
|
+
Returns:
|
|
601
|
+
Number of rows exported
|
|
602
|
+
"""
|
|
603
|
+
# First, run the query and store results in a temporary table
|
|
604
|
+
|
|
605
|
+
temp_table_id = f"temp_export_{uuid.uuid4().hex[:8]}"
|
|
606
|
+
dataset_id = getattr(self.connection, "default_dataset", None) or options.get("dataset", "temp")
|
|
607
|
+
|
|
608
|
+
# Create a temporary table with query results
|
|
609
|
+
query_with_table = f"CREATE OR REPLACE TABLE `{dataset_id}.{temp_table_id}` AS {query}"
|
|
610
|
+
create_job = self._run_query_job(query_with_table, [])
|
|
611
|
+
create_job.result()
|
|
612
|
+
|
|
613
|
+
# Get row count
|
|
614
|
+
count_query = f"SELECT COUNT(*) as cnt FROM `{dataset_id}.{temp_table_id}`"
|
|
615
|
+
count_job = self._run_query_job(count_query, [])
|
|
616
|
+
count_result = list(count_job.result())
|
|
617
|
+
row_count = count_result[0]["cnt"] if count_result else 0
|
|
618
|
+
|
|
619
|
+
try:
|
|
620
|
+
# Configure extract job
|
|
621
|
+
extract_config = ExtractJobConfig(**options) # type: ignore[no-untyped-call]
|
|
622
|
+
|
|
623
|
+
# Set format
|
|
624
|
+
format_mapping = {
|
|
625
|
+
"parquet": SourceFormat.PARQUET,
|
|
626
|
+
"csv": SourceFormat.CSV,
|
|
627
|
+
"json": SourceFormat.NEWLINE_DELIMITED_JSON,
|
|
628
|
+
"avro": SourceFormat.AVRO,
|
|
629
|
+
}
|
|
630
|
+
extract_config.destination_format = format_mapping.get(format, SourceFormat.PARQUET)
|
|
631
|
+
|
|
632
|
+
# Extract table to GCS
|
|
633
|
+
table_ref = self.connection.dataset(dataset_id).table(temp_table_id)
|
|
634
|
+
extract_job = self.connection.extract_table(table_ref, gcs_uri, job_config=extract_config)
|
|
635
|
+
extract_job.result()
|
|
636
|
+
|
|
637
|
+
return row_count
|
|
638
|
+
finally:
|
|
639
|
+
# Clean up temporary table
|
|
640
|
+
try:
|
|
641
|
+
delete_query = f"DROP TABLE IF EXISTS `{dataset_id}.{temp_table_id}`"
|
|
642
|
+
delete_job = self._run_query_job(delete_query, [])
|
|
643
|
+
delete_job.result()
|
|
644
|
+
except Exception as e:
|
|
645
|
+
logger.warning("Failed to clean up temporary table %s: %s", temp_table_id, e)
|
|
554
646
|
|
|
555
647
|
# ============================================================================
|
|
556
648
|
# BigQuery Native Arrow Support
|
|
@@ -2,6 +2,7 @@ import contextlib
|
|
|
2
2
|
import uuid
|
|
3
3
|
from collections.abc import Generator
|
|
4
4
|
from contextlib import contextmanager
|
|
5
|
+
from pathlib import Path
|
|
5
6
|
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast
|
|
6
7
|
|
|
7
8
|
from duckdb import DuckDBPyConnection
|
|
@@ -251,7 +252,7 @@ class DuckDBDriver(
|
|
|
251
252
|
return True
|
|
252
253
|
return False
|
|
253
254
|
|
|
254
|
-
def _export_native(self, query: str, destination_uri: str, format: str, **options: Any) -> int:
|
|
255
|
+
def _export_native(self, query: str, destination_uri: Union[str, Path], format: str, **options: Any) -> int:
|
|
255
256
|
conn = self._connection(None)
|
|
256
257
|
copy_options: list[str] = []
|
|
257
258
|
|
|
@@ -283,19 +284,21 @@ class DuckDBDriver(
|
|
|
283
284
|
raise ValueError(msg)
|
|
284
285
|
|
|
285
286
|
options_str = f"({', '.join(copy_options)})" if copy_options else ""
|
|
286
|
-
copy_sql = f"COPY ({query}) TO '{destination_uri}' {options_str}"
|
|
287
|
+
copy_sql = f"COPY ({query}) TO '{destination_uri!s}' {options_str}"
|
|
287
288
|
result_rel = conn.execute(copy_sql)
|
|
288
289
|
result = result_rel.fetchone() if result_rel else None
|
|
289
290
|
return result[0] if result else 0
|
|
290
291
|
|
|
291
|
-
def _import_native(
|
|
292
|
+
def _import_native(
|
|
293
|
+
self, source_uri: Union[str, Path], table_name: str, format: str, mode: str, **options: Any
|
|
294
|
+
) -> int:
|
|
292
295
|
conn = self._connection(None)
|
|
293
296
|
if format == "parquet":
|
|
294
|
-
read_func = f"read_parquet('{source_uri}')"
|
|
297
|
+
read_func = f"read_parquet('{source_uri!s}')"
|
|
295
298
|
elif format == "csv":
|
|
296
|
-
read_func = f"read_csv_auto('{source_uri}')"
|
|
299
|
+
read_func = f"read_csv_auto('{source_uri!s}')"
|
|
297
300
|
elif format == "json":
|
|
298
|
-
read_func = f"read_json_auto('{source_uri}')"
|
|
301
|
+
read_func = f"read_json_auto('{source_uri!s}')"
|
|
299
302
|
else:
|
|
300
303
|
msg = f"Unsupported format for DuckDB native import: {format}"
|
|
301
304
|
raise ValueError(msg)
|
|
@@ -320,16 +323,16 @@ class DuckDBDriver(
|
|
|
320
323
|
return int(count_result[0]) if count_result else 0
|
|
321
324
|
|
|
322
325
|
def _read_parquet_native(
|
|
323
|
-
self, source_uri: str, columns: Optional[list[str]] = None, **options: Any
|
|
326
|
+
self, source_uri: Union[str, Path], columns: Optional[list[str]] = None, **options: Any
|
|
324
327
|
) -> "SQLResult[dict[str, Any]]":
|
|
325
328
|
conn = self._connection(None)
|
|
326
329
|
if isinstance(source_uri, list):
|
|
327
330
|
file_list = "[" + ", ".join(f"'{f}'" for f in source_uri) + "]"
|
|
328
331
|
read_func = f"read_parquet({file_list})"
|
|
329
|
-
elif "*" in source_uri or "?" in source_uri:
|
|
330
|
-
read_func = f"read_parquet('{source_uri}')"
|
|
332
|
+
elif "*" in str(source_uri) or "?" in str(source_uri):
|
|
333
|
+
read_func = f"read_parquet('{source_uri!s}')"
|
|
331
334
|
else:
|
|
332
|
-
read_func = f"read_parquet('{source_uri}')"
|
|
335
|
+
read_func = f"read_parquet('{source_uri!s}')"
|
|
333
336
|
|
|
334
337
|
column_list = ", ".join(columns) if columns else "*"
|
|
335
338
|
query = f"SELECT {column_list} FROM {read_func}"
|
|
@@ -353,7 +356,9 @@ class DuckDBDriver(
|
|
|
353
356
|
statement=SQL(query), data=rows, column_names=column_names, rows_affected=num_rows, operation_type="SELECT"
|
|
354
357
|
)
|
|
355
358
|
|
|
356
|
-
def _write_parquet_native(
|
|
359
|
+
def _write_parquet_native(
|
|
360
|
+
self, data: Union[str, "ArrowTable"], destination_uri: Union[str, Path], **options: Any
|
|
361
|
+
) -> None:
|
|
357
362
|
conn = self._connection(None)
|
|
358
363
|
copy_options: list[str] = ["FORMAT PARQUET"]
|
|
359
364
|
if "compression" in options:
|
|
@@ -364,13 +369,13 @@ class DuckDBDriver(
|
|
|
364
369
|
options_str = f"({', '.join(copy_options)})"
|
|
365
370
|
|
|
366
371
|
if isinstance(data, str):
|
|
367
|
-
copy_sql = f"COPY ({data}) TO '{destination_uri}' {options_str}"
|
|
372
|
+
copy_sql = f"COPY ({data}) TO '{destination_uri!s}' {options_str}"
|
|
368
373
|
conn.execute(copy_sql)
|
|
369
374
|
else:
|
|
370
375
|
temp_name = f"_arrow_data_{uuid.uuid4().hex[:8]}"
|
|
371
376
|
conn.register(temp_name, data)
|
|
372
377
|
try:
|
|
373
|
-
copy_sql = f"COPY {temp_name} TO '{destination_uri}' {options_str}"
|
|
378
|
+
copy_sql = f"COPY {temp_name} TO '{destination_uri!s}' {options_str}"
|
|
374
379
|
conn.execute(copy_sql)
|
|
375
380
|
finally:
|
|
376
381
|
with contextlib.suppress(Exception):
|
|
@@ -304,7 +304,7 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
|
|
|
304
304
|
if conninfo:
|
|
305
305
|
# If conninfo is provided, use it directly
|
|
306
306
|
# Don't pass kwargs when using conninfo string
|
|
307
|
-
pool = ConnectionPool(conninfo, **pool_params)
|
|
307
|
+
pool = ConnectionPool(conninfo, open=True, **pool_params)
|
|
308
308
|
else:
|
|
309
309
|
# Otherwise, pass connection parameters via kwargs
|
|
310
310
|
# Remove any non-connection parameters
|
|
@@ -312,7 +312,7 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
|
|
|
312
312
|
all_config.pop("row_factory", None)
|
|
313
313
|
# Remove pool-specific settings that may have been left
|
|
314
314
|
all_config.pop("kwargs", None)
|
|
315
|
-
pool = ConnectionPool("", kwargs=all_config, **pool_params)
|
|
315
|
+
pool = ConnectionPool("", kwargs=all_config, open=True, **pool_params)
|
|
316
316
|
|
|
317
317
|
logger.info("Psycopg connection pool created successfully", extra={"adapter": "psycopg"})
|
|
318
318
|
except Exception as e:
|
|
@@ -328,11 +328,19 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
|
|
|
328
328
|
logger.info("Closing Psycopg connection pool", extra={"adapter": "psycopg"})
|
|
329
329
|
|
|
330
330
|
try:
|
|
331
|
+
# Set a flag to prevent __del__ from running cleanup
|
|
332
|
+
# This avoids the "cannot join current thread" error during garbage collection
|
|
333
|
+
if hasattr(self.pool_instance, "_closed"):
|
|
334
|
+
self.pool_instance._closed = True
|
|
335
|
+
|
|
331
336
|
self.pool_instance.close()
|
|
332
337
|
logger.info("Psycopg connection pool closed successfully", extra={"adapter": "psycopg"})
|
|
333
338
|
except Exception as e:
|
|
334
339
|
logger.exception("Failed to close Psycopg connection pool", extra={"adapter": "psycopg", "error": str(e)})
|
|
335
340
|
raise
|
|
341
|
+
finally:
|
|
342
|
+
# Clear the reference to help garbage collection
|
|
343
|
+
self.pool_instance = None
|
|
336
344
|
|
|
337
345
|
def create_connection(self) -> "PsycopgSyncConnection":
|
|
338
346
|
"""Create a single connection (not from pool).
|
|
@@ -657,7 +665,16 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
|
|
|
657
665
|
if not self.pool_instance:
|
|
658
666
|
return
|
|
659
667
|
|
|
660
|
-
|
|
668
|
+
try:
|
|
669
|
+
# Set a flag to prevent __del__ from running cleanup
|
|
670
|
+
# This avoids the "cannot join current thread" error during garbage collection
|
|
671
|
+
if hasattr(self.pool_instance, "_closed"):
|
|
672
|
+
self.pool_instance._closed = True
|
|
673
|
+
|
|
674
|
+
await self.pool_instance.close()
|
|
675
|
+
finally:
|
|
676
|
+
# Clear the reference to help garbage collection
|
|
677
|
+
self.pool_instance = None
|
|
661
678
|
|
|
662
679
|
async def create_connection(self) -> "PsycopgAsyncConnection": # pyright: ignore
|
|
663
680
|
"""Create a single async connection (not from pool).
|
|
@@ -20,6 +20,7 @@ from sqlspec.driver.mixins import (
|
|
|
20
20
|
ToSchemaMixin,
|
|
21
21
|
TypeCoercionMixin,
|
|
22
22
|
)
|
|
23
|
+
from sqlspec.exceptions import PipelineExecutionError
|
|
23
24
|
from sqlspec.statement.parameters import ParameterStyle
|
|
24
25
|
from sqlspec.statement.result import ArrowResult, DMLResultDict, ScriptResultDict, SelectResultDict, SQLResult
|
|
25
26
|
from sqlspec.statement.splitter import split_sql_script
|
|
@@ -113,6 +114,12 @@ class PsycopgSyncDriver(
|
|
|
113
114
|
**kwargs: Any,
|
|
114
115
|
) -> Union[SelectResultDict, DMLResultDict]:
|
|
115
116
|
conn = self._connection(connection)
|
|
117
|
+
|
|
118
|
+
# Check if this is a COPY command
|
|
119
|
+
sql_upper = sql.strip().upper()
|
|
120
|
+
if sql_upper.startswith("COPY") and ("FROM STDIN" in sql_upper or "TO STDOUT" in sql_upper):
|
|
121
|
+
return self._handle_copy_command(sql, parameters, conn)
|
|
122
|
+
|
|
116
123
|
with conn.cursor() as cursor:
|
|
117
124
|
cursor.execute(cast("Query", sql), parameters)
|
|
118
125
|
# Check if the statement returns rows by checking cursor.description
|
|
@@ -123,6 +130,38 @@ class PsycopgSyncDriver(
|
|
|
123
130
|
return {"data": fetched_data, "column_names": column_names, "rows_affected": len(fetched_data)}
|
|
124
131
|
return {"rows_affected": cursor.rowcount, "status_message": cursor.statusmessage or "OK"}
|
|
125
132
|
|
|
133
|
+
def _handle_copy_command(
|
|
134
|
+
self, sql: str, data: Any, connection: PsycopgSyncConnection
|
|
135
|
+
) -> Union[SelectResultDict, DMLResultDict]:
|
|
136
|
+
"""Handle PostgreSQL COPY commands using cursor.copy() method."""
|
|
137
|
+
sql_upper = sql.strip().upper()
|
|
138
|
+
|
|
139
|
+
with connection.cursor() as cursor:
|
|
140
|
+
if "TO STDOUT" in sql_upper:
|
|
141
|
+
# COPY TO STDOUT - read data from the database
|
|
142
|
+
output_data: list[Any] = []
|
|
143
|
+
with cursor.copy(cast("Query", sql)) as copy:
|
|
144
|
+
output_data.extend(row for row in copy)
|
|
145
|
+
|
|
146
|
+
# Return as SelectResultDict with the raw COPY data
|
|
147
|
+
return {"data": output_data, "column_names": ["copy_data"], "rows_affected": len(output_data)}
|
|
148
|
+
# COPY FROM STDIN - write data to the database
|
|
149
|
+
with cursor.copy(cast("Query", sql)) as copy:
|
|
150
|
+
if data:
|
|
151
|
+
# If data is provided, write it to the copy stream
|
|
152
|
+
if isinstance(data, (str, bytes)):
|
|
153
|
+
copy.write(data)
|
|
154
|
+
elif isinstance(data, (list, tuple)):
|
|
155
|
+
# If data is a list/tuple of rows, write each row
|
|
156
|
+
for row in data:
|
|
157
|
+
copy.write_row(row)
|
|
158
|
+
else:
|
|
159
|
+
# Single row
|
|
160
|
+
copy.write_row(data)
|
|
161
|
+
|
|
162
|
+
# For COPY operations, cursor.rowcount contains the number of rows affected
|
|
163
|
+
return {"rows_affected": cursor.rowcount or -1, "status_message": cursor.statusmessage or "COPY COMPLETE"}
|
|
164
|
+
|
|
126
165
|
def _execute_many(
|
|
127
166
|
self, sql: str, param_list: Any, connection: Optional[PsycopgSyncConnection] = None, **kwargs: Any
|
|
128
167
|
) -> DMLResultDict:
|
|
@@ -242,7 +281,6 @@ class PsycopgSyncDriver(
|
|
|
242
281
|
Returns:
|
|
243
282
|
List of SQLResult objects from all operations
|
|
244
283
|
"""
|
|
245
|
-
from sqlspec.exceptions import PipelineExecutionError
|
|
246
284
|
|
|
247
285
|
results = []
|
|
248
286
|
connection = self._connection()
|
|
@@ -489,6 +527,12 @@ class PsycopgAsyncDriver(
|
|
|
489
527
|
**kwargs: Any,
|
|
490
528
|
) -> Union[SelectResultDict, DMLResultDict]:
|
|
491
529
|
conn = self._connection(connection)
|
|
530
|
+
|
|
531
|
+
# Check if this is a COPY command
|
|
532
|
+
sql_upper = sql.strip().upper()
|
|
533
|
+
if sql_upper.startswith("COPY") and ("FROM STDIN" in sql_upper or "TO STDOUT" in sql_upper):
|
|
534
|
+
return await self._handle_copy_command(sql, parameters, conn)
|
|
535
|
+
|
|
492
536
|
async with conn.cursor() as cursor:
|
|
493
537
|
await cursor.execute(cast("Query", sql), parameters)
|
|
494
538
|
|
|
@@ -510,6 +554,38 @@ class PsycopgAsyncDriver(
|
|
|
510
554
|
}
|
|
511
555
|
return dml_result
|
|
512
556
|
|
|
557
|
+
async def _handle_copy_command(
|
|
558
|
+
self, sql: str, data: Any, connection: PsycopgAsyncConnection
|
|
559
|
+
) -> Union[SelectResultDict, DMLResultDict]:
|
|
560
|
+
"""Handle PostgreSQL COPY commands using cursor.copy() method."""
|
|
561
|
+
sql_upper = sql.strip().upper()
|
|
562
|
+
|
|
563
|
+
async with connection.cursor() as cursor:
|
|
564
|
+
if "TO STDOUT" in sql_upper:
|
|
565
|
+
# COPY TO STDOUT - read data from the database
|
|
566
|
+
output_data = []
|
|
567
|
+
async with cursor.copy(cast("Query", sql)) as copy:
|
|
568
|
+
output_data.extend([row async for row in copy])
|
|
569
|
+
|
|
570
|
+
# Return as SelectResultDict with the raw COPY data
|
|
571
|
+
return {"data": output_data, "column_names": ["copy_data"], "rows_affected": len(output_data)}
|
|
572
|
+
# COPY FROM STDIN - write data to the database
|
|
573
|
+
async with cursor.copy(cast("Query", sql)) as copy:
|
|
574
|
+
if data:
|
|
575
|
+
# If data is provided, write it to the copy stream
|
|
576
|
+
if isinstance(data, (str, bytes)):
|
|
577
|
+
await copy.write(data)
|
|
578
|
+
elif isinstance(data, (list, tuple)):
|
|
579
|
+
# If data is a list/tuple of rows, write each row
|
|
580
|
+
for row in data:
|
|
581
|
+
await copy.write_row(row)
|
|
582
|
+
else:
|
|
583
|
+
# Single row
|
|
584
|
+
await copy.write_row(data)
|
|
585
|
+
|
|
586
|
+
# For COPY operations, cursor.rowcount contains the number of rows affected
|
|
587
|
+
return {"rows_affected": cursor.rowcount or -1, "status_message": cursor.statusmessage or "COPY COMPLETE"}
|
|
588
|
+
|
|
513
589
|
async def _execute_many(
|
|
514
590
|
self, sql: str, param_list: Any, connection: Optional[PsycopgAsyncConnection] = None, **kwargs: Any
|
|
515
591
|
) -> DMLResultDict:
|
|
@@ -595,6 +671,11 @@ class PsycopgAsyncDriver(
|
|
|
595
671
|
if statement.expression:
|
|
596
672
|
operation_type = str(statement.expression.key).upper()
|
|
597
673
|
|
|
674
|
+
# Handle case where we got a SelectResultDict but it was routed here due to parsing being disabled
|
|
675
|
+
if is_dict_with_field(result, "data") and is_dict_with_field(result, "column_names"):
|
|
676
|
+
# This is actually a SELECT result, wrap it properly
|
|
677
|
+
return await self._wrap_select_result(statement, cast("SelectResultDict", result), **kwargs)
|
|
678
|
+
|
|
598
679
|
if is_dict_with_field(result, "statements_executed"):
|
|
599
680
|
return SQLResult[RowT](
|
|
600
681
|
statement=statement,
|
|
@@ -197,8 +197,41 @@ class SqliteDriver(
|
|
|
197
197
|
result: ScriptResultDict = {"statements_executed": -1, "status_message": "SCRIPT EXECUTED"}
|
|
198
198
|
return result
|
|
199
199
|
|
|
200
|
+
def _ingest_arrow_table(self, table: Any, table_name: str, mode: str = "create", **options: Any) -> int:
|
|
201
|
+
"""SQLite-specific Arrow table ingestion using CSV conversion.
|
|
202
|
+
|
|
203
|
+
Since SQLite only supports CSV bulk loading, we convert the Arrow table
|
|
204
|
+
to CSV format first using the storage backend for efficient operations.
|
|
205
|
+
"""
|
|
206
|
+
import io
|
|
207
|
+
import tempfile
|
|
208
|
+
|
|
209
|
+
import pyarrow.csv as pa_csv
|
|
210
|
+
|
|
211
|
+
# Convert Arrow table to CSV in memory
|
|
212
|
+
csv_buffer = io.BytesIO()
|
|
213
|
+
pa_csv.write_csv(table, csv_buffer)
|
|
214
|
+
csv_content = csv_buffer.getvalue()
|
|
215
|
+
|
|
216
|
+
# Create a temporary file path
|
|
217
|
+
temp_filename = f"sqlspec_temp_{table_name}_{id(self)}.csv"
|
|
218
|
+
temp_path = Path(tempfile.gettempdir()) / temp_filename
|
|
219
|
+
|
|
220
|
+
# Use storage backend to write the CSV content
|
|
221
|
+
backend = self._get_storage_backend(temp_path)
|
|
222
|
+
backend.write_bytes(str(temp_path), csv_content)
|
|
223
|
+
|
|
224
|
+
try:
|
|
225
|
+
# Use SQLite's CSV bulk load
|
|
226
|
+
return self._bulk_load_file(temp_path, table_name, "csv", mode, **options)
|
|
227
|
+
finally:
|
|
228
|
+
# Clean up using storage backend
|
|
229
|
+
with contextlib.suppress(Exception):
|
|
230
|
+
# Best effort cleanup
|
|
231
|
+
backend.delete(str(temp_path))
|
|
232
|
+
|
|
200
233
|
def _bulk_load_file(self, file_path: Path, table_name: str, format: str, mode: str, **options: Any) -> int:
|
|
201
|
-
"""Database-specific bulk load implementation."""
|
|
234
|
+
"""Database-specific bulk load implementation using storage backend."""
|
|
202
235
|
if format != "csv":
|
|
203
236
|
msg = f"SQLite driver only supports CSV for bulk loading, not {format}."
|
|
204
237
|
raise NotImplementedError(msg)
|
|
@@ -208,16 +241,23 @@ class SqliteDriver(
|
|
|
208
241
|
if mode == "replace":
|
|
209
242
|
cursor.execute(f"DELETE FROM {table_name}")
|
|
210
243
|
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
244
|
+
# Use storage backend to read the file
|
|
245
|
+
backend = self._get_storage_backend(file_path)
|
|
246
|
+
content = backend.read_text(str(file_path), encoding="utf-8")
|
|
247
|
+
|
|
248
|
+
# Parse CSV content
|
|
249
|
+
import io
|
|
250
|
+
|
|
251
|
+
csv_file = io.StringIO(content)
|
|
252
|
+
reader = csv.reader(csv_file, **options)
|
|
253
|
+
header = next(reader) # Skip header
|
|
254
|
+
placeholders = ", ".join("?" for _ in header)
|
|
255
|
+
sql = f"INSERT INTO {table_name} VALUES ({placeholders})"
|
|
216
256
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
257
|
+
# executemany is efficient for bulk inserts
|
|
258
|
+
data_iter = list(reader) # Read all data into memory
|
|
259
|
+
cursor.executemany(sql, data_iter)
|
|
260
|
+
return cursor.rowcount
|
|
221
261
|
|
|
222
262
|
def _wrap_select_result(
|
|
223
263
|
self, statement: SQL, result: SelectResultDict, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any
|