sqlspec 0.24.1__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (95) hide show
  1. sqlspec/_serialization.py +223 -21
  2. sqlspec/_sql.py +20 -62
  3. sqlspec/_typing.py +11 -0
  4. sqlspec/adapters/adbc/config.py +8 -1
  5. sqlspec/adapters/adbc/data_dictionary.py +290 -0
  6. sqlspec/adapters/adbc/driver.py +129 -20
  7. sqlspec/adapters/adbc/type_converter.py +159 -0
  8. sqlspec/adapters/aiosqlite/config.py +3 -0
  9. sqlspec/adapters/aiosqlite/data_dictionary.py +117 -0
  10. sqlspec/adapters/aiosqlite/driver.py +17 -3
  11. sqlspec/adapters/asyncmy/_types.py +1 -1
  12. sqlspec/adapters/asyncmy/config.py +11 -8
  13. sqlspec/adapters/asyncmy/data_dictionary.py +122 -0
  14. sqlspec/adapters/asyncmy/driver.py +31 -7
  15. sqlspec/adapters/asyncpg/config.py +3 -0
  16. sqlspec/adapters/asyncpg/data_dictionary.py +134 -0
  17. sqlspec/adapters/asyncpg/driver.py +19 -4
  18. sqlspec/adapters/bigquery/config.py +3 -0
  19. sqlspec/adapters/bigquery/data_dictionary.py +109 -0
  20. sqlspec/adapters/bigquery/driver.py +21 -3
  21. sqlspec/adapters/bigquery/type_converter.py +93 -0
  22. sqlspec/adapters/duckdb/_types.py +1 -1
  23. sqlspec/adapters/duckdb/config.py +2 -0
  24. sqlspec/adapters/duckdb/data_dictionary.py +124 -0
  25. sqlspec/adapters/duckdb/driver.py +32 -5
  26. sqlspec/adapters/duckdb/pool.py +1 -1
  27. sqlspec/adapters/duckdb/type_converter.py +103 -0
  28. sqlspec/adapters/oracledb/config.py +6 -0
  29. sqlspec/adapters/oracledb/data_dictionary.py +442 -0
  30. sqlspec/adapters/oracledb/driver.py +68 -9
  31. sqlspec/adapters/oracledb/migrations.py +51 -67
  32. sqlspec/adapters/oracledb/type_converter.py +132 -0
  33. sqlspec/adapters/psqlpy/config.py +3 -0
  34. sqlspec/adapters/psqlpy/data_dictionary.py +133 -0
  35. sqlspec/adapters/psqlpy/driver.py +23 -179
  36. sqlspec/adapters/psqlpy/type_converter.py +73 -0
  37. sqlspec/adapters/psycopg/config.py +8 -4
  38. sqlspec/adapters/psycopg/data_dictionary.py +257 -0
  39. sqlspec/adapters/psycopg/driver.py +40 -5
  40. sqlspec/adapters/sqlite/config.py +3 -0
  41. sqlspec/adapters/sqlite/data_dictionary.py +117 -0
  42. sqlspec/adapters/sqlite/driver.py +18 -3
  43. sqlspec/adapters/sqlite/pool.py +13 -4
  44. sqlspec/base.py +3 -4
  45. sqlspec/builder/_base.py +130 -48
  46. sqlspec/builder/_column.py +66 -24
  47. sqlspec/builder/_ddl.py +91 -41
  48. sqlspec/builder/_insert.py +40 -58
  49. sqlspec/builder/_parsing_utils.py +127 -12
  50. sqlspec/builder/_select.py +147 -2
  51. sqlspec/builder/_update.py +1 -1
  52. sqlspec/builder/mixins/_cte_and_set_ops.py +31 -23
  53. sqlspec/builder/mixins/_delete_operations.py +12 -7
  54. sqlspec/builder/mixins/_insert_operations.py +50 -36
  55. sqlspec/builder/mixins/_join_operations.py +15 -30
  56. sqlspec/builder/mixins/_merge_operations.py +210 -78
  57. sqlspec/builder/mixins/_order_limit_operations.py +4 -10
  58. sqlspec/builder/mixins/_pivot_operations.py +1 -0
  59. sqlspec/builder/mixins/_select_operations.py +44 -22
  60. sqlspec/builder/mixins/_update_operations.py +30 -37
  61. sqlspec/builder/mixins/_where_clause.py +52 -70
  62. sqlspec/cli.py +246 -140
  63. sqlspec/config.py +33 -19
  64. sqlspec/core/__init__.py +3 -2
  65. sqlspec/core/cache.py +298 -352
  66. sqlspec/core/compiler.py +61 -4
  67. sqlspec/core/filters.py +246 -213
  68. sqlspec/core/hashing.py +9 -11
  69. sqlspec/core/parameters.py +27 -10
  70. sqlspec/core/statement.py +72 -12
  71. sqlspec/core/type_conversion.py +234 -0
  72. sqlspec/driver/__init__.py +6 -3
  73. sqlspec/driver/_async.py +108 -5
  74. sqlspec/driver/_common.py +186 -17
  75. sqlspec/driver/_sync.py +108 -5
  76. sqlspec/driver/mixins/_result_tools.py +60 -7
  77. sqlspec/exceptions.py +5 -0
  78. sqlspec/loader.py +8 -9
  79. sqlspec/migrations/__init__.py +4 -3
  80. sqlspec/migrations/base.py +153 -14
  81. sqlspec/migrations/commands.py +34 -96
  82. sqlspec/migrations/context.py +145 -0
  83. sqlspec/migrations/loaders.py +25 -8
  84. sqlspec/migrations/runner.py +352 -82
  85. sqlspec/storage/backends/fsspec.py +1 -0
  86. sqlspec/typing.py +4 -0
  87. sqlspec/utils/config_resolver.py +153 -0
  88. sqlspec/utils/serializers.py +50 -2
  89. {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/METADATA +1 -1
  90. sqlspec-0.26.0.dist-info/RECORD +157 -0
  91. sqlspec-0.24.1.dist-info/RECORD +0 -139
  92. {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/WHEEL +0 -0
  93. {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/entry_points.txt +0 -0
  94. {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/licenses/LICENSE +0 -0
  95. {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/licenses/NOTICE +0 -0
@@ -10,6 +10,7 @@ from sqlglot import exp
10
10
  from typing_extensions import Self
11
11
 
12
12
  from sqlspec.builder._base import QueryBuilder
13
+ from sqlspec.builder._parsing_utils import extract_sql_object_expression
13
14
  from sqlspec.builder.mixins import InsertFromSelectMixin, InsertIntoClauseMixin, InsertValuesMixin, ReturningClauseMixin
14
15
  from sqlspec.core.result import SQLResult
15
16
  from sqlspec.exceptions import SQLBuilderError
@@ -46,7 +47,6 @@ class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSe
46
47
  """
47
48
  super().__init__(**kwargs)
48
49
 
49
- # Initialize Insert-specific attributes
50
50
  self._table: Optional[str] = None
51
51
  self._columns: list[str] = []
52
52
  self._values_added_count: int = 0
@@ -90,6 +90,10 @@ class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSe
90
90
  raise SQLBuilderError(ERR_MSG_INTERNAL_EXPRESSION_TYPE)
91
91
  return self._expression
92
92
 
93
+ def get_insert_expression(self) -> exp.Insert:
94
+ """Get the insert expression (public API)."""
95
+ return self._get_insert_expression()
96
+
93
97
  def values(self, *values: Any, **kwargs: Any) -> "Self":
94
98
  """Adds a row of values to the INSERT statement.
95
99
 
@@ -126,10 +130,10 @@ class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSe
126
130
 
127
131
  if len(values) == 1:
128
132
  values_0 = values[0]
129
- if hasattr(values_0, "items") and hasattr(values_0, "keys"):
133
+ if isinstance(values_0, dict):
130
134
  return self.values_from_dict(values_0)
131
135
 
132
- insert_expr = self._get_insert_expression()
136
+ insert_expr = self.get_insert_expression()
133
137
 
134
138
  if self._columns and len(values) != len(self._columns):
135
139
  msg = ERR_MSG_VALUES_COLUMNS_MISMATCH.format(values_len=len(values), columns_len=len(self._columns))
@@ -140,29 +144,15 @@ class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSe
140
144
  if isinstance(value, exp.Expression):
141
145
  value_placeholders.append(value)
142
146
  elif has_expression_and_sql(value):
143
- # Handle SQL objects (from sql.raw with parameters)
144
- expression = getattr(value, "expression", None)
145
- if expression is not None and isinstance(expression, exp.Expression):
146
- # Merge parameters from SQL object into builder
147
- self._merge_sql_object_parameters(value)
148
- value_placeholders.append(expression)
149
- else:
150
- # If expression is None, fall back to parsing the raw SQL
151
- sql_text = getattr(value, "sql", "")
152
- # Merge parameters even when parsing raw SQL
153
- self._merge_sql_object_parameters(value)
154
- # Check if sql_text is callable (like Expression.sql method)
155
- if callable(sql_text):
156
- sql_text = str(value)
157
- value_expr = exp.maybe_parse(sql_text) or exp.convert(str(sql_text))
158
- value_placeholders.append(value_expr)
147
+ value_expr = extract_sql_object_expression(value, builder=self)
148
+ value_placeholders.append(value_expr)
159
149
  else:
160
150
  if self._columns and i < len(self._columns):
161
151
  column_str = str(self._columns[i])
162
152
  column_name = column_str.rsplit(".", maxsplit=1)[-1] if "." in column_str else column_str
163
- param_name = self._generate_unique_parameter_name(column_name)
153
+ param_name = self.generate_unique_parameter_name(column_name)
164
154
  else:
165
- param_name = self._generate_unique_parameter_name(f"value_{i + 1}")
155
+ param_name = self.generate_unique_parameter_name(f"value_{i + 1}")
166
156
  _, param_name = self.add_parameter(value, name=param_name)
167
157
  value_placeholders.append(exp.Placeholder(this=param_name))
168
158
 
@@ -254,17 +244,14 @@ class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSe
254
244
 
255
245
  Example:
256
246
  ```python
257
- # ON CONFLICT (id) DO NOTHING
258
247
  sql.insert("users").values(id=1, name="John").on_conflict(
259
248
  "id"
260
249
  ).do_nothing()
261
250
 
262
- # ON CONFLICT (email, username) DO UPDATE SET updated_at = NOW()
263
251
  sql.insert("users").values(...).on_conflict(
264
252
  "email", "username"
265
253
  ).do_update(updated_at=sql.raw("NOW()"))
266
254
 
267
- # ON CONFLICT DO NOTHING (catches all conflicts)
268
255
  sql.insert("users").values(...).on_conflict().do_nothing()
269
256
  ```
270
257
  """
@@ -286,22 +273,41 @@ class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSe
286
273
  return self.on_conflict(*columns).do_nothing()
287
274
 
288
275
  def on_duplicate_key_update(self, **kwargs: Any) -> "Insert":
289
- """Adds conflict resolution using the ON CONFLICT syntax (cross-database compatible).
276
+ """Adds MySQL-style ON DUPLICATE KEY UPDATE clause.
290
277
 
291
278
  Args:
292
- **kwargs: Column-value pairs to update on conflict.
279
+ **kwargs: Column-value pairs to update on duplicate key.
293
280
 
294
281
  Returns:
295
282
  The current builder instance for method chaining.
296
283
 
297
284
  Note:
298
- This method uses PostgreSQL-style ON CONFLICT syntax but SQLGlot will
299
- transpile it to the appropriate syntax for each database (MySQL's
300
- ON DUPLICATE KEY UPDATE, etc.).
285
+ This method creates MySQL-specific ON DUPLICATE KEY UPDATE syntax.
286
+ For PostgreSQL, use on_conflict() instead.
301
287
  """
302
288
  if not kwargs:
303
289
  return self
304
- return self.on_conflict().do_update(**kwargs)
290
+
291
+ insert_expr = self._get_insert_expression()
292
+
293
+ set_expressions = []
294
+ for col, val in kwargs.items():
295
+ if has_expression_and_sql(val):
296
+ value_expr = extract_sql_object_expression(val, builder=self)
297
+ elif isinstance(val, exp.Expression):
298
+ value_expr = val
299
+ else:
300
+ param_name = self.generate_unique_parameter_name(col)
301
+ _, param_name = self.add_parameter(val, name=param_name)
302
+ value_expr = exp.Placeholder(this=param_name)
303
+
304
+ set_expressions.append(exp.EQ(this=exp.column(col), expression=value_expr))
305
+
306
+ on_conflict = exp.OnConflict(duplicate=True, action=exp.var("UPDATE"), expressions=set_expressions or None)
307
+
308
+ insert_expr.set("conflict", on_conflict)
309
+
310
+ return self
305
311
 
306
312
 
307
313
  class ConflictBuilder:
@@ -336,9 +342,8 @@ class ConflictBuilder:
336
342
  ).do_nothing()
337
343
  ```
338
344
  """
339
- insert_expr = self._insert_builder._get_insert_expression()
345
+ insert_expr = self._insert_builder.get_insert_expression()
340
346
 
341
- # Create ON CONFLICT with proper structure
342
347
  conflict_keys = [exp.to_identifier(col) for col in self._columns] if self._columns else None
343
348
  on_conflict = exp.OnConflict(conflict_keys=conflict_keys, action=exp.var("DO NOTHING"))
344
349
 
@@ -363,44 +368,21 @@ class ConflictBuilder:
363
368
  )
364
369
  ```
365
370
  """
366
- insert_expr = self._insert_builder._get_insert_expression()
371
+ insert_expr = self._insert_builder.get_insert_expression()
367
372
 
368
- # Create SET expressions for the UPDATE
369
373
  set_expressions = []
370
374
  for col, val in kwargs.items():
371
375
  if has_expression_and_sql(val):
372
- # Handle SQL objects (from sql.raw with parameters)
373
- expression = getattr(val, "expression", None)
374
- if expression is not None and isinstance(expression, exp.Expression):
375
- # Merge parameters from SQL object into builder
376
- if hasattr(val, "parameters"):
377
- sql_parameters = getattr(val, "parameters", {})
378
- for param_name, param_value in sql_parameters.items():
379
- self._insert_builder.add_parameter(param_value, name=param_name)
380
- value_expr = expression
381
- else:
382
- # If expression is None, fall back to parsing the raw SQL
383
- sql_text = getattr(val, "sql", "")
384
- # Merge parameters even when parsing raw SQL
385
- if hasattr(val, "parameters"):
386
- sql_parameters = getattr(val, "parameters", {})
387
- for param_name, param_value in sql_parameters.items():
388
- self._insert_builder.add_parameter(param_value, name=param_name)
389
- # Check if sql_text is callable (like Expression.sql method)
390
- if callable(sql_text):
391
- sql_text = str(val)
392
- value_expr = exp.maybe_parse(sql_text) or exp.convert(str(sql_text))
376
+ value_expr = extract_sql_object_expression(val, builder=self._insert_builder)
393
377
  elif isinstance(val, exp.Expression):
394
378
  value_expr = val
395
379
  else:
396
- # Create parameter for regular values
397
- param_name = self._insert_builder._generate_unique_parameter_name(col)
380
+ param_name = self._insert_builder.generate_unique_parameter_name(col)
398
381
  _, param_name = self._insert_builder.add_parameter(val, name=param_name)
399
382
  value_expr = exp.Placeholder(this=param_name)
400
383
 
401
384
  set_expressions.append(exp.EQ(this=exp.column(col), expression=value_expr))
402
385
 
403
- # Create ON CONFLICT with proper structure
404
386
  conflict_keys = [exp.to_identifier(col) for col in self._columns] if self._columns else None
405
387
  on_conflict = exp.OnConflict(
406
388
  conflict_keys=conflict_keys, action=exp.var("DO UPDATE"), expressions=set_expressions or None
@@ -9,7 +9,7 @@ from typing import Any, Final, Optional, Union, cast
9
9
 
10
10
  from sqlglot import exp, maybe_parse, parse_one
11
11
 
12
- from sqlspec.core.parameters import ParameterStyle
12
+ from sqlspec.core.parameters import ParameterStyle, ParameterValidator
13
13
  from sqlspec.utils.type_guards import (
14
14
  has_expression_and_parameters,
15
15
  has_expression_and_sql,
@@ -18,6 +18,27 @@ from sqlspec.utils.type_guards import (
18
18
  )
19
19
 
20
20
 
21
+ def extract_column_name(column: Union[str, exp.Column]) -> str:
22
+ """Extract column name from column expression for parameter naming.
23
+
24
+ Args:
25
+ column: Column expression (string or SQLGlot Column)
26
+
27
+ Returns:
28
+ Column name as string for use as parameter name
29
+ """
30
+ if isinstance(column, str):
31
+ if "." in column:
32
+ return column.split(".")[-1]
33
+ return column
34
+ if isinstance(column, exp.Column):
35
+ try:
36
+ return str(column.this.this)
37
+ except AttributeError:
38
+ return str(column.this) if column.this else "column"
39
+ return "column"
40
+
41
+
21
42
  def parse_column_expression(
22
43
  column_input: Union[str, exp.Expression, Any], builder: Optional[Any] = None
23
44
  ) -> exp.Expression:
@@ -42,9 +63,7 @@ def parse_column_expression(
42
63
  if isinstance(column_input, exp.Expression):
43
64
  return column_input
44
65
 
45
- # Handle SQL objects (from sql.raw with parameters)
46
66
  if has_expression_and_sql(column_input):
47
- # This is likely a SQL object
48
67
  expression = getattr(column_input, "expression", None)
49
68
  if expression is not None and isinstance(expression, exp.Expression):
50
69
  # Merge parameters from SQL object into builder if available
@@ -53,9 +72,7 @@ def parse_column_expression(
53
72
  for param_name, param_value in sql_parameters.items():
54
73
  builder.add_parameter(param_value, name=param_name)
55
74
  return cast("exp.Expression", expression)
56
- # If expression is None, fall back to parsing the raw SQL
57
75
  sql_text = getattr(column_input, "sql", "")
58
- # Merge parameters even when parsing raw SQL
59
76
  if builder and has_expression_and_parameters(column_input) and hasattr(builder, "add_parameter"):
60
77
  sql_parameters = getattr(column_input, "parameters", {})
61
78
  for param_name, param_value in sql_parameters.items():
@@ -139,10 +156,8 @@ def parse_condition_expression(
139
156
  if value is None:
140
157
  return exp.Is(this=column_expr, expression=exp.null())
141
158
  if builder and has_parameter_builder(builder):
142
- from sqlspec.builder.mixins._where_clause import _extract_column_name
143
-
144
- column_name = _extract_column_name(column)
145
- param_name = builder._generate_unique_parameter_name(column_name)
159
+ column_name = extract_column_name(column)
160
+ param_name = builder.generate_unique_parameter_name(column_name)
146
161
  _, param_name = builder.add_parameter(value, name=param_name)
147
162
  return exp.EQ(this=column_expr, expression=exp.Placeholder(this=param_name))
148
163
  if isinstance(value, str):
@@ -156,8 +171,6 @@ def parse_condition_expression(
156
171
 
157
172
  # Convert database-specific parameter styles to SQLGlot-compatible format
158
173
  # This ensures that placeholders like $1, %s, :1 are properly recognized as parameters
159
- from sqlspec.core.parameters import ParameterValidator
160
-
161
174
  validator = ParameterValidator()
162
175
  param_info = validator.extract_parameters(condition_input)
163
176
 
@@ -186,4 +199,106 @@ def parse_condition_expression(
186
199
  return exp.condition(condition_input)
187
200
 
188
201
 
189
- __all__ = ("parse_column_expression", "parse_condition_expression", "parse_order_expression", "parse_table_expression")
202
+ def extract_sql_object_expression(value: Any, builder: Optional[Any] = None) -> exp.Expression:
203
+ """Extract SQLGlot expression from SQL object value with parameter merging.
204
+
205
+ Handles the common pattern of:
206
+ 1. Check if value has expression and SQL attributes
207
+ 2. Try to get expression first, merge parameters if available
208
+ 3. Fall back to parsing raw SQL text if expression is None
209
+ 4. Merge parameters in both cases
210
+ 5. Handle callable SQL text
211
+
212
+ This consolidates duplicated logic across builder files that process
213
+ SQL objects (like those from sql.raw() calls).
214
+
215
+ Args:
216
+ value: The SQL object value to process
217
+ builder: Optional builder instance for parameter merging (must have add_parameter method)
218
+
219
+ Returns:
220
+ SQLGlot Expression extracted from the SQL object
221
+
222
+ Raises:
223
+ ValueError: If the value doesn't appear to be a SQL object
224
+ """
225
+ if not has_expression_and_sql(value):
226
+ msg = f"Value does not have both expression and sql attributes: {type(value)}"
227
+ raise ValueError(msg)
228
+
229
+ # Try expression attribute first
230
+ expression = getattr(value, "expression", None)
231
+ if expression is not None and isinstance(expression, exp.Expression):
232
+ # Merge parameters if available and builder supports it
233
+ if builder and hasattr(value, "parameters") and hasattr(builder, "add_parameter"):
234
+ sql_parameters = getattr(value, "parameters", {})
235
+ for param_name, param_value in sql_parameters.items():
236
+ builder.add_parameter(param_value, name=param_name)
237
+ return cast("exp.Expression", expression)
238
+
239
+ # Fall back to parsing raw SQL text
240
+ sql_text = getattr(value, "sql", "")
241
+
242
+ # Merge parameters even when parsing raw SQL
243
+ if builder and hasattr(value, "parameters") and hasattr(builder, "add_parameter"):
244
+ sql_parameters = getattr(value, "parameters", {})
245
+ for param_name, param_value in sql_parameters.items():
246
+ builder.add_parameter(param_value, name=param_name)
247
+
248
+ # Handle callable SQL text
249
+ if callable(sql_text):
250
+ sql_text = str(value)
251
+
252
+ # Parse SQL text and return as expression
253
+ return exp.maybe_parse(sql_text) or exp.convert(str(sql_text))
254
+
255
+
256
+ def extract_expression(value: Any) -> exp.Expression:
257
+ """Extract SQLGlot expression from value, handling wrapper types.
258
+
259
+ Args:
260
+ value: String, SQLGlot expression, or wrapper type.
261
+
262
+ Returns:
263
+ Raw SQLGlot expression.
264
+ """
265
+ from sqlspec.builder._column import Column
266
+ from sqlspec.builder._expression_wrappers import ExpressionWrapper
267
+ from sqlspec.builder.mixins._select_operations import Case
268
+
269
+ if isinstance(value, str):
270
+ return exp.column(value)
271
+ if isinstance(value, Column):
272
+ return value.sqlglot_expression
273
+ if isinstance(value, ExpressionWrapper):
274
+ return value.expression
275
+ if isinstance(value, Case):
276
+ return exp.Case(ifs=value.conditions, default=value.default)
277
+ if isinstance(value, exp.Expression):
278
+ return value
279
+ return exp.convert(value)
280
+
281
+
282
+ def to_expression(value: Any) -> exp.Expression:
283
+ """Convert a Python value to a raw SQLGlot expression.
284
+
285
+ Args:
286
+ value: Python value or SQLGlot expression to convert.
287
+
288
+ Returns:
289
+ Raw SQLGlot expression.
290
+ """
291
+ if isinstance(value, exp.Expression):
292
+ return value
293
+ return exp.convert(value)
294
+
295
+
296
+ __all__ = (
297
+ "extract_expression",
298
+ "extract_sql_object_expression",
299
+ "parse_column_expression",
300
+ "parse_condition_expression",
301
+ "parse_order_expression",
302
+ "parse_table_expression",
303
+ "to_expression",
304
+ )
@@ -5,7 +5,7 @@ parameter binding and validation.
5
5
  """
6
6
 
7
7
  import re
8
- from typing import Any, Callable, Final, Optional, Union
8
+ from typing import Any, Callable, Final, Optional, Union, cast
9
9
 
10
10
  from sqlglot import exp
11
11
  from typing_extensions import Self
@@ -24,6 +24,7 @@ from sqlspec.builder.mixins import (
24
24
  WhereClauseMixin,
25
25
  )
26
26
  from sqlspec.core.result import SQLResult
27
+ from sqlspec.exceptions import SQLBuilderError
27
28
 
28
29
  __all__ = ("Select",)
29
30
 
@@ -73,7 +74,6 @@ class Select(
73
74
  """
74
75
  super().__init__(**kwargs)
75
76
 
76
- # Initialize Select-specific attributes
77
77
  self._with_parts: dict[str, Union[exp.CTE, Select]] = {}
78
78
  self._hints: list[dict[str, object]] = []
79
79
 
@@ -169,3 +169,148 @@ class Select(
169
169
  )
170
170
 
171
171
  return SafeQuery(sql=modified_sql, parameters=safe_query.parameters, dialect=safe_query.dialect)
172
+
173
+ def _validate_select_expression(self) -> None:
174
+ """Validate that current expression is a valid SELECT statement.
175
+
176
+ Raises:
177
+ SQLBuilderError: If expression is None or not a SELECT statement
178
+ """
179
+ if self._expression is None or not isinstance(self._expression, exp.Select):
180
+ msg = "Locking clauses can only be applied to SELECT statements"
181
+ raise SQLBuilderError(msg)
182
+
183
+ def _validate_lock_parameters(self, skip_locked: bool, nowait: bool) -> None:
184
+ """Validate locking parameters for conflicting options.
185
+
186
+ Args:
187
+ skip_locked: Whether SKIP LOCKED option is enabled
188
+ nowait: Whether NOWAIT option is enabled
189
+
190
+ Raises:
191
+ SQLBuilderError: If both skip_locked and nowait are True
192
+ """
193
+ if skip_locked and nowait:
194
+ msg = "Cannot use both skip_locked and nowait"
195
+ raise SQLBuilderError(msg)
196
+
197
+ def for_update(
198
+ self, *, skip_locked: bool = False, nowait: bool = False, of: "Optional[Union[str, list[str]]]" = None
199
+ ) -> "Self":
200
+ """Add FOR UPDATE clause to SELECT statement for row-level locking.
201
+
202
+ Args:
203
+ skip_locked: Skip rows that are already locked (SKIP LOCKED)
204
+ nowait: Return immediately if row is locked (NOWAIT)
205
+ of: Table names/aliases to lock (FOR UPDATE OF table)
206
+
207
+ Returns:
208
+ Self for method chaining
209
+ """
210
+ self._validate_select_expression()
211
+ self._validate_lock_parameters(skip_locked, nowait)
212
+
213
+ assert self._expression is not None
214
+ select_expr = cast("exp.Select", self._expression)
215
+
216
+ lock_args = {"update": True}
217
+
218
+ if skip_locked:
219
+ lock_args["wait"] = False
220
+ elif nowait:
221
+ lock_args["wait"] = True
222
+
223
+ if of:
224
+ tables = [of] if isinstance(of, str) else of
225
+ lock_args["expressions"] = [exp.table_(t) for t in tables] # type: ignore[assignment]
226
+
227
+ lock = exp.Lock(**lock_args)
228
+
229
+ current_locks = select_expr.args.get("locks", [])
230
+ current_locks.append(lock)
231
+ select_expr.set("locks", current_locks)
232
+
233
+ return self
234
+
235
+ def for_share(
236
+ self, *, skip_locked: bool = False, nowait: bool = False, of: "Optional[Union[str, list[str]]]" = None
237
+ ) -> "Self":
238
+ """Add FOR SHARE clause for shared row-level locking.
239
+
240
+ Args:
241
+ skip_locked: Skip rows that are already locked (SKIP LOCKED)
242
+ nowait: Return immediately if row is locked (NOWAIT)
243
+ of: Table names/aliases to lock (FOR SHARE OF table)
244
+
245
+ Returns:
246
+ Self for method chaining
247
+ """
248
+ self._validate_select_expression()
249
+ self._validate_lock_parameters(skip_locked, nowait)
250
+
251
+ assert self._expression is not None
252
+ select_expr = cast("exp.Select", self._expression)
253
+
254
+ lock_args = {"update": False}
255
+
256
+ if skip_locked:
257
+ lock_args["wait"] = False
258
+ elif nowait:
259
+ lock_args["wait"] = True
260
+
261
+ if of:
262
+ tables = [of] if isinstance(of, str) else of
263
+ lock_args["expressions"] = [exp.table_(t) for t in tables] # type: ignore[assignment]
264
+
265
+ lock = exp.Lock(**lock_args)
266
+
267
+ current_locks = select_expr.args.get("locks", [])
268
+ current_locks.append(lock)
269
+ select_expr.set("locks", current_locks)
270
+
271
+ return self
272
+
273
+ def for_key_share(self) -> "Self":
274
+ """Add FOR KEY SHARE clause (PostgreSQL-specific).
275
+
276
+ FOR KEY SHARE is like FOR SHARE, but the lock is weaker:
277
+ SELECT FOR UPDATE is blocked, but not SELECT FOR NO KEY UPDATE.
278
+
279
+ Returns:
280
+ Self for method chaining
281
+ """
282
+ self._validate_select_expression()
283
+
284
+ assert self._expression is not None
285
+ select_expr = cast("exp.Select", self._expression)
286
+
287
+ lock = exp.Lock(update=False, key=True)
288
+
289
+ current_locks = select_expr.args.get("locks", [])
290
+ current_locks.append(lock)
291
+ select_expr.set("locks", current_locks)
292
+
293
+ return self
294
+
295
+ def for_no_key_update(self) -> "Self":
296
+ """Add FOR NO KEY UPDATE clause (PostgreSQL-specific).
297
+
298
+ FOR NO KEY UPDATE is like FOR UPDATE, but the lock is weaker:
299
+ it does not block SELECT FOR KEY SHARE commands that attempt to
300
+ acquire a share lock on the same rows.
301
+
302
+ Returns:
303
+ Self for method chaining
304
+ """
305
+ self._validate_select_expression()
306
+
307
+ assert self._expression is not None
308
+ select_expr = cast("exp.Select", self._expression)
309
+
310
+ lock = exp.Lock(update=True, key=False)
311
+
312
+ current_locks = select_expr.args.get("locks", [])
313
+ current_locks.append(lock)
314
+ select_expr.set("locks", current_locks)
315
+
316
+ return self
@@ -131,7 +131,7 @@ class Update(
131
131
  subquery_exp = exp.paren(exp.maybe_parse(subquery.sql, dialect=self.dialect))
132
132
  table_expr = exp.alias_(subquery_exp, alias) if alias else subquery_exp
133
133
 
134
- subquery_parameters = table._parameters
134
+ subquery_parameters = table.parameters
135
135
  if subquery_parameters:
136
136
  for p_name, p_value in subquery_parameters.items():
137
137
  self.add_parameter(p_value, name=p_name)