sqlspec 0.10.1__py3-none-any.whl → 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

@@ -0,0 +1,521 @@
1
+ # ruff: noqa: B008
2
+ """Application dependency providers generators.
3
+
4
+ This module contains functions to create dependency providers for services and filters.
5
+
6
+ You should not have modify this module very often and should only be invoked under normal usage.
7
+ """
8
+
9
+ import datetime
10
+ import inspect
11
+ from collections.abc import Callable
12
+ from typing import (
13
+ Any,
14
+ Literal,
15
+ NamedTuple,
16
+ Optional,
17
+ TypedDict,
18
+ Union,
19
+ cast,
20
+ )
21
+ from uuid import UUID
22
+
23
+ from litestar.di import Provide
24
+ from litestar.params import Dependency, Parameter
25
+ from typing_extensions import NotRequired
26
+
27
+ from sqlspec.filters import (
28
+ BeforeAfter,
29
+ CollectionFilter,
30
+ FilterTypes,
31
+ LimitOffset,
32
+ NotInCollectionFilter,
33
+ OrderBy,
34
+ SearchFilter,
35
+ )
36
+ from sqlspec.utils.singleton import SingletonMeta
37
+ from sqlspec.utils.text import camelize
38
+
39
+ __all__ = (
40
+ "DEPENDENCY_DEFAULTS",
41
+ "BooleanOrNone",
42
+ "DTorNone",
43
+ "DependencyDefaults",
44
+ "FieldNameType",
45
+ "FilterConfig",
46
+ "HashableType",
47
+ "HashableValue",
48
+ "IntOrNone",
49
+ "SortOrder",
50
+ "SortOrderOrNone",
51
+ "StringOrNone",
52
+ "UuidOrNone",
53
+ "create_filter_dependencies",
54
+ "dep_cache",
55
+ )
56
+
57
+ DTorNone = Optional[datetime.datetime]
58
+ StringOrNone = Optional[str]
59
+ UuidOrNone = Optional[UUID]
60
+ IntOrNone = Optional[int]
61
+ BooleanOrNone = Optional[bool]
62
+ SortOrder = Literal["asc", "desc"]
63
+ SortOrderOrNone = Optional[SortOrder]
64
+ HashableValue = Union[str, int, float, bool, None]
65
+ HashableType = Union[HashableValue, tuple[Any, ...], tuple[tuple[str, Any], ...], tuple[HashableValue, ...]]
66
+
67
+
68
+ class DependencyDefaults:
69
+ FILTERS_DEPENDENCY_KEY: str = "filters"
70
+ """Key for the filters dependency."""
71
+ CREATED_FILTER_DEPENDENCY_KEY: str = "created_filter"
72
+ """Key for the created filter dependency."""
73
+ ID_FILTER_DEPENDENCY_KEY: str = "id_filter"
74
+ """Key for the id filter dependency."""
75
+ LIMIT_OFFSET_FILTER_DEPENDENCY_KEY: str = "limit_offset_filter"
76
+ """Key for the limit offset dependency."""
77
+ UPDATED_FILTER_DEPENDENCY_KEY: str = "updated_filter"
78
+ """Key for the updated filter dependency."""
79
+ ORDER_BY_FILTER_DEPENDENCY_KEY: str = "order_by_filter"
80
+ """Key for the order by dependency."""
81
+ SEARCH_FILTER_DEPENDENCY_KEY: str = "search_filter"
82
+ """Key for the search filter dependency."""
83
+ DEFAULT_PAGINATION_SIZE: int = 20
84
+ """Default pagination size."""
85
+
86
+
87
+ DEPENDENCY_DEFAULTS = DependencyDefaults()
88
+
89
+
90
+ class FieldNameType(NamedTuple):
91
+ """Type for field name and associated type information.
92
+
93
+ This allows for specifying both the field name and the expected type for filter values.
94
+ """
95
+
96
+ name: str
97
+ """Name of the field to filter on."""
98
+ type_hint: type[Any] = str
99
+ """Type of the filter value. Defaults to str."""
100
+
101
+
102
+ class FilterConfig(TypedDict):
103
+ """Configuration for generating dynamic filters."""
104
+
105
+ id_filter: NotRequired[type[Union[UUID, int, str]]]
106
+ """Indicates that the id filter should be enabled. When set, the type specified will be used for the :class:`CollectionFilter`."""
107
+ id_field: NotRequired[str]
108
+ """The field on the model that stored the primary key or identifier."""
109
+ sort_field: NotRequired[str]
110
+ """The default field to use for the sort filter."""
111
+ sort_order: NotRequired[SortOrder]
112
+ """The default order to use for the sort filter."""
113
+ pagination_type: NotRequired[Literal["limit_offset"]]
114
+ """When set, pagination is enabled based on the type specified."""
115
+ pagination_size: NotRequired[int]
116
+ """The size of the pagination. Defaults to `DEFAULT_PAGINATION_SIZE`."""
117
+ search: NotRequired[Union[str, set[str], list[str]]]
118
+ """Fields to enable search on. Can be a comma-separated string or a set of field names."""
119
+ search_ignore_case: NotRequired[bool]
120
+ """When set, search is case insensitive by default."""
121
+ created_at: NotRequired[bool]
122
+ """When set, created_at filter is enabled."""
123
+ updated_at: NotRequired[bool]
124
+ """When set, updated_at filter is enabled."""
125
+ not_in_fields: NotRequired[Union[FieldNameType, set[FieldNameType], list[Union[str, FieldNameType]]]]
126
+ """Fields that support not-in collection filters. Can be a single field or a set of fields with type information."""
127
+ in_fields: NotRequired[Union[FieldNameType, set[FieldNameType], list[Union[str, FieldNameType]]]]
128
+ """Fields that support in-collection filters. Can be a single field or a set of fields with type information."""
129
+
130
+
131
+ class DependencyCache(metaclass=SingletonMeta):
132
+ """Simple dependency cache for the application. This is used to help memoize dependencies that are generated dynamically."""
133
+
134
+ def __init__(self) -> None:
135
+ self.dependencies: dict[Union[int, str], dict[str, Provide]] = {}
136
+
137
+ def add_dependencies(self, key: Union[int, str], dependencies: dict[str, Provide]) -> None:
138
+ self.dependencies[key] = dependencies
139
+
140
+ def get_dependencies(self, key: Union[int, str]) -> Optional[dict[str, Provide]]:
141
+ return self.dependencies.get(key)
142
+
143
+
144
+ dep_cache = DependencyCache()
145
+
146
+
147
+ def create_filter_dependencies(
148
+ config: FilterConfig, dep_defaults: DependencyDefaults = DEPENDENCY_DEFAULTS
149
+ ) -> dict[str, Provide]:
150
+ """Create a dependency provider for the combined filter function.
151
+
152
+ Args:
153
+ config: FilterConfig instance with desired settings.
154
+ dep_defaults: Dependency defaults to use for the filter dependencies
155
+
156
+ Returns:
157
+ A dependency provider function for the combined filter function.
158
+ """
159
+ cache_key = hash(_make_hashable(config))
160
+ deps = dep_cache.get_dependencies(cache_key)
161
+ if deps is not None:
162
+ return deps
163
+ deps = _create_statement_filters(config, dep_defaults)
164
+ dep_cache.add_dependencies(cache_key, deps)
165
+ return deps
166
+
167
+
168
+ def _make_hashable(value: Any) -> HashableType:
169
+ """Convert a value into a hashable type.
170
+
171
+ This function converts any value into a hashable type by:
172
+ - Converting dictionaries to sorted tuples of (key, value) pairs
173
+ - Converting lists and sets to sorted tuples
174
+ - Preserving primitive types (str, int, float, bool, None)
175
+ - Converting any other type to its string representation
176
+
177
+ Args:
178
+ value: Any value that needs to be made hashable.
179
+
180
+ Returns:
181
+ A hashable version of the value.
182
+ """
183
+ if isinstance(value, dict):
184
+ # Convert dict to tuple of tuples with sorted keys
185
+ items = []
186
+ for k in sorted(value.keys()): # pyright: ignore
187
+ v = value[k] # pyright: ignore
188
+ items.append((str(k), _make_hashable(v))) # pyright: ignore
189
+ return tuple(items) # pyright: ignore
190
+ if isinstance(value, (list, set)):
191
+ hashable_items = [_make_hashable(item) for item in value] # pyright: ignore
192
+ filtered_items = [item for item in hashable_items if item is not None] # pyright: ignore
193
+ return tuple(sorted(filtered_items, key=str))
194
+ if isinstance(value, (str, int, float, bool, type(None))):
195
+ return value
196
+ return str(value)
197
+
198
+
199
+ def _create_statement_filters(
200
+ config: FilterConfig, dep_defaults: DependencyDefaults = DEPENDENCY_DEFAULTS
201
+ ) -> dict[str, Provide]:
202
+ """Create filter dependencies based on configuration.
203
+
204
+ Args:
205
+ config (FilterConfig): Configuration dictionary specifying which filters to enable
206
+ dep_defaults (DependencyDefaults): Dependency defaults to use for the filter dependencies
207
+
208
+ Returns:
209
+ dict[str, Provide]: Dictionary of filter provider functions
210
+ """
211
+ filters: dict[str, Provide] = {}
212
+
213
+ if config.get("id_filter", False):
214
+
215
+ def provide_id_filter( # pyright: ignore[reportUnknownParameterType]
216
+ ids: Optional[list[str]] = Parameter(query="ids", default=None, required=False),
217
+ ) -> CollectionFilter: # pyright: ignore[reportMissingTypeArgument]
218
+ return CollectionFilter(field_name=config.get("id_field", "id"), values=ids)
219
+
220
+ filters[dep_defaults.ID_FILTER_DEPENDENCY_KEY] = Provide(provide_id_filter, sync_to_thread=False) # pyright: ignore[reportUnknownArgumentType]
221
+
222
+ if config.get("created_at", False):
223
+
224
+ def provide_created_filter(
225
+ before: DTorNone = Parameter(query="createdBefore", default=None, required=False),
226
+ after: DTorNone = Parameter(query="createdAfter", default=None, required=False),
227
+ ) -> BeforeAfter:
228
+ return BeforeAfter("created_at", before, after)
229
+
230
+ filters[dep_defaults.CREATED_FILTER_DEPENDENCY_KEY] = Provide(provide_created_filter, sync_to_thread=False)
231
+
232
+ if config.get("updated_at", False):
233
+
234
+ def provide_updated_filter(
235
+ before: DTorNone = Parameter(query="updatedBefore", default=None, required=False),
236
+ after: DTorNone = Parameter(query="updatedAfter", default=None, required=False),
237
+ ) -> BeforeAfter:
238
+ return BeforeAfter("updated_at", before, after)
239
+
240
+ filters[dep_defaults.UPDATED_FILTER_DEPENDENCY_KEY] = Provide(provide_updated_filter, sync_to_thread=False)
241
+
242
+ if config.get("pagination_type") == "limit_offset":
243
+
244
+ def provide_limit_offset_pagination(
245
+ current_page: int = Parameter(ge=1, query="currentPage", default=1, required=False),
246
+ page_size: int = Parameter(
247
+ query="pageSize",
248
+ ge=1,
249
+ default=config.get("pagination_size", dep_defaults.DEFAULT_PAGINATION_SIZE),
250
+ required=False,
251
+ ),
252
+ ) -> LimitOffset:
253
+ return LimitOffset(page_size, page_size * (current_page - 1))
254
+
255
+ filters[dep_defaults.LIMIT_OFFSET_FILTER_DEPENDENCY_KEY] = Provide(
256
+ provide_limit_offset_pagination, sync_to_thread=False
257
+ )
258
+
259
+ if search_fields := config.get("search"):
260
+
261
+ def provide_search_filter(
262
+ search_string: StringOrNone = Parameter(
263
+ title="Field to search",
264
+ query="searchString",
265
+ default=None,
266
+ required=False,
267
+ ),
268
+ ignore_case: BooleanOrNone = Parameter(
269
+ title="Search should be case sensitive",
270
+ query="searchIgnoreCase",
271
+ default=config.get("search_ignore_case", False),
272
+ required=False,
273
+ ),
274
+ ) -> SearchFilter:
275
+ # Handle both string and set input types for search fields
276
+ field_names = set(search_fields.split(",")) if isinstance(search_fields, str) else set(search_fields)
277
+
278
+ return SearchFilter(
279
+ field_name=field_names,
280
+ value=search_string, # type: ignore[arg-type]
281
+ ignore_case=ignore_case or False,
282
+ )
283
+
284
+ filters[dep_defaults.SEARCH_FILTER_DEPENDENCY_KEY] = Provide(provide_search_filter, sync_to_thread=False)
285
+
286
+ if sort_field := config.get("sort_field"):
287
+
288
+ def provide_order_by(
289
+ field_name: StringOrNone = Parameter(
290
+ title="Order by field",
291
+ query="orderBy",
292
+ default=sort_field,
293
+ required=False,
294
+ ),
295
+ sort_order: SortOrderOrNone = Parameter(
296
+ title="Field to search",
297
+ query="sortOrder",
298
+ default=config.get("sort_order", "desc"),
299
+ required=False,
300
+ ),
301
+ ) -> OrderBy:
302
+ return OrderBy(field_name=field_name, sort_order=sort_order) # type: ignore[arg-type]
303
+
304
+ filters[dep_defaults.ORDER_BY_FILTER_DEPENDENCY_KEY] = Provide(provide_order_by, sync_to_thread=False)
305
+
306
+ # Add not_in filter providers
307
+ if not_in_fields := config.get("not_in_fields"):
308
+ # Get all field names, handling both strings and FieldNameType objects
309
+ not_in_fields = {not_in_fields} if isinstance(not_in_fields, (str, FieldNameType)) else not_in_fields
310
+
311
+ for field_def in not_in_fields:
312
+ field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
313
+
314
+ def create_not_in_filter_provider( # pyright: ignore
315
+ field_name: FieldNameType,
316
+ ) -> Callable[..., Optional[NotInCollectionFilter[field_def.type_hint]]]: # type: ignore
317
+ def provide_not_in_filter( # pyright: ignore
318
+ values: Optional[list[field_name.type_hint]] = Parameter( # type: ignore
319
+ query=camelize(f"{field_name.name}_not_in"), default=None, required=False
320
+ ),
321
+ ) -> Optional[NotInCollectionFilter[field_name.type_hint]]: # type: ignore
322
+ return (
323
+ NotInCollectionFilter[field_name.type_hint](field_name=field_name.name, values=values) # type: ignore
324
+ if values
325
+ else None
326
+ )
327
+
328
+ return provide_not_in_filter # pyright: ignore
329
+
330
+ provider = create_not_in_filter_provider(field_def) # pyright: ignore
331
+ filters[f"{field_def.name}_not_in_filter"] = Provide(provider, sync_to_thread=False) # pyright: ignore
332
+
333
+ # Add in filter providers
334
+ if in_fields := config.get("in_fields"):
335
+ # Get all field names, handling both strings and FieldNameType objects
336
+ in_fields = {in_fields} if isinstance(in_fields, (str, FieldNameType)) else in_fields
337
+
338
+ for field_def in in_fields:
339
+ field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
340
+
341
+ def create_in_filter_provider( # pyright: ignore
342
+ field_name: FieldNameType,
343
+ ) -> Callable[..., Optional[CollectionFilter[field_def.type_hint]]]: # type: ignore # pyright: ignore
344
+ def provide_in_filter( # pyright: ignore
345
+ values: Optional[list[field_name.type_hint]] = Parameter( # type: ignore # pyright: ignore
346
+ query=camelize(f"{field_name.name}_in"), default=None, required=False
347
+ ),
348
+ ) -> Optional[CollectionFilter[field_name.type_hint]]: # type: ignore # pyright: ignore
349
+ return (
350
+ CollectionFilter[field_name.type_hint](field_name=field_name.name, values=values) # type: ignore # pyright: ignore
351
+ if values
352
+ else None
353
+ )
354
+
355
+ return provide_in_filter # pyright: ignore
356
+
357
+ provider = create_in_filter_provider(field_def) # type: ignore
358
+ filters[f"{field_def.name}_in_filter"] = Provide(provider, sync_to_thread=False) # pyright: ignore
359
+
360
+ if filters:
361
+ filters[dep_defaults.FILTERS_DEPENDENCY_KEY] = Provide(
362
+ _create_filter_aggregate_function(config), sync_to_thread=False
363
+ )
364
+
365
+ return filters
366
+
367
+
368
+ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., list[FilterTypes]]: # noqa: PLR0915
369
+ """Create a filter function based on the provided configuration.
370
+
371
+ Args:
372
+ config: The filter configuration.
373
+
374
+ Returns:
375
+ A function that returns a list of filters based on the configuration.
376
+ """
377
+
378
+ parameters: dict[str, inspect.Parameter] = {}
379
+ annotations: dict[str, Any] = {}
380
+
381
+ # Build parameters based on config
382
+ if cls := config.get("id_filter"):
383
+ parameters["id_filter"] = inspect.Parameter(
384
+ name="id_filter",
385
+ kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
386
+ default=Dependency(skip_validation=True),
387
+ annotation=CollectionFilter[cls], # type: ignore[valid-type]
388
+ )
389
+ annotations["id_filter"] = CollectionFilter[cls] # type: ignore[valid-type]
390
+
391
+ if config.get("created_at"):
392
+ parameters["created_filter"] = inspect.Parameter(
393
+ name="created_filter",
394
+ kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
395
+ default=Dependency(skip_validation=True),
396
+ annotation=BeforeAfter,
397
+ )
398
+ annotations["created_filter"] = BeforeAfter
399
+
400
+ if config.get("updated_at"):
401
+ parameters["updated_filter"] = inspect.Parameter(
402
+ name="updated_filter",
403
+ kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
404
+ default=Dependency(skip_validation=True),
405
+ annotation=BeforeAfter,
406
+ )
407
+ annotations["updated_filter"] = BeforeAfter
408
+
409
+ if config.get("search"):
410
+ parameters["search_filter"] = inspect.Parameter(
411
+ name="search_filter",
412
+ kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
413
+ default=Dependency(skip_validation=True),
414
+ annotation=SearchFilter,
415
+ )
416
+ annotations["search_filter"] = SearchFilter
417
+
418
+ if config.get("pagination_type") == "limit_offset":
419
+ parameters["limit_offset_filter"] = inspect.Parameter(
420
+ name="limit_offset_filter",
421
+ kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
422
+ default=Dependency(skip_validation=True),
423
+ annotation=LimitOffset,
424
+ )
425
+ annotations["limit_offset_filter"] = LimitOffset
426
+
427
+ if config.get("sort_field"):
428
+ parameters["order_by_filter"] = inspect.Parameter(
429
+ name="order_by_filter",
430
+ kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
431
+ default=Dependency(skip_validation=True),
432
+ annotation=OrderBy,
433
+ )
434
+ annotations["order_by_filter"] = OrderBy
435
+
436
+ # Add parameters for not_in filters
437
+ if not_in_fields := config.get("not_in_fields"):
438
+ for field_def in not_in_fields:
439
+ field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
440
+ parameters[f"{field_def.name}_not_in_filter"] = inspect.Parameter(
441
+ name=f"{field_def.name}_not_in_filter",
442
+ kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
443
+ default=Dependency(skip_validation=True),
444
+ annotation=NotInCollectionFilter[field_def.type_hint], # type: ignore
445
+ )
446
+ annotations[f"{field_def.name}_not_in_filter"] = NotInCollectionFilter[field_def.type_hint] # type: ignore
447
+
448
+ # Add parameters for in filters
449
+ if in_fields := config.get("in_fields"):
450
+ for field_def in in_fields:
451
+ field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
452
+ parameters[f"{field_def.name}_in_filter"] = inspect.Parameter(
453
+ name=f"{field_def.name}_in_filter",
454
+ kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
455
+ default=Dependency(skip_validation=True),
456
+ annotation=CollectionFilter[field_def.type_hint], # type: ignore
457
+ )
458
+ annotations[f"{field_def.name}_in_filter"] = CollectionFilter[field_def.type_hint] # type: ignore
459
+
460
+ def provide_filters(**kwargs: FilterTypes) -> list[FilterTypes]:
461
+ """Provide filter dependencies based on configuration.
462
+
463
+ Args:
464
+ **kwargs: Filter parameters dynamically provided based on configuration.
465
+
466
+ Returns:
467
+ list[FilterTypes]: List of configured filters.
468
+ """
469
+ filters: list[FilterTypes] = []
470
+ if id_filter := kwargs.get("id_filter"):
471
+ filters.append(id_filter)
472
+ if created_filter := kwargs.get("created_filter"):
473
+ filters.append(created_filter)
474
+ if limit_offset := kwargs.get("limit_offset_filter"):
475
+ filters.append(limit_offset)
476
+ if updated_filter := kwargs.get("updated_filter"):
477
+ filters.append(updated_filter)
478
+ if (
479
+ (search_filter := cast("Optional[SearchFilter]", kwargs.get("search_filter")))
480
+ and search_filter is not None # pyright: ignore[reportUnnecessaryComparison]
481
+ and search_filter.field_name is not None # pyright: ignore[reportUnnecessaryComparison]
482
+ and search_filter.value is not None # pyright: ignore[reportUnnecessaryComparison]
483
+ ):
484
+ filters.append(search_filter)
485
+ if (
486
+ (order_by := cast("Optional[OrderBy]", kwargs.get("order_by_filter")))
487
+ and order_by is not None # pyright: ignore[reportUnnecessaryComparison]
488
+ and order_by.field_name is not None # pyright: ignore[reportUnnecessaryComparison]
489
+ ):
490
+ filters.append(order_by)
491
+
492
+ # Add not_in filters
493
+ if not_in_fields := config.get("not_in_fields"):
494
+ # Get all field names, handling both strings and FieldNameType objects
495
+ not_in_fields = {not_in_fields} if isinstance(not_in_fields, (str, FieldNameType)) else not_in_fields
496
+ for field_def in not_in_fields:
497
+ field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
498
+ filter_ = kwargs.get(f"{field_def.name}_not_in_filter")
499
+ if filter_ is not None:
500
+ filters.append(filter_)
501
+
502
+ # Add in filters
503
+ if in_fields := config.get("in_fields"):
504
+ # Get all field names, handling both strings and FieldNameType objects
505
+ in_fields = {in_fields} if isinstance(in_fields, (str, FieldNameType)) else in_fields
506
+ for field_def in in_fields:
507
+ field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
508
+ filter_ = kwargs.get(f"{field_def.name}_in_filter")
509
+ if filter_ is not None:
510
+ filters.append(filter_)
511
+ return filters
512
+
513
+ # Set both signature and annotations
514
+ provide_filters.__signature__ = inspect.Signature( # type: ignore
515
+ parameters=list(parameters.values()),
516
+ return_annotation=list[FilterTypes],
517
+ )
518
+ provide_filters.__annotations__ = annotations
519
+ provide_filters.__annotations__["return"] = list[FilterTypes]
520
+
521
+ return provide_filters