sqlspec 0.16.1__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +11 -1
- sqlspec/_sql.py +18 -412
- sqlspec/adapters/aiosqlite/__init__.py +11 -1
- sqlspec/adapters/aiosqlite/config.py +137 -165
- sqlspec/adapters/aiosqlite/driver.py +21 -10
- sqlspec/adapters/aiosqlite/pool.py +492 -0
- sqlspec/adapters/duckdb/__init__.py +2 -0
- sqlspec/adapters/duckdb/config.py +11 -235
- sqlspec/adapters/duckdb/pool.py +243 -0
- sqlspec/adapters/sqlite/__init__.py +2 -0
- sqlspec/adapters/sqlite/config.py +4 -115
- sqlspec/adapters/sqlite/pool.py +140 -0
- sqlspec/base.py +147 -26
- sqlspec/builder/__init__.py +6 -0
- sqlspec/builder/_insert.py +177 -12
- sqlspec/builder/_parsing_utils.py +53 -2
- sqlspec/builder/mixins/_join_operations.py +148 -7
- sqlspec/builder/mixins/_merge_operations.py +102 -16
- sqlspec/builder/mixins/_select_operations.py +311 -6
- sqlspec/builder/mixins/_update_operations.py +49 -34
- sqlspec/builder/mixins/_where_clause.py +85 -13
- sqlspec/core/compiler.py +7 -5
- sqlspec/driver/_common.py +9 -1
- sqlspec/loader.py +27 -54
- sqlspec/storage/registry.py +2 -2
- sqlspec/typing.py +53 -99
- {sqlspec-0.16.1.dist-info → sqlspec-0.17.0.dist-info}/METADATA +1 -1
- {sqlspec-0.16.1.dist-info → sqlspec-0.17.0.dist-info}/RECORD +32 -29
- {sqlspec-0.16.1.dist-info → sqlspec-0.17.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.16.1.dist-info → sqlspec-0.17.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.16.1.dist-info → sqlspec-0.17.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.16.1.dist-info → sqlspec-0.17.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
"""SQLite database configuration with thread-local connections."""
|
|
2
|
+
|
|
3
|
+
import sqlite3
|
|
4
|
+
import threading
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Optional, TypedDict, cast
|
|
7
|
+
|
|
8
|
+
from typing_extensions import NotRequired
|
|
9
|
+
|
|
10
|
+
from sqlspec.adapters.sqlite._types import SqliteConnection
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from collections.abc import Generator
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SqliteConnectionParams(TypedDict, total=False):
|
|
17
|
+
"""SQLite connection parameters."""
|
|
18
|
+
|
|
19
|
+
database: NotRequired[str]
|
|
20
|
+
timeout: NotRequired[float]
|
|
21
|
+
detect_types: NotRequired[int]
|
|
22
|
+
isolation_level: "NotRequired[Optional[str]]"
|
|
23
|
+
check_same_thread: NotRequired[bool]
|
|
24
|
+
factory: "NotRequired[Optional[type[SqliteConnection]]]"
|
|
25
|
+
cached_statements: NotRequired[int]
|
|
26
|
+
uri: NotRequired[bool]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
__all__ = ("SqliteConnectionPool",)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class SqliteConnectionPool:
|
|
33
|
+
"""Thread-local connection manager for SQLite.
|
|
34
|
+
|
|
35
|
+
SQLite connections aren't thread-safe, so we use thread-local storage
|
|
36
|
+
to ensure each thread has its own connection. This is simpler and more
|
|
37
|
+
efficient than a traditional pool for SQLite's constraints.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
__slots__ = ("_connection_parameters", "_enable_optimizations", "_thread_local")
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
connection_parameters: "dict[str, Any]",
|
|
45
|
+
enable_optimizations: bool = True,
|
|
46
|
+
**kwargs: Any, # Accept and ignore pool parameters for compatibility
|
|
47
|
+
) -> None:
|
|
48
|
+
"""Initialize the thread-local connection manager.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
connection_parameters: SQLite connection parameters
|
|
52
|
+
enable_optimizations: Whether to apply performance PRAGMAs
|
|
53
|
+
**kwargs: Ignored pool parameters for compatibility
|
|
54
|
+
"""
|
|
55
|
+
self._connection_parameters = connection_parameters
|
|
56
|
+
self._thread_local = threading.local()
|
|
57
|
+
self._enable_optimizations = enable_optimizations
|
|
58
|
+
|
|
59
|
+
def _create_connection(self) -> SqliteConnection:
|
|
60
|
+
"""Create a new SQLite connection with optimizations."""
|
|
61
|
+
connection = sqlite3.connect(**self._connection_parameters)
|
|
62
|
+
|
|
63
|
+
# Only apply optimizations if requested and not in-memory
|
|
64
|
+
if self._enable_optimizations:
|
|
65
|
+
database = self._connection_parameters.get("database", ":memory:")
|
|
66
|
+
is_memory = database == ":memory:" or database.startswith("file::memory:")
|
|
67
|
+
|
|
68
|
+
if not is_memory:
|
|
69
|
+
# WAL mode doesn't work with in-memory databases
|
|
70
|
+
connection.execute("PRAGMA journal_mode = WAL")
|
|
71
|
+
# Set busy timeout for better concurrent access
|
|
72
|
+
connection.execute("PRAGMA busy_timeout = 5000")
|
|
73
|
+
connection.execute("PRAGMA optimize")
|
|
74
|
+
# These work for all database types
|
|
75
|
+
connection.execute("PRAGMA foreign_keys = ON")
|
|
76
|
+
connection.execute("PRAGMA synchronous = NORMAL")
|
|
77
|
+
|
|
78
|
+
return connection # type: ignore[no-any-return]
|
|
79
|
+
|
|
80
|
+
def _get_thread_connection(self) -> SqliteConnection:
|
|
81
|
+
"""Get or create a connection for the current thread."""
|
|
82
|
+
try:
|
|
83
|
+
return cast("SqliteConnection", self._thread_local.connection)
|
|
84
|
+
except AttributeError:
|
|
85
|
+
# Connection doesn't exist for this thread yet
|
|
86
|
+
connection = self._create_connection()
|
|
87
|
+
self._thread_local.connection = connection
|
|
88
|
+
return connection
|
|
89
|
+
|
|
90
|
+
def _close_thread_connection(self) -> None:
|
|
91
|
+
"""Close the connection for the current thread."""
|
|
92
|
+
try:
|
|
93
|
+
connection = self._thread_local.connection
|
|
94
|
+
connection.close()
|
|
95
|
+
del self._thread_local.connection
|
|
96
|
+
except AttributeError:
|
|
97
|
+
# No connection for this thread
|
|
98
|
+
pass
|
|
99
|
+
|
|
100
|
+
@contextmanager
|
|
101
|
+
def get_connection(self) -> "Generator[SqliteConnection, None, None]":
|
|
102
|
+
"""Get a thread-local connection.
|
|
103
|
+
|
|
104
|
+
Yields:
|
|
105
|
+
SqliteConnection: A thread-local connection.
|
|
106
|
+
"""
|
|
107
|
+
yield self._get_thread_connection()
|
|
108
|
+
|
|
109
|
+
def close(self) -> None:
|
|
110
|
+
"""Close the thread-local connection if it exists."""
|
|
111
|
+
self._close_thread_connection()
|
|
112
|
+
|
|
113
|
+
def acquire(self) -> SqliteConnection:
|
|
114
|
+
"""Acquire a thread-local connection.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
SqliteConnection: A thread-local connection
|
|
118
|
+
"""
|
|
119
|
+
return self._get_thread_connection()
|
|
120
|
+
|
|
121
|
+
def release(self, connection: SqliteConnection) -> None:
|
|
122
|
+
"""Release a connection (no-op for thread-local connections).
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
connection: The connection to release (ignored)
|
|
126
|
+
"""
|
|
127
|
+
# No-op: thread-local connections are managed per-thread
|
|
128
|
+
|
|
129
|
+
# Compatibility methods that return dummy values
|
|
130
|
+
def size(self) -> int:
|
|
131
|
+
"""Get pool size (always 1 for thread-local)."""
|
|
132
|
+
try:
|
|
133
|
+
_ = self._thread_local.connection
|
|
134
|
+
except AttributeError:
|
|
135
|
+
return 0
|
|
136
|
+
return 1
|
|
137
|
+
|
|
138
|
+
def checked_out(self) -> int:
|
|
139
|
+
"""Get number of checked out connections (always 0)."""
|
|
140
|
+
return 0
|
sqlspec/base.py
CHANGED
|
@@ -26,7 +26,10 @@ from sqlspec.utils.logging import get_logger
|
|
|
26
26
|
|
|
27
27
|
if TYPE_CHECKING:
|
|
28
28
|
from contextlib import AbstractAsyncContextManager, AbstractContextManager
|
|
29
|
+
from pathlib import Path
|
|
29
30
|
|
|
31
|
+
from sqlspec.core.statement import SQL
|
|
32
|
+
from sqlspec.loader import SQLFileLoader
|
|
30
33
|
from sqlspec.typing import ConnectionT, PoolT
|
|
31
34
|
|
|
32
35
|
|
|
@@ -38,54 +41,77 @@ logger = get_logger()
|
|
|
38
41
|
class SQLSpec:
|
|
39
42
|
"""Configuration manager and registry for database connections and pools."""
|
|
40
43
|
|
|
41
|
-
__slots__ = ("
|
|
44
|
+
__slots__ = ("_configs", "_instance_cache_config", "_sql_loader")
|
|
42
45
|
|
|
43
|
-
def __init__(self) -> None:
|
|
46
|
+
def __init__(self, *, loader: "Optional[SQLFileLoader]" = None) -> None:
|
|
44
47
|
self._configs: dict[Any, DatabaseConfigProtocol[Any, Any, Any]] = {}
|
|
45
|
-
|
|
48
|
+
# Register sync cleanup only for sync resources
|
|
49
|
+
atexit.register(self._cleanup_sync_pools)
|
|
46
50
|
self._instance_cache_config: Optional[CacheConfig] = None
|
|
47
|
-
self.
|
|
51
|
+
self._sql_loader: Optional[SQLFileLoader] = loader
|
|
48
52
|
|
|
49
53
|
@staticmethod
|
|
50
54
|
def _get_config_name(obj: Any) -> str:
|
|
51
55
|
"""Get display name for configuration object."""
|
|
52
56
|
return getattr(obj, "__name__", str(obj))
|
|
53
57
|
|
|
54
|
-
def
|
|
55
|
-
"""Clean up
|
|
58
|
+
def _cleanup_sync_pools(self) -> None:
|
|
59
|
+
"""Clean up only synchronous connection pools at exit."""
|
|
56
60
|
cleaned_count = 0
|
|
57
61
|
|
|
62
|
+
for config_type, config in self._configs.items():
|
|
63
|
+
if config.supports_connection_pooling and not config.is_async:
|
|
64
|
+
try:
|
|
65
|
+
config.close_pool()
|
|
66
|
+
cleaned_count += 1
|
|
67
|
+
except Exception as e:
|
|
68
|
+
logger.warning("Failed to clean up sync pool for config %s: %s", config_type.__name__, e)
|
|
69
|
+
|
|
70
|
+
if cleaned_count > 0:
|
|
71
|
+
logger.debug("Sync pool cleanup completed. Cleaned %d pools.", cleaned_count)
|
|
72
|
+
|
|
73
|
+
async def close_all_pools(self) -> None:
|
|
74
|
+
"""Explicitly close all connection pools (async and sync).
|
|
75
|
+
|
|
76
|
+
This method should be called before application shutdown for proper cleanup.
|
|
77
|
+
"""
|
|
78
|
+
cleanup_tasks = []
|
|
79
|
+
sync_configs = []
|
|
80
|
+
|
|
58
81
|
for config_type, config in self._configs.items():
|
|
59
82
|
if config.supports_connection_pooling:
|
|
60
83
|
try:
|
|
61
84
|
if config.is_async:
|
|
62
85
|
close_pool_awaitable = config.close_pool()
|
|
63
86
|
if close_pool_awaitable is not None:
|
|
64
|
-
|
|
65
|
-
loop = asyncio.get_running_loop()
|
|
66
|
-
if loop.is_running():
|
|
67
|
-
task = asyncio.create_task(cast("Coroutine[Any, Any, None]", close_pool_awaitable))
|
|
68
|
-
self._cleanup_tasks.append(task)
|
|
69
|
-
else:
|
|
70
|
-
asyncio.run(cast("Coroutine[Any, Any, None]", close_pool_awaitable))
|
|
71
|
-
except RuntimeError:
|
|
72
|
-
asyncio.run(cast("Coroutine[Any, Any, None]", close_pool_awaitable))
|
|
87
|
+
cleanup_tasks.append(cast("Coroutine[Any, Any, None]", close_pool_awaitable))
|
|
73
88
|
else:
|
|
74
|
-
|
|
75
|
-
cleaned_count += 1
|
|
89
|
+
sync_configs.append((config_type, config))
|
|
76
90
|
except Exception as e:
|
|
77
|
-
logger.warning("Failed to
|
|
91
|
+
logger.warning("Failed to prepare cleanup for config %s: %s", config_type.__name__, e)
|
|
78
92
|
|
|
79
|
-
|
|
93
|
+
# Close async pools concurrently
|
|
94
|
+
if cleanup_tasks:
|
|
80
95
|
try:
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
96
|
+
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
|
|
97
|
+
logger.debug("Async pool cleanup completed. Cleaned %d pools.", len(cleanup_tasks))
|
|
98
|
+
except Exception as e:
|
|
99
|
+
logger.warning("Failed to complete async pool cleanup: %s", e)
|
|
100
|
+
|
|
101
|
+
# Close sync pools
|
|
102
|
+
for _config_type, config in sync_configs:
|
|
103
|
+
config.close_pool() # Let exceptions propagate for proper logging
|
|
104
|
+
|
|
105
|
+
if sync_configs:
|
|
106
|
+
logger.debug("Sync pool cleanup completed. Cleaned %d pools.", len(sync_configs))
|
|
86
107
|
|
|
87
|
-
|
|
88
|
-
|
|
108
|
+
async def __aenter__(self) -> "SQLSpec":
|
|
109
|
+
"""Async context manager entry."""
|
|
110
|
+
return self
|
|
111
|
+
|
|
112
|
+
async def __aexit__(self, _exc_type: Any, _exc_val: Any, _exc_tb: Any) -> None:
|
|
113
|
+
"""Async context manager exit with automatic cleanup."""
|
|
114
|
+
await self.close_all_pools()
|
|
89
115
|
|
|
90
116
|
@overload
|
|
91
117
|
def add_config(self, config: "SyncConfigT") -> "type[SyncConfigT]": # pyright: ignore[reportInvalidTypeVarUse]
|
|
@@ -569,3 +595,98 @@ class SQLSpec:
|
|
|
569
595
|
else current_config.optimized_cache_enabled,
|
|
570
596
|
)
|
|
571
597
|
)
|
|
598
|
+
|
|
599
|
+
# SQL File Loading Integration
|
|
600
|
+
|
|
601
|
+
def _ensure_sql_loader(self) -> "SQLFileLoader":
|
|
602
|
+
"""Ensure SQL loader is initialized lazily."""
|
|
603
|
+
if self._sql_loader is None:
|
|
604
|
+
# Import here to avoid circular imports
|
|
605
|
+
from sqlspec.loader import SQLFileLoader
|
|
606
|
+
|
|
607
|
+
self._sql_loader = SQLFileLoader()
|
|
608
|
+
return self._sql_loader
|
|
609
|
+
|
|
610
|
+
def load_sql_files(self, *paths: "Union[str, Path]") -> None:
|
|
611
|
+
"""Load SQL files from paths or directories.
|
|
612
|
+
|
|
613
|
+
Args:
|
|
614
|
+
*paths: One or more file paths or directory paths to load.
|
|
615
|
+
"""
|
|
616
|
+
loader = self._ensure_sql_loader()
|
|
617
|
+
loader.load_sql(*paths)
|
|
618
|
+
logger.debug("Loaded SQL files: %s", paths)
|
|
619
|
+
|
|
620
|
+
def add_named_sql(self, name: str, sql: str, dialect: "Optional[str]" = None) -> None:
|
|
621
|
+
"""Add a named SQL query directly.
|
|
622
|
+
|
|
623
|
+
Args:
|
|
624
|
+
name: Name for the SQL query.
|
|
625
|
+
sql: Raw SQL content.
|
|
626
|
+
dialect: Optional dialect for the SQL statement.
|
|
627
|
+
"""
|
|
628
|
+
loader = self._ensure_sql_loader()
|
|
629
|
+
loader.add_named_sql(name, sql, dialect)
|
|
630
|
+
logger.debug("Added named SQL: %s", name)
|
|
631
|
+
|
|
632
|
+
def get_sql(self, name: str) -> "SQL":
|
|
633
|
+
"""Get a SQL object by name.
|
|
634
|
+
|
|
635
|
+
Args:
|
|
636
|
+
name: Name of the statement (from -- name: in SQL file).
|
|
637
|
+
Hyphens in names are converted to underscores.
|
|
638
|
+
|
|
639
|
+
Returns:
|
|
640
|
+
SQL object ready for execution.
|
|
641
|
+
"""
|
|
642
|
+
loader = self._ensure_sql_loader()
|
|
643
|
+
return loader.get_sql(name)
|
|
644
|
+
|
|
645
|
+
def list_sql_queries(self) -> "list[str]":
|
|
646
|
+
"""List all available query names.
|
|
647
|
+
|
|
648
|
+
Returns:
|
|
649
|
+
Sorted list of query names.
|
|
650
|
+
"""
|
|
651
|
+
if self._sql_loader is None:
|
|
652
|
+
return []
|
|
653
|
+
return self._sql_loader.list_queries()
|
|
654
|
+
|
|
655
|
+
def has_sql_query(self, name: str) -> bool:
|
|
656
|
+
"""Check if a SQL query exists.
|
|
657
|
+
|
|
658
|
+
Args:
|
|
659
|
+
name: Query name to check.
|
|
660
|
+
|
|
661
|
+
Returns:
|
|
662
|
+
True if query exists.
|
|
663
|
+
"""
|
|
664
|
+
if self._sql_loader is None:
|
|
665
|
+
return False
|
|
666
|
+
return self._sql_loader.has_query(name)
|
|
667
|
+
|
|
668
|
+
def clear_sql_cache(self) -> None:
|
|
669
|
+
"""Clear the SQL file cache."""
|
|
670
|
+
if self._sql_loader is not None:
|
|
671
|
+
self._sql_loader.clear_cache()
|
|
672
|
+
logger.debug("Cleared SQL cache")
|
|
673
|
+
|
|
674
|
+
def reload_sql_files(self) -> None:
|
|
675
|
+
"""Reload all SQL files.
|
|
676
|
+
|
|
677
|
+
Note: This clears the cache and requires calling load_sql_files again.
|
|
678
|
+
"""
|
|
679
|
+
if self._sql_loader is not None:
|
|
680
|
+
# Clear cache to force reload
|
|
681
|
+
self._sql_loader.clear_cache()
|
|
682
|
+
logger.debug("Cleared SQL cache for reload")
|
|
683
|
+
|
|
684
|
+
def get_sql_files(self) -> "list[str]":
|
|
685
|
+
"""Get list of loaded SQL files.
|
|
686
|
+
|
|
687
|
+
Returns:
|
|
688
|
+
Sorted list of file paths.
|
|
689
|
+
"""
|
|
690
|
+
if self._sql_loader is None:
|
|
691
|
+
return []
|
|
692
|
+
return self._sql_loader.list_files()
|
sqlspec/builder/__init__.py
CHANGED
|
@@ -29,10 +29,13 @@ from sqlspec.builder._merge import Merge
|
|
|
29
29
|
from sqlspec.builder._select import Select
|
|
30
30
|
from sqlspec.builder._update import Update
|
|
31
31
|
from sqlspec.builder.mixins import WhereClauseMixin
|
|
32
|
+
from sqlspec.builder.mixins._join_operations import JoinBuilder
|
|
33
|
+
from sqlspec.builder.mixins._select_operations import Case, SubqueryBuilder, WindowFunctionBuilder
|
|
32
34
|
from sqlspec.exceptions import SQLBuilderError
|
|
33
35
|
|
|
34
36
|
__all__ = (
|
|
35
37
|
"AlterTable",
|
|
38
|
+
"Case",
|
|
36
39
|
"Column",
|
|
37
40
|
"ColumnExpression",
|
|
38
41
|
"CommentOn",
|
|
@@ -50,13 +53,16 @@ __all__ = (
|
|
|
50
53
|
"DropView",
|
|
51
54
|
"FunctionColumn",
|
|
52
55
|
"Insert",
|
|
56
|
+
"JoinBuilder",
|
|
53
57
|
"Merge",
|
|
54
58
|
"QueryBuilder",
|
|
55
59
|
"RenameTable",
|
|
56
60
|
"SQLBuilderError",
|
|
57
61
|
"SafeQuery",
|
|
58
62
|
"Select",
|
|
63
|
+
"SubqueryBuilder",
|
|
59
64
|
"Truncate",
|
|
60
65
|
"Update",
|
|
61
66
|
"WhereClauseMixin",
|
|
67
|
+
"WindowFunctionBuilder",
|
|
62
68
|
)
|
sqlspec/builder/_insert.py
CHANGED
|
@@ -142,6 +142,29 @@ class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSe
|
|
|
142
142
|
for i, value in enumerate(values):
|
|
143
143
|
if isinstance(value, exp.Expression):
|
|
144
144
|
value_placeholders.append(value)
|
|
145
|
+
elif hasattr(value, "expression") and hasattr(value, "sql"):
|
|
146
|
+
# Handle SQL objects (from sql.raw with parameters)
|
|
147
|
+
expression = getattr(value, "expression", None)
|
|
148
|
+
if expression is not None and isinstance(expression, exp.Expression):
|
|
149
|
+
# Merge parameters from SQL object into builder
|
|
150
|
+
if hasattr(value, "parameters"):
|
|
151
|
+
sql_parameters = getattr(value, "parameters", {})
|
|
152
|
+
for param_name, param_value in sql_parameters.items():
|
|
153
|
+
self.add_parameter(param_value, name=param_name)
|
|
154
|
+
value_placeholders.append(expression)
|
|
155
|
+
else:
|
|
156
|
+
# If expression is None, fall back to parsing the raw SQL
|
|
157
|
+
sql_text = getattr(value, "sql", "")
|
|
158
|
+
# Merge parameters even when parsing raw SQL
|
|
159
|
+
if hasattr(value, "parameters"):
|
|
160
|
+
sql_parameters = getattr(value, "parameters", {})
|
|
161
|
+
for param_name, param_value in sql_parameters.items():
|
|
162
|
+
self.add_parameter(param_value, name=param_name)
|
|
163
|
+
# Check if sql_text is callable (like Expression.sql method)
|
|
164
|
+
if callable(sql_text):
|
|
165
|
+
sql_text = str(value)
|
|
166
|
+
value_expr = exp.maybe_parse(sql_text) or exp.convert(str(sql_text))
|
|
167
|
+
value_placeholders.append(value_expr)
|
|
145
168
|
else:
|
|
146
169
|
if self._columns and i < len(self._columns):
|
|
147
170
|
column_str = str(self._columns[i])
|
|
@@ -228,29 +251,171 @@ class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSe
|
|
|
228
251
|
|
|
229
252
|
return self
|
|
230
253
|
|
|
231
|
-
def
|
|
232
|
-
"""Adds an ON CONFLICT
|
|
254
|
+
def on_conflict(self, *columns: str) -> "ConflictBuilder":
|
|
255
|
+
"""Adds an ON CONFLICT clause with specified columns.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
*columns: Column names that define the conflict. If no columns provided,
|
|
259
|
+
creates an ON CONFLICT without specific columns (catches all conflicts).
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
A ConflictBuilder instance for chaining conflict resolution methods.
|
|
263
|
+
|
|
264
|
+
Example:
|
|
265
|
+
```python
|
|
266
|
+
# ON CONFLICT (id) DO NOTHING
|
|
267
|
+
sql.insert("users").values(id=1, name="John").on_conflict(
|
|
268
|
+
"id"
|
|
269
|
+
).do_nothing()
|
|
270
|
+
|
|
271
|
+
# ON CONFLICT (email, username) DO UPDATE SET updated_at = NOW()
|
|
272
|
+
sql.insert("users").values(...).on_conflict(
|
|
273
|
+
"email", "username"
|
|
274
|
+
).do_update(updated_at=sql.raw("NOW()"))
|
|
275
|
+
|
|
276
|
+
# ON CONFLICT DO NOTHING (catches all conflicts)
|
|
277
|
+
sql.insert("users").values(...).on_conflict().do_nothing()
|
|
278
|
+
```
|
|
279
|
+
"""
|
|
280
|
+
return ConflictBuilder(self, columns)
|
|
281
|
+
|
|
282
|
+
def on_conflict_do_nothing(self, *columns: str) -> "Insert":
|
|
283
|
+
"""Adds an ON CONFLICT DO NOTHING clause (convenience method).
|
|
233
284
|
|
|
234
|
-
|
|
285
|
+
Args:
|
|
286
|
+
*columns: Column names that define the conflict. If no columns provided,
|
|
287
|
+
creates an ON CONFLICT without specific columns.
|
|
235
288
|
|
|
236
289
|
Returns:
|
|
237
290
|
The current builder instance for method chaining.
|
|
238
291
|
|
|
239
292
|
Note:
|
|
240
|
-
This is
|
|
241
|
-
For a more general solution, you might need dialect-specific handling.
|
|
293
|
+
This is a convenience method. For more control, use on_conflict().do_nothing().
|
|
242
294
|
"""
|
|
243
|
-
|
|
244
|
-
insert_expr.set("on", exp.OnConflict(this=None, expressions=[]))
|
|
245
|
-
return self
|
|
295
|
+
return self.on_conflict(*columns).do_nothing()
|
|
246
296
|
|
|
247
|
-
def on_duplicate_key_update(self, **
|
|
248
|
-
"""Adds
|
|
297
|
+
def on_duplicate_key_update(self, **kwargs: Any) -> "Insert":
|
|
298
|
+
"""Adds conflict resolution using the ON CONFLICT syntax (cross-database compatible).
|
|
249
299
|
|
|
250
300
|
Args:
|
|
251
|
-
**
|
|
301
|
+
**kwargs: Column-value pairs to update on conflict.
|
|
252
302
|
|
|
253
303
|
Returns:
|
|
254
304
|
The current builder instance for method chaining.
|
|
305
|
+
|
|
306
|
+
Note:
|
|
307
|
+
This method uses PostgreSQL-style ON CONFLICT syntax but SQLGlot will
|
|
308
|
+
transpile it to the appropriate syntax for each database (MySQL's
|
|
309
|
+
ON DUPLICATE KEY UPDATE, etc.).
|
|
255
310
|
"""
|
|
256
|
-
|
|
311
|
+
if not kwargs:
|
|
312
|
+
return self
|
|
313
|
+
return self.on_conflict().do_update(**kwargs)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class ConflictBuilder:
|
|
317
|
+
"""Builder for ON CONFLICT clauses in INSERT statements.
|
|
318
|
+
|
|
319
|
+
This builder provides a fluent interface for constructing conflict resolution
|
|
320
|
+
clauses using PostgreSQL-style syntax, which SQLGlot can transpile to other dialects.
|
|
321
|
+
"""
|
|
322
|
+
|
|
323
|
+
__slots__ = ("_columns", "_insert_builder")
|
|
324
|
+
|
|
325
|
+
def __init__(self, insert_builder: "Insert", columns: tuple[str, ...]) -> None:
|
|
326
|
+
"""Initialize ConflictBuilder.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
insert_builder: The parent Insert builder
|
|
330
|
+
columns: Column names that define the conflict
|
|
331
|
+
"""
|
|
332
|
+
self._insert_builder = insert_builder
|
|
333
|
+
self._columns = columns
|
|
334
|
+
|
|
335
|
+
def do_nothing(self) -> "Insert":
|
|
336
|
+
"""Add DO NOTHING conflict resolution.
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
The parent Insert builder for method chaining.
|
|
340
|
+
|
|
341
|
+
Example:
|
|
342
|
+
```python
|
|
343
|
+
sql.insert("users").values(id=1, name="John").on_conflict(
|
|
344
|
+
"id"
|
|
345
|
+
).do_nothing()
|
|
346
|
+
```
|
|
347
|
+
"""
|
|
348
|
+
insert_expr = self._insert_builder._get_insert_expression()
|
|
349
|
+
|
|
350
|
+
# Create ON CONFLICT with proper structure
|
|
351
|
+
conflict_keys = [exp.to_identifier(col) for col in self._columns] if self._columns else None
|
|
352
|
+
on_conflict = exp.OnConflict(conflict_keys=conflict_keys, action=exp.var("DO NOTHING"))
|
|
353
|
+
|
|
354
|
+
insert_expr.set("conflict", on_conflict)
|
|
355
|
+
return self._insert_builder
|
|
356
|
+
|
|
357
|
+
def do_update(self, **kwargs: Any) -> "Insert":
|
|
358
|
+
"""Add DO UPDATE conflict resolution with SET clauses.
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
**kwargs: Column-value pairs to update on conflict.
|
|
362
|
+
|
|
363
|
+
Returns:
|
|
364
|
+
The parent Insert builder for method chaining.
|
|
365
|
+
|
|
366
|
+
Example:
|
|
367
|
+
```python
|
|
368
|
+
sql.insert("users").values(id=1, name="John").on_conflict(
|
|
369
|
+
"id"
|
|
370
|
+
).do_update(
|
|
371
|
+
name="Updated Name", updated_at=sql.raw("NOW()")
|
|
372
|
+
)
|
|
373
|
+
```
|
|
374
|
+
"""
|
|
375
|
+
insert_expr = self._insert_builder._get_insert_expression()
|
|
376
|
+
|
|
377
|
+
# Create SET expressions for the UPDATE
|
|
378
|
+
set_expressions = []
|
|
379
|
+
for col, val in kwargs.items():
|
|
380
|
+
if hasattr(val, "expression") and hasattr(val, "sql"):
|
|
381
|
+
# Handle SQL objects (from sql.raw with parameters)
|
|
382
|
+
expression = getattr(val, "expression", None)
|
|
383
|
+
if expression is not None and isinstance(expression, exp.Expression):
|
|
384
|
+
# Merge parameters from SQL object into builder
|
|
385
|
+
if hasattr(val, "parameters"):
|
|
386
|
+
sql_parameters = getattr(val, "parameters", {})
|
|
387
|
+
for param_name, param_value in sql_parameters.items():
|
|
388
|
+
self._insert_builder.add_parameter(param_value, name=param_name)
|
|
389
|
+
value_expr = expression
|
|
390
|
+
else:
|
|
391
|
+
# If expression is None, fall back to parsing the raw SQL
|
|
392
|
+
sql_text = getattr(val, "sql", "")
|
|
393
|
+
# Merge parameters even when parsing raw SQL
|
|
394
|
+
if hasattr(val, "parameters"):
|
|
395
|
+
sql_parameters = getattr(val, "parameters", {})
|
|
396
|
+
for param_name, param_value in sql_parameters.items():
|
|
397
|
+
self._insert_builder.add_parameter(param_value, name=param_name)
|
|
398
|
+
# Check if sql_text is callable (like Expression.sql method)
|
|
399
|
+
if callable(sql_text):
|
|
400
|
+
sql_text = str(val)
|
|
401
|
+
value_expr = exp.maybe_parse(sql_text) or exp.convert(str(sql_text))
|
|
402
|
+
elif isinstance(val, exp.Expression):
|
|
403
|
+
value_expr = val
|
|
404
|
+
else:
|
|
405
|
+
# Create parameter for regular values
|
|
406
|
+
param_name = self._insert_builder._generate_unique_parameter_name(col)
|
|
407
|
+
_, param_name = self._insert_builder.add_parameter(val, name=param_name)
|
|
408
|
+
value_expr = exp.Placeholder(this=param_name)
|
|
409
|
+
|
|
410
|
+
set_expressions.append(exp.EQ(this=exp.column(col), expression=value_expr))
|
|
411
|
+
|
|
412
|
+
# Create ON CONFLICT with proper structure
|
|
413
|
+
conflict_keys = [exp.to_identifier(col) for col in self._columns] if self._columns else None
|
|
414
|
+
on_conflict = exp.OnConflict(
|
|
415
|
+
conflict_keys=conflict_keys,
|
|
416
|
+
action=exp.var("DO UPDATE"),
|
|
417
|
+
expressions=set_expressions if set_expressions else None,
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
insert_expr.set("conflict", on_conflict)
|
|
421
|
+
return self._insert_builder
|