sqlspec 0.13.1__py3-none-any.whl → 0.14.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (112) hide show
  1. sqlspec/__init__.py +39 -1
  2. sqlspec/__main__.py +12 -0
  3. sqlspec/adapters/adbc/config.py +16 -40
  4. sqlspec/adapters/adbc/driver.py +43 -16
  5. sqlspec/adapters/adbc/transformers.py +108 -0
  6. sqlspec/adapters/aiosqlite/config.py +2 -20
  7. sqlspec/adapters/aiosqlite/driver.py +36 -18
  8. sqlspec/adapters/asyncmy/config.py +2 -33
  9. sqlspec/adapters/asyncmy/driver.py +23 -16
  10. sqlspec/adapters/asyncpg/config.py +5 -39
  11. sqlspec/adapters/asyncpg/driver.py +41 -18
  12. sqlspec/adapters/bigquery/config.py +2 -43
  13. sqlspec/adapters/bigquery/driver.py +26 -14
  14. sqlspec/adapters/duckdb/config.py +2 -49
  15. sqlspec/adapters/duckdb/driver.py +35 -16
  16. sqlspec/adapters/oracledb/config.py +4 -83
  17. sqlspec/adapters/oracledb/driver.py +54 -27
  18. sqlspec/adapters/psqlpy/config.py +2 -55
  19. sqlspec/adapters/psqlpy/driver.py +28 -8
  20. sqlspec/adapters/psycopg/config.py +4 -73
  21. sqlspec/adapters/psycopg/driver.py +69 -24
  22. sqlspec/adapters/sqlite/config.py +3 -21
  23. sqlspec/adapters/sqlite/driver.py +50 -26
  24. sqlspec/cli.py +248 -0
  25. sqlspec/config.py +18 -20
  26. sqlspec/driver/_async.py +28 -10
  27. sqlspec/driver/_common.py +5 -4
  28. sqlspec/driver/_sync.py +28 -10
  29. sqlspec/driver/mixins/__init__.py +6 -0
  30. sqlspec/driver/mixins/_cache.py +114 -0
  31. sqlspec/driver/mixins/_pipeline.py +0 -4
  32. sqlspec/{service/base.py → driver/mixins/_query_tools.py} +86 -421
  33. sqlspec/driver/mixins/_result_utils.py +0 -2
  34. sqlspec/driver/mixins/_sql_translator.py +0 -2
  35. sqlspec/driver/mixins/_storage.py +4 -18
  36. sqlspec/driver/mixins/_type_coercion.py +0 -2
  37. sqlspec/driver/parameters.py +4 -4
  38. sqlspec/extensions/aiosql/adapter.py +4 -4
  39. sqlspec/extensions/litestar/__init__.py +2 -1
  40. sqlspec/extensions/litestar/cli.py +48 -0
  41. sqlspec/extensions/litestar/plugin.py +3 -0
  42. sqlspec/loader.py +1 -1
  43. sqlspec/migrations/__init__.py +23 -0
  44. sqlspec/migrations/base.py +390 -0
  45. sqlspec/migrations/commands.py +525 -0
  46. sqlspec/migrations/runner.py +215 -0
  47. sqlspec/migrations/tracker.py +153 -0
  48. sqlspec/migrations/utils.py +89 -0
  49. sqlspec/protocols.py +37 -3
  50. sqlspec/statement/builder/__init__.py +8 -8
  51. sqlspec/statement/builder/{column.py → _column.py} +82 -52
  52. sqlspec/statement/builder/{ddl.py → _ddl.py} +5 -5
  53. sqlspec/statement/builder/_ddl_utils.py +1 -1
  54. sqlspec/statement/builder/{delete.py → _delete.py} +1 -1
  55. sqlspec/statement/builder/{insert.py → _insert.py} +1 -1
  56. sqlspec/statement/builder/{merge.py → _merge.py} +1 -1
  57. sqlspec/statement/builder/_parsing_utils.py +5 -3
  58. sqlspec/statement/builder/{select.py → _select.py} +59 -61
  59. sqlspec/statement/builder/{update.py → _update.py} +2 -2
  60. sqlspec/statement/builder/mixins/__init__.py +24 -30
  61. sqlspec/statement/builder/mixins/{_set_ops.py → _cte_and_set_ops.py} +86 -2
  62. sqlspec/statement/builder/mixins/{_delete_from.py → _delete_operations.py} +2 -0
  63. sqlspec/statement/builder/mixins/{_insert_values.py → _insert_operations.py} +70 -1
  64. sqlspec/statement/builder/mixins/{_merge_clauses.py → _merge_operations.py} +2 -0
  65. sqlspec/statement/builder/mixins/_order_limit_operations.py +123 -0
  66. sqlspec/statement/builder/mixins/{_pivot.py → _pivot_operations.py} +71 -2
  67. sqlspec/statement/builder/mixins/_select_operations.py +612 -0
  68. sqlspec/statement/builder/mixins/{_update_set.py → _update_operations.py} +73 -2
  69. sqlspec/statement/builder/mixins/_where_clause.py +536 -0
  70. sqlspec/statement/cache.py +50 -0
  71. sqlspec/statement/filters.py +37 -8
  72. sqlspec/statement/parameters.py +143 -54
  73. sqlspec/statement/pipelines/__init__.py +1 -1
  74. sqlspec/statement/pipelines/context.py +4 -10
  75. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +3 -3
  76. sqlspec/statement/pipelines/validators/_parameter_style.py +22 -22
  77. sqlspec/statement/pipelines/validators/_performance.py +1 -5
  78. sqlspec/statement/sql.py +246 -176
  79. sqlspec/utils/__init__.py +2 -1
  80. sqlspec/utils/statement_hashing.py +203 -0
  81. sqlspec/utils/type_guards.py +32 -0
  82. {sqlspec-0.13.1.dist-info → sqlspec-0.14.1.dist-info}/METADATA +1 -1
  83. sqlspec-0.14.1.dist-info/RECORD +145 -0
  84. sqlspec-0.14.1.dist-info/entry_points.txt +2 -0
  85. sqlspec/service/__init__.py +0 -4
  86. sqlspec/service/_util.py +0 -147
  87. sqlspec/service/pagination.py +0 -26
  88. sqlspec/statement/builder/mixins/_aggregate_functions.py +0 -250
  89. sqlspec/statement/builder/mixins/_case_builder.py +0 -91
  90. sqlspec/statement/builder/mixins/_common_table_expr.py +0 -90
  91. sqlspec/statement/builder/mixins/_from.py +0 -63
  92. sqlspec/statement/builder/mixins/_group_by.py +0 -118
  93. sqlspec/statement/builder/mixins/_having.py +0 -35
  94. sqlspec/statement/builder/mixins/_insert_from_select.py +0 -47
  95. sqlspec/statement/builder/mixins/_insert_into.py +0 -36
  96. sqlspec/statement/builder/mixins/_limit_offset.py +0 -53
  97. sqlspec/statement/builder/mixins/_order_by.py +0 -46
  98. sqlspec/statement/builder/mixins/_returning.py +0 -37
  99. sqlspec/statement/builder/mixins/_select_columns.py +0 -61
  100. sqlspec/statement/builder/mixins/_unpivot.py +0 -77
  101. sqlspec/statement/builder/mixins/_update_from.py +0 -55
  102. sqlspec/statement/builder/mixins/_update_table.py +0 -29
  103. sqlspec/statement/builder/mixins/_where.py +0 -401
  104. sqlspec/statement/builder/mixins/_window_functions.py +0 -86
  105. sqlspec/statement/parameter_manager.py +0 -220
  106. sqlspec/statement/sql_compiler.py +0 -140
  107. sqlspec-0.13.1.dist-info/RECORD +0 -150
  108. /sqlspec/statement/builder/{base.py → _base.py} +0 -0
  109. /sqlspec/statement/builder/mixins/{_join.py → _join_operations.py} +0 -0
  110. {sqlspec-0.13.1.dist-info → sqlspec-0.14.1.dist-info}/WHEEL +0 -0
  111. {sqlspec-0.13.1.dist-info → sqlspec-0.14.1.dist-info}/licenses/LICENSE +0 -0
  112. {sqlspec-0.13.1.dist-info → sqlspec-0.14.1.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,215 @@
1
+ """Migration execution engine for SQLSpec.
2
+
3
+ This module handles migration file loading and execution using SQLFileLoader.
4
+ """
5
+
6
+ import time
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING, Any, Optional
9
+
10
+ from sqlspec.migrations.base import BaseMigrationRunner
11
+ from sqlspec.utils.logging import get_logger
12
+
13
+ if TYPE_CHECKING:
14
+ from sqlspec.driver import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol
15
+ from sqlspec.statement.sql import SQL
16
+
17
+ __all__ = ("AsyncMigrationRunner", "SyncMigrationRunner")
18
+
19
+ logger = get_logger("migrations.runner")
20
+
21
+
22
+ class SyncMigrationRunner(BaseMigrationRunner["SyncDriverAdapterProtocol[Any]"]):
23
+ """Sync version - executes migrations using SQLFileLoader."""
24
+
25
+ def get_migration_files(self) -> "list[tuple[str, Path]]":
26
+ """Get all migration files sorted by version.
27
+
28
+ Returns:
29
+ List of (version, path) tuples sorted by version.
30
+ """
31
+ return self._get_migration_files_sync()
32
+
33
+ def load_migration(self, file_path: Path) -> "dict[str, Any]":
34
+ """Load a migration file and extract its components.
35
+
36
+ Args:
37
+ file_path: Path to the migration file.
38
+
39
+ Returns:
40
+ Dictionary containing migration metadata and queries.
41
+ """
42
+ return self._load_migration_metadata(file_path)
43
+
44
+ def execute_upgrade(
45
+ self, driver: "SyncDriverAdapterProtocol[Any]", migration: "dict[str, Any]"
46
+ ) -> "tuple[Optional[str], int]":
47
+ """Execute an upgrade migration.
48
+
49
+ Args:
50
+ driver: The database driver to use.
51
+ migration: Migration metadata dictionary.
52
+
53
+ Returns:
54
+ Tuple of (sql_content, execution_time_ms).
55
+ """
56
+ upgrade_sql = self._get_migration_sql(migration, "up")
57
+ if upgrade_sql is None:
58
+ return None, 0
59
+
60
+ start_time = time.time()
61
+
62
+ # Execute migration
63
+ driver.execute(upgrade_sql)
64
+
65
+ execution_time = int((time.time() - start_time) * 1000)
66
+ return None, execution_time
67
+
68
+ def execute_downgrade(
69
+ self, driver: "SyncDriverAdapterProtocol[Any]", migration: "dict[str, Any]"
70
+ ) -> "tuple[Optional[str], int]":
71
+ """Execute a downgrade migration.
72
+
73
+ Args:
74
+ driver: The database driver to use.
75
+ migration: Migration metadata dictionary.
76
+
77
+ Returns:
78
+ Tuple of (sql_content, execution_time_ms).
79
+ """
80
+ downgrade_sql = self._get_migration_sql(migration, "down")
81
+ if downgrade_sql is None:
82
+ return None, 0
83
+
84
+ start_time = time.time()
85
+
86
+ # Execute migration
87
+ driver.execute(downgrade_sql)
88
+
89
+ execution_time = int((time.time() - start_time) * 1000)
90
+ return None, execution_time
91
+
92
+ def load_all_migrations(self) -> "dict[str, SQL]":
93
+ """Load all migrations into a single namespace for bulk operations.
94
+
95
+ Returns a dictionary mapping query names to SQL objects.
96
+ Useful for:
97
+ - Migration analysis tools
98
+ - Documentation generation
99
+ - Validation and linting
100
+ - Migration squashing
101
+
102
+ Returns:
103
+ Dictionary mapping query names to SQL objects.
104
+ """
105
+ all_queries = {}
106
+ migrations = self.get_migration_files()
107
+
108
+ for _version, file_path in migrations:
109
+ self.loader.load_sql(file_path)
110
+
111
+ # Get all queries from this file
112
+ for query_name in self.loader.list_queries():
113
+ # Store with full query name for uniqueness
114
+ all_queries[query_name] = self.loader.get_sql(query_name)
115
+
116
+ return all_queries
117
+
118
+
119
+ class AsyncMigrationRunner(BaseMigrationRunner["AsyncDriverAdapterProtocol[Any]"]):
120
+ """Async version - executes migrations using SQLFileLoader."""
121
+
122
+ async def get_migration_files(self) -> "list[tuple[str, Path]]":
123
+ """Get all migration files sorted by version.
124
+
125
+ Returns:
126
+ List of tuples containing (version, file_path).
127
+ """
128
+ # For async, we still use the sync file operations since Path.glob is sync
129
+ return self._get_migration_files_sync()
130
+
131
+ async def load_migration(self, file_path: Path) -> "dict[str, Any]":
132
+ """Load a migration file and extract its components.
133
+
134
+ Args:
135
+ file_path: Path to the migration file.
136
+
137
+ Returns:
138
+ Dictionary containing migration metadata.
139
+ """
140
+ # File loading is still sync, so we use the base implementation
141
+ return self._load_migration_metadata(file_path)
142
+
143
+ async def execute_upgrade(
144
+ self, driver: "AsyncDriverAdapterProtocol[Any]", migration: "dict[str, Any]"
145
+ ) -> "tuple[Optional[str], int]":
146
+ """Execute an upgrade migration.
147
+
148
+ Args:
149
+ driver: The async database driver to use.
150
+ migration: Migration metadata dictionary.
151
+
152
+ Returns:
153
+ Tuple of (sql_content, execution_time_ms).
154
+ """
155
+ upgrade_sql = self._get_migration_sql(migration, "up")
156
+ if upgrade_sql is None:
157
+ return None, 0
158
+
159
+ start_time = time.time()
160
+
161
+ # Execute migration
162
+ await driver.execute(upgrade_sql)
163
+
164
+ execution_time = int((time.time() - start_time) * 1000)
165
+ return None, execution_time
166
+
167
+ async def execute_downgrade(
168
+ self, driver: "AsyncDriverAdapterProtocol[Any]", migration: "dict[str, Any]"
169
+ ) -> "tuple[Optional[str], int]":
170
+ """Execute a downgrade migration.
171
+
172
+ Args:
173
+ driver: The async database driver to use.
174
+ migration: Migration metadata dictionary.
175
+
176
+ Returns:
177
+ Tuple of (sql_content, execution_time_ms).
178
+ """
179
+ downgrade_sql = self._get_migration_sql(migration, "down")
180
+ if downgrade_sql is None:
181
+ return None, 0
182
+
183
+ start_time = time.time()
184
+
185
+ # Execute migration
186
+ await driver.execute(downgrade_sql)
187
+
188
+ execution_time = int((time.time() - start_time) * 1000)
189
+ return None, execution_time
190
+
191
+ async def load_all_migrations(self) -> "dict[str, SQL]":
192
+ """Load all migrations into a single namespace for bulk operations.
193
+
194
+ Returns a dictionary mapping query names to SQL objects.
195
+ Useful for:
196
+ - Migration analysis tools
197
+ - Documentation generation
198
+ - Validation and linting
199
+ - Migration squashing
200
+
201
+ Returns:
202
+ Dictionary mapping query names to SQL objects.
203
+ """
204
+ all_queries = {}
205
+ migrations = await self.get_migration_files()
206
+
207
+ for _version, file_path in migrations:
208
+ self.loader.load_sql(file_path)
209
+
210
+ # Get all queries from this file
211
+ for query_name in self.loader.list_queries():
212
+ # Store with full query name for uniqueness
213
+ all_queries[query_name] = self.loader.get_sql(query_name)
214
+
215
+ return all_queries
@@ -0,0 +1,153 @@
1
+ """Migration version tracking for SQLSpec.
2
+
3
+ This module provides functionality to track applied migrations in the database.
4
+ """
5
+
6
+ from typing import TYPE_CHECKING, Any, Optional
7
+
8
+ from sqlspec.migrations.base import BaseMigrationTracker
9
+
10
+ if TYPE_CHECKING:
11
+ from sqlspec.driver import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol
12
+
13
+ __all__ = ("AsyncMigrationTracker", "SyncMigrationTracker")
14
+
15
+
16
+ class SyncMigrationTracker(BaseMigrationTracker["SyncDriverAdapterProtocol[Any]"]):
17
+ """Sync version - tracks applied migrations in the database."""
18
+
19
+ def ensure_tracking_table(self, driver: "SyncDriverAdapterProtocol[Any]") -> None:
20
+ """Create the migration tracking table if it doesn't exist.
21
+
22
+ Args:
23
+ driver: The database driver to use.
24
+ """
25
+ driver.execute(self._get_create_table_sql())
26
+
27
+ def get_current_version(self, driver: "SyncDriverAdapterProtocol[Any]") -> Optional[str]:
28
+ """Get the latest applied migration version.
29
+
30
+ Args:
31
+ driver: The database driver to use.
32
+
33
+ Returns:
34
+ The current version number or None if no migrations applied.
35
+ """
36
+ result = driver.execute(self._get_current_version_sql())
37
+ return result.data[0]["version_num"] if result.data else None
38
+
39
+ def get_applied_migrations(self, driver: "SyncDriverAdapterProtocol[Any]") -> "list[dict[str, Any]]":
40
+ """Get all applied migrations in order.
41
+
42
+ Args:
43
+ driver: The database driver to use.
44
+
45
+ Returns:
46
+ List of migration records.
47
+ """
48
+ result = driver.execute(self._get_applied_migrations_sql())
49
+ return result.data
50
+
51
+ def record_migration(
52
+ self,
53
+ driver: "SyncDriverAdapterProtocol[Any]",
54
+ version: str,
55
+ description: str,
56
+ execution_time_ms: int,
57
+ checksum: str,
58
+ ) -> None:
59
+ """Record a successfully applied migration.
60
+
61
+ Args:
62
+ driver: The database driver to use.
63
+ version: Version number of the migration.
64
+ description: Description of the migration.
65
+ execution_time_ms: Execution time in milliseconds.
66
+ checksum: MD5 checksum of the migration content.
67
+ connection: Optional connection to use for the operation.
68
+ """
69
+ import os
70
+
71
+ applied_by = os.environ.get("USER", "unknown")
72
+
73
+ driver.execute(self._get_record_migration_sql(version, description, execution_time_ms, checksum, applied_by))
74
+
75
+ def remove_migration(self, driver: "SyncDriverAdapterProtocol[Any]", version: str) -> None:
76
+ """Remove a migration record (used during downgrade).
77
+
78
+ Args:
79
+ driver: The database driver to use.
80
+ version: Version number to remove.
81
+ connection: Optional connection to use for the operation.
82
+ """
83
+ driver.execute(self._get_remove_migration_sql(version))
84
+
85
+
86
+ class AsyncMigrationTracker(BaseMigrationTracker["AsyncDriverAdapterProtocol[Any]"]):
87
+ """Async version - tracks applied migrations in the database."""
88
+
89
+ async def ensure_tracking_table(self, driver: "AsyncDriverAdapterProtocol[Any]") -> None:
90
+ """Create the migration tracking table if it doesn't exist.
91
+
92
+ Args:
93
+ driver: The async database driver to use.
94
+ """
95
+ await driver.execute(self._get_create_table_sql())
96
+
97
+ async def get_current_version(self, driver: "AsyncDriverAdapterProtocol[Any]") -> Optional[str]:
98
+ """Get the latest applied migration version.
99
+
100
+ Args:
101
+ driver: The async database driver to use.
102
+
103
+ Returns:
104
+ The current version number or None if no migrations applied.
105
+ """
106
+ result = await driver.execute(self._get_current_version_sql())
107
+ return result.data[0]["version_num"] if result.data else None
108
+
109
+ async def get_applied_migrations(self, driver: "AsyncDriverAdapterProtocol[Any]") -> "list[dict[str, Any]]":
110
+ """Get all applied migrations in order.
111
+
112
+ Args:
113
+ driver: The async database driver to use.
114
+
115
+ Returns:
116
+ List of migration records.
117
+ """
118
+ result = await driver.execute(self._get_applied_migrations_sql())
119
+ return result.data
120
+
121
+ async def record_migration(
122
+ self,
123
+ driver: "AsyncDriverAdapterProtocol[Any]",
124
+ version: str,
125
+ description: str,
126
+ execution_time_ms: int,
127
+ checksum: str,
128
+ ) -> None:
129
+ """Record a successfully applied migration.
130
+
131
+ Args:
132
+ driver: The async database driver to use.
133
+ version: Version number of the migration.
134
+ description: Description of the migration.
135
+ execution_time_ms: Execution time in milliseconds.
136
+ checksum: MD5 checksum of the migration content.
137
+ """
138
+ import os
139
+
140
+ applied_by = os.environ.get("USER", "unknown")
141
+
142
+ await driver.execute(
143
+ self._get_record_migration_sql(version, description, execution_time_ms, checksum, applied_by)
144
+ )
145
+
146
+ async def remove_migration(self, driver: "AsyncDriverAdapterProtocol[Any]", version: str) -> None:
147
+ """Remove a migration record (used during downgrade).
148
+
149
+ Args:
150
+ driver: The async database driver to use.
151
+ version: Version number to remove.
152
+ """
153
+ await driver.execute(self._get_remove_migration_sql(version))
@@ -0,0 +1,89 @@
1
+ """Utility functions for SQLSpec migrations.
2
+
3
+ This module provides helper functions for migration operations.
4
+ """
5
+
6
+ import os
7
+ from datetime import datetime, timezone
8
+ from pathlib import Path
9
+ from typing import TYPE_CHECKING, Any, Optional
10
+
11
+ if TYPE_CHECKING:
12
+ from sqlspec.driver import AsyncDriverAdapterProtocol
13
+
14
+ __all__ = ("create_migration_file", "drop_all", "get_author")
15
+
16
+
17
+ def create_migration_file(migrations_dir: Path, version: str, message: str) -> Path:
18
+ """Create a new migration file from template.
19
+
20
+ Args:
21
+ migrations_dir: Directory to create the migration in.
22
+ version: Version number for the migration.
23
+ message: Description message for the migration.
24
+
25
+ Returns:
26
+ Path to the created migration file.
27
+ """
28
+ # Sanitize message for filename
29
+ safe_message = message.lower()
30
+ safe_message = "".join(c if c.isalnum() or c in " -" else "" for c in safe_message)
31
+ safe_message = safe_message.replace(" ", "_").replace("-", "_")
32
+ safe_message = "_".join(filter(None, safe_message.split("_")))[:50]
33
+
34
+ filename = f"{version}_{safe_message}.sql"
35
+ file_path = migrations_dir / filename
36
+
37
+ # Generate template content
38
+ template = f"""-- SQLSpec Migration
39
+ -- Version: {version}
40
+ -- Description: {message}
41
+ -- Created: {datetime.now(timezone.utc).isoformat()}
42
+ -- Author: {get_author()}
43
+
44
+ -- name: migrate-{version}-up
45
+ -- TODO: Add your upgrade SQL statements here
46
+ -- Example:
47
+ -- CREATE TABLE example (
48
+ -- id INTEGER PRIMARY KEY,
49
+ -- name TEXT NOT NULL
50
+ -- );
51
+
52
+ -- name: migrate-{version}-down
53
+ -- TODO: Add your downgrade SQL statements here (optional)
54
+ -- Example:
55
+ -- DROP TABLE example;
56
+ """
57
+
58
+ file_path.write_text(template)
59
+ return file_path
60
+
61
+
62
+ def get_author() -> str:
63
+ """Get current user for migration metadata.
64
+
65
+ Returns:
66
+ Username from environment or 'unknown'.
67
+ """
68
+ return os.environ.get("USER", "unknown")
69
+
70
+
71
+ async def drop_all(
72
+ engine: "AsyncDriverAdapterProtocol[Any]", version_table_name: str, metadata: Optional[Any] = None
73
+ ) -> None:
74
+ """Drop all tables from the database.
75
+
76
+ This is a placeholder for database-specific implementations.
77
+
78
+ Args:
79
+ engine: The database engine/driver.
80
+ version_table_name: Name of the version tracking table.
81
+ metadata: Optional metadata object.
82
+
83
+ Raises:
84
+ NotImplementedError: Always, as this requires database-specific logic.
85
+ """
86
+ # This would need database-specific implementation
87
+ # For now, it's a placeholder
88
+ msg = "drop_all functionality requires database-specific implementation"
89
+ raise NotImplementedError(msg)
sqlspec/protocols.py CHANGED
@@ -30,12 +30,16 @@ __all__ = (
30
30
  "DictProtocol",
31
31
  "FilterAppenderProtocol",
32
32
  "FilterParameterProtocol",
33
+ "HasExpressionProtocol",
33
34
  "HasExpressionsProtocol",
34
35
  "HasLimitProtocol",
35
36
  "HasOffsetProtocol",
36
37
  "HasOrderByProtocol",
38
+ "HasParameterBuilderProtocol",
37
39
  "HasRiskLevelProtocol",
40
+ "HasSQLGlotExpressionProtocol",
38
41
  "HasSQLMethodProtocol",
42
+ "HasToStatementProtocol",
39
43
  "HasWhereProtocol",
40
44
  "IndexableRow",
41
45
  "IterableParameters",
@@ -501,9 +505,39 @@ class ObjectStoreProtocol(Protocol):
501
505
  raise NotImplementedError(msg)
502
506
 
503
507
 
504
- # =============================================================================
505
- # SQL Builder Protocols
506
- # =============================================================================
508
+ @runtime_checkable
509
+ class HasSQLGlotExpressionProtocol(Protocol):
510
+ """Protocol for objects with a sqlglot_expression property."""
511
+
512
+ @property
513
+ def sqlglot_expression(self) -> "Optional[exp.Expression]":
514
+ """Return the SQLGlot expression for this object."""
515
+ ...
516
+
517
+
518
+ @runtime_checkable
519
+ class HasParameterBuilderProtocol(Protocol):
520
+ """Protocol for objects that can add parameters."""
521
+
522
+ def add_parameter(self, value: Any, name: "Optional[str]" = None) -> tuple[Any, str]:
523
+ """Add a parameter to the builder."""
524
+ ...
525
+
526
+
527
+ @runtime_checkable
528
+ class HasExpressionProtocol(Protocol):
529
+ """Protocol for objects with an _expression attribute."""
530
+
531
+ _expression: "Optional[exp.Expression]"
532
+
533
+
534
+ @runtime_checkable
535
+ class HasToStatementProtocol(Protocol):
536
+ """Protocol for objects with a to_statement method."""
537
+
538
+ def to_statement(self) -> Any:
539
+ """Convert to SQL statement."""
540
+ ...
507
541
 
508
542
 
509
543
  @runtime_checkable
@@ -7,9 +7,9 @@ parameter binding and validation.
7
7
  """
8
8
 
9
9
  from sqlspec.exceptions import SQLBuilderError
10
- from sqlspec.statement.builder.base import QueryBuilder, SafeQuery
11
- from sqlspec.statement.builder.column import Column, ColumnExpression, FunctionColumn
12
- from sqlspec.statement.builder.ddl import (
10
+ from sqlspec.statement.builder._base import QueryBuilder, SafeQuery
11
+ from sqlspec.statement.builder._column import Column, ColumnExpression, FunctionColumn
12
+ from sqlspec.statement.builder._ddl import (
13
13
  AlterTable,
14
14
  CommentOn,
15
15
  CreateIndex,
@@ -26,12 +26,12 @@ from sqlspec.statement.builder.ddl import (
26
26
  RenameTable,
27
27
  TruncateTable,
28
28
  )
29
- from sqlspec.statement.builder.delete import Delete
30
- from sqlspec.statement.builder.insert import Insert
31
- from sqlspec.statement.builder.merge import Merge
29
+ from sqlspec.statement.builder._delete import Delete
30
+ from sqlspec.statement.builder._insert import Insert
31
+ from sqlspec.statement.builder._merge import Merge
32
+ from sqlspec.statement.builder._select import Select
33
+ from sqlspec.statement.builder._update import Update
32
34
  from sqlspec.statement.builder.mixins import WhereClauseMixin
33
- from sqlspec.statement.builder.select import Select
34
- from sqlspec.statement.builder.update import Update
35
35
 
36
36
  __all__ = (
37
37
  "AlterTable",