sqlspec 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlspec/__init__.py +104 -0
- sqlspec/__main__.py +12 -0
- sqlspec/__metadata__.py +14 -0
- sqlspec/_serialization.py +312 -0
- sqlspec/_typing.py +784 -0
- sqlspec/adapters/__init__.py +0 -0
- sqlspec/adapters/adbc/__init__.py +5 -0
- sqlspec/adapters/adbc/_types.py +12 -0
- sqlspec/adapters/adbc/adk/__init__.py +5 -0
- sqlspec/adapters/adbc/adk/store.py +880 -0
- sqlspec/adapters/adbc/config.py +436 -0
- sqlspec/adapters/adbc/data_dictionary.py +537 -0
- sqlspec/adapters/adbc/driver.py +841 -0
- sqlspec/adapters/adbc/litestar/__init__.py +5 -0
- sqlspec/adapters/adbc/litestar/store.py +504 -0
- sqlspec/adapters/adbc/type_converter.py +153 -0
- sqlspec/adapters/aiosqlite/__init__.py +29 -0
- sqlspec/adapters/aiosqlite/_types.py +13 -0
- sqlspec/adapters/aiosqlite/adk/__init__.py +5 -0
- sqlspec/adapters/aiosqlite/adk/store.py +536 -0
- sqlspec/adapters/aiosqlite/config.py +310 -0
- sqlspec/adapters/aiosqlite/data_dictionary.py +260 -0
- sqlspec/adapters/aiosqlite/driver.py +463 -0
- sqlspec/adapters/aiosqlite/litestar/__init__.py +5 -0
- sqlspec/adapters/aiosqlite/litestar/store.py +281 -0
- sqlspec/adapters/aiosqlite/pool.py +500 -0
- sqlspec/adapters/asyncmy/__init__.py +25 -0
- sqlspec/adapters/asyncmy/_types.py +12 -0
- sqlspec/adapters/asyncmy/adk/__init__.py +5 -0
- sqlspec/adapters/asyncmy/adk/store.py +503 -0
- sqlspec/adapters/asyncmy/config.py +246 -0
- sqlspec/adapters/asyncmy/data_dictionary.py +241 -0
- sqlspec/adapters/asyncmy/driver.py +632 -0
- sqlspec/adapters/asyncmy/litestar/__init__.py +5 -0
- sqlspec/adapters/asyncmy/litestar/store.py +296 -0
- sqlspec/adapters/asyncpg/__init__.py +23 -0
- sqlspec/adapters/asyncpg/_type_handlers.py +76 -0
- sqlspec/adapters/asyncpg/_types.py +23 -0
- sqlspec/adapters/asyncpg/adk/__init__.py +5 -0
- sqlspec/adapters/asyncpg/adk/store.py +460 -0
- sqlspec/adapters/asyncpg/config.py +464 -0
- sqlspec/adapters/asyncpg/data_dictionary.py +321 -0
- sqlspec/adapters/asyncpg/driver.py +720 -0
- sqlspec/adapters/asyncpg/litestar/__init__.py +5 -0
- sqlspec/adapters/asyncpg/litestar/store.py +253 -0
- sqlspec/adapters/bigquery/__init__.py +18 -0
- sqlspec/adapters/bigquery/_types.py +12 -0
- sqlspec/adapters/bigquery/adk/__init__.py +5 -0
- sqlspec/adapters/bigquery/adk/store.py +585 -0
- sqlspec/adapters/bigquery/config.py +298 -0
- sqlspec/adapters/bigquery/data_dictionary.py +256 -0
- sqlspec/adapters/bigquery/driver.py +1073 -0
- sqlspec/adapters/bigquery/litestar/__init__.py +5 -0
- sqlspec/adapters/bigquery/litestar/store.py +327 -0
- sqlspec/adapters/bigquery/type_converter.py +125 -0
- sqlspec/adapters/duckdb/__init__.py +24 -0
- sqlspec/adapters/duckdb/_types.py +12 -0
- sqlspec/adapters/duckdb/adk/__init__.py +14 -0
- sqlspec/adapters/duckdb/adk/store.py +563 -0
- sqlspec/adapters/duckdb/config.py +396 -0
- sqlspec/adapters/duckdb/data_dictionary.py +264 -0
- sqlspec/adapters/duckdb/driver.py +604 -0
- sqlspec/adapters/duckdb/litestar/__init__.py +5 -0
- sqlspec/adapters/duckdb/litestar/store.py +332 -0
- sqlspec/adapters/duckdb/pool.py +273 -0
- sqlspec/adapters/duckdb/type_converter.py +133 -0
- sqlspec/adapters/oracledb/__init__.py +32 -0
- sqlspec/adapters/oracledb/_numpy_handlers.py +133 -0
- sqlspec/adapters/oracledb/_types.py +39 -0
- sqlspec/adapters/oracledb/_uuid_handlers.py +130 -0
- sqlspec/adapters/oracledb/adk/__init__.py +5 -0
- sqlspec/adapters/oracledb/adk/store.py +1632 -0
- sqlspec/adapters/oracledb/config.py +469 -0
- sqlspec/adapters/oracledb/data_dictionary.py +717 -0
- sqlspec/adapters/oracledb/driver.py +1493 -0
- sqlspec/adapters/oracledb/litestar/__init__.py +5 -0
- sqlspec/adapters/oracledb/litestar/store.py +765 -0
- sqlspec/adapters/oracledb/migrations.py +532 -0
- sqlspec/adapters/oracledb/type_converter.py +207 -0
- sqlspec/adapters/psqlpy/__init__.py +16 -0
- sqlspec/adapters/psqlpy/_type_handlers.py +44 -0
- sqlspec/adapters/psqlpy/_types.py +12 -0
- sqlspec/adapters/psqlpy/adk/__init__.py +5 -0
- sqlspec/adapters/psqlpy/adk/store.py +483 -0
- sqlspec/adapters/psqlpy/config.py +271 -0
- sqlspec/adapters/psqlpy/data_dictionary.py +179 -0
- sqlspec/adapters/psqlpy/driver.py +892 -0
- sqlspec/adapters/psqlpy/litestar/__init__.py +5 -0
- sqlspec/adapters/psqlpy/litestar/store.py +272 -0
- sqlspec/adapters/psqlpy/type_converter.py +102 -0
- sqlspec/adapters/psycopg/__init__.py +32 -0
- sqlspec/adapters/psycopg/_type_handlers.py +90 -0
- sqlspec/adapters/psycopg/_types.py +18 -0
- sqlspec/adapters/psycopg/adk/__init__.py +5 -0
- sqlspec/adapters/psycopg/adk/store.py +962 -0
- sqlspec/adapters/psycopg/config.py +487 -0
- sqlspec/adapters/psycopg/data_dictionary.py +630 -0
- sqlspec/adapters/psycopg/driver.py +1336 -0
- sqlspec/adapters/psycopg/litestar/__init__.py +5 -0
- sqlspec/adapters/psycopg/litestar/store.py +554 -0
- sqlspec/adapters/spanner/__init__.py +38 -0
- sqlspec/adapters/spanner/_type_handlers.py +186 -0
- sqlspec/adapters/spanner/_types.py +12 -0
- sqlspec/adapters/spanner/adk/__init__.py +5 -0
- sqlspec/adapters/spanner/adk/store.py +435 -0
- sqlspec/adapters/spanner/config.py +241 -0
- sqlspec/adapters/spanner/data_dictionary.py +95 -0
- sqlspec/adapters/spanner/dialect/__init__.py +6 -0
- sqlspec/adapters/spanner/dialect/_spangres.py +52 -0
- sqlspec/adapters/spanner/dialect/_spanner.py +123 -0
- sqlspec/adapters/spanner/driver.py +366 -0
- sqlspec/adapters/spanner/litestar/__init__.py +5 -0
- sqlspec/adapters/spanner/litestar/store.py +266 -0
- sqlspec/adapters/spanner/type_converter.py +46 -0
- sqlspec/adapters/sqlite/__init__.py +18 -0
- sqlspec/adapters/sqlite/_type_handlers.py +86 -0
- sqlspec/adapters/sqlite/_types.py +11 -0
- sqlspec/adapters/sqlite/adk/__init__.py +5 -0
- sqlspec/adapters/sqlite/adk/store.py +582 -0
- sqlspec/adapters/sqlite/config.py +221 -0
- sqlspec/adapters/sqlite/data_dictionary.py +256 -0
- sqlspec/adapters/sqlite/driver.py +527 -0
- sqlspec/adapters/sqlite/litestar/__init__.py +5 -0
- sqlspec/adapters/sqlite/litestar/store.py +318 -0
- sqlspec/adapters/sqlite/pool.py +140 -0
- sqlspec/base.py +811 -0
- sqlspec/builder/__init__.py +146 -0
- sqlspec/builder/_base.py +900 -0
- sqlspec/builder/_column.py +517 -0
- sqlspec/builder/_ddl.py +1642 -0
- sqlspec/builder/_delete.py +84 -0
- sqlspec/builder/_dml.py +381 -0
- sqlspec/builder/_expression_wrappers.py +46 -0
- sqlspec/builder/_factory.py +1537 -0
- sqlspec/builder/_insert.py +315 -0
- sqlspec/builder/_join.py +375 -0
- sqlspec/builder/_merge.py +848 -0
- sqlspec/builder/_parsing_utils.py +297 -0
- sqlspec/builder/_select.py +1615 -0
- sqlspec/builder/_update.py +161 -0
- sqlspec/builder/_vector_expressions.py +259 -0
- sqlspec/cli.py +764 -0
- sqlspec/config.py +1540 -0
- sqlspec/core/__init__.py +305 -0
- sqlspec/core/cache.py +785 -0
- sqlspec/core/compiler.py +603 -0
- sqlspec/core/filters.py +872 -0
- sqlspec/core/hashing.py +274 -0
- sqlspec/core/metrics.py +83 -0
- sqlspec/core/parameters/__init__.py +64 -0
- sqlspec/core/parameters/_alignment.py +266 -0
- sqlspec/core/parameters/_converter.py +413 -0
- sqlspec/core/parameters/_processor.py +341 -0
- sqlspec/core/parameters/_registry.py +201 -0
- sqlspec/core/parameters/_transformers.py +226 -0
- sqlspec/core/parameters/_types.py +430 -0
- sqlspec/core/parameters/_validator.py +123 -0
- sqlspec/core/pipeline.py +187 -0
- sqlspec/core/result.py +1124 -0
- sqlspec/core/splitter.py +940 -0
- sqlspec/core/stack.py +163 -0
- sqlspec/core/statement.py +835 -0
- sqlspec/core/type_conversion.py +235 -0
- sqlspec/driver/__init__.py +36 -0
- sqlspec/driver/_async.py +1027 -0
- sqlspec/driver/_common.py +1236 -0
- sqlspec/driver/_sync.py +1025 -0
- sqlspec/driver/mixins/__init__.py +7 -0
- sqlspec/driver/mixins/_result_tools.py +61 -0
- sqlspec/driver/mixins/_sql_translator.py +122 -0
- sqlspec/driver/mixins/_storage.py +311 -0
- sqlspec/exceptions.py +321 -0
- sqlspec/extensions/__init__.py +0 -0
- sqlspec/extensions/adk/__init__.py +53 -0
- sqlspec/extensions/adk/_types.py +51 -0
- sqlspec/extensions/adk/converters.py +172 -0
- sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +144 -0
- sqlspec/extensions/adk/migrations/__init__.py +0 -0
- sqlspec/extensions/adk/service.py +181 -0
- sqlspec/extensions/adk/store.py +536 -0
- sqlspec/extensions/aiosql/__init__.py +10 -0
- sqlspec/extensions/aiosql/adapter.py +471 -0
- sqlspec/extensions/fastapi/__init__.py +19 -0
- sqlspec/extensions/fastapi/extension.py +341 -0
- sqlspec/extensions/fastapi/providers.py +543 -0
- sqlspec/extensions/flask/__init__.py +36 -0
- sqlspec/extensions/flask/_state.py +72 -0
- sqlspec/extensions/flask/_utils.py +40 -0
- sqlspec/extensions/flask/extension.py +402 -0
- sqlspec/extensions/litestar/__init__.py +23 -0
- sqlspec/extensions/litestar/_utils.py +52 -0
- sqlspec/extensions/litestar/cli.py +92 -0
- sqlspec/extensions/litestar/config.py +90 -0
- sqlspec/extensions/litestar/handlers.py +316 -0
- sqlspec/extensions/litestar/migrations/0001_create_session_table.py +137 -0
- sqlspec/extensions/litestar/migrations/__init__.py +3 -0
- sqlspec/extensions/litestar/plugin.py +638 -0
- sqlspec/extensions/litestar/providers.py +454 -0
- sqlspec/extensions/litestar/store.py +265 -0
- sqlspec/extensions/otel/__init__.py +58 -0
- sqlspec/extensions/prometheus/__init__.py +107 -0
- sqlspec/extensions/starlette/__init__.py +10 -0
- sqlspec/extensions/starlette/_state.py +26 -0
- sqlspec/extensions/starlette/_utils.py +52 -0
- sqlspec/extensions/starlette/extension.py +257 -0
- sqlspec/extensions/starlette/middleware.py +154 -0
- sqlspec/loader.py +716 -0
- sqlspec/migrations/__init__.py +36 -0
- sqlspec/migrations/base.py +728 -0
- sqlspec/migrations/commands.py +1140 -0
- sqlspec/migrations/context.py +142 -0
- sqlspec/migrations/fix.py +203 -0
- sqlspec/migrations/loaders.py +450 -0
- sqlspec/migrations/runner.py +1024 -0
- sqlspec/migrations/templates.py +234 -0
- sqlspec/migrations/tracker.py +403 -0
- sqlspec/migrations/utils.py +256 -0
- sqlspec/migrations/validation.py +203 -0
- sqlspec/observability/__init__.py +22 -0
- sqlspec/observability/_config.py +228 -0
- sqlspec/observability/_diagnostics.py +67 -0
- sqlspec/observability/_dispatcher.py +151 -0
- sqlspec/observability/_observer.py +180 -0
- sqlspec/observability/_runtime.py +381 -0
- sqlspec/observability/_spans.py +158 -0
- sqlspec/protocols.py +530 -0
- sqlspec/py.typed +0 -0
- sqlspec/storage/__init__.py +46 -0
- sqlspec/storage/_utils.py +104 -0
- sqlspec/storage/backends/__init__.py +1 -0
- sqlspec/storage/backends/base.py +163 -0
- sqlspec/storage/backends/fsspec.py +398 -0
- sqlspec/storage/backends/local.py +377 -0
- sqlspec/storage/backends/obstore.py +580 -0
- sqlspec/storage/errors.py +104 -0
- sqlspec/storage/pipeline.py +604 -0
- sqlspec/storage/registry.py +289 -0
- sqlspec/typing.py +219 -0
- sqlspec/utils/__init__.py +31 -0
- sqlspec/utils/arrow_helpers.py +95 -0
- sqlspec/utils/config_resolver.py +153 -0
- sqlspec/utils/correlation.py +132 -0
- sqlspec/utils/data_transformation.py +114 -0
- sqlspec/utils/dependencies.py +79 -0
- sqlspec/utils/deprecation.py +113 -0
- sqlspec/utils/fixtures.py +250 -0
- sqlspec/utils/logging.py +172 -0
- sqlspec/utils/module_loader.py +273 -0
- sqlspec/utils/portal.py +325 -0
- sqlspec/utils/schema.py +288 -0
- sqlspec/utils/serializers.py +396 -0
- sqlspec/utils/singleton.py +41 -0
- sqlspec/utils/sync_tools.py +277 -0
- sqlspec/utils/text.py +108 -0
- sqlspec/utils/type_converters.py +99 -0
- sqlspec/utils/type_guards.py +1324 -0
- sqlspec/utils/version.py +444 -0
- sqlspec-0.32.0.dist-info/METADATA +202 -0
- sqlspec-0.32.0.dist-info/RECORD +262 -0
- sqlspec-0.32.0.dist-info/WHEEL +4 -0
- sqlspec-0.32.0.dist-info/entry_points.txt +2 -0
- sqlspec-0.32.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
"""UPDATE statement builder.
|
|
2
|
+
|
|
3
|
+
Provides a fluent interface for building SQL UPDATE queries with
|
|
4
|
+
parameter binding and validation.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
8
|
+
|
|
9
|
+
from sqlglot import exp
|
|
10
|
+
from typing_extensions import Self
|
|
11
|
+
|
|
12
|
+
from sqlspec.builder._base import QueryBuilder, SafeQuery
|
|
13
|
+
from sqlspec.builder._dml import UpdateFromClauseMixin, UpdateSetClauseMixin, UpdateTableClauseMixin
|
|
14
|
+
from sqlspec.builder._join import build_join_clause
|
|
15
|
+
from sqlspec.builder._select import ReturningClauseMixin, WhereClauseMixin
|
|
16
|
+
from sqlspec.core import SQLResult
|
|
17
|
+
from sqlspec.exceptions import SQLBuilderError
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from sqlglot.dialects.dialect import DialectType
|
|
21
|
+
|
|
22
|
+
from sqlspec.builder._select import Select
|
|
23
|
+
from sqlspec.protocols import SQLBuilderProtocol
|
|
24
|
+
|
|
25
|
+
__all__ = ("Update",)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Update(
|
|
29
|
+
QueryBuilder,
|
|
30
|
+
WhereClauseMixin,
|
|
31
|
+
ReturningClauseMixin,
|
|
32
|
+
UpdateSetClauseMixin,
|
|
33
|
+
UpdateFromClauseMixin,
|
|
34
|
+
UpdateTableClauseMixin,
|
|
35
|
+
):
|
|
36
|
+
"""Builder for UPDATE statements.
|
|
37
|
+
|
|
38
|
+
Constructs SQL UPDATE statements with parameter binding and validation.
|
|
39
|
+
|
|
40
|
+
Example:
|
|
41
|
+
```python
|
|
42
|
+
update_query = (
|
|
43
|
+
Update()
|
|
44
|
+
.table("users")
|
|
45
|
+
.set_(name="John Doe")
|
|
46
|
+
.set_(email="john@example.com")
|
|
47
|
+
.where("id = 1")
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
update_query = (
|
|
51
|
+
Update("users").set_(name="John Doe").where("id = 1")
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
update_query = (
|
|
55
|
+
Update()
|
|
56
|
+
.table("users")
|
|
57
|
+
.set_(status="active")
|
|
58
|
+
.where_eq("id", 123)
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
update_query = (
|
|
62
|
+
Update()
|
|
63
|
+
.table("users", "u")
|
|
64
|
+
.set_(name="Updated Name")
|
|
65
|
+
.from_("profiles", "p")
|
|
66
|
+
.where("u.id = p.user_id AND p.is_verified = true")
|
|
67
|
+
)
|
|
68
|
+
```
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
__slots__ = ()
|
|
72
|
+
_expression: exp.Expression | None
|
|
73
|
+
|
|
74
|
+
def __init__(self, table: str | None = None, **kwargs: Any) -> None:
|
|
75
|
+
"""Initialize UPDATE with optional table.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
table: Target table name
|
|
79
|
+
**kwargs: Additional QueryBuilder arguments
|
|
80
|
+
"""
|
|
81
|
+
super().__init__(**kwargs)
|
|
82
|
+
self._initialize_expression()
|
|
83
|
+
|
|
84
|
+
if table:
|
|
85
|
+
self.table(table)
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def _expected_result_type(self) -> "type[SQLResult]":
|
|
89
|
+
"""Return the expected result type for this builder."""
|
|
90
|
+
return SQLResult
|
|
91
|
+
|
|
92
|
+
def _create_base_expression(self) -> exp.Update:
|
|
93
|
+
"""Create a base UPDATE expression.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
A new sqlglot Update expression with empty clauses.
|
|
97
|
+
"""
|
|
98
|
+
return exp.Update(this=None, expressions=[], joins=[])
|
|
99
|
+
|
|
100
|
+
def join(
|
|
101
|
+
self,
|
|
102
|
+
table: "str | exp.Expression | Select",
|
|
103
|
+
on: "str | exp.Expression",
|
|
104
|
+
alias: "str | None" = None,
|
|
105
|
+
join_type: str = "INNER",
|
|
106
|
+
) -> "Self":
|
|
107
|
+
"""Add JOIN clause to the UPDATE statement.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
table: The table name, expression, or subquery to join.
|
|
111
|
+
on: The JOIN condition.
|
|
112
|
+
alias: Optional alias for the joined table.
|
|
113
|
+
join_type: Type of join (INNER, LEFT, RIGHT, FULL).
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
The current builder instance for method chaining.
|
|
117
|
+
|
|
118
|
+
Raises:
|
|
119
|
+
SQLBuilderError: If the current expression is not an UPDATE statement.
|
|
120
|
+
"""
|
|
121
|
+
if self._expression is None or not isinstance(self._expression, exp.Update):
|
|
122
|
+
msg = "Cannot add JOIN clause to non-UPDATE expression."
|
|
123
|
+
raise SQLBuilderError(msg)
|
|
124
|
+
|
|
125
|
+
join_expr = build_join_clause(cast("SQLBuilderProtocol", self), table, on, alias, join_type)
|
|
126
|
+
|
|
127
|
+
if not self._expression.args.get("joins"):
|
|
128
|
+
self._expression.set("joins", [])
|
|
129
|
+
self._expression.args["joins"].append(join_expr)
|
|
130
|
+
|
|
131
|
+
return self
|
|
132
|
+
|
|
133
|
+
def build(self, dialect: "DialectType" = None) -> "SafeQuery":
|
|
134
|
+
"""Build the UPDATE query with validation.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
dialect: Optional dialect override for SQL generation.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
SafeQuery: The built query with SQL and parameters.
|
|
141
|
+
|
|
142
|
+
Raises:
|
|
143
|
+
SQLBuilderError: If no table is set or expression is not an UPDATE.
|
|
144
|
+
"""
|
|
145
|
+
if self._expression is None:
|
|
146
|
+
msg = "UPDATE expression not initialized."
|
|
147
|
+
raise SQLBuilderError(msg)
|
|
148
|
+
|
|
149
|
+
if not isinstance(self._expression, exp.Update):
|
|
150
|
+
msg = "No UPDATE expression to build or expression is of the wrong type."
|
|
151
|
+
raise SQLBuilderError(msg)
|
|
152
|
+
|
|
153
|
+
if self._expression.this is None:
|
|
154
|
+
msg = "No table specified for UPDATE statement."
|
|
155
|
+
raise SQLBuilderError(msg)
|
|
156
|
+
|
|
157
|
+
if not self._expression.args.get("expressions"):
|
|
158
|
+
msg = "At least one SET clause must be specified for UPDATE statement."
|
|
159
|
+
raise SQLBuilderError(msg)
|
|
160
|
+
|
|
161
|
+
return super().build(dialect=dialect)
|
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
"""Custom SQLGlot expressions for vector distance operations.
|
|
2
|
+
|
|
3
|
+
Provides dialect-specific SQL generation for vector similarity search
|
|
4
|
+
across PostgreSQL (pgvector), MySQL 9+, Oracle 23ai+, BigQuery, and Spanner.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from contextlib import suppress
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from sqlglot import exp
|
|
11
|
+
|
|
12
|
+
__all__ = ("VectorDistance",)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class VectorDistance(exp.Expression):
|
|
16
|
+
"""Vector distance expression with dialect-specific generation.
|
|
17
|
+
|
|
18
|
+
Generates database-specific SQL for vector distance calculations:
|
|
19
|
+
- PostgreSQL (pgvector): Operators <->, <=>, <#>
|
|
20
|
+
- MySQL 9+: DISTANCE(col, vec, 'METRIC') function
|
|
21
|
+
- Oracle 23ai+: VECTOR_DISTANCE(col, vec, METRIC) function
|
|
22
|
+
- Generic: VECTOR_DISTANCE(col, vec, 'METRIC') function
|
|
23
|
+
|
|
24
|
+
The metric is stored as a raw string attribute (not parametrized) and drives
|
|
25
|
+
dialect-specific generation at SQL build time.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
arg_types = {"this": True, "expression": True, "metric": False}
|
|
29
|
+
|
|
30
|
+
def __init__(self, **args: Any) -> None:
|
|
31
|
+
"""Initialize VectorDistance with metric stored in args."""
|
|
32
|
+
metric_value = args.get("metric", "euclidean")
|
|
33
|
+
if isinstance(metric_value, exp.Literal):
|
|
34
|
+
metric_value = str(metric_value.this).lower()
|
|
35
|
+
elif isinstance(metric_value, exp.Identifier):
|
|
36
|
+
metric_value = metric_value.this.lower()
|
|
37
|
+
elif isinstance(metric_value, str):
|
|
38
|
+
metric_value = metric_value.lower()
|
|
39
|
+
else:
|
|
40
|
+
metric_value = "euclidean"
|
|
41
|
+
|
|
42
|
+
args["metric"] = exp.Identifier(this=metric_value)
|
|
43
|
+
super().__init__(**args)
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def left(self) -> "exp.Expression":
|
|
47
|
+
"""Get the left operand (column)."""
|
|
48
|
+
result: exp.Expression = self.this
|
|
49
|
+
return result
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def right(self) -> "exp.Expression":
|
|
53
|
+
"""Get the right operand (vector value)."""
|
|
54
|
+
result: exp.Expression = self.expression
|
|
55
|
+
return result
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def metric(self) -> str:
|
|
59
|
+
"""Get the distance metric as raw string (not parametrized)."""
|
|
60
|
+
metric_expr = self.args.get("metric")
|
|
61
|
+
if isinstance(metric_expr, exp.Identifier):
|
|
62
|
+
metric_name: str = metric_expr.this
|
|
63
|
+
return metric_name.lower()
|
|
64
|
+
return "euclidean"
|
|
65
|
+
|
|
66
|
+
def sql(self, dialect: "Any | None" = None, **opts: Any) -> str:
|
|
67
|
+
"""Generate dialect-specific SQL.
|
|
68
|
+
|
|
69
|
+
This overrides the default sql() method to provide custom
|
|
70
|
+
dialect-specific generation for vector distance operations.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
dialect: Target SQL dialect (postgres, mysql, oracle, bigquery, duckdb, etc.)
|
|
74
|
+
**opts: Additional SQL generation options
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Dialect-specific SQL string
|
|
78
|
+
"""
|
|
79
|
+
dialect_name = str(dialect).lower() if dialect else "generic"
|
|
80
|
+
|
|
81
|
+
left_sql = self.left.sql(dialect=dialect, **opts)
|
|
82
|
+
right_sql = self.right.sql(dialect=dialect, **opts)
|
|
83
|
+
metric = self.metric
|
|
84
|
+
|
|
85
|
+
if dialect_name in {"postgres", "postgresql"}:
|
|
86
|
+
return self._sql_postgres(left_sql, right_sql, metric)
|
|
87
|
+
|
|
88
|
+
if dialect_name == "mysql":
|
|
89
|
+
return self._sql_mysql(left_sql, right_sql, metric)
|
|
90
|
+
|
|
91
|
+
if dialect_name == "oracle":
|
|
92
|
+
return self._sql_oracle(left_sql, right_sql, metric)
|
|
93
|
+
|
|
94
|
+
if dialect_name == "bigquery":
|
|
95
|
+
return self._sql_bigquery(left_sql, right_sql, metric)
|
|
96
|
+
|
|
97
|
+
if dialect_name == "duckdb":
|
|
98
|
+
return self._sql_duckdb(left_sql, right_sql, metric)
|
|
99
|
+
|
|
100
|
+
return self._sql_generic(left_sql, right_sql, metric)
|
|
101
|
+
|
|
102
|
+
def _sql_postgres(self, left: str, right: str, metric: str) -> str:
|
|
103
|
+
"""Generate PostgreSQL pgvector operator syntax."""
|
|
104
|
+
operator_map = {"euclidean": "<->", "cosine": "<=>", "inner_product": "<#>"}
|
|
105
|
+
|
|
106
|
+
operator = operator_map.get(metric)
|
|
107
|
+
if operator:
|
|
108
|
+
return f"{left} {operator} {right}"
|
|
109
|
+
|
|
110
|
+
return self._sql_generic(left, right, metric)
|
|
111
|
+
|
|
112
|
+
def _sql_mysql(self, left: str, right: str, metric: str) -> str:
|
|
113
|
+
"""Generate MySQL DISTANCE function syntax."""
|
|
114
|
+
metric_map = {"euclidean": "EUCLIDEAN", "cosine": "COSINE", "inner_product": "DOT"}
|
|
115
|
+
|
|
116
|
+
mysql_metric = metric_map.get(metric, "EUCLIDEAN")
|
|
117
|
+
|
|
118
|
+
if ("ARRAY" in right or "[" in right) and "STRING_TO_VECTOR" not in right:
|
|
119
|
+
right = f"STRING_TO_VECTOR({right})"
|
|
120
|
+
|
|
121
|
+
return f"DISTANCE({left}, {right}, '{mysql_metric}')"
|
|
122
|
+
|
|
123
|
+
def _sql_oracle(self, left: str, right: str, metric: str) -> str:
|
|
124
|
+
"""Generate Oracle VECTOR_DISTANCE function syntax."""
|
|
125
|
+
metric_map = {
|
|
126
|
+
"euclidean": "EUCLIDEAN",
|
|
127
|
+
"cosine": "COSINE",
|
|
128
|
+
"inner_product": "DOT",
|
|
129
|
+
"euclidean_squared": "EUCLIDEAN_SQUARED",
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
oracle_metric = metric_map.get(metric, "EUCLIDEAN")
|
|
133
|
+
|
|
134
|
+
if isinstance(self.expression, exp.Array):
|
|
135
|
+
values = []
|
|
136
|
+
for expr in self.expression.expressions:
|
|
137
|
+
if isinstance(expr, exp.Literal):
|
|
138
|
+
values.append(str(expr.this))
|
|
139
|
+
else: # pragma: no cover - defensive
|
|
140
|
+
values.append(expr.sql(dialect="oracle"))
|
|
141
|
+
right = f"TO_VECTOR('[{', '.join(values)}]')"
|
|
142
|
+
elif ("ARRAY" in right or "[" in right) and "TO_VECTOR" not in right:
|
|
143
|
+
right = f"TO_VECTOR({right})"
|
|
144
|
+
|
|
145
|
+
return f"VECTOR_DISTANCE({left}, {right}, {oracle_metric})"
|
|
146
|
+
|
|
147
|
+
def _sql_bigquery(self, left: str, right: str, metric: str) -> str:
|
|
148
|
+
"""Generate BigQuery vector distance function syntax."""
|
|
149
|
+
function_map = {"euclidean": "EUCLIDEAN_DISTANCE", "cosine": "COSINE_DISTANCE", "inner_product": "DOT_PRODUCT"}
|
|
150
|
+
|
|
151
|
+
function_name = function_map.get(metric)
|
|
152
|
+
if function_name:
|
|
153
|
+
return f"{function_name}({left}, {right})"
|
|
154
|
+
|
|
155
|
+
return self._sql_generic(left, right, metric)
|
|
156
|
+
|
|
157
|
+
def _sql_duckdb(self, left: str, right: str, metric: str) -> str:
|
|
158
|
+
"""Generate DuckDB VSS extension function syntax.
|
|
159
|
+
|
|
160
|
+
DuckDB's VSS extension provides:
|
|
161
|
+
- array_distance(): L2 squared distance (euclidean)
|
|
162
|
+
- array_cosine_distance(): Cosine distance (1 - cosine_similarity)
|
|
163
|
+
- array_negative_inner_product(): Negative inner product
|
|
164
|
+
|
|
165
|
+
Note: Array literals must be cast to DOUBLE[] since DuckDB infers
|
|
166
|
+
decimal literals as DECIMAL type, but VSS functions require DOUBLE[].
|
|
167
|
+
"""
|
|
168
|
+
function_map = {
|
|
169
|
+
"euclidean": "array_distance",
|
|
170
|
+
"cosine": "array_cosine_distance",
|
|
171
|
+
"inner_product": "array_negative_inner_product",
|
|
172
|
+
}
|
|
173
|
+
target_type = "DOUBLE[]"
|
|
174
|
+
if isinstance(self.expression, exp.Array) and self.expression.expressions:
|
|
175
|
+
target_type = f"DOUBLE[{len(self.expression.expressions)}]"
|
|
176
|
+
|
|
177
|
+
function_name = function_map.get(metric)
|
|
178
|
+
if function_name:
|
|
179
|
+
right_cast = f"CAST({right} AS {target_type})"
|
|
180
|
+
return f"{function_name}({left}, {right_cast})"
|
|
181
|
+
|
|
182
|
+
return self._sql_generic(left, right, metric)
|
|
183
|
+
|
|
184
|
+
def _sql_generic(self, left: str, right: str, metric: str) -> str:
|
|
185
|
+
"""Generate generic VECTOR_DISTANCE function syntax."""
|
|
186
|
+
return f"VECTOR_DISTANCE({left}, {right}, '{metric.upper()}')"
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def _register_with_sqlglot() -> None:
|
|
190
|
+
"""Register VectorDistance with SQLGlot's generator dispatch system."""
|
|
191
|
+
from sqlglot.dialects.bigquery import BigQuery
|
|
192
|
+
from sqlglot.dialects.duckdb import DuckDB
|
|
193
|
+
from sqlglot.dialects.mysql import MySQL
|
|
194
|
+
from sqlglot.dialects.oracle import Oracle
|
|
195
|
+
from sqlglot.dialects.postgres import Postgres
|
|
196
|
+
from sqlglot.generator import Generator
|
|
197
|
+
|
|
198
|
+
spanner_dialect: type | None = None
|
|
199
|
+
spangres_dialect: type | None = None
|
|
200
|
+
with suppress(ImportError):
|
|
201
|
+
from sqlspec.adapters.spanner.dialect import Spangres, Spanner
|
|
202
|
+
|
|
203
|
+
spanner_dialect = Spanner
|
|
204
|
+
spangres_dialect = Spangres
|
|
205
|
+
|
|
206
|
+
def vector_distance_sql_base(generator: "Generator", expression: "VectorDistance") -> str:
|
|
207
|
+
"""Base generator for VectorDistance expressions."""
|
|
208
|
+
return expression._sql_generic( # pyright: ignore[reportPrivateUsage]
|
|
209
|
+
generator.sql(expression.left), generator.sql(expression.right), expression.metric
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
def vector_distance_sql_postgres(generator: "Generator", expression: "VectorDistance") -> str:
|
|
213
|
+
"""PostgreSQL generator for VectorDistance expressions."""
|
|
214
|
+
return expression._sql_postgres( # pyright: ignore[reportPrivateUsage]
|
|
215
|
+
generator.sql(expression.left), generator.sql(expression.right), expression.metric
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
def vector_distance_sql_mysql(generator: "Generator", expression: "VectorDistance") -> str:
|
|
219
|
+
"""MySQL generator for VectorDistance expressions."""
|
|
220
|
+
return expression._sql_mysql(generator.sql(expression.left), generator.sql(expression.right), expression.metric) # pyright: ignore[reportPrivateUsage]
|
|
221
|
+
|
|
222
|
+
def vector_distance_sql_oracle(generator: "Generator", expression: "VectorDistance") -> str:
|
|
223
|
+
"""Oracle generator for VectorDistance expressions."""
|
|
224
|
+
return expression._sql_oracle( # pyright: ignore[reportPrivateUsage]
|
|
225
|
+
generator.sql(expression.left), generator.sql(expression.right), expression.metric
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
def vector_distance_sql_bigquery(generator: "Generator", expression: "VectorDistance") -> str:
|
|
229
|
+
"""BigQuery generator for VectorDistance expressions."""
|
|
230
|
+
return expression._sql_bigquery( # pyright: ignore[reportPrivateUsage]
|
|
231
|
+
generator.sql(expression.left), generator.sql(expression.right), expression.metric
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
def vector_distance_sql_spanner(generator: "Generator", expression: "VectorDistance") -> str:
|
|
235
|
+
"""Spanner generator for VectorDistance expressions (same as BigQuery)."""
|
|
236
|
+
return expression._sql_bigquery( # pyright: ignore[reportPrivateUsage]
|
|
237
|
+
generator.sql(expression.left), generator.sql(expression.right), expression.metric
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
def vector_distance_sql_duckdb(generator: "Generator", expression: "VectorDistance") -> str:
|
|
241
|
+
"""DuckDB generator for VectorDistance expressions."""
|
|
242
|
+
return expression._sql_duckdb( # pyright: ignore[reportPrivateUsage]
|
|
243
|
+
generator.sql(expression.left), generator.sql(expression.right), expression.metric
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
Generator.TRANSFORMS[VectorDistance] = vector_distance_sql_base
|
|
247
|
+
|
|
248
|
+
Postgres.Generator.TRANSFORMS[VectorDistance] = vector_distance_sql_postgres
|
|
249
|
+
MySQL.Generator.TRANSFORMS[VectorDistance] = vector_distance_sql_mysql
|
|
250
|
+
Oracle.Generator.TRANSFORMS[VectorDistance] = vector_distance_sql_oracle
|
|
251
|
+
BigQuery.Generator.TRANSFORMS[VectorDistance] = vector_distance_sql_bigquery
|
|
252
|
+
DuckDB.Generator.TRANSFORMS[VectorDistance] = vector_distance_sql_duckdb
|
|
253
|
+
if spanner_dialect is not None:
|
|
254
|
+
spanner_dialect.Generator.TRANSFORMS[VectorDistance] = vector_distance_sql_spanner # type: ignore[attr-defined]
|
|
255
|
+
if spangres_dialect is not None:
|
|
256
|
+
spangres_dialect.Generator.TRANSFORMS[VectorDistance] = vector_distance_sql_postgres # type: ignore[attr-defined]
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
_register_with_sqlglot()
|