sqlalchemy-cratedb 0.41.0.dev0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- sqlalchemy_cratedb/__init__.py +65 -0
- sqlalchemy_cratedb/compat/__init__.py +0 -0
- sqlalchemy_cratedb/compat/api13.py +152 -0
- sqlalchemy_cratedb/compat/core10.py +253 -0
- sqlalchemy_cratedb/compat/core14.py +337 -0
- sqlalchemy_cratedb/compat/core20.py +423 -0
- sqlalchemy_cratedb/compiler.py +361 -0
- sqlalchemy_cratedb/dialect.py +414 -0
- sqlalchemy_cratedb/predicate.py +96 -0
- sqlalchemy_cratedb/sa_version.py +28 -0
- sqlalchemy_cratedb/support/__init__.py +18 -0
- sqlalchemy_cratedb/support/pandas.py +110 -0
- sqlalchemy_cratedb/support/polyfill.py +130 -0
- sqlalchemy_cratedb/support/util.py +82 -0
- sqlalchemy_cratedb/type/__init__.py +13 -0
- sqlalchemy_cratedb/type/array.py +143 -0
- sqlalchemy_cratedb/type/geo.py +43 -0
- sqlalchemy_cratedb/type/object.py +94 -0
- sqlalchemy_cratedb/type/vector.py +176 -0
- sqlalchemy_cratedb-0.41.0.dev0.dist-info/LICENSE +178 -0
- sqlalchemy_cratedb-0.41.0.dev0.dist-info/METADATA +143 -0
- sqlalchemy_cratedb-0.41.0.dev0.dist-info/NOTICE +24 -0
- sqlalchemy_cratedb-0.41.0.dev0.dist-info/RECORD +26 -0
- sqlalchemy_cratedb-0.41.0.dev0.dist-info/WHEEL +5 -0
- sqlalchemy_cratedb-0.41.0.dev0.dist-info/entry_points.txt +2 -0
- sqlalchemy_cratedb-0.41.0.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,130 @@
|
|
1
|
+
import typing as t
|
2
|
+
|
3
|
+
import sqlalchemy as sa
|
4
|
+
from sqlalchemy.event import listen
|
5
|
+
|
6
|
+
from sqlalchemy_cratedb.support.util import refresh_dirty, refresh_table
|
7
|
+
|
8
|
+
|
9
|
+
def patch_autoincrement_timestamp():
|
10
|
+
"""
|
11
|
+
Configure SQLAlchemy model columns with an alternative to `autoincrement=True`.
|
12
|
+
Use the current timestamp instead.
|
13
|
+
|
14
|
+
This is used by CrateDB's MLflow adapter.
|
15
|
+
|
16
|
+
TODO: Maybe enable through a dialect parameter `crate_polyfill_autoincrement` or such.
|
17
|
+
"""
|
18
|
+
import sqlalchemy.sql.schema as schema
|
19
|
+
|
20
|
+
init_dist = schema.Column.__init__
|
21
|
+
|
22
|
+
def __init__(self, *args, **kwargs):
|
23
|
+
if "autoincrement" in kwargs:
|
24
|
+
del kwargs["autoincrement"]
|
25
|
+
if "default" not in kwargs:
|
26
|
+
kwargs["default"] = sa.func.now()
|
27
|
+
init_dist(self, *args, **kwargs)
|
28
|
+
|
29
|
+
schema.Column.__init__ = __init__ # type: ignore[method-assign]
|
30
|
+
|
31
|
+
|
32
|
+
def check_uniqueness_factory(sa_entity, *attribute_names):
|
33
|
+
"""
|
34
|
+
Run a manual column value uniqueness check on a table, and raise an IntegrityError if applicable.
|
35
|
+
|
36
|
+
CrateDB does not support the UNIQUE constraint on columns. This attempts to emulate it.
|
37
|
+
|
38
|
+
https://github.com/crate/sqlalchemy-cratedb/issues/76
|
39
|
+
|
40
|
+
This is used by CrateDB's MLflow adapter.
|
41
|
+
|
42
|
+
TODO: Maybe enable through a dialect parameter `crate_polyfill_unique` or such.
|
43
|
+
""" # noqa: E501
|
44
|
+
|
45
|
+
# Synthesize a canonical "name" for the constraint,
|
46
|
+
# composed of all column names involved.
|
47
|
+
constraint_name: str = "-".join(attribute_names)
|
48
|
+
|
49
|
+
def check_uniqueness(mapper, connection, target):
|
50
|
+
from sqlalchemy.exc import IntegrityError
|
51
|
+
|
52
|
+
if isinstance(target, sa_entity):
|
53
|
+
# TODO: How to use `session.query(SqlExperiment)` here?
|
54
|
+
stmt = mapper.selectable.select()
|
55
|
+
for attribute_name in attribute_names:
|
56
|
+
stmt = stmt.filter(
|
57
|
+
getattr(sa_entity, attribute_name) == getattr(target, attribute_name)
|
58
|
+
)
|
59
|
+
stmt = stmt.compile(bind=connection.engine)
|
60
|
+
results = connection.execute(stmt)
|
61
|
+
if results.rowcount > 0:
|
62
|
+
raise IntegrityError(
|
63
|
+
statement=stmt,
|
64
|
+
params=[],
|
65
|
+
orig=Exception(
|
66
|
+
f"DuplicateKeyException in table '{target.__tablename__}' "
|
67
|
+
f"on constraint '{constraint_name}'"
|
68
|
+
),
|
69
|
+
)
|
70
|
+
|
71
|
+
return check_uniqueness
|
72
|
+
|
73
|
+
|
74
|
+
def refresh_after_dml_session(session: sa.orm.Session):
|
75
|
+
"""
|
76
|
+
Run `REFRESH TABLE` after each DML operation (INSERT, UPDATE, DELETE).
|
77
|
+
|
78
|
+
CrateDB is eventually consistent, i.e. write operations are not flushed to
|
79
|
+
disk immediately, so readers may see stale data. In a traditional OLTP-like
|
80
|
+
application, this is not applicable.
|
81
|
+
|
82
|
+
This SQLAlchemy extension makes sure that data is synchronized after each
|
83
|
+
operation manipulating data.
|
84
|
+
|
85
|
+
> `after_{insert,update,delete}` events only apply to the session flush operation
|
86
|
+
> and do not apply to the ORM DML operations described at ORM-Enabled INSERT,
|
87
|
+
> UPDATE, and DELETE statements. To intercept ORM DML events, use
|
88
|
+
> `SessionEvents.do_orm_execute().`
|
89
|
+
> -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.MapperEvents.after_insert
|
90
|
+
|
91
|
+
> Intercept statement executions that occur on behalf of an ORM Session object.
|
92
|
+
> -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.SessionEvents.do_orm_execute
|
93
|
+
|
94
|
+
> Execute after flush has completed, but before commit has been called.
|
95
|
+
> -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.SessionEvents.after_flush
|
96
|
+
|
97
|
+
This is used by CrateDB's LangChain adapter.
|
98
|
+
|
99
|
+
TODO: Maybe enable through a dialect parameter `crate_dml_refresh` or such.
|
100
|
+
""" # noqa: E501
|
101
|
+
listen(session, "after_flush", refresh_dirty)
|
102
|
+
|
103
|
+
|
104
|
+
def refresh_after_dml_engine(engine: sa.engine.Engine):
|
105
|
+
"""
|
106
|
+
Run `REFRESH TABLE` after each DML operation (INSERT, UPDATE, DELETE).
|
107
|
+
|
108
|
+
This is used by CrateDB's Singer/Meltano and `rdflib-sqlalchemy` adapters.
|
109
|
+
"""
|
110
|
+
|
111
|
+
def receive_after_execute(
|
112
|
+
conn: sa.engine.Connection, clauseelement, multiparams, params, execution_options, result
|
113
|
+
):
|
114
|
+
if isinstance(clauseelement, (sa.sql.Insert, sa.sql.Update, sa.sql.Delete)):
|
115
|
+
if not isinstance(clauseelement.table, sa.sql.Join):
|
116
|
+
refresh_table(conn, clauseelement.table)
|
117
|
+
|
118
|
+
sa.event.listen(engine, "after_execute", receive_after_execute)
|
119
|
+
|
120
|
+
|
121
|
+
def refresh_after_dml(engine_or_session: t.Union[sa.engine.Engine, sa.orm.Session]):
|
122
|
+
"""
|
123
|
+
Run `REFRESH TABLE` after each DML operation (INSERT, UPDATE, DELETE).
|
124
|
+
"""
|
125
|
+
if isinstance(engine_or_session, sa.engine.Engine):
|
126
|
+
refresh_after_dml_engine(engine_or_session)
|
127
|
+
elif isinstance(engine_or_session, (sa.orm.Session, sa.orm.scoping.scoped_session)):
|
128
|
+
refresh_after_dml_session(engine_or_session)
|
129
|
+
else:
|
130
|
+
raise TypeError(f"Unknown type: {type(engine_or_session)}")
|
@@ -0,0 +1,82 @@
|
|
1
|
+
import itertools
|
2
|
+
import typing as t
|
3
|
+
|
4
|
+
import sqlalchemy as sa
|
5
|
+
|
6
|
+
from sqlalchemy_cratedb.dialect import CrateDialect
|
7
|
+
|
8
|
+
if t.TYPE_CHECKING:
|
9
|
+
try:
|
10
|
+
from sqlalchemy.orm import DeclarativeBase
|
11
|
+
except ImportError:
|
12
|
+
pass
|
13
|
+
|
14
|
+
|
15
|
+
# An instance of the dialect used for quoting purposes.
|
16
|
+
identifier_preparer = CrateDialect().identifier_preparer
|
17
|
+
|
18
|
+
|
19
|
+
def refresh_table(
|
20
|
+
connection, target: t.Union[str, "DeclarativeBase", "sa.sql.selectable.TableClause"]
|
21
|
+
):
|
22
|
+
"""
|
23
|
+
Invoke a `REFRESH TABLE` statement.
|
24
|
+
"""
|
25
|
+
|
26
|
+
if isinstance(target, sa.sql.selectable.TableClause):
|
27
|
+
full_table_name = f'"{target.name}"'
|
28
|
+
if target.schema is not None:
|
29
|
+
full_table_name = f'"{target.schema}".' + full_table_name
|
30
|
+
elif hasattr(target, "__tablename__"):
|
31
|
+
full_table_name = target.__tablename__
|
32
|
+
else:
|
33
|
+
full_table_name = target
|
34
|
+
|
35
|
+
sql = f"REFRESH TABLE {full_table_name}"
|
36
|
+
connection.execute(sa.text(sql))
|
37
|
+
|
38
|
+
|
39
|
+
def refresh_dirty(session, flush_context=None):
|
40
|
+
"""
|
41
|
+
Invoke a `REFRESH TABLE` statement on each table entity flagged as "dirty".
|
42
|
+
|
43
|
+
SQLAlchemy event handler for the 'after_flush' event,
|
44
|
+
invoking `REFRESH TABLE` on each table which has been modified.
|
45
|
+
"""
|
46
|
+
dirty_entities = itertools.chain(session.new, session.dirty, session.deleted)
|
47
|
+
dirty_classes = {entity.__class__ for entity in dirty_entities}
|
48
|
+
for class_ in dirty_classes:
|
49
|
+
refresh_table(session, class_)
|
50
|
+
|
51
|
+
|
52
|
+
def quote_relation_name(ident: str) -> str:
|
53
|
+
"""
|
54
|
+
Quote a simple or full-qualified table/relation name, when needed.
|
55
|
+
|
56
|
+
Simple: <table>
|
57
|
+
Full-qualified: <schema>.<table>
|
58
|
+
|
59
|
+
Happy path examples:
|
60
|
+
|
61
|
+
foo => foo
|
62
|
+
Foo => "Foo"
|
63
|
+
"Foo" => "Foo"
|
64
|
+
foo.bar => foo.bar
|
65
|
+
foo-bar.baz_qux => "foo-bar".baz_qux
|
66
|
+
|
67
|
+
Such input strings will not be modified:
|
68
|
+
|
69
|
+
"foo.bar" => "foo.bar"
|
70
|
+
"""
|
71
|
+
|
72
|
+
# If a quote exists at the beginning or the end of the input string,
|
73
|
+
# let's consider that the relation name has been quoted already.
|
74
|
+
if ident.startswith('"') or ident.endswith('"'):
|
75
|
+
return ident
|
76
|
+
|
77
|
+
# If a dot is included, it's a full-qualified identifier like <schema>.<table>.
|
78
|
+
# It needs to be split, in order to apply identifier quoting properly.
|
79
|
+
parts = ident.split(".")
|
80
|
+
if len(parts) > 3:
|
81
|
+
raise ValueError(f"Invalid relation name, too many parts: {ident}")
|
82
|
+
return ".".join(map(identifier_preparer.quote, parts))
|
@@ -0,0 +1,143 @@
|
|
1
|
+
# -*- coding: utf-8; -*-
|
2
|
+
#
|
3
|
+
# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor
|
4
|
+
# license agreements. See the NOTICE file distributed with this work for
|
5
|
+
# additional information regarding copyright ownership. Crate licenses
|
6
|
+
# this file to you under the Apache License, Version 2.0 (the "License");
|
7
|
+
# you may not use this file except in compliance with the License. You may
|
8
|
+
# obtain a copy of the License at
|
9
|
+
#
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
+
#
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
14
|
+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
15
|
+
# License for the specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
#
|
18
|
+
# However, if you have executed another commercial license agreement
|
19
|
+
# with Crate these terms will supersede the license and you may use the
|
20
|
+
# software solely pursuant to the terms of the relevant commercial agreement.
|
21
|
+
|
22
|
+
# ruff: noqa: A005 # Module `array` shadows a Python standard-library module
|
23
|
+
|
24
|
+
import sqlalchemy.types as sqltypes
|
25
|
+
from sqlalchemy.ext.mutable import Mutable
|
26
|
+
from sqlalchemy.sql import default_comparator, expression, operators
|
27
|
+
|
28
|
+
|
29
|
+
class MutableList(Mutable, list):
|
30
|
+
@classmethod
|
31
|
+
def coerce(cls, key, value):
|
32
|
+
"""Convert plain list to MutableList"""
|
33
|
+
if not isinstance(value, MutableList):
|
34
|
+
if isinstance(value, list):
|
35
|
+
return MutableList(value)
|
36
|
+
elif value is None:
|
37
|
+
return value
|
38
|
+
else:
|
39
|
+
return MutableList([value])
|
40
|
+
else:
|
41
|
+
return value
|
42
|
+
|
43
|
+
def __init__(self, initval=None):
|
44
|
+
list.__init__(self, initval or [])
|
45
|
+
|
46
|
+
def __setitem__(self, key, value):
|
47
|
+
list.__setitem__(self, key, value)
|
48
|
+
self.changed()
|
49
|
+
|
50
|
+
def __eq__(self, other):
|
51
|
+
return list.__eq__(self, other)
|
52
|
+
|
53
|
+
def append(self, item):
|
54
|
+
list.append(self, item)
|
55
|
+
self.changed()
|
56
|
+
|
57
|
+
def insert(self, idx, item):
|
58
|
+
list.insert(self, idx, item)
|
59
|
+
self.changed()
|
60
|
+
|
61
|
+
def extend(self, iterable):
|
62
|
+
list.extend(self, iterable)
|
63
|
+
self.changed()
|
64
|
+
|
65
|
+
def pop(self, index=-1):
|
66
|
+
list.pop(self, index)
|
67
|
+
self.changed()
|
68
|
+
|
69
|
+
def remove(self, item):
|
70
|
+
list.remove(self, item)
|
71
|
+
self.changed()
|
72
|
+
|
73
|
+
|
74
|
+
class Any(expression.ColumnElement):
|
75
|
+
"""Represent the clause ``left operator ANY (right)``. ``right`` must be
|
76
|
+
an array expression.
|
77
|
+
|
78
|
+
copied from postgresql dialect
|
79
|
+
|
80
|
+
.. seealso::
|
81
|
+
|
82
|
+
:class:`sqlalchemy.dialects.postgresql.ARRAY`
|
83
|
+
|
84
|
+
:meth:`sqlalchemy.dialects.postgresql.ARRAY.Comparator.any`
|
85
|
+
ARRAY-bound method
|
86
|
+
|
87
|
+
"""
|
88
|
+
|
89
|
+
__visit_name__ = "any"
|
90
|
+
inherit_cache = True
|
91
|
+
|
92
|
+
def __init__(self, left, right, operator=operators.eq):
|
93
|
+
self.type = sqltypes.Boolean()
|
94
|
+
self.left = expression.literal(left)
|
95
|
+
self.right = right
|
96
|
+
self.operator = operator
|
97
|
+
|
98
|
+
|
99
|
+
class _ObjectArray(sqltypes.UserDefinedType):
|
100
|
+
cache_ok = True
|
101
|
+
|
102
|
+
class Comparator(sqltypes.TypeEngine.Comparator):
|
103
|
+
def __getitem__(self, key):
|
104
|
+
return default_comparator._binary_operate(self.expr, operators.getitem, key)
|
105
|
+
|
106
|
+
def any(self, other, operator=operators.eq):
|
107
|
+
"""Return ``other operator ANY (array)`` clause.
|
108
|
+
|
109
|
+
Argument places are switched, because ANY requires array
|
110
|
+
expression to be on the right hand-side.
|
111
|
+
|
112
|
+
E.g.::
|
113
|
+
|
114
|
+
from sqlalchemy.sql import operators
|
115
|
+
|
116
|
+
conn.execute(
|
117
|
+
select([table.c.data]).where(
|
118
|
+
table.c.data.any(7, operator=operators.lt)
|
119
|
+
)
|
120
|
+
)
|
121
|
+
|
122
|
+
:param other: expression to be compared
|
123
|
+
:param operator: an operator object from the
|
124
|
+
:mod:`sqlalchemy.sql.operators`
|
125
|
+
package, defaults to :func:`.operators.eq`.
|
126
|
+
|
127
|
+
.. seealso::
|
128
|
+
|
129
|
+
:class:`.postgresql.Any`
|
130
|
+
|
131
|
+
:meth:`.postgresql.ARRAY.Comparator.all`
|
132
|
+
|
133
|
+
"""
|
134
|
+
return Any(other, self.expr, operator=operator)
|
135
|
+
|
136
|
+
type = MutableList
|
137
|
+
comparator_factory = Comparator
|
138
|
+
|
139
|
+
def get_col_spec(self, **kws):
|
140
|
+
return "ARRAY(OBJECT)"
|
141
|
+
|
142
|
+
|
143
|
+
ObjectArray = MutableList.as_mutable(_ObjectArray)
|
@@ -0,0 +1,43 @@
|
|
1
|
+
import geojson
|
2
|
+
from sqlalchemy import types as sqltypes
|
3
|
+
from sqlalchemy.sql import default_comparator, operators
|
4
|
+
|
5
|
+
|
6
|
+
class Geopoint(sqltypes.UserDefinedType):
|
7
|
+
cache_ok = True
|
8
|
+
|
9
|
+
class Comparator(sqltypes.TypeEngine.Comparator):
|
10
|
+
def __getitem__(self, key):
|
11
|
+
return default_comparator._binary_operate(self.expr, operators.getitem, key)
|
12
|
+
|
13
|
+
def get_col_spec(self):
|
14
|
+
return "GEO_POINT"
|
15
|
+
|
16
|
+
def bind_processor(self, dialect):
|
17
|
+
def process(value):
|
18
|
+
if isinstance(value, geojson.Point):
|
19
|
+
return value.coordinates
|
20
|
+
return value
|
21
|
+
|
22
|
+
return process
|
23
|
+
|
24
|
+
def result_processor(self, dialect, coltype):
|
25
|
+
return tuple
|
26
|
+
|
27
|
+
comparator_factory = Comparator
|
28
|
+
|
29
|
+
|
30
|
+
class Geoshape(sqltypes.UserDefinedType):
|
31
|
+
cache_ok = True
|
32
|
+
|
33
|
+
class Comparator(sqltypes.TypeEngine.Comparator):
|
34
|
+
def __getitem__(self, key):
|
35
|
+
return default_comparator._binary_operate(self.expr, operators.getitem, key)
|
36
|
+
|
37
|
+
def get_col_spec(self):
|
38
|
+
return "GEO_SHAPE"
|
39
|
+
|
40
|
+
def result_processor(self, dialect, coltype):
|
41
|
+
return geojson.GeoJSON.to_instance
|
42
|
+
|
43
|
+
comparator_factory = Comparator
|
@@ -0,0 +1,94 @@
|
|
1
|
+
import warnings
|
2
|
+
|
3
|
+
from sqlalchemy import types as sqltypes
|
4
|
+
from sqlalchemy.ext.mutable import Mutable
|
5
|
+
|
6
|
+
|
7
|
+
class MutableDict(Mutable, dict):
|
8
|
+
@classmethod
|
9
|
+
def coerce(cls, key, value):
|
10
|
+
"Convert plain dictionaries to MutableDict."
|
11
|
+
|
12
|
+
if not isinstance(value, MutableDict):
|
13
|
+
if isinstance(value, dict):
|
14
|
+
return MutableDict(value)
|
15
|
+
|
16
|
+
# this call will raise ValueError
|
17
|
+
return Mutable.coerce(key, value)
|
18
|
+
else:
|
19
|
+
return value
|
20
|
+
|
21
|
+
def __init__(self, initval=None, to_update=None, root_change_key=None):
|
22
|
+
initval = initval or {}
|
23
|
+
self._changed_keys = set()
|
24
|
+
self._deleted_keys = set()
|
25
|
+
self._overwrite_key = root_change_key
|
26
|
+
self.to_update = self if to_update is None else to_update
|
27
|
+
for k in initval:
|
28
|
+
initval[k] = self._convert_dict(
|
29
|
+
initval[k], overwrite_key=k if self._overwrite_key is None else self._overwrite_key
|
30
|
+
)
|
31
|
+
dict.__init__(self, initval)
|
32
|
+
|
33
|
+
def __setitem__(self, key, value):
|
34
|
+
value = self._convert_dict(
|
35
|
+
value, key if self._overwrite_key is None else self._overwrite_key
|
36
|
+
)
|
37
|
+
dict.__setitem__(self, key, value)
|
38
|
+
self.to_update.on_key_changed(key if self._overwrite_key is None else self._overwrite_key)
|
39
|
+
|
40
|
+
def __delitem__(self, key):
|
41
|
+
dict.__delitem__(self, key)
|
42
|
+
# add the key to the deleted keys if this is the root object
|
43
|
+
# otherwise update on root object
|
44
|
+
if self._overwrite_key is None:
|
45
|
+
self._deleted_keys.add(key)
|
46
|
+
self.changed()
|
47
|
+
else:
|
48
|
+
self.to_update.on_key_changed(self._overwrite_key)
|
49
|
+
|
50
|
+
def on_key_changed(self, key):
|
51
|
+
self._deleted_keys.discard(key)
|
52
|
+
self._changed_keys.add(key)
|
53
|
+
self.changed()
|
54
|
+
|
55
|
+
def _convert_dict(self, value, overwrite_key):
|
56
|
+
if isinstance(value, dict) and not isinstance(value, MutableDict):
|
57
|
+
return MutableDict(value, self.to_update, overwrite_key)
|
58
|
+
return value
|
59
|
+
|
60
|
+
def __eq__(self, other):
|
61
|
+
return dict.__eq__(self, other)
|
62
|
+
|
63
|
+
|
64
|
+
class ObjectTypeImpl(sqltypes.UserDefinedType, sqltypes.JSON):
|
65
|
+
__visit_name__ = "OBJECT"
|
66
|
+
|
67
|
+
cache_ok = False
|
68
|
+
none_as_null = False
|
69
|
+
|
70
|
+
|
71
|
+
# Designated name to refer to. `Object` is too ambiguous.
|
72
|
+
ObjectType = MutableDict.as_mutable(ObjectTypeImpl)
|
73
|
+
|
74
|
+
# Backward-compatibility aliases.
|
75
|
+
_deprecated_Craty = ObjectType
|
76
|
+
_deprecated_Object = ObjectType
|
77
|
+
|
78
|
+
# https://www.lesinskis.com/deprecating-module-scope-variables.html
|
79
|
+
deprecated_names = ["Craty", "Object"]
|
80
|
+
|
81
|
+
|
82
|
+
def __getattr__(name):
|
83
|
+
if name in deprecated_names:
|
84
|
+
warnings.warn(
|
85
|
+
f"{name} is deprecated and will be removed in future releases. "
|
86
|
+
f"Please use ObjectType instead.",
|
87
|
+
category=DeprecationWarning,
|
88
|
+
stacklevel=2,
|
89
|
+
)
|
90
|
+
return globals()[f"_deprecated_{name}"]
|
91
|
+
raise AttributeError(f"module {__name__} has no attribute {name}")
|
92
|
+
|
93
|
+
|
94
|
+
__all__ = deprecated_names + ["ObjectType"]
|
@@ -0,0 +1,176 @@
|
|
1
|
+
"""
|
2
|
+
## About
|
3
|
+
SQLAlchemy data type implementation for CrateDB's `FLOAT_VECTOR` type.
|
4
|
+
|
5
|
+
## References
|
6
|
+
- https://cratedb.com/docs/crate/reference/en/latest/general/ddl/data-types.html#float-vector
|
7
|
+
- https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match
|
8
|
+
|
9
|
+
## Details
|
10
|
+
The implementation is based on SQLAlchemy's `TypeDecorator`, and also
|
11
|
+
offers compiler support.
|
12
|
+
|
13
|
+
## Notes
|
14
|
+
CrateDB currently only supports the similarity function `VectorSimilarityFunction.EUCLIDEAN`.
|
15
|
+
-- https://github.com/crate/crate/blob/5.5.1/server/src/main/java/io/crate/types/FloatVectorType.java#L55
|
16
|
+
|
17
|
+
pgvector use a comparator to apply different similarity functions as operators,
|
18
|
+
see `pgvector.sqlalchemy.Vector.comparator_factory`.
|
19
|
+
|
20
|
+
<->: l2/euclidean_distance
|
21
|
+
<#>: max_inner_product
|
22
|
+
<=>: cosine_distance
|
23
|
+
|
24
|
+
## Backlog
|
25
|
+
- After dropping support for SQLAlchemy 1.3, use
|
26
|
+
`class FloatVector(sa.TypeDecorator[t.Sequence[float]]):`
|
27
|
+
|
28
|
+
## Origin
|
29
|
+
This module is based on the corresponding pgvector implementation
|
30
|
+
by Andrew Kane. Thank you.
|
31
|
+
|
32
|
+
The MIT License (MIT)
|
33
|
+
Copyright (c) 2021-2023 Andrew Kane
|
34
|
+
https://github.com/pgvector/pgvector-python
|
35
|
+
"""
|
36
|
+
|
37
|
+
import typing as t
|
38
|
+
|
39
|
+
if t.TYPE_CHECKING:
|
40
|
+
import numpy.typing as npt # pragma: no cover
|
41
|
+
|
42
|
+
import sqlalchemy as sa
|
43
|
+
from sqlalchemy.ext.compiler import compiles
|
44
|
+
from sqlalchemy.sql.expression import ColumnElement, literal
|
45
|
+
|
46
|
+
__all__ = [
|
47
|
+
"from_db",
|
48
|
+
"knn_match",
|
49
|
+
"to_db",
|
50
|
+
"FloatVector",
|
51
|
+
]
|
52
|
+
|
53
|
+
|
54
|
+
def from_db(value: t.Iterable) -> t.Optional["npt.ArrayLike"]:
|
55
|
+
import numpy as np
|
56
|
+
|
57
|
+
# from `pgvector.utils`
|
58
|
+
# could be ndarray if already cast by lower-level driver
|
59
|
+
if value is None or isinstance(value, np.ndarray):
|
60
|
+
return value
|
61
|
+
|
62
|
+
return np.array(value, dtype=np.float32)
|
63
|
+
|
64
|
+
|
65
|
+
def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]:
|
66
|
+
import numpy as np
|
67
|
+
|
68
|
+
# from `pgvector.utils`
|
69
|
+
if value is None:
|
70
|
+
return value
|
71
|
+
|
72
|
+
if isinstance(value, np.ndarray):
|
73
|
+
if value.ndim != 1:
|
74
|
+
raise ValueError("expected ndim to be 1")
|
75
|
+
|
76
|
+
if not np.issubdtype(value.dtype, np.integer) and not np.issubdtype(
|
77
|
+
value.dtype, np.floating
|
78
|
+
):
|
79
|
+
raise ValueError("dtype must be numeric")
|
80
|
+
|
81
|
+
value = value.tolist()
|
82
|
+
|
83
|
+
if dim is not None and len(value) != dim:
|
84
|
+
raise ValueError("expected %d dimensions, not %d" % (dim, len(value)))
|
85
|
+
|
86
|
+
return value
|
87
|
+
|
88
|
+
|
89
|
+
class FloatVector(sa.TypeDecorator):
|
90
|
+
"""
|
91
|
+
SQLAlchemy `FloatVector` data type for CrateDB.
|
92
|
+
"""
|
93
|
+
|
94
|
+
cache_ok = False
|
95
|
+
|
96
|
+
__visit_name__ = "FLOAT_VECTOR"
|
97
|
+
|
98
|
+
_is_array = True
|
99
|
+
|
100
|
+
zero_indexes = False
|
101
|
+
|
102
|
+
impl = sa.ARRAY
|
103
|
+
|
104
|
+
def __init__(self, dimensions: int = None):
|
105
|
+
super().__init__(sa.FLOAT, dimensions=dimensions)
|
106
|
+
|
107
|
+
def as_generic(self, allow_nulltype=False):
|
108
|
+
return sa.ARRAY(item_type=sa.FLOAT)
|
109
|
+
|
110
|
+
@property
|
111
|
+
def python_type(self):
|
112
|
+
return list
|
113
|
+
|
114
|
+
def bind_processor(self, dialect: sa.engine.Dialect) -> t.Callable:
|
115
|
+
def process(value: t.Iterable) -> t.Optional[t.List]:
|
116
|
+
return to_db(value, self.dimensions)
|
117
|
+
|
118
|
+
return process
|
119
|
+
|
120
|
+
def result_processor(self, dialect: sa.engine.Dialect, coltype: t.Any) -> t.Callable:
|
121
|
+
def process(value: t.Any) -> t.Optional["npt.ArrayLike"]:
|
122
|
+
return from_db(value)
|
123
|
+
|
124
|
+
return process
|
125
|
+
|
126
|
+
|
127
|
+
class KnnMatch(ColumnElement):
|
128
|
+
"""
|
129
|
+
Wrap CrateDB's `KNN_MATCH` function into an SQLAlchemy function.
|
130
|
+
|
131
|
+
https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match
|
132
|
+
"""
|
133
|
+
|
134
|
+
inherit_cache = True
|
135
|
+
|
136
|
+
def __init__(self, column, term, k=None):
|
137
|
+
super().__init__()
|
138
|
+
self.column = column
|
139
|
+
self.term = term
|
140
|
+
self.k = k
|
141
|
+
|
142
|
+
def compile_column(self, compiler):
|
143
|
+
return compiler.process(self.column)
|
144
|
+
|
145
|
+
def compile_term(self, compiler):
|
146
|
+
return compiler.process(literal(self.term))
|
147
|
+
|
148
|
+
def compile_k(self, compiler):
|
149
|
+
return compiler.process(literal(self.k))
|
150
|
+
|
151
|
+
|
152
|
+
def knn_match(column, term, k):
|
153
|
+
"""
|
154
|
+
Generate a match predicate for vector search.
|
155
|
+
|
156
|
+
:param column: A reference to a column or an index.
|
157
|
+
|
158
|
+
:param term: The term to match against. This is an array of floating point
|
159
|
+
values, which is compared to other vectors using a HNSW index search.
|
160
|
+
|
161
|
+
:param k: The `k` argument determines the number of nearest neighbours to
|
162
|
+
search in the index.
|
163
|
+
"""
|
164
|
+
return KnnMatch(column, term, k)
|
165
|
+
|
166
|
+
|
167
|
+
@compiles(KnnMatch)
|
168
|
+
def compile_knn_match(knn_match, compiler, **kwargs):
|
169
|
+
"""
|
170
|
+
Clause compiler for `KNN_MATCH`.
|
171
|
+
"""
|
172
|
+
return "KNN_MATCH(%s, %s, %s)" % (
|
173
|
+
knn_match.compile_column(compiler),
|
174
|
+
knn_match.compile_term(compiler),
|
175
|
+
knn_match.compile_k(compiler),
|
176
|
+
)
|