sqlspec 0.21.1__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

sqlspec/_sql.py CHANGED
@@ -628,6 +628,42 @@ class SQLFactory:
628
628
  """Create a CROSS JOIN builder."""
629
629
  return JoinBuilder("cross join")
630
630
 
631
+ @property
632
+ def lateral_join_(self) -> "JoinBuilder":
633
+ """Create a LATERAL JOIN builder.
634
+
635
+ Returns:
636
+ JoinBuilder configured for LATERAL JOIN
637
+
638
+ Example:
639
+ ```python
640
+ query = (
641
+ sql.select("u.name", "arr.value")
642
+ .from_("users u")
643
+ .join(sql.lateral_join_("UNNEST(u.tags)").on("true"))
644
+ )
645
+ ```
646
+ """
647
+ return JoinBuilder("lateral join", lateral=True)
648
+
649
+ @property
650
+ def left_lateral_join_(self) -> "JoinBuilder":
651
+ """Create a LEFT LATERAL JOIN builder.
652
+
653
+ Returns:
654
+ JoinBuilder configured for LEFT LATERAL JOIN
655
+ """
656
+ return JoinBuilder("left join", lateral=True)
657
+
658
+ @property
659
+ def cross_lateral_join_(self) -> "JoinBuilder":
660
+ """Create a CROSS LATERAL JOIN builder.
661
+
662
+ Returns:
663
+ JoinBuilder configured for CROSS LATERAL JOIN
664
+ """
665
+ return JoinBuilder("cross join", lateral=True)
666
+
631
667
  def __getattr__(self, name: str) -> "Column":
632
668
  """Dynamically create column references.
633
669
 
sqlspec/base.py CHANGED
@@ -64,7 +64,7 @@ class SQLSpec:
64
64
  config.close_pool()
65
65
  cleaned_count += 1
66
66
  except Exception as e:
67
- logger.warning("Failed to clean up sync pool for config %s: %s", config_type.__name__, e)
67
+ logger.debug("Failed to clean up sync pool for config %s: %s", config_type.__name__, e)
68
68
 
69
69
  if cleaned_count > 0:
70
70
  logger.debug("Sync pool cleanup completed. Cleaned %d pools.", cleaned_count)
@@ -87,14 +87,14 @@ class SQLSpec:
87
87
  else:
88
88
  sync_configs.append((config_type, config))
89
89
  except Exception as e:
90
- logger.warning("Failed to prepare cleanup for config %s: %s", config_type.__name__, e)
90
+ logger.debug("Failed to prepare cleanup for config %s: %s", config_type.__name__, e)
91
91
 
92
92
  if cleanup_tasks:
93
93
  try:
94
94
  await asyncio.gather(*cleanup_tasks, return_exceptions=True)
95
95
  logger.debug("Async pool cleanup completed. Cleaned %d pools.", len(cleanup_tasks))
96
96
  except Exception as e:
97
- logger.warning("Failed to complete async pool cleanup: %s", e)
97
+ logger.debug("Failed to complete async pool cleanup: %s", e)
98
98
 
99
99
  for _config_type, config in sync_configs:
100
100
  config.close_pool()
@@ -129,7 +129,7 @@ class SQLSpec:
129
129
  """
130
130
  config_type = type(config)
131
131
  if config_type in self._configs:
132
- logger.warning("Configuration for %s already exists. Overwriting.", config_type.__name__)
132
+ logger.debug("Configuration for %s already exists. Overwriting.", config_type.__name__)
133
133
  self._configs[config_type] = config
134
134
  return config_type
135
135
 
@@ -14,7 +14,6 @@ from sqlspec.exceptions import SQLBuilderError
14
14
  from sqlspec.utils.type_guards import has_query_builder_parameters
15
15
 
16
16
  if TYPE_CHECKING:
17
- from sqlspec.builder._column import ColumnExpression
18
17
  from sqlspec.core.statement import SQL
19
18
  from sqlspec.protocols import SQLBuilderProtocol
20
19
 
