sqlspec 0.15.0__py3-none-any.whl → 0.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (43) hide show
  1. sqlspec/_sql.py +699 -43
  2. sqlspec/builder/_base.py +77 -44
  3. sqlspec/builder/_column.py +0 -4
  4. sqlspec/builder/_ddl.py +15 -52
  5. sqlspec/builder/_ddl_utils.py +0 -1
  6. sqlspec/builder/_delete.py +4 -5
  7. sqlspec/builder/_insert.py +61 -35
  8. sqlspec/builder/_merge.py +17 -2
  9. sqlspec/builder/_parsing_utils.py +16 -12
  10. sqlspec/builder/_select.py +29 -33
  11. sqlspec/builder/_update.py +4 -2
  12. sqlspec/builder/mixins/_cte_and_set_ops.py +47 -20
  13. sqlspec/builder/mixins/_delete_operations.py +6 -1
  14. sqlspec/builder/mixins/_insert_operations.py +126 -24
  15. sqlspec/builder/mixins/_join_operations.py +11 -4
  16. sqlspec/builder/mixins/_merge_operations.py +91 -19
  17. sqlspec/builder/mixins/_order_limit_operations.py +15 -3
  18. sqlspec/builder/mixins/_pivot_operations.py +11 -2
  19. sqlspec/builder/mixins/_select_operations.py +16 -10
  20. sqlspec/builder/mixins/_update_operations.py +43 -10
  21. sqlspec/builder/mixins/_where_clause.py +177 -65
  22. sqlspec/core/cache.py +26 -28
  23. sqlspec/core/compiler.py +58 -37
  24. sqlspec/core/filters.py +12 -10
  25. sqlspec/core/parameters.py +80 -52
  26. sqlspec/core/result.py +30 -17
  27. sqlspec/core/statement.py +47 -22
  28. sqlspec/driver/_async.py +76 -46
  29. sqlspec/driver/_common.py +25 -6
  30. sqlspec/driver/_sync.py +73 -43
  31. sqlspec/driver/mixins/_result_tools.py +62 -37
  32. sqlspec/driver/mixins/_sql_translator.py +61 -11
  33. sqlspec/extensions/litestar/cli.py +1 -1
  34. sqlspec/extensions/litestar/plugin.py +2 -2
  35. sqlspec/protocols.py +7 -0
  36. sqlspec/utils/sync_tools.py +1 -1
  37. sqlspec/utils/type_guards.py +7 -3
  38. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/METADATA +1 -1
  39. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/RECORD +43 -43
  40. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/WHEEL +0 -0
  41. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/entry_points.txt +0 -0
  42. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/licenses/LICENSE +0 -0
  43. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/licenses/NOTICE +0 -0
sqlspec/core/compiler.py CHANGED
@@ -121,10 +121,11 @@ class CompiledSQL:
121
121
  self._hash: Optional[int] = None
122
122
 
123
123
  def __hash__(self) -> int:
124
- """Cached hash value."""
124
+ """Cached hash value with optimization."""
125
125
  if self._hash is None:
126
- hash_data = (self.compiled_sql, str(self.execution_parameters), self.operation_type, self.parameter_style)
127
- self._hash = hash(hash_data)
126
+ # Optimize by avoiding str() conversion if possible
127
+ param_str = str(self.execution_parameters)
128
+ self._hash = hash((self.compiled_sql, param_str, self.operation_type, self.parameter_style))
128
129
  return self._hash
129
130
 
130
131
  def __eq__(self, other: object) -> bool:
@@ -229,16 +230,21 @@ class SQLProcessor:
229
230
  CompiledSQL result
