sqlspec 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. sqlspec/__init__.py +104 -0
  2. sqlspec/__main__.py +12 -0
  3. sqlspec/__metadata__.py +14 -0
  4. sqlspec/_serialization.py +312 -0
  5. sqlspec/_typing.py +784 -0
  6. sqlspec/adapters/__init__.py +0 -0
  7. sqlspec/adapters/adbc/__init__.py +5 -0
  8. sqlspec/adapters/adbc/_types.py +12 -0
  9. sqlspec/adapters/adbc/adk/__init__.py +5 -0
  10. sqlspec/adapters/adbc/adk/store.py +880 -0
  11. sqlspec/adapters/adbc/config.py +436 -0
  12. sqlspec/adapters/adbc/data_dictionary.py +537 -0
  13. sqlspec/adapters/adbc/driver.py +841 -0
  14. sqlspec/adapters/adbc/litestar/__init__.py +5 -0
  15. sqlspec/adapters/adbc/litestar/store.py +504 -0
  16. sqlspec/adapters/adbc/type_converter.py +153 -0
  17. sqlspec/adapters/aiosqlite/__init__.py +29 -0
  18. sqlspec/adapters/aiosqlite/_types.py +13 -0
  19. sqlspec/adapters/aiosqlite/adk/__init__.py +5 -0
  20. sqlspec/adapters/aiosqlite/adk/store.py +536 -0
  21. sqlspec/adapters/aiosqlite/config.py +310 -0
  22. sqlspec/adapters/aiosqlite/data_dictionary.py +260 -0
  23. sqlspec/adapters/aiosqlite/driver.py +463 -0
  24. sqlspec/adapters/aiosqlite/litestar/__init__.py +5 -0
  25. sqlspec/adapters/aiosqlite/litestar/store.py +281 -0
  26. sqlspec/adapters/aiosqlite/pool.py +500 -0
  27. sqlspec/adapters/asyncmy/__init__.py +25 -0
  28. sqlspec/adapters/asyncmy/_types.py +12 -0
  29. sqlspec/adapters/asyncmy/adk/__init__.py +5 -0
  30. sqlspec/adapters/asyncmy/adk/store.py +503 -0
  31. sqlspec/adapters/asyncmy/config.py +246 -0
  32. sqlspec/adapters/asyncmy/data_dictionary.py +241 -0
  33. sqlspec/adapters/asyncmy/driver.py +632 -0
  34. sqlspec/adapters/asyncmy/litestar/__init__.py +5 -0
  35. sqlspec/adapters/asyncmy/litestar/store.py +296 -0
  36. sqlspec/adapters/asyncpg/__init__.py +23 -0
  37. sqlspec/adapters/asyncpg/_type_handlers.py +76 -0
  38. sqlspec/adapters/asyncpg/_types.py +23 -0
  39. sqlspec/adapters/asyncpg/adk/__init__.py +5 -0
  40. sqlspec/adapters/asyncpg/adk/store.py +460 -0
  41. sqlspec/adapters/asyncpg/config.py +464 -0
  42. sqlspec/adapters/asyncpg/data_dictionary.py +321 -0
  43. sqlspec/adapters/asyncpg/driver.py +720 -0
  44. sqlspec/adapters/asyncpg/litestar/__init__.py +5 -0
  45. sqlspec/adapters/asyncpg/litestar/store.py +253 -0
  46. sqlspec/adapters/bigquery/__init__.py +18 -0
  47. sqlspec/adapters/bigquery/_types.py +12 -0
  48. sqlspec/adapters/bigquery/adk/__init__.py +5 -0
  49. sqlspec/adapters/bigquery/adk/store.py +585 -0
  50. sqlspec/adapters/bigquery/config.py +298 -0
  51. sqlspec/adapters/bigquery/data_dictionary.py +256 -0
  52. sqlspec/adapters/bigquery/driver.py +1073 -0
  53. sqlspec/adapters/bigquery/litestar/__init__.py +5 -0
  54. sqlspec/adapters/bigquery/litestar/store.py +327 -0
  55. sqlspec/adapters/bigquery/type_converter.py +125 -0
  56. sqlspec/adapters/duckdb/__init__.py +24 -0
  57. sqlspec/adapters/duckdb/_types.py +12 -0
  58. sqlspec/adapters/duckdb/adk/__init__.py +14 -0
  59. sqlspec/adapters/duckdb/adk/store.py +563 -0
  60. sqlspec/adapters/duckdb/config.py +396 -0
  61. sqlspec/adapters/duckdb/data_dictionary.py +264 -0
  62. sqlspec/adapters/duckdb/driver.py +604 -0
  63. sqlspec/adapters/duckdb/litestar/__init__.py +5 -0
  64. sqlspec/adapters/duckdb/litestar/store.py +332 -0
  65. sqlspec/adapters/duckdb/pool.py +273 -0
  66. sqlspec/adapters/duckdb/type_converter.py +133 -0
  67. sqlspec/adapters/oracledb/__init__.py +32 -0
  68. sqlspec/adapters/oracledb/_numpy_handlers.py +133 -0
  69. sqlspec/adapters/oracledb/_types.py +39 -0
  70. sqlspec/adapters/oracledb/_uuid_handlers.py +130 -0
  71. sqlspec/adapters/oracledb/adk/__init__.py +5 -0
  72. sqlspec/adapters/oracledb/adk/store.py +1632 -0
  73. sqlspec/adapters/oracledb/config.py +469 -0
  74. sqlspec/adapters/oracledb/data_dictionary.py +717 -0
  75. sqlspec/adapters/oracledb/driver.py +1493 -0
  76. sqlspec/adapters/oracledb/litestar/__init__.py +5 -0
  77. sqlspec/adapters/oracledb/litestar/store.py +765 -0
  78. sqlspec/adapters/oracledb/migrations.py +532 -0
  79. sqlspec/adapters/oracledb/type_converter.py +207 -0
  80. sqlspec/adapters/psqlpy/__init__.py +16 -0
  81. sqlspec/adapters/psqlpy/_type_handlers.py +44 -0
  82. sqlspec/adapters/psqlpy/_types.py +12 -0
  83. sqlspec/adapters/psqlpy/adk/__init__.py +5 -0
  84. sqlspec/adapters/psqlpy/adk/store.py +483 -0
  85. sqlspec/adapters/psqlpy/config.py +271 -0
  86. sqlspec/adapters/psqlpy/data_dictionary.py +179 -0
  87. sqlspec/adapters/psqlpy/driver.py +892 -0
  88. sqlspec/adapters/psqlpy/litestar/__init__.py +5 -0
  89. sqlspec/adapters/psqlpy/litestar/store.py +272 -0
  90. sqlspec/adapters/psqlpy/type_converter.py +102 -0
  91. sqlspec/adapters/psycopg/__init__.py +32 -0
  92. sqlspec/adapters/psycopg/_type_handlers.py +90 -0
  93. sqlspec/adapters/psycopg/_types.py +18 -0
  94. sqlspec/adapters/psycopg/adk/__init__.py +5 -0
  95. sqlspec/adapters/psycopg/adk/store.py +962 -0
  96. sqlspec/adapters/psycopg/config.py +487 -0
  97. sqlspec/adapters/psycopg/data_dictionary.py +630 -0
  98. sqlspec/adapters/psycopg/driver.py +1336 -0
  99. sqlspec/adapters/psycopg/litestar/__init__.py +5 -0
  100. sqlspec/adapters/psycopg/litestar/store.py +554 -0
  101. sqlspec/adapters/spanner/__init__.py +38 -0
  102. sqlspec/adapters/spanner/_type_handlers.py +186 -0
  103. sqlspec/adapters/spanner/_types.py +12 -0
  104. sqlspec/adapters/spanner/adk/__init__.py +5 -0
  105. sqlspec/adapters/spanner/adk/store.py +435 -0
  106. sqlspec/adapters/spanner/config.py +241 -0
  107. sqlspec/adapters/spanner/data_dictionary.py +95 -0
  108. sqlspec/adapters/spanner/dialect/__init__.py +6 -0
  109. sqlspec/adapters/spanner/dialect/_spangres.py +52 -0
  110. sqlspec/adapters/spanner/dialect/_spanner.py +123 -0
  111. sqlspec/adapters/spanner/driver.py +366 -0
  112. sqlspec/adapters/spanner/litestar/__init__.py +5 -0
  113. sqlspec/adapters/spanner/litestar/store.py +266 -0
  114. sqlspec/adapters/spanner/type_converter.py +46 -0
  115. sqlspec/adapters/sqlite/__init__.py +18 -0
  116. sqlspec/adapters/sqlite/_type_handlers.py +86 -0
  117. sqlspec/adapters/sqlite/_types.py +11 -0
  118. sqlspec/adapters/sqlite/adk/__init__.py +5 -0
  119. sqlspec/adapters/sqlite/adk/store.py +582 -0
  120. sqlspec/adapters/sqlite/config.py +221 -0
  121. sqlspec/adapters/sqlite/data_dictionary.py +256 -0
  122. sqlspec/adapters/sqlite/driver.py +527 -0
  123. sqlspec/adapters/sqlite/litestar/__init__.py +5 -0
  124. sqlspec/adapters/sqlite/litestar/store.py +318 -0
  125. sqlspec/adapters/sqlite/pool.py +140 -0
  126. sqlspec/base.py +811 -0
  127. sqlspec/builder/__init__.py +146 -0
  128. sqlspec/builder/_base.py +900 -0
  129. sqlspec/builder/_column.py +517 -0
  130. sqlspec/builder/_ddl.py +1642 -0
  131. sqlspec/builder/_delete.py +84 -0
  132. sqlspec/builder/_dml.py +381 -0
  133. sqlspec/builder/_expression_wrappers.py +46 -0
  134. sqlspec/builder/_factory.py +1537 -0
  135. sqlspec/builder/_insert.py +315 -0
  136. sqlspec/builder/_join.py +375 -0
  137. sqlspec/builder/_merge.py +848 -0
  138. sqlspec/builder/_parsing_utils.py +297 -0
  139. sqlspec/builder/_select.py +1615 -0
  140. sqlspec/builder/_update.py +161 -0
  141. sqlspec/builder/_vector_expressions.py +259 -0
  142. sqlspec/cli.py +764 -0
  143. sqlspec/config.py +1540 -0
  144. sqlspec/core/__init__.py +305 -0
  145. sqlspec/core/cache.py +785 -0
  146. sqlspec/core/compiler.py +603 -0
  147. sqlspec/core/filters.py +872 -0
  148. sqlspec/core/hashing.py +274 -0
  149. sqlspec/core/metrics.py +83 -0
  150. sqlspec/core/parameters/__init__.py +64 -0
  151. sqlspec/core/parameters/_alignment.py +266 -0
  152. sqlspec/core/parameters/_converter.py +413 -0
  153. sqlspec/core/parameters/_processor.py +341 -0
  154. sqlspec/core/parameters/_registry.py +201 -0
  155. sqlspec/core/parameters/_transformers.py +226 -0
  156. sqlspec/core/parameters/_types.py +430 -0
  157. sqlspec/core/parameters/_validator.py +123 -0
  158. sqlspec/core/pipeline.py +187 -0
  159. sqlspec/core/result.py +1124 -0
  160. sqlspec/core/splitter.py +940 -0
  161. sqlspec/core/stack.py +163 -0
  162. sqlspec/core/statement.py +835 -0
  163. sqlspec/core/type_conversion.py +235 -0
  164. sqlspec/driver/__init__.py +36 -0
  165. sqlspec/driver/_async.py +1027 -0
  166. sqlspec/driver/_common.py +1236 -0
  167. sqlspec/driver/_sync.py +1025 -0
  168. sqlspec/driver/mixins/__init__.py +7 -0
  169. sqlspec/driver/mixins/_result_tools.py +61 -0
  170. sqlspec/driver/mixins/_sql_translator.py +122 -0
  171. sqlspec/driver/mixins/_storage.py +311 -0
  172. sqlspec/exceptions.py +321 -0
  173. sqlspec/extensions/__init__.py +0 -0
  174. sqlspec/extensions/adk/__init__.py +53 -0
  175. sqlspec/extensions/adk/_types.py +51 -0
  176. sqlspec/extensions/adk/converters.py +172 -0
  177. sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +144 -0
  178. sqlspec/extensions/adk/migrations/__init__.py +0 -0
  179. sqlspec/extensions/adk/service.py +181 -0
  180. sqlspec/extensions/adk/store.py +536 -0
  181. sqlspec/extensions/aiosql/__init__.py +10 -0
  182. sqlspec/extensions/aiosql/adapter.py +471 -0
  183. sqlspec/extensions/fastapi/__init__.py +19 -0
  184. sqlspec/extensions/fastapi/extension.py +341 -0
  185. sqlspec/extensions/fastapi/providers.py +543 -0
  186. sqlspec/extensions/flask/__init__.py +36 -0
  187. sqlspec/extensions/flask/_state.py +72 -0
  188. sqlspec/extensions/flask/_utils.py +40 -0
  189. sqlspec/extensions/flask/extension.py +402 -0
  190. sqlspec/extensions/litestar/__init__.py +23 -0
  191. sqlspec/extensions/litestar/_utils.py +52 -0
  192. sqlspec/extensions/litestar/cli.py +92 -0
  193. sqlspec/extensions/litestar/config.py +90 -0
  194. sqlspec/extensions/litestar/handlers.py +316 -0
  195. sqlspec/extensions/litestar/migrations/0001_create_session_table.py +137 -0
  196. sqlspec/extensions/litestar/migrations/__init__.py +3 -0
  197. sqlspec/extensions/litestar/plugin.py +638 -0
  198. sqlspec/extensions/litestar/providers.py +454 -0
  199. sqlspec/extensions/litestar/store.py +265 -0
  200. sqlspec/extensions/otel/__init__.py +58 -0
  201. sqlspec/extensions/prometheus/__init__.py +107 -0
  202. sqlspec/extensions/starlette/__init__.py +10 -0
  203. sqlspec/extensions/starlette/_state.py +26 -0
  204. sqlspec/extensions/starlette/_utils.py +52 -0
  205. sqlspec/extensions/starlette/extension.py +257 -0
  206. sqlspec/extensions/starlette/middleware.py +154 -0
  207. sqlspec/loader.py +716 -0
  208. sqlspec/migrations/__init__.py +36 -0
  209. sqlspec/migrations/base.py +728 -0
  210. sqlspec/migrations/commands.py +1140 -0
  211. sqlspec/migrations/context.py +142 -0
  212. sqlspec/migrations/fix.py +203 -0
  213. sqlspec/migrations/loaders.py +450 -0
  214. sqlspec/migrations/runner.py +1024 -0
  215. sqlspec/migrations/templates.py +234 -0
  216. sqlspec/migrations/tracker.py +403 -0
  217. sqlspec/migrations/utils.py +256 -0
  218. sqlspec/migrations/validation.py +203 -0
  219. sqlspec/observability/__init__.py +22 -0
  220. sqlspec/observability/_config.py +228 -0
  221. sqlspec/observability/_diagnostics.py +67 -0
  222. sqlspec/observability/_dispatcher.py +151 -0
  223. sqlspec/observability/_observer.py +180 -0
  224. sqlspec/observability/_runtime.py +381 -0
  225. sqlspec/observability/_spans.py +158 -0
  226. sqlspec/protocols.py +530 -0
  227. sqlspec/py.typed +0 -0
  228. sqlspec/storage/__init__.py +46 -0
  229. sqlspec/storage/_utils.py +104 -0
  230. sqlspec/storage/backends/__init__.py +1 -0
  231. sqlspec/storage/backends/base.py +163 -0
  232. sqlspec/storage/backends/fsspec.py +398 -0
  233. sqlspec/storage/backends/local.py +377 -0
  234. sqlspec/storage/backends/obstore.py +580 -0
  235. sqlspec/storage/errors.py +104 -0
  236. sqlspec/storage/pipeline.py +604 -0
  237. sqlspec/storage/registry.py +289 -0
  238. sqlspec/typing.py +219 -0
  239. sqlspec/utils/__init__.py +31 -0
  240. sqlspec/utils/arrow_helpers.py +95 -0
  241. sqlspec/utils/config_resolver.py +153 -0
  242. sqlspec/utils/correlation.py +132 -0
  243. sqlspec/utils/data_transformation.py +114 -0
  244. sqlspec/utils/dependencies.py +79 -0
  245. sqlspec/utils/deprecation.py +113 -0
  246. sqlspec/utils/fixtures.py +250 -0
  247. sqlspec/utils/logging.py +172 -0
  248. sqlspec/utils/module_loader.py +273 -0
  249. sqlspec/utils/portal.py +325 -0
  250. sqlspec/utils/schema.py +288 -0
  251. sqlspec/utils/serializers.py +396 -0
  252. sqlspec/utils/singleton.py +41 -0
  253. sqlspec/utils/sync_tools.py +277 -0
  254. sqlspec/utils/text.py +108 -0
  255. sqlspec/utils/type_converters.py +99 -0
  256. sqlspec/utils/type_guards.py +1324 -0
  257. sqlspec/utils/version.py +444 -0
  258. sqlspec-0.32.0.dist-info/METADATA +202 -0
  259. sqlspec-0.32.0.dist-info/RECORD +262 -0
  260. sqlspec-0.32.0.dist-info/WHEEL +4 -0
  261. sqlspec-0.32.0.dist-info/entry_points.txt +2 -0
  262. sqlspec-0.32.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,161 @@