@@ -36,74 +35,133 @@ class JoinClauseMixin:
36
35
  on: Optional[Union[str, exp.Expression, "SQL"]] = None,
37
36
  alias: Optional[str] = None,
38
37
  join_type: str = "INNER",
38
+ lateral: bool = False,
39
39
  ) -> Self:
40
40
  builder = cast("SQLBuilderProtocol", self)
41
+ self._validate_join_context(builder)
42
+
43
+ # Handle Join expressions directly (from JoinBuilder.on() calls)
44
+ if isinstance(table, exp.Join):
45
+ if builder._expression is not None and isinstance(builder._expression, exp.Select):
46
+ builder._expression = builder._expression.join(table, copy=False)
47
+ return cast("Self", builder)
48
+
49
+ table_expr = self._parse_table_expression(table, alias, builder)
50
+ on_expr = self._parse_on_condition(on, builder)
51
+ join_expr = self._create_join_expression(table_expr, on_expr, join_type)
52
+
53
+ if lateral:
54
+ self._apply_lateral_modifier(join_expr)
55
+
56
+ if builder._expression is not None and isinstance(builder._expression, exp.Select):
57
+ builder._expression = builder._expression.join(join_expr, copy=False)
58
+ return cast("Self", builder)
59
+
60
+ def _validate_join_context(self, builder: "SQLBuilderProtocol") -> None:
61
+ """Validate that the join can be applied to the current expression."""
41
62
  if builder._expression is None:
42
63
  builder._expression = exp.Select()
43
64
  if not isinstance(builder._expression, exp.Select):
44
65
  msg = "JOIN clause is only supported for SELECT statements."
45
66
  raise SQLBuilderError(msg)
46
- table_expr: exp.Expression
67
+
68
+ def _parse_table_expression(
69
+ self, table: Union[str, exp.Expression, Any], alias: Optional[str], builder: "SQLBuilderProtocol"
70
+ ) -> exp.Expression:
71
+ """Parse table parameter into a SQLGlot expression."""
47
72
  if isinstance(table, str):
48
- table_expr = parse_table_expression(table, alias)
49
- elif has_query_builder_parameters(table):
50
- if hasattr(table, "_expression") and getattr(table, "_expression", None) is not None:
51
- table_expr_value = getattr(table, "_expression", None)
52
- if table_expr_value is not None:
53
- subquery_exp = exp.paren(table_expr_value)
54
- else:
55
- subquery_exp = exp.paren(exp.Anonymous(this=""))
56
- table_expr = exp.alias_(subquery_exp, alias) if alias else subquery_exp
73
+ return parse_table_expression(table, alias)
74
+ if has_query_builder_parameters(table):
75
+ return self._handle_query_builder_table(table, alias, builder)
76
+ if isinstance(table, exp.Expression):
77
+ return table
78
+ return cast("exp.Expression", table)
79
+
80
+ def _handle_query_builder_table(
81
+ self, table: Any, alias: Optional[str], builder: "SQLBuilderProtocol"
82
+ ) -> exp.Expression:
83
+ """Handle table parameters that are query builders."""
84
+ if hasattr(table, "_expression") and getattr(table, "_expression", None) is not None:
85
+ table_expr_value = getattr(table, "_expression", None)
86
+ if table_expr_value is not None:
87
+ subquery_exp = exp.paren(table_expr_value)
57
88
  else:
