sqlspec 0.11.1__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (155) hide show
  1. sqlspec/__init__.py +16 -3
  2. sqlspec/_serialization.py +3 -10
  3. sqlspec/_sql.py +1147 -0
  4. sqlspec/_typing.py +343 -41
  5. sqlspec/adapters/adbc/__init__.py +2 -6
  6. sqlspec/adapters/adbc/config.py +474 -149
  7. sqlspec/adapters/adbc/driver.py +330 -621
  8. sqlspec/adapters/aiosqlite/__init__.py +2 -6
  9. sqlspec/adapters/aiosqlite/config.py +143 -57
  10. sqlspec/adapters/aiosqlite/driver.py +269 -431
  11. sqlspec/adapters/asyncmy/__init__.py +3 -8
  12. sqlspec/adapters/asyncmy/config.py +247 -202
  13. sqlspec/adapters/asyncmy/driver.py +218 -436
  14. sqlspec/adapters/asyncpg/__init__.py +4 -7
  15. sqlspec/adapters/asyncpg/config.py +329 -176
  16. sqlspec/adapters/asyncpg/driver.py +417 -487
  17. sqlspec/adapters/bigquery/__init__.py +2 -2
  18. sqlspec/adapters/bigquery/config.py +407 -0
  19. sqlspec/adapters/bigquery/driver.py +600 -553
  20. sqlspec/adapters/duckdb/__init__.py +4 -1
  21. sqlspec/adapters/duckdb/config.py +432 -321
  22. sqlspec/adapters/duckdb/driver.py +392 -406
  23. sqlspec/adapters/oracledb/__init__.py +3 -8
  24. sqlspec/adapters/oracledb/config.py +625 -0
  25. sqlspec/adapters/oracledb/driver.py +548 -921
  26. sqlspec/adapters/psqlpy/__init__.py +4 -7
  27. sqlspec/adapters/psqlpy/config.py +372 -203
  28. sqlspec/adapters/psqlpy/driver.py +197 -533
  29. sqlspec/adapters/psycopg/__init__.py +3 -8
  30. sqlspec/adapters/psycopg/config.py +741 -0
  31. sqlspec/adapters/psycopg/driver.py +734 -694
  32. sqlspec/adapters/sqlite/__init__.py +2 -6
  33. sqlspec/adapters/sqlite/config.py +146 -81
  34. sqlspec/adapters/sqlite/driver.py +242 -405
  35. sqlspec/base.py +220 -784
  36. sqlspec/config.py +354 -0
  37. sqlspec/driver/__init__.py +22 -0
  38. sqlspec/driver/_async.py +252 -0
  39. sqlspec/driver/_common.py +338 -0
  40. sqlspec/driver/_sync.py +261 -0
  41. sqlspec/driver/mixins/__init__.py +17 -0
  42. sqlspec/driver/mixins/_pipeline.py +523 -0
  43. sqlspec/driver/mixins/_result_utils.py +122 -0
  44. sqlspec/driver/mixins/_sql_translator.py +35 -0
  45. sqlspec/driver/mixins/_storage.py +993 -0
  46. sqlspec/driver/mixins/_type_coercion.py +131 -0
  47. sqlspec/exceptions.py +299 -7
  48. sqlspec/extensions/aiosql/__init__.py +10 -0
  49. sqlspec/extensions/aiosql/adapter.py +474 -0
  50. sqlspec/extensions/litestar/__init__.py +1 -6
  51. sqlspec/extensions/litestar/_utils.py +1 -5
  52. sqlspec/extensions/litestar/config.py +5 -6
  53. sqlspec/extensions/litestar/handlers.py +13 -12
  54. sqlspec/extensions/litestar/plugin.py +22 -24
  55. sqlspec/extensions/litestar/providers.py +37 -55
  56. sqlspec/loader.py +528 -0
  57. sqlspec/service/__init__.py +3 -0
  58. sqlspec/service/base.py +24 -0
  59. sqlspec/service/pagination.py +26 -0
  60. sqlspec/statement/__init__.py +21 -0
  61. sqlspec/statement/builder/__init__.py +54 -0
  62. sqlspec/statement/builder/_ddl_utils.py +119 -0
  63. sqlspec/statement/builder/_parsing_utils.py +135 -0
  64. sqlspec/statement/builder/base.py +328 -0
  65. sqlspec/statement/builder/ddl.py +1379 -0
  66. sqlspec/statement/builder/delete.py +80 -0
  67. sqlspec/statement/builder/insert.py +274 -0
  68. sqlspec/statement/builder/merge.py +95 -0
  69. sqlspec/statement/builder/mixins/__init__.py +65 -0
  70. sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
  71. sqlspec/statement/builder/mixins/_case_builder.py +91 -0
  72. sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
  73. sqlspec/statement/builder/mixins/_delete_from.py +34 -0
  74. sqlspec/statement/builder/mixins/_from.py +61 -0
  75. sqlspec/statement/builder/mixins/_group_by.py +119 -0
  76. sqlspec/statement/builder/mixins/_having.py +35 -0
  77. sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
  78. sqlspec/statement/builder/mixins/_insert_into.py +36 -0
  79. sqlspec/statement/builder/mixins/_insert_values.py +69 -0
  80. sqlspec/statement/builder/mixins/_join.py +110 -0
  81. sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
  82. sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
  83. sqlspec/statement/builder/mixins/_order_by.py +46 -0
  84. sqlspec/statement/builder/mixins/_pivot.py +82 -0
  85. sqlspec/statement/builder/mixins/_returning.py +37 -0
  86. sqlspec/statement/builder/mixins/_select_columns.py +60 -0
  87. sqlspec/statement/builder/mixins/_set_ops.py +122 -0
  88. sqlspec/statement/builder/mixins/_unpivot.py +80 -0
  89. sqlspec/statement/builder/mixins/_update_from.py +54 -0
  90. sqlspec/statement/builder/mixins/_update_set.py +91 -0
  91. sqlspec/statement/builder/mixins/_update_table.py +29 -0
  92. sqlspec/statement/builder/mixins/_where.py +374 -0
  93. sqlspec/statement/builder/mixins/_window_functions.py +86 -0
  94. sqlspec/statement/builder/protocols.py +20 -0
  95. sqlspec/statement/builder/select.py +206 -0
  96. sqlspec/statement/builder/update.py +178 -0
  97. sqlspec/statement/filters.py +571 -0
  98. sqlspec/statement/parameters.py +736 -0
  99. sqlspec/statement/pipelines/__init__.py +67 -0
  100. sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
  101. sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
  102. sqlspec/statement/pipelines/base.py +315 -0
  103. sqlspec/statement/pipelines/context.py +119 -0
  104. sqlspec/statement/pipelines/result_types.py +41 -0
  105. sqlspec/statement/pipelines/transformers/__init__.py +8 -0
  106. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
  107. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
  108. sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
  109. sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
  110. sqlspec/statement/pipelines/validators/__init__.py +23 -0
  111. sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
  112. sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
  113. sqlspec/statement/pipelines/validators/_performance.py +703 -0
  114. sqlspec/statement/pipelines/validators/_security.py +990 -0
  115. sqlspec/statement/pipelines/validators/base.py +67 -0
  116. sqlspec/statement/result.py +527 -0
  117. sqlspec/statement/splitter.py +701 -0
  118. sqlspec/statement/sql.py +1198 -0
  119. sqlspec/storage/__init__.py +15 -0
  120. sqlspec/storage/backends/__init__.py +0 -0
  121. sqlspec/storage/backends/base.py +166 -0
  122. sqlspec/storage/backends/fsspec.py +315 -0
  123. sqlspec/storage/backends/obstore.py +464 -0
  124. sqlspec/storage/protocol.py +170 -0
  125. sqlspec/storage/registry.py +315 -0
  126. sqlspec/typing.py +157 -36
  127. sqlspec/utils/correlation.py +155 -0
  128. sqlspec/utils/deprecation.py +3 -6
  129. sqlspec/utils/fixtures.py +6 -11
  130. sqlspec/utils/logging.py +135 -0
  131. sqlspec/utils/module_loader.py +45 -43
  132. sqlspec/utils/serializers.py +4 -0
  133. sqlspec/utils/singleton.py +6 -8
  134. sqlspec/utils/sync_tools.py +15 -27
  135. sqlspec/utils/text.py +58 -26
  136. {sqlspec-0.11.1.dist-info → sqlspec-0.12.0.dist-info}/METADATA +97 -26
  137. sqlspec-0.12.0.dist-info/RECORD +145 -0
  138. sqlspec/adapters/bigquery/config/__init__.py +0 -3
  139. sqlspec/adapters/bigquery/config/_common.py +0 -40
  140. sqlspec/adapters/bigquery/config/_sync.py +0 -87
  141. sqlspec/adapters/oracledb/config/__init__.py +0 -9
  142. sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
  143. sqlspec/adapters/oracledb/config/_common.py +0 -131
  144. sqlspec/adapters/oracledb/config/_sync.py +0 -186
  145. sqlspec/adapters/psycopg/config/__init__.py +0 -19
  146. sqlspec/adapters/psycopg/config/_async.py +0 -169
  147. sqlspec/adapters/psycopg/config/_common.py +0 -56
  148. sqlspec/adapters/psycopg/config/_sync.py +0 -168
  149. sqlspec/filters.py +0 -331
  150. sqlspec/mixins.py +0 -305
  151. sqlspec/statement.py +0 -378
  152. sqlspec-0.11.1.dist-info/RECORD +0 -69
  153. {sqlspec-0.11.1.dist-info → sqlspec-0.12.0.dist-info}/WHEEL +0 -0
  154. {sqlspec-0.11.1.dist-info → sqlspec-0.12.0.dist-info}/licenses/LICENSE +0 -0
  155. {sqlspec-0.11.1.dist-info → sqlspec-0.12.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,701 @@
1
+ """SQL script statement splitter with dialect-aware lexer-driven state machine.
2
+
3
+ This module provides a robust way to split SQL scripts into individual statements,
4
+ handling complex constructs like PL/SQL blocks, T-SQL batches, and nested blocks.
5
+ """
6
+
7
+ import re
8
+ from abc import ABC, abstractmethod
9
+ from collections.abc import Generator
10
+ from dataclasses import dataclass
11
+ from enum import Enum
12
+ from re import Pattern
13
+ from typing import Callable, Optional, Union
14
+
15
+ from typing_extensions import TypeAlias
16
+
17
+ from sqlspec.utils.logging import get_logger
18
+
19
+ __all__ = (
20
+ "DialectConfig",
21
+ "OracleDialectConfig",
22
+ "PostgreSQLDialectConfig",
23
+ "StatementSplitter",
24
+ "TSQLDialectConfig",
25
+ "Token",
26
+ "TokenType",
27
+ "split_sql_script",
28
+ )
29
+
30
+
31
+ logger = get_logger("sqlspec")
32
+
33
+
34
+ class TokenType(Enum):
35
+ """Types of tokens recognized by the SQL lexer."""
36
+
37
+ COMMENT_LINE = "COMMENT_LINE"
38
+ COMMENT_BLOCK = "COMMENT_BLOCK"
39
+ STRING_LITERAL = "STRING_LITERAL"
40
+ QUOTED_IDENTIFIER = "QUOTED_IDENTIFIER"
41
+ KEYWORD = "KEYWORD"
42
+ TERMINATOR = "TERMINATOR"
43
+ BATCH_SEPARATOR = "BATCH_SEPARATOR"
44
+ WHITESPACE = "WHITESPACE"
45
+ OTHER = "OTHER"
46
+
47
+
48
+ @dataclass
49
+ class Token:
50
+ """Represents a single token in the SQL script."""
51
+
52
+ type: TokenType
53
+ value: str
54
+ line: int
55
+ column: int
56
+ position: int # Absolute position in the script
57
+
58
+
59
+ TokenHandler: TypeAlias = Callable[[str, int, int, int], Optional[Token]]
60
+ TokenPattern: TypeAlias = Union[str, TokenHandler]
61
+ CompiledTokenPattern: TypeAlias = Union[Pattern[str], TokenHandler]
62
+
63
+
64
+ class DialectConfig(ABC):
65
+ """Abstract base class for SQL dialect configurations."""
66
+
67
+ @property
68
+ @abstractmethod
69
+ def name(self) -> str:
70
+ """Name of the dialect (e.g., 'oracle', 'tsql')."""
71
+
72
+ @property
73
+ @abstractmethod
74
+ def block_starters(self) -> set[str]:
75
+ """Keywords that start a block (e.g., BEGIN, DECLARE)."""
76
+
77
+ @property
78
+ @abstractmethod
79
+ def block_enders(self) -> set[str]:
80
+ """Keywords that end a block (e.g., END)."""
81
+
82
+ @property
83
+ @abstractmethod
84
+ def statement_terminators(self) -> set[str]:
85
+ """Characters that terminate statements (e.g., ;)."""
86
+
87
+ @property
88
+ def batch_separators(self) -> set[str]:
89
+ """Keywords that separate batches (e.g., GO for T-SQL)."""
90
+ return set()
91
+
92
+ @property
93
+ def special_terminators(self) -> dict[str, Callable[[list[Token], int], bool]]:
94
+ """Special terminators that need custom handling."""
95
+ return {}
96
+
97
+ @property
98
+ def max_nesting_depth(self) -> int:
99
+ """Maximum allowed nesting depth for blocks."""
100
+ return 256
101
+
102
+ def get_all_token_patterns(self) -> list[tuple[TokenType, TokenPattern]]:
103
+ """Assembles the complete, ordered list of token regex patterns."""
104
+ # 1. Start with high-precedence patterns
105
+ patterns: list[tuple[TokenType, TokenPattern]] = [
106
+ (TokenType.COMMENT_LINE, r"--[^\n]*"),
107
+ (TokenType.COMMENT_BLOCK, r"/\*[\s\S]*?\*/"),
108
+ (TokenType.STRING_LITERAL, r"'(?:[^']|'')*'"),
109
+ (TokenType.QUOTED_IDENTIFIER, r'"[^"]*"|\[[^\]]*\]'), # Standard and T-SQL
110
+ ]
111
+
112
+ # 2. Add dialect-specific patterns (can be overridden)
113
+ patterns.extend(self._get_dialect_specific_patterns())
114
+
115
+ # 3. Dynamically build and insert keyword/separator patterns
116
+ all_keywords = self.block_starters | self.block_enders | self.batch_separators
117
+ if all_keywords:
118
+ # Sort by length descending to match longer keywords first
119
+ sorted_keywords = sorted(all_keywords, key=len, reverse=True)
120
+ patterns.append((TokenType.KEYWORD, r"\b(" + "|".join(re.escape(kw) for kw in sorted_keywords) + r")\b"))
121
+
122
+ # 4. Add terminators
123
+ all_terminators = self.statement_terminators | set(self.special_terminators.keys())
124
+ if all_terminators:
125
+ # Escape special regex characters
126
+ patterns.append((TokenType.TERMINATOR, "|".join(re.escape(t) for t in all_terminators)))
127
+
128
+ # 5. Add low-precedence patterns
129
+ patterns.extend([(TokenType.WHITESPACE, r"\s+"), (TokenType.OTHER, r".")])
130
+
131
+ return patterns
132
+
133
+ def _get_dialect_specific_patterns(self) -> list[tuple[TokenType, TokenPattern]]:
134
+ """Override to add dialect-specific token patterns."""
135
+ return []
136
+
137
+ @staticmethod
138
+ def is_real_block_ender(tokens: list[Token], current_pos: int) -> bool:
139
+ """Check if this END keyword is actually a block ender.
140
+
141
+ Override in dialect configs to handle cases like END IF, END LOOP, etc.
142
+ that are not true block enders.
143
+ """
144
+ _ = tokens, current_pos # Default implementation doesn't use these
145
+ return True
146
+
147
+ def should_delay_semicolon_termination(self, tokens: list[Token], current_pos: int) -> bool:
148
+ """Check if semicolon termination should be delayed.
149
+
150
+ Override in dialect configs to handle special cases like Oracle END; /
151
+ """
152
+ _ = tokens, current_pos # Default implementation doesn't use these
153
+ return False
154
+
155
+
156
+ class OracleDialectConfig(DialectConfig):
157
+ """Configuration for Oracle PL/SQL dialect."""
158
+
159
+ @property
160
+ def name(self) -> str:
161
+ return "oracle"
162
+
163
+ @property
164
+ def block_starters(self) -> set[str]:
165
+ return {"BEGIN", "DECLARE", "CASE"}
166
+
167
+ @property
168
+ def block_enders(self) -> set[str]:
169
+ return {"END"}
170
+
171
+ @property
172
+ def statement_terminators(self) -> set[str]:
173
+ return {";"}
174
+
175
+ @property
176
+ def special_terminators(self) -> dict[str, Callable[[list[Token], int], bool]]:
177
+ return {"/": self._handle_slash_terminator}
178
+
179
+ def should_delay_semicolon_termination(self, tokens: list[Token], current_pos: int) -> bool:
180
+ """Check if we should delay semicolon termination to look for a slash.
181
+
182
+ In Oracle, after END; we should check if there's a / coming up on its own line.
183
+ """
184
+ # Look backwards to see if we just processed an END token
185
+ pos = current_pos - 1
186
+ while pos >= 0:
187
+ token = tokens[pos]
188
+ if token.type == TokenType.WHITESPACE:
189
+ pos -= 1
190
+ continue
191
+ if token.type == TokenType.KEYWORD and token.value.upper() == "END":
192
+ # We found END just before this semicolon
193
+ # Now look ahead to see if there's a / on its own line
194
+ return self._has_upcoming_slash(tokens, current_pos)
195
+ # Found something else, not an END
196
+ break
197
+
198
+ return False
199
+
200
+ def _has_upcoming_slash(self, tokens: list[Token], current_pos: int) -> bool:
201
+ """Check if there's a / terminator coming up on its own line."""
202
+ pos = current_pos + 1
203
+ found_newline = False
204
+
205
+ while pos < len(tokens):
206
+ token = tokens[pos]
207
+ if token.type == TokenType.WHITESPACE:
208
+ if "\n" in token.value:
209
+ found_newline = True
210
+ pos += 1
211
+ continue
212
+ if token.type == TokenType.TERMINATOR and token.value == "/":
213
+ # Found a /, check if it's valid (on its own line)
214
+ return found_newline and self._handle_slash_terminator(tokens, pos)
215
+ if token.type in {TokenType.COMMENT_LINE, TokenType.COMMENT_BLOCK}:
216
+ # Skip comments
217
+ pos += 1
218
+ continue
219
+ # Found non-whitespace, non-comment content
220
+ break
221
+
222
+ return False
223
+
224
+ @staticmethod
225
+ def is_real_block_ender(tokens: list[Token], current_pos: int) -> bool:
226
+ """Check if this END keyword is actually a block ender.
227
+
228
+ In Oracle PL/SQL, END followed by IF, LOOP, CASE etc. are not block enders
229
+ for BEGIN blocks - they terminate control structures.
230
+ """
231
+ # Look ahead for the next non-whitespace token(s)
232
+ pos = current_pos + 1
233
+ while pos < len(tokens):
234
+ next_token = tokens[pos]
235
+
236
+ if next_token.type == TokenType.WHITESPACE:
237
+ pos += 1
238
+ continue
239
+ if next_token.type == TokenType.OTHER:
240
+ # Collect consecutive OTHER tokens to form a word
241
+ word_chars = []
242
+ word_pos = pos
243
+ while word_pos < len(tokens) and tokens[word_pos].type == TokenType.OTHER:
244
+ word_chars.append(tokens[word_pos].value)
245
+ word_pos += 1
246
+
247
+ word = "".join(word_chars).upper()
248
+ if word in {"IF", "LOOP", "CASE", "WHILE"}:
249
+ return False # This is not a block ender
250
+ # Found a non-whitespace token that's not one of our special cases
251
+ break
252
+ return True # This is a real block ender
253
+
254
+ @staticmethod
255
+ def _handle_slash_terminator(tokens: list[Token], current_pos: int) -> bool:
256
+ """Oracle / must be on its own line after whitespace only."""
257
+ if current_pos == 0:
258
+ return True # / at start is valid
259
+
260
+ # Look backwards to find the start of the line
261
+ pos = current_pos - 1
262
+ while pos >= 0:
263
+ token = tokens[pos]
264
+ if "\n" in token.value:
265
+ # Found newline, check if only whitespace between newline and /
266
+ break
267
+ if token.type not in {TokenType.WHITESPACE, TokenType.COMMENT_LINE}:
268
+ return False # Non-whitespace before / on same line
269
+ pos -= 1
270
+
271
+ return True
272
+
273
+
274
+ class TSQLDialectConfig(DialectConfig):
275
+ """Configuration for T-SQL (SQL Server) dialect."""
276
+
277
+ @property
278
+ def name(self) -> str:
279
+ return "tsql"
280
+
281
+ @property
282
+ def block_starters(self) -> set[str]:
283
+ return {"BEGIN", "TRY"}
284
+
285
+ @property
286
+ def block_enders(self) -> set[str]:
287
+ return {"END", "CATCH"}
288
+
289
+ @property
290
+ def statement_terminators(self) -> set[str]:
291
+ return {";"}
292
+
293
+ @property
294
+ def batch_separators(self) -> set[str]:
295
+ return {"GO"}
296
+
297
+ @staticmethod
298
+ def validate_batch_separator(tokens: list[Token], current_pos: int) -> bool:
299
+ """GO must be the only keyword on its line."""
300
+ # Look for non-whitespace tokens on the same line
301
+ # Implementation similar to Oracle slash handler
302
+ _ = tokens, current_pos # Simplified implementation
303
+ return True # Simplified for now
304
+
305
+
306
+ class PostgreSQLDialectConfig(DialectConfig):
307
+ """Configuration for PostgreSQL dialect with dollar-quoted strings."""
308
+
309
+ @property
310
+ def name(self) -> str:
311
+ return "postgresql"
312
+
313
+ @property
314
+ def block_starters(self) -> set[str]:
315
+ return {"BEGIN", "DECLARE", "CASE", "DO"}
316
+
317
+ @property
318
+ def block_enders(self) -> set[str]:
319
+ return {"END"}
320
+
321
+ @property
322
+ def statement_terminators(self) -> set[str]:
323
+ return {";"}
324
+
325
+ def _get_dialect_specific_patterns(self) -> list[tuple[TokenType, TokenPattern]]:
326
+ """Add PostgreSQL-specific patterns like dollar-quoted strings."""
327
+ return [(TokenType.STRING_LITERAL, self._handle_dollar_quoted_string)]
328
+
329
+ @staticmethod
330
+ def _handle_dollar_quoted_string(text: str, position: int, line: int, column: int) -> Optional[Token]:
331
+ """Handle PostgreSQL dollar-quoted strings like $tag$...$tag$."""
332
+ # Match opening tag
333
+ start_match = re.match(r"\$([a-zA-Z_][a-zA-Z0-9_]*)?\$", text[position:])
334
+ if not start_match:
335
+ return None
336
+
337
+ tag = start_match.group(0) # The full opening tag, e.g., "$tag$"
338
+ content_start = position + len(tag)
339
+
340
+ # Find the corresponding closing tag
341
+ try:
342
+ content_end = text.index(tag, content_start)
343
+ full_value = text[position : content_end + len(tag)]
344
+
345
+ return Token(type=TokenType.STRING_LITERAL, value=full_value, line=line, column=column, position=position)
346
+ except ValueError:
347
+ # Closing tag not found
348
+ return None
349
+
350
+
351
+ class GenericDialectConfig(DialectConfig):
352
+ """Generic SQL dialect configuration for standard SQL."""
353
+
354
+ @property
355
+ def name(self) -> str:
356
+ return "generic"
357
+
358
+ @property
359
+ def block_starters(self) -> set[str]:
360
+ return {"BEGIN", "DECLARE", "CASE"}
361
+
362
+ @property
363
+ def block_enders(self) -> set[str]:
364
+ return {"END"}
365
+
366
+ @property
367
+ def statement_terminators(self) -> set[str]:
368
+ return {";"}
369
+
370
+
371
+ class MySQLDialectConfig(DialectConfig):
372
+ """Configuration for MySQL dialect."""
373
+
374
+ @property
375
+ def name(self) -> str:
376
+ return "mysql"
377
+
378
+ @property
379
+ def block_starters(self) -> set[str]:
380
+ return {"BEGIN", "DECLARE", "CASE"}
381
+
382
+ @property
383
+ def block_enders(self) -> set[str]:
384
+ return {"END"}
385
+
386
+ @property
387
+ def statement_terminators(self) -> set[str]:
388
+ return {";"}
389
+
390
+ @property
391
+ def special_terminators(self) -> dict[str, Callable[[list[Token], int], bool]]:
392
+ """MySQL supports DELIMITER command for changing terminators."""
393
+ return {"\\g": lambda _tokens, _pos: True, "\\G": lambda _tokens, _pos: True}
394
+
395
+
396
+ class SQLiteDialectConfig(DialectConfig):
397
+ """Configuration for SQLite dialect."""
398
+
399
+ @property
400
+ def name(self) -> str:
401
+ return "sqlite"
402
+
403
+ @property
404
+ def block_starters(self) -> set[str]:
405
+ # SQLite has limited block support
406
+ return {"BEGIN", "CASE"}
407
+
408
+ @property
409
+ def block_enders(self) -> set[str]:
410
+ return {"END"}
411
+
412
+ @property
413
+ def statement_terminators(self) -> set[str]:
414
+ return {";"}
415
+
416
+
417
+ class DuckDBDialectConfig(DialectConfig):
418
+ """Configuration for DuckDB dialect."""
419
+
420
+ @property
421
+ def name(self) -> str:
422
+ return "duckdb"
423
+
424
+ @property
425
+ def block_starters(self) -> set[str]:
426
+ return {"BEGIN", "CASE"}
427
+
428
+ @property
429
+ def block_enders(self) -> set[str]:
430
+ return {"END"}
431
+
432
+ @property
433
+ def statement_terminators(self) -> set[str]:
434
+ return {";"}
435
+
436
+
437
+ class BigQueryDialectConfig(DialectConfig):
438
+ """Configuration for BigQuery dialect."""
439
+
440
+ @property
441
+ def name(self) -> str:
442
+ return "bigquery"
443
+
444
+ @property
445
+ def block_starters(self) -> set[str]:
446
+ return {"BEGIN", "CASE"}
447
+
448
+ @property
449
+ def block_enders(self) -> set[str]:
450
+ return {"END"}
451
+
452
+ @property
453
+ def statement_terminators(self) -> set[str]:
454
+ return {";"}
455
+
456
+
457
+ class StatementSplitter:
458
+ """Splits SQL scripts into individual statements using a lexer-driven state machine."""
459
+
460
+ def __init__(self, dialect: DialectConfig, strip_trailing_semicolon: bool = False) -> None:
461
+ """Initialize the splitter with a specific dialect configuration.
462
+
463
+ Args:
464
+ dialect: The dialect configuration to use
465
+ strip_trailing_semicolon: If True, remove trailing semicolons from statements
466
+ """
467
+ self.dialect = dialect
468
+ self.strip_trailing_semicolon = strip_trailing_semicolon
469
+ self.token_patterns = dialect.get_all_token_patterns()
470
+ self._compiled_patterns = self._compile_patterns()
471
+
472
+ def _compile_patterns(self) -> list[tuple[TokenType, CompiledTokenPattern]]:
473
+ """Compile regex patterns for efficiency."""
474
+ compiled: list[tuple[TokenType, CompiledTokenPattern]] = []
475
+ for token_type, pattern in self.token_patterns:
476
+ if isinstance(pattern, str):
477
+ compiled.append((token_type, re.compile(pattern, re.IGNORECASE | re.DOTALL)))
478
+ else:
479
+ # It's a callable
480
+ compiled.append((token_type, pattern))
481
+ return compiled
482
+
483
+ def _tokenize(self, sql: str) -> Generator[Token, None, None]:
484
+ """Tokenize the SQL script into a stream of tokens.
485
+
486
+ sql: The SQL script to tokenize
487
+
488
+ Yields:
489
+ Token objects representing the recognized tokens in the script.
490
+
491
+ """
492
+ pos = 0
493
+ line = 1
494
+ line_start = 0
495
+
496
+ while pos < len(sql):
497
+ matched = False
498
+
499
+ for token_type, pattern in self._compiled_patterns:
500
+ if callable(pattern):
501
+ # Call the handler function
502
+ column = pos - line_start + 1
503
+ token = pattern(sql, pos, line, column)
504
+ if token:
505
+ # Update position tracking
506
+ newlines = token.value.count("\n")
507
+ if newlines > 0:
508
+ line += newlines
509
+ last_newline = token.value.rfind("\n")
510
+ line_start = pos + last_newline + 1
511
+
512
+ yield token
513
+ pos += len(token.value)
514
+ matched = True
515
+ break
516
+ else:
517
+ # Use regex
518
+ match = pattern.match(sql, pos)
519
+ if match:
520
+ value = match.group(0)
521
+ column = pos - line_start + 1
522
+
523
+ # Update line tracking
524
+ newlines = value.count("\n")
525
+ if newlines > 0:
526
+ line += newlines
527
+ last_newline = value.rfind("\n")
528
+ line_start = pos + last_newline + 1
529
+
530
+ yield Token(type=token_type, value=value, line=line, column=column, position=pos)
531
+ pos = match.end()
532
+ matched = True
533
+ break
534
+
535
+ if not matched:
536
+ # This should never happen with our catch-all OTHER pattern
537
+ logger.error("Failed to tokenize at position %d: %s", pos, sql[pos : pos + 20])
538
+ pos += 1 # Skip the problematic character
539
+
540
+ def split(self, sql: str) -> list[str]:
541
+ """Split the SQL script into individual statements."""
542
+ statements = []
543
+ current_statement_tokens = []
544
+ current_statement_chars = []
545
+ block_stack = []
546
+
547
+ # Convert token generator to list so we can look ahead
548
+ all_tokens = list(self._tokenize(sql))
549
+
550
+ for token_idx, token in enumerate(all_tokens):
551
+ # Always accumulate the original text
552
+ current_statement_chars.append(token.value)
553
+
554
+ # Skip whitespace and comments for logic (but keep in output)
555
+ if token.type in {TokenType.WHITESPACE, TokenType.COMMENT_LINE, TokenType.COMMENT_BLOCK}:
556
+ current_statement_tokens.append(token)
557
+ continue
558
+
559
+ current_statement_tokens.append(token)
560
+ token_upper = token.value.upper()
561
+
562
+ # Update block nesting
563
+ if token.type == TokenType.KEYWORD:
564
+ if token_upper in self.dialect.block_starters:
565
+ block_stack.append(token_upper)
566
+ if len(block_stack) > self.dialect.max_nesting_depth:
567
+ msg = f"Maximum nesting depth ({self.dialect.max_nesting_depth}) exceeded"
568
+ raise ValueError(msg)
569
+ elif token_upper in self.dialect.block_enders:
570
+ # Check if this is actually a block ender (not END IF, END LOOP, etc.)
571
+ if block_stack and self.dialect.is_real_block_ender(all_tokens, token_idx):
572
+ block_stack.pop()
573
+
574
+ # Check for statement termination
575
+ is_terminator = False
576
+ if not block_stack: # Only terminate when not inside a block
577
+ if token.type == TokenType.TERMINATOR:
578
+ if token.value in self.dialect.statement_terminators:
579
+ # Check if we should delay this termination (e.g., for Oracle END; /)
580
+ should_delay = self.dialect.should_delay_semicolon_termination(all_tokens, token_idx)
581
+
582
+ # Also check if there's a batch separator coming up (for T-SQL GO)
583
+ if not should_delay and token.value == ";" and self.dialect.batch_separators:
584
+ # In dialects with batch separators, semicolons don't terminate
585
+ # statements - only batch separators do
586
+ should_delay = True
587
+
588
+ if not should_delay:
589
+ is_terminator = True
590
+ elif token.value in self.dialect.special_terminators:
591
+ # Call the handler to validate
592
+ handler = self.dialect.special_terminators[token.value]
593
+ if handler(all_tokens, token_idx):
594
+ is_terminator = True
595
+
596
+ elif token.type == TokenType.KEYWORD and token_upper in self.dialect.batch_separators:
597
+ # Batch separators like GO should be included with the preceding statement
598
+ is_terminator = True
599
+
600
+ if is_terminator:
601
+ # Save the statement
602
+ statement = "".join(current_statement_chars).strip()
603
+
604
+ # Determine if this is a PL/SQL block
605
+ is_plsql_block = self._is_plsql_block(current_statement_tokens)
606
+
607
+ # Optionally strip the trailing terminator
608
+ # For PL/SQL blocks, never strip the semicolon as it's syntactically required
609
+ if (
610
+ self.strip_trailing_semicolon
611
+ and token.type == TokenType.TERMINATOR
612
+ and statement.endswith(token.value)
613
+ and not is_plsql_block
614
+ ):
615
+ statement = statement[: -len(token.value)].rstrip()
616
+
617
+ if statement and self._contains_executable_content(statement):
618
+ statements.append(statement)
619
+ current_statement_tokens = []
620
+ current_statement_chars = []
621
+
622
+ # Handle any remaining content
623
+ if current_statement_chars:
624
+ statement = "".join(current_statement_chars).strip()
625
+ if statement and self._contains_executable_content(statement):
626
+ statements.append(statement)
627
+
628
+ return statements
629
+
630
+ @staticmethod
631
+ def _is_plsql_block(tokens: list[Token]) -> bool:
632
+ """Check if the token list represents a PL/SQL block.
633
+
634
+ Args:
635
+ tokens: List of tokens for the current statement
636
+
637
+ Returns:
638
+ True if this is a PL/SQL block (BEGIN...END or DECLARE...END)
639
+ """
640
+ # Find the first meaningful keyword token (skip whitespace and comments)
641
+ for token in tokens:
642
+ if token.type == TokenType.KEYWORD:
643
+ return token.value.upper() in {"BEGIN", "DECLARE"}
644
+ return False
645
+
646
+ def _contains_executable_content(self, statement: str) -> bool:
647
+ """Check if a statement contains actual executable content (not just comments/whitespace).
648
+
649
+ Args:
650
+ statement: The statement string to check
651
+
652
+ Returns:
653
+ True if the statement contains executable SQL, False if it's only comments/whitespace
654
+ """
655
+ # Tokenize the statement to check its content
656
+ tokens = list(self._tokenize(statement))
657
+
658
+ # Check if there are any non-comment, non-whitespace tokens
659
+ for token in tokens:
660
+ if token.type not in {TokenType.WHITESPACE, TokenType.COMMENT_LINE, TokenType.COMMENT_BLOCK}:
661
+ return True
662
+
663
+ return False
664
+
665
+
666
+ def split_sql_script(script: str, dialect: str = "generic", strip_trailing_semicolon: bool = False) -> list[str]:
667
+ """Split a SQL script into statements using the appropriate dialect.
668
+
669
+ Args:
670
+ script: The SQL script to split
671
+ dialect: The SQL dialect name ('oracle', 'tsql', 'postgresql', etc.)
672
+ strip_trailing_semicolon: If True, remove trailing terminators from statements
673
+
674
+ Returns:
675
+ List of individual SQL statements
676
+ """
677
+ dialect_configs = {
678
+ # Standard dialects
679
+ "generic": GenericDialectConfig(),
680
+ # Major databases
681
+ "oracle": OracleDialectConfig(),
682
+ "tsql": TSQLDialectConfig(),
683
+ "mssql": TSQLDialectConfig(), # Alias for tsql
684
+ "sqlserver": TSQLDialectConfig(), # Alias for tsql
685
+ "postgresql": PostgreSQLDialectConfig(),
686
+ "postgres": PostgreSQLDialectConfig(), # Common alias
687
+ "mysql": MySQLDialectConfig(),
688
+ "sqlite": SQLiteDialectConfig(),
689
+ # Modern analytical databases
690
+ "duckdb": DuckDBDialectConfig(),
691
+ "bigquery": BigQueryDialectConfig(),
692
+ }
693
+
694
+ config = dialect_configs.get(dialect.lower())
695
+ if not config:
696
+ # Fall back to generic config for unknown dialects
697
+ logger.warning("Unknown dialect '%s', using generic SQL splitter", dialect)
698
+ config = GenericDialectConfig()
699
+
700
+ splitter = StatementSplitter(config, strip_trailing_semicolon=strip_trailing_semicolon)
701
+ return splitter.split(script)