sqlspec 0.14.1__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (159) hide show
  1. sqlspec/__init__.py +50 -25
  2. sqlspec/__main__.py +1 -1
  3. sqlspec/__metadata__.py +1 -3
  4. sqlspec/_serialization.py +1 -2
  5. sqlspec/_sql.py +480 -121
  6. sqlspec/_typing.py +278 -142
  7. sqlspec/adapters/adbc/__init__.py +4 -3
  8. sqlspec/adapters/adbc/_types.py +12 -0
  9. sqlspec/adapters/adbc/config.py +115 -260
  10. sqlspec/adapters/adbc/driver.py +462 -367
  11. sqlspec/adapters/aiosqlite/__init__.py +18 -3
  12. sqlspec/adapters/aiosqlite/_types.py +13 -0
  13. sqlspec/adapters/aiosqlite/config.py +199 -129
  14. sqlspec/adapters/aiosqlite/driver.py +230 -269
  15. sqlspec/adapters/asyncmy/__init__.py +18 -3
  16. sqlspec/adapters/asyncmy/_types.py +12 -0
  17. sqlspec/adapters/asyncmy/config.py +80 -168
  18. sqlspec/adapters/asyncmy/driver.py +260 -225
  19. sqlspec/adapters/asyncpg/__init__.py +19 -4
  20. sqlspec/adapters/asyncpg/_types.py +17 -0
  21. sqlspec/adapters/asyncpg/config.py +82 -181
  22. sqlspec/adapters/asyncpg/driver.py +285 -383
  23. sqlspec/adapters/bigquery/__init__.py +17 -3
  24. sqlspec/adapters/bigquery/_types.py +12 -0
  25. sqlspec/adapters/bigquery/config.py +191 -258
  26. sqlspec/adapters/bigquery/driver.py +474 -646
  27. sqlspec/adapters/duckdb/__init__.py +14 -3
  28. sqlspec/adapters/duckdb/_types.py +12 -0
  29. sqlspec/adapters/duckdb/config.py +415 -351
  30. sqlspec/adapters/duckdb/driver.py +343 -413
  31. sqlspec/adapters/oracledb/__init__.py +19 -5
  32. sqlspec/adapters/oracledb/_types.py +14 -0
  33. sqlspec/adapters/oracledb/config.py +123 -379
  34. sqlspec/adapters/oracledb/driver.py +507 -560
  35. sqlspec/adapters/psqlpy/__init__.py +13 -3
  36. sqlspec/adapters/psqlpy/_types.py +11 -0
  37. sqlspec/adapters/psqlpy/config.py +93 -254
  38. sqlspec/adapters/psqlpy/driver.py +505 -234
  39. sqlspec/adapters/psycopg/__init__.py +19 -5
  40. sqlspec/adapters/psycopg/_types.py +17 -0
  41. sqlspec/adapters/psycopg/config.py +143 -403
  42. sqlspec/adapters/psycopg/driver.py +706 -872
  43. sqlspec/adapters/sqlite/__init__.py +14 -3
  44. sqlspec/adapters/sqlite/_types.py +11 -0
  45. sqlspec/adapters/sqlite/config.py +202 -118
  46. sqlspec/adapters/sqlite/driver.py +264 -303
  47. sqlspec/base.py +105 -9
  48. sqlspec/{statement/builder → builder}/__init__.py +12 -14
  49. sqlspec/{statement/builder → builder}/_base.py +120 -55
  50. sqlspec/{statement/builder → builder}/_column.py +17 -6
  51. sqlspec/{statement/builder → builder}/_ddl.py +46 -79
  52. sqlspec/{statement/builder → builder}/_ddl_utils.py +5 -10
  53. sqlspec/{statement/builder → builder}/_delete.py +6 -25
  54. sqlspec/{statement/builder → builder}/_insert.py +18 -65
  55. sqlspec/builder/_merge.py +56 -0
  56. sqlspec/{statement/builder → builder}/_parsing_utils.py +8 -11
  57. sqlspec/{statement/builder → builder}/_select.py +11 -56
  58. sqlspec/{statement/builder → builder}/_update.py +12 -18
  59. sqlspec/{statement/builder → builder}/mixins/__init__.py +10 -14
  60. sqlspec/{statement/builder → builder}/mixins/_cte_and_set_ops.py +48 -59
  61. sqlspec/{statement/builder → builder}/mixins/_insert_operations.py +34 -18
  62. sqlspec/{statement/builder → builder}/mixins/_join_operations.py +1 -3
  63. sqlspec/{statement/builder → builder}/mixins/_merge_operations.py +19 -9
  64. sqlspec/{statement/builder → builder}/mixins/_order_limit_operations.py +3 -3
  65. sqlspec/{statement/builder → builder}/mixins/_pivot_operations.py +4 -8
  66. sqlspec/{statement/builder → builder}/mixins/_select_operations.py +25 -38
  67. sqlspec/{statement/builder → builder}/mixins/_update_operations.py +15 -16
  68. sqlspec/{statement/builder → builder}/mixins/_where_clause.py +210 -137
  69. sqlspec/cli.py +4 -5
  70. sqlspec/config.py +180 -133
  71. sqlspec/core/__init__.py +63 -0
  72. sqlspec/core/cache.py +873 -0
  73. sqlspec/core/compiler.py +396 -0
  74. sqlspec/core/filters.py +830 -0
  75. sqlspec/core/hashing.py +310 -0
  76. sqlspec/core/parameters.py +1209 -0
  77. sqlspec/core/result.py +664 -0
  78. sqlspec/{statement → core}/splitter.py +321 -191
  79. sqlspec/core/statement.py +666 -0
  80. sqlspec/driver/__init__.py +7 -10
  81. sqlspec/driver/_async.py +387 -176
  82. sqlspec/driver/_common.py +527 -289
  83. sqlspec/driver/_sync.py +390 -172
  84. sqlspec/driver/mixins/__init__.py +2 -19
  85. sqlspec/driver/mixins/_result_tools.py +164 -0
  86. sqlspec/driver/mixins/_sql_translator.py +6 -3
  87. sqlspec/exceptions.py +5 -252
  88. sqlspec/extensions/aiosql/adapter.py +93 -96
  89. sqlspec/extensions/litestar/cli.py +1 -1
  90. sqlspec/extensions/litestar/config.py +0 -1
  91. sqlspec/extensions/litestar/handlers.py +15 -26
  92. sqlspec/extensions/litestar/plugin.py +18 -16
  93. sqlspec/extensions/litestar/providers.py +17 -52
  94. sqlspec/loader.py +424 -105
  95. sqlspec/migrations/__init__.py +12 -0
  96. sqlspec/migrations/base.py +92 -68
  97. sqlspec/migrations/commands.py +24 -106
  98. sqlspec/migrations/loaders.py +402 -0
  99. sqlspec/migrations/runner.py +49 -51
  100. sqlspec/migrations/tracker.py +31 -44
  101. sqlspec/migrations/utils.py +64 -24
  102. sqlspec/protocols.py +7 -183
  103. sqlspec/storage/__init__.py +1 -1
  104. sqlspec/storage/backends/base.py +37 -40
  105. sqlspec/storage/backends/fsspec.py +136 -112
  106. sqlspec/storage/backends/obstore.py +138 -160
  107. sqlspec/storage/capabilities.py +5 -4
  108. sqlspec/storage/registry.py +57 -106
  109. sqlspec/typing.py +136 -115
  110. sqlspec/utils/__init__.py +2 -3
  111. sqlspec/utils/correlation.py +0 -3
  112. sqlspec/utils/deprecation.py +6 -6
  113. sqlspec/utils/fixtures.py +6 -6
  114. sqlspec/utils/logging.py +0 -2
  115. sqlspec/utils/module_loader.py +7 -12
  116. sqlspec/utils/singleton.py +0 -1
  117. sqlspec/utils/sync_tools.py +17 -38
  118. sqlspec/utils/text.py +12 -51
  119. sqlspec/utils/type_guards.py +443 -232
  120. {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/METADATA +7 -2
  121. sqlspec-0.16.0.dist-info/RECORD +134 -0
  122. sqlspec/adapters/adbc/transformers.py +0 -108
  123. sqlspec/driver/connection.py +0 -207
  124. sqlspec/driver/mixins/_cache.py +0 -114
  125. sqlspec/driver/mixins/_csv_writer.py +0 -91
  126. sqlspec/driver/mixins/_pipeline.py +0 -508
  127. sqlspec/driver/mixins/_query_tools.py +0 -796
  128. sqlspec/driver/mixins/_result_utils.py +0 -138
  129. sqlspec/driver/mixins/_storage.py +0 -912
  130. sqlspec/driver/mixins/_type_coercion.py +0 -128
  131. sqlspec/driver/parameters.py +0 -138
  132. sqlspec/statement/__init__.py +0 -21
  133. sqlspec/statement/builder/_merge.py +0 -95
  134. sqlspec/statement/cache.py +0 -50
  135. sqlspec/statement/filters.py +0 -625
  136. sqlspec/statement/parameters.py +0 -956
  137. sqlspec/statement/pipelines/__init__.py +0 -210
  138. sqlspec/statement/pipelines/analyzers/__init__.py +0 -9
  139. sqlspec/statement/pipelines/analyzers/_analyzer.py +0 -646
  140. sqlspec/statement/pipelines/context.py +0 -109
  141. sqlspec/statement/pipelines/transformers/__init__.py +0 -7
  142. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +0 -88
  143. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +0 -1247
  144. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +0 -76
  145. sqlspec/statement/pipelines/validators/__init__.py +0 -23
  146. sqlspec/statement/pipelines/validators/_dml_safety.py +0 -290
  147. sqlspec/statement/pipelines/validators/_parameter_style.py +0 -370
  148. sqlspec/statement/pipelines/validators/_performance.py +0 -714
  149. sqlspec/statement/pipelines/validators/_security.py +0 -967
  150. sqlspec/statement/result.py +0 -435
  151. sqlspec/statement/sql.py +0 -1774
  152. sqlspec/utils/cached_property.py +0 -25
  153. sqlspec/utils/statement_hashing.py +0 -203
  154. sqlspec-0.14.1.dist-info/RECORD +0 -145
  155. /sqlspec/{statement/builder → builder}/mixins/_delete_operations.py +0 -0
  156. {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/WHEEL +0 -0
  157. {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/entry_points.txt +0 -0
  158. {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/licenses/LICENSE +0 -0
  159. {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/licenses/NOTICE +0 -0
@@ -1,19 +1,35 @@
1
- """SQL script statement splitter with dialect-aware lexer-driven state machine.
2
-
3
- This module provides a robust way to split SQL scripts into individual statements,
4
- handling complex constructs like PL/SQL blocks, T-SQL batches, and nested blocks.
1
+ """SQL statement splitter with caching and dialect support.
2
+
3
+ This module provides a SQL script statement splitter with caching and
4
+ multiple dialect support.
5
+
6
+ Components:
7
+ - StatementSplitter: SQL splitter with caching
8
+ - DialectConfig: Dialect configuration system
9
+ - Token/TokenType: Tokenization system
10
+ - Caching: LRU caching for split results
11
+ - Pattern compilation caching
12
+
13
+ Features:
14
+ - Support for multiple SQL dialects (Oracle, T-SQL, PostgreSQL, MySQL, SQLite, DuckDB, BigQuery)
15
+ - Cached pattern compilation
16
+ - LRU caching for split results
17
+ - Optimized tokenization
18
+ - Complete preservation of split_sql_script function
5
19
  """
6
20
 
7
21
  import re
22
+ import threading
8
23
  from abc import ABC, abstractmethod
9
24
  from collections.abc import Generator
10
- from dataclasses import dataclass
11
25
  from enum import Enum
12
26
  from re import Pattern
13
- from typing import Callable, Optional, Union
27
+ from typing import Any, Callable, Optional, Union
14
28
 
29
+ from mypy_extensions import mypyc_attr
15
30
  from typing_extensions import TypeAlias
16
31
 
32
+ from sqlspec.core.cache import CacheKey, UnifiedCache
17
33
  from sqlspec.utils.logging import get_logger
18
34
 
19
35
  __all__ = (
@@ -27,8 +43,33 @@ __all__ = (
27
43
  "split_sql_script",
28
44
  )
29
45
 
46
+ logger = get_logger("sqlspec.core.splitter")
47
+
48
+ DEFAULT_PATTERN_CACHE_SIZE = 1000 # Compiled regex patterns
49
+ DEFAULT_RESULT_CACHE_SIZE = 5000 # Split results
50
+ DEFAULT_CACHE_TTL = 3600 # 1 hour TTL
30
51
 
31
- logger = get_logger("sqlspec")
52
+ DIALECT_CONFIG_SLOTS = (
53
+ "_block_starters",
54
+ "_block_enders",
55
+ "_statement_terminators",
56
+ "_batch_separators",
57
+ "_special_terminators",
58
+ "_max_nesting_depth",
59
+ "_name",
60
+ )
61
+
62
+ TOKEN_SLOTS = ("type", "value", "line", "column", "position")
63
+
64
+ SPLITTER_SLOTS = (
65
+ "_dialect",
66
+ "_strip_trailing_semicolon",
67
+ "_token_patterns",
68
+ "_compiled_patterns",
69
+ "_pattern_cache_key",
70
+ "_result_cache",
71
+ "_pattern_cache",
72
+ )
32
73
 
33
74
 
34
75
  class TokenType(Enum):
@@ -45,15 +86,21 @@ class TokenType(Enum):
45
86
  OTHER = "OTHER"
46
87
 
47
88
 
48
- @dataclass
89
+ @mypyc_attr(allow_interpreted_subclasses=True)
49
90
  class Token:
50
- """Represents a single token in the SQL script."""
91
+ """SQL token with metadata."""
92
+
93
+ __slots__ = TOKEN_SLOTS
51
94
 
52
- type: TokenType
53
- value: str
54
- line: int
55
- column: int
56
- position: int # Absolute position in the script
95
+ def __init__(self, type: TokenType, value: str, line: int, column: int, position: int) -> None:
96
+ self.type = type
97
+ self.value = value
98
+ self.line = line
99
+ self.column = column
100
+ self.position = position
101
+
102
+ def __repr__(self) -> str:
103
+ return f"Token({self.type.value}, {self.value!r}, {self.line}:{self.column})"
57
104
 
58
105
 
59
106
  TokenHandler: TypeAlias = Callable[[str, int, int, int], Optional[Token]]
@@ -61,9 +108,22 @@ TokenPattern: TypeAlias = Union[str, TokenHandler]
61
108
  CompiledTokenPattern: TypeAlias = Union[Pattern[str], TokenHandler]
62
109
 
63
110
 
111
+ @mypyc_attr(allow_interpreted_subclasses=True)
64
112
  class DialectConfig(ABC):
65
113
  """Abstract base class for SQL dialect configurations."""
66
114
 
115
+ __slots__ = DIALECT_CONFIG_SLOTS
116
+
117
+ def __init__(self) -> None:
118
+ """Initialize dialect configuration."""
119
+ self._name: Optional[str] = None
120
+ self._block_starters: Optional[set[str]] = None
121
+ self._block_enders: Optional[set[str]] = None
122
+ self._statement_terminators: Optional[set[str]] = None
123
+ self._batch_separators: Optional[set[str]] = None
124
+ self._special_terminators: Optional[dict[str, Callable[[list[Token], int], bool]]] = None
125
+ self._max_nesting_depth: Optional[int] = None
126
+
67
127
  @property
68
128
  @abstractmethod
69
129
  def name(self) -> str:
@@ -87,44 +147,44 @@ class DialectConfig(ABC):
87
147
  @property
88
148
  def batch_separators(self) -> set[str]:
89
149
  """Keywords that separate batches (e.g., GO for T-SQL)."""
90
- return set()
150
+ if self._batch_separators is None:
151
+ self._batch_separators = set()
152
+ return self._batch_separators
91
153
 
92
154
  @property
93
155
  def special_terminators(self) -> dict[str, Callable[[list[Token], int], bool]]:
94
156
  """Special terminators that need custom handling."""
95
- return {}
157
+ if self._special_terminators is None:
158
+ self._special_terminators = {}
159
+ return self._special_terminators
96
160
 
97
161
  @property
98
162
  def max_nesting_depth(self) -> int:
99
163
  """Maximum allowed nesting depth for blocks."""
100
- return 256
164
+ if self._max_nesting_depth is None:
165
+ self._max_nesting_depth = 256
166
+ return self._max_nesting_depth
101
167
 
102
168
  def get_all_token_patterns(self) -> list[tuple[TokenType, TokenPattern]]:
103
169
  """Assembles the complete, ordered list of token regex patterns."""
104
- # 1. Start with high-precedence patterns
105
170
  patterns: list[tuple[TokenType, TokenPattern]] = [
106
171
  (TokenType.COMMENT_LINE, r"--[^\n]*"),
107
172
  (TokenType.COMMENT_BLOCK, r"/\*[\s\S]*?\*/"),
108
173
  (TokenType.STRING_LITERAL, r"'(?:[^']|'')*'"),
109
- (TokenType.QUOTED_IDENTIFIER, r'"[^"]*"|\[[^\]]*\]'), # Standard and T-SQL
174
+ (TokenType.QUOTED_IDENTIFIER, r'"[^"]*"|\[[^\]]*\]'),
110
175
  ]
111
176
 
112
- # 2. Add dialect-specific patterns (can be overridden)
113
177
  patterns.extend(self._get_dialect_specific_patterns())
114
178
 
115
- # 3. Dynamically build and insert keyword/separator patterns
116
179
  all_keywords = self.block_starters | self.block_enders | self.batch_separators
117
180
  if all_keywords:
118
181
  sorted_keywords = sorted(all_keywords, key=len, reverse=True)
119
182
  patterns.append((TokenType.KEYWORD, r"\b(" + "|".join(re.escape(kw) for kw in sorted_keywords) + r")\b"))
120
183
 
121
- # 4. Add terminators
122
184
  all_terminators = self.statement_terminators | set(self.special_terminators.keys())
123
185
  if all_terminators:
124
- # Escape special regex characters
125
186
  patterns.append((TokenType.TERMINATOR, "|".join(re.escape(t) for t in all_terminators)))
126
187
 
127
- # 5. Add low-precedence patterns
128
188
  patterns.extend([(TokenType.WHITESPACE, r"\s+"), (TokenType.OTHER, r".")])
129
189
 
130
190
  return patterns
@@ -134,21 +194,12 @@ class DialectConfig(ABC):
134
194
  return []
135
195
 
136
196
  @staticmethod
137
- def is_real_block_ender(tokens: list[Token], current_pos: int) -> bool:
138
- """Check if this END keyword is actually a block ender.
139
-
140
- Override in dialect configs to handle cases like END IF, END LOOP, etc.
141
- that are not true block enders.
142
- """
143
- _ = tokens, current_pos # Default implementation doesn't use these
197
+ def is_real_block_ender(tokens: list[Token], current_pos: int) -> bool: # noqa: ARG004
198
+ """Check if this END keyword is actually a block ender."""
144
199
  return True
145
200
 
146
201
  def should_delay_semicolon_termination(self, tokens: list[Token], current_pos: int) -> bool:
147
- """Check if semicolon termination should be delayed.
148
-
149
- Override in dialect configs to handle special cases like Oracle END; /
150
- """
151
- _ = tokens, current_pos # Default implementation doesn't use these
202
+ """Check if semicolon termination should be delayed."""
152
203
  return False
153
204
 
154
205
 
@@ -157,30 +208,36 @@ class OracleDialectConfig(DialectConfig):
157
208
 
158
209
  @property
159
210
  def name(self) -> str:
160
- return "oracle"
211
+ if self._name is None:
212
+ self._name = "oracle"
213
+ return self._name
161
214
 
162
215
  @property
163
216
  def block_starters(self) -> set[str]:
164
- return {"BEGIN", "DECLARE", "CASE"}
217
+ if self._block_starters is None:
218
+ self._block_starters = {"BEGIN", "DECLARE", "CASE"}
219
+ return self._block_starters
165
220
 
166
221
  @property
167
222
  def block_enders(self) -> set[str]:
168
- return {"END"}
223
+ if self._block_enders is None:
224
+ self._block_enders = {"END"}
225
+ return self._block_enders
169
226
 
170
227
  @property
171
228
  def statement_terminators(self) -> set[str]:
172
- return {";"}
229
+ if self._statement_terminators is None:
230
+ self._statement_terminators = {";"}
231
+ return self._statement_terminators
173
232
 
174
233
  @property
175
234
  def special_terminators(self) -> dict[str, Callable[[list[Token], int], bool]]:
176
- return {"/": self._handle_slash_terminator}
235
+ if self._special_terminators is None:
236
+ self._special_terminators = {"/": self._handle_slash_terminator}
237
+ return self._special_terminators
177
238
 
178
239
  def should_delay_semicolon_termination(self, tokens: list[Token], current_pos: int) -> bool:
179
- """Check if we should delay semicolon termination to look for a slash.
180
-
181
- In Oracle, after END; we should check if there's a / coming up on its own line.
182
- """
183
- # Look backwards to see if we just processed an END token
240
+ """Check if we should delay semicolon termination to look for a slash."""
184
241
  pos = current_pos - 1
185
242
  while pos >= 0:
186
243
  token = tokens[pos]
@@ -188,10 +245,7 @@ class OracleDialectConfig(DialectConfig):
188
245
  pos -= 1
189
246
  continue
190
247
  if token.type == TokenType.KEYWORD and token.value.upper() == "END":
191
- # We found END just before this semicolon
192
- # Now look ahead to see if there's a / on its own line
193
248
  return self._has_upcoming_slash(tokens, current_pos)
194
- # Found something else, not an END
195
249
  break
196
250
 
197
251
  return False
@@ -209,25 +263,17 @@ class OracleDialectConfig(DialectConfig):
209
263
  pos += 1
210
264
  continue
211
265
  if token.type == TokenType.TERMINATOR and token.value == "/":
212
- # Found a /, check if it's valid (on its own line)
213
266
  return found_newline and self._handle_slash_terminator(tokens, pos)
214
267
  if token.type in {TokenType.COMMENT_LINE, TokenType.COMMENT_BLOCK}:
215
- # Skip comments
216
268
  pos += 1
217
269
  continue
218
- # Found non-whitespace, non-comment content
219
270
  break
220
271
 
221
272
  return False
222
273
 
223
274
  @staticmethod
224
275
  def is_real_block_ender(tokens: list[Token], current_pos: int) -> bool:
225
- """Check if this END keyword is actually a block ender.
226
-
227
- In Oracle PL/SQL, END followed by IF, LOOP, CASE etc. are not block enders
228
- for BEGIN blocks - they terminate control structures.
229
- """
230
- # Look ahead for the next non-whitespace token(s)
276
+ """Check if this END keyword is actually a block ender for Oracle PL/SQL."""
231
277
  pos = current_pos + 1
232
278
  while pos < len(tokens):
233
279
  next_token = tokens[pos]
@@ -236,7 +282,6 @@ class OracleDialectConfig(DialectConfig):
236
282
  pos += 1
237
283
  continue
238
284
  if next_token.type == TokenType.OTHER:
239
- # Collect consecutive OTHER tokens to form a word
240
285
  word_chars = []
241
286
  word_pos = pos
242
287
  while word_pos < len(tokens) and tokens[word_pos].type == TokenType.OTHER:
@@ -245,26 +290,23 @@ class OracleDialectConfig(DialectConfig):
245
290
 
246
291
  word = "".join(word_chars).upper()
247
292
  if word in {"IF", "LOOP", "CASE", "WHILE"}:
248
- return False # This is not a block ender
249
- # Found a non-whitespace token that's not one of our special cases
293
+ return False
250
294
  break
251
- return True # This is a real block ender
295
+ return True
252
296
 
253
297
  @staticmethod
254
298
  def _handle_slash_terminator(tokens: list[Token], current_pos: int) -> bool:
255
299
  """Oracle / must be on its own line after whitespace only."""
256
300
  if current_pos == 0:
257
- return True # / at start is valid
301
+ return True
258
302
 
259
- # Look backwards to find the start of the line
260
303
  pos = current_pos - 1
261
304
  while pos >= 0:
262
305
  token = tokens[pos]
263
306
  if "\n" in token.value:
264
- # Found newline, check if only whitespace between newline and /
265
307
  break
266
308
  if token.type not in {TokenType.WHITESPACE, TokenType.COMMENT_LINE}:
267
- return False # Non-whitespace before / on same line
309
+ return False
268
310
  pos -= 1
269
311
 
270
312
  return True
@@ -275,31 +317,33 @@ class TSQLDialectConfig(DialectConfig):
275
317
 
276
318
  @property
277
319
  def name(self) -> str:
278
- return "tsql"
320
+ if self._name is None:
321
+ self._name = "tsql"
322
+ return self._name
279
323
 
280
324
  @property
281
325
  def block_starters(self) -> set[str]:
282
- return {"BEGIN", "TRY"}
326
+ if self._block_starters is None:
327
+ self._block_starters = {"BEGIN", "TRY"}
328
+ return self._block_starters
283
329
 
284
330
  @property
285
331
  def block_enders(self) -> set[str]:
286
- return {"END", "CATCH"}
332
+ if self._block_enders is None:
333
+ self._block_enders = {"END", "CATCH"}
334
+ return self._block_enders
287
335
 
288
336
  @property
289
337
  def statement_terminators(self) -> set[str]:
290
- return {";"}
338
+ if self._statement_terminators is None:
339
+ self._statement_terminators = {";"}
340
+ return self._statement_terminators
291
341
 
292
342
  @property
293
343
  def batch_separators(self) -> set[str]:
294
- return {"GO"}
295
-
296
- @staticmethod
297
- def validate_batch_separator(tokens: list[Token], current_pos: int) -> bool:
298
- """GO must be the only keyword on its line."""
299
- # Look for non-whitespace tokens on the same line
300
- # Implementation similar to Oracle slash handler
301
- _ = tokens, current_pos # Simplified implementation
302
- return True # Simplified for now
344
+ if self._batch_separators is None:
345
+ self._batch_separators = {"GO"}
346
+ return self._batch_separators
303
347
 
304
348
 
305
349
  class PostgreSQLDialectConfig(DialectConfig):
@@ -307,19 +351,27 @@ class PostgreSQLDialectConfig(DialectConfig):
307
351
 
308
352
  @property
309
353
  def name(self) -> str:
310
- return "postgresql"
354
+ if self._name is None:
355
+ self._name = "postgresql"
356
+ return self._name
311
357
 
312
358
  @property
313
359
  def block_starters(self) -> set[str]:
314
- return {"BEGIN", "DECLARE", "CASE", "DO"}
360
+ if self._block_starters is None:
361
+ self._block_starters = {"BEGIN", "DECLARE", "CASE", "DO"}
362
+ return self._block_starters
315
363
 
316
364
  @property
317
365
  def block_enders(self) -> set[str]:
318
- return {"END"}
366
+ if self._block_enders is None:
367
+ self._block_enders = {"END"}
368
+ return self._block_enders
319
369
 
320
370
  @property
321
371
  def statement_terminators(self) -> set[str]:
322
- return {";"}
372
+ if self._statement_terminators is None:
373
+ self._statement_terminators = {";"}
374
+ return self._statement_terminators
323
375
 
324
376
  def _get_dialect_specific_patterns(self) -> list[tuple[TokenType, TokenPattern]]:
325
377
  """Add PostgreSQL-specific patterns like dollar-quoted strings."""
@@ -328,12 +380,11 @@ class PostgreSQLDialectConfig(DialectConfig):
328
380
  @staticmethod
329
381
  def _handle_dollar_quoted_string(text: str, position: int, line: int, column: int) -> Optional[Token]:
330
382
  """Handle PostgreSQL dollar-quoted strings like $tag$...$tag$."""
331
- # Match opening tag
332
383
  start_match = re.match(r"\$([a-zA-Z_][a-zA-Z0-9_]*)?\$", text[position:])
333
384
  if not start_match:
334
385
  return None
335
386
 
336
- tag = start_match.group(0) # The full opening tag, e.g., "$tag$"
387
+ tag = start_match.group(0)
337
388
  content_start = position + len(tag)
338
389
 
339
390
  try:
@@ -342,7 +393,6 @@ class PostgreSQLDialectConfig(DialectConfig):
342
393
 
343
394
  return Token(type=TokenType.STRING_LITERAL, value=full_value, line=line, column=column, position=position)
344
395
  except ValueError:
345
- # Closing tag not found
346
396
  return None
347
397
 
348
398
 
@@ -351,19 +401,27 @@ class GenericDialectConfig(DialectConfig):
351
401
 
352
402
  @property
353
403
  def name(self) -> str:
354
- return "generic"
404
+ if self._name is None:
405
+ self._name = "generic"
406
+ return self._name
355
407
 
356
408
  @property
357
409
  def block_starters(self) -> set[str]:
358
- return {"BEGIN", "DECLARE", "CASE"}
410
+ if self._block_starters is None:
411
+ self._block_starters = {"BEGIN", "DECLARE", "CASE"}
412
+ return self._block_starters
359
413
 
360
414
  @property
361
415
  def block_enders(self) -> set[str]:
362
- return {"END"}
416
+ if self._block_enders is None:
417
+ self._block_enders = {"END"}
418
+ return self._block_enders
363
419
 
364
420
  @property
365
421
  def statement_terminators(self) -> set[str]:
366
- return {";"}
422
+ if self._statement_terminators is None:
423
+ self._statement_terminators = {";"}
424
+ return self._statement_terminators
367
425
 
368
426
 
369
427
  class MySQLDialectConfig(DialectConfig):
@@ -371,24 +429,33 @@ class MySQLDialectConfig(DialectConfig):
371
429
 
372
430
  @property
373
431
  def name(self) -> str:
374
- return "mysql"
432
+ if self._name is None:
433
+ self._name = "mysql"
434
+ return self._name
375
435
 
376
436
  @property
377
437
  def block_starters(self) -> set[str]:
378
- return {"BEGIN", "DECLARE", "CASE"}
438
+ if self._block_starters is None:
439
+ self._block_starters = {"BEGIN", "DECLARE", "CASE"}
440
+ return self._block_starters
379
441
 
380
442
  @property
381
443
  def block_enders(self) -> set[str]:
382
- return {"END"}
444
+ if self._block_enders is None:
445
+ self._block_enders = {"END"}
446
+ return self._block_enders
383
447
 
384
448
  @property
385
449
  def statement_terminators(self) -> set[str]:
386
- return {";"}
450
+ if self._statement_terminators is None:
451
+ self._statement_terminators = {";"}
452
+ return self._statement_terminators
387
453
 
388
454
  @property
389
455
  def special_terminators(self) -> dict[str, Callable[[list[Token], int], bool]]:
390
- """MySQL supports DELIMITER command for changing terminators."""
391
- return {"\\g": lambda _tokens, _pos: True, "\\G": lambda _tokens, _pos: True}
456
+ if self._special_terminators is None:
457
+ self._special_terminators = {"\\g": lambda _tokens, _pos: True, "\\G": lambda _tokens, _pos: True}
458
+ return self._special_terminators
392
459
 
393
460
 
394
461
  class SQLiteDialectConfig(DialectConfig):
@@ -396,20 +463,27 @@ class SQLiteDialectConfig(DialectConfig):
396
463
 
397
464
  @property
398
465
  def name(self) -> str:
399
- return "sqlite"
466
+ if self._name is None:
467
+ self._name = "sqlite"
468
+ return self._name
400
469
 
401
470
  @property
402
471
  def block_starters(self) -> set[str]:
403
- # SQLite has limited block support
404
- return {"BEGIN", "CASE"}
472
+ if self._block_starters is None:
473
+ self._block_starters = {"BEGIN", "CASE"}
474
+ return self._block_starters
405
475
 
406
476
  @property
407
477
  def block_enders(self) -> set[str]:
408
- return {"END"}
478
+ if self._block_enders is None:
479
+ self._block_enders = {"END"}
480
+ return self._block_enders
409
481
 
410
482
  @property
411
483
  def statement_terminators(self) -> set[str]:
412
- return {";"}
484
+ if self._statement_terminators is None:
485
+ self._statement_terminators = {";"}
486
+ return self._statement_terminators
413
487
 
414
488
 
415
489
  class DuckDBDialectConfig(DialectConfig):
@@ -417,19 +491,27 @@ class DuckDBDialectConfig(DialectConfig):
417
491
 
418
492
  @property
419
493
  def name(self) -> str:
420
- return "duckdb"
494
+ if self._name is None:
495
+ self._name = "duckdb"
496
+ return self._name
421
497
 
422
498
  @property
423
499
  def block_starters(self) -> set[str]:
424
- return {"BEGIN", "CASE"}
500
+ if self._block_starters is None:
501
+ self._block_starters = {"BEGIN", "CASE"}
502
+ return self._block_starters
425
503
 
426
504
  @property
427
505
  def block_enders(self) -> set[str]:
428
- return {"END"}
506
+ if self._block_enders is None:
507
+ self._block_enders = {"END"}
508
+ return self._block_enders
429
509
 
430
510
  @property
431
511
  def statement_terminators(self) -> set[str]:
432
- return {";"}
512
+ if self._statement_terminators is None:
513
+ self._statement_terminators = {";"}
514
+ return self._statement_terminators
433
515
 
434
516
 
435
517
  class BigQueryDialectConfig(DialectConfig):
@@ -437,56 +519,97 @@ class BigQueryDialectConfig(DialectConfig):
437
519
 
438
520
  @property
439
521
  def name(self) -> str:
440
- return "bigquery"
522
+ if self._name is None:
523
+ self._name = "bigquery"
524
+ return self._name
441
525
 
442
526
  @property
443
527
  def block_starters(self) -> set[str]:
444
- return {"BEGIN", "CASE"}
528
+ if self._block_starters is None:
529
+ self._block_starters = {"BEGIN", "CASE"}
530
+ return self._block_starters
445
531
 
446
532
  @property
447
533
  def block_enders(self) -> set[str]:
448
- return {"END"}
534
+ if self._block_enders is None:
535
+ self._block_enders = {"END"}
536
+ return self._block_enders
449
537
 
450
538
  @property
451
539
  def statement_terminators(self) -> set[str]:
452
- return {";"}
540
+ if self._statement_terminators is None:
541
+ self._statement_terminators = {";"}
542
+ return self._statement_terminators
543
+
544
+
545
+ _pattern_cache: Optional[UnifiedCache[list[tuple[TokenType, CompiledTokenPattern]]]] = None
546
+ _result_cache: Optional[UnifiedCache[list[str]]] = None
547
+ _cache_lock = threading.Lock()
548
+
453
549
 
550
+ def _get_pattern_cache() -> UnifiedCache[list[tuple[TokenType, CompiledTokenPattern]]]:
551
+ """Get or create the pattern compilation cache."""
552
+ global _pattern_cache
553
+ if _pattern_cache is None:
554
+ with _cache_lock:
555
+ if _pattern_cache is None:
556
+ _pattern_cache = UnifiedCache[list[tuple[TokenType, CompiledTokenPattern]]](
557
+ max_size=DEFAULT_PATTERN_CACHE_SIZE, ttl_seconds=DEFAULT_CACHE_TTL
558
+ )
559
+ return _pattern_cache
454
560
 
561
+
562
+ def _get_result_cache() -> UnifiedCache[list[str]]:
563
+ """Get or create the result cache."""
564
+ global _result_cache
565
+ if _result_cache is None:
566
+ with _cache_lock:
567
+ if _result_cache is None:
568
+ _result_cache = UnifiedCache[list[str]](
569
+ max_size=DEFAULT_RESULT_CACHE_SIZE, ttl_seconds=DEFAULT_CACHE_TTL
570
+ )
571
+ return _result_cache
572
+
573
+
574
+ @mypyc_attr(allow_interpreted_subclasses=False)
455
575
  class StatementSplitter:
456
- """Splits SQL scripts into individual statements using a lexer-driven state machine."""
576
+ """SQL script splitter with caching and dialect support."""
577
+
578
+ __slots__ = SPLITTER_SLOTS
457
579
 
458
580
  def __init__(self, dialect: DialectConfig, strip_trailing_semicolon: bool = False) -> None:
459
- """Initialize the splitter with a specific dialect configuration.
460
-
461
- Args:
462
- dialect: The dialect configuration to use
463
- strip_trailing_semicolon: If True, remove trailing semicolons from statements
464
- """
465
- self.dialect = dialect
466
- self.strip_trailing_semicolon = strip_trailing_semicolon
467
- self.token_patterns = dialect.get_all_token_patterns()
468
- self._compiled_patterns = self._compile_patterns()
469
-
470
- def _compile_patterns(self) -> list[tuple[TokenType, CompiledTokenPattern]]:
471
- """Compile regex patterns for efficiency."""
581
+ """Initialize the splitter with caching and dialect support."""
582
+ self._dialect = dialect
583
+ self._strip_trailing_semicolon = strip_trailing_semicolon
584
+ self._token_patterns = dialect.get_all_token_patterns()
585
+
586
+ self._pattern_cache_key = f"{dialect.name}:{hash(tuple(str(p) for _, p in self._token_patterns))}"
587
+
588
+ self._pattern_cache = _get_pattern_cache()
589
+ self._result_cache = _get_result_cache()
590
+
591
+ self._compiled_patterns = self._get_or_compile_patterns()
592
+
593
+ def _get_or_compile_patterns(self) -> list[tuple[TokenType, CompiledTokenPattern]]:
594
+ """Get compiled patterns from cache or compile and cache them."""
595
+ cache_key = CacheKey(("pattern", self._pattern_cache_key))
596
+
597
+ cached_patterns = self._pattern_cache.get(cache_key)
598
+ if cached_patterns is not None:
599
+ return cached_patterns
600
+
472
601
  compiled: list[tuple[TokenType, CompiledTokenPattern]] = []
473
- for token_type, pattern in self.token_patterns:
602
+ for token_type, pattern in self._token_patterns:
474
603
  if isinstance(pattern, str):
475
604
  compiled.append((token_type, re.compile(pattern, re.IGNORECASE | re.DOTALL)))
476
605
  else:
477
- # It's a callable
478
606
  compiled.append((token_type, pattern))
607
+
608
+ self._pattern_cache.put(cache_key, compiled)
479
609
  return compiled
480
610
 
481
611
  def _tokenize(self, sql: str) -> Generator[Token, None, None]:
482
- """Tokenize the SQL script into a stream of tokens.
483
-
484
- sql: The SQL script to tokenize
485
-
486
- Yields:
487
- Token objects representing the recognized tokens in the script.
488
-
489
- """
612
+ """Tokenize SQL string."""
490
613
  pos = 0
491
614
  line = 1
492
615
  line_start = 0
@@ -496,7 +619,6 @@ class StatementSplitter:
496
619
 
497
620
  for token_type, pattern in self._compiled_patterns:
498
621
  if callable(pattern):
499
- # Call the handler function
500
622
  column = pos - line_start + 1
501
623
  token = pattern(sql, pos, line, column)
502
624
  if token:
@@ -511,7 +633,6 @@ class StatementSplitter:
511
633
  matched = True
512
634
  break
513
635
  else:
514
- # Use regex
515
636
  match = pattern.match(sql, pos)
516
637
  if match:
517
638
  value = match.group(0)
@@ -529,12 +650,25 @@ class StatementSplitter:
529
650
  break
530
651
 
531
652
  if not matched:
532
- # This should never happen with our catch-all OTHER pattern
533
653
  logger.error("Failed to tokenize at position %d: %s", pos, sql[pos : pos + 20])
534
- pos += 1 # Skip the problematic character
654
+ pos += 1
535
655
 
536
656
  def split(self, sql: str) -> list[str]:
537
- """Split the SQL script into individual statements."""
657
+ """Split SQL script with result caching."""
658
+ script_hash = hash(sql)
659
+ cache_key = CacheKey(("split", self._dialect.name, script_hash, self._strip_trailing_semicolon))
660
+
661
+ cached_result = self._result_cache.get(cache_key)
662
+ if cached_result is not None:
663
+ return cached_result
664
+
665
+ statements = self._do_split(sql)
666
+
667
+ self._result_cache.put(cache_key, statements)
668
+ return statements
669
+
670
+ def _do_split(self, sql: str) -> list[str]:
671
+ """Perform SQL script splitting."""
538
672
  statements = []
539
673
  current_statement_tokens = []
540
674
  current_statement_chars = []
@@ -543,10 +677,8 @@ class StatementSplitter:
543
677
  all_tokens = list(self._tokenize(sql))
544
678
 
545
679
  for token_idx, token in enumerate(all_tokens):
546
- # Always accumulate the original text
547
680
  current_statement_chars.append(token.value)
548
681
 
549
- # Skip whitespace and comments for logic (but keep in output)
550
682
  if token.type in {TokenType.WHITESPACE, TokenType.COMMENT_LINE, TokenType.COMMENT_BLOCK}:
551
683
  current_statement_tokens.append(token)
552
684
  continue
@@ -555,50 +687,41 @@ class StatementSplitter:
555
687
  token_upper = token.value.upper()
556
688
 
557
689
  if token.type == TokenType.KEYWORD:
558
- if token_upper in self.dialect.block_starters:
690
+ if token_upper in self._dialect.block_starters:
559
691
  block_stack.append(token_upper)
560
- if len(block_stack) > self.dialect.max_nesting_depth:
561
- msg = f"Maximum nesting depth ({self.dialect.max_nesting_depth}) exceeded"
692
+ if len(block_stack) > self._dialect.max_nesting_depth:
693
+ msg = f"Maximum nesting depth ({self._dialect.max_nesting_depth}) exceeded"
562
694
  raise ValueError(msg)
563
- elif token_upper in self.dialect.block_enders:
564
- if block_stack and self.dialect.is_real_block_ender(all_tokens, token_idx):
695
+ elif token_upper in self._dialect.block_enders:
696
+ if block_stack and self._dialect.is_real_block_ender(all_tokens, token_idx):
565
697
  block_stack.pop()
566
698
 
567
- # Check for statement termination
568
699
  is_terminator = False
569
- if not block_stack: # Only terminate when not inside a block
700
+ if not block_stack:
570
701
  if token.type == TokenType.TERMINATOR:
571
- if token.value in self.dialect.statement_terminators:
572
- should_delay = self.dialect.should_delay_semicolon_termination(all_tokens, token_idx)
702
+ if token.value in self._dialect.statement_terminators:
703
+ should_delay = self._dialect.should_delay_semicolon_termination(all_tokens, token_idx)
573
704
 
574
- # Also check if there's a batch separator coming up (for T-SQL GO)
575
- if not should_delay and token.value == ";" and self.dialect.batch_separators:
576
- # In dialects with batch separators, semicolons don't terminate
577
- # statements - only batch separators do
705
+ if not should_delay and token.value == ";" and self._dialect.batch_separators:
578
706
  should_delay = True
579
707
 
580
708
  if not should_delay:
581
709
  is_terminator = True
582
- elif token.value in self.dialect.special_terminators:
583
- # Call the handler to validate
584
- handler = self.dialect.special_terminators[token.value]
710
+ elif token.value in self._dialect.special_terminators:
711
+ handler = self._dialect.special_terminators[token.value]
585
712
  if handler(all_tokens, token_idx):
586
713
  is_terminator = True
587
714
 
588
- elif token.type == TokenType.KEYWORD and token_upper in self.dialect.batch_separators:
589
- # Batch separators like GO should be included with the preceding statement
715
+ elif token.type == TokenType.KEYWORD and token_upper in self._dialect.batch_separators:
590
716
  is_terminator = True
591
717
 
592
718
  if is_terminator:
593
- # Save the statement
594
719
  statement = "".join(current_statement_chars).strip()
595
720
 
596
721
  is_plsql_block = self._is_plsql_block(current_statement_tokens)
597
722
 
598
- # Optionally strip the trailing terminator
599
- # For PL/SQL blocks, never strip the semicolon as it's syntactically required
600
723
  if (
601
- self.strip_trailing_semicolon
724
+ self._strip_trailing_semicolon
602
725
  and token.type == TokenType.TERMINATOR
603
726
  and statement.endswith(token.value)
604
727
  and not is_plsql_block
@@ -619,29 +742,14 @@ class StatementSplitter:
619
742
 
620
743
  @staticmethod
621
744
  def _is_plsql_block(tokens: list[Token]) -> bool:
622
- """Check if the token list represents a PL/SQL block.
623
-
624
- Args:
625
- tokens: List of tokens for the current statement
626
-
627
- Returns:
628
- True if this is a PL/SQL block (BEGIN...END or DECLARE...END)
629
- """
745
+ """Check if the token list represents a PL/SQL block."""
630
746
  for token in tokens:
631
747
  if token.type == TokenType.KEYWORD:
632
748
  return token.value.upper() in {"BEGIN", "DECLARE"}
633
749
  return False
634
750
 
635
751
  def _contains_executable_content(self, statement: str) -> bool:
636
- """Check if a statement contains actual executable content (not just comments/whitespace).
637
-
638
- Args:
639
- statement: The statement string to check
640
-
641
- Returns:
642
- True if the statement contains executable SQL, False if it's only comments/whitespace
643
- """
644
- # Tokenize the statement to check its content
752
+ """Check if a statement contains actual executable content."""
645
753
  tokens = list(self._tokenize(statement))
646
754
 
647
755
  for token in tokens:
@@ -651,39 +759,61 @@ class StatementSplitter:
651
759
  return False
652
760
 
653
761
 
654
- def split_sql_script(script: str, dialect: str = "generic", strip_trailing_semicolon: bool = False) -> list[str]:
655
- """Split a SQL script into statements using the appropriate dialect.
762
+ def split_sql_script(script: str, dialect: Optional[str] = None, strip_trailing_terminator: bool = False) -> list[str]:
763
+ """Split SQL script into individual statements.
656
764
 
657
765
  Args:
658
766
  script: The SQL script to split
659
- dialect: The SQL dialect name ('oracle', 'tsql', 'postgresql', etc.)
660
- strip_trailing_semicolon: If True, remove trailing terminators from statements
767
+ dialect: The SQL dialect name
768
+ strip_trailing_terminator: If True, remove trailing terminators from statements
661
769
 
662
770
  Returns:
663
771
  List of individual SQL statements
664
772
  """
773
+ if dialect is None:
774
+ dialect = "generic"
775
+
665
776
  dialect_configs = {
666
- # Standard dialects
667
777
  "generic": GenericDialectConfig(),
668
- # Major databases
669
778
  "oracle": OracleDialectConfig(),
670
779
  "tsql": TSQLDialectConfig(),
671
- "mssql": TSQLDialectConfig(), # Alias for tsql
672
- "sqlserver": TSQLDialectConfig(), # Alias for tsql
780
+ "mssql": TSQLDialectConfig(),
781
+ "sqlserver": TSQLDialectConfig(),
673
782
  "postgresql": PostgreSQLDialectConfig(),
674
- "postgres": PostgreSQLDialectConfig(), # Common alias
783
+ "postgres": PostgreSQLDialectConfig(),
675
784
  "mysql": MySQLDialectConfig(),
676
785
  "sqlite": SQLiteDialectConfig(),
677
- # Modern analytical databases
678
786
  "duckdb": DuckDBDialectConfig(),
679
787
  "bigquery": BigQueryDialectConfig(),
680
788
  }
681
789
 
682
790
  config = dialect_configs.get(dialect.lower())
683
791
  if not config:
684
- # Fall back to generic config for unknown dialects
685
792
  logger.warning("Unknown dialect '%s', using generic SQL splitter", dialect)
686
793
  config = GenericDialectConfig()
687
794
 
688
- splitter = StatementSplitter(config, strip_trailing_semicolon=strip_trailing_semicolon)
795
+ splitter = StatementSplitter(config, strip_trailing_semicolon=strip_trailing_terminator)
689
796
  return splitter.split(script)
797
+
798
+
799
+ def clear_splitter_caches() -> None:
800
+ """Clear all splitter caches for memory management."""
801
+ pattern_cache = _get_pattern_cache()
802
+ result_cache = _get_result_cache()
803
+ pattern_cache.clear()
804
+ result_cache.clear()
805
+
806
+
807
+ def get_splitter_cache_stats() -> dict[str, Any]:
808
+ """Get statistics from splitter caches.
809
+
810
+ Returns:
811
+ Dictionary containing cache statistics
812
+ """
813
+ pattern_cache = _get_pattern_cache()
814
+ result_cache = _get_result_cache()
815
+
816
+ return {
817
+ "pattern_cache": {"size": pattern_cache.size(), "stats": pattern_cache.get_stats()},
818
+ "result_cache": {"size": result_cache.size(), "stats": result_cache.get_stats()},
819
+ }