sqlspec 0.21.1__py3-none-any.whl → 0.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +36 -0
- sqlspec/base.py +4 -4
- sqlspec/builder/mixins/_join_operations.py +205 -85
- sqlspec/loader.py +65 -68
- sqlspec/protocols.py +3 -5
- sqlspec/storage/__init__.py +2 -12
- sqlspec/storage/backends/__init__.py +1 -0
- sqlspec/storage/backends/fsspec.py +87 -147
- sqlspec/storage/backends/local.py +310 -0
- sqlspec/storage/backends/obstore.py +210 -192
- sqlspec/storage/registry.py +101 -70
- sqlspec/utils/sync_tools.py +8 -5
- {sqlspec-0.21.1.dist-info → sqlspec-0.23.0.dist-info}/METADATA +1 -1
- {sqlspec-0.21.1.dist-info → sqlspec-0.23.0.dist-info}/RECORD +18 -18
- sqlspec/storage/capabilities.py +0 -102
- {sqlspec-0.21.1.dist-info → sqlspec-0.23.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.21.1.dist-info → sqlspec-0.23.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.21.1.dist-info → sqlspec-0.23.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.21.1.dist-info → sqlspec-0.23.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/_sql.py
CHANGED
|
@@ -628,6 +628,42 @@ class SQLFactory:
|
|
|
628
628
|
"""Create a CROSS JOIN builder."""
|
|
629
629
|
return JoinBuilder("cross join")
|
|
630
630
|
|
|
631
|
+
@property
|
|
632
|
+
def lateral_join_(self) -> "JoinBuilder":
|
|
633
|
+
"""Create a LATERAL JOIN builder.
|
|
634
|
+
|
|
635
|
+
Returns:
|
|
636
|
+
JoinBuilder configured for LATERAL JOIN
|
|
637
|
+
|
|
638
|
+
Example:
|
|
639
|
+
```python
|
|
640
|
+
query = (
|
|
641
|
+
sql.select("u.name", "arr.value")
|
|
642
|
+
.from_("users u")
|
|
643
|
+
.join(sql.lateral_join_("UNNEST(u.tags)").on("true"))
|
|
644
|
+
)
|
|
645
|
+
```
|
|
646
|
+
"""
|
|
647
|
+
return JoinBuilder("lateral join", lateral=True)
|
|
648
|
+
|
|
649
|
+
@property
|
|
650
|
+
def left_lateral_join_(self) -> "JoinBuilder":
|
|
651
|
+
"""Create a LEFT LATERAL JOIN builder.
|
|
652
|
+
|
|
653
|
+
Returns:
|
|
654
|
+
JoinBuilder configured for LEFT LATERAL JOIN
|
|
655
|
+
"""
|
|
656
|
+
return JoinBuilder("left join", lateral=True)
|
|
657
|
+
|
|
658
|
+
@property
|
|
659
|
+
def cross_lateral_join_(self) -> "JoinBuilder":
|
|
660
|
+
"""Create a CROSS LATERAL JOIN builder.
|
|
661
|
+
|
|
662
|
+
Returns:
|
|
663
|
+
JoinBuilder configured for CROSS LATERAL JOIN
|
|
664
|
+
"""
|
|
665
|
+
return JoinBuilder("cross join", lateral=True)
|
|
666
|
+
|
|
631
667
|
def __getattr__(self, name: str) -> "Column":
|
|
632
668
|
"""Dynamically create column references.
|
|
633
669
|
|
sqlspec/base.py
CHANGED
|
@@ -64,7 +64,7 @@ class SQLSpec:
|
|
|
64
64
|
config.close_pool()
|
|
65
65
|
cleaned_count += 1
|
|
66
66
|
except Exception as e:
|
|
67
|
-
logger.
|
|
67
|
+
logger.debug("Failed to clean up sync pool for config %s: %s", config_type.__name__, e)
|
|
68
68
|
|
|
69
69
|
if cleaned_count > 0:
|
|
70
70
|
logger.debug("Sync pool cleanup completed. Cleaned %d pools.", cleaned_count)
|
|
@@ -87,14 +87,14 @@ class SQLSpec:
|
|
|
87
87
|
else:
|
|
88
88
|
sync_configs.append((config_type, config))
|
|
89
89
|
except Exception as e:
|
|
90
|
-
logger.
|
|
90
|
+
logger.debug("Failed to prepare cleanup for config %s: %s", config_type.__name__, e)
|
|
91
91
|
|
|
92
92
|
if cleanup_tasks:
|
|
93
93
|
try:
|
|
94
94
|
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
|
|
95
95
|
logger.debug("Async pool cleanup completed. Cleaned %d pools.", len(cleanup_tasks))
|
|
96
96
|
except Exception as e:
|
|
97
|
-
logger.
|
|
97
|
+
logger.debug("Failed to complete async pool cleanup: %s", e)
|
|
98
98
|
|
|
99
99
|
for _config_type, config in sync_configs:
|
|
100
100
|
config.close_pool()
|
|
@@ -129,7 +129,7 @@ class SQLSpec:
|
|
|
129
129
|
"""
|
|
130
130
|
config_type = type(config)
|
|
131
131
|
if config_type in self._configs:
|
|
132
|
-
logger.
|
|
132
|
+
logger.debug("Configuration for %s already exists. Overwriting.", config_type.__name__)
|
|
133
133
|
self._configs[config_type] = config
|
|
134
134
|
return config_type
|
|
135
135
|
|
|
@@ -14,7 +14,6 @@ from sqlspec.exceptions import SQLBuilderError
|
|
|
14
14
|
from sqlspec.utils.type_guards import has_query_builder_parameters
|
|
15
15
|
|
|
16
16
|
if TYPE_CHECKING:
|
|
17
|
-
from sqlspec.builder._column import ColumnExpression
|
|
18
17
|
from sqlspec.core.statement import SQL
|
|
19
18
|
from sqlspec.protocols import SQLBuilderProtocol
|
|
20
19
|
|
|
@@ -36,74 +35,133 @@ class JoinClauseMixin:
|
|
|
36
35
|
on: Optional[Union[str, exp.Expression, "SQL"]] = None,
|
|
37
36
|
alias: Optional[str] = None,
|
|
38
37
|
join_type: str = "INNER",
|
|
38
|
+
lateral: bool = False,
|
|
39
39
|
) -> Self:
|
|
40
40
|
builder = cast("SQLBuilderProtocol", self)
|
|
41
|
+
self._validate_join_context(builder)
|
|
42
|
+
|
|
43
|
+
# Handle Join expressions directly (from JoinBuilder.on() calls)
|
|
44
|
+
if isinstance(table, exp.Join):
|
|
45
|
+
if builder._expression is not None and isinstance(builder._expression, exp.Select):
|
|
46
|
+
builder._expression = builder._expression.join(table, copy=False)
|
|
47
|
+
return cast("Self", builder)
|
|
48
|
+
|
|
49
|
+
table_expr = self._parse_table_expression(table, alias, builder)
|
|
50
|
+
on_expr = self._parse_on_condition(on, builder)
|
|
51
|
+
join_expr = self._create_join_expression(table_expr, on_expr, join_type)
|
|
52
|
+
|
|
53
|
+
if lateral:
|
|
54
|
+
self._apply_lateral_modifier(join_expr)
|
|
55
|
+
|
|
56
|
+
if builder._expression is not None and isinstance(builder._expression, exp.Select):
|
|
57
|
+
builder._expression = builder._expression.join(join_expr, copy=False)
|
|
58
|
+
return cast("Self", builder)
|
|
59
|
+
|
|
60
|
+
def _validate_join_context(self, builder: "SQLBuilderProtocol") -> None:
|
|
61
|
+
"""Validate that the join can be applied to the current expression."""
|
|
41
62
|
if builder._expression is None:
|
|
42
63
|
builder._expression = exp.Select()
|
|
43
64
|
if not isinstance(builder._expression, exp.Select):
|
|
44
65
|
msg = "JOIN clause is only supported for SELECT statements."
|
|
45
66
|
raise SQLBuilderError(msg)
|
|
46
|
-
|
|
67
|
+
|
|
68
|
+
def _parse_table_expression(
|
|
69
|
+
self, table: Union[str, exp.Expression, Any], alias: Optional[str], builder: "SQLBuilderProtocol"
|
|
70
|
+
) -> exp.Expression:
|
|
71
|
+
"""Parse table parameter into a SQLGlot expression."""
|
|
47
72
|
if isinstance(table, str):
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
73
|
+
return parse_table_expression(table, alias)
|
|
74
|
+
if has_query_builder_parameters(table):
|
|
75
|
+
return self._handle_query_builder_table(table, alias, builder)
|
|
76
|
+
if isinstance(table, exp.Expression):
|
|
77
|
+
return table
|
|
78
|
+
return cast("exp.Expression", table)
|
|
79
|
+
|
|
80
|
+
def _handle_query_builder_table(
|
|
81
|
+
self, table: Any, alias: Optional[str], builder: "SQLBuilderProtocol"
|
|
82
|
+
) -> exp.Expression:
|
|
83
|
+
"""Handle table parameters that are query builders."""
|
|
84
|
+
if hasattr(table, "_expression") and getattr(table, "_expression", None) is not None:
|
|
85
|
+
table_expr_value = getattr(table, "_expression", None)
|
|
86
|
+
if table_expr_value is not None:
|
|
87
|
+
subquery_exp = exp.paren(table_expr_value)
|
|
57
88
|
else:
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
89
|
+
subquery_exp = exp.paren(exp.Anonymous(this=""))
|
|
90
|
+
return exp.alias_(subquery_exp, alias) if alias else subquery_exp
|
|
91
|
+
subquery = table.build()
|
|
92
|
+
sql_str = subquery.sql if hasattr(subquery, "sql") and not callable(subquery.sql) else str(subquery)
|
|
93
|
+
subquery_exp = exp.paren(exp.maybe_parse(sql_str, dialect=getattr(builder, "dialect", None)))
|
|
94
|
+
return exp.alias_(subquery_exp, alias) if alias else subquery_exp
|
|
95
|
+
|
|
96
|
+
def _parse_on_condition(
|
|
97
|
+
self, on: Optional[Union[str, exp.Expression, "SQL"]], builder: "SQLBuilderProtocol"
|
|
98
|
+
) -> Optional[exp.Expression]:
|
|
99
|
+
"""Parse ON condition into a SQLGlot expression."""
|
|
100
|
+
if on is None:
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
if isinstance(on, str):
|
|
104
|
+
return exp.condition(on)
|
|
105
|
+
if hasattr(on, "expression") and hasattr(on, "sql"):
|
|
106
|
+
return self._handle_sql_object_condition(on, builder)
|
|
107
|
+
if isinstance(on, exp.Expression):
|
|
108
|
+
return on
|
|
109
|
+
# Last resort - convert to string and parse
|
|
110
|
+
return exp.condition(str(on))
|
|
111
|
+
|
|
112
|
+
def _handle_sql_object_condition(self, on: Any, builder: "SQLBuilderProtocol") -> exp.Expression:
|
|
113
|
+
"""Handle SQL object conditions with parameter binding."""
|
|
114
|
+
expression = getattr(on, "expression", None)
|
|
115
|
+
if expression is not None and isinstance(expression, exp.Expression):
|
|
116
|
+
# Merge parameters from SQL object into builder
|
|
117
|
+
if hasattr(on, "parameters") and hasattr(builder, "add_parameter"):
|
|
118
|
+
sql_parameters = getattr(on, "parameters", {})
|
|
119
|
+
for param_name, param_value in sql_parameters.items():
|
|
120
|
+
builder.add_parameter(param_value, name=param_name)
|
|
121
|
+
return cast("exp.Expression", expression)
|
|
122
|
+
# If expression is None, fall back to parsing the raw SQL
|
|
123
|
+
sql_text = getattr(on, "sql", "")
|
|
124
|
+
# Merge parameters even when parsing raw SQL
|
|
125
|
+
if hasattr(on, "parameters") and hasattr(builder, "add_parameter"):
|
|
126
|
+
sql_parameters = getattr(on, "parameters", {})
|
|
127
|
+
for param_name, param_value in sql_parameters.items():
|
|
128
|
+
builder.add_parameter(param_value, name=param_name)
|
|
129
|
+
parsed_expr = exp.maybe_parse(sql_text)
|
|
130
|
+
return parsed_expr if parsed_expr is not None else exp.condition(str(sql_text))
|
|
131
|
+
|
|
132
|
+
def _create_join_expression(
|
|
133
|
+
self, table_expr: exp.Expression, on_expr: Optional[exp.Expression], join_type: str
|
|
134
|
+
) -> exp.Join:
|
|
135
|
+
"""Create the appropriate JOIN expression based on join type."""
|
|
93
136
|
join_type_upper = join_type.upper()
|
|
94
137
|
if join_type_upper == "INNER":
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
138
|
+
return exp.Join(this=table_expr, on=on_expr)
|
|
139
|
+
if join_type_upper == "LEFT":
|
|
140
|
+
return exp.Join(this=table_expr, on=on_expr, side="LEFT")
|
|
141
|
+
if join_type_upper == "RIGHT":
|
|
142
|
+
return exp.Join(this=table_expr, on=on_expr, side="RIGHT")
|
|
143
|
+
if join_type_upper == "FULL":
|
|
144
|
+
return exp.Join(this=table_expr, on=on_expr, side="FULL", kind="OUTER")
|
|
145
|
+
if join_type_upper == "CROSS":
|
|
146
|
+
return exp.Join(this=table_expr, kind="CROSS")
|
|
147
|
+
msg = f"Unsupported join type: {join_type}"
|
|
148
|
+
raise SQLBuilderError(msg)
|
|
149
|
+
|
|
150
|
+
def _apply_lateral_modifier(self, join_expr: exp.Join) -> None:
|
|
151
|
+
"""Apply LATERAL modifier to the join expression."""
|
|
152
|
+
current_kind = join_expr.args.get("kind")
|
|
153
|
+
current_side = join_expr.args.get("side")
|
|
154
|
+
|
|
155
|
+
if current_kind == "CROSS":
|
|
156
|
+
join_expr.set("kind", "CROSS LATERAL")
|
|
157
|
+
elif current_kind == "OUTER" and current_side == "FULL":
|
|
158
|
+
join_expr.set("side", "FULL") # Keep side
|
|
159
|
+
join_expr.set("kind", "OUTER LATERAL")
|
|
160
|
+
elif current_side:
|
|
161
|
+
join_expr.set("kind", f"{current_side} LATERAL")
|
|
162
|
+
join_expr.set("side", None) # Clear side to avoid duplication
|
|
102
163
|
else:
|
|
103
|
-
|
|
104
|
-
raise SQLBuilderError(msg)
|
|
105
|
-
builder._expression = builder._expression.join(join_expr, copy=False)
|
|
106
|
-
return cast("Self", builder)
|
|
164
|
+
join_expr.set("kind", "LATERAL")
|
|
107
165
|
|
|
108
166
|
def inner_join(
|
|
109
167
|
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression, "SQL"], alias: Optional[str] = None
|
|
@@ -154,6 +212,63 @@ class JoinClauseMixin:
|
|
|
154
212
|
builder._expression = builder._expression.join(join_expr, copy=False)
|
|
155
213
|
return cast("Self", builder)
|
|
156
214
|
|
|
215
|
+
def lateral_join(
|
|
216
|
+
self,
|
|
217
|
+
table: Union[str, exp.Expression, Any],
|
|
218
|
+
on: Optional[Union[str, exp.Expression, "SQL"]] = None,
|
|
219
|
+
alias: Optional[str] = None,
|
|
220
|
+
) -> Self:
|
|
221
|
+
"""Create a LATERAL JOIN.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
table: Table, subquery, or table function to join
|
|
225
|
+
on: Optional join condition (for LATERAL JOINs with ON clause)
|
|
226
|
+
alias: Optional alias for the joined table/subquery
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
Self for method chaining
|
|
230
|
+
|
|
231
|
+
Example:
|
|
232
|
+
```python
|
|
233
|
+
query = (
|
|
234
|
+
sql.select("u.name", "arr.value")
|
|
235
|
+
.from_("users u")
|
|
236
|
+
.lateral_join("UNNEST(u.tags)", alias="arr")
|
|
237
|
+
)
|
|
238
|
+
```
|
|
239
|
+
"""
|
|
240
|
+
return self.join(table, on=on, alias=alias, join_type="INNER", lateral=True)
|
|
241
|
+
|
|
242
|
+
def left_lateral_join(
|
|
243
|
+
self,
|
|
244
|
+
table: Union[str, exp.Expression, Any],
|
|
245
|
+
on: Optional[Union[str, exp.Expression, "SQL"]] = None,
|
|
246
|
+
alias: Optional[str] = None,
|
|
247
|
+
) -> Self:
|
|
248
|
+
"""Create a LEFT LATERAL JOIN.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
table: Table, subquery, or table function to join
|
|
252
|
+
on: Optional join condition
|
|
253
|
+
alias: Optional alias for the joined table/subquery
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
Self for method chaining
|
|
257
|
+
"""
|
|
258
|
+
return self.join(table, on=on, alias=alias, join_type="LEFT", lateral=True)
|
|
259
|
+
|
|
260
|
+
def cross_lateral_join(self, table: Union[str, exp.Expression, Any], alias: Optional[str] = None) -> Self:
|
|
261
|
+
"""Create a CROSS LATERAL JOIN (no ON condition).
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
table: Table, subquery, or table function to join
|
|
265
|
+
alias: Optional alias for the joined table/subquery
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
Self for method chaining
|
|
269
|
+
"""
|
|
270
|
+
return self.join(table, on=None, alias=alias, join_type="CROSS", lateral=True)
|
|
271
|
+
|
|
157
272
|
|
|
158
273
|
@trait
|
|
159
274
|
class JoinBuilder:
|
|
@@ -181,32 +296,19 @@ class JoinBuilder:
|
|
|
181
296
|
```
|
|
182
297
|
"""
|
|
183
298
|
|
|
184
|
-
def __init__(self, join_type: str) -> None:
|
|
299
|
+
def __init__(self, join_type: str, lateral: bool = False) -> None:
|
|
185
300
|
"""Initialize the join builder.
|
|
186
301
|
|
|
187
302
|
Args:
|
|
188
|
-
join_type: Type of join (inner, left, right, full, cross)
|
|
303
|
+
join_type: Type of join (inner, left, right, full, cross, lateral)
|
|
304
|
+
lateral: Whether this is a LATERAL join
|
|
189
305
|
"""
|
|
190
306
|
self._join_type = join_type.upper()
|
|
307
|
+
self._lateral = lateral
|
|
191
308
|
self._table: Optional[Union[str, exp.Expression]] = None
|
|
192
309
|
self._condition: Optional[exp.Expression] = None
|
|
193
310
|
self._alias: Optional[str] = None
|
|
194
311
|
|
|
195
|
-
def __eq__(self, other: object) -> "ColumnExpression": # type: ignore[override]
|
|
196
|
-
"""Equal to (==) - not typically used but needed for type consistency."""
|
|
197
|
-
from sqlspec.builder._column import ColumnExpression
|
|
198
|
-
|
|
199
|
-
# JoinBuilder doesn't have a direct expression, so this is a placeholder
|
|
200
|
-
# In practice, this shouldn't be called as joins are used differently
|
|
201
|
-
placeholder_expr = exp.Literal.string(f"join_{self._join_type.lower()}")
|
|
202
|
-
if other is None:
|
|
203
|
-
return ColumnExpression(exp.Is(this=placeholder_expr, expression=exp.Null()))
|
|
204
|
-
return ColumnExpression(exp.EQ(this=placeholder_expr, expression=exp.convert(other)))
|
|
205
|
-
|
|
206
|
-
def __hash__(self) -> int:
|
|
207
|
-
"""Make JoinBuilder hashable."""
|
|
208
|
-
return hash(id(self))
|
|
209
|
-
|
|
210
312
|
def __call__(self, table: Union[str, exp.Expression], alias: Optional[str] = None) -> Self:
|
|
211
313
|
"""Set the table to join.
|
|
212
314
|
|
|
@@ -254,15 +356,33 @@ class JoinBuilder:
|
|
|
254
356
|
table_expr = exp.alias_(table_expr, self._alias)
|
|
255
357
|
|
|
256
358
|
# Create the appropriate join type using same pattern as existing JoinClauseMixin
|
|
257
|
-
if self._join_type
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
359
|
+
if self._join_type in {"INNER JOIN", "INNER", "LATERAL JOIN"}:
|
|
360
|
+
join_expr = exp.Join(this=table_expr, on=condition_expr)
|
|
361
|
+
elif self._join_type in {"LEFT JOIN", "LEFT"}:
|
|
362
|
+
join_expr = exp.Join(this=table_expr, on=condition_expr, side="LEFT")
|
|
363
|
+
elif self._join_type in {"RIGHT JOIN", "RIGHT"}:
|
|
364
|
+
join_expr = exp.Join(this=table_expr, on=condition_expr, side="RIGHT")
|
|
365
|
+
elif self._join_type in {"FULL JOIN", "FULL"}:
|
|
366
|
+
join_expr = exp.Join(this=table_expr, on=condition_expr, side="FULL", kind="OUTER")
|
|
367
|
+
elif self._join_type in {"CROSS JOIN", "CROSS"}:
|
|
266
368
|
# CROSS JOIN doesn't use ON condition
|
|
267
|
-
|
|
268
|
-
|
|
369
|
+
join_expr = exp.Join(this=table_expr, kind="CROSS")
|
|
370
|
+
else:
|
|
371
|
+
join_expr = exp.Join(this=table_expr, on=condition_expr)
|
|
372
|
+
|
|
373
|
+
if self._lateral or self._join_type == "LATERAL JOIN":
|
|
374
|
+
current_kind = join_expr.args.get("kind")
|
|
375
|
+
current_side = join_expr.args.get("side")
|
|
376
|
+
|
|
377
|
+
if current_kind == "CROSS":
|
|
378
|
+
join_expr.set("kind", "CROSS LATERAL")
|
|
379
|
+
elif current_kind == "OUTER" and current_side == "FULL":
|
|
380
|
+
join_expr.set("side", "FULL") # Keep side
|
|
381
|
+
join_expr.set("kind", "OUTER LATERAL")
|
|
382
|
+
elif current_side:
|
|
383
|
+
join_expr.set("kind", f"{current_side} LATERAL")
|
|
384
|
+
join_expr.set("side", None) # Clear side to avoid duplication
|
|
385
|
+
else:
|
|
386
|
+
join_expr.set("kind", "LATERAL")
|
|
387
|
+
|
|
388
|
+
return join_expr
|
sqlspec/loader.py
CHANGED
|
@@ -10,18 +10,15 @@ import time
|
|
|
10
10
|
from datetime import datetime, timezone
|
|
11
11
|
from pathlib import Path
|
|
12
12
|
from typing import TYPE_CHECKING, Any, Final, Optional, Union
|
|
13
|
+
from urllib.parse import unquote, urlparse
|
|
13
14
|
|
|
14
15
|
from sqlspec.core.cache import CacheKey, get_cache_config, get_default_cache
|
|
15
16
|
from sqlspec.core.statement import SQL
|
|
16
|
-
from sqlspec.exceptions import
|
|
17
|
-
MissingDependencyError,
|
|
18
|
-
SQLFileNotFoundError,
|
|
19
|
-
SQLFileParseError,
|
|
20
|
-
StorageOperationFailedError,
|
|
21
|
-
)
|
|
17
|
+
from sqlspec.exceptions import SQLFileNotFoundError, SQLFileParseError, StorageOperationFailedError
|
|
22
18
|
from sqlspec.storage.registry import storage_registry as default_storage_registry
|
|
23
19
|
from sqlspec.utils.correlation import CorrelationContext
|
|
24
20
|
from sqlspec.utils.logging import get_logger
|
|
21
|
+
from sqlspec.utils.text import slugify
|
|
25
22
|
|
|
26
23
|
if TYPE_CHECKING:
|
|
27
24
|
from sqlspec.storage.registry import StorageRegistry
|
|
@@ -54,13 +51,25 @@ MIN_QUERY_PARTS: Final = 3
|
|
|
54
51
|
def _normalize_query_name(name: str) -> str:
|
|
55
52
|
"""Normalize query name to be a valid Python identifier.
|
|
56
53
|
|
|
54
|
+
Convert hyphens to underscores, preserve dots for namespacing,
|
|
55
|
+
and remove invalid characters.
|
|
56
|
+
|
|
57
57
|
Args:
|
|
58
58
|
name: Raw query name from SQL file.
|
|
59
59
|
|
|
60
60
|
Returns:
|
|
61
61
|
Normalized query name suitable as Python identifier.
|
|
62
62
|
"""
|
|
63
|
-
|
|
63
|
+
# Handle namespace parts separately to preserve dots
|
|
64
|
+
parts = name.split(".")
|
|
65
|
+
normalized_parts = []
|
|
66
|
+
|
|
67
|
+
for part in parts:
|
|
68
|
+
# Use slugify with underscore separator and remove any remaining invalid chars
|
|
69
|
+
normalized_part = slugify(part, separator="_")
|
|
70
|
+
normalized_parts.append(normalized_part)
|
|
71
|
+
|
|
72
|
+
return ".".join(normalized_parts)
|
|
64
73
|
|
|
65
74
|
|
|
66
75
|
def _normalize_dialect(dialect: str) -> str:
|
|
@@ -76,19 +85,6 @@ def _normalize_dialect(dialect: str) -> str:
|
|
|
76
85
|
return DIALECT_ALIASES.get(normalized, normalized)
|
|
77
86
|
|
|
78
87
|
|
|
79
|
-
def _normalize_dialect_for_sqlglot(dialect: str) -> str:
|
|
80
|
-
"""Normalize dialect name for SQLGlot compatibility.
|
|
81
|
-
|
|
82
|
-
Args:
|
|
83
|
-
dialect: Dialect name from SQL file or parameter.
|
|
84
|
-
|
|
85
|
-
Returns:
|
|
86
|
-
SQLGlot-compatible dialect name.
|
|
87
|
-
"""
|
|
88
|
-
normalized = dialect.lower().strip()
|
|
89
|
-
return DIALECT_ALIASES.get(normalized, normalized)
|
|
90
|
-
|
|
91
|
-
|
|
92
88
|
class NamedStatement:
|
|
93
89
|
"""Represents a parsed SQL statement with metadata.
|
|
94
90
|
|
|
@@ -218,8 +214,7 @@ class SQLFileLoader:
|
|
|
218
214
|
SQLFileParseError: If file cannot be read.
|
|
219
215
|
"""
|
|
220
216
|
try:
|
|
221
|
-
|
|
222
|
-
return hashlib.md5(content.encode(), usedforsecurity=False).hexdigest()
|
|
217
|
+
return hashlib.md5(self._read_file_content(path).encode(), usedforsecurity=False).hexdigest()
|
|
223
218
|
except Exception as e:
|
|
224
219
|
raise SQLFileParseError(str(path), str(path), e) from e
|
|
225
220
|
|
|
@@ -253,19 +248,22 @@ class SQLFileLoader:
|
|
|
253
248
|
SQLFileNotFoundError: If file does not exist.
|
|
254
249
|
SQLFileParseError: If file cannot be read or parsed.
|
|
255
250
|
"""
|
|
256
|
-
|
|
257
251
|
path_str = str(path)
|
|
258
252
|
|
|
259
253
|
try:
|
|
260
254
|
backend = self.storage_registry.get(path)
|
|
255
|
+
# For file:// URIs, extract just the filename for the backend call
|
|
256
|
+
if path_str.startswith("file://"):
|
|
257
|
+
parsed = urlparse(path_str)
|
|
258
|
+
file_path = unquote(parsed.path)
|
|
259
|
+
# Handle Windows paths (file:///C:/path)
|
|
260
|
+
if file_path and len(file_path) > 2 and file_path[2] == ":": # noqa: PLR2004
|
|
261
|
+
file_path = file_path[1:] # Remove leading slash for Windows
|
|
262
|
+
filename = Path(file_path).name
|
|
263
|
+
return backend.read_text(filename, encoding=self.encoding)
|
|
261
264
|
return backend.read_text(path_str, encoding=self.encoding)
|
|
262
265
|
except KeyError as e:
|
|
263
266
|
raise SQLFileNotFoundError(path_str) from e
|
|
264
|
-
except MissingDependencyError:
|
|
265
|
-
try:
|
|
266
|
-
return path.read_text(encoding=self.encoding) # type: ignore[union-attr]
|
|
267
|
-
except FileNotFoundError as e:
|
|
268
|
-
raise SQLFileNotFoundError(path_str) from e
|
|
269
267
|
except StorageOperationFailedError as e:
|
|
270
268
|
if "not found" in str(e).lower() or "no such file" in str(e).lower():
|
|
271
269
|
raise SQLFileNotFoundError(path_str) from e
|
|
@@ -419,8 +417,7 @@ class SQLFileLoader:
|
|
|
419
417
|
for file_path in sql_files:
|
|
420
418
|
relative_path = file_path.relative_to(dir_path)
|
|
421
419
|
namespace_parts = relative_path.parent.parts
|
|
422
|
-
|
|
423
|
-
self._load_single_file(file_path, namespace)
|
|
420
|
+
self._load_single_file(file_path, ".".join(namespace_parts) if namespace_parts else None)
|
|
424
421
|
return len(sql_files)
|
|
425
422
|
|
|
426
423
|
def _load_single_file(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
|
|
@@ -533,44 +530,6 @@ class SQLFileLoader:
|
|
|
533
530
|
self._queries[normalized_name] = statement
|
|
534
531
|
self._query_to_file[normalized_name] = "<directly added>"
|
|
535
532
|
|
|
536
|
-
def get_sql(self, name: str) -> "SQL":
|
|
537
|
-
"""Get a SQL object by statement name.
|
|
538
|
-
|
|
539
|
-
Args:
|
|
540
|
-
name: Name of the statement (from -- name: in SQL file).
|
|
541
|
-
Hyphens in names are converted to underscores.
|
|
542
|
-
|
|
543
|
-
Returns:
|
|
544
|
-
SQL object ready for execution.
|
|
545
|
-
|
|
546
|
-
Raises:
|
|
547
|
-
SQLFileNotFoundError: If statement name not found.
|
|
548
|
-
"""
|
|
549
|
-
correlation_id = CorrelationContext.get()
|
|
550
|
-
|
|
551
|
-
safe_name = _normalize_query_name(name)
|
|
552
|
-
|
|
553
|
-
if safe_name not in self._queries:
|
|
554
|
-
available = ", ".join(sorted(self._queries.keys())) if self._queries else "none"
|
|
555
|
-
logger.error(
|
|
556
|
-
"Statement not found: %s",
|
|
557
|
-
name,
|
|
558
|
-
extra={
|
|
559
|
-
"statement_name": name,
|
|
560
|
-
"safe_name": safe_name,
|
|
561
|
-
"available_statements": len(self._queries),
|
|
562
|
-
"correlation_id": correlation_id,
|
|
563
|
-
},
|
|
564
|
-
)
|
|
565
|
-
raise SQLFileNotFoundError(name, path=f"Statement '{name}' not found. Available statements: {available}")
|
|
566
|
-
|
|
567
|
-
parsed_statement = self._queries[safe_name]
|
|
568
|
-
sqlglot_dialect = None
|
|
569
|
-
if parsed_statement.dialect:
|
|
570
|
-
sqlglot_dialect = _normalize_dialect_for_sqlglot(parsed_statement.dialect)
|
|
571
|
-
|
|
572
|
-
return SQL(parsed_statement.sql, dialect=sqlglot_dialect)
|
|
573
|
-
|
|
574
533
|
def get_file(self, path: Union[str, Path]) -> "Optional[SQLFile]":
|
|
575
534
|
"""Get a loaded SQLFile object by path.
|
|
576
535
|
|
|
@@ -659,3 +618,41 @@ class SQLFileLoader:
|
|
|
659
618
|
if safe_name not in self._queries:
|
|
660
619
|
raise SQLFileNotFoundError(name)
|
|
661
620
|
return self._queries[safe_name].sql
|
|
621
|
+
|
|
622
|
+
def get_sql(self, name: str) -> "SQL":
|
|
623
|
+
"""Get a SQL object by statement name.
|
|
624
|
+
|
|
625
|
+
Args:
|
|
626
|
+
name: Name of the statement (from -- name: in SQL file).
|
|
627
|
+
Hyphens in names are converted to underscores.
|
|
628
|
+
|
|
629
|
+
Returns:
|
|
630
|
+
SQL object ready for execution.
|
|
631
|
+
|
|
632
|
+
Raises:
|
|
633
|
+
SQLFileNotFoundError: If statement name not found.
|
|
634
|
+
"""
|
|
635
|
+
correlation_id = CorrelationContext.get()
|
|
636
|
+
|
|
637
|
+
safe_name = _normalize_query_name(name)
|
|
638
|
+
|
|
639
|
+
if safe_name not in self._queries:
|
|
640
|
+
available = ", ".join(sorted(self._queries.keys())) if self._queries else "none"
|
|
641
|
+
logger.error(
|
|
642
|
+
"Statement not found: %s",
|
|
643
|
+
name,
|
|
644
|
+
extra={
|
|
645
|
+
"statement_name": name,
|
|
646
|
+
"safe_name": safe_name,
|
|
647
|
+
"available_statements": len(self._queries),
|
|
648
|
+
"correlation_id": correlation_id,
|
|
649
|
+
},
|
|
650
|
+
)
|
|
651
|
+
raise SQLFileNotFoundError(name, path=f"Statement '{name}' not found. Available statements: {available}")
|
|
652
|
+
|
|
653
|
+
parsed_statement = self._queries[safe_name]
|
|
654
|
+
sqlglot_dialect = None
|
|
655
|
+
if parsed_statement.dialect:
|
|
656
|
+
sqlglot_dialect = _normalize_dialect(parsed_statement.dialect)
|
|
657
|
+
|
|
658
|
+
return SQL(parsed_statement.sql, dialect=sqlglot_dialect)
|
sqlspec/protocols.py
CHANGED
|
@@ -4,7 +4,7 @@ This module provides protocols that can be used for static type checking
|
|
|
4
4
|
and runtime isinstance() checks.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from typing import TYPE_CHECKING, Any,
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Optional, Protocol, Union, runtime_checkable
|
|
8
8
|
|
|
9
9
|
from typing_extensions import Self
|
|
10
10
|
|
|
@@ -14,7 +14,6 @@ if TYPE_CHECKING:
|
|
|
14
14
|
|
|
15
15
|
from sqlglot import exp
|
|
16
16
|
|
|
17
|
-
from sqlspec.storage.capabilities import StorageCapabilities
|
|
18
17
|
from sqlspec.typing import ArrowRecordBatch, ArrowTable
|
|
19
18
|
|
|
20
19
|
__all__ = (
|
|
@@ -194,9 +193,8 @@ class ObjectStoreItemProtocol(Protocol):
|
|
|
194
193
|
class ObjectStoreProtocol(Protocol):
|
|
195
194
|
"""Protocol for object storage operations."""
|
|
196
195
|
|
|
197
|
-
capabilities: ClassVar["StorageCapabilities"]
|
|
198
|
-
|
|
199
196
|
protocol: str
|
|
197
|
+
backend_type: str
|
|
200
198
|
|
|
201
199
|
def __init__(self, uri: str, **kwargs: Any) -> None:
|
|
202
200
|
return
|
|
@@ -330,7 +328,7 @@ class ObjectStoreProtocol(Protocol):
|
|
|
330
328
|
msg = "Async arrow writing not implemented"
|
|
331
329
|
raise NotImplementedError(msg)
|
|
332
330
|
|
|
333
|
-
|
|
331
|
+
def stream_arrow_async(self, pattern: str, **kwargs: Any) -> "AsyncIterator[ArrowRecordBatch]":
|
|
334
332
|
"""Async stream Arrow record batches from matching objects."""
|
|
335
333
|
msg = "Async arrow streaming not implemented"
|
|
336
334
|
raise NotImplementedError(msg)
|