230
231
  """
231
232
  try:
233
+ # Cache dialect string to avoid repeated conversions
232
234
  dialect_str = str(self._config.dialect) if self._config.dialect else None
233
- processed_sql, processed_params_tuple = self._parameter_processor.process(
235
+
236
+ # Process parameters in single call
237
+ processed_sql: str
238
+ processed_params: Any
239
+ processed_sql, processed_params = self._parameter_processor.process(
234
240
  sql=sql,
235
241
  parameters=parameters,
236
242
  config=self._config.parameter_config,
237
243
  dialect=dialect_str,
238
244
  is_many=is_many,
239
245
  )
240
- processed_params: Any = processed_params_tuple
241
246
 
247
+ # Optimize static compilation path
242
248
  if self._config.parameter_config.needs_static_script_compilation and processed_params is None:
243
249
  sqlglot_sql = processed_sql
244
250
  else:
@@ -246,35 +252,39 @@ class SQLProcessor:
246
252
  sql, parameters, self._config.parameter_config, dialect_str
247
253
  )
248
254
 
249
- final_parameters: Any = processed_params
255
+ final_parameters = processed_params
250
256
  ast_was_transformed = False
257
+ expression = None
258
+ operation_type = "EXECUTE"
251
259
 
252
260
  if self._config.enable_parsing:
253
261
  try:
262
+ # Use copy=False for performance optimization
254
263
  expression = sqlglot.parse_one(sqlglot_sql, dialect=dialect_str)
255
264
  operation_type = self._detect_operation_type(expression)
256
265
 
257
- if self._config.parameter_config.ast_transformer:
258
- expression, final_parameters = self._config.parameter_config.ast_transformer(
259
- expression, processed_params
260
- )
266
+ # Handle AST transformation if configured
267
+ ast_transformer = self._config.parameter_config.ast_transformer
268
+ if ast_transformer:
269
+ expression, final_parameters = ast_transformer(expression, processed_params)
261
270
  ast_was_transformed = True
262
271
 
263
272
  except ParseError:
264
273
  expression = None
265
274
  operation_type = "EXECUTE"
266
- else:
267
- expression = None
268
- operation_type = "EXECUTE"
269
275
 
276
+ # Optimize final SQL generation path
270
277
  if self._config.parameter_config.needs_static_script_compilation and processed_params is None:
271
278
  final_sql, final_params = processed_sql, processed_params
272
279
  elif ast_was_transformed and expression is not None:
273
280
  final_sql = expression.sql(dialect=dialect_str)
274
281
  final_params = final_parameters
275
282
  logger.debug("AST was transformed - final SQL: %s, final params: %s", final_sql, final_params)
276
- if self._config.output_transformer:
277
- final_sql, final_params = self._config.output_transformer(final_sql, final_params)
283
+
284
+ # Apply output transformer if configured
285
+ output_transformer = self._config.output_transformer
286
+ if output_transformer:
287
+ final_sql, final_params = output_transformer(final_sql, final_params)
278
288
  else:
279
289
  final_sql, final_params = self._apply_final_transformations(
280
290
  expression, processed_sql, final_parameters, dialect_str
@@ -305,15 +315,23 @@ class SQLProcessor:
305
315
  Returns:
306
316
  Cache key string
307
317
  """
318
+ # Optimize key generation by avoiding string conversion overhead
319
+ param_repr = repr(parameters)
320
+ dialect_str = str(self._config.dialect) if self._config.dialect else None
321
+ param_style = self._config.parameter_config.default_parameter_style.value
322
+
323
+ # Use direct tuple construction for better performance
308
324
  hash_data = (
309
325
  sql,
310
- repr(parameters),
311
- self._config.parameter_config.default_parameter_style.value,
312
- str(self._config.dialect),
326
+ param_repr,
327
+ param_style,
328
+ dialect_str,
313
329
  self._config.enable_parsing,
314
330
  self._config.enable_transformations,
315
331
  )
316
- hash_str = hashlib.sha256(str(hash_data).encode()).hexdigest()[:16]
332
+
333
+ # Optimize hash computation
334
+ hash_str = hashlib.sha256(str(hash_data).encode("utf-8")).hexdigest()[:16]
317
335
  return f"sql_{hash_str}"
