sqlspec 0.15.0__py3-none-any.whl → 0.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (43) hide show
  1. sqlspec/_sql.py +699 -43
  2. sqlspec/builder/_base.py +77 -44
  3. sqlspec/builder/_column.py +0 -4
  4. sqlspec/builder/_ddl.py +15 -52
  5. sqlspec/builder/_ddl_utils.py +0 -1
  6. sqlspec/builder/_delete.py +4 -5
  7. sqlspec/builder/_insert.py +61 -35
  8. sqlspec/builder/_merge.py +17 -2
  9. sqlspec/builder/_parsing_utils.py +16 -12
  10. sqlspec/builder/_select.py +29 -33
  11. sqlspec/builder/_update.py +4 -2
  12. sqlspec/builder/mixins/_cte_and_set_ops.py +47 -20
  13. sqlspec/builder/mixins/_delete_operations.py +6 -1
  14. sqlspec/builder/mixins/_insert_operations.py +126 -24
  15. sqlspec/builder/mixins/_join_operations.py +11 -4
  16. sqlspec/builder/mixins/_merge_operations.py +91 -19
  17. sqlspec/builder/mixins/_order_limit_operations.py +15 -3
  18. sqlspec/builder/mixins/_pivot_operations.py +11 -2
  19. sqlspec/builder/mixins/_select_operations.py +16 -10
  20. sqlspec/builder/mixins/_update_operations.py +43 -10
  21. sqlspec/builder/mixins/_where_clause.py +177 -65
  22. sqlspec/core/cache.py +26 -28
  23. sqlspec/core/compiler.py +58 -37
  24. sqlspec/core/filters.py +12 -10
  25. sqlspec/core/parameters.py +80 -52
  26. sqlspec/core/result.py +30 -17
  27. sqlspec/core/statement.py +47 -22
  28. sqlspec/driver/_async.py +76 -46
  29. sqlspec/driver/_common.py +25 -6
  30. sqlspec/driver/_sync.py +73 -43
  31. sqlspec/driver/mixins/_result_tools.py +62 -37
  32. sqlspec/driver/mixins/_sql_translator.py +61 -11
  33. sqlspec/extensions/litestar/cli.py +1 -1
  34. sqlspec/extensions/litestar/plugin.py +2 -2
  35. sqlspec/protocols.py +7 -0
  36. sqlspec/utils/sync_tools.py +1 -1
  37. sqlspec/utils/type_guards.py +7 -3
  38. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/METADATA +1 -1
  39. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/RECORD +43 -43
  40. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/WHEEL +0 -0
  41. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/entry_points.txt +0 -0
  42. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/licenses/LICENSE +0 -0
  43. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/licenses/NOTICE +0 -0
sqlspec/core/statement.py CHANGED
@@ -220,13 +220,20 @@ class SQL:
220
220
  if "is_script" in kwargs:
221
221
  self._is_script = bool(kwargs.pop("is_script"))
222
222
 
223
- filters = [p for p in parameters if is_statement_filter(p)]
224
- actual_params = [p for p in parameters if not is_statement_filter(p)]
223
+ # Optimize parameter filtering with direct iteration
224
+ filters: list[StatementFilter] = []
225
+ actual_params: list[Any] = []
226
+ for p in parameters:
227
+ if is_statement_filter(p):
228
+ filters.append(p)
229
+ else:
230
+ actual_params.append(p)
225
231
 
226
232
  self._filters.extend(filters)
227
233
 
228
234
  if actual_params:
229
- if len(actual_params) == 1:
235
+ param_count = len(actual_params)
236
+ if param_count == 1:
230
237
  param = actual_params[0]
231
238
  if isinstance(param, dict):
232
239
  self._named_parameters.update(param)
@@ -339,10 +346,11 @@ class SQL:
339
346
  """Explicitly compile the SQL statement."""
340
347
  if self._processed_state is Empty:
341
348
  try:
