sqlspec 0.21.0__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

sqlspec/base.py CHANGED
@@ -64,7 +64,7 @@ class SQLSpec:
64
64
  config.close_pool()
65
65
  cleaned_count += 1
66
66
  except Exception as e:
67
- logger.warning("Failed to clean up sync pool for config %s: %s", config_type.__name__, e)
67
+ logger.debug("Failed to clean up sync pool for config %s: %s", config_type.__name__, e)
68
68
 
69
69
  if cleaned_count > 0:
70
70
  logger.debug("Sync pool cleanup completed. Cleaned %d pools.", cleaned_count)
@@ -87,14 +87,14 @@ class SQLSpec:
87
87
  else:
88
88
  sync_configs.append((config_type, config))
89
89
  except Exception as e:
90
- logger.warning("Failed to prepare cleanup for config %s: %s", config_type.__name__, e)
90
+ logger.debug("Failed to prepare cleanup for config %s: %s", config_type.__name__, e)
91
91
 
92
92
  if cleanup_tasks:
93
93
  try:
94
94
  await asyncio.gather(*cleanup_tasks, return_exceptions=True)
95
95
  logger.debug("Async pool cleanup completed. Cleaned %d pools.", len(cleanup_tasks))
96
96
  except Exception as e:
97
- logger.warning("Failed to complete async pool cleanup: %s", e)
97
+ logger.debug("Failed to complete async pool cleanup: %s", e)
98
98
 
99
99
  for _config_type, config in sync_configs:
100
100
  config.close_pool()
@@ -129,7 +129,7 @@ class SQLSpec:
129
129
  """
130
130
  config_type = type(config)
131
131
  if config_type in self._configs:
132
- logger.warning("Configuration for %s already exists. Overwriting.", config_type.__name__)
132
+ logger.debug("Configuration for %s already exists. Overwriting.", config_type.__name__)
133
133
  self._configs[config_type] = config
134
134
  return config_type
135
135
 
@@ -1,3 +1,4 @@
1
+ # ruff: noqa: C901
1
2
  """Result handling and schema conversion mixins for database drivers."""
2
3
 
3
4
  import datetime
@@ -22,7 +23,16 @@ from sqlspec.typing import (
22
23
  convert,
23
24
  get_type_adapter,
24
25
  )
25
- from sqlspec.utils.type_guards import is_attrs_schema, is_dataclass, is_msgspec_struct, is_pydantic_model
26
+ from sqlspec.utils.data_transformation import transform_dict_keys
27
+ from sqlspec.utils.text import camelize, kebabize, pascalize
28
+ from sqlspec.utils.type_guards import (
29
+ get_msgspec_rename_config,
30
+ is_attrs_schema,
31
+ is_dataclass,
32
+ is_dict,
33
+ is_msgspec_struct,
34
+ is_pydantic_model,
35
+ )
26
36
 
27
37
  __all__ = ("_DEFAULT_TYPE_DECODERS", "_default_msgspec_deserializer")
28
38
 
@@ -143,21 +153,46 @@ class ToSchemaMixin:
143
153
  if isinstance(data, list):
144
154
  result: list[Any] = []
145
155
  for item in data:
146
- if hasattr(item, "keys"):
156
+ if is_dict(item):
147
157
  result.append(schema_type(**dict(item))) # type: ignore[operator]
148
158
  else:
149
159
  result.append(item)
150
160
  return result
151
- if hasattr(data, "keys"):
161
+ if is_dict(data):
152
162
  return schema_type(**dict(data)) # type: ignore[operator]
153
163
  if isinstance(data, dict):
154
164
  return schema_type(**data) # type: ignore[operator]
155
165
  return data
156
166
  if is_msgspec_struct(schema_type):
167
+ rename_config = get_msgspec_rename_config(schema_type) # type: ignore[arg-type]
157
168
  deserializer = partial(_default_msgspec_deserializer, type_decoders=_DEFAULT_TYPE_DECODERS)
158
- if not isinstance(data, Sequence):
159
- return convert(obj=data, type=schema_type, from_attributes=True, dec_hook=deserializer)
160
- return convert(obj=data, type=list[schema_type], from_attributes=True, dec_hook=deserializer) # type: ignore[valid-type]
169
+
170
+ # Transform field names if rename configuration exists
171
+ transformed_data = data
172
+ if (rename_config and is_dict(data)) or (isinstance(data, Sequence) and data and is_dict(data[0])):
173
+ try:
174
+ converter = None
175
+ if rename_config == "camel":
176
+ converter = camelize
177
+ elif rename_config == "kebab":
178
+ converter = kebabize
179
+ elif rename_config == "pascal":
180
+ converter = pascalize
181
+
182
+ if converter is not None:
183
+ if isinstance(data, Sequence):
184
+ transformed_data = [
185
+ transform_dict_keys(item, converter) if is_dict(item) else item for item in data
186
+ ]
187
+ else:
188
+ transformed_data = transform_dict_keys(data, converter) if is_dict(data) else data
189
+ except Exception as e:
190
+ logger.debug("Field name transformation failed for msgspec schema: %s", e)
191
+ transformed_data = data
192
+
193
+ if not isinstance(transformed_data, Sequence):
194
+ return convert(obj=transformed_data, type=schema_type, from_attributes=True, dec_hook=deserializer)
195
+ return convert(obj=transformed_data, type=list[schema_type], from_attributes=True, dec_hook=deserializer) # type: ignore[valid-type]
161
196
  if is_pydantic_model(schema_type):
162
197
  if not isinstance(data, Sequence):
163
198
  adapter = get_type_adapter(schema_type)
sqlspec/loader.py CHANGED
@@ -10,18 +10,15 @@ import time
10
10
  from datetime import datetime, timezone
11
11
  from pathlib import Path
12
12
  from typing import TYPE_CHECKING, Any, Final, Optional, Union
13
+ from urllib.parse import unquote, urlparse
13
14
 
14
15
  from sqlspec.core.cache import CacheKey, get_cache_config, get_default_cache
15
16
  from sqlspec.core.statement import SQL
16
- from sqlspec.exceptions import (
17
- MissingDependencyError,
18
- SQLFileNotFoundError,
19
- SQLFileParseError,
20
- StorageOperationFailedError,
21
- )
17
+ from sqlspec.exceptions import SQLFileNotFoundError, SQLFileParseError, StorageOperationFailedError
22
18
  from sqlspec.storage.registry import storage_registry as default_storage_registry
23
19
  from sqlspec.utils.correlation import CorrelationContext
24
20
  from sqlspec.utils.logging import get_logger
21
+ from sqlspec.utils.text import slugify
25
22
 
26
23
  if TYPE_CHECKING:
27
24
  from sqlspec.storage.registry import StorageRegistry
@@ -54,13 +51,25 @@ MIN_QUERY_PARTS: Final = 3
54
51
  def _normalize_query_name(name: str) -> str:
55
52
  """Normalize query name to be a valid Python identifier.