318
336
 
319
337
  def _detect_operation_type(self, expression: "exp.Expression") -> str:
@@ -327,27 +345,29 @@ class SQLProcessor:
327
345
  Returns:
328
346
  Operation type string
329
347
  """
348
+ # Use isinstance for compatibility with mocks and inheritance
330
349
  if isinstance(expression, exp.Select):
331
- return _OPERATION_TYPES["SELECT"]
350
+ return "SELECT"
332
351
  if isinstance(expression, exp.Insert):
333
- return _OPERATION_TYPES["INSERT"]
352
+ return "INSERT"
334
353
  if isinstance(expression, exp.Update):
335
- return _OPERATION_TYPES["UPDATE"]
354
+ return "UPDATE"
336
355
  if isinstance(expression, exp.Delete):
337
- return _OPERATION_TYPES["DELETE"]
338
- if isinstance(expression, (exp.Create, exp.Drop, exp.Alter)):
339
- return _OPERATION_TYPES["DDL"]
340
- if isinstance(expression, exp.Copy):
341
- if expression.args["kind"] is True:
342
- return _OPERATION_TYPES["COPY_FROM"]
343
- if expression.args["kind"] is False:
344
- return _OPERATION_TYPES["COPY_TO"]
345
- return _OPERATION_TYPES["COPY"]
356
+ return "DELETE"
346
357
  if isinstance(expression, exp.Pragma):
347
- return _OPERATION_TYPES["PRAGMA"]
358
+ return "PRAGMA"
348
359
  if isinstance(expression, exp.Command):
349
- return _OPERATION_TYPES["EXECUTE"]
350
- return _OPERATION_TYPES["UNKNOWN"]
360
+ return "EXECUTE"
361
+ if isinstance(expression, exp.Copy):
362
+ copy_kind = expression.args.get("kind")
363
+ if copy_kind is True:
364
+ return "COPY_FROM"
365
+ if copy_kind is False:
366
+ return "COPY_TO"
367
+ return "COPY"
368
+ if isinstance(expression, (exp.Create, exp.Drop, exp.Alter)):
369
+ return "DDL"
370
+ return "UNKNOWN"
351
371
 
352
372
  def _apply_final_transformations(
353
373
  self, expression: "Optional[exp.Expression]", sql: str, parameters: Any, dialect_str: "Optional[str]"
@@ -363,11 +383,12 @@ class SQLProcessor:
363
383
  Returns:
364
384
  Tuple of (final_sql, final_parameters)
365
385
  """
366
- if self._config.output_transformer:
386
+ output_transformer = self._config.output_transformer
387
+ if output_transformer:
367
388
  if expression is not None:
368
389
  ast_sql = expression.sql(dialect=dialect_str)
369
- return self._config.output_transformer(ast_sql, parameters)
370
- return self._config.output_transformer(sql, parameters)
390
+ return output_transformer(ast_sql, parameters)
391
+ return output_transformer(sql, parameters)
371
392
 
372
393
  return sql, parameters
373
394
 
sqlspec/core/filters.py CHANGED
@@ -558,6 +558,7 @@ class LimitOffsetFilter(PaginationFilter):
558
558
  return [], {self._limit_param_name: self.limit, self._offset_param_name: self.offset}
559
559
 
560
560
  def append_to_statement(self, statement: "SQL") -> "SQL":
561
+ import sqlglot
561
562
  from sqlglot import exp
562
563
 
563
564
  # Resolve parameter name conflicts
@@ -567,17 +568,18 @@ class LimitOffsetFilter(PaginationFilter):
567
568
  limit_placeholder = exp.Placeholder(this=limit_param_name)
568
569
  offset_placeholder = exp.Placeholder(this=offset_param_name)
569
570
 