342
- current_parameters = self._named_parameters or self._positional_parameters
349
+ # Avoid unnecessary variable assignment
343
350
  processor = SQLProcessor(self._statement_config)
344
-
345
- compiled_result = processor.compile(self._raw_sql, current_parameters, is_many=self._is_many)
351
+ compiled_result = processor.compile(
352
+ self._raw_sql, self._named_parameters or self._positional_parameters, is_many=self._is_many
353
+ )
346
354
 
347
355
  self._processed_state = ProcessedState(
348
356
  compiled_sql=compiled_result.compiled_sql,
@@ -368,6 +376,10 @@ class SQL:
368
376
  new_sql = SQL(
369
377
  self._raw_sql, *self._original_parameters, statement_config=self._statement_config, is_many=self._is_many
370
378
  )
379
+ # Preserve accumulated parameters when marking as script
380
+ new_sql._named_parameters.update(self._named_parameters)
381
+ new_sql._positional_parameters = self._positional_parameters.copy()
382
+ new_sql._filters = self._filters.copy()
371
383
  new_sql._is_script = True
372
384
  return new_sql
373
385
 
@@ -375,13 +387,19 @@ class SQL:
375
387
  self, statement: "Optional[Union[str, exp.Expression]]" = None, parameters: Optional[Any] = None, **kwargs: Any
376
388
  ) -> "SQL":
377
389
  """Create copy with modifications."""
