sqlspec 0.16.2__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +11 -1
- sqlspec/_sql.py +16 -412
- sqlspec/adapters/aiosqlite/__init__.py +11 -1
- sqlspec/adapters/aiosqlite/config.py +137 -165
- sqlspec/adapters/aiosqlite/driver.py +21 -10
- sqlspec/adapters/aiosqlite/pool.py +492 -0
- sqlspec/adapters/duckdb/__init__.py +2 -0
- sqlspec/adapters/duckdb/config.py +11 -235
- sqlspec/adapters/duckdb/pool.py +243 -0
- sqlspec/adapters/sqlite/__init__.py +2 -0
- sqlspec/adapters/sqlite/config.py +4 -115
- sqlspec/adapters/sqlite/pool.py +140 -0
- sqlspec/base.py +147 -26
- sqlspec/builder/__init__.py +6 -0
- sqlspec/builder/_parsing_utils.py +27 -0
- sqlspec/builder/mixins/_join_operations.py +115 -1
- sqlspec/builder/mixins/_select_operations.py +307 -3
- sqlspec/builder/mixins/_where_clause.py +60 -11
- sqlspec/core/compiler.py +7 -5
- sqlspec/driver/_common.py +9 -1
- sqlspec/loader.py +27 -54
- sqlspec/storage/registry.py +2 -2
- sqlspec/typing.py +53 -99
- {sqlspec-0.16.2.dist-info → sqlspec-0.17.0.dist-info}/METADATA +1 -1
- {sqlspec-0.16.2.dist-info → sqlspec-0.17.0.dist-info}/RECORD +29 -26
- {sqlspec-0.16.2.dist-info → sqlspec-0.17.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.16.2.dist-info → sqlspec-0.17.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.16.2.dist-info → sqlspec-0.17.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.16.2.dist-info → sqlspec-0.17.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
"""SQLite database configuration with thread-local connections."""
|
|
2
|
+
|
|
3
|
+
import sqlite3
|
|
4
|
+
import threading
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Optional, TypedDict, cast
|
|
7
|
+
|
|
8
|
+
from typing_extensions import NotRequired
|
|
9
|
+
|
|
10
|
+
from sqlspec.adapters.sqlite._types import SqliteConnection
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from collections.abc import Generator
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SqliteConnectionParams(TypedDict, total=False):
|
|
17
|
+
"""SQLite connection parameters."""
|
|
18
|
+
|
|
19
|
+
database: NotRequired[str]
|
|
20
|
+
timeout: NotRequired[float]
|
|
21
|
+
detect_types: NotRequired[int]
|
|
22
|
+
isolation_level: "NotRequired[Optional[str]]"
|
|
23
|
+
check_same_thread: NotRequired[bool]
|
|
24
|
+
factory: "NotRequired[Optional[type[SqliteConnection]]]"
|
|
25
|
+
cached_statements: NotRequired[int]
|
|
26
|
+
uri: NotRequired[bool]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
__all__ = ("SqliteConnectionPool",)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class SqliteConnectionPool:
|
|
33
|
+
"""Thread-local connection manager for SQLite.
|
|
34
|
+
|
|
35
|
+
SQLite connections aren't thread-safe, so we use thread-local storage
|
|
36
|
+
to ensure each thread has its own connection. This is simpler and more
|
|
37
|
+
efficient than a traditional pool for SQLite's constraints.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
__slots__ = ("_connection_parameters", "_enable_optimizations", "_thread_local")
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
connection_parameters: "dict[str, Any]",
|
|
45
|
+
enable_optimizations: bool = True,
|
|
46
|
+
**kwargs: Any, # Accept and ignore pool parameters for compatibility
|
|
47
|
+
) -> None:
|
|
48
|
+
"""Initialize the thread-local connection manager.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
connection_parameters: SQLite connection parameters
|
|
52
|
+
enable_optimizations: Whether to apply performance PRAGMAs
|
|
53
|
+
**kwargs: Ignored pool parameters for compatibility
|
|
54
|
+
"""
|
|
55
|
+
self._connection_parameters = connection_parameters
|
|
56
|
+
self._thread_local = threading.local()
|
|
57
|
+
self._enable_optimizations = enable_optimizations
|
|
58
|
+
|
|
59
|
+
def _create_connection(self) -> SqliteConnection:
|
|
60
|
+
"""Create a new SQLite connection with optimizations."""
|
|
61
|
+
connection = sqlite3.connect(**self._connection_parameters)
|
|
62
|
+
|
|
63
|
+
# Only apply optimizations if requested and not in-memory
|
|
64
|
+
if self._enable_optimizations:
|
|
65
|
+
database = self._connection_parameters.get("database", ":memory:")
|
|
66
|
+
is_memory = database == ":memory:" or database.startswith("file::memory:")
|
|
67
|
+
|
|
68
|
+
if not is_memory:
|
|
69
|
+
# WAL mode doesn't work with in-memory databases
|
|
70
|
+
connection.execute("PRAGMA journal_mode = WAL")
|
|
71
|
+
# Set busy timeout for better concurrent access
|
|
72
|
+
connection.execute("PRAGMA busy_timeout = 5000")
|
|
73
|
+
connection.execute("PRAGMA optimize")
|
|
74
|
+
# These work for all database types
|
|
75
|
+
connection.execute("PRAGMA foreign_keys = ON")
|
|
76
|
+
connection.execute("PRAGMA synchronous = NORMAL")
|
|
77
|
+
|
|
78
|
+
return connection # type: ignore[no-any-return]
|
|
79
|
+
|
|
80
|
+
def _get_thread_connection(self) -> SqliteConnection:
|
|
81
|
+
"""Get or create a connection for the current thread."""
|
|
82
|
+
try:
|
|
83
|
+
return cast("SqliteConnection", self._thread_local.connection)
|
|
84
|
+
except AttributeError:
|
|
85
|
+
# Connection doesn't exist for this thread yet
|
|
86
|
+
connection = self._create_connection()
|
|
87
|
+
self._thread_local.connection = connection
|
|
88
|
+
return connection
|
|
89
|
+
|
|
90
|
+
def _close_thread_connection(self) -> None:
|
|
91
|
+
"""Close the connection for the current thread."""
|
|
92
|
+
try:
|
|
93
|
+
connection = self._thread_local.connection
|
|
94
|
+
connection.close()
|
|
95
|
+
del self._thread_local.connection
|
|
96
|
+
except AttributeError:
|
|
97
|
+
# No connection for this thread
|
|
98
|
+
pass
|
|
99
|
+
|
|
100
|
+
@contextmanager
|
|
101
|
+
def get_connection(self) -> "Generator[SqliteConnection, None, None]":
|
|
102
|
+
"""Get a thread-local connection.
|
|
103
|
+
|
|
104
|
+
Yields:
|
|
105
|
+
SqliteConnection: A thread-local connection.
|
|
106
|
+
"""
|
|
107
|
+
yield self._get_thread_connection()
|
|
108
|
+
|
|
109
|
+
def close(self) -> None:
|
|
110
|
+
"""Close the thread-local connection if it exists."""
|
|
111
|
+
self._close_thread_connection()
|
|
112
|
+
|
|
113
|
+
def acquire(self) -> SqliteConnection:
|
|
114
|
+
"""Acquire a thread-local connection.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
SqliteConnection: A thread-local connection
|
|
118
|
+
"""
|
|
119
|
+
return self._get_thread_connection()
|
|
120
|
+
|
|
121
|
+
def release(self, connection: SqliteConnection) -> None:
|
|
122
|
+
"""Release a connection (no-op for thread-local connections).
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
connection: The connection to release (ignored)
|
|
126
|
+
"""
|
|
127
|
+
# No-op: thread-local connections are managed per-thread
|
|
128
|
+
|
|
129
|
+
# Compatibility methods that return dummy values
|
|
130
|
+
def size(self) -> int:
|
|
131
|
+
"""Get pool size (always 1 for thread-local)."""
|
|
132
|
+
try:
|
|
133
|
+
_ = self._thread_local.connection
|
|
134
|
+
except AttributeError:
|
|
135
|
+
return 0
|
|
136
|
+
return 1
|
|
137
|
+
|
|
138
|
+
def checked_out(self) -> int:
|
|
139
|
+
"""Get number of checked out connections (always 0)."""
|
|
140
|
+
return 0
|
sqlspec/base.py
CHANGED
|
@@ -26,7 +26,10 @@ from sqlspec.utils.logging import get_logger
|
|
|
26
26
|
|
|
27
27
|
if TYPE_CHECKING:
|
|
28
28
|
from contextlib import AbstractAsyncContextManager, AbstractContextManager
|
|
29
|
+
from pathlib import Path
|
|
29
30
|
|
|
31
|
+
from sqlspec.core.statement import SQL
|
|
32
|
+
from sqlspec.loader import SQLFileLoader
|
|
30
33
|
from sqlspec.typing import ConnectionT, PoolT
|
|
31
34
|
|
|
32
35
|
|
|
@@ -38,54 +41,77 @@ logger = get_logger()
|
|
|
38
41
|
class SQLSpec:
|
|
39
42
|
"""Configuration manager and registry for database connections and pools."""
|
|
40
43
|
|
|
41
|
-
__slots__ = ("
|
|
44
|
+
__slots__ = ("_configs", "_instance_cache_config", "_sql_loader")
|
|
42
45
|
|
|
43
|
-
def __init__(self) -> None:
|
|
46
|
+
def __init__(self, *, loader: "Optional[SQLFileLoader]" = None) -> None:
|
|
44
47
|
self._configs: dict[Any, DatabaseConfigProtocol[Any, Any, Any]] = {}
|
|
45
|
-
|
|
48
|
+
# Register sync cleanup only for sync resources
|
|
49
|
+
atexit.register(self._cleanup_sync_pools)
|
|
46
50
|
self._instance_cache_config: Optional[CacheConfig] = None
|
|
47
|
-
self.
|
|
51
|
+
self._sql_loader: Optional[SQLFileLoader] = loader
|
|
48
52
|
|
|
49
53
|
@staticmethod
|
|
50
54
|
def _get_config_name(obj: Any) -> str:
|
|
51
55
|
"""Get display name for configuration object."""
|
|
52
56
|
return getattr(obj, "__name__", str(obj))
|
|
53
57
|
|
|
54
|
-
def
|
|
55
|
-
"""Clean up
|
|
58
|
+
def _cleanup_sync_pools(self) -> None:
|
|
59
|
+
"""Clean up only synchronous connection pools at exit."""
|
|
56
60
|
cleaned_count = 0
|
|
57
61
|
|
|
62
|
+
for config_type, config in self._configs.items():
|
|
63
|
+
if config.supports_connection_pooling and not config.is_async:
|
|
64
|
+
try:
|
|
65
|
+
config.close_pool()
|
|
66
|
+
cleaned_count += 1
|
|
67
|
+
except Exception as e:
|
|
68
|
+
logger.warning("Failed to clean up sync pool for config %s: %s", config_type.__name__, e)
|
|
69
|
+
|
|
70
|
+
if cleaned_count > 0:
|
|
71
|
+
logger.debug("Sync pool cleanup completed. Cleaned %d pools.", cleaned_count)
|
|
72
|
+
|
|
73
|
+
async def close_all_pools(self) -> None:
|
|
74
|
+
"""Explicitly close all connection pools (async and sync).
|
|
75
|
+
|
|
76
|
+
This method should be called before application shutdown for proper cleanup.
|
|
77
|
+
"""
|
|
78
|
+
cleanup_tasks = []
|
|
79
|
+
sync_configs = []
|
|
80
|
+
|
|
58
81
|
for config_type, config in self._configs.items():
|
|
59
82
|
if config.supports_connection_pooling:
|
|
60
83
|
try:
|
|
61
84
|
if config.is_async:
|
|
62
85
|
close_pool_awaitable = config.close_pool()
|
|
63
86
|
if close_pool_awaitable is not None:
|
|
64
|
-
|
|
65
|
-
loop = asyncio.get_running_loop()
|
|
66
|
-
if loop.is_running():
|
|
67
|
-
task = asyncio.create_task(cast("Coroutine[Any, Any, None]", close_pool_awaitable))
|
|
68
|
-
self._cleanup_tasks.append(task)
|
|
69
|
-
else:
|
|
70
|
-
asyncio.run(cast("Coroutine[Any, Any, None]", close_pool_awaitable))
|
|
71
|
-
except RuntimeError:
|
|
72
|
-
asyncio.run(cast("Coroutine[Any, Any, None]", close_pool_awaitable))
|
|
87
|
+
cleanup_tasks.append(cast("Coroutine[Any, Any, None]", close_pool_awaitable))
|
|
73
88
|
else:
|
|
74
|
-
|
|
75
|
-
cleaned_count += 1
|
|
89
|
+
sync_configs.append((config_type, config))
|
|
76
90
|
except Exception as e:
|
|
77
|
-
logger.warning("Failed to
|
|
91
|
+
logger.warning("Failed to prepare cleanup for config %s: %s", config_type.__name__, e)
|
|
78
92
|
|
|
79
|
-
|
|
93
|
+
# Close async pools concurrently
|
|
94
|
+
if cleanup_tasks:
|
|
80
95
|
try:
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
96
|
+
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
|
|
97
|
+
logger.debug("Async pool cleanup completed. Cleaned %d pools.", len(cleanup_tasks))
|
|
98
|
+
except Exception as e:
|
|
99
|
+
logger.warning("Failed to complete async pool cleanup: %s", e)
|
|
100
|
+
|
|
101
|
+
# Close sync pools
|
|
102
|
+
for _config_type, config in sync_configs:
|
|
103
|
+
config.close_pool() # Let exceptions propagate for proper logging
|
|
104
|
+
|
|
105
|
+
if sync_configs:
|
|
106
|
+
logger.debug("Sync pool cleanup completed. Cleaned %d pools.", len(sync_configs))
|
|
86
107
|
|
|
87
|
-
|
|
88
|
-
|
|
108
|
+
async def __aenter__(self) -> "SQLSpec":
|
|
109
|
+
"""Async context manager entry."""
|
|
110
|
+
return self
|
|
111
|
+
|
|
112
|
+
async def __aexit__(self, _exc_type: Any, _exc_val: Any, _exc_tb: Any) -> None:
|
|
113
|
+
"""Async context manager exit with automatic cleanup."""
|
|
114
|
+
await self.close_all_pools()
|
|
89
115
|
|
|
90
116
|
@overload
|
|
91
117
|
def add_config(self, config: "SyncConfigT") -> "type[SyncConfigT]": # pyright: ignore[reportInvalidTypeVarUse]
|
|
@@ -569,3 +595,98 @@ class SQLSpec:
|
|
|
569
595
|
else current_config.optimized_cache_enabled,
|
|
570
596
|
)
|
|
571
597
|
)
|
|
598
|
+
|
|
599
|
+
# SQL File Loading Integration
|
|
600
|
+
|
|
601
|
+
def _ensure_sql_loader(self) -> "SQLFileLoader":
|
|
602
|
+
"""Ensure SQL loader is initialized lazily."""
|
|
603
|
+
if self._sql_loader is None:
|
|
604
|
+
# Import here to avoid circular imports
|
|
605
|
+
from sqlspec.loader import SQLFileLoader
|
|
606
|
+
|
|
607
|
+
self._sql_loader = SQLFileLoader()
|
|
608
|
+
return self._sql_loader
|
|
609
|
+
|
|
610
|
+
def load_sql_files(self, *paths: "Union[str, Path]") -> None:
|
|
611
|
+
"""Load SQL files from paths or directories.
|
|
612
|
+
|
|
613
|
+
Args:
|
|
614
|
+
*paths: One or more file paths or directory paths to load.
|
|
615
|
+
"""
|
|
616
|
+
loader = self._ensure_sql_loader()
|
|
617
|
+
loader.load_sql(*paths)
|
|
618
|
+
logger.debug("Loaded SQL files: %s", paths)
|
|
619
|
+
|
|
620
|
+
def add_named_sql(self, name: str, sql: str, dialect: "Optional[str]" = None) -> None:
|
|
621
|
+
"""Add a named SQL query directly.
|
|
622
|
+
|
|
623
|
+
Args:
|
|
624
|
+
name: Name for the SQL query.
|
|
625
|
+
sql: Raw SQL content.
|
|
626
|
+
dialect: Optional dialect for the SQL statement.
|
|
627
|
+
"""
|
|
628
|
+
loader = self._ensure_sql_loader()
|
|
629
|
+
loader.add_named_sql(name, sql, dialect)
|
|
630
|
+
logger.debug("Added named SQL: %s", name)
|
|
631
|
+
|
|
632
|
+
def get_sql(self, name: str) -> "SQL":
|
|
633
|
+
"""Get a SQL object by name.
|
|
634
|
+
|
|
635
|
+
Args:
|
|
636
|
+
name: Name of the statement (from -- name: in SQL file).
|
|
637
|
+
Hyphens in names are converted to underscores.
|
|
638
|
+
|
|
639
|
+
Returns:
|
|
640
|
+
SQL object ready for execution.
|
|
641
|
+
"""
|
|
642
|
+
loader = self._ensure_sql_loader()
|
|
643
|
+
return loader.get_sql(name)
|
|
644
|
+
|
|
645
|
+
def list_sql_queries(self) -> "list[str]":
|
|
646
|
+
"""List all available query names.
|
|
647
|
+
|
|
648
|
+
Returns:
|
|
649
|
+
Sorted list of query names.
|
|
650
|
+
"""
|
|
651
|
+
if self._sql_loader is None:
|
|
652
|
+
return []
|
|
653
|
+
return self._sql_loader.list_queries()
|
|
654
|
+
|
|
655
|
+
def has_sql_query(self, name: str) -> bool:
|
|
656
|
+
"""Check if a SQL query exists.
|
|
657
|
+
|
|
658
|
+
Args:
|
|
659
|
+
name: Query name to check.
|
|
660
|
+
|
|
661
|
+
Returns:
|
|
662
|
+
True if query exists.
|
|
663
|
+
"""
|
|
664
|
+
if self._sql_loader is None:
|
|
665
|
+
return False
|
|
666
|
+
return self._sql_loader.has_query(name)
|
|
667
|
+
|
|
668
|
+
def clear_sql_cache(self) -> None:
|
|
669
|
+
"""Clear the SQL file cache."""
|
|
670
|
+
if self._sql_loader is not None:
|
|
671
|
+
self._sql_loader.clear_cache()
|
|
672
|
+
logger.debug("Cleared SQL cache")
|
|
673
|
+
|
|
674
|
+
def reload_sql_files(self) -> None:
|
|
675
|
+
"""Reload all SQL files.
|
|
676
|
+
|
|
677
|
+
Note: This clears the cache and requires calling load_sql_files again.
|
|
678
|
+
"""
|
|
679
|
+
if self._sql_loader is not None:
|
|
680
|
+
# Clear cache to force reload
|
|
681
|
+
self._sql_loader.clear_cache()
|
|
682
|
+
logger.debug("Cleared SQL cache for reload")
|
|
683
|
+
|
|
684
|
+
def get_sql_files(self) -> "list[str]":
|
|
685
|
+
"""Get list of loaded SQL files.
|
|
686
|
+
|
|
687
|
+
Returns:
|
|
688
|
+
Sorted list of file paths.
|
|
689
|
+
"""
|
|
690
|
+
if self._sql_loader is None:
|
|
691
|
+
return []
|
|
692
|
+
return self._sql_loader.list_files()
|
sqlspec/builder/__init__.py
CHANGED
|
@@ -29,10 +29,13 @@ from sqlspec.builder._merge import Merge
|
|
|
29
29
|
from sqlspec.builder._select import Select
|
|
30
30
|
from sqlspec.builder._update import Update
|
|
31
31
|
from sqlspec.builder.mixins import WhereClauseMixin
|
|
32
|
+
from sqlspec.builder.mixins._join_operations import JoinBuilder
|
|
33
|
+
from sqlspec.builder.mixins._select_operations import Case, SubqueryBuilder, WindowFunctionBuilder
|
|
32
34
|
from sqlspec.exceptions import SQLBuilderError
|
|
33
35
|
|
|
34
36
|
__all__ = (
|
|
35
37
|
"AlterTable",
|
|
38
|
+
"Case",
|
|
36
39
|
"Column",
|
|
37
40
|
"ColumnExpression",
|
|
38
41
|
"CommentOn",
|
|
@@ -50,13 +53,16 @@ __all__ = (
|
|
|
50
53
|
"DropView",
|
|
51
54
|
"FunctionColumn",
|
|
52
55
|
"Insert",
|
|
56
|
+
"JoinBuilder",
|
|
53
57
|
"Merge",
|
|
54
58
|
"QueryBuilder",
|
|
55
59
|
"RenameTable",
|
|
56
60
|
"SQLBuilderError",
|
|
57
61
|
"SafeQuery",
|
|
58
62
|
"Select",
|
|
63
|
+
"SubqueryBuilder",
|
|
59
64
|
"Truncate",
|
|
60
65
|
"Update",
|
|
61
66
|
"WhereClauseMixin",
|
|
67
|
+
"WindowFunctionBuilder",
|
|
62
68
|
)
|
|
@@ -9,6 +9,7 @@ from typing import Any, Final, Optional, Union, cast
|
|
|
9
9
|
|
|
10
10
|
from sqlglot import exp, maybe_parse, parse_one
|
|
11
11
|
|
|
12
|
+
from sqlspec.core.parameters import ParameterStyle
|
|
12
13
|
from sqlspec.utils.type_guards import has_expression_attr, has_parameter_builder
|
|
13
14
|
|
|
14
15
|
|
|
@@ -151,6 +152,32 @@ def parse_condition_expression(
|
|
|
151
152
|
if not isinstance(condition_input, str):
|
|
152
153
|
condition_input = str(condition_input)
|
|
153
154
|
|
|
155
|
+
# Convert database-specific parameter styles to SQLGlot-compatible format
|
|
156
|
+
# This ensures that placeholders like $1, %s, :1 are properly recognized as parameters
|
|
157
|
+
from sqlspec.core.parameters import ParameterValidator
|
|
158
|
+
|
|
159
|
+
validator = ParameterValidator()
|
|
160
|
+
param_info = validator.extract_parameters(condition_input)
|
|
161
|
+
|
|
162
|
+
# If we found parameters, convert incompatible ones to SQLGlot-compatible format
|
|
163
|
+
if param_info:
|
|
164
|
+
# Convert problematic parameter styles to :param_N format for SQLGlot
|
|
165
|
+
converted_condition = condition_input
|
|
166
|
+
for param in reversed(param_info): # Reverse to preserve positions
|
|
167
|
+
if param.style in {
|
|
168
|
+
ParameterStyle.NUMERIC,
|
|
169
|
+
ParameterStyle.POSITIONAL_PYFORMAT,
|
|
170
|
+
ParameterStyle.POSITIONAL_COLON,
|
|
171
|
+
}:
|
|
172
|
+
# Convert $1, %s, :1 to :param_0, :param_1, etc.
|
|
173
|
+
placeholder = f":param_{param.ordinal}"
|
|
174
|
+
converted_condition = (
|
|
175
|
+
converted_condition[: param.position]
|
|
176
|
+
+ placeholder
|
|
177
|
+
+ converted_condition[param.position + len(param.placeholder_text) :]
|
|
178
|
+
)
|
|
179
|
+
condition_input = converted_condition
|
|
180
|
+
|
|
154
181
|
try:
|
|
155
182
|
return exp.condition(condition_input)
|
|
156
183
|
except Exception:
|
|
@@ -9,10 +9,11 @@ from sqlspec.exceptions import SQLBuilderError
|
|
|
9
9
|
from sqlspec.utils.type_guards import has_query_builder_parameters
|
|
10
10
|
|
|
11
11
|
if TYPE_CHECKING:
|
|
12
|
+
from sqlspec.builder._column import ColumnExpression
|
|
12
13
|
from sqlspec.core.statement import SQL
|
|
13
14
|
from sqlspec.protocols import SQLBuilderProtocol
|
|
14
15
|
|
|
15
|
-
__all__ = ("JoinClauseMixin"
|
|
16
|
+
__all__ = ("JoinBuilder", "JoinClauseMixin")
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
@trait
|
|
@@ -147,3 +148,116 @@ class JoinClauseMixin:
|
|
|
147
148
|
join_expr = exp.Join(this=table_expr, kind="CROSS")
|
|
148
149
|
builder._expression = builder._expression.join(join_expr, copy=False)
|
|
149
150
|
return cast("Self", builder)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@trait
|
|
154
|
+
class JoinBuilder:
|
|
155
|
+
"""Builder for JOIN operations with fluent syntax.
|
|
156
|
+
|
|
157
|
+
Example:
|
|
158
|
+
```python
|
|
159
|
+
from sqlspec import sql
|
|
160
|
+
|
|
161
|
+
# sql.left_join_("posts").on("users.id = posts.user_id")
|
|
162
|
+
join_clause = sql.left_join_("posts").on(
|
|
163
|
+
"users.id = posts.user_id"
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Or with query builder
|
|
167
|
+
query = (
|
|
168
|
+
sql.select("users.name", "posts.title")
|
|
169
|
+
.from_("users")
|
|
170
|
+
.join(
|
|
171
|
+
sql.left_join_("posts").on(
|
|
172
|
+
"users.id = posts.user_id"
|
|
173
|
+
)
|
|
174
|
+
)
|
|
175
|
+
)
|
|
176
|
+
```
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
def __init__(self, join_type: str) -> None:
|
|
180
|
+
"""Initialize the join builder.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
join_type: Type of join (inner, left, right, full, cross)
|
|
184
|
+
"""
|
|
185
|
+
self._join_type = join_type.upper()
|
|
186
|
+
self._table: Optional[Union[str, exp.Expression]] = None
|
|
187
|
+
self._condition: Optional[exp.Expression] = None
|
|
188
|
+
self._alias: Optional[str] = None
|
|
189
|
+
|
|
190
|
+
def __eq__(self, other: object) -> "ColumnExpression": # type: ignore[override]
|
|
191
|
+
"""Equal to (==) - not typically used but needed for type consistency."""
|
|
192
|
+
from sqlspec.builder._column import ColumnExpression
|
|
193
|
+
|
|
194
|
+
# JoinBuilder doesn't have a direct expression, so this is a placeholder
|
|
195
|
+
# In practice, this shouldn't be called as joins are used differently
|
|
196
|
+
placeholder_expr = exp.Literal.string(f"join_{self._join_type.lower()}")
|
|
197
|
+
if other is None:
|
|
198
|
+
return ColumnExpression(exp.Is(this=placeholder_expr, expression=exp.Null()))
|
|
199
|
+
return ColumnExpression(exp.EQ(this=placeholder_expr, expression=exp.convert(other)))
|
|
200
|
+
|
|
201
|
+
def __hash__(self) -> int:
|
|
202
|
+
"""Make JoinBuilder hashable."""
|
|
203
|
+
return hash(id(self))
|
|
204
|
+
|
|
205
|
+
def __call__(self, table: Union[str, exp.Expression], alias: Optional[str] = None) -> Self:
|
|
206
|
+
"""Set the table to join.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
table: Table name or expression to join
|
|
210
|
+
alias: Optional alias for the table
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
Self for method chaining
|
|
214
|
+
"""
|
|
215
|
+
self._table = table
|
|
216
|
+
self._alias = alias
|
|
217
|
+
return self
|
|
218
|
+
|
|
219
|
+
def on(self, condition: Union[str, exp.Expression]) -> exp.Expression:
|
|
220
|
+
"""Set the join condition and build the JOIN expression.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
condition: JOIN condition (e.g., "users.id = posts.user_id")
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
Complete JOIN expression
|
|
227
|
+
"""
|
|
228
|
+
if not self._table:
|
|
229
|
+
msg = "Table must be set before calling .on()"
|
|
230
|
+
raise SQLBuilderError(msg)
|
|
231
|
+
|
|
232
|
+
# Parse the condition
|
|
233
|
+
condition_expr: exp.Expression
|
|
234
|
+
if isinstance(condition, str):
|
|
235
|
+
parsed: Optional[exp.Expression] = exp.maybe_parse(condition)
|
|
236
|
+
condition_expr = parsed or exp.condition(condition)
|
|
237
|
+
else:
|
|
238
|
+
condition_expr = condition
|
|
239
|
+
|
|
240
|
+
# Build table expression
|
|
241
|
+
table_expr: exp.Expression
|
|
242
|
+
if isinstance(self._table, str):
|
|
243
|
+
table_expr = exp.to_table(self._table)
|
|
244
|
+
if self._alias:
|
|
245
|
+
table_expr = exp.alias_(table_expr, self._alias)
|
|
246
|
+
else:
|
|
247
|
+
table_expr = self._table
|
|
248
|
+
if self._alias:
|
|
249
|
+
table_expr = exp.alias_(table_expr, self._alias)
|
|
250
|
+
|
|
251
|
+
# Create the appropriate join type using same pattern as existing JoinClauseMixin
|
|
252
|
+
if self._join_type == "INNER JOIN":
|
|
253
|
+
return exp.Join(this=table_expr, on=condition_expr)
|
|
254
|
+
if self._join_type == "LEFT JOIN":
|
|
255
|
+
return exp.Join(this=table_expr, on=condition_expr, side="LEFT")
|
|
256
|
+
if self._join_type == "RIGHT JOIN":
|
|
257
|
+
return exp.Join(this=table_expr, on=condition_expr, side="RIGHT")
|
|
258
|
+
if self._join_type == "FULL JOIN":
|
|
259
|
+
return exp.Join(this=table_expr, on=condition_expr, side="FULL", kind="OUTER")
|
|
260
|
+
if self._join_type == "CROSS JOIN":
|
|
261
|
+
# CROSS JOIN doesn't use ON condition
|
|
262
|
+
return exp.Join(this=table_expr, kind="CROSS")
|
|
263
|
+
return exp.Join(this=table_expr, on=condition_expr)
|