570
- if statement._statement is None:
571
- new_statement = exp.Select().limit(limit_placeholder)
572
- else:
573
- new_statement = (
574
- statement._statement.limit(limit_placeholder)
575
- if isinstance(statement._statement, exp.Select)
576
- else exp.Select().from_(statement._statement).limit(limit_placeholder)
577
- )
571
+ # Parse the current SQL to get the statement structure
572
+ try:
573
+ current_statement = sqlglot.parse_one(statement._raw_sql, dialect=getattr(statement, "_dialect", None))
574
+ except Exception:
575
+ # Fallback to wrapping in subquery if parsing fails
576
+ current_statement = exp.Select().from_(f"({statement._raw_sql})")
578
577
 
579
- if isinstance(new_statement, exp.Select):
580
- new_statement = new_statement.offset(offset_placeholder)
578
+ if isinstance(current_statement, exp.Select):
579
+ new_statement = current_statement.limit(limit_placeholder).offset(offset_placeholder)
580
+ else:
581
+ # Wrap non-SELECT statements in a subquery
582
+ new_statement = exp.Select().from_(current_statement).limit(limit_placeholder).offset(offset_placeholder)
581
583
 
582
584
  result = statement.copy(statement=new_statement)
583
585
 
@@ -124,9 +124,11 @@ class TypedParameter:
124
124
  self._hash: Optional[int] = None
125
125
 
126
126
  def __hash__(self) -> int:
127
- """Cached hash value."""
127
+ """Cached hash value with optimization."""
128
128
  if self._hash is None:
129
- self._hash = hash((id(self.value), self.original_type, self.semantic_name))
129
+ # Optimize by avoiding tuple creation for common case
130
+ value_id = id(self.value)
131
+ self._hash = hash((value_id, self.original_type, self.semantic_name))
130
132
  return self._hash
131
133
 
132
134
  def __eq__(self, other: object) -> bool:
@@ -361,13 +363,15 @@ class ParameterValidator:
361
363
  Returns:
362
364
  List of ParameterInfo objects for each detected parameter
