sqlspec 0.14.1__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (158) hide show
  1. sqlspec/__init__.py +50 -25
  2. sqlspec/__main__.py +1 -1
  3. sqlspec/__metadata__.py +1 -3
  4. sqlspec/_serialization.py +1 -2
  5. sqlspec/_sql.py +256 -120
  6. sqlspec/_typing.py +278 -142
  7. sqlspec/adapters/adbc/__init__.py +4 -3
  8. sqlspec/adapters/adbc/_types.py +12 -0
  9. sqlspec/adapters/adbc/config.py +115 -260
  10. sqlspec/adapters/adbc/driver.py +462 -367
  11. sqlspec/adapters/aiosqlite/__init__.py +18 -3
  12. sqlspec/adapters/aiosqlite/_types.py +13 -0
  13. sqlspec/adapters/aiosqlite/config.py +199 -129
  14. sqlspec/adapters/aiosqlite/driver.py +230 -269
  15. sqlspec/adapters/asyncmy/__init__.py +18 -3
  16. sqlspec/adapters/asyncmy/_types.py +12 -0
  17. sqlspec/adapters/asyncmy/config.py +80 -168
  18. sqlspec/adapters/asyncmy/driver.py +260 -225
  19. sqlspec/adapters/asyncpg/__init__.py +19 -4
  20. sqlspec/adapters/asyncpg/_types.py +17 -0
  21. sqlspec/adapters/asyncpg/config.py +82 -181
  22. sqlspec/adapters/asyncpg/driver.py +285 -383
  23. sqlspec/adapters/bigquery/__init__.py +17 -3
  24. sqlspec/adapters/bigquery/_types.py +12 -0
  25. sqlspec/adapters/bigquery/config.py +191 -258
  26. sqlspec/adapters/bigquery/driver.py +474 -646
  27. sqlspec/adapters/duckdb/__init__.py +14 -3
  28. sqlspec/adapters/duckdb/_types.py +12 -0
  29. sqlspec/adapters/duckdb/config.py +415 -351
  30. sqlspec/adapters/duckdb/driver.py +343 -413
  31. sqlspec/adapters/oracledb/__init__.py +19 -5
  32. sqlspec/adapters/oracledb/_types.py +14 -0
  33. sqlspec/adapters/oracledb/config.py +123 -379
  34. sqlspec/adapters/oracledb/driver.py +507 -560
  35. sqlspec/adapters/psqlpy/__init__.py +13 -3
  36. sqlspec/adapters/psqlpy/_types.py +11 -0
  37. sqlspec/adapters/psqlpy/config.py +93 -254
  38. sqlspec/adapters/psqlpy/driver.py +505 -234
  39. sqlspec/adapters/psycopg/__init__.py +19 -5
  40. sqlspec/adapters/psycopg/_types.py +17 -0
  41. sqlspec/adapters/psycopg/config.py +143 -403
  42. sqlspec/adapters/psycopg/driver.py +706 -872
  43. sqlspec/adapters/sqlite/__init__.py +14 -3
  44. sqlspec/adapters/sqlite/_types.py +11 -0
  45. sqlspec/adapters/sqlite/config.py +202 -118
  46. sqlspec/adapters/sqlite/driver.py +264 -303
  47. sqlspec/base.py +105 -9
  48. sqlspec/{statement/builder → builder}/__init__.py +12 -14
  49. sqlspec/{statement/builder → builder}/_base.py +120 -55
  50. sqlspec/{statement/builder → builder}/_column.py +17 -6
  51. sqlspec/{statement/builder → builder}/_ddl.py +46 -79
  52. sqlspec/{statement/builder → builder}/_ddl_utils.py +5 -10
  53. sqlspec/{statement/builder → builder}/_delete.py +6 -25
  54. sqlspec/{statement/builder → builder}/_insert.py +6 -64
  55. sqlspec/builder/_merge.py +56 -0
  56. sqlspec/{statement/builder → builder}/_parsing_utils.py +3 -10
  57. sqlspec/{statement/builder → builder}/_select.py +11 -56
  58. sqlspec/{statement/builder → builder}/_update.py +12 -18
  59. sqlspec/{statement/builder → builder}/mixins/__init__.py +10 -14
  60. sqlspec/{statement/builder → builder}/mixins/_cte_and_set_ops.py +48 -59
  61. sqlspec/{statement/builder → builder}/mixins/_insert_operations.py +22 -16
  62. sqlspec/{statement/builder → builder}/mixins/_join_operations.py +1 -3
  63. sqlspec/{statement/builder → builder}/mixins/_merge_operations.py +3 -5
  64. sqlspec/{statement/builder → builder}/mixins/_order_limit_operations.py +3 -3
  65. sqlspec/{statement/builder → builder}/mixins/_pivot_operations.py +4 -8
  66. sqlspec/{statement/builder → builder}/mixins/_select_operations.py +21 -36
  67. sqlspec/{statement/builder → builder}/mixins/_update_operations.py +3 -14
  68. sqlspec/{statement/builder → builder}/mixins/_where_clause.py +52 -79
  69. sqlspec/cli.py +4 -5
  70. sqlspec/config.py +180 -133
  71. sqlspec/core/__init__.py +63 -0
  72. sqlspec/core/cache.py +873 -0
  73. sqlspec/core/compiler.py +396 -0
  74. sqlspec/core/filters.py +828 -0
  75. sqlspec/core/hashing.py +310 -0
  76. sqlspec/core/parameters.py +1209 -0
  77. sqlspec/core/result.py +664 -0
  78. sqlspec/{statement → core}/splitter.py +321 -191
  79. sqlspec/core/statement.py +651 -0
  80. sqlspec/driver/__init__.py +7 -10
  81. sqlspec/driver/_async.py +387 -176
  82. sqlspec/driver/_common.py +527 -289
  83. sqlspec/driver/_sync.py +390 -172
  84. sqlspec/driver/mixins/__init__.py +2 -19
  85. sqlspec/driver/mixins/_result_tools.py +168 -0
  86. sqlspec/driver/mixins/_sql_translator.py +6 -3
  87. sqlspec/exceptions.py +5 -252
  88. sqlspec/extensions/aiosql/adapter.py +93 -96
  89. sqlspec/extensions/litestar/config.py +0 -1
  90. sqlspec/extensions/litestar/handlers.py +15 -26
  91. sqlspec/extensions/litestar/plugin.py +16 -14
  92. sqlspec/extensions/litestar/providers.py +17 -52
  93. sqlspec/loader.py +424 -105
  94. sqlspec/migrations/__init__.py +12 -0
  95. sqlspec/migrations/base.py +92 -68
  96. sqlspec/migrations/commands.py +24 -106
  97. sqlspec/migrations/loaders.py +402 -0
  98. sqlspec/migrations/runner.py +49 -51
  99. sqlspec/migrations/tracker.py +31 -44
  100. sqlspec/migrations/utils.py +64 -24
  101. sqlspec/protocols.py +7 -183
  102. sqlspec/storage/__init__.py +1 -1
  103. sqlspec/storage/backends/base.py +37 -40
  104. sqlspec/storage/backends/fsspec.py +136 -112
  105. sqlspec/storage/backends/obstore.py +138 -160
  106. sqlspec/storage/capabilities.py +5 -4
  107. sqlspec/storage/registry.py +57 -106
  108. sqlspec/typing.py +136 -115
  109. sqlspec/utils/__init__.py +2 -3
  110. sqlspec/utils/correlation.py +0 -3
  111. sqlspec/utils/deprecation.py +6 -6
  112. sqlspec/utils/fixtures.py +6 -6
  113. sqlspec/utils/logging.py +0 -2
  114. sqlspec/utils/module_loader.py +7 -12
  115. sqlspec/utils/singleton.py +0 -1
  116. sqlspec/utils/sync_tools.py +16 -37
  117. sqlspec/utils/text.py +12 -51
  118. sqlspec/utils/type_guards.py +443 -232
  119. {sqlspec-0.14.1.dist-info → sqlspec-0.15.0.dist-info}/METADATA +7 -2
  120. sqlspec-0.15.0.dist-info/RECORD +134 -0
  121. sqlspec/adapters/adbc/transformers.py +0 -108
  122. sqlspec/driver/connection.py +0 -207
  123. sqlspec/driver/mixins/_cache.py +0 -114
  124. sqlspec/driver/mixins/_csv_writer.py +0 -91
  125. sqlspec/driver/mixins/_pipeline.py +0 -508
  126. sqlspec/driver/mixins/_query_tools.py +0 -796
  127. sqlspec/driver/mixins/_result_utils.py +0 -138
  128. sqlspec/driver/mixins/_storage.py +0 -912
  129. sqlspec/driver/mixins/_type_coercion.py +0 -128
  130. sqlspec/driver/parameters.py +0 -138
  131. sqlspec/statement/__init__.py +0 -21
  132. sqlspec/statement/builder/_merge.py +0 -95
  133. sqlspec/statement/cache.py +0 -50
  134. sqlspec/statement/filters.py +0 -625
  135. sqlspec/statement/parameters.py +0 -956
  136. sqlspec/statement/pipelines/__init__.py +0 -210
  137. sqlspec/statement/pipelines/analyzers/__init__.py +0 -9
  138. sqlspec/statement/pipelines/analyzers/_analyzer.py +0 -646
  139. sqlspec/statement/pipelines/context.py +0 -109
  140. sqlspec/statement/pipelines/transformers/__init__.py +0 -7
  141. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +0 -88
  142. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +0 -1247
  143. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +0 -76
  144. sqlspec/statement/pipelines/validators/__init__.py +0 -23
  145. sqlspec/statement/pipelines/validators/_dml_safety.py +0 -290
  146. sqlspec/statement/pipelines/validators/_parameter_style.py +0 -370
  147. sqlspec/statement/pipelines/validators/_performance.py +0 -714
  148. sqlspec/statement/pipelines/validators/_security.py +0 -967
  149. sqlspec/statement/result.py +0 -435
  150. sqlspec/statement/sql.py +0 -1774
  151. sqlspec/utils/cached_property.py +0 -25
  152. sqlspec/utils/statement_hashing.py +0 -203
  153. sqlspec-0.14.1.dist-info/RECORD +0 -145
  154. /sqlspec/{statement/builder → builder}/mixins/_delete_operations.py +0 -0
  155. {sqlspec-0.14.1.dist-info → sqlspec-0.15.0.dist-info}/WHEEL +0 -0
  156. {sqlspec-0.14.1.dist-info → sqlspec-0.15.0.dist-info}/entry_points.txt +0 -0
  157. {sqlspec-0.14.1.dist-info → sqlspec-0.15.0.dist-info}/licenses/LICENSE +0 -0
  158. {sqlspec-0.14.1.dist-info → sqlspec-0.15.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,828 @@
1
+ """Filter system for SQL statement manipulation.
2
+
3
+ This module provides filters that can be applied to SQL statements to add
4
+ WHERE clauses, ORDER BY clauses, LIMIT/OFFSET, and other modifications.
5
+
6
+ Components:
7
+ - StatementFilter: Abstract base class for all filters
8
+ - BeforeAfterFilter: Date range filtering
9
+ - InCollectionFilter: IN clause filtering
10
+ - LimitOffsetFilter: Pagination support
11
+ - OrderByFilter: Sorting support
12
+ - SearchFilter: Text search filtering
13
+ - Various collection and negation filters
14
+
15
+ Features:
16
+ - Parameter conflict resolution
17
+ - Type-safe filter application
18
+ - Cacheable filter configurations
19
+ """
20
+
21
+ import uuid
22
+ from abc import ABC, abstractmethod
23
+ from collections import abc
24
+ from collections.abc import Sequence
25
+ from datetime import datetime
26
+ from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union
27
+
28
+ from sqlglot import exp
29
+ from typing_extensions import TypeAlias, TypeVar
30
+
31
+ if TYPE_CHECKING:
32
+ from sqlglot.expressions import Condition
33
+
34
+ from sqlspec.core.statement import SQL
35
+
36
+ __all__ = (
37
+ "AnyCollectionFilter",
38
+ "BeforeAfterFilter",
39
+ "FilterTypeT",
40
+ "FilterTypes",
41
+ "InAnyFilter",
42
+ "InCollectionFilter",
43
+ "LimitOffsetFilter",
44
+ "NotAnyCollectionFilter",
45
+ "NotInCollectionFilter",
46
+ "NotInSearchFilter",
47
+ "OffsetPagination",
48
+ "OnBeforeAfterFilter",
49
+ "OrderByFilter",
50
+ "PaginationFilter",
51
+ "SearchFilter",
52
+ "StatementFilter",
53
+ "apply_filter",
54
+ )
55
+
56
+ T = TypeVar("T")
57
+ FilterTypeT = TypeVar("FilterTypeT", bound="StatementFilter")
58
+
59
+
60
+ class StatementFilter(ABC):
61
+ """Abstract base class for filters that can be appended to a statement."""
62
+
63
+ __slots__ = ()
64
+
65
+ @abstractmethod
66
+ def append_to_statement(self, statement: "SQL") -> "SQL":
67
+ """Append the filter to the statement.
68
+
69
+ This method should modify the SQL expression only, not the parameters.
70
+ Parameters should be provided via extract_parameters().
71
+ """
72
+
73
+ def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]:
74
+ """Extract parameters that this filter contributes.
75
+
76
+ Returns:
77
+ Tuple of (positional_parameters, named_parameters) where:
78
+ - positional_parameters: List of positional parameter values
79
+ - named_parameters: Dict of parameter name to value
80
+ """
81
+ return [], {}
82
+
83
+ def _resolve_parameter_conflicts(self, statement: "SQL", proposed_names: list[str]) -> list[str]:
84
+ """Resolve parameter name conflicts.
85
+
86
+ Args:
87
+ statement: The SQL statement to check for existing parameters
88
+ proposed_names: List of proposed parameter names
89
+
90
+ Returns:
91
+ List of resolved parameter names (same length as proposed_names)
92
+ """
93
+ existing_params = set(statement._named_parameters.keys())
94
+ existing_params.update(statement.parameters.keys() if isinstance(statement.parameters, dict) else [])
95
+
96
+ resolved_names = []
97
+ for name in proposed_names:
98
+ if name in existing_params:
99
+ unique_suffix = str(uuid.uuid4()).replace("-", "")[:8]
100
+ resolved_name = f"{name}_{unique_suffix}"
101
+ else:
102
+ resolved_name = name
103
+ resolved_names.append(resolved_name)
104
+ existing_params.add(resolved_name)
105
+
106
+ return resolved_names
107
+
108
+ @abstractmethod
109
+ def get_cache_key(self) -> tuple[Any, ...]:
110
+ """Return a cache key for this filter's configuration.
111
+
112
+ Returns:
113
+ Tuple of hashable values representing the filter's configuration
114
+ """
115
+
116
+
117
+ class BeforeAfterFilter(StatementFilter):
118
+ """Filter for datetime range queries.
119
+
120
+ Applies WHERE clauses for before/after datetime filtering.
121
+ """
122
+
123
+ __slots__ = ("_param_name_after", "_param_name_before", "after", "before", "field_name")
124
+
125
+ field_name: str
126
+ before: Optional[datetime]
127
+ after: Optional[datetime]
128
+
129
+ def __init__(self, field_name: str, before: Optional[datetime] = None, after: Optional[datetime] = None) -> None:
130
+ """Initialize the BeforeAfterFilter.
131
+
132
+ Args:
133
+ field_name: Name of the model attribute to filter on.
134
+ before: Filter results where field earlier than this.
135
+ after: Filter results where field later than this.
136
+ """
137
+ self.field_name = field_name
138
+ self.before = before
139
+ self.after = after
140
+
141
+ self._param_name_before: Optional[str] = None
142
+ self._param_name_after: Optional[str] = None
143
+
144
+ if self.before:
145
+ self._param_name_before = f"{self.field_name}_before"
146
+ if self.after:
147
+ self._param_name_after = f"{self.field_name}_after"
148
+
149
+ def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]:
150
+ """Extract filter parameters."""
151
+ named_parameters = {}
152
+ if self.before and self._param_name_before:
153
+ named_parameters[self._param_name_before] = self.before
154
+ if self.after and self._param_name_after:
155
+ named_parameters[self._param_name_after] = self.after
156
+ return [], named_parameters
157
+
158
+ def append_to_statement(self, statement: "SQL") -> "SQL":
159
+ """Apply filter to SQL expression only."""
160
+ conditions: list[Condition] = []
161
+ col_expr = exp.column(self.field_name)
162
+
163
+ # Resolve parameter name conflicts
164
+ proposed_names = []
165
+ if self.before and self._param_name_before:
166
+ proposed_names.append(self._param_name_before)
167
+ if self.after and self._param_name_after:
168
+ proposed_names.append(self._param_name_after)
169
+
170
+ if not proposed_names:
171
+ return statement
172
+
173
+ resolved_names = self._resolve_parameter_conflicts(statement, proposed_names)
174
+
175
+ param_idx = 0
176
+ result = statement
177
+ if self.before and self._param_name_before:
178
+ before_param_name = resolved_names[param_idx]
179
+ param_idx += 1
180
+ conditions.append(exp.LT(this=col_expr, expression=exp.Placeholder(this=before_param_name)))
181
+ result = result.add_named_parameter(before_param_name, self.before)
182
+
183
+ if self.after and self._param_name_after:
184
+ after_param_name = resolved_names[param_idx]
185
+ conditions.append(exp.GT(this=col_expr, expression=exp.Placeholder(this=after_param_name)))
186
+ result = result.add_named_parameter(after_param_name, self.after)
187
+
188
+ final_condition = conditions[0]
189
+ for cond in conditions[1:]:
190
+ final_condition = exp.And(this=final_condition, expression=cond)
191
+ return result.where(final_condition)
192
+
193
+ def get_cache_key(self) -> tuple[Any, ...]:
194
+ """Return cache key for this filter configuration."""
195
+ return ("BeforeAfterFilter", self.field_name, self.before, self.after)
196
+
197
+
198
+ class OnBeforeAfterFilter(StatementFilter):
199
+ """Data required to filter a query on a ``datetime`` column."""
200
+
201
+ __slots__ = ("_param_name_on_or_after", "_param_name_on_or_before", "field_name", "on_or_after", "on_or_before")
202
+
203
+ field_name: str
204
+ on_or_before: Optional[datetime]
205
+ on_or_after: Optional[datetime]
206
+
207
+ def __init__(
208
+ self, field_name: str, on_or_before: Optional[datetime] = None, on_or_after: Optional[datetime] = None
209
+ ) -> None:
210
+ """Initialize the OnBeforeAfterFilter.
211
+
212
+ Args:
213
+ field_name: Name of the model attribute to filter on.
214
+ on_or_before: Filter results where field is on or earlier than this.
215
+ on_or_after: Filter results where field on or later than this.
216
+ """
217
+ self.field_name = field_name
218
+ self.on_or_before = on_or_before
219
+ self.on_or_after = on_or_after
220
+
221
+ self._param_name_on_or_before: Optional[str] = None
222
+ self._param_name_on_or_after: Optional[str] = None
223
+
224
+ if self.on_or_before:
225
+ self._param_name_on_or_before = f"{self.field_name}_on_or_before"
226
+ if self.on_or_after:
227
+ self._param_name_on_or_after = f"{self.field_name}_on_or_after"
228
+
229
+ def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]:
230
+ """Extract filter parameters."""
231
+ named_parameters = {}
232
+ if self.on_or_before and self._param_name_on_or_before:
233
+ named_parameters[self._param_name_on_or_before] = self.on_or_before
234
+ if self.on_or_after and self._param_name_on_or_after:
235
+ named_parameters[self._param_name_on_or_after] = self.on_or_after
236
+ return [], named_parameters
237
+
238
+ def append_to_statement(self, statement: "SQL") -> "SQL":
239
+ conditions: list[Condition] = []
240
+
241
+ # Resolve parameter name conflicts
242
+ proposed_names = []
243
+ if self.on_or_before and self._param_name_on_or_before:
244
+ proposed_names.append(self._param_name_on_or_before)
245
+ if self.on_or_after and self._param_name_on_or_after:
246
+ proposed_names.append(self._param_name_on_or_after)
247
+
248
+ if not proposed_names:
249
+ return statement
250
+
251
+ resolved_names = self._resolve_parameter_conflicts(statement, proposed_names)
252
+
253
+ param_idx = 0
254
+ result = statement
255
+ if self.on_or_before and self._param_name_on_or_before:
256
+ before_param_name = resolved_names[param_idx]
257
+ param_idx += 1
258
+ conditions.append(
259
+ exp.LTE(this=exp.column(self.field_name), expression=exp.Placeholder(this=before_param_name))
260
+ )
261
+ result = result.add_named_parameter(before_param_name, self.on_or_before)
262
+
263
+ if self.on_or_after and self._param_name_on_or_after:
264
+ after_param_name = resolved_names[param_idx]
265
+ conditions.append(
266
+ exp.GTE(this=exp.column(self.field_name), expression=exp.Placeholder(this=after_param_name))
267
+ )
268
+ result = result.add_named_parameter(after_param_name, self.on_or_after)
269
+
270
+ final_condition = conditions[0]
271
+ for cond in conditions[1:]:
272
+ final_condition = exp.And(this=final_condition, expression=cond)
273
+ return result.where(final_condition)
274
+
275
+ def get_cache_key(self) -> tuple[Any, ...]:
276
+ """Return cache key for this filter configuration."""
277
+ return ("OnBeforeAfterFilter", self.field_name, self.on_or_before, self.on_or_after)
278
+
279
+
280
+ class InAnyFilter(StatementFilter, ABC, Generic[T]):
281
+ """Subclass for methods that have a `prefer_any` attribute."""
282
+
283
+ __slots__ = ()
284
+
285
+ def append_to_statement(self, statement: "SQL") -> "SQL":
286
+ raise NotImplementedError
287
+
288
+
289
+ class InCollectionFilter(InAnyFilter[T]):
290
+ """Filter for IN clause queries.
291
+
292
+ Constructs WHERE ... IN (...) clauses.
293
+ """
294
+
295
+ __slots__ = ("_param_names", "field_name", "values")
296
+
297
+ field_name: str
298
+ values: Optional[abc.Collection[T]]
299
+
300
+ def __init__(self, field_name: str, values: Optional[abc.Collection[T]]) -> None:
301
+ """Initialize the InCollectionFilter.
302
+
303
+ Args:
304
+ field_name: Name of the model attribute to filter on.
305
+ values: Values for ``IN`` clause. An empty list will return an empty result set,
306
+ however, if ``None``, the filter is not applied to the query, and all rows are returned.
307
+ """
308
+ self.field_name = field_name
309
+ self.values = values
310
+
311
+ self._param_names: list[str] = []
312
+ if self.values:
313
+ for i, _ in enumerate(self.values):
314
+ self._param_names.append(f"{self.field_name}_in_{i}")
315
+
316
+ def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]:
317
+ """Extract filter parameters."""
318
+ named_parameters = {}
319
+ if self.values:
320
+ for i, value in enumerate(self.values):
321
+ named_parameters[self._param_names[i]] = value
322
+ return [], named_parameters
323
+
324
+ def append_to_statement(self, statement: "SQL") -> "SQL":
325
+ if self.values is None:
326
+ return statement
327
+
328
+ if not self.values:
329
+ return statement.where(exp.false())
330
+
331
+ # Resolve parameter name conflicts
332
+ resolved_names = self._resolve_parameter_conflicts(statement, self._param_names)
333
+
334
+ placeholder_expressions: list[exp.Placeholder] = [
335
+ exp.Placeholder(this=param_name) for param_name in resolved_names
336
+ ]
337
+
338
+ result = statement.where(exp.In(this=exp.column(self.field_name), expressions=placeholder_expressions))
339
+
340
+ # Add parameters with resolved names
341
+ for resolved_name, value in zip(resolved_names, self.values):
342
+ result = result.add_named_parameter(resolved_name, value)
343
+ return result
344
+
345
+ def get_cache_key(self) -> tuple[Any, ...]:
346
+ """Return cache key for this filter configuration."""
347
+ values_tuple = tuple(self.values) if self.values is not None else None
348
+ return ("InCollectionFilter", self.field_name, values_tuple)
349
+
350
+
351
+ class NotInCollectionFilter(InAnyFilter[T]):
352
+ """Data required to construct a ``WHERE ... NOT IN (...)`` clause."""
353
+
354
+ __slots__ = ("_param_names", "field_name", "values")
355
+
356
+ field_name: str
357
+ values: Optional[abc.Collection[T]]
358
+
359
+ def __init__(self, field_name: str, values: Optional[abc.Collection[T]]) -> None:
360
+ """Initialize the NotInCollectionFilter.
361
+
362
+ Args:
363
+ field_name: Name of the model attribute to filter on.
364
+ values: Values for ``NOT IN`` clause. An empty list or ``None`` will return all rows.
365
+ """
366
+ self.field_name = field_name
367
+ self.values = values
368
+
369
+ self._param_names: list[str] = []
370
+ if self.values:
371
+ for i, _ in enumerate(self.values):
372
+ self._param_names.append(f"{self.field_name}_notin_{i}_{id(self)}")
373
+
374
+ def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]:
375
+ """Extract filter parameters."""
376
+ named_parameters = {}
377
+ if self.values:
378
+ for i, value in enumerate(self.values):
379
+ named_parameters[self._param_names[i]] = value
380
+ return [], named_parameters
381
+
382
+ def append_to_statement(self, statement: "SQL") -> "SQL":
383
+ if self.values is None or not self.values:
384
+ return statement
385
+
386
+ # Resolve parameter name conflicts
387
+ resolved_names = self._resolve_parameter_conflicts(statement, self._param_names)
388
+
389
+ placeholder_expressions: list[exp.Placeholder] = [
390
+ exp.Placeholder(this=param_name) for param_name in resolved_names
391
+ ]
392
+
393
+ result = statement.where(
394
+ exp.Not(this=exp.In(this=exp.column(self.field_name), expressions=placeholder_expressions))
395
+ )
396
+
397
+ # Add parameters with resolved names
398
+ for resolved_name, value in zip(resolved_names, self.values):
399
+ result = result.add_named_parameter(resolved_name, value)
400
+ return result
401
+
402
+ def get_cache_key(self) -> tuple[Any, ...]:
403
+ """Return cache key for this filter configuration."""
404
+ values_tuple = tuple(self.values) if self.values is not None else None
405
+ return ("NotInCollectionFilter", self.field_name, values_tuple)
406
+
407
+
408
+ class AnyCollectionFilter(InAnyFilter[T]):
409
+ """Data required to construct a ``WHERE column_name = ANY (array_expression)`` clause."""
410
+
411
+ __slots__ = ("_param_names", "field_name", "values")
412
+
413
+ field_name: str
414
+ values: Optional[abc.Collection[T]]
415
+
416
+ def __init__(self, field_name: str, values: Optional[abc.Collection[T]]) -> None:
417
+ """Initialize the AnyCollectionFilter.
418
+
419
+ Args:
420
+ field_name: Name of the model attribute to filter on.
421
+ values: Values for ``= ANY (...)`` clause. An empty list will result in a condition
422
+ that is always false (no rows returned). If ``None``, the filter is not applied
423
+ to the query, and all rows are returned.
424
+ """
425
+ self.field_name = field_name
426
+ self.values = values
427
+
428
+ self._param_names: list[str] = []
429
+ if self.values:
430
+ for i, _ in enumerate(self.values):
431
+ self._param_names.append(f"{self.field_name}_any_{i}")
432
+
433
+ def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]:
434
+ """Extract filter parameters."""
435
+ named_parameters = {}
436
+ if self.values:
437
+ for i, value in enumerate(self.values):
438
+ named_parameters[self._param_names[i]] = value
439
+ return [], named_parameters
440
+
441
+ def append_to_statement(self, statement: "SQL") -> "SQL":
442
+ if self.values is None:
443
+ return statement
444
+
445
+ if not self.values:
446
+ return statement.where(exp.false())
447
+
448
+ # Resolve parameter name conflicts
449
+ resolved_names = self._resolve_parameter_conflicts(statement, self._param_names)
450
+
451
+ placeholder_expressions: list[exp.Expression] = [
452
+ exp.Placeholder(this=param_name) for param_name in resolved_names
453
+ ]
454
+
455
+ array_expr = exp.Array(expressions=placeholder_expressions)
456
+ result = statement.where(exp.EQ(this=exp.column(self.field_name), expression=exp.Any(this=array_expr)))
457
+
458
+ # Add parameters with resolved names
459
+ for resolved_name, value in zip(resolved_names, self.values):
460
+ result = result.add_named_parameter(resolved_name, value)
461
+ return result
462
+
463
+ def get_cache_key(self) -> tuple[Any, ...]:
464
+ """Return cache key for this filter configuration."""
465
+ values_tuple = tuple(self.values) if self.values is not None else None
466
+ return ("AnyCollectionFilter", self.field_name, values_tuple)
467
+
468
+
469
+ class NotAnyCollectionFilter(InAnyFilter[T]):
470
+ """Data required to construct a ``WHERE NOT (column_name = ANY (array_expression))`` clause."""
471
+
472
+ __slots__ = ("_param_names", "field_name", "values")
473
+
474
+ def __init__(self, field_name: str, values: Optional[abc.Collection[T]]) -> None:
475
+ """Initialize the NotAnyCollectionFilter.
476
+
477
+ Args:
478
+ field_name: Name of the model attribute to filter on.
479
+ values: Values for ``NOT (... = ANY (...))`` clause. An empty list will result in a
480
+ condition that is always true (all rows returned, filter effectively ignored).
481
+ If ``None``, the filter is not applied to the query, and all rows are returned.
482
+ """
483
+ self.field_name = field_name
484
+ self.values = values
485
+
486
+ self._param_names: list[str] = []
487
+ if self.values:
488
+ for i, _ in enumerate(self.values):
489
+ self._param_names.append(f"{self.field_name}_not_any_{i}")
490
+
491
+ def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]:
492
+ """Extract filter parameters."""
493
+ named_parameters = {}
494
+ if self.values:
495
+ for i, value in enumerate(self.values):
496
+ named_parameters[self._param_names[i]] = value
497
+ return [], named_parameters
498
+
499
+ def append_to_statement(self, statement: "SQL") -> "SQL":
500
+ if self.values is None or not self.values:
501
+ return statement
502
+
503
+ # Resolve parameter name conflicts
504
+ resolved_names = self._resolve_parameter_conflicts(statement, self._param_names)
505
+
506
+ placeholder_expressions: list[exp.Expression] = [
507
+ exp.Placeholder(this=param_name) for param_name in resolved_names
508
+ ]
509
+
510
+ array_expr = exp.Array(expressions=placeholder_expressions)
511
+ condition = exp.EQ(this=exp.column(self.field_name), expression=exp.Any(this=array_expr))
512
+ result = statement.where(exp.Not(this=condition))
513
+
514
+ # Add parameters with resolved names
515
+ for resolved_name, value in zip(resolved_names, self.values):
516
+ result = result.add_named_parameter(resolved_name, value)
517
+ return result
518
+
519
+ def get_cache_key(self) -> tuple[Any, ...]:
520
+ """Return cache key for this filter configuration."""
521
+ values_tuple = tuple(self.values) if self.values is not None else None
522
+ return ("NotAnyCollectionFilter", self.field_name, values_tuple)
523
+
524
+
525
+ class PaginationFilter(StatementFilter, ABC):
526
+ """Subclass for methods that function as a pagination type."""
527
+
528
+ __slots__ = ()
529
+
530
+ @abstractmethod
531
+ def append_to_statement(self, statement: "SQL") -> "SQL":
532
+ raise NotImplementedError
533
+
534
+
535
+ class LimitOffsetFilter(PaginationFilter):
536
+ """Data required to add limit/offset filtering to a query."""
537
+
538
+ __slots__ = ("_limit_param_name", "_offset_param_name", "limit", "offset")
539
+
540
+ limit: int
541
+ offset: int
542
+
543
+ def __init__(self, limit: int, offset: int) -> None:
544
+ """Initialize the LimitOffsetFilter.
545
+
546
+ Args:
547
+ limit: Value for ``LIMIT`` clause of query.
548
+ offset: Value for ``OFFSET`` clause of query.
549
+ """
550
+ self.limit = limit
551
+ self.offset = offset
552
+
553
+ self._limit_param_name = "limit"
554
+ self._offset_param_name = "offset"
555
+
556
+ def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]:
557
+ """Extract filter parameters."""
558
+ return [], {self._limit_param_name: self.limit, self._offset_param_name: self.offset}
559
+
560
+ def append_to_statement(self, statement: "SQL") -> "SQL":
561
+ from sqlglot import exp
562
+
563
+ # Resolve parameter name conflicts
564
+ resolved_names = self._resolve_parameter_conflicts(statement, [self._limit_param_name, self._offset_param_name])
565
+ limit_param_name, offset_param_name = resolved_names
566
+
567
+ limit_placeholder = exp.Placeholder(this=limit_param_name)
568
+ offset_placeholder = exp.Placeholder(this=offset_param_name)
569
+
570
+ if statement._statement is None:
571
+ new_statement = exp.Select().limit(limit_placeholder)
572
+ else:
573
+ new_statement = (
574
+ statement._statement.limit(limit_placeholder)
575
+ if isinstance(statement._statement, exp.Select)
576
+ else exp.Select().from_(statement._statement).limit(limit_placeholder)
577
+ )
578
+
579
+ if isinstance(new_statement, exp.Select):
580
+ new_statement = new_statement.offset(offset_placeholder)
581
+
582
+ result = statement.copy(statement=new_statement)
583
+
584
+ result = result.add_named_parameter(limit_param_name, self.limit)
585
+ return result.add_named_parameter(offset_param_name, self.offset)
586
+
587
+ def get_cache_key(self) -> tuple[Any, ...]:
588
+ """Return cache key for this filter configuration."""
589
+ return ("LimitOffsetFilter", self.limit, self.offset)
590
+
591
+
592
+ class OrderByFilter(StatementFilter):
593
+ """Data required to construct a ``ORDER BY ...`` clause."""
594
+
595
+ __slots__ = ("field_name", "sort_order")
596
+
597
+ field_name: str
598
+ sort_order: Literal["asc", "desc"]
599
+
600
+ def __init__(self, field_name: str, sort_order: Literal["asc", "desc"] = "asc") -> None:
601
+ """Initialize the OrderByFilter.
602
+
603
+ Args:
604
+ field_name: Name of the model attribute to sort on.
605
+ sort_order: Sort ascending or descending.
606
+ """
607
+ self.field_name = field_name
608
+ self.sort_order = sort_order
609
+
610
+ def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]:
611
+ """Extract filter parameters."""
612
+ return [], {}
613
+
614
+ def append_to_statement(self, statement: "SQL") -> "SQL":
615
+ converted_sort_order = self.sort_order.lower()
616
+ if converted_sort_order not in {"asc", "desc"}:
617
+ converted_sort_order = "asc"
618
+
619
+ col_expr = exp.column(self.field_name)
620
+ order_expr = col_expr.desc() if converted_sort_order == "desc" else col_expr.asc()
621
+
622
+ if statement._statement is None:
623
+ new_statement = exp.Select().order_by(order_expr)
624
+ elif isinstance(statement._statement, exp.Select):
625
+ new_statement = statement._statement.order_by(order_expr)
626
+ else:
627
+ new_statement = exp.Select().from_(statement._statement).order_by(order_expr)
628
+
629
+ return statement.copy(statement=new_statement)
630
+
631
+ def get_cache_key(self) -> tuple[Any, ...]:
632
+ """Return cache key for this filter configuration."""
633
+ return ("OrderByFilter", self.field_name, self.sort_order)
634
+
635
+
636
+ class SearchFilter(StatementFilter):
637
+ """Filter for text search queries.
638
+
639
+ Constructs WHERE field_name LIKE '%value%' clauses.
640
+ """
641
+
642
+ __slots__ = ("_param_name", "field_name", "ignore_case", "value")
643
+
644
+ field_name: Union[str, set[str]]
645
+ value: str
646
+ ignore_case: Optional[bool]
647
+
648
+ def __init__(self, field_name: Union[str, set[str]], value: str, ignore_case: Optional[bool] = False) -> None:
649
+ """Initialize the SearchFilter.
650
+
651
+ Args:
652
+ field_name: Name of the model attribute to search on.
653
+ value: Search value.
654
+ ignore_case: Should the search be case insensitive.
655
+ """
656
+ self.field_name = field_name
657
+ self.value = value
658
+ self.ignore_case = ignore_case
659
+
660
+ self._param_name: Optional[str] = None
661
+ if self.value:
662
+ if isinstance(self.field_name, str):
663
+ self._param_name = f"{self.field_name}_search"
664
+ else:
665
+ self._param_name = "search_value"
666
+
667
+ def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]:
668
+ """Extract filter parameters."""
669
+ named_parameters = {}
670
+ if self.value and self._param_name:
671
+ search_value_with_wildcards = f"%{self.value}%"
672
+ named_parameters[self._param_name] = search_value_with_wildcards
673
+ return [], named_parameters
674
+
675
+ def append_to_statement(self, statement: "SQL") -> "SQL":
676
+ if not self.value or not self._param_name:
677
+ return statement
678
+
679
+ # Resolve parameter name conflicts
680
+ resolved_names = self._resolve_parameter_conflicts(statement, [self._param_name])
681
+ param_name = resolved_names[0]
682
+
683
+ pattern_expr = exp.Placeholder(this=param_name)
684
+ like_op = exp.ILike if self.ignore_case else exp.Like
685
+
686
+ if isinstance(self.field_name, str):
687
+ result = statement.where(like_op(this=exp.column(self.field_name), expression=pattern_expr))
688
+ elif isinstance(self.field_name, set) and self.field_name:
689
+ field_conditions: list[Condition] = [
690
+ like_op(this=exp.column(field), expression=pattern_expr) for field in self.field_name
691
+ ]
692
+ if not field_conditions:
693
+ return statement
694
+
695
+ final_condition: Condition = field_conditions[0]
696
+ for cond in field_conditions[1:]:
697
+ final_condition = exp.Or(this=final_condition, expression=cond)
698
+ result = statement.where(final_condition)
699
+ else:
700
+ result = statement
701
+
702
+ # Add parameter with resolved name
703
+ search_value_with_wildcards = f"%{self.value}%"
704
+ return result.add_named_parameter(param_name, search_value_with_wildcards)
705
+
706
+ def get_cache_key(self) -> tuple[Any, ...]:
707
+ """Return cache key for this filter configuration."""
708
+ field_names = tuple(sorted(self.field_name)) if isinstance(self.field_name, set) else self.field_name
709
+ return ("SearchFilter", field_names, self.value, self.ignore_case)
710
+
711
+
712
+ class NotInSearchFilter(SearchFilter):
713
+ """Data required to construct a ``WHERE field_name NOT LIKE '%' || :value || '%'`` clause."""
714
+
715
+ __slots__ = ()
716
+
717
+ def __init__(self, field_name: Union[str, set[str]], value: str, ignore_case: Optional[bool] = False) -> None:
718
+ """Initialize the NotInSearchFilter.
719
+
720
+ Args:
721
+ field_name: Name of the model attribute to search on.
722
+ value: Search value.
723
+ ignore_case: Should the search be case insensitive.
724
+ """
725
+ super().__init__(field_name, value, ignore_case)
726
+
727
+ self._param_name: Optional[str] = None
728
+ if self.value:
729
+ if isinstance(self.field_name, str):
730
+ self._param_name = f"{self.field_name}_not_search"
731
+ else:
732
+ self._param_name = "not_search_value"
733
+
734
+ def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]:
735
+ """Extract filter parameters."""
736
+ named_parameters = {}
737
+ if self.value and self._param_name:
738
+ search_value_with_wildcards = f"%{self.value}%"
739
+ named_parameters[self._param_name] = search_value_with_wildcards
740
+ return [], named_parameters
741
+
742
+ def append_to_statement(self, statement: "SQL") -> "SQL":
743
+ if not self.value or not self._param_name:
744
+ return statement
745
+
746
+ # Resolve parameter name conflicts
747
+ resolved_names = self._resolve_parameter_conflicts(statement, [self._param_name])
748
+ param_name = resolved_names[0]
749
+
750
+ pattern_expr = exp.Placeholder(this=param_name)
751
+ like_op = exp.ILike if self.ignore_case else exp.Like
752
+
753
+ result = statement
754
+ if isinstance(self.field_name, str):
755
+ result = statement.where(exp.Not(this=like_op(this=exp.column(self.field_name), expression=pattern_expr)))
756
+ elif isinstance(self.field_name, set) and self.field_name:
757
+ field_conditions: list[Condition] = [
758
+ exp.Not(this=like_op(this=exp.column(field), expression=pattern_expr)) for field in self.field_name
759
+ ]
760
+ if not field_conditions:
761
+ return statement
762
+
763
+ final_condition: Condition = field_conditions[0]
764
+ if len(field_conditions) > 1:
765
+ for cond in field_conditions[1:]:
766
+ final_condition = exp.And(this=final_condition, expression=cond)
767
+ result = statement.where(final_condition)
768
+
769
+ # Add parameter with resolved name
770
+ search_value_with_wildcards = f"%{self.value}%"
771
+ return result.add_named_parameter(param_name, search_value_with_wildcards)
772
+
773
+ def get_cache_key(self) -> tuple[Any, ...]:
774
+ """Return cache key for this filter configuration."""
775
+ field_names = tuple(sorted(self.field_name)) if isinstance(self.field_name, set) else self.field_name
776
+ return ("NotInSearchFilter", field_names, self.value, self.ignore_case)
777
+
778
+
779
+ class OffsetPagination(Generic[T]):
780
+ """Container for data returned using limit/offset pagination."""
781
+
782
+ __slots__ = ("items", "limit", "offset", "total")
783
+
784
+ items: Sequence[T]
785
+ limit: int
786
+ offset: int
787
+ total: int
788
+
789
+ def __init__(self, items: Sequence[T], limit: int, offset: int, total: int) -> None:
790
+ """Initialize OffsetPagination.
791
+
792
+ Args:
793
+ items: List of data being sent as part of the response.
794
+ limit: Maximal number of items to send.
795
+ offset: Offset from the beginning of the query. Identical to an index.
796
+ total: Total number of items.
797
+ """
798
+ self.items = items
799
+ self.limit = limit
800
+ self.offset = offset
801
+ self.total = total
802
+
803
+
804
+ def apply_filter(statement: "SQL", filter_obj: StatementFilter) -> "SQL":
805
+ """Apply a statement filter to a SQL query object.
806
+
807
+ Args:
808
+ statement: The SQL query object to modify.
809
+ filter_obj: The filter to apply.
810
+
811
+ Returns:
812
+ The modified query object.
813
+ """
814
+ return filter_obj.append_to_statement(statement)
815
+
816
+
817
+ FilterTypes: TypeAlias = Union[
818
+ BeforeAfterFilter,
819
+ OnBeforeAfterFilter,
820
+ InCollectionFilter[Any],
821
+ LimitOffsetFilter,
822
+ OrderByFilter,
823
+ SearchFilter,
824
+ NotInCollectionFilter[Any],
825
+ NotInSearchFilter,
826
+ AnyCollectionFilter[Any],
827
+ NotAnyCollectionFilter[Any],
828
+ ]