56
53
 
54
+ Convert hyphens to underscores, preserve dots for namespacing,
55
+ and remove invalid characters.
56
+
57
57
  Args:
58
58
  name: Raw query name from SQL file.
59
59
 
60
60
  Returns:
61
61
  Normalized query name suitable as Python identifier.
62
62
  """
63
- return TRIM_SPECIAL_CHARS.sub("", name).replace("-", "_")
63
+ # Handle namespace parts separately to preserve dots
64
+ parts = name.split(".")
65
+ normalized_parts = []
66
+
67
+ for part in parts:
68
+ # Use slugify with underscore separator and remove any remaining invalid chars
69
+ normalized_part = slugify(part, separator="_")
70
+ normalized_parts.append(normalized_part)
71
+
72
+ return ".".join(normalized_parts)
64
73
 
65
74
 
66
75
  def _normalize_dialect(dialect: str) -> str:
@@ -76,19 +85,6 @@ def _normalize_dialect(dialect: str) -> str:
76
85
  return DIALECT_ALIASES.get(normalized, normalized)
77
86
 
78
87
 
79
- def _normalize_dialect_for_sqlglot(dialect: str) -> str:
80
- """Normalize dialect name for SQLGlot compatibility.
81
-
82
- Args:
83
- dialect: Dialect name from SQL file or parameter.
84
-
85
- Returns:
86
- SQLGlot-compatible dialect name.
87
- """
88
- normalized = dialect.lower().strip()
89
- return DIALECT_ALIASES.get(normalized, normalized)
90
-
91
-
92
88
  class NamedStatement:
93
89
  """Represents a parsed SQL statement with metadata.
94
90
 
@@ -218,8 +214,7 @@ class SQLFileLoader:
218
214
  SQLFileParseError: If file cannot be read.
219
215
  """
220
216
  try:
221
- content = self._read_file_content(path)
222
- return hashlib.md5(content.encode(), usedforsecurity=False).hexdigest()
217
+ return hashlib.md5(self._read_file_content(path).encode(), usedforsecurity=False).hexdigest()
223
218
  except Exception as e:
224
219
  raise SQLFileParseError(str(path), str(path), e) from e
225
220
 
@@ -253,19 +248,22 @@ class SQLFileLoader:
253
248
  SQLFileNotFoundError: If file does not exist.
254
249
  SQLFileParseError: If file cannot be read or parsed.
255
250
  """
256
-
257
251
  path_str = str(path)
258
252
 
259
253
  try:
260
254
  backend = self.storage_registry.get(path)
255
+ # For file:// URIs, extract just the filename for the backend call
256
+ if path_str.startswith("file://"):
257
+ parsed = urlparse(path_str)
258
+ file_path = unquote(parsed.path)
259
+ # Handle Windows paths (file:///C:/path)
260
+ if file_path and len(file_path) > 2 and file_path[2] == ":": # noqa: PLR2004
261
+ file_path = file_path[1:] # Remove leading slash for Windows
262
+ filename = Path(file_path).name
263
+ return backend.read_text(filename, encoding=self.encoding)
261
264
  return backend.read_text(path_str, encoding=self.encoding)
262
265
  except KeyError as e:
263
266
  raise SQLFileNotFoundError(path_str) from e
264
- except MissingDependencyError:
265
- try:
266
- return path.read_text(encoding=self.encoding) # type: ignore[union-attr]
267
- except FileNotFoundError as e:
268
- raise SQLFileNotFoundError(path_str) from e
269
267
  except StorageOperationFailedError as e:
270
268
  if "not found" in str(e).lower() or "no such file" in str(e).lower():
271
269
  raise SQLFileNotFoundError(path_str) from e
@@ -419,8 +417,7 @@ class SQLFileLoader:
419
417
  for file_path in sql_files:
420
418
  relative_path = file_path.relative_to(dir_path)
421
419
  namespace_parts = relative_path.parent.parts
422
- namespace = ".".join(namespace_parts) if namespace_parts else None
423
- self._load_single_file(file_path, namespace)
420
+ self._load_single_file(file_path, ".".join(namespace_parts) if namespace_parts else None)
424
421
  return len(sql_files)
425
422
 