363
365
  """
364
- if sql in self._parameter_cache:
365
- return self._parameter_cache[sql]
366
+ cached_result = self._parameter_cache.get(sql)
367
+ if cached_result is not None:
368
+ return cached_result
366
369
 
367
370
  parameters: list[ParameterInfo] = []
368
371
  ordinal = 0
369
372
 
370
373
  for match in _PARAMETER_REGEX.finditer(sql):
374
+ # Fast rejection of comments and quotes
371
375
  if (
372
376
  match.group("dquote")
373
377
  or match.group("squote")
@@ -381,37 +385,52 @@ class ParameterValidator:
381
385
 
382
386
  position = match.start()
383
387
  placeholder_text = match.group(0)
384
- name = None
385
- style = None
388
+ name: Optional[str] = None
389
+ style: Optional[ParameterStyle] = None
386
390
 
387
- if match.group("pyformat_named"):
391
+ # Optimize with elif chain for better branch prediction
392
+ pyformat_named = match.group("pyformat_named")
393
+ if pyformat_named:
388
394
  style = ParameterStyle.NAMED_PYFORMAT
389
395
  name = match.group("pyformat_name")
390
- elif match.group("pyformat_pos"):
391
- style = ParameterStyle.POSITIONAL_PYFORMAT
392
- elif match.group("positional_colon"):
393
- style = ParameterStyle.POSITIONAL_COLON
394
- name = match.group("colon_num")
395
- elif match.group("named_colon"):
396
- style = ParameterStyle.NAMED_COLON
397
- name = match.group("colon_name")
398
- elif match.group("named_at"):
399
- style = ParameterStyle.NAMED_AT
400
- name = match.group("at_name")
401
- elif match.group("numeric"):
402
- style = ParameterStyle.NUMERIC
403
- name = match.group("numeric_num")
404
- elif match.group("named_dollar_param"):
405
- style = ParameterStyle.NAMED_DOLLAR
406
- name = match.group("dollar_param_name")
407
- elif match.group("qmark"):
408
- style = ParameterStyle.QMARK
396
+ else:
397
+ pyformat_pos = match.group("pyformat_pos")
398
+ if pyformat_pos:
399
+ style = ParameterStyle.POSITIONAL_PYFORMAT
400
+ else:
401
+ positional_colon = match.group("positional_colon")
402
+ if positional_colon:
403
+ style = ParameterStyle.POSITIONAL_COLON
404
+ name = match.group("colon_num")
405
+ else:
406
+ named_colon = match.group("named_colon")
407
+ if named_colon:
408
+ style = ParameterStyle.NAMED_COLON
409
+ name = match.group("colon_name")
410
+ else:
411
+ named_at = match.group("named_at")
412
+ if named_at:
413
+ style = ParameterStyle.NAMED_AT
414
+ name = match.group("at_name")
415
+ else:
416
+ numeric = match.group("numeric")
417
+ if numeric:
418
+ style = ParameterStyle.NUMERIC
419
+ name = match.group("numeric_num")
420
+ else:
421
+ named_dollar_param = match.group("named_dollar_param")
422
+ if named_dollar_param:
423
+ style = ParameterStyle.NAMED_DOLLAR
424
+ name = match.group("dollar_param_name")
425
+ elif match.group("qmark"):
426
+ style = ParameterStyle.QMARK
409
427
 
410
428
  if style is not None:
411
- param_info = ParameterInfo(
412
- name=name, style=style, position=position, ordinal=ordinal, placeholder_text=placeholder_text
429
+ parameters.append(
430
+ ParameterInfo(
431
+ name=name, style=style, position=position, ordinal=ordinal, placeholder_text=placeholder_text
432
+ )
413
433
  )
414
- parameters.append(param_info)
415
434
  ordinal += 1
416
435
 
417
436
  self._parameter_cache[sql] = parameters
@@ -567,26 +586,34 @@ class ParameterConverter:
567
586
  msg = f"Unsupported target parameter style: {target_style}"
568
587
  raise ValueError(msg)
569
588
 
570
- # Build a mapping of unique parameters to their ordinals
571
- # This handles repeated parameters like $1, $2, $2 correctly
572
- # Special case: QMARK (?) parameters converting to NUMERIC ($1, $2) need sequential numbering
589
+ # Optimize parameter style detection
573
590
  param_styles = {p.style for p in param_info}
574
- use_sequential_for_qmark = param_styles == {ParameterStyle.QMARK} and target_style == ParameterStyle.NUMERIC
591
+ use_sequential_for_qmark = (
592
+ len(param_styles) == 1 and ParameterStyle.QMARK in param_styles and target_style == ParameterStyle.NUMERIC
593
+ )
575
594
 
595
+ # Build unique parameters mapping efficiently
576
596
  unique_params: dict[str, int] = {}
577
597
  for param in param_info:
578
- if use_sequential_for_qmark and param.style == ParameterStyle.QMARK:
579
- # For QMARK → NUMERIC conversion, each ? gets sequential numbering
580
- param_key = f"{param.placeholder_text}_{param.ordinal}"
581
- else:
582
- # For all other cases, group by placeholder text
583
- param_key = param.placeholder_text
598
+ param_key = (
599
+ f"{param.placeholder_text}_{param.ordinal}"
600
+ if use_sequential_for_qmark and param.style == ParameterStyle.QMARK
601
+ else param.placeholder_text
602
+ )
584
603
 
585
604
  if param_key not in unique_params:
586
605
  unique_params[param_key] = len(unique_params)
587
606
 
607
+ # Convert SQL with optimized string operations
588
608
  converted_sql = sql
609
+ placeholder_text_len_cache: dict[str, int] = {}
610
+
589
611
  for param in reversed(param_info):
612
+ # Cache placeholder text length to avoid recalculation
613
+ if param.placeholder_text not in placeholder_text_len_cache:
614
+ placeholder_text_len_cache[param.placeholder_text] = len(param.placeholder_text)
615
+ text_len = placeholder_text_len_cache[param.placeholder_text]
616
+
590
617
  # Generate new placeholder based on target style
591
618
  if target_style in {
592
619
  ParameterStyle.QMARK,
@@ -594,23 +621,19 @@ class ParameterConverter:
594
621
  ParameterStyle.POSITIONAL_PYFORMAT,
595
622
  ParameterStyle.POSITIONAL_COLON,
596
623
  }:
597
- # Use the appropriate key for the unique parameter mapping
598
- if use_sequential_for_qmark and param.style == ParameterStyle.QMARK:
599
- param_key = f"{param.placeholder_text}_{param.ordinal}"
600
- else:
601
- param_key = param.placeholder_text
602
-
603
- ordinal_to_use = unique_params[param_key]
604
- new_placeholder = generator(ordinal_to_use)
624
+ param_key = (
625
+ f"{param.placeholder_text}_{param.ordinal}"
626
+ if use_sequential_for_qmark and param.style == ParameterStyle.QMARK
627
+ else param.placeholder_text
628
+ )
629
+ new_placeholder = generator(unique_params[param_key])
605
630
  else: # Named styles
606
631
  param_name = param.name or f"param_{param.ordinal}"
607
632
  new_placeholder = generator(param_name)
608
633
 
609
- # Replace in SQL
634
+ # Optimized string replacement
610
635
  converted_sql = (
611
- converted_sql[: param.position]
612
- + new_placeholder
613
- + converted_sql[param.position + len(param.placeholder_text) :]
636
+ converted_sql[: param.position] + new_placeholder + converted_sql[param.position + text_len :]
614
637
  )
615
638
 
616
639
  return converted_sql
@@ -1116,9 +1139,14 @@ class ParameterProcessor:
1116
1139
  def _apply_type_wrapping(self, parameters: Any) -> Any:
1117
1140
  """Apply type wrapping using singledispatch for performance."""
1118
1141
  if isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)):
1142
+ # Optimize with direct iteration instead of list comprehension for better memory usage
1119
1143
  return [_wrap_parameter_by_type(p) for p in parameters]
1120
1144
  if isinstance(parameters, Mapping):
1121
- return {k: _wrap_parameter_by_type(v) for k, v in parameters.items()}
1145
+ # Optimize dict comprehension with items() iteration
1146
+ wrapped_dict = {}
1147
+ for k, v in parameters.items():
1148
+ wrapped_dict[k] = _wrap_parameter_by_type(v)
1149
+ return wrapped_dict
1122
1150
  return _wrap_parameter_by_type(parameters)
1123
1151
 
1124
1152
  def _apply_type_coercions(
sqlspec/core/result.py CHANGED
@@ -188,19 +188,22 @@ class SQLResult(StatementResult):
188
188
  self._operation_type = operation_type
189
189
  self.operation_index = operation_index
190
190
  self.parameters = parameters
191
- self.column_names = column_names if column_names is not None else []
191
+
192
+ # Optimize list initialization to avoid unnecessary object creation
193
+ self.column_names = column_names or []
192
194
  self.total_count = total_count
193
195
  self.has_more = has_more
194
- self.inserted_ids = inserted_ids if inserted_ids is not None else []
195
- self.statement_results: list[SQLResult] = statement_results if statement_results is not None else []
196
- self.errors = errors if errors is not None else []
196
+ self.inserted_ids = inserted_ids or []
197
+ self.statement_results = statement_results or []
198
+ self.errors = errors or []
197
199
  self.total_statements = total_statements
198
200
  self.successful_statements = successful_statements
199
201
 
200
- if not self.column_names and self.data is not None and self.data:
201
- self.column_names = list(self.data[0].keys())
202
+ # Optimize column name extraction and count calculation
203
+ if not self.column_names and data and len(data) > 0:
204
+ self.column_names = list(data[0].keys())
202
205
  if self.total_count is None:
203
- self.total_count = len(self.data) if self.data is not None else 0
206
+ self.total_count = len(data) if data is not None else 0
204
207
 
205
208
  @property
206
209
  def operation_type(self) -> "OperationType":
@@ -256,18 +259,21 @@ class SQLResult(StatementResult):
256
259
  Returns:
257
260
  List of result rows or script summary.
258
261
  """