58
- subquery = table.build()
59
- sql_str = subquery.sql if hasattr(subquery, "sql") and not callable(subquery.sql) else str(subquery)
60
- subquery_exp = exp.paren(exp.maybe_parse(sql_str, dialect=getattr(builder, "dialect", None)))
61
- table_expr = exp.alias_(subquery_exp, alias) if alias else subquery_exp
62
- else:
63
- table_expr = table
64
- on_expr: Optional[exp.Expression] = None
65
- if on is not None:
66
- if isinstance(on, str):
67
- on_expr = exp.condition(on)
68
- elif hasattr(on, "expression") and hasattr(on, "sql"):
69
- # Handle SQL objects (from sql.raw with parameters)
70
- expression = getattr(on, "expression", None)
71
- if expression is not None and isinstance(expression, exp.Expression):
72
- # Merge parameters from SQL object into builder
73
- if hasattr(on, "parameters") and hasattr(builder, "add_parameter"):
74
- sql_parameters = getattr(on, "parameters", {})
75
- for param_name, param_value in sql_parameters.items():
76
- builder.add_parameter(param_value, name=param_name)
77
- on_expr = expression
78
- else:
79
- # If expression is None, fall back to parsing the raw SQL
80
- sql_text = getattr(on, "sql", "")
81
- # Merge parameters even when parsing raw SQL
82
- if hasattr(on, "parameters") and hasattr(builder, "add_parameter"):
83
- sql_parameters = getattr(on, "parameters", {})
84
- for param_name, param_value in sql_parameters.items():
85
- builder.add_parameter(param_value, name=param_name)
86
- on_expr = exp.maybe_parse(sql_text) or exp.condition(str(sql_text))
87
- # For other types (should be exp.Expression)
88
- elif isinstance(on, exp.Expression):
89
- on_expr = on
90
- else:
91
- # Last resort - convert to string and parse
92
- on_expr = exp.condition(str(on))
89
+ subquery_exp = exp.paren(exp.Anonymous(this=""))
90
+ return exp.alias_(subquery_exp, alias) if alias else subquery_exp
91
+ subquery = table.build()
92
+ sql_str = subquery.sql if hasattr(subquery, "sql") and not callable(subquery.sql) else str(subquery)
93
+ subquery_exp = exp.paren(exp.maybe_parse(sql_str, dialect=getattr(builder, "dialect", None)))
94
+ return exp.alias_(subquery_exp, alias) if alias else subquery_exp
95
+
96
+ def _parse_on_condition(
97
+ self, on: Optional[Union[str, exp.Expression, "SQL"]], builder: "SQLBuilderProtocol"
98
+ ) -> Optional[exp.Expression]:
99
+ """Parse ON condition into a SQLGlot expression."""
100
+ if on is None:
101
+ return None
102
+
103
+ if isinstance(on, str):
104
+ return exp.condition(on)
105
+ if hasattr(on, "expression") and hasattr(on, "sql"):
106
+ return self._handle_sql_object_condition(on, builder)
107
+ if isinstance(on, exp.Expression):
108
+ return on
109
+ # Last resort - convert to string and parse
110
+ return exp.condition(str(on))
111
+
112
+ def _handle_sql_object_condition(self, on: Any, builder: "SQLBuilderProtocol") -> exp.Expression:
113
+ """Handle SQL object conditions with parameter binding."""
114
+ expression = getattr(on, "expression", None)
115
+ if expression is not None and isinstance(expression, exp.Expression):
116
+ # Merge parameters from SQL object into builder
117
+ if hasattr(on, "parameters") and hasattr(builder, "add_parameter"):
118
+ sql_parameters = getattr(on, "parameters", {})
119
+ for param_name, param_value in sql_parameters.items():
120
+ builder.add_parameter(param_value, name=param_name)
121
+ return cast("exp.Expression", expression)
122
+ # If expression is None, fall back to parsing the raw SQL
123
+ sql_text = getattr(on, "sql", "")
124
+ # Merge parameters even when parsing raw SQL
125
+ if hasattr(on, "parameters") and hasattr(builder, "add_parameter"):
126
+ sql_parameters = getattr(on, "parameters", {})
127
+ for param_name, param_value in sql_parameters.items():
128
+ builder.add_parameter(param_value, name=param_name)
129
+ parsed_expr = exp.maybe_parse(sql_text)
130
+ return parsed_expr if parsed_expr is not None else exp.condition(str(sql_text))
131
+
132
+ def _create_join_expression(
133
+ self, table_expr: exp.Expression, on_expr: Optional[exp.Expression], join_type: str
134
+ ) -> exp.Join:
135
+ """Create the appropriate JOIN expression based on join type."""
93
136
  join_type_upper = join_type.upper()
94
137
  if join_type_upper == "INNER":