378
- return SQL(
390
+ new_sql = SQL(
379
391
  statement or self._raw_sql,
380
392
  *(parameters if parameters is not None else self._original_parameters),
381
393
  statement_config=self._statement_config,
382
394
  is_many=self._is_many,
383
395
  **kwargs,
384
396
  )
397
+ # Only preserve accumulated parameters when no explicit parameters are provided
398
+ if parameters is None:
399
+ new_sql._named_parameters.update(self._named_parameters)
400
+ new_sql._positional_parameters = self._positional_parameters.copy()
401
+ new_sql._filters = self._filters.copy()
402
+ return new_sql
385
403
 
386
404
  def add_named_parameter(self, name: str, value: Any) -> "SQL":
387
405
  """Add a named parameter and return a new SQL instance.
@@ -411,6 +429,7 @@ class SQL:
411
429
  Returns:
412
430
  New SQL instance with the WHERE condition applied
413
431
  """
432
+ # Parse current SQL with copy=False optimization
414
433
  current_expr = None
415
434
  with contextlib.suppress(ParseError):
416
435
  current_expr = sqlglot.parse_one(self._raw_sql, dialect=self._dialect)
@@ -419,8 +438,11 @@ class SQL:
419
438
  try:
420
439
  current_expr = sqlglot.parse_one(self._raw_sql, dialect=self._dialect)
421
440
  except ParseError:
422
- current_expr = sqlglot.parse_one(f"SELECT * FROM ({self._raw_sql}) AS subquery", dialect=self._dialect)
441
+ # Use f-string optimization and copy=False
442
+ subquery_sql = f"SELECT * FROM ({self._raw_sql}) AS subquery"
443
+ current_expr = sqlglot.parse_one(subquery_sql, dialect=self._dialect)
423
444
 
445
+ # Parse condition with copy=False optimization
424
446
  condition_expr: exp.Expression
425
447
  if isinstance(condition, str):
426
448
  try:
@@ -430,29 +452,32 @@ class SQL:
430
452
  else:
431
453
  condition_expr = condition
432
454
 
455
+ # Apply WHERE clause
433
456
  if isinstance(current_expr, exp.Select) or supports_where(current_expr):
434
- new_expr = current_expr.where(condition_expr)
457
+ new_expr = current_expr.where(condition_expr, copy=False)
435
458
  else:
436
- new_expr = exp.Select().from_(current_expr).where(condition_expr)
459
+ new_expr = exp.Select().from_(current_expr).where(condition_expr, copy=False)
437
460
 
461
+ # Generate SQL and create new instance
438
462
  new_sql_text = new_expr.sql(dialect=self._dialect)
439
-
440
- return SQL(
463
+ new_sql = SQL(
441
464
  new_sql_text, *self._original_parameters, statement_config=self._statement_config, is_many=self._is_many
442
465
  )
443
466
 
467
+ # Preserve state efficiently
468
+ new_sql._named_parameters.update(self._named_parameters)
469
+ new_sql._positional_parameters = self._positional_parameters.copy()
470
+ new_sql._filters = self._filters.copy()
471
+ return new_sql
472
+
444
473
  def __hash__(self) -> int:
445
- """Hash value."""
474
+ """Hash value with optimized computation."""
446
475
  if self._hash is None:
447
- self._hash = hash(
448
- (
449
- self._raw_sql,
450
- tuple(self._positional_parameters),
451
- tuple(sorted(self._named_parameters.items())),
452
- self._is_many,
453
- self._is_script,
454
- )
455
- )
476
+ # Pre-compute tuple components to avoid multiple tuple() calls
477
+ positional_tuple = tuple(self._positional_parameters)
478
+ named_tuple = tuple(sorted(self._named_parameters.items())) if self._named_parameters else ()
479
+
480
+ self._hash = hash((self._raw_sql, positional_tuple, named_tuple, self._is_many, self._is_script))
456
481
  return self._hash
457
482
 
458
483
  def __eq__(self, other: object) -> bool:
sqlspec/driver/_async.py CHANGED
@@ -5,7 +5,7 @@ including connection management, transaction support, and result processing.
5
5
  """
6
6
 
7
7
  from abc import abstractmethod
8
- from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload
8
+ from typing import TYPE_CHECKING, Any, Final, NoReturn, Optional, Union, cast, overload
9
9
 
10
10
  from sqlspec.core import SQL, Statement
11
11
  from sqlspec.driver._common import CommonDriverAttributesMixin, ExecutionResult
@@ -20,14 +20,15 @@ if TYPE_CHECKING:
20
20
 
21
21
  from sqlspec.builder import QueryBuilder
22
22
  from sqlspec.core import SQLResult, StatementConfig, StatementFilter
23
- from sqlspec.typing import ModelDTOT, ModelT, RowT, StatementParameters
23
+ from sqlspec.typing import ModelDTOT, StatementParameters
24
24
 
25
- logger = get_logger("sqlspec")
25
+ _LOGGER_NAME: Final[str] = "sqlspec"
26
+ logger = get_logger(_LOGGER_NAME)
26
27
 
27
28
  __all__ = ("AsyncDriverAdapterBase",)
28
29
 
29
30
 
30
- EMPTY_FILTERS: "list[StatementFilter]" = []
31
+ EMPTY_FILTERS: Final["list[StatementFilter]"] = []
31
32
 
32
33
 
33
34
  class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToSchemaMixin):
@@ -128,12 +129,16 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
128
129
  sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
129
130
  statements = self.split_script_statements(sql, self.statement_config, strip_trailing_semicolon=True)
130
131
 
132
+ statement_count: int = len(statements)
133
+ successful_count: int = 0
134
+
131
135
  for stmt in statements:
132
136
  single_stmt = statement.copy(statement=stmt, parameters=prepared_parameters)
133
137
  await self._execute_statement(cursor, single_stmt)
138
+ successful_count += 1
134
139
 
135
140
  return self.create_execution_result(
136
- cursor, statement_count=len(statements), successful_statements=len(statements), is_script_result=True
141
+ cursor, statement_count=statement_count, successful_statements=successful_count, is_script_result=True
137
142
  )
138
143
 
139
144
  @abstractmethod
@@ -214,8 +219,8 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
214
219
  By default, validates each statement and logs warnings for dangerous
215
220
  operations. Use suppress_warnings=True for migrations and admin scripts.
216
221
  """
217
- script_config = statement_config or self.statement_config
218
- sql_statement = self.prepare_statement(statement, parameters, statement_config=script_config, kwargs=kwargs)
222
+ config = statement_config or self.statement_config
223
+ sql_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs)
219
224
 
220
225
  return await self.dispatch_statement_execution(statement=sql_statement.as_script(), connection=self.connection)
221
226
 
@@ -239,7 +244,7 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
239
244
  schema_type: None = None,
240
245
  statement_config: "Optional[StatementConfig]" = None,
241
246
  **kwargs: Any,
242
- ) -> "Union[ModelT, RowT, dict[str, Any]]": ... # pyright: ignore[reportInvalidTypeVarUse]
247
+ ) -> "dict[str, Any]": ...
243
248
 