259
- if self.operation_type.upper() == "SCRIPT":
262
+ op_type_upper = self.operation_type.upper()
263
+ if op_type_upper == "SCRIPT":
264
+ # Cache calculation to avoid redundant work
265
+ failed_statements = self.total_statements - self.successful_statements
260
266
  return [
261
267
  {
262
268
  "total_statements": self.total_statements,
263
269
  "successful_statements": self.successful_statements,
264
- "failed_statements": self.total_statements - self.successful_statements,
270
+ "failed_statements": failed_statements,
265
271
  "errors": self.errors,
266
272
  "statement_results": self.statement_results,
267
273
  "total_rows_affected": self.get_total_rows_affected(),
268
274
  }
269
275
  ]
270
- return self.data if self.data is not None else []
276
+ return self.data or []
271
277
 
272
278
  def add_statement_result(self, result: "SQLResult") -> None:
273
279
  """Add a statement result to the script execution results.
@@ -287,9 +293,11 @@ class SQLResult(StatementResult):
287
293
  Total rows affected.
288
294
  """
289
295
  if self.statement_results:
290
- return sum(
291
- stmt.rows_affected for stmt in self.statement_results if stmt.rows_affected and stmt.rows_affected > 0
292
- )
296
+ total = 0
297
+ for stmt in self.statement_results:
298
+ if stmt.rows_affected and stmt.rows_affected > 0:
299
+ total += stmt.rows_affected
300
+ return total
293
301
  return self.rows_affected if self.rows_affected and self.rows_affected > 0 else 0