1
+ """UPDATE statement builder.
2
+
3
+ Provides a fluent interface for building SQL UPDATE queries with
4
+ parameter binding and validation.
5
+ """
6
+
7
+ from typing import TYPE_CHECKING, Any, cast
8
+
9
+ from sqlglot import exp
10
+ from typing_extensions import Self
11
+
12
+ from sqlspec.builder._base import QueryBuilder, SafeQuery
13
+ from sqlspec.builder._dml import UpdateFromClauseMixin, UpdateSetClauseMixin, UpdateTableClauseMixin
14
+ from sqlspec.builder._join import build_join_clause
15
+ from sqlspec.builder._select import ReturningClauseMixin, WhereClauseMixin
16
+ from sqlspec.core import SQLResult
17
+ from sqlspec.exceptions import SQLBuilderError
18
+
19
+ if TYPE_CHECKING:
20
+ from sqlglot.dialects.dialect import DialectType
21
+
22
+ from sqlspec.builder._select import Select
23
+ from sqlspec.protocols import SQLBuilderProtocol
24
+
25
+ __all__ = ("Update",)
26
+
27
+
28
+ class Update(
29
+ QueryBuilder,
30
+ WhereClauseMixin,
31
+ ReturningClauseMixin,
32
+ UpdateSetClauseMixin,
33
+ UpdateFromClauseMixin,
34
+ UpdateTableClauseMixin,
35
+ ):
36
+ """Builder for UPDATE statements.
37
+
38
+ Constructs SQL UPDATE statements with parameter binding and validation.
39
+
40
+ Example:
41
+ ```python
42
+ update_query = (
43
+ Update()
44
+ .table("users")
45
+ .set_(name="John Doe")
46
+ .set_(email="john@example.com")
47
+ .where("id = 1")
48
+ )
49
+
50
+ update_query = (
51
+ Update("users").set_(name="John Doe").where("id = 1")
52
+ )
53
+
54
+ update_query = (
55
+ Update()
56
+ .table("users")
57
+ .set_(status="active")
58
+ .where_eq("id", 123)
59
+ )
60
+
61
+ update_query = (
62
+ Update()
63
+ .table("users", "u")
64
+ .set_(name="Updated Name")
65
+ .from_("profiles", "p")
66
+ .where("u.id = p.user_id AND p.is_verified = true")
67
+ )
68
+ ```
69
+ """
70
+
71
+ __slots__ = ()
72
+ _expression: exp.Expression | None
73
+
74
+ def __init__(self, table: str | None = None, **kwargs: Any) -> None:
75
+ """Initialize UPDATE with optional table.
76
+
77
+ Args:
78
+ table: Target table name
79
+ **kwargs: Additional QueryBuilder arguments
80
+ """
81
+ super().__init__(**kwargs)
82
+ self._initialize_expression()
83
+
84
+ if table:
85
+ self.table(table)
86
+
87
+ @property
88
+ def _expected_result_type(self) -> "type[SQLResult]":
89
+ """Return the expected result type for this builder."""
90
+ return SQLResult
91
+
92
+ def _create_base_expression(self) -> exp.Update:
93
+ """Create a base UPDATE expression.
94
+
95
+ Returns:
96
+ A new sqlglot Update expression with empty clauses.
97
+ """
98
+ return exp.Update(this=None, expressions=[], joins=[])
99
+
100
+ def join(
101
+ self,
102
+ table: "str | exp.Expression | Select",
103
+ on: "str | exp.Expression",
104
+ alias: "str | None" = None,
105
+ join_type: str = "INNER",
106
+ ) -> "Self":
107
+ """Add JOIN clause to the UPDATE statement.
108
+
109
+ Args:
110
+ table: The table name, expression, or subquery to join.
111
+ on: The JOIN condition.
112
+ alias: Optional alias for the joined table.
113
+ join_type: Type of join (INNER, LEFT, RIGHT, FULL).
114
+
115
+ Returns:
116
+ The current builder instance for method chaining.
117
+
118
+ Raises:
119
+ SQLBuilderError: If the current expression is not an UPDATE statement.
120
+ """
121
+ if self._expression is None or not isinstance(self._expression, exp.Update):
122
+ msg = "Cannot add JOIN clause to non-UPDATE expression."
123
+ raise SQLBuilderError(msg)
124
+
125
+ join_expr = build_join_clause(cast("SQLBuilderProtocol", self), table, on, alias, join_type)
126
+
127
+ if not self._expression.args.get("joins"):
128
+ self._expression.set("joins", [])
129
+ self._expression.args["joins"].append(join_expr)
130
+
131
+ return self
132
+
133
+ def build(self, dialect: "DialectType" = None) -> "SafeQuery":
134
+ """Build the UPDATE query with validation.
135
+
136
+ Args:
137
+ dialect: Optional dialect override for SQL generation.
138
+
139
+ Returns:
140
+ SafeQuery: The built query with SQL and parameters.
141
+
142
+ Raises:
143
+ SQLBuilderError: If no table is set or expression is not an UPDATE.
144
+ """
145
+ if self._expression is None:
146
+ msg = "UPDATE expression not initialized."
147
+ raise SQLBuilderError(msg)
148
+
149
+ if not isinstance(self._expression, exp.Update):
150
+ msg = "No UPDATE expression to build or expression is of the wrong type."
151
+ raise SQLBuilderError(msg)
152
+
153
+ if self._expression.this is None:
154
+ msg = "No table specified for UPDATE statement."
155
+ raise SQLBuilderError(msg)
156
+
157
+ if not self._expression.args.get("expressions"):
158
+ msg = "At least one SET clause must be specified for UPDATE statement."
159
+ raise SQLBuilderError(msg)
160
+
161
+ return super().build(dialect=dialect)
@@ -0,0 +1,259 @@
1
+ """Custom SQLGlot expressions for vector distance operations.
2
+
3
+ Provides dialect-specific SQL generation for vector similarity search
4
+ across PostgreSQL (pgvector), MySQL 9+, Oracle 23ai+, BigQuery, and Spanner.
5
+ """
6
+
7
+ from contextlib import suppress
8
+ from typing import Any
9
+
10
+ from sqlglot import exp
11
+
12
+ __all__ = ("VectorDistance",)
13
+
14
+
15
+ class VectorDistance(exp.Expression):
16
+ """Vector distance expression with dialect-specific generation.
17
+
18
+ Generates database-specific SQL for vector distance calculations:
19
+ - PostgreSQL (pgvector): Operators <->, <=>, <#>
20
+ - MySQL 9+: DISTANCE(col, vec, 'METRIC') function
21
+ - Oracle 23ai+: VECTOR_DISTANCE(col, vec, METRIC) function
22
+ - Generic: VECTOR_DISTANCE(col, vec, 'METRIC') function
23
+
24
+ The metric is stored as a raw string attribute (not parametrized) and drives
25
+ dialect-specific generation at SQL build time.
26
+ """
27
+
28
+ arg_types = {"this": True, "expression": True, "metric": False}
29
+
30
+ def __init__(self, **args: Any) -> None:
31
+ """Initialize VectorDistance with metric stored in args."""
32
+ metric_value = args.get("metric", "euclidean")
33
+ if isinstance(metric_value, exp.Literal):
34
+ metric_value = str(metric_value.this).lower()
35
+ elif isinstance(metric_value, exp.Identifier):
36
+ metric_value = metric_value.this.lower()
37
+ elif isinstance(metric_value, str):
38
+ metric_value = metric_value.lower()
39
+ else:
40
+ metric_value = "euclidean"
41
+
42
+ args["metric"] = exp.Identifier(this=metric_value)
43
+ super().__init__(**args)
44
+
45
+ @property
46
+ def left(self) -> "exp.Expression":
47
+ """Get the left operand (column)."""
48
+ result: exp.Expression = self.this
49
+ return result
50
+
51
+ @property
52
+ def right(self) -> "exp.Expression":
53
+ """Get the right operand (vector value)."""
54
+ result: exp.Expression = self.expression
55
+ return result
56
+
57
+ @property
58
+ def metric(self) -> str:
59
+ """Get the distance metric as raw string (not parametrized)."""
60
+ metric_expr = self.args.get("metric")
61
+ if isinstance(metric_expr, exp.Identifier):
62
+ metric_name: str = metric_expr.this
63
+ return metric_name.lower()
64
+ return "euclidean"
65
+
66
+ def sql(self, dialect: "Any | None" = None, **opts: Any) -> str:
67
+ """Generate dialect-specific SQL.
68
+
69
+ This overrides the default sql() method to provide custom
70
+ dialect-specific generation for vector distance operations.
71
+
72
+ Args:
73
+ dialect: Target SQL dialect (postgres, mysql, oracle, bigquery, duckdb, etc.)
74
+ **opts: Additional SQL generation options
75
+
76
+ Returns:
77
+ Dialect-specific SQL string
78
+ """
79
+ dialect_name = str(dialect).lower() if dialect else "generic"
80
+
81
+ left_sql = self.left.sql(dialect=dialect, **opts)
82
+ right_sql = self.right.sql(dialect=dialect, **opts)
83
+ metric = self.metric
84
+
85
+ if dialect_name in {"postgres", "postgresql"}:
86
+ return self._sql_postgres(left_sql, right_sql, metric)
87
+
88
+ if dialect_name == "mysql":
89
+ return self._sql_mysql(left_sql, right_sql, metric)
90
+
91
+ if dialect_name == "oracle":
92
+ return self._sql_oracle(left_sql, right_sql, metric)
93
+
94
+ if dialect_name == "bigquery":
95
+ return self._sql_bigquery(left_sql, right_sql, metric)
96
+
97
+ if dialect_name == "duckdb":
98
+ return self._sql_duckdb(left_sql, right_sql, metric)
99
+
100
+ return self._sql_generic(left_sql, right_sql, metric)
101
+
102
+ def _sql_postgres(self, left: str, right: str, metric: str) -> str:
103
+ """Generate PostgreSQL pgvector operator syntax."""
104
+ operator_map = {"euclidean": "<->", "cosine": "<=>", "inner_product": "<#>"}
105
+
106
+ operator = operator_map.get(metric)
107
+ if operator:
108
+ return f"{left} {operator} {right}"
109
+
110
+ return self._sql_generic(left, right, metric)
111
+
112
+ def _sql_mysql(self, left: str, right: str, metric: str) -> str:
113
+ """Generate MySQL DISTANCE function syntax."""
114
+ metric_map = {"euclidean": "EUCLIDEAN", "cosine": "COSINE", "inner_product": "DOT"}
115
+
116
+ mysql_metric = metric_map.get(metric, "EUCLIDEAN")
117
+
118
+ if ("ARRAY" in right or "[" in right) and "STRING_TO_VECTOR" not in right:
119
+ right = f"STRING_TO_VECTOR({right})"
120
+
121
+ return f"DISTANCE({left}, {right}, '{mysql_metric}')"
122
+
123
+ def _sql_oracle(self, left: str, right: str, metric: str) -> str:
124
+ """Generate Oracle VECTOR_DISTANCE function syntax."""
125
+ metric_map = {
126
+ "euclidean": "EUCLIDEAN",
127
+ "cosine": "COSINE",
128
+ "inner_product": "DOT",
129
+ "euclidean_squared": "EUCLIDEAN_SQUARED",
130
+ }
131
+
132
+ oracle_metric = metric_map.get(metric, "EUCLIDEAN")
133
+
134
+ if isinstance(self.expression, exp.Array):
135
+ values = []
136
+ for expr in self.expression.expressions:
137
+ if isinstance(expr, exp.Literal):
138
+ values.append(str(expr.this))
139
+ else: # pragma: no cover - defensive
140
+ values.append(expr.sql(dialect="oracle"))
141
+ right = f"TO_VECTOR('[{', '.join(values)}]')"
142
+ elif ("ARRAY" in right or "[" in right) and "TO_VECTOR" not in right:
143
+ right = f"TO_VECTOR({right})"
144
+
145
+ return f"VECTOR_DISTANCE({left}, {right}, {oracle_metric})"
146
+
147
+ def _sql_bigquery(self, left: str, right: str, metric: str) -> str:
148
+ """Generate BigQuery vector distance function syntax."""
149
+ function_map = {"euclidean": "EUCLIDEAN_DISTANCE", "cosine": "COSINE_DISTANCE", "inner_product": "DOT_PRODUCT"}
150
+
151
+ function_name = function_map.get(metric)
152
+ if function_name:
153
+ return f"{function_name}({left}, {right})"
154
+
155
+ return self._sql_generic(left, right, metric)
156
+
157
+ def _sql_duckdb(self, left: str, right: str, metric: str) -> str:
158
+ """Generate DuckDB VSS extension function syntax.
159
+
160
+ DuckDB's VSS extension provides:
161
+ - array_distance(): L2 squared distance (euclidean)
162
+ - array_cosine_distance(): Cosine distance (1 - cosine_similarity)
163
+ - array_negative_inner_product(): Negative inner product
164
+
165
+ Note: Array literals must be cast to DOUBLE[] since DuckDB infers
166
+ decimal literals as DECIMAL type, but VSS functions require DOUBLE[].
167
+ """
168
+ function_map = {
169
+ "euclidean": "array_distance",
170
+ "cosine": "array_cosine_distance",
171
+ "inner_product": "array_negative_inner_product",
172
+ }
173
+ target_type = "DOUBLE[]"
174
+ if isinstance(self.expression, exp.Array) and self.expression.expressions:
175
+ target_type = f"DOUBLE[{len(self.expression.expressions)}]"
176
+
177
+ function_name = function_map.get(metric)
178
+ if function_name:
179
+ right_cast = f"CAST({right} AS {target_type})"
180
+ return f"{function_name}({left}, {right_cast})"
181
+
182
+ return self._sql_generic(left, right, metric)
183
+
184
+ def _sql_generic(self, left: str, right: str, metric: str) -> str:
185
+ """Generate generic VECTOR_DISTANCE function syntax."""
186
+ return f"VECTOR_DISTANCE({left}, {right}, '{metric.upper()}')"
187
+
188
+
189
+ def _register_with_sqlglot() -> None:
190
+ """Register VectorDistance with SQLGlot's generator dispatch system."""
191
+ from sqlglot.dialects.bigquery import BigQuery
192
+ from sqlglot.dialects.duckdb import DuckDB
193
+ from sqlglot.dialects.mysql import MySQL
194
+ from sqlglot.dialects.oracle import Oracle
195
+ from sqlglot.dialects.postgres import Postgres
196
+ from sqlglot.generator import Generator
197
+
198
+ spanner_dialect: type | None = None
199
+ spangres_dialect: type | None = None
200
+ with suppress(ImportError):
201
+ from sqlspec.adapters.spanner.dialect import Spangres, Spanner
202
+
203
+ spanner_dialect = Spanner
204
+ spangres_dialect = Spangres
205
+
206
+ def vector_distance_sql_base(generator: "Generator", expression: "VectorDistance") -> str:
207
+ """Base generator for VectorDistance expressions."""
208
+ return expression._sql_generic( # pyright: ignore[reportPrivateUsage]
209
+ generator.sql(expression.left), generator.sql(expression.right), expression.metric
210
+ )
211
+
212
+ def vector_distance_sql_postgres(generator: "Generator", expression: "VectorDistance") -> str:
213
+ """PostgreSQL generator for VectorDistance expressions."""
214
+ return expression._sql_postgres( # pyright: ignore[reportPrivateUsage]
215
+ generator.sql(expression.left), generator.sql(expression.right), expression.metric
216
+ )
217
+
218
+ def vector_distance_sql_mysql(generator: "Generator", expression: "VectorDistance") -> str:
219
+ """MySQL generator for VectorDistance expressions."""
220
+ return expression._sql_mysql(generator.sql(expression.left), generator.sql(expression.right), expression.metric) # pyright: ignore[reportPrivateUsage]
221
+
222
+ def vector_distance_sql_oracle(generator: "Generator", expression: "VectorDistance") -> str:
223
+ """Oracle generator for VectorDistance expressions."""
224
+ return expression._sql_oracle( # pyright: ignore[reportPrivateUsage]
225
+ generator.sql(expression.left), generator.sql(expression.right), expression.metric
226
+ )
227
+
228
+ def vector_distance_sql_bigquery(generator: "Generator", expression: "VectorDistance") -> str:
229
+ """BigQuery generator for VectorDistance expressions."""
230
+ return expression._sql_bigquery( # pyright: ignore[reportPrivateUsage]
231
+ generator.sql(expression.left), generator.sql(expression.right), expression.metric
232
+ )
233
+
234
+ def vector_distance_sql_spanner(generator: "Generator", expression: "VectorDistance") -> str:
235
+ """Spanner generator for VectorDistance expressions (same as BigQuery)."""
236
+ return expression._sql_bigquery( # pyright: ignore[reportPrivateUsage]
237
+ generator.sql(expression.left), generator.sql(expression.right), expression.metric
238
+ )
239
+
240
+ def vector_distance_sql_duckdb(generator: "Generator", expression: "VectorDistance") -> str:
241
+ """DuckDB generator for VectorDistance expressions."""
242
+ return expression._sql_duckdb( # pyright: ignore[reportPrivateUsage]
243
+ generator.sql(expression.left), generator.sql(expression.right), expression.metric
244
+ )
245
+
246
+ Generator.TRANSFORMS[VectorDistance] = vector_distance_sql_base
247
+
248
+ Postgres.Generator.TRANSFORMS[VectorDistance] = vector_distance_sql_postgres
249
+ MySQL.Generator.TRANSFORMS[VectorDistance] = vector_distance_sql_mysql
250
+ Oracle.Generator.TRANSFORMS[VectorDistance] = vector_distance_sql_oracle
251
+ BigQuery.Generator.TRANSFORMS[VectorDistance] = vector_distance_sql_bigquery
252
+ DuckDB.Generator.TRANSFORMS[VectorDistance] = vector_distance_sql_duckdb
253
+ if spanner_dialect is not None:
254
+ spanner_dialect.Generator.TRANSFORMS[VectorDistance] = vector_distance_sql_spanner # type: ignore[attr-defined]
255
+ if spangres_dialect is not None:
256
+ spangres_dialect.Generator.TRANSFORMS[VectorDistance] = vector_distance_sql_postgres # type: ignore[attr-defined]
257
+
258
+
259
+ _register_with_sqlglot()