244
249
  async def select_one(
245
250
  self,
@@ -249,23 +254,20 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
249
254
  schema_type: "Optional[type[ModelDTOT]]" = None,
250
255
  statement_config: "Optional[StatementConfig]" = None,
251
256
  **kwargs: Any,
252
- ) -> "Union[ModelT, RowT,ModelDTOT]": # pyright: ignore[reportInvalidTypeVarUse]
257
+ ) -> "Union[dict[str, Any], ModelDTOT]":
253
258
  """Execute a select statement and return exactly one row.
254
259
 
255
260
  Raises an exception if no rows or more than one row is returned.
256
261
  """
257
262
  result = await self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
258
263
  data = result.get_data()
259
- if not data:
260
- msg = "No rows found"
261
- raise NotFoundError(msg)
262
- if len(data) > 1:
263
- msg = f"Expected exactly one row, found {len(data)}"
264
- raise ValueError(msg)
265
- return cast(
266
- "Union[ModelT, RowT, ModelDTOT]",
267
- self.to_schema(data[0], schema_type=schema_type) if schema_type else data[0],
268
- )
264
+ data_len: int = len(data)
265
+ if data_len == 0:
266
+ self._raise_no_rows_found()
267
+ if data_len > 1:
268
+ self._raise_expected_one_row(data_len)
269
+ first_row = data[0]
270
+ return self.to_schema(first_row, schema_type=schema_type) if schema_type else first_row
269
271
 
270
272
  @overload
271
273
  async def select_one_or_none(
@@ -287,7 +289,7 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
287
289
  schema_type: None = None,
288
290
  statement_config: "Optional[StatementConfig]" = None,
289
291
  **kwargs: Any,
290
- ) -> "Optional[ModelT]": ... # pyright: ignore[reportInvalidTypeVarUse]
292
+ ) -> "Optional[dict[str, Any]]": ...
291
293
 
292
294
  async def select_one_or_none(
293
295
  self,
@@ -297,7 +299,7 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
297
299
  schema_type: "Optional[type[ModelDTOT]]" = None,
298
300
  statement_config: "Optional[StatementConfig]" = None,
299
301
  **kwargs: Any,
300
- ) -> "Optional[Union[ModelT, ModelDTOT]]": # pyright: ignore[reportInvalidTypeVarUse]
302
+ ) -> "Optional[Union[dict[str, Any], ModelDTOT]]":
301
303
  """Execute a select statement and return at most one row.
302
304
 
303
305
  Returns None if no rows are found.
@@ -305,12 +307,16 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
305
307
  """
306
308
  result = await self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
307
309
  data = result.get_data()
308
- if not data:
310
+ data_len: int = len(data)
311
+ if data_len == 0:
309
312
  return None
310
- if len(data) > 1:
311
- msg = f"Expected at most one row, found {len(data)}"
312
- raise ValueError(msg)
313
- return cast("Optional[Union[ModelT, ModelDTOT]]", self.to_schema(data[0], schema_type=schema_type))
313
+ if data_len > 1:
314
+ self._raise_expected_at_most_one_row(data_len)
315
+ first_row = data[0]
316
+ return cast(
317
+ "Optional[Union[dict[str, Any], ModelDTOT]]",
318
+ self.to_schema(first_row, schema_type=schema_type) if schema_type else first_row,
319
+ )
314
320
 
315
321
  @overload