294
302
 
295
303
  @property
@@ -394,9 +402,7 @@ class SQLResult(StatementResult):
394
402
  Returns:
395
403
  Iterator that yields each row as a dictionary
396
404
  """
397
- if self.data is None:
398
- return iter([])
399
- return iter(self.data)
405
+ return iter(self.data or [])
400
406
 
401
407
  def all(self) -> list[dict[str, Any]]:
402
408
  """Return all rows as a list.
@@ -415,14 +421,18 @@ class SQLResult(StatementResult):
415
421
  Raises:
416
422
  ValueError: If no results or more than one result
417
423
  """
418
- data_len = 0 if self.data is None else len(self.data)
424
+ if not self.data:
425
+ msg = "No result found, exactly one row expected"
426
+ raise ValueError(msg)
419
427
 
428
+ data_len = len(self.data)
420
429
  if data_len == 0:
421
430
  msg = "No result found, exactly one row expected"
422
431
  raise ValueError(msg)
423
432
  if data_len > 1:
424
433
  msg = f"Multiple results found ({data_len}), exactly one row expected"
425
434
  raise ValueError(msg)
435
+
426
436
  return cast("dict[str, Any]", self.data[0])
427
437
 
428
438
  def one_or_none(self) -> "Optional[dict[str, Any]]":
@@ -438,9 +448,12 @@ class SQLResult(StatementResult):
438
448
  return None
439
449
 
440
450
  data_len = len(self.data)
451
+ if data_len == 0:
452
+ return None
441
453
  if data_len > 1:
442
454
  msg = f"Multiple results found ({data_len}), at most one row expected"
443
455
  raise ValueError(msg)
456
+
444
457
  return cast("dict[str, Any]", self.data[0])
445
458
 
446
459
  def scalar(self) -> Any: