sqlalchemy-cratedb 0.38.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlalchemy_cratedb/__init__.py +62 -0
- sqlalchemy_cratedb/compat/__init__.py +0 -0
- sqlalchemy_cratedb/compat/api13.py +156 -0
- sqlalchemy_cratedb/compat/core10.py +264 -0
- sqlalchemy_cratedb/compat/core14.py +359 -0
- sqlalchemy_cratedb/compat/core20.py +447 -0
- sqlalchemy_cratedb/compiler.py +372 -0
- sqlalchemy_cratedb/dialect.py +381 -0
- sqlalchemy_cratedb/predicate.py +99 -0
- sqlalchemy_cratedb/sa_version.py +28 -0
- sqlalchemy_cratedb/support/__init__.py +14 -0
- sqlalchemy_cratedb/support/pandas.py +111 -0
- sqlalchemy_cratedb/support/polyfill.py +125 -0
- sqlalchemy_cratedb/support/util.py +41 -0
- sqlalchemy_cratedb/type/__init__.py +4 -0
- sqlalchemy_cratedb/type/array.py +144 -0
- sqlalchemy_cratedb/type/geo.py +48 -0
- sqlalchemy_cratedb/type/object.py +92 -0
- sqlalchemy_cratedb/type/vector.py +173 -0
- sqlalchemy_cratedb-0.38.0.dist-info/LICENSE +178 -0
- sqlalchemy_cratedb-0.38.0.dist-info/METADATA +143 -0
- sqlalchemy_cratedb-0.38.0.dist-info/NOTICE +24 -0
- sqlalchemy_cratedb-0.38.0.dist-info/RECORD +26 -0
- sqlalchemy_cratedb-0.38.0.dist-info/WHEEL +5 -0
- sqlalchemy_cratedb-0.38.0.dist-info/entry_points.txt +2 -0
- sqlalchemy_cratedb-0.38.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,125 @@
|
|
1
|
+
import sqlalchemy as sa
|
2
|
+
from sqlalchemy.event import listen
|
3
|
+
import typing as t
|
4
|
+
|
5
|
+
from sqlalchemy_cratedb.support.util import refresh_dirty, refresh_table
|
6
|
+
|
7
|
+
|
8
|
+
def patch_autoincrement_timestamp():
|
9
|
+
"""
|
10
|
+
Configure SQLAlchemy model columns with an alternative to `autoincrement=True`.
|
11
|
+
Use the current timestamp instead.
|
12
|
+
|
13
|
+
This is used by CrateDB's MLflow adapter.
|
14
|
+
|
15
|
+
TODO: Maybe enable through a dialect parameter `crate_polyfill_autoincrement` or such.
|
16
|
+
"""
|
17
|
+
import sqlalchemy.sql.schema as schema
|
18
|
+
|
19
|
+
init_dist = schema.Column.__init__
|
20
|
+
|
21
|
+
def __init__(self, *args, **kwargs):
|
22
|
+
if "autoincrement" in kwargs:
|
23
|
+
del kwargs["autoincrement"]
|
24
|
+
if "default" not in kwargs:
|
25
|
+
kwargs["default"] = sa.func.now()
|
26
|
+
init_dist(self, *args, **kwargs)
|
27
|
+
|
28
|
+
schema.Column.__init__ = __init__ # type: ignore[method-assign]
|
29
|
+
|
30
|
+
|
31
|
+
def check_uniqueness_factory(sa_entity, *attribute_names):
|
32
|
+
"""
|
33
|
+
Run a manual column value uniqueness check on a table, and raise an IntegrityError if applicable.
|
34
|
+
|
35
|
+
CrateDB does not support the UNIQUE constraint on columns. This attempts to emulate it.
|
36
|
+
|
37
|
+
https://github.com/crate/sqlalchemy-cratedb/issues/76
|
38
|
+
|
39
|
+
This is used by CrateDB's MLflow adapter.
|
40
|
+
|
41
|
+
TODO: Maybe enable through a dialect parameter `crate_polyfill_unique` or such.
|
42
|
+
"""
|
43
|
+
|
44
|
+
# Synthesize a canonical "name" for the constraint,
|
45
|
+
# composed of all column names involved.
|
46
|
+
constraint_name: str = "-".join(attribute_names)
|
47
|
+
|
48
|
+
def check_uniqueness(mapper, connection, target):
|
49
|
+
from sqlalchemy.exc import IntegrityError
|
50
|
+
|
51
|
+
if isinstance(target, sa_entity):
|
52
|
+
# TODO: How to use `session.query(SqlExperiment)` here?
|
53
|
+
stmt = mapper.selectable.select()
|
54
|
+
for attribute_name in attribute_names:
|
55
|
+
stmt = stmt.filter(getattr(sa_entity, attribute_name) == getattr(target, attribute_name))
|
56
|
+
stmt = stmt.compile(bind=connection.engine)
|
57
|
+
results = connection.execute(stmt)
|
58
|
+
if results.rowcount > 0:
|
59
|
+
raise IntegrityError(
|
60
|
+
statement=stmt,
|
61
|
+
params=[],
|
62
|
+
orig=Exception(
|
63
|
+
f"DuplicateKeyException in table '{target.__tablename__}' " f"on constraint '{constraint_name}'"
|
64
|
+
),
|
65
|
+
)
|
66
|
+
|
67
|
+
return check_uniqueness
|
68
|
+
|
69
|
+
|
70
|
+
def refresh_after_dml_session(session: sa.orm.Session):
|
71
|
+
"""
|
72
|
+
Run `REFRESH TABLE` after each DML operation (INSERT, UPDATE, DELETE).
|
73
|
+
|
74
|
+
CrateDB is eventually consistent, i.e. write operations are not flushed to
|
75
|
+
disk immediately, so readers may see stale data. In a traditional OLTP-like
|
76
|
+
application, this is not applicable.
|
77
|
+
|
78
|
+
This SQLAlchemy extension makes sure that data is synchronized after each
|
79
|
+
operation manipulating data.
|
80
|
+
|
81
|
+
> `after_{insert,update,delete}` events only apply to the session flush operation
|
82
|
+
> and do not apply to the ORM DML operations described at ORM-Enabled INSERT,
|
83
|
+
> UPDATE, and DELETE statements. To intercept ORM DML events, use
|
84
|
+
> `SessionEvents.do_orm_execute().`
|
85
|
+
> -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.MapperEvents.after_insert
|
86
|
+
|
87
|
+
> Intercept statement executions that occur on behalf of an ORM Session object.
|
88
|
+
> -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.SessionEvents.do_orm_execute
|
89
|
+
|
90
|
+
> Execute after flush has completed, but before commit has been called.
|
91
|
+
> -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.SessionEvents.after_flush
|
92
|
+
|
93
|
+
This is used by CrateDB's LangChain adapter.
|
94
|
+
|
95
|
+
TODO: Maybe enable through a dialect parameter `crate_dml_refresh` or such.
|
96
|
+
""" # noqa: E501
|
97
|
+
listen(session, "after_flush", refresh_dirty)
|
98
|
+
|
99
|
+
|
100
|
+
def refresh_after_dml_engine(engine: sa.engine.Engine):
|
101
|
+
"""
|
102
|
+
Run `REFRESH TABLE` after each DML operation (INSERT, UPDATE, DELETE).
|
103
|
+
|
104
|
+
This is used by CrateDB's Singer/Meltano and `rdflib-sqlalchemy` adapters.
|
105
|
+
"""
|
106
|
+
def receive_after_execute(
|
107
|
+
conn: sa.engine.Connection, clauseelement, multiparams, params, execution_options, result
|
108
|
+
):
|
109
|
+
if isinstance(clauseelement, (sa.sql.Insert, sa.sql.Update, sa.sql.Delete)):
|
110
|
+
if not isinstance(clauseelement.table, sa.sql.Join):
|
111
|
+
refresh_table(conn, clauseelement.table)
|
112
|
+
|
113
|
+
sa.event.listen(engine, "after_execute", receive_after_execute)
|
114
|
+
|
115
|
+
|
116
|
+
def refresh_after_dml(engine_or_session: t.Union[sa.engine.Engine, sa.orm.Session]):
|
117
|
+
"""
|
118
|
+
Run `REFRESH TABLE` after each DML operation (INSERT, UPDATE, DELETE).
|
119
|
+
"""
|
120
|
+
if isinstance(engine_or_session, sa.engine.Engine):
|
121
|
+
refresh_after_dml_engine(engine_or_session)
|
122
|
+
elif isinstance(engine_or_session, (sa.orm.Session, sa.orm.scoping.scoped_session)):
|
123
|
+
refresh_after_dml_session(engine_or_session)
|
124
|
+
else:
|
125
|
+
raise TypeError(f"Unknown type: {type(engine_or_session)}")
|
@@ -0,0 +1,41 @@
|
|
1
|
+
import itertools
|
2
|
+
import typing as t
|
3
|
+
|
4
|
+
import sqlalchemy as sa
|
5
|
+
|
6
|
+
if t.TYPE_CHECKING:
|
7
|
+
try:
|
8
|
+
from sqlalchemy.orm import DeclarativeBase
|
9
|
+
except ImportError:
|
10
|
+
pass
|
11
|
+
|
12
|
+
|
13
|
+
def refresh_table(connection, target: t.Union[str, "DeclarativeBase", "sa.sql.selectable.TableClause"]):
|
14
|
+
"""
|
15
|
+
Invoke a `REFRESH TABLE` statement.
|
16
|
+
"""
|
17
|
+
|
18
|
+
if isinstance(target, sa.sql.selectable.TableClause):
|
19
|
+
full_table_name = f'"{target.name}"'
|
20
|
+
if target.schema is not None:
|
21
|
+
full_table_name = f'"{target.schema}".' + full_table_name
|
22
|
+
elif hasattr(target, "__tablename__"):
|
23
|
+
full_table_name = target.__tablename__
|
24
|
+
else:
|
25
|
+
full_table_name = target
|
26
|
+
|
27
|
+
sql = f"REFRESH TABLE {full_table_name}"
|
28
|
+
connection.execute(sa.text(sql))
|
29
|
+
|
30
|
+
|
31
|
+
def refresh_dirty(session, flush_context=None):
|
32
|
+
"""
|
33
|
+
Invoke a `REFRESH TABLE` statement on each table entity flagged as "dirty".
|
34
|
+
|
35
|
+
SQLAlchemy event handler for the 'after_flush' event,
|
36
|
+
invoking `REFRESH TABLE` on each table which has been modified.
|
37
|
+
"""
|
38
|
+
dirty_entities = itertools.chain(session.new, session.dirty, session.deleted)
|
39
|
+
dirty_classes = {entity.__class__ for entity in dirty_entities}
|
40
|
+
for class_ in dirty_classes:
|
41
|
+
refresh_table(session, class_)
|
@@ -0,0 +1,144 @@
|
|
1
|
+
# -*- coding: utf-8; -*-
|
2
|
+
#
|
3
|
+
# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor
|
4
|
+
# license agreements. See the NOTICE file distributed with this work for
|
5
|
+
# additional information regarding copyright ownership. Crate licenses
|
6
|
+
# this file to you under the Apache License, Version 2.0 (the "License");
|
7
|
+
# you may not use this file except in compliance with the License. You may
|
8
|
+
# obtain a copy of the License at
|
9
|
+
#
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
+
#
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
14
|
+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
15
|
+
# License for the specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
#
|
18
|
+
# However, if you have executed another commercial license agreement
|
19
|
+
# with Crate these terms will supersede the license and you may use the
|
20
|
+
# software solely pursuant to the terms of the relevant commercial agreement.
|
21
|
+
|
22
|
+
import sqlalchemy.types as sqltypes
|
23
|
+
from sqlalchemy.sql import operators, expression
|
24
|
+
from sqlalchemy.sql import default_comparator
|
25
|
+
from sqlalchemy.ext.mutable import Mutable
|
26
|
+
|
27
|
+
|
28
|
+
class MutableList(Mutable, list):
|
29
|
+
|
30
|
+
@classmethod
|
31
|
+
def coerce(cls, key, value):
|
32
|
+
""" Convert plain list to MutableList """
|
33
|
+
if not isinstance(value, MutableList):
|
34
|
+
if isinstance(value, list):
|
35
|
+
return MutableList(value)
|
36
|
+
elif value is None:
|
37
|
+
return value
|
38
|
+
else:
|
39
|
+
return MutableList([value])
|
40
|
+
else:
|
41
|
+
return value
|
42
|
+
|
43
|
+
def __init__(self, initval=None):
|
44
|
+
list.__init__(self, initval or [])
|
45
|
+
|
46
|
+
def __setitem__(self, key, value):
|
47
|
+
list.__setitem__(self, key, value)
|
48
|
+
self.changed()
|
49
|
+
|
50
|
+
def __eq__(self, other):
|
51
|
+
return list.__eq__(self, other)
|
52
|
+
|
53
|
+
def append(self, item):
|
54
|
+
list.append(self, item)
|
55
|
+
self.changed()
|
56
|
+
|
57
|
+
def insert(self, idx, item):
|
58
|
+
list.insert(self, idx, item)
|
59
|
+
self.changed()
|
60
|
+
|
61
|
+
def extend(self, iterable):
|
62
|
+
list.extend(self, iterable)
|
63
|
+
self.changed()
|
64
|
+
|
65
|
+
def pop(self, index=-1):
|
66
|
+
list.pop(self, index)
|
67
|
+
self.changed()
|
68
|
+
|
69
|
+
def remove(self, item):
|
70
|
+
list.remove(self, item)
|
71
|
+
self.changed()
|
72
|
+
|
73
|
+
|
74
|
+
class Any(expression.ColumnElement):
|
75
|
+
"""Represent the clause ``left operator ANY (right)``. ``right`` must be
|
76
|
+
an array expression.
|
77
|
+
|
78
|
+
copied from postgresql dialect
|
79
|
+
|
80
|
+
.. seealso::
|
81
|
+
|
82
|
+
:class:`sqlalchemy.dialects.postgresql.ARRAY`
|
83
|
+
|
84
|
+
:meth:`sqlalchemy.dialects.postgresql.ARRAY.Comparator.any`
|
85
|
+
ARRAY-bound method
|
86
|
+
|
87
|
+
"""
|
88
|
+
__visit_name__ = 'any'
|
89
|
+
inherit_cache = True
|
90
|
+
|
91
|
+
def __init__(self, left, right, operator=operators.eq):
|
92
|
+
self.type = sqltypes.Boolean()
|
93
|
+
self.left = expression.literal(left)
|
94
|
+
self.right = right
|
95
|
+
self.operator = operator
|
96
|
+
|
97
|
+
|
98
|
+
class _ObjectArray(sqltypes.UserDefinedType):
|
99
|
+
cache_ok = True
|
100
|
+
|
101
|
+
class Comparator(sqltypes.TypeEngine.Comparator):
|
102
|
+
def __getitem__(self, key):
|
103
|
+
return default_comparator._binary_operate(self.expr,
|
104
|
+
operators.getitem,
|
105
|
+
key)
|
106
|
+
|
107
|
+
def any(self, other, operator=operators.eq):
|
108
|
+
"""Return ``other operator ANY (array)`` clause.
|
109
|
+
|
110
|
+
Argument places are switched, because ANY requires array
|
111
|
+
expression to be on the right hand-side.
|
112
|
+
|
113
|
+
E.g.::
|
114
|
+
|
115
|
+
from sqlalchemy.sql import operators
|
116
|
+
|
117
|
+
conn.execute(
|
118
|
+
select([table.c.data]).where(
|
119
|
+
table.c.data.any(7, operator=operators.lt)
|
120
|
+
)
|
121
|
+
)
|
122
|
+
|
123
|
+
:param other: expression to be compared
|
124
|
+
:param operator: an operator object from the
|
125
|
+
:mod:`sqlalchemy.sql.operators`
|
126
|
+
package, defaults to :func:`.operators.eq`.
|
127
|
+
|
128
|
+
.. seealso::
|
129
|
+
|
130
|
+
:class:`.postgresql.Any`
|
131
|
+
|
132
|
+
:meth:`.postgresql.ARRAY.Comparator.all`
|
133
|
+
|
134
|
+
"""
|
135
|
+
return Any(other, self.expr, operator=operator)
|
136
|
+
|
137
|
+
type = MutableList
|
138
|
+
comparator_factory = Comparator
|
139
|
+
|
140
|
+
def get_col_spec(self, **kws):
|
141
|
+
return "ARRAY(OBJECT)"
|
142
|
+
|
143
|
+
|
144
|
+
ObjectArray = MutableList.as_mutable(_ObjectArray)
|
@@ -0,0 +1,48 @@
|
|
1
|
+
import geojson
|
2
|
+
from sqlalchemy import types as sqltypes
|
3
|
+
from sqlalchemy.sql import default_comparator, operators
|
4
|
+
|
5
|
+
|
6
|
+
class Geopoint(sqltypes.UserDefinedType):
|
7
|
+
cache_ok = True
|
8
|
+
|
9
|
+
class Comparator(sqltypes.TypeEngine.Comparator):
|
10
|
+
|
11
|
+
def __getitem__(self, key):
|
12
|
+
return default_comparator._binary_operate(self.expr,
|
13
|
+
operators.getitem,
|
14
|
+
key)
|
15
|
+
|
16
|
+
def get_col_spec(self):
|
17
|
+
return 'GEO_POINT'
|
18
|
+
|
19
|
+
def bind_processor(self, dialect):
|
20
|
+
def process(value):
|
21
|
+
if isinstance(value, geojson.Point):
|
22
|
+
return value.coordinates
|
23
|
+
return value
|
24
|
+
return process
|
25
|
+
|
26
|
+
def result_processor(self, dialect, coltype):
|
27
|
+
return tuple
|
28
|
+
|
29
|
+
comparator_factory = Comparator
|
30
|
+
|
31
|
+
|
32
|
+
class Geoshape(sqltypes.UserDefinedType):
|
33
|
+
cache_ok = True
|
34
|
+
|
35
|
+
class Comparator(sqltypes.TypeEngine.Comparator):
|
36
|
+
|
37
|
+
def __getitem__(self, key):
|
38
|
+
return default_comparator._binary_operate(self.expr,
|
39
|
+
operators.getitem,
|
40
|
+
key)
|
41
|
+
|
42
|
+
def get_col_spec(self):
|
43
|
+
return 'GEO_SHAPE'
|
44
|
+
|
45
|
+
def result_processor(self, dialect, coltype):
|
46
|
+
return geojson.GeoJSON.to_instance
|
47
|
+
|
48
|
+
comparator_factory = Comparator
|
@@ -0,0 +1,92 @@
|
|
1
|
+
import warnings
|
2
|
+
|
3
|
+
from sqlalchemy import types as sqltypes
|
4
|
+
from sqlalchemy.ext.mutable import Mutable
|
5
|
+
|
6
|
+
|
7
|
+
class MutableDict(Mutable, dict):
|
8
|
+
|
9
|
+
@classmethod
|
10
|
+
def coerce(cls, key, value):
|
11
|
+
"Convert plain dictionaries to MutableDict."
|
12
|
+
|
13
|
+
if not isinstance(value, MutableDict):
|
14
|
+
if isinstance(value, dict):
|
15
|
+
return MutableDict(value)
|
16
|
+
|
17
|
+
# this call will raise ValueError
|
18
|
+
return Mutable.coerce(key, value)
|
19
|
+
else:
|
20
|
+
return value
|
21
|
+
|
22
|
+
def __init__(self, initval=None, to_update=None, root_change_key=None):
|
23
|
+
initval = initval or {}
|
24
|
+
self._changed_keys = set()
|
25
|
+
self._deleted_keys = set()
|
26
|
+
self._overwrite_key = root_change_key
|
27
|
+
self.to_update = self if to_update is None else to_update
|
28
|
+
for k in initval:
|
29
|
+
initval[k] = self._convert_dict(initval[k],
|
30
|
+
overwrite_key=k if self._overwrite_key is None else self._overwrite_key
|
31
|
+
)
|
32
|
+
dict.__init__(self, initval)
|
33
|
+
|
34
|
+
def __setitem__(self, key, value):
|
35
|
+
value = self._convert_dict(value, key if self._overwrite_key is None else self._overwrite_key)
|
36
|
+
dict.__setitem__(self, key, value)
|
37
|
+
self.to_update.on_key_changed(
|
38
|
+
key if self._overwrite_key is None else self._overwrite_key
|
39
|
+
)
|
40
|
+
|
41
|
+
def __delitem__(self, key):
|
42
|
+
dict.__delitem__(self, key)
|
43
|
+
# add the key to the deleted keys if this is the root object
|
44
|
+
# otherwise update on root object
|
45
|
+
if self._overwrite_key is None:
|
46
|
+
self._deleted_keys.add(key)
|
47
|
+
self.changed()
|
48
|
+
else:
|
49
|
+
self.to_update.on_key_changed(self._overwrite_key)
|
50
|
+
|
51
|
+
def on_key_changed(self, key):
|
52
|
+
self._deleted_keys.discard(key)
|
53
|
+
self._changed_keys.add(key)
|
54
|
+
self.changed()
|
55
|
+
|
56
|
+
def _convert_dict(self, value, overwrite_key):
|
57
|
+
if isinstance(value, dict) and not isinstance(value, MutableDict):
|
58
|
+
return MutableDict(value, self.to_update, overwrite_key)
|
59
|
+
return value
|
60
|
+
|
61
|
+
def __eq__(self, other):
|
62
|
+
return dict.__eq__(self, other)
|
63
|
+
|
64
|
+
|
65
|
+
class ObjectTypeImpl(sqltypes.UserDefinedType, sqltypes.JSON):
|
66
|
+
|
67
|
+
__visit_name__ = "OBJECT"
|
68
|
+
|
69
|
+
cache_ok = False
|
70
|
+
none_as_null = False
|
71
|
+
|
72
|
+
|
73
|
+
# Designated name to refer to. `Object` is too ambiguous.
|
74
|
+
ObjectType = MutableDict.as_mutable(ObjectTypeImpl)
|
75
|
+
|
76
|
+
# Backward-compatibility aliases.
|
77
|
+
_deprecated_Craty = ObjectType
|
78
|
+
_deprecated_Object = ObjectType
|
79
|
+
|
80
|
+
# https://www.lesinskis.com/deprecating-module-scope-variables.html
|
81
|
+
deprecated_names = ["Craty", "Object"]
|
82
|
+
|
83
|
+
|
84
|
+
def __getattr__(name):
|
85
|
+
if name in deprecated_names:
|
86
|
+
warnings.warn(f"{name} is deprecated and will be removed in future releases. "
|
87
|
+
f"Please use ObjectType instead.", DeprecationWarning)
|
88
|
+
return globals()[f"_deprecated_{name}"]
|
89
|
+
raise AttributeError(f"module {__name__} has no attribute {name}")
|
90
|
+
|
91
|
+
|
92
|
+
__all__ = deprecated_names + ["ObjectType"]
|
@@ -0,0 +1,173 @@
|
|
1
|
+
"""
|
2
|
+
## About
|
3
|
+
SQLAlchemy data type implementation for CrateDB's `FLOAT_VECTOR` type.
|
4
|
+
|
5
|
+
## References
|
6
|
+
- https://cratedb.com/docs/crate/reference/en/latest/general/ddl/data-types.html#float-vector
|
7
|
+
- https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match
|
8
|
+
|
9
|
+
## Details
|
10
|
+
The implementation is based on SQLAlchemy's `TypeDecorator`, and also
|
11
|
+
offers compiler support.
|
12
|
+
|
13
|
+
## Notes
|
14
|
+
CrateDB currently only supports the similarity function `VectorSimilarityFunction.EUCLIDEAN`.
|
15
|
+
-- https://github.com/crate/crate/blob/5.5.1/server/src/main/java/io/crate/types/FloatVectorType.java#L55
|
16
|
+
|
17
|
+
pgvector use a comparator to apply different similarity functions as operators,
|
18
|
+
see `pgvector.sqlalchemy.Vector.comparator_factory`.
|
19
|
+
|
20
|
+
<->: l2/euclidean_distance
|
21
|
+
<#>: max_inner_product
|
22
|
+
<=>: cosine_distance
|
23
|
+
|
24
|
+
## Backlog
|
25
|
+
- After dropping support for SQLAlchemy 1.3, use
|
26
|
+
`class FloatVector(sa.TypeDecorator[t.Sequence[float]]):`
|
27
|
+
|
28
|
+
## Origin
|
29
|
+
This module is based on the corresponding pgvector implementation
|
30
|
+
by Andrew Kane. Thank you.
|
31
|
+
|
32
|
+
The MIT License (MIT)
|
33
|
+
Copyright (c) 2021-2023 Andrew Kane
|
34
|
+
https://github.com/pgvector/pgvector-python
|
35
|
+
"""
|
36
|
+
import typing as t
|
37
|
+
|
38
|
+
if t.TYPE_CHECKING:
|
39
|
+
import numpy.typing as npt # pragma: no cover
|
40
|
+
|
41
|
+
import sqlalchemy as sa
|
42
|
+
from sqlalchemy.sql.expression import ColumnElement, literal
|
43
|
+
from sqlalchemy.ext.compiler import compiles
|
44
|
+
|
45
|
+
|
46
|
+
__all__ = [
|
47
|
+
"from_db",
|
48
|
+
"knn_match",
|
49
|
+
"to_db",
|
50
|
+
"FloatVector",
|
51
|
+
]
|
52
|
+
|
53
|
+
|
54
|
+
def from_db(value: t.Iterable) -> t.Optional["npt.ArrayLike"]:
|
55
|
+
import numpy as np
|
56
|
+
|
57
|
+
# from `pgvector.utils`
|
58
|
+
# could be ndarray if already cast by lower-level driver
|
59
|
+
if value is None or isinstance(value, np.ndarray):
|
60
|
+
return value
|
61
|
+
|
62
|
+
return np.array(value, dtype=np.float32)
|
63
|
+
|
64
|
+
|
65
|
+
def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]:
|
66
|
+
import numpy as np
|
67
|
+
|
68
|
+
# from `pgvector.utils`
|
69
|
+
if value is None:
|
70
|
+
return value
|
71
|
+
|
72
|
+
if isinstance(value, np.ndarray):
|
73
|
+
if value.ndim != 1:
|
74
|
+
raise ValueError("expected ndim to be 1")
|
75
|
+
|
76
|
+
if not np.issubdtype(value.dtype, np.integer) and not np.issubdtype(value.dtype, np.floating):
|
77
|
+
raise ValueError("dtype must be numeric")
|
78
|
+
|
79
|
+
value = value.tolist()
|
80
|
+
|
81
|
+
if dim is not None and len(value) != dim:
|
82
|
+
raise ValueError("expected %d dimensions, not %d" % (dim, len(value)))
|
83
|
+
|
84
|
+
return value
|
85
|
+
|
86
|
+
|
87
|
+
class FloatVector(sa.TypeDecorator):
|
88
|
+
"""
|
89
|
+
SQLAlchemy `FloatVector` data type for CrateDB.
|
90
|
+
"""
|
91
|
+
|
92
|
+
cache_ok = False
|
93
|
+
|
94
|
+
__visit_name__ = "FLOAT_VECTOR"
|
95
|
+
|
96
|
+
_is_array = True
|
97
|
+
|
98
|
+
zero_indexes = False
|
99
|
+
|
100
|
+
impl = sa.ARRAY
|
101
|
+
|
102
|
+
def __init__(self, dimensions: int = None):
|
103
|
+
super().__init__(sa.FLOAT, dimensions=dimensions)
|
104
|
+
|
105
|
+
def as_generic(self, allow_nulltype=False):
|
106
|
+
return sa.ARRAY(item_type=sa.FLOAT)
|
107
|
+
|
108
|
+
@property
|
109
|
+
def python_type(self):
|
110
|
+
return list
|
111
|
+
|
112
|
+
def bind_processor(self, dialect: sa.engine.Dialect) -> t.Callable:
|
113
|
+
def process(value: t.Iterable) -> t.Optional[t.List]:
|
114
|
+
return to_db(value, self.dimensions)
|
115
|
+
|
116
|
+
return process
|
117
|
+
|
118
|
+
def result_processor(self, dialect: sa.engine.Dialect, coltype: t.Any) -> t.Callable:
|
119
|
+
def process(value: t.Any) -> t.Optional["npt.ArrayLike"]:
|
120
|
+
return from_db(value)
|
121
|
+
|
122
|
+
return process
|
123
|
+
|
124
|
+
|
125
|
+
class KnnMatch(ColumnElement):
|
126
|
+
"""
|
127
|
+
Wrap CrateDB's `KNN_MATCH` function into an SQLAlchemy function.
|
128
|
+
|
129
|
+
https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match
|
130
|
+
"""
|
131
|
+
inherit_cache = True
|
132
|
+
|
133
|
+
def __init__(self, column, term, k=None):
|
134
|
+
super().__init__()
|
135
|
+
self.column = column
|
136
|
+
self.term = term
|
137
|
+
self.k = k
|
138
|
+
|
139
|
+
def compile_column(self, compiler):
|
140
|
+
return compiler.process(self.column)
|
141
|
+
|
142
|
+
def compile_term(self, compiler):
|
143
|
+
return compiler.process(literal(self.term))
|
144
|
+
|
145
|
+
def compile_k(self, compiler):
|
146
|
+
return compiler.process(literal(self.k))
|
147
|
+
|
148
|
+
|
149
|
+
def knn_match(column, term, k):
|
150
|
+
"""
|
151
|
+
Generate a match predicate for vector search.
|
152
|
+
|
153
|
+
:param column: A reference to a column or an index.
|
154
|
+
|
155
|
+
:param term: The term to match against. This is an array of floating point
|
156
|
+
values, which is compared to other vectors using a HNSW index search.
|
157
|
+
|
158
|
+
:param k: The `k` argument determines the number of nearest neighbours to
|
159
|
+
search in the index.
|
160
|
+
"""
|
161
|
+
return KnnMatch(column, term, k)
|
162
|
+
|
163
|
+
|
164
|
+
@compiles(KnnMatch)
|
165
|
+
def compile_knn_match(knn_match, compiler, **kwargs):
|
166
|
+
"""
|
167
|
+
Clause compiler for `KNN_MATCH`.
|
168
|
+
"""
|
169
|
+
return "KNN_MATCH(%s, %s, %s)" % (
|
170
|
+
knn_match.compile_column(compiler),
|
171
|
+
knn_match.compile_term(compiler),
|
172
|
+
knn_match.compile_k(compiler),
|
173
|
+
)
|