sqlspec 0.26.0__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +7 -15
- sqlspec/_serialization.py +55 -25
- sqlspec/_typing.py +155 -52
- sqlspec/adapters/adbc/_types.py +1 -1
- sqlspec/adapters/adbc/adk/__init__.py +5 -0
- sqlspec/adapters/adbc/adk/store.py +880 -0
- sqlspec/adapters/adbc/config.py +62 -12
- sqlspec/adapters/adbc/data_dictionary.py +74 -2
- sqlspec/adapters/adbc/driver.py +226 -58
- sqlspec/adapters/adbc/litestar/__init__.py +5 -0
- sqlspec/adapters/adbc/litestar/store.py +504 -0
- sqlspec/adapters/adbc/type_converter.py +44 -50
- sqlspec/adapters/aiosqlite/_types.py +1 -1
- sqlspec/adapters/aiosqlite/adk/__init__.py +5 -0
- sqlspec/adapters/aiosqlite/adk/store.py +536 -0
- sqlspec/adapters/aiosqlite/config.py +86 -16
- sqlspec/adapters/aiosqlite/data_dictionary.py +34 -2
- sqlspec/adapters/aiosqlite/driver.py +127 -38
- sqlspec/adapters/aiosqlite/litestar/__init__.py +5 -0
- sqlspec/adapters/aiosqlite/litestar/store.py +281 -0
- sqlspec/adapters/aiosqlite/pool.py +7 -7
- sqlspec/adapters/asyncmy/__init__.py +7 -1
- sqlspec/adapters/asyncmy/_types.py +1 -1
- sqlspec/adapters/asyncmy/adk/__init__.py +5 -0
- sqlspec/adapters/asyncmy/adk/store.py +503 -0
- sqlspec/adapters/asyncmy/config.py +59 -17
- sqlspec/adapters/asyncmy/data_dictionary.py +41 -2
- sqlspec/adapters/asyncmy/driver.py +293 -62
- sqlspec/adapters/asyncmy/litestar/__init__.py +5 -0
- sqlspec/adapters/asyncmy/litestar/store.py +296 -0
- sqlspec/adapters/asyncpg/__init__.py +2 -1
- sqlspec/adapters/asyncpg/_type_handlers.py +71 -0
- sqlspec/adapters/asyncpg/_types.py +11 -7
- sqlspec/adapters/asyncpg/adk/__init__.py +5 -0
- sqlspec/adapters/asyncpg/adk/store.py +460 -0
- sqlspec/adapters/asyncpg/config.py +57 -36
- sqlspec/adapters/asyncpg/data_dictionary.py +48 -2
- sqlspec/adapters/asyncpg/driver.py +153 -23
- sqlspec/adapters/asyncpg/litestar/__init__.py +5 -0
- sqlspec/adapters/asyncpg/litestar/store.py +253 -0
- sqlspec/adapters/bigquery/_types.py +1 -1
- sqlspec/adapters/bigquery/adk/__init__.py +5 -0
- sqlspec/adapters/bigquery/adk/store.py +585 -0
- sqlspec/adapters/bigquery/config.py +36 -11
- sqlspec/adapters/bigquery/data_dictionary.py +42 -2
- sqlspec/adapters/bigquery/driver.py +489 -144
- sqlspec/adapters/bigquery/litestar/__init__.py +5 -0
- sqlspec/adapters/bigquery/litestar/store.py +327 -0
- sqlspec/adapters/bigquery/type_converter.py +55 -23
- sqlspec/adapters/duckdb/_types.py +2 -2
- sqlspec/adapters/duckdb/adk/__init__.py +14 -0
- sqlspec/adapters/duckdb/adk/store.py +563 -0
- sqlspec/adapters/duckdb/config.py +79 -21
- sqlspec/adapters/duckdb/data_dictionary.py +41 -2
- sqlspec/adapters/duckdb/driver.py +225 -44
- sqlspec/adapters/duckdb/litestar/__init__.py +5 -0
- sqlspec/adapters/duckdb/litestar/store.py +332 -0
- sqlspec/adapters/duckdb/pool.py +5 -5
- sqlspec/adapters/duckdb/type_converter.py +51 -21
- sqlspec/adapters/oracledb/_numpy_handlers.py +133 -0
- sqlspec/adapters/oracledb/_types.py +20 -2
- sqlspec/adapters/oracledb/adk/__init__.py +5 -0
- sqlspec/adapters/oracledb/adk/store.py +1628 -0
- sqlspec/adapters/oracledb/config.py +120 -36
- sqlspec/adapters/oracledb/data_dictionary.py +87 -20
- sqlspec/adapters/oracledb/driver.py +475 -86
- sqlspec/adapters/oracledb/litestar/__init__.py +5 -0
- sqlspec/adapters/oracledb/litestar/store.py +765 -0
- sqlspec/adapters/oracledb/migrations.py +316 -25
- sqlspec/adapters/oracledb/type_converter.py +91 -16
- sqlspec/adapters/psqlpy/_type_handlers.py +44 -0
- sqlspec/adapters/psqlpy/_types.py +2 -1
- sqlspec/adapters/psqlpy/adk/__init__.py +5 -0
- sqlspec/adapters/psqlpy/adk/store.py +483 -0
- sqlspec/adapters/psqlpy/config.py +45 -19
- sqlspec/adapters/psqlpy/data_dictionary.py +48 -2
- sqlspec/adapters/psqlpy/driver.py +108 -41
- sqlspec/adapters/psqlpy/litestar/__init__.py +5 -0
- sqlspec/adapters/psqlpy/litestar/store.py +272 -0
- sqlspec/adapters/psqlpy/type_converter.py +40 -11
- sqlspec/adapters/psycopg/_type_handlers.py +80 -0
- sqlspec/adapters/psycopg/_types.py +2 -1
- sqlspec/adapters/psycopg/adk/__init__.py +5 -0
- sqlspec/adapters/psycopg/adk/store.py +962 -0
- sqlspec/adapters/psycopg/config.py +65 -37
- sqlspec/adapters/psycopg/data_dictionary.py +91 -3
- sqlspec/adapters/psycopg/driver.py +200 -78
- sqlspec/adapters/psycopg/litestar/__init__.py +5 -0
- sqlspec/adapters/psycopg/litestar/store.py +554 -0
- sqlspec/adapters/sqlite/__init__.py +2 -1
- sqlspec/adapters/sqlite/_type_handlers.py +86 -0
- sqlspec/adapters/sqlite/_types.py +1 -1
- sqlspec/adapters/sqlite/adk/__init__.py +5 -0
- sqlspec/adapters/sqlite/adk/store.py +582 -0
- sqlspec/adapters/sqlite/config.py +85 -16
- sqlspec/adapters/sqlite/data_dictionary.py +34 -2
- sqlspec/adapters/sqlite/driver.py +120 -52
- sqlspec/adapters/sqlite/litestar/__init__.py +5 -0
- sqlspec/adapters/sqlite/litestar/store.py +318 -0
- sqlspec/adapters/sqlite/pool.py +5 -5
- sqlspec/base.py +45 -26
- sqlspec/builder/__init__.py +73 -4
- sqlspec/builder/_base.py +91 -58
- sqlspec/builder/_column.py +5 -5
- sqlspec/builder/_ddl.py +98 -89
- sqlspec/builder/_delete.py +5 -4
- sqlspec/builder/_dml.py +388 -0
- sqlspec/{_sql.py → builder/_factory.py} +41 -44
- sqlspec/builder/_insert.py +5 -82
- sqlspec/builder/{mixins/_join_operations.py → _join.py} +145 -143
- sqlspec/builder/_merge.py +446 -11
- sqlspec/builder/_parsing_utils.py +9 -11
- sqlspec/builder/_select.py +1313 -25
- sqlspec/builder/_update.py +11 -42
- sqlspec/cli.py +76 -69
- sqlspec/config.py +331 -62
- sqlspec/core/__init__.py +5 -4
- sqlspec/core/cache.py +18 -18
- sqlspec/core/compiler.py +6 -8
- sqlspec/core/filters.py +55 -47
- sqlspec/core/hashing.py +9 -9
- sqlspec/core/parameters.py +76 -45
- sqlspec/core/result.py +234 -47
- sqlspec/core/splitter.py +16 -17
- sqlspec/core/statement.py +32 -31
- sqlspec/core/type_conversion.py +3 -2
- sqlspec/driver/__init__.py +1 -3
- sqlspec/driver/_async.py +183 -160
- sqlspec/driver/_common.py +197 -109
- sqlspec/driver/_sync.py +189 -161
- sqlspec/driver/mixins/_result_tools.py +20 -236
- sqlspec/driver/mixins/_sql_translator.py +4 -4
- sqlspec/exceptions.py +70 -7
- sqlspec/extensions/adk/__init__.py +53 -0
- sqlspec/extensions/adk/_types.py +51 -0
- sqlspec/extensions/adk/converters.py +172 -0
- sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +144 -0
- sqlspec/extensions/adk/migrations/__init__.py +0 -0
- sqlspec/extensions/adk/service.py +181 -0
- sqlspec/extensions/adk/store.py +536 -0
- sqlspec/extensions/aiosql/adapter.py +69 -61
- sqlspec/extensions/fastapi/__init__.py +21 -0
- sqlspec/extensions/fastapi/extension.py +331 -0
- sqlspec/extensions/fastapi/providers.py +543 -0
- sqlspec/extensions/flask/__init__.py +36 -0
- sqlspec/extensions/flask/_state.py +71 -0
- sqlspec/extensions/flask/_utils.py +40 -0
- sqlspec/extensions/flask/extension.py +389 -0
- sqlspec/extensions/litestar/__init__.py +21 -4
- sqlspec/extensions/litestar/cli.py +54 -10
- sqlspec/extensions/litestar/config.py +56 -266
- sqlspec/extensions/litestar/handlers.py +46 -17
- sqlspec/extensions/litestar/migrations/0001_create_session_table.py +137 -0
- sqlspec/extensions/litestar/migrations/__init__.py +3 -0
- sqlspec/extensions/litestar/plugin.py +349 -224
- sqlspec/extensions/litestar/providers.py +25 -25
- sqlspec/extensions/litestar/store.py +265 -0
- sqlspec/extensions/starlette/__init__.py +10 -0
- sqlspec/extensions/starlette/_state.py +25 -0
- sqlspec/extensions/starlette/_utils.py +52 -0
- sqlspec/extensions/starlette/extension.py +254 -0
- sqlspec/extensions/starlette/middleware.py +154 -0
- sqlspec/loader.py +30 -49
- sqlspec/migrations/base.py +200 -76
- sqlspec/migrations/commands.py +591 -62
- sqlspec/migrations/context.py +6 -9
- sqlspec/migrations/fix.py +199 -0
- sqlspec/migrations/loaders.py +47 -19
- sqlspec/migrations/runner.py +241 -75
- sqlspec/migrations/tracker.py +237 -21
- sqlspec/migrations/utils.py +51 -3
- sqlspec/migrations/validation.py +177 -0
- sqlspec/protocols.py +106 -36
- sqlspec/storage/_utils.py +85 -0
- sqlspec/storage/backends/fsspec.py +133 -107
- sqlspec/storage/backends/local.py +78 -51
- sqlspec/storage/backends/obstore.py +276 -168
- sqlspec/storage/registry.py +75 -39
- sqlspec/typing.py +30 -84
- sqlspec/utils/__init__.py +25 -4
- sqlspec/utils/arrow_helpers.py +81 -0
- sqlspec/utils/config_resolver.py +6 -6
- sqlspec/utils/correlation.py +4 -5
- sqlspec/utils/data_transformation.py +3 -2
- sqlspec/utils/deprecation.py +9 -8
- sqlspec/utils/fixtures.py +4 -4
- sqlspec/utils/logging.py +46 -6
- sqlspec/utils/module_loader.py +205 -5
- sqlspec/utils/portal.py +311 -0
- sqlspec/utils/schema.py +288 -0
- sqlspec/utils/serializers.py +113 -4
- sqlspec/utils/sync_tools.py +36 -22
- sqlspec/utils/text.py +1 -2
- sqlspec/utils/type_guards.py +136 -20
- sqlspec/utils/version.py +433 -0
- {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/METADATA +41 -22
- sqlspec-0.28.0.dist-info/RECORD +221 -0
- sqlspec/builder/mixins/__init__.py +0 -55
- sqlspec/builder/mixins/_cte_and_set_ops.py +0 -253
- sqlspec/builder/mixins/_delete_operations.py +0 -50
- sqlspec/builder/mixins/_insert_operations.py +0 -282
- sqlspec/builder/mixins/_merge_operations.py +0 -698
- sqlspec/builder/mixins/_order_limit_operations.py +0 -145
- sqlspec/builder/mixins/_pivot_operations.py +0 -157
- sqlspec/builder/mixins/_select_operations.py +0 -930
- sqlspec/builder/mixins/_update_operations.py +0 -199
- sqlspec/builder/mixins/_where_clause.py +0 -1298
- sqlspec-0.26.0.dist-info/RECORD +0 -157
- sqlspec-0.26.0.dist-info/licenses/NOTICE +0 -29
- {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -7,7 +7,7 @@ This module contains functions to create dependency providers for services and f
|
|
|
7
7
|
import datetime
|
|
8
8
|
import inspect
|
|
9
9
|
from collections.abc import Callable
|
|
10
|
-
from typing import Any, Literal, NamedTuple,
|
|
10
|
+
from typing import Any, Literal, NamedTuple, TypedDict, cast
|
|
11
11
|
from uuid import UUID
|
|
12
12
|
|
|
13
13
|
from litestar.di import Provide
|
|
@@ -44,15 +44,15 @@ __all__ = (
|
|
|
44
44
|
"dep_cache",
|
|
45
45
|
)
|
|
46
46
|
|
|
47
|
-
DTorNone =
|
|
48
|
-
StringOrNone =
|
|
49
|
-
UuidOrNone =
|
|
50
|
-
IntOrNone =
|
|
51
|
-
BooleanOrNone =
|
|
47
|
+
DTorNone = datetime.datetime | None
|
|
48
|
+
StringOrNone = str | None
|
|
49
|
+
UuidOrNone = UUID | None
|
|
50
|
+
IntOrNone = int | None
|
|
51
|
+
BooleanOrNone = bool | None
|
|
52
52
|
SortOrder = Literal["asc", "desc"]
|
|
53
|
-
SortOrderOrNone =
|
|
54
|
-
HashableValue =
|
|
55
|
-
HashableType =
|
|
53
|
+
SortOrderOrNone = SortOrder | None
|
|
54
|
+
HashableValue = str | int | float | bool | None
|
|
55
|
+
HashableType = HashableValue | tuple[Any, ...] | tuple[tuple[str, Any], ...] | tuple[HashableValue, ...]
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
class DependencyDefaults:
|
|
@@ -79,30 +79,30 @@ class FieldNameType(NamedTuple):
|
|
|
79
79
|
class FilterConfig(TypedDict):
|
|
80
80
|
"""Configuration for generating dynamic filters."""
|
|
81
81
|
|
|
82
|
-
id_filter: NotRequired[type[
|
|
82
|
+
id_filter: NotRequired[type[UUID | int | str]]
|
|
83
83
|
id_field: NotRequired[str]
|
|
84
84
|
sort_field: NotRequired[str]
|
|
85
85
|
sort_order: NotRequired[SortOrder]
|
|
86
86
|
pagination_type: NotRequired[Literal["limit_offset"]]
|
|
87
87
|
pagination_size: NotRequired[int]
|
|
88
|
-
search: NotRequired[
|
|
88
|
+
search: NotRequired[str | set[str] | list[str]]
|
|
89
89
|
search_ignore_case: NotRequired[bool]
|
|
90
90
|
created_at: NotRequired[bool]
|
|
91
91
|
updated_at: NotRequired[bool]
|
|
92
|
-
not_in_fields: NotRequired[
|
|
93
|
-
in_fields: NotRequired[
|
|
92
|
+
not_in_fields: NotRequired[FieldNameType | set[FieldNameType] | list[str | FieldNameType]]
|
|
93
|
+
in_fields: NotRequired[FieldNameType | set[FieldNameType] | list[str | FieldNameType]]
|
|
94
94
|
|
|
95
95
|
|
|
96
96
|
class DependencyCache(metaclass=SingletonMeta):
|
|
97
97
|
"""Dependency cache for memoizing dynamically generated dependencies."""
|
|
98
98
|
|
|
99
99
|
def __init__(self) -> None:
|
|
100
|
-
self.dependencies: dict[
|
|
100
|
+
self.dependencies: dict[int | str, dict[str, Provide]] = {}
|
|
101
101
|
|
|
102
|
-
def add_dependencies(self, key:
|
|
102
|
+
def add_dependencies(self, key: int | str, dependencies: dict[str, Provide]) -> None:
|
|
103
103
|
self.dependencies[key] = dependencies
|
|
104
104
|
|
|
105
|
-
def get_dependencies(self, key:
|
|
105
|
+
def get_dependencies(self, key: int | str) -> dict[str, Provide] | None:
|
|
106
106
|
return self.dependencies.get(key)
|
|
107
107
|
|
|
108
108
|
|
|
@@ -169,7 +169,7 @@ def _create_statement_filters(
|
|
|
169
169
|
if config.get("id_filter", False):
|
|
170
170
|
|
|
171
171
|
def provide_id_filter( # pyright: ignore[reportUnknownParameterType]
|
|
172
|
-
ids:
|
|
172
|
+
ids: list[str] | None = Parameter(query="ids", default=None, required=False),
|
|
173
173
|
) -> InCollectionFilter: # pyright: ignore[reportMissingTypeArgument]
|
|
174
174
|
return InCollectionFilter(field_name=config.get("id_field", "id"), values=ids)
|
|
175
175
|
|
|
@@ -257,12 +257,12 @@ def _create_statement_filters(
|
|
|
257
257
|
|
|
258
258
|
def create_not_in_filter_provider( # pyright: ignore
|
|
259
259
|
field_name: FieldNameType,
|
|
260
|
-
) -> Callable[...,
|
|
260
|
+
) -> Callable[..., NotInCollectionFilter[field_def.type_hint] | None]: # type: ignore
|
|
261
261
|
def provide_not_in_filter( # pyright: ignore
|
|
262
|
-
values:
|
|
262
|
+
values: list[field_name.type_hint] | None = Parameter( # type: ignore
|
|
263
263
|
query=camelize(f"{field_name.name}_not_in"), default=None, required=False
|
|
264
264
|
),
|
|
265
|
-
) ->
|
|
265
|
+
) -> NotInCollectionFilter[field_name.type_hint] | None: # type: ignore
|
|
266
266
|
return (
|
|
267
267
|
NotInCollectionFilter[field_name.type_hint](field_name=field_name.name, values=values) # type: ignore
|
|
268
268
|
if values
|
|
@@ -282,12 +282,12 @@ def _create_statement_filters(
|
|
|
282
282
|
|
|
283
283
|
def create_in_filter_provider( # pyright: ignore
|
|
284
284
|
field_name: FieldNameType,
|
|
285
|
-
) -> Callable[...,
|
|
285
|
+
) -> Callable[..., InCollectionFilter[field_def.type_hint] | None]: # type: ignore # pyright: ignore
|
|
286
286
|
def provide_in_filter( # pyright: ignore
|
|
287
|
-
values:
|
|
287
|
+
values: list[field_name.type_hint] | None = Parameter( # type: ignore # pyright: ignore
|
|
288
288
|
query=camelize(f"{field_name.name}_in"), default=None, required=False
|
|
289
289
|
),
|
|
290
|
-
) ->
|
|
290
|
+
) -> InCollectionFilter[field_name.type_hint] | None: # type: ignore # pyright: ignore
|
|
291
291
|
return (
|
|
292
292
|
InCollectionFilter[field_name.type_hint](field_name=field_name.name, values=values) # type: ignore # pyright: ignore
|
|
293
293
|
if values
|
|
@@ -415,14 +415,14 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
|
|
|
415
415
|
if updated_filter := kwargs.get("updated_filter"):
|
|
416
416
|
filters.append(updated_filter)
|
|
417
417
|
if (
|
|
418
|
-
(search_filter := cast("
|
|
418
|
+
(search_filter := cast("SearchFilter | None", kwargs.get("search_filter")))
|
|
419
419
|
and search_filter is not None # pyright: ignore[reportUnnecessaryComparison]
|
|
420
420
|
and search_filter.field_name is not None # pyright: ignore[reportUnnecessaryComparison]
|
|
421
421
|
and search_filter.value is not None # pyright: ignore[reportUnnecessaryComparison]
|
|
422
422
|
):
|
|
423
423
|
filters.append(search_filter)
|
|
424
424
|
if (
|
|
425
|
-
(order_by := cast("
|
|
425
|
+
(order_by := cast("OrderByFilter | None", kwargs.get("order_by_filter")))
|
|
426
426
|
and order_by is not None # pyright: ignore[reportUnnecessaryComparison]
|
|
427
427
|
and order_by.field_name is not None # pyright: ignore[reportUnnecessaryComparison]
|
|
428
428
|
):
|
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
"""Base session store classes for Litestar integration."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from datetime import datetime, timedelta, timezone
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Final, Generic, TypeVar, cast
|
|
7
|
+
|
|
8
|
+
from sqlspec.utils.logging import get_logger
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from types import TracebackType
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
ConfigT = TypeVar("ConfigT")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
logger = get_logger("extensions.litestar.store")
|
|
18
|
+
|
|
19
|
+
__all__ = ("BaseSQLSpecStore",)
|
|
20
|
+
|
|
21
|
+
VALID_TABLE_NAME_PATTERN: Final = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
|
|
22
|
+
MAX_TABLE_NAME_LENGTH: Final = 63
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class BaseSQLSpecStore(ABC, Generic[ConfigT]):
|
|
26
|
+
"""Base class for SQLSpec-backed Litestar session stores.
|
|
27
|
+
|
|
28
|
+
Implements the litestar.stores.base.Store protocol for server-side session
|
|
29
|
+
storage using SQLSpec database adapters.
|
|
30
|
+
|
|
31
|
+
This abstract base class provides common functionality for all database-specific
|
|
32
|
+
store implementations including:
|
|
33
|
+
- Connection management via SQLSpec configs
|
|
34
|
+
- Session expiration calculation
|
|
35
|
+
- Table creation utilities
|
|
36
|
+
|
|
37
|
+
Subclasses must implement dialect-specific SQL queries.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
config: SQLSpec database configuration with extension_config["litestar"] settings.
|
|
41
|
+
|
|
42
|
+
Example:
|
|
43
|
+
from sqlspec.adapters.asyncpg import AsyncpgConfig
|
|
44
|
+
from sqlspec.adapters.asyncpg.litestar.store import AsyncpgStore
|
|
45
|
+
|
|
46
|
+
config = AsyncpgConfig(
|
|
47
|
+
pool_config={"dsn": "postgresql://..."},
|
|
48
|
+
extension_config={"litestar": {"session_table": "my_sessions"}}
|
|
49
|
+
)
|
|
50
|
+
store = AsyncpgStore(config)
|
|
51
|
+
await store.create_table()
|
|
52
|
+
|
|
53
|
+
Notes:
|
|
54
|
+
Configuration is read from config.extension_config["litestar"]:
|
|
55
|
+
- session_table: Table name (default: "litestar_session")
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
__slots__ = ("_config", "_table_name")
|
|
59
|
+
|
|
60
|
+
def __init__(self, config: ConfigT) -> None:
|
|
61
|
+
"""Initialize the session store.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
config: SQLSpec database configuration.
|
|
65
|
+
|
|
66
|
+
Notes:
|
|
67
|
+
Reads table_name from config.extension_config["litestar"]["session_table"].
|
|
68
|
+
Defaults to "litestar_session" if not specified.
|
|
69
|
+
"""
|
|
70
|
+
self._config = config
|
|
71
|
+
self._table_name = self._get_table_name_from_config()
|
|
72
|
+
self._validate_table_name(self._table_name)
|
|
73
|
+
|
|
74
|
+
def _get_table_name_from_config(self) -> str:
|
|
75
|
+
"""Extract table name from config.extension_config.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Table name for the session store.
|
|
79
|
+
"""
|
|
80
|
+
if hasattr(self._config, "extension_config"):
|
|
81
|
+
extension_config = cast("dict[str, dict[str, Any]]", self._config.extension_config) # pyright: ignore
|
|
82
|
+
litestar_config: dict[str, Any] = extension_config.get("litestar", {})
|
|
83
|
+
return str(litestar_config.get("session_table", "litestar_session"))
|
|
84
|
+
return "litestar_session"
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def config(self) -> ConfigT:
|
|
88
|
+
"""Return the database configuration."""
|
|
89
|
+
return self._config
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def table_name(self) -> str:
|
|
93
|
+
"""Return the session table name."""
|
|
94
|
+
return self._table_name
|
|
95
|
+
|
|
96
|
+
@abstractmethod
|
|
97
|
+
async def get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None":
|
|
98
|
+
"""Get a session value by key.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
key: Session ID to retrieve.
|
|
102
|
+
renew_for: If given and the value had an initial expiry time set, renew the
|
|
103
|
+
expiry time for ``renew_for`` seconds. If the value has not been set
|
|
104
|
+
with an expiry time this is a no-op.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
Session data as bytes if found and not expired, None otherwise.
|
|
108
|
+
"""
|
|
109
|
+
raise NotImplementedError
|
|
110
|
+
|
|
111
|
+
@abstractmethod
|
|
112
|
+
async def set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None:
|
|
113
|
+
"""Store a session value.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
key: Session ID.
|
|
117
|
+
value: Session data (will be converted to bytes if string).
|
|
118
|
+
expires_in: Time in seconds or timedelta before expiration.
|
|
119
|
+
"""
|
|
120
|
+
raise NotImplementedError
|
|
121
|
+
|
|
122
|
+
@abstractmethod
|
|
123
|
+
async def delete(self, key: str) -> None:
|
|
124
|
+
"""Delete a session by key.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
key: Session ID to delete.
|
|
128
|
+
"""
|
|
129
|
+
raise NotImplementedError
|
|
130
|
+
|
|
131
|
+
@abstractmethod
|
|
132
|
+
async def delete_all(self) -> None:
|
|
133
|
+
"""Delete all sessions from the store."""
|
|
134
|
+
raise NotImplementedError
|
|
135
|
+
|
|
136
|
+
@abstractmethod
|
|
137
|
+
async def exists(self, key: str) -> bool:
|
|
138
|
+
"""Check if a session key exists and is not expired.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
key: Session ID to check.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
True if the session exists and is not expired.
|
|
145
|
+
"""
|
|
146
|
+
raise NotImplementedError
|
|
147
|
+
|
|
148
|
+
@abstractmethod
|
|
149
|
+
async def expires_in(self, key: str) -> "int | None":
|
|
150
|
+
"""Get the time in seconds until the session expires.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
key: Session ID to check.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Seconds until expiration, or None if no expiry or key doesn't exist.
|
|
157
|
+
"""
|
|
158
|
+
raise NotImplementedError
|
|
159
|
+
|
|
160
|
+
@abstractmethod
|
|
161
|
+
async def delete_expired(self) -> int:
|
|
162
|
+
"""Delete all expired sessions.
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
Number of sessions deleted.
|
|
166
|
+
"""
|
|
167
|
+
raise NotImplementedError
|
|
168
|
+
|
|
169
|
+
@abstractmethod
|
|
170
|
+
async def create_table(self) -> None:
|
|
171
|
+
"""Create the session table if it doesn't exist."""
|
|
172
|
+
raise NotImplementedError
|
|
173
|
+
|
|
174
|
+
@abstractmethod
|
|
175
|
+
def _get_create_table_sql(self) -> str:
|
|
176
|
+
"""Get the CREATE TABLE SQL for this database dialect.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
SQL statement to create the sessions table.
|
|
180
|
+
"""
|
|
181
|
+
raise NotImplementedError
|
|
182
|
+
|
|
183
|
+
@abstractmethod
|
|
184
|
+
def _get_drop_table_sql(self) -> "list[str]":
|
|
185
|
+
"""Get the DROP TABLE SQL statements for this database dialect.
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
List of SQL statements to drop the table and all indexes.
|
|
189
|
+
Order matters: drop indexes before table.
|
|
190
|
+
|
|
191
|
+
Notes:
|
|
192
|
+
Should use IF EXISTS or dialect-specific error handling
|
|
193
|
+
to allow idempotent migrations.
|
|
194
|
+
"""
|
|
195
|
+
raise NotImplementedError
|
|
196
|
+
|
|
197
|
+
async def __aenter__(self) -> "BaseSQLSpecStore":
|
|
198
|
+
"""Enter context manager."""
|
|
199
|
+
return self
|
|
200
|
+
|
|
201
|
+
async def __aexit__(
|
|
202
|
+
self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None"
|
|
203
|
+
) -> None:
|
|
204
|
+
"""Exit context manager."""
|
|
205
|
+
return
|
|
206
|
+
|
|
207
|
+
def _calculate_expires_at(self, expires_in: "int | timedelta | None") -> "datetime | None":
|
|
208
|
+
"""Calculate expiration timestamp from expires_in.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
expires_in: Seconds or timedelta until expiration.
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
UTC datetime of expiration, or None if no expiration.
|
|
215
|
+
"""
|
|
216
|
+
if expires_in is None:
|
|
217
|
+
return None
|
|
218
|
+
|
|
219
|
+
expires_in_seconds = int(expires_in.total_seconds()) if isinstance(expires_in, timedelta) else expires_in
|
|
220
|
+
|
|
221
|
+
if expires_in_seconds <= 0:
|
|
222
|
+
return None
|
|
223
|
+
|
|
224
|
+
return datetime.now(timezone.utc) + timedelta(seconds=expires_in_seconds)
|
|
225
|
+
|
|
226
|
+
def _value_to_bytes(self, value: "str | bytes") -> bytes:
|
|
227
|
+
"""Convert value to bytes if needed.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
value: String or bytes value.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
Value as bytes.
|
|
234
|
+
"""
|
|
235
|
+
if isinstance(value, str):
|
|
236
|
+
return value.encode("utf-8")
|
|
237
|
+
return value
|
|
238
|
+
|
|
239
|
+
@staticmethod
|
|
240
|
+
def _validate_table_name(table_name: str) -> None:
|
|
241
|
+
"""Validate table name for SQL safety.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
table_name: Table name to validate.
|
|
245
|
+
|
|
246
|
+
Raises:
|
|
247
|
+
ValueError: If table name is invalid.
|
|
248
|
+
|
|
249
|
+
Notes:
|
|
250
|
+
- Must start with letter or underscore
|
|
251
|
+
- Can only contain letters, numbers, and underscores
|
|
252
|
+
- Maximum length is 63 characters (PostgreSQL limit)
|
|
253
|
+
- Prevents SQL injection in table names
|
|
254
|
+
"""
|
|
255
|
+
if not table_name:
|
|
256
|
+
msg = "Table name cannot be empty"
|
|
257
|
+
raise ValueError(msg)
|
|
258
|
+
|
|
259
|
+
if len(table_name) > MAX_TABLE_NAME_LENGTH:
|
|
260
|
+
msg = f"Table name too long: {len(table_name)} chars (max {MAX_TABLE_NAME_LENGTH})"
|
|
261
|
+
raise ValueError(msg)
|
|
262
|
+
|
|
263
|
+
if not VALID_TABLE_NAME_PATTERN.match(table_name):
|
|
264
|
+
msg = f"Invalid table name: {table_name!r}. Must start with letter/underscore and contain only alphanumeric characters and underscores"
|
|
265
|
+
raise ValueError(msg)
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""Starlette extension for SQLSpec.
|
|
2
|
+
|
|
3
|
+
Provides middleware-based session management, automatic transaction handling,
|
|
4
|
+
and connection pooling lifecycle management for Starlette applications.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from sqlspec.extensions.starlette.extension import SQLSpecPlugin
|
|
8
|
+
from sqlspec.extensions.starlette.middleware import SQLSpecAutocommitMiddleware, SQLSpecManualMiddleware
|
|
9
|
+
|
|
10
|
+
__all__ = ("SQLSpecAutocommitMiddleware", "SQLSpecManualMiddleware", "SQLSpecPlugin")
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
3
|
+
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from sqlspec.config import DatabaseConfigProtocol
|
|
6
|
+
|
|
7
|
+
__all__ = ("CommitMode", "SQLSpecConfigState")
|
|
8
|
+
|
|
9
|
+
CommitMode = Literal["manual", "autocommit", "autocommit_include_redirect"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class SQLSpecConfigState:
|
|
14
|
+
"""Internal state for each database configuration.
|
|
15
|
+
|
|
16
|
+
Tracks all configuration parameters needed for middleware and session management.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
config: "DatabaseConfigProtocol[Any, Any, Any]"
|
|
20
|
+
connection_key: str
|
|
21
|
+
pool_key: str
|
|
22
|
+
session_key: str
|
|
23
|
+
commit_mode: CommitMode
|
|
24
|
+
extra_commit_statuses: "set[int] | None"
|
|
25
|
+
extra_rollback_statuses: "set[int] | None"
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Any
|
|
2
|
+
|
|
3
|
+
if TYPE_CHECKING:
|
|
4
|
+
from starlette.requests import Request
|
|
5
|
+
|
|
6
|
+
from sqlspec.extensions.starlette._state import SQLSpecConfigState
|
|
7
|
+
|
|
8
|
+
__all__ = ("get_connection_from_request", "get_or_create_session")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_connection_from_request(request: "Request", config_state: "SQLSpecConfigState") -> Any:
|
|
12
|
+
"""Get database connection from request state.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
request: Starlette request instance.
|
|
16
|
+
config_state: Configuration state for the database.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
Database connection object.
|
|
20
|
+
"""
|
|
21
|
+
return getattr(request.state, config_state.connection_key)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_or_create_session(request: "Request", config_state: "SQLSpecConfigState") -> Any:
|
|
25
|
+
"""Get or create database session for request.
|
|
26
|
+
|
|
27
|
+
Sessions are cached per request to ensure the same session instance
|
|
28
|
+
is returned for multiple calls within the same request.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
request: Starlette request instance.
|
|
32
|
+
config_state: Configuration state for the database.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Database session (driver instance).
|
|
36
|
+
"""
|
|
37
|
+
session_instance_key = f"{config_state.session_key}_instance"
|
|
38
|
+
|
|
39
|
+
existing_session = getattr(request.state, session_instance_key, None)
|
|
40
|
+
if existing_session is not None:
|
|
41
|
+
return existing_session
|
|
42
|
+
|
|
43
|
+
connection = get_connection_from_request(request, config_state)
|
|
44
|
+
|
|
45
|
+
session = config_state.config.driver_type(
|
|
46
|
+
connection=connection,
|
|
47
|
+
statement_config=config_state.config.statement_config,
|
|
48
|
+
driver_features=config_state.config.driver_features,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
setattr(request.state, session_instance_key, session)
|
|
52
|
+
return session
|