sqlspec 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (199) hide show
  1. sqlspec/__init__.py +7 -15
  2. sqlspec/_serialization.py +256 -24
  3. sqlspec/_typing.py +71 -52
  4. sqlspec/adapters/adbc/_types.py +1 -1
  5. sqlspec/adapters/adbc/adk/__init__.py +5 -0
  6. sqlspec/adapters/adbc/adk/store.py +870 -0
  7. sqlspec/adapters/adbc/config.py +69 -12
  8. sqlspec/adapters/adbc/data_dictionary.py +340 -0
  9. sqlspec/adapters/adbc/driver.py +266 -58
  10. sqlspec/adapters/adbc/litestar/__init__.py +5 -0
  11. sqlspec/adapters/adbc/litestar/store.py +504 -0
  12. sqlspec/adapters/adbc/type_converter.py +153 -0
  13. sqlspec/adapters/aiosqlite/_types.py +1 -1
  14. sqlspec/adapters/aiosqlite/adk/__init__.py +5 -0
  15. sqlspec/adapters/aiosqlite/adk/store.py +527 -0
  16. sqlspec/adapters/aiosqlite/config.py +88 -15
  17. sqlspec/adapters/aiosqlite/data_dictionary.py +149 -0
  18. sqlspec/adapters/aiosqlite/driver.py +143 -40
  19. sqlspec/adapters/aiosqlite/litestar/__init__.py +5 -0
  20. sqlspec/adapters/aiosqlite/litestar/store.py +281 -0
  21. sqlspec/adapters/aiosqlite/pool.py +7 -7
  22. sqlspec/adapters/asyncmy/__init__.py +7 -1
  23. sqlspec/adapters/asyncmy/_types.py +2 -2
  24. sqlspec/adapters/asyncmy/adk/__init__.py +5 -0
  25. sqlspec/adapters/asyncmy/adk/store.py +493 -0
  26. sqlspec/adapters/asyncmy/config.py +68 -23
  27. sqlspec/adapters/asyncmy/data_dictionary.py +161 -0
  28. sqlspec/adapters/asyncmy/driver.py +313 -58
  29. sqlspec/adapters/asyncmy/litestar/__init__.py +5 -0
  30. sqlspec/adapters/asyncmy/litestar/store.py +296 -0
  31. sqlspec/adapters/asyncpg/__init__.py +2 -1
  32. sqlspec/adapters/asyncpg/_type_handlers.py +71 -0
  33. sqlspec/adapters/asyncpg/_types.py +11 -7
  34. sqlspec/adapters/asyncpg/adk/__init__.py +5 -0
  35. sqlspec/adapters/asyncpg/adk/store.py +450 -0
  36. sqlspec/adapters/asyncpg/config.py +59 -35
  37. sqlspec/adapters/asyncpg/data_dictionary.py +173 -0
  38. sqlspec/adapters/asyncpg/driver.py +170 -25
  39. sqlspec/adapters/asyncpg/litestar/__init__.py +5 -0
  40. sqlspec/adapters/asyncpg/litestar/store.py +253 -0
  41. sqlspec/adapters/bigquery/_types.py +1 -1
  42. sqlspec/adapters/bigquery/adk/__init__.py +5 -0
  43. sqlspec/adapters/bigquery/adk/store.py +576 -0
  44. sqlspec/adapters/bigquery/config.py +27 -10
  45. sqlspec/adapters/bigquery/data_dictionary.py +149 -0
  46. sqlspec/adapters/bigquery/driver.py +368 -142
  47. sqlspec/adapters/bigquery/litestar/__init__.py +5 -0
  48. sqlspec/adapters/bigquery/litestar/store.py +327 -0
  49. sqlspec/adapters/bigquery/type_converter.py +125 -0
  50. sqlspec/adapters/duckdb/_types.py +1 -1
  51. sqlspec/adapters/duckdb/adk/__init__.py +14 -0
  52. sqlspec/adapters/duckdb/adk/store.py +553 -0
  53. sqlspec/adapters/duckdb/config.py +80 -20
  54. sqlspec/adapters/duckdb/data_dictionary.py +163 -0
  55. sqlspec/adapters/duckdb/driver.py +167 -45
  56. sqlspec/adapters/duckdb/litestar/__init__.py +5 -0
  57. sqlspec/adapters/duckdb/litestar/store.py +332 -0
  58. sqlspec/adapters/duckdb/pool.py +4 -4
  59. sqlspec/adapters/duckdb/type_converter.py +133 -0
  60. sqlspec/adapters/oracledb/_numpy_handlers.py +133 -0
  61. sqlspec/adapters/oracledb/_types.py +20 -2
  62. sqlspec/adapters/oracledb/adk/__init__.py +5 -0
  63. sqlspec/adapters/oracledb/adk/store.py +1745 -0
  64. sqlspec/adapters/oracledb/config.py +122 -32
  65. sqlspec/adapters/oracledb/data_dictionary.py +509 -0
  66. sqlspec/adapters/oracledb/driver.py +353 -91
  67. sqlspec/adapters/oracledb/litestar/__init__.py +5 -0
  68. sqlspec/adapters/oracledb/litestar/store.py +767 -0
  69. sqlspec/adapters/oracledb/migrations.py +348 -73
  70. sqlspec/adapters/oracledb/type_converter.py +207 -0
  71. sqlspec/adapters/psqlpy/_type_handlers.py +44 -0
  72. sqlspec/adapters/psqlpy/_types.py +2 -1
  73. sqlspec/adapters/psqlpy/adk/__init__.py +5 -0
  74. sqlspec/adapters/psqlpy/adk/store.py +482 -0
  75. sqlspec/adapters/psqlpy/config.py +46 -17
  76. sqlspec/adapters/psqlpy/data_dictionary.py +172 -0
  77. sqlspec/adapters/psqlpy/driver.py +123 -209
  78. sqlspec/adapters/psqlpy/litestar/__init__.py +5 -0
  79. sqlspec/adapters/psqlpy/litestar/store.py +272 -0
  80. sqlspec/adapters/psqlpy/type_converter.py +102 -0
  81. sqlspec/adapters/psycopg/_type_handlers.py +80 -0
  82. sqlspec/adapters/psycopg/_types.py +2 -1
  83. sqlspec/adapters/psycopg/adk/__init__.py +5 -0
  84. sqlspec/adapters/psycopg/adk/store.py +944 -0
  85. sqlspec/adapters/psycopg/config.py +69 -35
  86. sqlspec/adapters/psycopg/data_dictionary.py +331 -0
  87. sqlspec/adapters/psycopg/driver.py +238 -81
  88. sqlspec/adapters/psycopg/litestar/__init__.py +5 -0
  89. sqlspec/adapters/psycopg/litestar/store.py +554 -0
  90. sqlspec/adapters/sqlite/__init__.py +2 -1
  91. sqlspec/adapters/sqlite/_type_handlers.py +86 -0
  92. sqlspec/adapters/sqlite/_types.py +1 -1
  93. sqlspec/adapters/sqlite/adk/__init__.py +5 -0
  94. sqlspec/adapters/sqlite/adk/store.py +572 -0
  95. sqlspec/adapters/sqlite/config.py +87 -15
  96. sqlspec/adapters/sqlite/data_dictionary.py +149 -0
  97. sqlspec/adapters/sqlite/driver.py +137 -54
  98. sqlspec/adapters/sqlite/litestar/__init__.py +5 -0
  99. sqlspec/adapters/sqlite/litestar/store.py +318 -0
  100. sqlspec/adapters/sqlite/pool.py +18 -9
  101. sqlspec/base.py +45 -26
  102. sqlspec/builder/__init__.py +73 -4
  103. sqlspec/builder/_base.py +162 -89
  104. sqlspec/builder/_column.py +62 -29
  105. sqlspec/builder/_ddl.py +180 -121
  106. sqlspec/builder/_delete.py +5 -4
  107. sqlspec/builder/_dml.py +388 -0
  108. sqlspec/{_sql.py → builder/_factory.py} +53 -94
  109. sqlspec/builder/_insert.py +32 -131
  110. sqlspec/builder/_join.py +375 -0
  111. sqlspec/builder/_merge.py +446 -11
  112. sqlspec/builder/_parsing_utils.py +111 -17
  113. sqlspec/builder/_select.py +1457 -24
  114. sqlspec/builder/_update.py +11 -42
  115. sqlspec/cli.py +307 -194
  116. sqlspec/config.py +252 -67
  117. sqlspec/core/__init__.py +5 -4
  118. sqlspec/core/cache.py +17 -17
  119. sqlspec/core/compiler.py +62 -9
  120. sqlspec/core/filters.py +37 -37
  121. sqlspec/core/hashing.py +9 -9
  122. sqlspec/core/parameters.py +83 -48
  123. sqlspec/core/result.py +102 -46
  124. sqlspec/core/splitter.py +16 -17
  125. sqlspec/core/statement.py +36 -30
  126. sqlspec/core/type_conversion.py +235 -0
  127. sqlspec/driver/__init__.py +7 -6
  128. sqlspec/driver/_async.py +188 -151
  129. sqlspec/driver/_common.py +285 -80
  130. sqlspec/driver/_sync.py +188 -152
  131. sqlspec/driver/mixins/_result_tools.py +20 -236
  132. sqlspec/driver/mixins/_sql_translator.py +4 -4
  133. sqlspec/exceptions.py +75 -7
  134. sqlspec/extensions/adk/__init__.py +53 -0
  135. sqlspec/extensions/adk/_types.py +51 -0
  136. sqlspec/extensions/adk/converters.py +172 -0
  137. sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +144 -0
  138. sqlspec/extensions/adk/migrations/__init__.py +0 -0
  139. sqlspec/extensions/adk/service.py +181 -0
  140. sqlspec/extensions/adk/store.py +536 -0
  141. sqlspec/extensions/aiosql/adapter.py +73 -53
  142. sqlspec/extensions/litestar/__init__.py +21 -4
  143. sqlspec/extensions/litestar/cli.py +54 -10
  144. sqlspec/extensions/litestar/config.py +59 -266
  145. sqlspec/extensions/litestar/handlers.py +46 -17
  146. sqlspec/extensions/litestar/migrations/0001_create_session_table.py +137 -0
  147. sqlspec/extensions/litestar/migrations/__init__.py +3 -0
  148. sqlspec/extensions/litestar/plugin.py +324 -223
  149. sqlspec/extensions/litestar/providers.py +25 -25
  150. sqlspec/extensions/litestar/store.py +265 -0
  151. sqlspec/loader.py +30 -49
  152. sqlspec/migrations/__init__.py +4 -3
  153. sqlspec/migrations/base.py +302 -39
  154. sqlspec/migrations/commands.py +611 -144
  155. sqlspec/migrations/context.py +142 -0
  156. sqlspec/migrations/fix.py +199 -0
  157. sqlspec/migrations/loaders.py +68 -23
  158. sqlspec/migrations/runner.py +543 -107
  159. sqlspec/migrations/tracker.py +237 -21
  160. sqlspec/migrations/utils.py +51 -3
  161. sqlspec/migrations/validation.py +177 -0
  162. sqlspec/protocols.py +66 -36
  163. sqlspec/storage/_utils.py +98 -0
  164. sqlspec/storage/backends/fsspec.py +134 -106
  165. sqlspec/storage/backends/local.py +78 -51
  166. sqlspec/storage/backends/obstore.py +278 -162
  167. sqlspec/storage/registry.py +75 -39
  168. sqlspec/typing.py +16 -84
  169. sqlspec/utils/config_resolver.py +153 -0
  170. sqlspec/utils/correlation.py +4 -5
  171. sqlspec/utils/data_transformation.py +3 -2
  172. sqlspec/utils/deprecation.py +9 -8
  173. sqlspec/utils/fixtures.py +4 -4
  174. sqlspec/utils/logging.py +46 -6
  175. sqlspec/utils/module_loader.py +2 -2
  176. sqlspec/utils/schema.py +288 -0
  177. sqlspec/utils/serializers.py +50 -2
  178. sqlspec/utils/sync_tools.py +21 -17
  179. sqlspec/utils/text.py +1 -2
  180. sqlspec/utils/type_guards.py +111 -20
  181. sqlspec/utils/version.py +433 -0
  182. {sqlspec-0.25.0.dist-info → sqlspec-0.27.0.dist-info}/METADATA +40 -21
  183. sqlspec-0.27.0.dist-info/RECORD +207 -0
  184. sqlspec/builder/mixins/__init__.py +0 -55
  185. sqlspec/builder/mixins/_cte_and_set_ops.py +0 -254
  186. sqlspec/builder/mixins/_delete_operations.py +0 -50
  187. sqlspec/builder/mixins/_insert_operations.py +0 -282
  188. sqlspec/builder/mixins/_join_operations.py +0 -389
  189. sqlspec/builder/mixins/_merge_operations.py +0 -592
  190. sqlspec/builder/mixins/_order_limit_operations.py +0 -152
  191. sqlspec/builder/mixins/_pivot_operations.py +0 -157
  192. sqlspec/builder/mixins/_select_operations.py +0 -936
  193. sqlspec/builder/mixins/_update_operations.py +0 -218
  194. sqlspec/builder/mixins/_where_clause.py +0 -1304
  195. sqlspec-0.25.0.dist-info/RECORD +0 -139
  196. sqlspec-0.25.0.dist-info/licenses/NOTICE +0 -29
  197. {sqlspec-0.25.0.dist-info → sqlspec-0.27.0.dist-info}/WHEEL +0 -0
  198. {sqlspec-0.25.0.dist-info → sqlspec-0.27.0.dist-info}/entry_points.txt +0 -0
  199. {sqlspec-0.25.0.dist-info → sqlspec-0.27.0.dist-info}/licenses/LICENSE +0 -0
@@ -7,7 +7,7 @@ This module contains functions to create dependency providers for services and f
7
7
  import datetime
8
8
  import inspect
9
9
  from collections.abc import Callable
10
- from typing import Any, Literal, NamedTuple, Optional, TypedDict, Union, cast
10
+ from typing import Any, Literal, NamedTuple, TypedDict, cast
11
11
  from uuid import UUID
12
12
 
13
13
  from litestar.di import Provide
@@ -44,15 +44,15 @@ __all__ = (
44
44
  "dep_cache",
45
45
  )
46
46
 
47
- DTorNone = Optional[datetime.datetime]
48
- StringOrNone = Optional[str]
49
- UuidOrNone = Optional[UUID]
50
- IntOrNone = Optional[int]
51
- BooleanOrNone = Optional[bool]
47
+ DTorNone = datetime.datetime | None
48
+ StringOrNone = str | None
49
+ UuidOrNone = UUID | None
50
+ IntOrNone = int | None
51
+ BooleanOrNone = bool | None
52
52
  SortOrder = Literal["asc", "desc"]
53
- SortOrderOrNone = Optional[SortOrder]
54
- HashableValue = Union[str, int, float, bool, None]
55
- HashableType = Union[HashableValue, tuple[Any, ...], tuple[tuple[str, Any], ...], tuple[HashableValue, ...]]
53
+ SortOrderOrNone = SortOrder | None
54
+ HashableValue = str | int | float | bool | None
55
+ HashableType = HashableValue | tuple[Any, ...] | tuple[tuple[str, Any], ...] | tuple[HashableValue, ...]
56
56
 
57
57
 
58
58
  class DependencyDefaults:
@@ -79,30 +79,30 @@ class FieldNameType(NamedTuple):
79
79
  class FilterConfig(TypedDict):
80
80
  """Configuration for generating dynamic filters."""
81
81
 
82
- id_filter: NotRequired[type[Union[UUID, int, str]]]
82
+ id_filter: NotRequired[type[UUID | int | str]]
83
83
  id_field: NotRequired[str]
84
84
  sort_field: NotRequired[str]
85
85
  sort_order: NotRequired[SortOrder]
86
86
  pagination_type: NotRequired[Literal["limit_offset"]]
87
87
  pagination_size: NotRequired[int]
88
- search: NotRequired[Union[str, set[str], list[str]]]
88
+ search: NotRequired[str | set[str] | list[str]]
89
89
  search_ignore_case: NotRequired[bool]
90
90
  created_at: NotRequired[bool]
91
91
  updated_at: NotRequired[bool]
92
- not_in_fields: NotRequired[Union[FieldNameType, set[FieldNameType], list[Union[str, FieldNameType]]]]
93
- in_fields: NotRequired[Union[FieldNameType, set[FieldNameType], list[Union[str, FieldNameType]]]]
92
+ not_in_fields: NotRequired[FieldNameType | set[FieldNameType] | list[str | FieldNameType]]
93
+ in_fields: NotRequired[FieldNameType | set[FieldNameType] | list[str | FieldNameType]]
94
94
 