426
423
  def _load_single_file(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
@@ -533,44 +530,6 @@ class SQLFileLoader:
533
530
  self._queries[normalized_name] = statement
534
531
  self._query_to_file[normalized_name] = "<directly added>"
535
532
 
536
- def get_sql(self, name: str) -> "SQL":
537
- """Get a SQL object by statement name.
538
-
539
- Args:
540
- name: Name of the statement (from -- name: in SQL file).
541
- Hyphens in names are converted to underscores.
542
-
543
- Returns:
544
- SQL object ready for execution.
545
-
546
- Raises:
547
- SQLFileNotFoundError: If statement name not found.
548
- """
549
- correlation_id = CorrelationContext.get()
550
-
551
- safe_name = _normalize_query_name(name)
552
-
553
- if safe_name not in self._queries:
554
- available = ", ".join(sorted(self._queries.keys())) if self._queries else "none"
555
- logger.error(
556
- "Statement not found: %s",
557
- name,
558
- extra={
559
- "statement_name": name,
560
- "safe_name": safe_name,
561
- "available_statements": len(self._queries),
562
- "correlation_id": correlation_id,
563
- },
564
- )
565
- raise SQLFileNotFoundError(name, path=f"Statement '{name}' not found. Available statements: {available}")
566
-
567
- parsed_statement = self._queries[safe_name]
568
- sqlglot_dialect = None
569
- if parsed_statement.dialect:
570
- sqlglot_dialect = _normalize_dialect_for_sqlglot(parsed_statement.dialect)
571
-
572
- return SQL(parsed_statement.sql, dialect=sqlglot_dialect)
573
-
574
533
  def get_file(self, path: Union[str, Path]) -> "Optional[SQLFile]":
575
534
  """Get a loaded SQLFile object by path.
576
535
 
@@ -659,3 +618,41 @@ class SQLFileLoader:
659
618
  if safe_name not in self._queries:
660
619
  raise SQLFileNotFoundError(name)
661
620
  return self._queries[safe_name].sql
621
+
622
+ def get_sql(self, name: str) -> "SQL":
623
+ """Get a SQL object by statement name.
624
+
625
+ Args:
626
+ name: Name of the statement (from -- name: in SQL file).
627
+ Hyphens in names are converted to underscores.
628
+
629
+ Returns:
630
+ SQL object ready for execution.
631
+
632
+ Raises:
633
+ SQLFileNotFoundError: If statement name not found.
634
+ """
635
+ correlation_id = CorrelationContext.get()
636
+
637
+ safe_name = _normalize_query_name(name)
638
+
639
+ if safe_name not in self._queries:
640
+ available = ", ".join(sorted(self._queries.keys())) if self._queries else "none"
641
+ logger.error(
642
+ "Statement not found: %s",
643
+ name,
644
+ extra={
645
+ "statement_name": name,
646
+ "safe_name": safe_name,
647
+ "available_statements": len(self._queries),
648
+ "correlation_id": correlation_id,
649
+ },
650
+ )
651
+ raise SQLFileNotFoundError(name, path=f"Statement '{name}' not found. Available statements: {available}")
652
+
653
+ parsed_statement = self._queries[safe_name]
654
+ sqlglot_dialect = None
655
+ if parsed_statement.dialect:
656
+ sqlglot_dialect = _normalize_dialect(parsed_statement.dialect)
657
+
658
+ return SQL(parsed_statement.sql, dialect=sqlglot_dialect)
sqlspec/protocols.py CHANGED
@@ -4,7 +4,7 @@ This module provides protocols that can be used for static type checking
4
4
  and runtime isinstance() checks.
5
5
  """
6
6
 
7
- from typing import TYPE_CHECKING, Any, ClassVar, Optional, Protocol, Union, runtime_checkable
7
+ from typing import TYPE_CHECKING, Any, Optional, Protocol, Union, runtime_checkable
8
8
 
9
9
  from typing_extensions import Self
10
10
 
@@ -14,7 +14,6 @@ if TYPE_CHECKING:
14
14
 
15
15
  from sqlglot import exp
16
16
 
17
- from sqlspec.storage.capabilities import StorageCapabilities
18
17
  from sqlspec.typing import ArrowRecordBatch, ArrowTable
19
18
 
20
19
  __all__ = (
@@ -194,9 +193,8 @@ class ObjectStoreItemProtocol(Protocol):
194
193
  class ObjectStoreProtocol(Protocol):
195
194
  """Protocol for object storage operations."""
196
195
 
197
- capabilities: ClassVar["StorageCapabilities"]
198
-
199
196
  protocol: str
197
+ backend_type: str
200
198
 
201
199
  def __init__(self, uri: str, **kwargs: Any) -> None:
202
200
  return
@@ -330,7 +328,7 @@ class ObjectStoreProtocol(Protocol):
330
328
  msg = "Async arrow writing not implemented"
331
329
  raise NotImplementedError(msg)
332
330
 
333
- async def stream_arrow_async(self, pattern: str, **kwargs: Any) -> "AsyncIterator[ArrowRecordBatch]":
331
+ def stream_arrow_async(self, pattern: str, **kwargs: Any) -> "AsyncIterator[ArrowRecordBatch]":
334
332
  """Async stream Arrow record batches from matching objects."""
335
333
  msg = "Async arrow streaming not implemented"
336
334
  raise NotImplementedError(msg)
@@ -8,16 +8,6 @@ Provides a storage system with:
8
8
  - Capability-based backend selection
9
9
  """
10
10
 
11
- from sqlspec.protocols import ObjectStoreProtocol
12
- from sqlspec.storage.capabilities import HasStorageCapabilities, StorageCapabilities
13
- from sqlspec.storage.registry import StorageRegistry
11
+ from sqlspec.storage.registry import StorageRegistry, storage_registry
14
12
 
15
- storage_registry = StorageRegistry()
16
-
17
- __all__ = (
18
- "HasStorageCapabilities",
19
- "ObjectStoreProtocol",
20
- "StorageCapabilities",
21
- "StorageRegistry",
22
- "storage_registry",
23
- )
13
+ __all__ = ("StorageRegistry", "storage_registry")
@@ -0,0 +1 @@
1
+ """Storage backends."""