sqlspec 0.11.1__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +16 -3
- sqlspec/_serialization.py +3 -10
- sqlspec/_sql.py +1147 -0
- sqlspec/_typing.py +343 -41
- sqlspec/adapters/adbc/__init__.py +2 -6
- sqlspec/adapters/adbc/config.py +474 -149
- sqlspec/adapters/adbc/driver.py +330 -621
- sqlspec/adapters/aiosqlite/__init__.py +2 -6
- sqlspec/adapters/aiosqlite/config.py +143 -57
- sqlspec/adapters/aiosqlite/driver.py +269 -431
- sqlspec/adapters/asyncmy/__init__.py +3 -8
- sqlspec/adapters/asyncmy/config.py +247 -202
- sqlspec/adapters/asyncmy/driver.py +218 -436
- sqlspec/adapters/asyncpg/__init__.py +4 -7
- sqlspec/adapters/asyncpg/config.py +329 -176
- sqlspec/adapters/asyncpg/driver.py +417 -487
- sqlspec/adapters/bigquery/__init__.py +2 -2
- sqlspec/adapters/bigquery/config.py +407 -0
- sqlspec/adapters/bigquery/driver.py +600 -553
- sqlspec/adapters/duckdb/__init__.py +4 -1
- sqlspec/adapters/duckdb/config.py +432 -321
- sqlspec/adapters/duckdb/driver.py +392 -406
- sqlspec/adapters/oracledb/__init__.py +3 -8
- sqlspec/adapters/oracledb/config.py +625 -0
- sqlspec/adapters/oracledb/driver.py +548 -921
- sqlspec/adapters/psqlpy/__init__.py +4 -7
- sqlspec/adapters/psqlpy/config.py +372 -203
- sqlspec/adapters/psqlpy/driver.py +197 -533
- sqlspec/adapters/psycopg/__init__.py +3 -8
- sqlspec/adapters/psycopg/config.py +741 -0
- sqlspec/adapters/psycopg/driver.py +734 -694
- sqlspec/adapters/sqlite/__init__.py +2 -6
- sqlspec/adapters/sqlite/config.py +146 -81
- sqlspec/adapters/sqlite/driver.py +242 -405
- sqlspec/base.py +220 -784
- sqlspec/config.py +354 -0
- sqlspec/driver/__init__.py +22 -0
- sqlspec/driver/_async.py +252 -0
- sqlspec/driver/_common.py +338 -0
- sqlspec/driver/_sync.py +261 -0
- sqlspec/driver/mixins/__init__.py +17 -0
- sqlspec/driver/mixins/_pipeline.py +523 -0
- sqlspec/driver/mixins/_result_utils.py +122 -0
- sqlspec/driver/mixins/_sql_translator.py +35 -0
- sqlspec/driver/mixins/_storage.py +993 -0
- sqlspec/driver/mixins/_type_coercion.py +131 -0
- sqlspec/exceptions.py +299 -7
- sqlspec/extensions/aiosql/__init__.py +10 -0
- sqlspec/extensions/aiosql/adapter.py +474 -0
- sqlspec/extensions/litestar/__init__.py +1 -6
- sqlspec/extensions/litestar/_utils.py +1 -5
- sqlspec/extensions/litestar/config.py +5 -6
- sqlspec/extensions/litestar/handlers.py +13 -12
- sqlspec/extensions/litestar/plugin.py +22 -24
- sqlspec/extensions/litestar/providers.py +37 -55
- sqlspec/loader.py +528 -0
- sqlspec/service/__init__.py +3 -0
- sqlspec/service/base.py +24 -0
- sqlspec/service/pagination.py +26 -0
- sqlspec/statement/__init__.py +21 -0
- sqlspec/statement/builder/__init__.py +54 -0
- sqlspec/statement/builder/_ddl_utils.py +119 -0
- sqlspec/statement/builder/_parsing_utils.py +135 -0
- sqlspec/statement/builder/base.py +328 -0
- sqlspec/statement/builder/ddl.py +1379 -0
- sqlspec/statement/builder/delete.py +80 -0
- sqlspec/statement/builder/insert.py +274 -0
- sqlspec/statement/builder/merge.py +95 -0
- sqlspec/statement/builder/mixins/__init__.py +65 -0
- sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
- sqlspec/statement/builder/mixins/_case_builder.py +91 -0
- sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
- sqlspec/statement/builder/mixins/_delete_from.py +34 -0
- sqlspec/statement/builder/mixins/_from.py +61 -0
- sqlspec/statement/builder/mixins/_group_by.py +119 -0
- sqlspec/statement/builder/mixins/_having.py +35 -0
- sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
- sqlspec/statement/builder/mixins/_insert_into.py +36 -0
- sqlspec/statement/builder/mixins/_insert_values.py +69 -0
- sqlspec/statement/builder/mixins/_join.py +110 -0
- sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
- sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
- sqlspec/statement/builder/mixins/_order_by.py +46 -0
- sqlspec/statement/builder/mixins/_pivot.py +82 -0
- sqlspec/statement/builder/mixins/_returning.py +37 -0
- sqlspec/statement/builder/mixins/_select_columns.py +60 -0
- sqlspec/statement/builder/mixins/_set_ops.py +122 -0
- sqlspec/statement/builder/mixins/_unpivot.py +80 -0
- sqlspec/statement/builder/mixins/_update_from.py +54 -0
- sqlspec/statement/builder/mixins/_update_set.py +91 -0
- sqlspec/statement/builder/mixins/_update_table.py +29 -0
- sqlspec/statement/builder/mixins/_where.py +374 -0
- sqlspec/statement/builder/mixins/_window_functions.py +86 -0
- sqlspec/statement/builder/protocols.py +20 -0
- sqlspec/statement/builder/select.py +206 -0
- sqlspec/statement/builder/update.py +178 -0
- sqlspec/statement/filters.py +571 -0
- sqlspec/statement/parameters.py +736 -0
- sqlspec/statement/pipelines/__init__.py +67 -0
- sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
- sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
- sqlspec/statement/pipelines/base.py +315 -0
- sqlspec/statement/pipelines/context.py +119 -0
- sqlspec/statement/pipelines/result_types.py +41 -0
- sqlspec/statement/pipelines/transformers/__init__.py +8 -0
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
- sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
- sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
- sqlspec/statement/pipelines/validators/__init__.py +23 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
- sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
- sqlspec/statement/pipelines/validators/_performance.py +703 -0
- sqlspec/statement/pipelines/validators/_security.py +990 -0
- sqlspec/statement/pipelines/validators/base.py +67 -0
- sqlspec/statement/result.py +527 -0
- sqlspec/statement/splitter.py +701 -0
- sqlspec/statement/sql.py +1198 -0
- sqlspec/storage/__init__.py +15 -0
- sqlspec/storage/backends/__init__.py +0 -0
- sqlspec/storage/backends/base.py +166 -0
- sqlspec/storage/backends/fsspec.py +315 -0
- sqlspec/storage/backends/obstore.py +464 -0
- sqlspec/storage/protocol.py +170 -0
- sqlspec/storage/registry.py +315 -0
- sqlspec/typing.py +157 -36
- sqlspec/utils/correlation.py +155 -0
- sqlspec/utils/deprecation.py +3 -6
- sqlspec/utils/fixtures.py +6 -11
- sqlspec/utils/logging.py +135 -0
- sqlspec/utils/module_loader.py +45 -43
- sqlspec/utils/serializers.py +4 -0
- sqlspec/utils/singleton.py +6 -8
- sqlspec/utils/sync_tools.py +15 -27
- sqlspec/utils/text.py +58 -26
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.0.dist-info}/METADATA +97 -26
- sqlspec-0.12.0.dist-info/RECORD +145 -0
- sqlspec/adapters/bigquery/config/__init__.py +0 -3
- sqlspec/adapters/bigquery/config/_common.py +0 -40
- sqlspec/adapters/bigquery/config/_sync.py +0 -87
- sqlspec/adapters/oracledb/config/__init__.py +0 -9
- sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
- sqlspec/adapters/oracledb/config/_common.py +0 -131
- sqlspec/adapters/oracledb/config/_sync.py +0 -186
- sqlspec/adapters/psycopg/config/__init__.py +0 -19
- sqlspec/adapters/psycopg/config/_async.py +0 -169
- sqlspec/adapters/psycopg/config/_common.py +0 -56
- sqlspec/adapters/psycopg/config/_sync.py +0 -168
- sqlspec/filters.py +0 -331
- sqlspec/mixins.py +0 -305
- sqlspec/statement.py +0 -378
- sqlspec-0.11.1.dist-info/RECORD +0 -69
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,701 @@
|
|
|
1
|
+
"""SQL script statement splitter with dialect-aware lexer-driven state machine.
|
|
2
|
+
|
|
3
|
+
This module provides a robust way to split SQL scripts into individual statements,
|
|
4
|
+
handling complex constructs like PL/SQL blocks, T-SQL batches, and nested blocks.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import re
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from collections.abc import Generator
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from enum import Enum
|
|
12
|
+
from re import Pattern
|
|
13
|
+
from typing import Callable, Optional, Union
|
|
14
|
+
|
|
15
|
+
from typing_extensions import TypeAlias
|
|
16
|
+
|
|
17
|
+
from sqlspec.utils.logging import get_logger
|
|
18
|
+
|
|
19
|
+
__all__ = (
|
|
20
|
+
"DialectConfig",
|
|
21
|
+
"OracleDialectConfig",
|
|
22
|
+
"PostgreSQLDialectConfig",
|
|
23
|
+
"StatementSplitter",
|
|
24
|
+
"TSQLDialectConfig",
|
|
25
|
+
"Token",
|
|
26
|
+
"TokenType",
|
|
27
|
+
"split_sql_script",
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
logger = get_logger("sqlspec")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class TokenType(Enum):
|
|
35
|
+
"""Types of tokens recognized by the SQL lexer."""
|
|
36
|
+
|
|
37
|
+
COMMENT_LINE = "COMMENT_LINE"
|
|
38
|
+
COMMENT_BLOCK = "COMMENT_BLOCK"
|
|
39
|
+
STRING_LITERAL = "STRING_LITERAL"
|
|
40
|
+
QUOTED_IDENTIFIER = "QUOTED_IDENTIFIER"
|
|
41
|
+
KEYWORD = "KEYWORD"
|
|
42
|
+
TERMINATOR = "TERMINATOR"
|
|
43
|
+
BATCH_SEPARATOR = "BATCH_SEPARATOR"
|
|
44
|
+
WHITESPACE = "WHITESPACE"
|
|
45
|
+
OTHER = "OTHER"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class Token:
|
|
50
|
+
"""Represents a single token in the SQL script."""
|
|
51
|
+
|
|
52
|
+
type: TokenType
|
|
53
|
+
value: str
|
|
54
|
+
line: int
|
|
55
|
+
column: int
|
|
56
|
+
position: int # Absolute position in the script
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
TokenHandler: TypeAlias = Callable[[str, int, int, int], Optional[Token]]
|
|
60
|
+
TokenPattern: TypeAlias = Union[str, TokenHandler]
|
|
61
|
+
CompiledTokenPattern: TypeAlias = Union[Pattern[str], TokenHandler]
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class DialectConfig(ABC):
|
|
65
|
+
"""Abstract base class for SQL dialect configurations."""
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
@abstractmethod
|
|
69
|
+
def name(self) -> str:
|
|
70
|
+
"""Name of the dialect (e.g., 'oracle', 'tsql')."""
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
@abstractmethod
|
|
74
|
+
def block_starters(self) -> set[str]:
|
|
75
|
+
"""Keywords that start a block (e.g., BEGIN, DECLARE)."""
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def block_enders(self) -> set[str]:
|
|
80
|
+
"""Keywords that end a block (e.g., END)."""
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
@abstractmethod
|
|
84
|
+
def statement_terminators(self) -> set[str]:
|
|
85
|
+
"""Characters that terminate statements (e.g., ;)."""
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def batch_separators(self) -> set[str]:
|
|
89
|
+
"""Keywords that separate batches (e.g., GO for T-SQL)."""
|
|
90
|
+
return set()
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def special_terminators(self) -> dict[str, Callable[[list[Token], int], bool]]:
|
|
94
|
+
"""Special terminators that need custom handling."""
|
|
95
|
+
return {}
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def max_nesting_depth(self) -> int:
|
|
99
|
+
"""Maximum allowed nesting depth for blocks."""
|
|
100
|
+
return 256
|
|
101
|
+
|
|
102
|
+
def get_all_token_patterns(self) -> list[tuple[TokenType, TokenPattern]]:
|
|
103
|
+
"""Assembles the complete, ordered list of token regex patterns."""
|
|
104
|
+
# 1. Start with high-precedence patterns
|
|
105
|
+
patterns: list[tuple[TokenType, TokenPattern]] = [
|
|
106
|
+
(TokenType.COMMENT_LINE, r"--[^\n]*"),
|
|
107
|
+
(TokenType.COMMENT_BLOCK, r"/\*[\s\S]*?\*/"),
|
|
108
|
+
(TokenType.STRING_LITERAL, r"'(?:[^']|'')*'"),
|
|
109
|
+
(TokenType.QUOTED_IDENTIFIER, r'"[^"]*"|\[[^\]]*\]'), # Standard and T-SQL
|
|
110
|
+
]
|
|
111
|
+
|
|
112
|
+
# 2. Add dialect-specific patterns (can be overridden)
|
|
113
|
+
patterns.extend(self._get_dialect_specific_patterns())
|
|
114
|
+
|
|
115
|
+
# 3. Dynamically build and insert keyword/separator patterns
|
|
116
|
+
all_keywords = self.block_starters | self.block_enders | self.batch_separators
|
|
117
|
+
if all_keywords:
|
|
118
|
+
# Sort by length descending to match longer keywords first
|
|
119
|
+
sorted_keywords = sorted(all_keywords, key=len, reverse=True)
|
|
120
|
+
patterns.append((TokenType.KEYWORD, r"\b(" + "|".join(re.escape(kw) for kw in sorted_keywords) + r")\b"))
|
|
121
|
+
|
|
122
|
+
# 4. Add terminators
|
|
123
|
+
all_terminators = self.statement_terminators | set(self.special_terminators.keys())
|
|
124
|
+
if all_terminators:
|
|
125
|
+
# Escape special regex characters
|
|
126
|
+
patterns.append((TokenType.TERMINATOR, "|".join(re.escape(t) for t in all_terminators)))
|
|
127
|
+
|
|
128
|
+
# 5. Add low-precedence patterns
|
|
129
|
+
patterns.extend([(TokenType.WHITESPACE, r"\s+"), (TokenType.OTHER, r".")])
|
|
130
|
+
|
|
131
|
+
return patterns
|
|
132
|
+
|
|
133
|
+
def _get_dialect_specific_patterns(self) -> list[tuple[TokenType, TokenPattern]]:
|
|
134
|
+
"""Override to add dialect-specific token patterns."""
|
|
135
|
+
return []
|
|
136
|
+
|
|
137
|
+
@staticmethod
|
|
138
|
+
def is_real_block_ender(tokens: list[Token], current_pos: int) -> bool:
|
|
139
|
+
"""Check if this END keyword is actually a block ender.
|
|
140
|
+
|
|
141
|
+
Override in dialect configs to handle cases like END IF, END LOOP, etc.
|
|
142
|
+
that are not true block enders.
|
|
143
|
+
"""
|
|
144
|
+
_ = tokens, current_pos # Default implementation doesn't use these
|
|
145
|
+
return True
|
|
146
|
+
|
|
147
|
+
def should_delay_semicolon_termination(self, tokens: list[Token], current_pos: int) -> bool:
|
|
148
|
+
"""Check if semicolon termination should be delayed.
|
|
149
|
+
|
|
150
|
+
Override in dialect configs to handle special cases like Oracle END; /
|
|
151
|
+
"""
|
|
152
|
+
_ = tokens, current_pos # Default implementation doesn't use these
|
|
153
|
+
return False
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class OracleDialectConfig(DialectConfig):
|
|
157
|
+
"""Configuration for Oracle PL/SQL dialect."""
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def name(self) -> str:
|
|
161
|
+
return "oracle"
|
|
162
|
+
|
|
163
|
+
@property
|
|
164
|
+
def block_starters(self) -> set[str]:
|
|
165
|
+
return {"BEGIN", "DECLARE", "CASE"}
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def block_enders(self) -> set[str]:
|
|
169
|
+
return {"END"}
|
|
170
|
+
|
|
171
|
+
@property
|
|
172
|
+
def statement_terminators(self) -> set[str]:
|
|
173
|
+
return {";"}
|
|
174
|
+
|
|
175
|
+
@property
|
|
176
|
+
def special_terminators(self) -> dict[str, Callable[[list[Token], int], bool]]:
|
|
177
|
+
return {"/": self._handle_slash_terminator}
|
|
178
|
+
|
|
179
|
+
def should_delay_semicolon_termination(self, tokens: list[Token], current_pos: int) -> bool:
|
|
180
|
+
"""Check if we should delay semicolon termination to look for a slash.
|
|
181
|
+
|
|
182
|
+
In Oracle, after END; we should check if there's a / coming up on its own line.
|
|
183
|
+
"""
|
|
184
|
+
# Look backwards to see if we just processed an END token
|
|
185
|
+
pos = current_pos - 1
|
|
186
|
+
while pos >= 0:
|
|
187
|
+
token = tokens[pos]
|
|
188
|
+
if token.type == TokenType.WHITESPACE:
|
|
189
|
+
pos -= 1
|
|
190
|
+
continue
|
|
191
|
+
if token.type == TokenType.KEYWORD and token.value.upper() == "END":
|
|
192
|
+
# We found END just before this semicolon
|
|
193
|
+
# Now look ahead to see if there's a / on its own line
|
|
194
|
+
return self._has_upcoming_slash(tokens, current_pos)
|
|
195
|
+
# Found something else, not an END
|
|
196
|
+
break
|
|
197
|
+
|
|
198
|
+
return False
|
|
199
|
+
|
|
200
|
+
def _has_upcoming_slash(self, tokens: list[Token], current_pos: int) -> bool:
|
|
201
|
+
"""Check if there's a / terminator coming up on its own line."""
|
|
202
|
+
pos = current_pos + 1
|
|
203
|
+
found_newline = False
|
|
204
|
+
|
|
205
|
+
while pos < len(tokens):
|
|
206
|
+
token = tokens[pos]
|
|
207
|
+
if token.type == TokenType.WHITESPACE:
|
|
208
|
+
if "\n" in token.value:
|
|
209
|
+
found_newline = True
|
|
210
|
+
pos += 1
|
|
211
|
+
continue
|
|
212
|
+
if token.type == TokenType.TERMINATOR and token.value == "/":
|
|
213
|
+
# Found a /, check if it's valid (on its own line)
|
|
214
|
+
return found_newline and self._handle_slash_terminator(tokens, pos)
|
|
215
|
+
if token.type in {TokenType.COMMENT_LINE, TokenType.COMMENT_BLOCK}:
|
|
216
|
+
# Skip comments
|
|
217
|
+
pos += 1
|
|
218
|
+
continue
|
|
219
|
+
# Found non-whitespace, non-comment content
|
|
220
|
+
break
|
|
221
|
+
|
|
222
|
+
return False
|
|
223
|
+
|
|
224
|
+
@staticmethod
|
|
225
|
+
def is_real_block_ender(tokens: list[Token], current_pos: int) -> bool:
|
|
226
|
+
"""Check if this END keyword is actually a block ender.
|
|
227
|
+
|
|
228
|
+
In Oracle PL/SQL, END followed by IF, LOOP, CASE etc. are not block enders
|
|
229
|
+
for BEGIN blocks - they terminate control structures.
|
|
230
|
+
"""
|
|
231
|
+
# Look ahead for the next non-whitespace token(s)
|
|
232
|
+
pos = current_pos + 1
|
|
233
|
+
while pos < len(tokens):
|
|
234
|
+
next_token = tokens[pos]
|
|
235
|
+
|
|
236
|
+
if next_token.type == TokenType.WHITESPACE:
|
|
237
|
+
pos += 1
|
|
238
|
+
continue
|
|
239
|
+
if next_token.type == TokenType.OTHER:
|
|
240
|
+
# Collect consecutive OTHER tokens to form a word
|
|
241
|
+
word_chars = []
|
|
242
|
+
word_pos = pos
|
|
243
|
+
while word_pos < len(tokens) and tokens[word_pos].type == TokenType.OTHER:
|
|
244
|
+
word_chars.append(tokens[word_pos].value)
|
|
245
|
+
word_pos += 1
|
|
246
|
+
|
|
247
|
+
word = "".join(word_chars).upper()
|
|
248
|
+
if word in {"IF", "LOOP", "CASE", "WHILE"}:
|
|
249
|
+
return False # This is not a block ender
|
|
250
|
+
# Found a non-whitespace token that's not one of our special cases
|
|
251
|
+
break
|
|
252
|
+
return True # This is a real block ender
|
|
253
|
+
|
|
254
|
+
@staticmethod
|
|
255
|
+
def _handle_slash_terminator(tokens: list[Token], current_pos: int) -> bool:
|
|
256
|
+
"""Oracle / must be on its own line after whitespace only."""
|
|
257
|
+
if current_pos == 0:
|
|
258
|
+
return True # / at start is valid
|
|
259
|
+
|
|
260
|
+
# Look backwards to find the start of the line
|
|
261
|
+
pos = current_pos - 1
|
|
262
|
+
while pos >= 0:
|
|
263
|
+
token = tokens[pos]
|
|
264
|
+
if "\n" in token.value:
|
|
265
|
+
# Found newline, check if only whitespace between newline and /
|
|
266
|
+
break
|
|
267
|
+
if token.type not in {TokenType.WHITESPACE, TokenType.COMMENT_LINE}:
|
|
268
|
+
return False # Non-whitespace before / on same line
|
|
269
|
+
pos -= 1
|
|
270
|
+
|
|
271
|
+
return True
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class TSQLDialectConfig(DialectConfig):
|
|
275
|
+
"""Configuration for T-SQL (SQL Server) dialect."""
|
|
276
|
+
|
|
277
|
+
@property
|
|
278
|
+
def name(self) -> str:
|
|
279
|
+
return "tsql"
|
|
280
|
+
|
|
281
|
+
@property
|
|
282
|
+
def block_starters(self) -> set[str]:
|
|
283
|
+
return {"BEGIN", "TRY"}
|
|
284
|
+
|
|
285
|
+
@property
|
|
286
|
+
def block_enders(self) -> set[str]:
|
|
287
|
+
return {"END", "CATCH"}
|
|
288
|
+
|
|
289
|
+
@property
|
|
290
|
+
def statement_terminators(self) -> set[str]:
|
|
291
|
+
return {";"}
|
|
292
|
+
|
|
293
|
+
@property
|
|
294
|
+
def batch_separators(self) -> set[str]:
|
|
295
|
+
return {"GO"}
|
|
296
|
+
|
|
297
|
+
@staticmethod
|
|
298
|
+
def validate_batch_separator(tokens: list[Token], current_pos: int) -> bool:
|
|
299
|
+
"""GO must be the only keyword on its line."""
|
|
300
|
+
# Look for non-whitespace tokens on the same line
|
|
301
|
+
# Implementation similar to Oracle slash handler
|
|
302
|
+
_ = tokens, current_pos # Simplified implementation
|
|
303
|
+
return True # Simplified for now
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
class PostgreSQLDialectConfig(DialectConfig):
|
|
307
|
+
"""Configuration for PostgreSQL dialect with dollar-quoted strings."""
|
|
308
|
+
|
|
309
|
+
@property
|
|
310
|
+
def name(self) -> str:
|
|
311
|
+
return "postgresql"
|
|
312
|
+
|
|
313
|
+
@property
|
|
314
|
+
def block_starters(self) -> set[str]:
|
|
315
|
+
return {"BEGIN", "DECLARE", "CASE", "DO"}
|
|
316
|
+
|
|
317
|
+
@property
|
|
318
|
+
def block_enders(self) -> set[str]:
|
|
319
|
+
return {"END"}
|
|
320
|
+
|
|
321
|
+
@property
|
|
322
|
+
def statement_terminators(self) -> set[str]:
|
|
323
|
+
return {";"}
|
|
324
|
+
|
|
325
|
+
def _get_dialect_specific_patterns(self) -> list[tuple[TokenType, TokenPattern]]:
|
|
326
|
+
"""Add PostgreSQL-specific patterns like dollar-quoted strings."""
|
|
327
|
+
return [(TokenType.STRING_LITERAL, self._handle_dollar_quoted_string)]
|
|
328
|
+
|
|
329
|
+
@staticmethod
|
|
330
|
+
def _handle_dollar_quoted_string(text: str, position: int, line: int, column: int) -> Optional[Token]:
|
|
331
|
+
"""Handle PostgreSQL dollar-quoted strings like $tag$...$tag$."""
|
|
332
|
+
# Match opening tag
|
|
333
|
+
start_match = re.match(r"\$([a-zA-Z_][a-zA-Z0-9_]*)?\$", text[position:])
|
|
334
|
+
if not start_match:
|
|
335
|
+
return None
|
|
336
|
+
|
|
337
|
+
tag = start_match.group(0) # The full opening tag, e.g., "$tag$"
|
|
338
|
+
content_start = position + len(tag)
|
|
339
|
+
|
|
340
|
+
# Find the corresponding closing tag
|
|
341
|
+
try:
|
|
342
|
+
content_end = text.index(tag, content_start)
|
|
343
|
+
full_value = text[position : content_end + len(tag)]
|
|
344
|
+
|
|
345
|
+
return Token(type=TokenType.STRING_LITERAL, value=full_value, line=line, column=column, position=position)
|
|
346
|
+
except ValueError:
|
|
347
|
+
# Closing tag not found
|
|
348
|
+
return None
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
class GenericDialectConfig(DialectConfig):
|
|
352
|
+
"""Generic SQL dialect configuration for standard SQL."""
|
|
353
|
+
|
|
354
|
+
@property
|
|
355
|
+
def name(self) -> str:
|
|
356
|
+
return "generic"
|
|
357
|
+
|
|
358
|
+
@property
|
|
359
|
+
def block_starters(self) -> set[str]:
|
|
360
|
+
return {"BEGIN", "DECLARE", "CASE"}
|
|
361
|
+
|
|
362
|
+
@property
|
|
363
|
+
def block_enders(self) -> set[str]:
|
|
364
|
+
return {"END"}
|
|
365
|
+
|
|
366
|
+
@property
|
|
367
|
+
def statement_terminators(self) -> set[str]:
|
|
368
|
+
return {";"}
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
class MySQLDialectConfig(DialectConfig):
|
|
372
|
+
"""Configuration for MySQL dialect."""
|
|
373
|
+
|
|
374
|
+
@property
|
|
375
|
+
def name(self) -> str:
|
|
376
|
+
return "mysql"
|
|
377
|
+
|
|
378
|
+
@property
|
|
379
|
+
def block_starters(self) -> set[str]:
|
|
380
|
+
return {"BEGIN", "DECLARE", "CASE"}
|
|
381
|
+
|
|
382
|
+
@property
|
|
383
|
+
def block_enders(self) -> set[str]:
|
|
384
|
+
return {"END"}
|
|
385
|
+
|
|
386
|
+
@property
|
|
387
|
+
def statement_terminators(self) -> set[str]:
|
|
388
|
+
return {";"}
|
|
389
|
+
|
|
390
|
+
@property
|
|
391
|
+
def special_terminators(self) -> dict[str, Callable[[list[Token], int], bool]]:
|
|
392
|
+
"""MySQL supports DELIMITER command for changing terminators."""
|
|
393
|
+
return {"\\g": lambda _tokens, _pos: True, "\\G": lambda _tokens, _pos: True}
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
class SQLiteDialectConfig(DialectConfig):
|
|
397
|
+
"""Configuration for SQLite dialect."""
|
|
398
|
+
|
|
399
|
+
@property
|
|
400
|
+
def name(self) -> str:
|
|
401
|
+
return "sqlite"
|
|
402
|
+
|
|
403
|
+
@property
|
|
404
|
+
def block_starters(self) -> set[str]:
|
|
405
|
+
# SQLite has limited block support
|
|
406
|
+
return {"BEGIN", "CASE"}
|
|
407
|
+
|
|
408
|
+
@property
|
|
409
|
+
def block_enders(self) -> set[str]:
|
|
410
|
+
return {"END"}
|
|
411
|
+
|
|
412
|
+
@property
|
|
413
|
+
def statement_terminators(self) -> set[str]:
|
|
414
|
+
return {";"}
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
class DuckDBDialectConfig(DialectConfig):
|
|
418
|
+
"""Configuration for DuckDB dialect."""
|
|
419
|
+
|
|
420
|
+
@property
|
|
421
|
+
def name(self) -> str:
|
|
422
|
+
return "duckdb"
|
|
423
|
+
|
|
424
|
+
@property
|
|
425
|
+
def block_starters(self) -> set[str]:
|
|
426
|
+
return {"BEGIN", "CASE"}
|
|
427
|
+
|
|
428
|
+
@property
|
|
429
|
+
def block_enders(self) -> set[str]:
|
|
430
|
+
return {"END"}
|
|
431
|
+
|
|
432
|
+
@property
|
|
433
|
+
def statement_terminators(self) -> set[str]:
|
|
434
|
+
return {";"}
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
class BigQueryDialectConfig(DialectConfig):
|
|
438
|
+
"""Configuration for BigQuery dialect."""
|
|
439
|
+
|
|
440
|
+
@property
|
|
441
|
+
def name(self) -> str:
|
|
442
|
+
return "bigquery"
|
|
443
|
+
|
|
444
|
+
@property
|
|
445
|
+
def block_starters(self) -> set[str]:
|
|
446
|
+
return {"BEGIN", "CASE"}
|
|
447
|
+
|
|
448
|
+
@property
|
|
449
|
+
def block_enders(self) -> set[str]:
|
|
450
|
+
return {"END"}
|
|
451
|
+
|
|
452
|
+
@property
|
|
453
|
+
def statement_terminators(self) -> set[str]:
|
|
454
|
+
return {";"}
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
class StatementSplitter:
|
|
458
|
+
"""Splits SQL scripts into individual statements using a lexer-driven state machine."""
|
|
459
|
+
|
|
460
|
+
def __init__(self, dialect: DialectConfig, strip_trailing_semicolon: bool = False) -> None:
|
|
461
|
+
"""Initialize the splitter with a specific dialect configuration.
|
|
462
|
+
|
|
463
|
+
Args:
|
|
464
|
+
dialect: The dialect configuration to use
|
|
465
|
+
strip_trailing_semicolon: If True, remove trailing semicolons from statements
|
|
466
|
+
"""
|
|
467
|
+
self.dialect = dialect
|
|
468
|
+
self.strip_trailing_semicolon = strip_trailing_semicolon
|
|
469
|
+
self.token_patterns = dialect.get_all_token_patterns()
|
|
470
|
+
self._compiled_patterns = self._compile_patterns()
|
|
471
|
+
|
|
472
|
+
def _compile_patterns(self) -> list[tuple[TokenType, CompiledTokenPattern]]:
|
|
473
|
+
"""Compile regex patterns for efficiency."""
|
|
474
|
+
compiled: list[tuple[TokenType, CompiledTokenPattern]] = []
|
|
475
|
+
for token_type, pattern in self.token_patterns:
|
|
476
|
+
if isinstance(pattern, str):
|
|
477
|
+
compiled.append((token_type, re.compile(pattern, re.IGNORECASE | re.DOTALL)))
|
|
478
|
+
else:
|
|
479
|
+
# It's a callable
|
|
480
|
+
compiled.append((token_type, pattern))
|
|
481
|
+
return compiled
|
|
482
|
+
|
|
483
|
+
def _tokenize(self, sql: str) -> Generator[Token, None, None]:
|
|
484
|
+
"""Tokenize the SQL script into a stream of tokens.
|
|
485
|
+
|
|
486
|
+
sql: The SQL script to tokenize
|
|
487
|
+
|
|
488
|
+
Yields:
|
|
489
|
+
Token objects representing the recognized tokens in the script.
|
|
490
|
+
|
|
491
|
+
"""
|
|
492
|
+
pos = 0
|
|
493
|
+
line = 1
|
|
494
|
+
line_start = 0
|
|
495
|
+
|
|
496
|
+
while pos < len(sql):
|
|
497
|
+
matched = False
|
|
498
|
+
|
|
499
|
+
for token_type, pattern in self._compiled_patterns:
|
|
500
|
+
if callable(pattern):
|
|
501
|
+
# Call the handler function
|
|
502
|
+
column = pos - line_start + 1
|
|
503
|
+
token = pattern(sql, pos, line, column)
|
|
504
|
+
if token:
|
|
505
|
+
# Update position tracking
|
|
506
|
+
newlines = token.value.count("\n")
|
|
507
|
+
if newlines > 0:
|
|
508
|
+
line += newlines
|
|
509
|
+
last_newline = token.value.rfind("\n")
|
|
510
|
+
line_start = pos + last_newline + 1
|
|
511
|
+
|
|
512
|
+
yield token
|
|
513
|
+
pos += len(token.value)
|
|
514
|
+
matched = True
|
|
515
|
+
break
|
|
516
|
+
else:
|
|
517
|
+
# Use regex
|
|
518
|
+
match = pattern.match(sql, pos)
|
|
519
|
+
if match:
|
|
520
|
+
value = match.group(0)
|
|
521
|
+
column = pos - line_start + 1
|
|
522
|
+
|
|
523
|
+
# Update line tracking
|
|
524
|
+
newlines = value.count("\n")
|
|
525
|
+
if newlines > 0:
|
|
526
|
+
line += newlines
|
|
527
|
+
last_newline = value.rfind("\n")
|
|
528
|
+
line_start = pos + last_newline + 1
|
|
529
|
+
|
|
530
|
+
yield Token(type=token_type, value=value, line=line, column=column, position=pos)
|
|
531
|
+
pos = match.end()
|
|
532
|
+
matched = True
|
|
533
|
+
break
|
|
534
|
+
|
|
535
|
+
if not matched:
|
|
536
|
+
# This should never happen with our catch-all OTHER pattern
|
|
537
|
+
logger.error("Failed to tokenize at position %d: %s", pos, sql[pos : pos + 20])
|
|
538
|
+
pos += 1 # Skip the problematic character
|
|
539
|
+
|
|
540
|
+
def split(self, sql: str) -> list[str]:
|
|
541
|
+
"""Split the SQL script into individual statements."""
|
|
542
|
+
statements = []
|
|
543
|
+
current_statement_tokens = []
|
|
544
|
+
current_statement_chars = []
|
|
545
|
+
block_stack = []
|
|
546
|
+
|
|
547
|
+
# Convert token generator to list so we can look ahead
|
|
548
|
+
all_tokens = list(self._tokenize(sql))
|
|
549
|
+
|
|
550
|
+
for token_idx, token in enumerate(all_tokens):
|
|
551
|
+
# Always accumulate the original text
|
|
552
|
+
current_statement_chars.append(token.value)
|
|
553
|
+
|
|
554
|
+
# Skip whitespace and comments for logic (but keep in output)
|
|
555
|
+
if token.type in {TokenType.WHITESPACE, TokenType.COMMENT_LINE, TokenType.COMMENT_BLOCK}:
|
|
556
|
+
current_statement_tokens.append(token)
|
|
557
|
+
continue
|
|
558
|
+
|
|
559
|
+
current_statement_tokens.append(token)
|
|
560
|
+
token_upper = token.value.upper()
|
|
561
|
+
|
|
562
|
+
# Update block nesting
|
|
563
|
+
if token.type == TokenType.KEYWORD:
|
|
564
|
+
if token_upper in self.dialect.block_starters:
|
|
565
|
+
block_stack.append(token_upper)
|
|
566
|
+
if len(block_stack) > self.dialect.max_nesting_depth:
|
|
567
|
+
msg = f"Maximum nesting depth ({self.dialect.max_nesting_depth}) exceeded"
|
|
568
|
+
raise ValueError(msg)
|
|
569
|
+
elif token_upper in self.dialect.block_enders:
|
|
570
|
+
# Check if this is actually a block ender (not END IF, END LOOP, etc.)
|
|
571
|
+
if block_stack and self.dialect.is_real_block_ender(all_tokens, token_idx):
|
|
572
|
+
block_stack.pop()
|
|
573
|
+
|
|
574
|
+
# Check for statement termination
|
|
575
|
+
is_terminator = False
|
|
576
|
+
if not block_stack: # Only terminate when not inside a block
|
|
577
|
+
if token.type == TokenType.TERMINATOR:
|
|
578
|
+
if token.value in self.dialect.statement_terminators:
|
|
579
|
+
# Check if we should delay this termination (e.g., for Oracle END; /)
|
|
580
|
+
should_delay = self.dialect.should_delay_semicolon_termination(all_tokens, token_idx)
|
|
581
|
+
|
|
582
|
+
# Also check if there's a batch separator coming up (for T-SQL GO)
|
|
583
|
+
if not should_delay and token.value == ";" and self.dialect.batch_separators:
|
|
584
|
+
# In dialects with batch separators, semicolons don't terminate
|
|
585
|
+
# statements - only batch separators do
|
|
586
|
+
should_delay = True
|
|
587
|
+
|
|
588
|
+
if not should_delay:
|
|
589
|
+
is_terminator = True
|
|
590
|
+
elif token.value in self.dialect.special_terminators:
|
|
591
|
+
# Call the handler to validate
|
|
592
|
+
handler = self.dialect.special_terminators[token.value]
|
|
593
|
+
if handler(all_tokens, token_idx):
|
|
594
|
+
is_terminator = True
|
|
595
|
+
|
|
596
|
+
elif token.type == TokenType.KEYWORD and token_upper in self.dialect.batch_separators:
|
|
597
|
+
# Batch separators like GO should be included with the preceding statement
|
|
598
|
+
is_terminator = True
|
|
599
|
+
|
|
600
|
+
if is_terminator:
|
|
601
|
+
# Save the statement
|
|
602
|
+
statement = "".join(current_statement_chars).strip()
|
|
603
|
+
|
|
604
|
+
# Determine if this is a PL/SQL block
|
|
605
|
+
is_plsql_block = self._is_plsql_block(current_statement_tokens)
|
|
606
|
+
|
|
607
|
+
# Optionally strip the trailing terminator
|
|
608
|
+
# For PL/SQL blocks, never strip the semicolon as it's syntactically required
|
|
609
|
+
if (
|
|
610
|
+
self.strip_trailing_semicolon
|
|
611
|
+
and token.type == TokenType.TERMINATOR
|
|
612
|
+
and statement.endswith(token.value)
|
|
613
|
+
and not is_plsql_block
|
|
614
|
+
):
|
|
615
|
+
statement = statement[: -len(token.value)].rstrip()
|
|
616
|
+
|
|
617
|
+
if statement and self._contains_executable_content(statement):
|
|
618
|
+
statements.append(statement)
|
|
619
|
+
current_statement_tokens = []
|
|
620
|
+
current_statement_chars = []
|
|
621
|
+
|
|
622
|
+
# Handle any remaining content
|
|
623
|
+
if current_statement_chars:
|
|
624
|
+
statement = "".join(current_statement_chars).strip()
|
|
625
|
+
if statement and self._contains_executable_content(statement):
|
|
626
|
+
statements.append(statement)
|
|
627
|
+
|
|
628
|
+
return statements
|
|
629
|
+
|
|
630
|
+
@staticmethod
|
|
631
|
+
def _is_plsql_block(tokens: list[Token]) -> bool:
|
|
632
|
+
"""Check if the token list represents a PL/SQL block.
|
|
633
|
+
|
|
634
|
+
Args:
|
|
635
|
+
tokens: List of tokens for the current statement
|
|
636
|
+
|
|
637
|
+
Returns:
|
|
638
|
+
True if this is a PL/SQL block (BEGIN...END or DECLARE...END)
|
|
639
|
+
"""
|
|
640
|
+
# Find the first meaningful keyword token (skip whitespace and comments)
|
|
641
|
+
for token in tokens:
|
|
642
|
+
if token.type == TokenType.KEYWORD:
|
|
643
|
+
return token.value.upper() in {"BEGIN", "DECLARE"}
|
|
644
|
+
return False
|
|
645
|
+
|
|
646
|
+
def _contains_executable_content(self, statement: str) -> bool:
|
|
647
|
+
"""Check if a statement contains actual executable content (not just comments/whitespace).
|
|
648
|
+
|
|
649
|
+
Args:
|
|
650
|
+
statement: The statement string to check
|
|
651
|
+
|
|
652
|
+
Returns:
|
|
653
|
+
True if the statement contains executable SQL, False if it's only comments/whitespace
|
|
654
|
+
"""
|
|
655
|
+
# Tokenize the statement to check its content
|
|
656
|
+
tokens = list(self._tokenize(statement))
|
|
657
|
+
|
|
658
|
+
# Check if there are any non-comment, non-whitespace tokens
|
|
659
|
+
for token in tokens:
|
|
660
|
+
if token.type not in {TokenType.WHITESPACE, TokenType.COMMENT_LINE, TokenType.COMMENT_BLOCK}:
|
|
661
|
+
return True
|
|
662
|
+
|
|
663
|
+
return False
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
def split_sql_script(script: str, dialect: str = "generic", strip_trailing_semicolon: bool = False) -> list[str]:
|
|
667
|
+
"""Split a SQL script into statements using the appropriate dialect.
|
|
668
|
+
|
|
669
|
+
Args:
|
|
670
|
+
script: The SQL script to split
|
|
671
|
+
dialect: The SQL dialect name ('oracle', 'tsql', 'postgresql', etc.)
|
|
672
|
+
strip_trailing_semicolon: If True, remove trailing terminators from statements
|
|
673
|
+
|
|
674
|
+
Returns:
|
|
675
|
+
List of individual SQL statements
|
|
676
|
+
"""
|
|
677
|
+
dialect_configs = {
|
|
678
|
+
# Standard dialects
|
|
679
|
+
"generic": GenericDialectConfig(),
|
|
680
|
+
# Major databases
|
|
681
|
+
"oracle": OracleDialectConfig(),
|
|
682
|
+
"tsql": TSQLDialectConfig(),
|
|
683
|
+
"mssql": TSQLDialectConfig(), # Alias for tsql
|
|
684
|
+
"sqlserver": TSQLDialectConfig(), # Alias for tsql
|
|
685
|
+
"postgresql": PostgreSQLDialectConfig(),
|
|
686
|
+
"postgres": PostgreSQLDialectConfig(), # Common alias
|
|
687
|
+
"mysql": MySQLDialectConfig(),
|
|
688
|
+
"sqlite": SQLiteDialectConfig(),
|
|
689
|
+
# Modern analytical databases
|
|
690
|
+
"duckdb": DuckDBDialectConfig(),
|
|
691
|
+
"bigquery": BigQueryDialectConfig(),
|
|
692
|
+
}
|
|
693
|
+
|
|
694
|
+
config = dialect_configs.get(dialect.lower())
|
|
695
|
+
if not config:
|
|
696
|
+
# Fall back to generic config for unknown dialects
|
|
697
|
+
logger.warning("Unknown dialect '%s', using generic SQL splitter", dialect)
|
|
698
|
+
config = GenericDialectConfig()
|
|
699
|
+
|
|
700
|
+
splitter = StatementSplitter(config, strip_trailing_semicolon=strip_trailing_semicolon)
|
|
701
|
+
return splitter.split(script)
|