95
95
 
96
96
  class DependencyCache(metaclass=SingletonMeta):
97
97
  """Dependency cache for memoizing dynamically generated dependencies."""
98
98
 
99
99
  def __init__(self) -> None:
100
- self.dependencies: dict[Union[int, str], dict[str, Provide]] = {}
100
+ self.dependencies: dict[int | str, dict[str, Provide]] = {}
101
101
 
102
- def add_dependencies(self, key: Union[int, str], dependencies: dict[str, Provide]) -> None:
102
+ def add_dependencies(self, key: int | str, dependencies: dict[str, Provide]) -> None:
103
103
  self.dependencies[key] = dependencies
104
104
 
105
- def get_dependencies(self, key: Union[int, str]) -> Optional[dict[str, Provide]]:
105
+ def get_dependencies(self, key: int | str) -> dict[str, Provide] | None:
106
106
  return self.dependencies.get(key)
107
107
 
108
108
 
@@ -169,7 +169,7 @@ def _create_statement_filters(
169
169
  if config.get("id_filter", False):
170
170
 
171
171
  def provide_id_filter( # pyright: ignore[reportUnknownParameterType]
172
- ids: Optional[list[str]] = Parameter(query="ids", default=None, required=False),
172
+ ids: list[str] | None = Parameter(query="ids", default=None, required=False),
173
173
  ) -> InCollectionFilter: # pyright: ignore[reportMissingTypeArgument]
174
174
  return InCollectionFilter(field_name=config.get("id_field", "id"), values=ids)
175
175
 
@@ -257,12 +257,12 @@ def _create_statement_filters(
257
257
 
258
258
  def create_not_in_filter_provider( # pyright: ignore
259
259
  field_name: FieldNameType,
260
- ) -> Callable[..., Optional[NotInCollectionFilter[field_def.type_hint]]]: # type: ignore
260
+ ) -> Callable[..., NotInCollectionFilter[field_def.type_hint] | None]: # type: ignore
261
261
  def provide_not_in_filter( # pyright: ignore
262
- values: Optional[list[field_name.type_hint]] = Parameter( # type: ignore
262
+ values: list[field_name.type_hint] | None = Parameter( # type: ignore
263
263
  query=camelize(f"{field_name.name}_not_in"), default=None, required=False
264
264
  ),
265
- ) -> Optional[NotInCollectionFilter[field_name.type_hint]]: # type: ignore
265
+ ) -> NotInCollectionFilter[field_name.type_hint] | None: # type: ignore
266
266
  return (
267
267
  NotInCollectionFilter[field_name.type_hint](field_name=field_name.name, values=values) # type: ignore
268
268
  if values
@@ -282,12 +282,12 @@ def _create_statement_filters(
282
282
 
283
283
  def create_in_filter_provider( # pyright: ignore
284
284
  field_name: FieldNameType,
285
- ) -> Callable[..., Optional[InCollectionFilter[field_def.type_hint]]]: # type: ignore # pyright: ignore
285
+ ) -> Callable[..., InCollectionFilter[field_def.type_hint] | None]: # type: ignore # pyright: ignore
286
286
  def provide_in_filter( # pyright: ignore
287
- values: Optional[list[field_name.type_hint]] = Parameter( # type: ignore # pyright: ignore
287
+ values: list[field_name.type_hint] | None = Parameter( # type: ignore # pyright: ignore
288
288
  query=camelize(f"{field_name.name}_in"), default=None, required=False
289
289
  ),
290
- ) -> Optional[InCollectionFilter[field_name.type_hint]]: # type: ignore # pyright: ignore
290
+ ) -> InCollectionFilter[field_name.type_hint] | None: # type: ignore # pyright: ignore
291
291
  return (
292
292
  InCollectionFilter[field_name.type_hint](field_name=field_name.name, values=values) # type: ignore # pyright: ignore
293
293
  if values
@@ -415,14 +415,14 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
415
415
  if updated_filter := kwargs.get("updated_filter"):
416
416
  filters.append(updated_filter)
417
417
  if (
418
- (search_filter := cast("Optional[SearchFilter]", kwargs.get("search_filter")))
418
+ (search_filter := cast("SearchFilter | None", kwargs.get("search_filter")))
419
419
  and search_filter is not None # pyright: ignore[reportUnnecessaryComparison]
420
420
  and search_filter.field_name is not None # pyright: ignore[reportUnnecessaryComparison]
421
421
  and search_filter.value is not None # pyright: ignore[reportUnnecessaryComparison]
422
422
  ):
423
423
  filters.append(search_filter)
424
424
  if (
425
- (order_by := cast("Optional[OrderByFilter]", kwargs.get("order_by_filter")))
425
+ (order_by := cast("OrderByFilter | None", kwargs.get("order_by_filter")))
426
426
  and order_by is not None # pyright: ignore[reportUnnecessaryComparison]
427
427
  and order_by.field_name is not None # pyright: ignore[reportUnnecessaryComparison]
428
428
  ):
@@ -0,0 +1,265 @@
1
+ """Base session store classes for Litestar integration."""
2
+
3
+ import re
4
+ from abc import ABC, abstractmethod
5
+ from datetime import datetime, timedelta, timezone
6
+ from typing import TYPE_CHECKING, Any, Final, Generic, TypeVar, cast
7
+
8
+ from sqlspec.utils.logging import get_logger
9
+
10
+ if TYPE_CHECKING:
11
+ from types import TracebackType
12
+
13
+
14
+ ConfigT = TypeVar("ConfigT")
15
+
16
+
17
+ logger = get_logger("extensions.litestar.store")
18
+
19
+ __all__ = ("BaseSQLSpecStore",)
20
+
21
+ VALID_TABLE_NAME_PATTERN: Final = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
22
+ MAX_TABLE_NAME_LENGTH: Final = 63
23
+
24
+
25
+ class BaseSQLSpecStore(ABC, Generic[ConfigT]):
26
+ """Base class for SQLSpec-backed Litestar session stores.
27
+
28
+ Implements the litestar.stores.base.Store protocol for server-side session
29
+ storage using SQLSpec database adapters.
30
+
31
+ This abstract base class provides common functionality for all database-specific
32
+ store implementations including:
33
+ - Connection management via SQLSpec configs
34
+ - Session expiration calculation
35
+ - Table creation utilities
36
+
37
+ Subclasses must implement dialect-specific SQL queries.
38
+
39
+ Args:
40
+ config: SQLSpec database configuration with extension_config["litestar"] settings.
41
+
42
+ Example:
43
+ from sqlspec.adapters.asyncpg import AsyncpgConfig
44
+ from sqlspec.adapters.asyncpg.litestar.store import AsyncpgStore
45
+
46
+ config = AsyncpgConfig(
47
+ pool_config={"dsn": "postgresql://..."},
48
+ extension_config={"litestar": {"session_table": "my_sessions"}}
49
+ )
50
+ store = AsyncpgStore(config)
51
+ await store.create_table()
52
+
53
+ Notes:
54
+ Configuration is read from config.extension_config["litestar"]:
55
+ - session_table: Table name (default: "litestar_session")
56
+ """
57
+
58
+ __slots__ = ("_config", "_table_name")
59
+
60
+ def __init__(self, config: ConfigT) -> None:
61
+ """Initialize the session store.
62
+
63
+ Args:
64
+ config: SQLSpec database configuration.
65
+
66
+ Notes:
67
+ Reads table_name from config.extension_config["litestar"]["session_table"].
68
+ Defaults to "litestar_session" if not specified.
69
+ """
70
+ self._config = config
71
+ self._table_name = self._get_table_name_from_config()
72
+ self._validate_table_name(self._table_name)
73
+
74
+ def _get_table_name_from_config(self) -> str:
75
+ """Extract table name from config.extension_config.
76
+
77
+ Returns:
78
+ Table name for the session store.
79
+ """
80
+ if hasattr(self._config, "extension_config"):
81
+ extension_config = cast("dict[str, dict[str, Any]]", self._config.extension_config) # pyright: ignore
82
+ litestar_config: dict[str, Any] = extension_config.get("litestar", {})
83
+ return str(litestar_config.get("session_table", "litestar_session"))
84
+ return "litestar_session"
85
+
86
+ @property
87
+ def config(self) -> ConfigT:
88
+ """Return the database configuration."""
89
+ return self._config
90
+
91
+ @property
92
+ def table_name(self) -> str:
93
+ """Return the session table name."""
94
+ return self._table_name
95
+
96
+ @abstractmethod
97
+ async def get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None":
98
+ """Get a session value by key.
99
+
100
+ Args:
101
+ key: Session ID to retrieve.
102
+ renew_for: If given and the value had an initial expiry time set, renew the
103
+ expiry time for ``renew_for`` seconds. If the value has not been set
104
+ with an expiry time this is a no-op.
105
+
106
+ Returns:
107
+ Session data as bytes if found and not expired, None otherwise.
108
+ """
109
+ raise NotImplementedError
110
+
111
+ @abstractmethod
112
+ async def set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None:
113
+ """Store a session value.
114
+
115
+ Args:
116
+ key: Session ID.
117
+ value: Session data (will be converted to bytes if string).
118
+ expires_in: Time in seconds or timedelta before expiration.
119
+ """
120
+ raise NotImplementedError
121
+
122
+ @abstractmethod
123
+ async def delete(self, key: str) -> None:
124
+ """Delete a session by key.
125
+
126
+ Args:
127
+ key: Session ID to delete.
128
+ """
129
+ raise NotImplementedError
130
+
131
+ @abstractmethod
132
+ async def delete_all(self) -> None:
133
+ """Delete all sessions from the store."""
134
+ raise NotImplementedError
135
+
136
+ @abstractmethod
137
+ async def exists(self, key: str) -> bool:
138
+ """Check if a session key exists and is not expired.
139
+
140
+ Args:
141
+ key: Session ID to check.
142
+
143
+ Returns:
144
+ True if the session exists and is not expired.
145
+ """
146
+ raise NotImplementedError
147
+
148
+ @abstractmethod
149
+ async def expires_in(self, key: str) -> "int | None":
150
+ """Get the time in seconds until the session expires.
151
+
152
+ Args:
153
+ key: Session ID to check.
154
+
155
+ Returns:
156
+ Seconds until expiration, or None if no expiry or key doesn't exist.
157
+ """
158
+ raise NotImplementedError
159
+
160
+ @abstractmethod
161
+ async def delete_expired(self) -> int:
162
+ """Delete all expired sessions.
163
+
164
+ Returns:
165
+ Number of sessions deleted.
166
+ """
167
+ raise NotImplementedError
168
+
169
+ @abstractmethod
170
+ async def create_table(self) -> None:
171
+ """Create the session table if it doesn't exist."""
172
+ raise NotImplementedError
173
+
174
+ @abstractmethod
175
+ def _get_create_table_sql(self) -> str:
176
+ """Get the CREATE TABLE SQL for this database dialect.
177
+
178
+ Returns:
179
+ SQL statement to create the sessions table.
180
+ """
181
+ raise NotImplementedError
182
+
183
+ @abstractmethod
184
+ def _get_drop_table_sql(self) -> "list[str]":
185
+ """Get the DROP TABLE SQL statements for this database dialect.
186
+
187
+ Returns:
188
+ List of SQL statements to drop the table and all indexes.
189
+ Order matters: drop indexes before table.
190
+
191
+ Notes:
192
+ Should use IF EXISTS or dialect-specific error handling
193
+ to allow idempotent migrations.
194
+ """
195
+ raise NotImplementedError
196
+
197
+ async def __aenter__(self) -> "BaseSQLSpecStore":
198
+ """Enter context manager."""
199
+ return self
200
+
201
+ async def __aexit__(
202
+ self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None"
203
+ ) -> None:
204
+ """Exit context manager."""
205
+ return
206
+
207
+ def _calculate_expires_at(self, expires_in: "int | timedelta | None") -> "datetime | None":
208
+ """Calculate expiration timestamp from expires_in.
209
+
210
+ Args:
211
+ expires_in: Seconds or timedelta until expiration.
212
+
213
+ Returns:
214
+ UTC datetime of expiration, or None if no expiration.
215
+ """
216
+ if expires_in is None:
217
+ return None
218
+
219
+ expires_in_seconds = int(expires_in.total_seconds()) if isinstance(expires_in, timedelta) else expires_in
220
+
221
+ if expires_in_seconds <= 0:
222
+ return None
223
+
224
+ return datetime.now(timezone.utc) + timedelta(seconds=expires_in_seconds)
225
+
226
+ def _value_to_bytes(self, value: "str | bytes") -> bytes:
227
+ """Convert value to bytes if needed.
228
+
229
+ Args:
230
+ value: String or bytes value.
231
+
232
+ Returns:
233
+ Value as bytes.
234
+ """
235
+ if isinstance(value, str):
236
+ return value.encode("utf-8")
237
+ return value
238
+
239
+ @staticmethod
240
+ def _validate_table_name(table_name: str) -> None:
241
+ """Validate table name for SQL safety.
242
+
243
+ Args:
244
+ table_name: Table name to validate.
245
+
246
+ Raises:
247
+ ValueError: If table name is invalid.
248
+
249
+ Notes:
250
+ - Must start with letter or underscore
251
+ - Can only contain letters, numbers, and underscores
252
+ - Maximum length is 63 characters (PostgreSQL limit)
253
+ - Prevents SQL injection in table names
254
+ """
255
+ if not table_name:
256
+ msg = "Table name cannot be empty"
257
+ raise ValueError(msg)
258
+
259
+ if len(table_name) > MAX_TABLE_NAME_LENGTH:
260
+ msg = f"Table name too long: {len(table_name)} chars (max {MAX_TABLE_NAME_LENGTH})"
261
+ raise ValueError(msg)
262
+
263
+ if not VALID_TABLE_NAME_PATTERN.match(table_name):
264
+ msg = f"Invalid table name: {table_name!r}. Must start with letter/underscore and contain only alphanumeric characters and underscores"
265
+ raise ValueError(msg)
sqlspec/loader.py CHANGED
@@ -9,7 +9,7 @@ import re
9
9
  import time
10
10
  from datetime import datetime, timezone
11
11
  from pathlib import Path
12
- from typing import TYPE_CHECKING, Any, Final, Optional, Union
12
+ from typing import TYPE_CHECKING, Any, Final
13
13
  from urllib.parse import unquote, urlparse
14
14
 
15
15
  from sqlspec.core.cache import get_cache, get_cache_config
@@ -95,7 +95,7 @@ class NamedStatement:
95
95
 
96
96
  __slots__ = ("dialect", "name", "sql", "start_line")
97
97
 
98
- def __init__(self, name: str, sql: str, dialect: "Optional[str]" = None, start_line: int = 0) -> None:
98
+ def __init__(self, name: str, sql: str, dialect: "str | None" = None, start_line: int = 0) -> None:
99
99
  self.name = name
100
100
  self.sql = sql
101
101
  self.dialect = dialect
@@ -112,11 +112,7 @@ class SQLFile:
112
112
  __slots__ = ("checksum", "content", "loaded_at", "metadata", "path")
113
113
 
114
114
  def __init__(
115
- self,
116
- content: str,
117
- path: str,
118
- metadata: "Optional[dict[str, Any]]" = None,
119
- loaded_at: "Optional[datetime]" = None,
115
+ self, content: str, path: str, metadata: "dict[str, Any] | None" = None, loaded_at: "datetime | None" = None
120
116
  ) -> None:
121
117
  """Initialize SQLFile.
122
118
 
@@ -163,7 +159,7 @@ class SQLFileLoader:
163
159
 
164
160
  __slots__ = ("_files", "_queries", "_query_to_file", "encoding", "storage_registry")
165
161
 
166
- def __init__(self, *, encoding: str = "utf-8", storage_registry: "Optional[StorageRegistry]" = None) -> None:
162
+ def __init__(self, *, encoding: str = "utf-8", storage_registry: "StorageRegistry | None" = None) -> None:
167
163
  """Initialize the SQL file loader.
168
164
 
169
165
  Args:
@@ -188,7 +184,7 @@ class SQLFileLoader:
188
184
  """
189
185
  raise SQLFileNotFoundError(path)
190
186
 
191
- def _generate_file_cache_key(self, path: Union[str, Path]) -> str:
187
+ def _generate_file_cache_key(self, path: str | Path) -> str:
192
188
  """Generate cache key for a file path.
193
189
 
194
190
  Args:
@@ -201,7 +197,7 @@ class SQLFileLoader:
201
197
  path_hash = hashlib.md5(path_str.encode(), usedforsecurity=False).hexdigest()
202
198
  return f"file:{path_hash[:16]}"
203
199
 
204
- def _calculate_file_checksum(self, path: Union[str, Path]) -> str:
200
+ def _calculate_file_checksum(self, path: str | Path) -> str:
205
201
  """Calculate checksum for file content validation.
206
202
 
207
203
  Args:
@@ -218,7 +214,7 @@ class SQLFileLoader:
218
214
  except Exception as e:
219
215
  raise SQLFileParseError(str(path), str(path), e) from e
220
216
 
221
- def _is_file_unchanged(self, path: Union[str, Path], cached_file: CachedSQLFile) -> bool:
217
+ def _is_file_unchanged(self, path: str | Path, cached_file: CachedSQLFile) -> bool:
222
218
  """Check if file has changed since caching.
223
219
 
224
220
  Args:
@@ -235,7 +231,7 @@ class SQLFileLoader:
235
231
  else:
236
232
  return current_checksum == cached_file.sql_file.checksum
237
233
 
238
- def _read_file_content(self, path: Union[str, Path]) -> str:
234
+ def _read_file_content(self, path: str | Path) -> str:
239
235
  """Read file content using storage backend.
240
236
 
241
237
  Args:
@@ -349,7 +345,7 @@ class SQLFileLoader:
349
345
 
350
346
  return statements
351
347
 
352
- def load_sql(self, *paths: Union[str, Path]) -> None:
348
+ def load_sql(self, *paths: str | Path) -> None:
353
349
  """Load SQL files and parse named queries.
354
350
 
355
351
  Args:
@@ -358,43 +354,20 @@ class SQLFileLoader:
358
354
  correlation_id = CorrelationContext.get()
359
355
  start_time = time.perf_counter()
360
356
 
361
- logger.info("Loading SQL files", extra={"file_count": len(paths), "correlation_id": correlation_id})
362
-
363
- loaded_count = 0
364
- query_count_before = len(self._queries)
365
-
366
357
  try:
367
358
  for path in paths:
368
359
  path_str = str(path)
369
360
  if "://" in path_str:
370
361
  self._load_single_file(path, None)
371
- loaded_count += 1
372
362
  else:
373
363
  path_obj = Path(path)
374
364
  if path_obj.is_dir():
375
- loaded_count += self._load_directory(path_obj)
365
+ self._load_directory(path_obj)
376
366
  elif path_obj.exists():
377
367
  self._load_single_file(path_obj, None)
378
- loaded_count += 1
379
368
  elif path_obj.suffix:
380
369
  self._raise_file_not_found(str(path))
381
370
 
382
- duration = time.perf_counter() - start_time
383
- new_queries = len(self._queries) - query_count_before
384
-
385
- logger.info(
386
- "Loaded %d SQL files with %d new queries in %.3fms",
387
- loaded_count,
388
- new_queries,
389
- duration * 1000,
390
- extra={
391
- "files_loaded": loaded_count,
392
- "new_queries": new_queries,
393
- "duration_ms": duration * 1000,
394
- "correlation_id": correlation_id,
395
- },
396
- )
397
-
398
371
  except Exception as e:
399
372
  duration = time.perf_counter() - start_time
400
373
  logger.exception(
@@ -408,34 +381,40 @@ class SQLFileLoader:
408
381
  )
409
382
  raise
410
383
 
411
- def _load_directory(self, dir_path: Path) -> int:
412
- """Load all SQL files from a directory."""
384
+ def _load_directory(self, dir_path: Path) -> None:
385
+ """Load all SQL files from a directory.
386
+
387
+ Args:
388
+ dir_path: Directory path to load SQL files from.
389
+ """
413
390
  sql_files = list(dir_path.rglob("*.sql"))
414
391
  if not sql_files:
415
- return 0
392
+ return
416
393
 
417
394
  for file_path in sql_files:
418
395
  relative_path = file_path.relative_to(dir_path)
419
396
  namespace_parts = relative_path.parent.parts
420
397
  self._load_single_file(file_path, ".".join(namespace_parts) if namespace_parts else None)
421
- return len(sql_files)
422
398
 
423
- def _load_single_file(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
399
+ def _load_single_file(self, file_path: str | Path, namespace: str | None) -> bool:
424
400
  """Load a single SQL file with optional namespace.
425
401
 
426
402
  Args:
427
403
  file_path: Path to the SQL file.
428
404
  namespace: Optional namespace prefix for queries.
405
+
406
+ Returns:
407
+ True if file was newly loaded, False if already cached.
429
408
  """
430
409
  path_str = str(file_path)
431
410
 
432
411
  if path_str in self._files:
433
- return
412
+ return False
434
413
 
435
414
  cache_config = get_cache_config()
436
415
  if not cache_config.compiled_cache_enabled:
437
416
  self._load_file_without_cache(file_path, namespace)
438
- return
417
+ return True
439
418
 
440
419
  cache_key_str = self._generate_file_cache_key(file_path)
441
420
  cache = get_cache()
@@ -459,7 +438,7 @@ class SQLFileLoader:
459
438
  )
460
439
  self._queries[namespaced_name] = statement
461
440
  self._query_to_file[namespaced_name] = path_str
462
- return
441
+ return True
463
442
 
464
443
  self._load_file_without_cache(file_path, namespace)
465
444
 
@@ -476,7 +455,9 @@ class SQLFileLoader:
476
455
  cached_file_data = CachedSQLFile(sql_file=sql_file, parsed_statements=file_statements)
477
456
  cache.put("file", cache_key_str, cached_file_data)
478
457
 
479
- def _load_file_without_cache(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
458
+ return True
459
+
460
+ def _load_file_without_cache(self, file_path: str | Path, namespace: str | None) -> None:
480
461
  """Load a single SQL file without using cache.
481
462
 
482
463
  Args:
@@ -503,7 +484,7 @@ class SQLFileLoader:
503
484
  self._queries[namespaced_name] = statement
504
485
  self._query_to_file[namespaced_name] = path_str
505
486
 
506
- def add_named_sql(self, name: str, sql: str, dialect: "Optional[str]" = None) -> None:
487
+ def add_named_sql(self, name: str, sql: str, dialect: "str | None" = None) -> None:
507
488
  """Add a named SQL query directly without loading from a file.
508
489
 
509
490
  Args:
@@ -529,7 +510,7 @@ class SQLFileLoader:
529
510
  self._queries[normalized_name] = statement
530
511
  self._query_to_file[normalized_name] = "<directly added>"
531
512
 
532
- def get_file(self, path: Union[str, Path]) -> "Optional[SQLFile]":
513
+ def get_file(self, path: str | Path) -> "SQLFile | None":
533
514
  """Get a loaded SQLFile object by path.
534
515
 
535
516
  Args:
@@ -540,7 +521,7 @@ class SQLFileLoader:
540
521
  """
541
522
  return self._files.get(str(path))
542
523
 
543
- def get_file_for_query(self, name: str) -> "Optional[SQLFile]":
524
+ def get_file_for_query(self, name: str) -> "SQLFile | None":
544
525
  """Get the SQLFile object containing a query.
545
526
 
546
527
  Args:
@@ -4,7 +4,7 @@ A native migration system for SQLSpec that leverages the SQLFileLoader
4
4
  and driver system for database versioning.
5
5
  """
6
6
 
7
- from sqlspec.migrations.commands import AsyncMigrationCommands, MigrationCommands, SyncMigrationCommands
7
+ from sqlspec.migrations.commands import AsyncMigrationCommands, SyncMigrationCommands, create_migration_commands
8
8
  from sqlspec.migrations.loaders import (
9
9
  BaseMigrationLoader,
10
10
  MigrationLoadError,
@@ -12,7 +12,7 @@ from sqlspec.migrations.loaders import (
12
12
  SQLFileLoader,
13
13
  get_migration_loader,
14
14
  )
15
- from sqlspec.migrations.runner import AsyncMigrationRunner, SyncMigrationRunner
15
+ from sqlspec.migrations.runner import AsyncMigrationRunner, SyncMigrationRunner, create_migration_runner
16
16
  from sqlspec.migrations.tracker import AsyncMigrationTracker, SyncMigrationTracker
17
17
  from sqlspec.migrations.utils import create_migration_file, drop_all, get_author
18
18
 
@@ -21,14 +21,15 @@ __all__ = (
21
21
  "AsyncMigrationRunner",
22
22
  "AsyncMigrationTracker",
23
23
  "BaseMigrationLoader",
24
- "MigrationCommands",
25
24
  "MigrationLoadError",
26
25
  "PythonFileLoader",
27
26
  "SQLFileLoader",
28
27
  "SyncMigrationCommands",
29
28
  "SyncMigrationRunner",
30
29
  "SyncMigrationTracker",
30
+ "create_migration_commands",
31
31
  "create_migration_file",
32
+ "create_migration_runner",
32
33
  "drop_all",
33
34
  "get_author",
34
35
  "get_migration_loader",