95
- join_expr = exp.Join(this=table_expr, on=on_expr)
96
- elif join_type_upper == "LEFT":
97
- join_expr = exp.Join(this=table_expr, on=on_expr, side="LEFT")
98
- elif join_type_upper == "RIGHT":
99
- join_expr = exp.Join(this=table_expr, on=on_expr, side="RIGHT")
100
- elif join_type_upper == "FULL":
101
- join_expr = exp.Join(this=table_expr, on=on_expr, side="FULL", kind="OUTER")
138
+ return exp.Join(this=table_expr, on=on_expr)
139
+ if join_type_upper == "LEFT":
140
+ return exp.Join(this=table_expr, on=on_expr, side="LEFT")
141
+ if join_type_upper == "RIGHT":
142
+ return exp.Join(this=table_expr, on=on_expr, side="RIGHT")
143
+ if join_type_upper == "FULL":
144
+ return exp.Join(this=table_expr, on=on_expr, side="FULL", kind="OUTER")
145
+ if join_type_upper == "CROSS":
146
+ return exp.Join(this=table_expr, kind="CROSS")
147
+ msg = f"Unsupported join type: {join_type}"
148
+ raise SQLBuilderError(msg)
149
+
150
+ def _apply_lateral_modifier(self, join_expr: exp.Join) -> None:
151
+ """Apply LATERAL modifier to the join expression."""
152
+ current_kind = join_expr.args.get("kind")
153
+ current_side = join_expr.args.get("side")
154
+
155
+ if current_kind == "CROSS":
156
+ join_expr.set("kind", "CROSS LATERAL")
157
+ elif current_kind == "OUTER" and current_side == "FULL":
158
+ join_expr.set("side", "FULL") # Keep side
159
+ join_expr.set("kind", "OUTER LATERAL")
160
+ elif current_side:
161
+ join_expr.set("kind", f"{current_side} LATERAL")
162
+ join_expr.set("side", None) # Clear side to avoid duplication
102
163
  else:
103
- msg = f"Unsupported join type: {join_type}"
104
- raise SQLBuilderError(msg)
105
- builder._expression = builder._expression.join(join_expr, copy=False)
106
- return cast("Self", builder)
164
+ join_expr.set("kind", "LATERAL")
107
165
 
108
166
  def inner_join(
109
167
  self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression, "SQL"], alias: Optional[str] = None
@@ -154,6 +212,63 @@ class JoinClauseMixin:
154
212
  builder._expression = builder._expression.join(join_expr, copy=False)
155
213
  return cast("Self", builder)
156
214
 
215
+ def lateral_join(
216
+ self,
217
+ table: Union[str, exp.Expression, Any],
218
+ on: Optional[Union[str, exp.Expression, "SQL"]] = None,
219
+ alias: Optional[str] = None,
220
+ ) -> Self:
221
+ """Create a LATERAL JOIN.
222
+
223
+ Args:
224
+ table: Table, subquery, or table function to join
225
+ on: Optional join condition (for LATERAL JOINs with ON clause)
226
+ alias: Optional alias for the joined table/subquery
227
+
228
+ Returns:
229
+ Self for method chaining
230
+
231
+ Example:
232
+ ```python
233
+ query = (
234
+ sql.select("u.name", "arr.value")
235
+ .from_("users u")
236
+ .lateral_join("UNNEST(u.tags)", alias="arr")
237
+ )
238
+ ```
239
+ """
240
+ return self.join(table, on=on, alias=alias, join_type="INNER", lateral=True)
241
+
242
+ def left_lateral_join(
243
+ self,
244
+ table: Union[str, exp.Expression, Any],
245
+ on: Optional[Union[str, exp.Expression, "SQL"]] = None,
246
+ alias: Optional[str] = None,
247
+ ) -> Self:
248
+ """Create a LEFT LATERAL JOIN.
249
+
250
+ Args:
251
+ table: Table, subquery, or table function to join
252
+ on: Optional join condition
253
+ alias: Optional alias for the joined table/subquery
254
+
255
+ Returns:
256
+ Self for method chaining
257
+ """
258
+ return self.join(table, on=on, alias=alias, join_type="LEFT", lateral=True)
259
+
260
+ def cross_lateral_join(self, table: Union[str, exp.Expression, Any], alias: Optional[str] = None) -> Self:
261
+ """Create a CROSS LATERAL JOIN (no ON condition).
262
+
263
+ Args:
264
+ table: Table, subquery, or table function to join
265
+ alias: Optional alias for the joined table/subquery
266
+
267
+ Returns:
268
+ Self for method chaining
269
+ """
270
+ return self.join(table, on=None, alias=alias, join_type="CROSS", lateral=True)
271
+
157
272
 
158
273
  @trait
159
274
  class JoinBuilder:
@@ -181,32 +296,19 @@ class JoinBuilder:
181
296
  ```
182
297
  """
183
298
 
184
- def __init__(self, join_type: str) -> None:
299
+ def __init__(self, join_type: str, lateral: bool = False) -> None:
185
300
  """Initialize the join builder.
186
301
 
187
302
  Args:
188
- join_type: Type of join (inner, left, right, full, cross)
303
+ join_type: Type of join (inner, left, right, full, cross, lateral)
304
+ lateral: Whether this is a LATERAL join
189
305
  """
190
306
  self._join_type = join_type.upper()
307
+ self._lateral = lateral
191
308
  self._table: Optional[Union[str, exp.Expression]] = None
192
309
  self._condition: Optional[exp.Expression] = None
193
310
  self._alias: Optional[str] = None
194
311
 
195
- def __eq__(self, other: object) -> "ColumnExpression": # type: ignore[override]
196
- """Equal to (==) - not typically used but needed for type consistency."""
197
- from sqlspec.builder._column import ColumnExpression
198
-
199
- # JoinBuilder doesn't have a direct expression, so this is a placeholder
200
- # In practice, this shouldn't be called as joins are used differently
201
- placeholder_expr = exp.Literal.string(f"join_{self._join_type.lower()}")
202
- if other is None:
203
- return ColumnExpression(exp.Is(this=placeholder_expr, expression=exp.Null()))
204
- return ColumnExpression(exp.EQ(this=placeholder_expr, expression=exp.convert(other)))
205
-
206
- def __hash__(self) -> int:
207
- """Make JoinBuilder hashable."""
208
- return hash(id(self))
209
-
210
312
  def __call__(self, table: Union[str, exp.Expression], alias: Optional[str] = None) -> Self:
211
313
  """Set the table to join.
212
314
 
@@ -254,15 +356,33 @@ class JoinBuilder:
254
356
  table_expr = exp.alias_(table_expr, self._alias)
255
357
 
256
358
  # Create the appropriate join type using same pattern as existing JoinClauseMixin
257
- if self._join_type == "INNER JOIN":
258
- return exp.Join(this=table_expr, on=condition_expr)
259
- if self._join_type == "LEFT JOIN":
260
- return exp.Join(this=table_expr, on=condition_expr, side="LEFT")
261
- if self._join_type == "RIGHT JOIN":
262
- return exp.Join(this=table_expr, on=condition_expr, side="RIGHT")
263
- if self._join_type == "FULL JOIN":
264
- return exp.Join(this=table_expr, on=condition_expr, side="FULL", kind="OUTER")
265
- if self._join_type == "CROSS JOIN":
359
+ if self._join_type in {"INNER JOIN", "INNER", "LATERAL JOIN"}:
360
+ join_expr = exp.Join(this=table_expr, on=condition_expr)
361
+ elif self._join_type in {"LEFT JOIN", "LEFT"}:
362
+ join_expr = exp.Join(this=table_expr, on=condition_expr, side="LEFT")
363
+ elif self._join_type in {"RIGHT JOIN", "RIGHT"}:
364
+ join_expr = exp.Join(this=table_expr, on=condition_expr, side="RIGHT")
365
+ elif self._join_type in {"FULL JOIN", "FULL"}:
366
+ join_expr = exp.Join(this=table_expr, on=condition_expr, side="FULL", kind="OUTER")
367
+ elif self._join_type in {"CROSS JOIN", "CROSS"}:
266
368
  # CROSS JOIN doesn't use ON condition
267
- return exp.Join(this=table_expr, kind="CROSS")
268
- return exp.Join(this=table_expr, on=condition_expr)
369
+ join_expr = exp.Join(this=table_expr, kind="CROSS")
370
+ else:
371
+ join_expr = exp.Join(this=table_expr, on=condition_expr)
372
+
373
+ if self._lateral or self._join_type == "LATERAL JOIN":
374
+ current_kind = join_expr.args.get("kind")
375
+ current_side = join_expr.args.get("side")
376
+
377
+ if current_kind == "CROSS":
378
+ join_expr.set("kind", "CROSS LATERAL")
379
+ elif current_kind == "OUTER" and current_side == "FULL":
380
+ join_expr.set("side", "FULL") # Keep side
381
+ join_expr.set("kind", "OUTER LATERAL")
382
+ elif current_side:
383
+ join_expr.set("kind", f"{current_side} LATERAL")
384
+ join_expr.set("side", None) # Clear side to avoid duplication
385
+ else:
386
+ join_expr.set("kind", "LATERAL")
387
+
388
+ return join_expr
sqlspec/loader.py CHANGED
@@ -10,18 +10,15 @@ import time
10
10
  from datetime import datetime, timezone
11
11
  from pathlib import Path
12
12
  from typing import TYPE_CHECKING, Any, Final, Optional, Union
13
+ from urllib.parse import unquote, urlparse
13
14
 
14
15
  from sqlspec.core.cache import CacheKey, get_cache_config, get_default_cache
15
16
  from sqlspec.core.statement import SQL
16
- from sqlspec.exceptions import (
17
- MissingDependencyError,
18
- SQLFileNotFoundError,
19
- SQLFileParseError,
20
- StorageOperationFailedError,
21
- )
17
+ from sqlspec.exceptions import SQLFileNotFoundError, SQLFileParseError, StorageOperationFailedError
22
18
  from sqlspec.storage.registry import storage_registry as default_storage_registry
23
19
  from sqlspec.utils.correlation import CorrelationContext
24
20
  from sqlspec.utils.logging import get_logger
21
+ from sqlspec.utils.text import slugify
25
22
 
26
23
  if TYPE_CHECKING:
27
24
  from sqlspec.storage.registry import StorageRegistry
@@ -54,13 +51,25 @@ MIN_QUERY_PARTS: Final = 3
54
51
  def _normalize_query_name(name: str) -> str:
55
52
  """Normalize query name to be a valid Python identifier.
56
53
 
54
+ Convert hyphens to underscores, preserve dots for namespacing,
55
+ and remove invalid characters.
56
+
57
57
  Args:
58
58
  name: Raw query name from SQL file.
59
59
 
60
60
  Returns:
61
61
  Normalized query name suitable as Python identifier.
62
62
  """
63
- return TRIM_SPECIAL_CHARS.sub("", name).replace("-", "_")
63
+ # Handle namespace parts separately to preserve dots
64
+ parts = name.split(".")
65
+ normalized_parts = []
66
+
67
+ for part in parts:
68
+ # Use slugify with underscore separator and remove any remaining invalid chars
69
+ normalized_part = slugify(part, separator="_")
70
+ normalized_parts.append(normalized_part)
71
+
72
+ return ".".join(normalized_parts)
64
73
 
65
74
 
66
75
  def _normalize_dialect(dialect: str) -> str:
@@ -76,19 +85,6 @@ def _normalize_dialect(dialect: str) -> str:
76
85
  return DIALECT_ALIASES.get(normalized, normalized)
77
86
 
78
87
 
79
- def _normalize_dialect_for_sqlglot(dialect: str) -> str:
80
- """Normalize dialect name for SQLGlot compatibility.
81
-
82
- Args:
83
- dialect: Dialect name from SQL file or parameter.
84
-
85
- Returns:
86
- SQLGlot-compatible dialect name.
87
- """
88
- normalized = dialect.lower().strip()
89
- return DIALECT_ALIASES.get(normalized, normalized)
90
-
91
-
92
88
  class NamedStatement:
93
89
  """Represents a parsed SQL statement with metadata.
94
90
 
@@ -218,8 +214,7 @@ class SQLFileLoader:
218
214
  SQLFileParseError: If file cannot be read.
219
215
  """
220
216
  try:
221
- content = self._read_file_content(path)
222
- return hashlib.md5(content.encode(), usedforsecurity=False).hexdigest()
217
+ return hashlib.md5(self._read_file_content(path).encode(), usedforsecurity=False).hexdigest()
223
218
  except Exception as e:
224
219
  raise SQLFileParseError(str(path), str(path), e) from e
225
220
 
@@ -253,19 +248,22 @@ class SQLFileLoader:
253
248
  SQLFileNotFoundError: If file does not exist.
254
249
  SQLFileParseError: If file cannot be read or parsed.
255
250
  """
256
-
257
251
  path_str = str(path)
258
252
 
259
253
  try:
260
254
  backend = self.storage_registry.get(path)
255
+ # For file:// URIs, extract just the filename for the backend call
256
+ if path_str.startswith("file://"):
257
+ parsed = urlparse(path_str)
258
+ file_path = unquote(parsed.path)
259
+ # Handle Windows paths (file:///C:/path)
260
+ if file_path and len(file_path) > 2 and file_path[2] == ":": # noqa: PLR2004
261
+ file_path = file_path[1:] # Remove leading slash for Windows
262
+ filename = Path(file_path).name
263
+ return backend.read_text(filename, encoding=self.encoding)
261
264
  return backend.read_text(path_str, encoding=self.encoding)
262
265
  except KeyError as e:
263
266
  raise SQLFileNotFoundError(path_str) from e
264
- except MissingDependencyError:
265
- try:
266
- return path.read_text(encoding=self.encoding) # type: ignore[union-attr]
267
- except FileNotFoundError as e:
268
- raise SQLFileNotFoundError(path_str) from e
269
267
  except StorageOperationFailedError as e:
270
268
  if "not found" in str(e).lower() or "no such file" in str(e).lower():
271
269
  raise SQLFileNotFoundError(path_str) from e
@@ -419,8 +417,7 @@ class SQLFileLoader:
419
417
  for file_path in sql_files:
420
418
  relative_path = file_path.relative_to(dir_path)
421
419
  namespace_parts = relative_path.parent.parts
422
- namespace = ".".join(namespace_parts) if namespace_parts else None
423
- self._load_single_file(file_path, namespace)
420
+ self._load_single_file(file_path, ".".join(namespace_parts) if namespace_parts else None)
424
421
  return len(sql_files)
425
422
 
426
423
  def _load_single_file(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
@@ -533,44 +530,6 @@ class SQLFileLoader:
533
530
  self._queries[normalized_name] = statement
534
531
  self._query_to_file[normalized_name] = "<directly added>"
535
532
 
536
- def get_sql(self, name: str) -> "SQL":
537
- """Get a SQL object by statement name.
538
-
539
- Args:
540
- name: Name of the statement (from -- name: in SQL file).
541
- Hyphens in names are converted to underscores.
542
-
543
- Returns:
544
- SQL object ready for execution.
545
-
546
- Raises:
547
- SQLFileNotFoundError: If statement name not found.
548
- """
549
- correlation_id = CorrelationContext.get()
550
-
551
- safe_name = _normalize_query_name(name)
552
-
553
- if safe_name not in self._queries:
554
- available = ", ".join(sorted(self._queries.keys())) if self._queries else "none"
555
- logger.error(
556
- "Statement not found: %s",
557
- name,
558
- extra={
559
- "statement_name": name,
560
- "safe_name": safe_name,
561
- "available_statements": len(self._queries),
562
- "correlation_id": correlation_id,
563
- },
564
- )
565
- raise SQLFileNotFoundError(name, path=f"Statement '{name}' not found. Available statements: {available}")
566
-
567
- parsed_statement = self._queries[safe_name]
568
- sqlglot_dialect = None
569
- if parsed_statement.dialect:
570
- sqlglot_dialect = _normalize_dialect_for_sqlglot(parsed_statement.dialect)
571
-
572
- return SQL(parsed_statement.sql, dialect=sqlglot_dialect)
573
-
574
533
  def get_file(self, path: Union[str, Path]) -> "Optional[SQLFile]":
575
534
  """Get a loaded SQLFile object by path.
576
535
 
@@ -659,3 +618,41 @@ class SQLFileLoader:
659
618
  if safe_name not in self._queries:
660
619
  raise SQLFileNotFoundError(name)
661
620
  return self._queries[safe_name].sql
621
+
622
+ def get_sql(self, name: str) -> "SQL":
623
+ """Get a SQL object by statement name.
624
+
625
+ Args:
626
+ name: Name of the statement (from -- name: in SQL file).
627
+ Hyphens in names are converted to underscores.
628
+
629
+ Returns:
630
+ SQL object ready for execution.
631
+
632
+ Raises:
633
+ SQLFileNotFoundError: If statement name not found.
634
+ """
635
+ correlation_id = CorrelationContext.get()
636
+
637
+ safe_name = _normalize_query_name(name)
638
+
639
+ if safe_name not in self._queries:
640
+ available = ", ".join(sorted(self._queries.keys())) if self._queries else "none"
641
+ logger.error(
642
+ "Statement not found: %s",
643
+ name,
644
+ extra={
645
+ "statement_name": name,
646
+ "safe_name": safe_name,
647
+ "available_statements": len(self._queries),
648
+ "correlation_id": correlation_id,
649
+ },
650
+ )
651
+ raise SQLFileNotFoundError(name, path=f"Statement '{name}' not found. Available statements: {available}")
652
+
653
+ parsed_statement = self._queries[safe_name]
654
+ sqlglot_dialect = None
655
+ if parsed_statement.dialect:
656
+ sqlglot_dialect = _normalize_dialect(parsed_statement.dialect)
657
+
658
+ return SQL(parsed_statement.sql, dialect=sqlglot_dialect)
sqlspec/protocols.py CHANGED
@@ -4,7 +4,7 @@ This module provides protocols that can be used for static type checking
4
4
  and runtime isinstance() checks.
5
5
  """
6
6
 
7
- from typing import TYPE_CHECKING, Any, ClassVar, Optional, Protocol, Union, runtime_checkable
7
+ from typing import TYPE_CHECKING, Any, Optional, Protocol, Union, runtime_checkable
8
8
 
9
9
  from typing_extensions import Self
10
10
 
@@ -14,7 +14,6 @@ if TYPE_CHECKING:
14
14
 
15
15
  from sqlglot import exp
16
16
 
17
- from sqlspec.storage.capabilities import StorageCapabilities
18
17
  from sqlspec.typing import ArrowRecordBatch, ArrowTable
19
18
 
20
19
  __all__ = (
@@ -194,9 +193,8 @@ class ObjectStoreItemProtocol(Protocol):
194
193
  class ObjectStoreProtocol(Protocol):
195
194
  """Protocol for object storage operations."""
196
195
 
197
- capabilities: ClassVar["StorageCapabilities"]
198
-
199
196
  protocol: str
197
+ backend_type: str
200
198
 
201
199
  def __init__(self, uri: str, **kwargs: Any) -> None:
202
200
  return
@@ -330,7 +328,7 @@ class ObjectStoreProtocol(Protocol):
330
328
  msg = "Async arrow writing not implemented"
331
329
  raise NotImplementedError(msg)
332
330
 
333
- async def stream_arrow_async(self, pattern: str, **kwargs: Any) -> "AsyncIterator[ArrowRecordBatch]":
331
+ def stream_arrow_async(self, pattern: str, **kwargs: Any) -> "AsyncIterator[ArrowRecordBatch]":
334
332
  """Async stream Arrow record batches from matching objects."""
335
333
  msg = "Async arrow streaming not implemented"
336
334
  raise NotImplementedError(msg)