316
322
  async def select(
@@ -332,7 +338,8 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
332
338
  schema_type: None = None,
333
339
  statement_config: "Optional[StatementConfig]" = None,
334
340
  **kwargs: Any,
335
- ) -> "list[ModelT]": ... # pyright: ignore[reportInvalidTypeVarUse]
341
+ ) -> "list[dict[str, Any]]": ...
342
+
336
343
  async def select(
337
344
  self,
338
345
  statement: "Union[Statement, QueryBuilder]",
@@ -341,12 +348,11 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
341
348
  schema_type: "Optional[type[ModelDTOT]]" = None,
342
349
  statement_config: "Optional[StatementConfig]" = None,
343
350
  **kwargs: Any,
344
- ) -> "Union[list[ModelT], list[ModelDTOT]]": # pyright: ignore[reportInvalidTypeVarUse]
351
+ ) -> "Union[list[dict[str, Any]], list[ModelDTOT]]":
345
352
  """Execute a select statement and return all rows."""
346
353
  result = await self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
347
354
  return cast(
348
- "Union[list[ModelT], list[ModelDTOT]]",
349
- self.to_schema(cast("list[ModelT]", result.get_data()), schema_type=schema_type),
355
+ "Union[list[dict[str, Any]], list[ModelDTOT]]", self.to_schema(result.get_data(), schema_type=schema_type)
350
356
  )
351
357
 
352
358
  async def select_value(
@@ -366,23 +372,19 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
366
372
  try:
367
373
  row = result.one()
368
374
  except ValueError as e:
369
- msg = "No rows found"
370
- raise NotFoundError(msg) from e
375
+ self._raise_no_rows_found_from_exception(e)
371
376
  if not row:
372
- msg = "No rows found"
373
- raise NotFoundError(msg)
377
+ self._raise_no_rows_found()
374
378
  if is_dict_row(row):
375
379
  if not row:
376
- msg = "Row has no columns"
377
- raise ValueError(msg)
380
+ self._raise_row_no_columns()
378
381
  return next(iter(row.values()))
379
382
  if is_indexable_row(row):
380
383
  if not row:
381
- msg = "Row has no columns"
382
- raise ValueError(msg)
384
+ self._raise_row_no_columns()
383
385
  return row[0]
384
- msg = f"Unexpected row type: {type(row)}"
385
- raise ValueError(msg)
386
+ self._raise_unexpected_row_type(type(row))
387
+ return None
386
388
 
387
389
  async def select_value_or_none(
388
390
  self,
@@ -400,11 +402,11 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
400
402
  """
401
403
  result = await self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
402
404
  data = result.get_data()
403
- if not data:
405
+ data_len: int = len(data)
406
+ if data_len == 0:
404
407
  return None
405
- if len(data) > 1:
406
- msg = f"Expected at most one row, found {len(data)}"
407
- raise ValueError(msg)
408
+ if data_len > 1:
409
+ self._raise_expected_at_most_one_row(data_len)
408
410
  row = data[0]
409
411
  if is_dict_row(row):
410
412
  if not row:
@@ -412,8 +414,8 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
412
414
  return next(iter(row.values()))
413
415
  if is_indexable_row(row):
414
416
  return row[0]
415
- msg = f"Cannot extract value from row type {type(row).__name__}"
416
- raise TypeError(msg)
417
+ self._raise_cannot_extract_value_from_row_type(type(row).__name__)
418
+ return None
417
419
 
418
420
  @overload
419
421
  async def select_with_total(
@@ -470,3 +472,31 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
470
472
  select_result = await self.execute(sql_statement)
471
473
 
472
474
  return (self.to_schema(select_result.get_data(), schema_type=schema_type), count_result.scalar())
475
+
476
+ def _raise_no_rows_found(self) -> NoReturn:
477
+ msg = "No rows found"
478
+ raise NotFoundError(msg)
479
+
480
+ def _raise_no_rows_found_from_exception(self, e: ValueError) -> NoReturn:
481
+ msg = "No rows found"
482
+ raise NotFoundError(msg) from e
483
+
484
+ def _raise_expected_one_row(self, data_len: int) -> NoReturn:
485
+ msg = f"Expected exactly one row, found {data_len}"
486
+ raise ValueError(msg)
487
+
488
+ def _raise_expected_at_most_one_row(self, data_len: int) -> NoReturn:
489
+ msg = f"Expected at most one row, found {data_len}"
490
+ raise ValueError(msg)
491
+
492
+ def _raise_row_no_columns(self) -> NoReturn:
493
+ msg = "Row has no columns"
494
+ raise ValueError(msg)
495
+
496
+ def _raise_unexpected_row_type(self, row_type: type) -> NoReturn:
497
+ msg = f"Unexpected row type: {row_type}"
498
+ raise ValueError(msg)
499
+
500
+ def _raise_cannot_extract_value_from_row_type(self, type_name: str) -> NoReturn:
501
+ msg = f"Cannot extract value from row type {type_name}"
502
+ raise TypeError(msg)
sqlspec/driver/_common.py CHANGED
@@ -17,7 +17,9 @@ from sqlspec.exceptions import ImproperConfigurationError
17
17
  from sqlspec.utils.logging import get_logger
18
18
 
19
19
  if TYPE_CHECKING:
20
- from sqlspec.core.filters import StatementFilter
20
+ from collections.abc import Sequence
21
+
22
+ from sqlspec.core.filters import FilterTypeT, StatementFilter
21
23
  from sqlspec.typing import StatementParameters
22
24
 
23
25
 
@@ -424,10 +426,9 @@ class CommonDriverAttributesMixin:
424
426
  if isinstance(parameters, dict):
425
427
  if not parameters:
426
428
  return []
427
- if (
428
- statement_config.parameter_config.supported_execution_parameter_styles
429
- and ParameterStyle.NAMED_PYFORMAT
430
- in statement_config.parameter_config.supported_execution_parameter_styles
429
+ if statement_config.parameter_config.supported_execution_parameter_styles and (
430
+ ParameterStyle.NAMED_PYFORMAT in statement_config.parameter_config.supported_execution_parameter_styles
431
+ or ParameterStyle.NAMED_COLON in statement_config.parameter_config.supported_execution_parameter_styles
431
432
  ):
432
433
  return {k: apply_type_coercion(v) for k, v in parameters.items()}
433
434
  if statement_config.parameter_config.default_parameter_style in {
@@ -577,6 +578,24 @@ class CommonDriverAttributesMixin:
577
578
 
578
579
  return max(style_counts.keys(), key=lambda style: (style_counts[style], -precedence.get(style, 99)))
579
580
 
581
+ @staticmethod
582
+ def find_filter(
583
+ filter_type: "type[FilterTypeT]",
584
+ filters: "Sequence[StatementFilter | StatementParameters] | Sequence[StatementFilter]",
585
+ ) -> "FilterTypeT | None":
586
+ """Get the filter specified by filter type from the filters.
587
+
588
+ Args:
589
+ filter_type: The type of filter to find.
590
+ filters: filter types to apply to the query
591
+
592
+ Returns:
593
+ The match filter instance or None
594
+ """
595
+ return next(
596
+ (cast("FilterTypeT | None", filter_) for filter_ in filters if isinstance(filter_, filter_type)), None
597
+ )
598
+
580
599
  def _create_count_query(self, original_sql: "SQL") -> "SQL":
581
600
  """Create a COUNT query from the original SQL statement.
582
601
 
@@ -586,7 +605,7 @@ class CommonDriverAttributesMixin:
586
605
  if not original_sql.expression:
587
606
  msg = "Cannot create COUNT query from empty SQL expression"
588
607
  raise ImproperConfigurationError(msg)
589
- expr = original_sql.expression.copy()
608
+ expr = original_sql.expression
590
609
 
591
610
  if isinstance(expr, exp.Select):
592
611
  if expr.args.get("group"):