sqlalchemy-cratedb 0.38.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,125 @@
1
+ import sqlalchemy as sa
2
+ from sqlalchemy.event import listen
3
+ import typing as t
4
+
5
+ from sqlalchemy_cratedb.support.util import refresh_dirty, refresh_table
6
+
7
+
8
+ def patch_autoincrement_timestamp():
9
+ """
10
+ Configure SQLAlchemy model columns with an alternative to `autoincrement=True`.
11
+ Use the current timestamp instead.
12
+
13
+ This is used by CrateDB's MLflow adapter.
14
+
15
+ TODO: Maybe enable through a dialect parameter `crate_polyfill_autoincrement` or such.
16
+ """
17
+ import sqlalchemy.sql.schema as schema
18
+
19
+ init_dist = schema.Column.__init__
20
+
21
+ def __init__(self, *args, **kwargs):
22
+ if "autoincrement" in kwargs:
23
+ del kwargs["autoincrement"]
24
+ if "default" not in kwargs:
25
+ kwargs["default"] = sa.func.now()
26
+ init_dist(self, *args, **kwargs)
27
+
28
+ schema.Column.__init__ = __init__ # type: ignore[method-assign]
29
+
30
+
31
+ def check_uniqueness_factory(sa_entity, *attribute_names):
32
+ """
33
+ Run a manual column value uniqueness check on a table, and raise an IntegrityError if applicable.
34
+
35
+ CrateDB does not support the UNIQUE constraint on columns. This attempts to emulate it.
36
+
37
+ https://github.com/crate/sqlalchemy-cratedb/issues/76
38
+
39
+ This is used by CrateDB's MLflow adapter.
40
+
41
+ TODO: Maybe enable through a dialect parameter `crate_polyfill_unique` or such.
42
+ """
43
+
44
+ # Synthesize a canonical "name" for the constraint,
45
+ # composed of all column names involved.
46
+ constraint_name: str = "-".join(attribute_names)
47
+
48
+ def check_uniqueness(mapper, connection, target):
49
+ from sqlalchemy.exc import IntegrityError
50
+
51
+ if isinstance(target, sa_entity):
52
+ # TODO: How to use `session.query(SqlExperiment)` here?
53
+ stmt = mapper.selectable.select()
54
+ for attribute_name in attribute_names:
55
+ stmt = stmt.filter(getattr(sa_entity, attribute_name) == getattr(target, attribute_name))
56
+ stmt = stmt.compile(bind=connection.engine)
57
+ results = connection.execute(stmt)
58
+ if results.rowcount > 0:
59
+ raise IntegrityError(
60
+ statement=stmt,
61
+ params=[],
62
+ orig=Exception(
63
+ f"DuplicateKeyException in table '{target.__tablename__}' " f"on constraint '{constraint_name}'"
64
+ ),
65
+ )
66
+
67
+ return check_uniqueness
68
+
69
+
70
+ def refresh_after_dml_session(session: sa.orm.Session):
71
+ """
72
+ Run `REFRESH TABLE` after each DML operation (INSERT, UPDATE, DELETE).
73
+
74
+ CrateDB is eventually consistent, i.e. write operations are not flushed to
75
+ disk immediately, so readers may see stale data. In a traditional OLTP-like
76
+ application, this is not applicable.
77
+
78
+ This SQLAlchemy extension makes sure that data is synchronized after each
79
+ operation manipulating data.
80
+
81
+ > `after_{insert,update,delete}` events only apply to the session flush operation
82
+ > and do not apply to the ORM DML operations described at ORM-Enabled INSERT,
83
+ > UPDATE, and DELETE statements. To intercept ORM DML events, use
84
+ > `SessionEvents.do_orm_execute().`
85
+ > -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.MapperEvents.after_insert
86
+
87
+ > Intercept statement executions that occur on behalf of an ORM Session object.
88
+ > -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.SessionEvents.do_orm_execute
89
+
90
+ > Execute after flush has completed, but before commit has been called.
91
+ > -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.SessionEvents.after_flush
92
+
93
+ This is used by CrateDB's LangChain adapter.
94
+
95
+ TODO: Maybe enable through a dialect parameter `crate_dml_refresh` or such.
96
+ """ # noqa: E501
97
+ listen(session, "after_flush", refresh_dirty)
98
+
99
+
100
+ def refresh_after_dml_engine(engine: sa.engine.Engine):
101
+ """
102
+ Run `REFRESH TABLE` after each DML operation (INSERT, UPDATE, DELETE).
103
+
104
+ This is used by CrateDB's Singer/Meltano and `rdflib-sqlalchemy` adapters.
105
+ """
106
+ def receive_after_execute(
107
+ conn: sa.engine.Connection, clauseelement, multiparams, params, execution_options, result
108
+ ):
109
+ if isinstance(clauseelement, (sa.sql.Insert, sa.sql.Update, sa.sql.Delete)):
110
+ if not isinstance(clauseelement.table, sa.sql.Join):
111
+ refresh_table(conn, clauseelement.table)
112
+
113
+ sa.event.listen(engine, "after_execute", receive_after_execute)
114
+
115
+
116
+ def refresh_after_dml(engine_or_session: t.Union[sa.engine.Engine, sa.orm.Session]):
117
+ """
118
+ Run `REFRESH TABLE` after each DML operation (INSERT, UPDATE, DELETE).
119
+ """
120
+ if isinstance(engine_or_session, sa.engine.Engine):
121
+ refresh_after_dml_engine(engine_or_session)
122
+ elif isinstance(engine_or_session, (sa.orm.Session, sa.orm.scoping.scoped_session)):
123
+ refresh_after_dml_session(engine_or_session)
124
+ else:
125
+ raise TypeError(f"Unknown type: {type(engine_or_session)}")
@@ -0,0 +1,41 @@
1
+ import itertools
2
+ import typing as t
3
+
4
+ import sqlalchemy as sa
5
+
6
+ if t.TYPE_CHECKING:
7
+ try:
8
+ from sqlalchemy.orm import DeclarativeBase
9
+ except ImportError:
10
+ pass
11
+
12
+
13
+ def refresh_table(connection, target: t.Union[str, "DeclarativeBase", "sa.sql.selectable.TableClause"]):
14
+ """
15
+ Invoke a `REFRESH TABLE` statement.
16
+ """
17
+
18
+ if isinstance(target, sa.sql.selectable.TableClause):
19
+ full_table_name = f'"{target.name}"'
20
+ if target.schema is not None:
21
+ full_table_name = f'"{target.schema}".' + full_table_name
22
+ elif hasattr(target, "__tablename__"):
23
+ full_table_name = target.__tablename__
24
+ else:
25
+ full_table_name = target
26
+
27
+ sql = f"REFRESH TABLE {full_table_name}"
28
+ connection.execute(sa.text(sql))
29
+
30
+
31
+ def refresh_dirty(session, flush_context=None):
32
+ """
33
+ Invoke a `REFRESH TABLE` statement on each table entity flagged as "dirty".
34
+
35
+ SQLAlchemy event handler for the 'after_flush' event,
36
+ invoking `REFRESH TABLE` on each table which has been modified.
37
+ """
38
+ dirty_entities = itertools.chain(session.new, session.dirty, session.deleted)
39
+ dirty_classes = {entity.__class__ for entity in dirty_entities}
40
+ for class_ in dirty_classes:
41
+ refresh_table(session, class_)
@@ -0,0 +1,4 @@
1
+ from .array import ObjectArray
2
+ from .geo import Geopoint, Geoshape
3
+ from .object import ObjectType
4
+ from .vector import FloatVector, knn_match
@@ -0,0 +1,144 @@
1
+ # -*- coding: utf-8; -*-
2
+ #
3
+ # Licensed to CRATE Technology GmbH ("Crate") under one or more contributor
4
+ # license agreements. See the NOTICE file distributed with this work for
5
+ # additional information regarding copyright ownership. Crate licenses
6
+ # this file to you under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License. You may
8
+ # obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14
+ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15
+ # License for the specific language governing permissions and limitations
16
+ # under the License.
17
+ #
18
+ # However, if you have executed another commercial license agreement
19
+ # with Crate these terms will supersede the license and you may use the
20
+ # software solely pursuant to the terms of the relevant commercial agreement.
21
+
22
+ import sqlalchemy.types as sqltypes
23
+ from sqlalchemy.sql import operators, expression
24
+ from sqlalchemy.sql import default_comparator
25
+ from sqlalchemy.ext.mutable import Mutable
26
+
27
+
28
+ class MutableList(Mutable, list):
29
+
30
+ @classmethod
31
+ def coerce(cls, key, value):
32
+ """ Convert plain list to MutableList """
33
+ if not isinstance(value, MutableList):
34
+ if isinstance(value, list):
35
+ return MutableList(value)
36
+ elif value is None:
37
+ return value
38
+ else:
39
+ return MutableList([value])
40
+ else:
41
+ return value
42
+
43
+ def __init__(self, initval=None):
44
+ list.__init__(self, initval or [])
45
+
46
+ def __setitem__(self, key, value):
47
+ list.__setitem__(self, key, value)
48
+ self.changed()
49
+
50
+ def __eq__(self, other):
51
+ return list.__eq__(self, other)
52
+
53
+ def append(self, item):
54
+ list.append(self, item)
55
+ self.changed()
56
+
57
+ def insert(self, idx, item):
58
+ list.insert(self, idx, item)
59
+ self.changed()
60
+
61
+ def extend(self, iterable):
62
+ list.extend(self, iterable)
63
+ self.changed()
64
+
65
+ def pop(self, index=-1):
66
+ list.pop(self, index)
67
+ self.changed()
68
+
69
+ def remove(self, item):
70
+ list.remove(self, item)
71
+ self.changed()
72
+
73
+
74
+ class Any(expression.ColumnElement):
75
+ """Represent the clause ``left operator ANY (right)``. ``right`` must be
76
+ an array expression.
77
+
78
+ copied from postgresql dialect
79
+
80
+ .. seealso::
81
+
82
+ :class:`sqlalchemy.dialects.postgresql.ARRAY`
83
+
84
+ :meth:`sqlalchemy.dialects.postgresql.ARRAY.Comparator.any`
85
+ ARRAY-bound method
86
+
87
+ """
88
+ __visit_name__ = 'any'
89
+ inherit_cache = True
90
+
91
+ def __init__(self, left, right, operator=operators.eq):
92
+ self.type = sqltypes.Boolean()
93
+ self.left = expression.literal(left)
94
+ self.right = right
95
+ self.operator = operator
96
+
97
+
98
+ class _ObjectArray(sqltypes.UserDefinedType):
99
+ cache_ok = True
100
+
101
+ class Comparator(sqltypes.TypeEngine.Comparator):
102
+ def __getitem__(self, key):
103
+ return default_comparator._binary_operate(self.expr,
104
+ operators.getitem,
105
+ key)
106
+
107
+ def any(self, other, operator=operators.eq):
108
+ """Return ``other operator ANY (array)`` clause.
109
+
110
+ Argument places are switched, because ANY requires array
111
+ expression to be on the right hand-side.
112
+
113
+ E.g.::
114
+
115
+ from sqlalchemy.sql import operators
116
+
117
+ conn.execute(
118
+ select([table.c.data]).where(
119
+ table.c.data.any(7, operator=operators.lt)
120
+ )
121
+ )
122
+
123
+ :param other: expression to be compared
124
+ :param operator: an operator object from the
125
+ :mod:`sqlalchemy.sql.operators`
126
+ package, defaults to :func:`.operators.eq`.
127
+
128
+ .. seealso::
129
+
130
+ :class:`.postgresql.Any`
131
+
132
+ :meth:`.postgresql.ARRAY.Comparator.all`
133
+
134
+ """
135
+ return Any(other, self.expr, operator=operator)
136
+
137
+ type = MutableList
138
+ comparator_factory = Comparator
139
+
140
+ def get_col_spec(self, **kws):
141
+ return "ARRAY(OBJECT)"
142
+
143
+
144
+ ObjectArray = MutableList.as_mutable(_ObjectArray)
@@ -0,0 +1,48 @@
1
+ import geojson
2
+ from sqlalchemy import types as sqltypes
3
+ from sqlalchemy.sql import default_comparator, operators
4
+
5
+
6
+ class Geopoint(sqltypes.UserDefinedType):
7
+ cache_ok = True
8
+
9
+ class Comparator(sqltypes.TypeEngine.Comparator):
10
+
11
+ def __getitem__(self, key):
12
+ return default_comparator._binary_operate(self.expr,
13
+ operators.getitem,
14
+ key)
15
+
16
+ def get_col_spec(self):
17
+ return 'GEO_POINT'
18
+
19
+ def bind_processor(self, dialect):
20
+ def process(value):
21
+ if isinstance(value, geojson.Point):
22
+ return value.coordinates
23
+ return value
24
+ return process
25
+
26
+ def result_processor(self, dialect, coltype):
27
+ return tuple
28
+
29
+ comparator_factory = Comparator
30
+
31
+
32
+ class Geoshape(sqltypes.UserDefinedType):
33
+ cache_ok = True
34
+
35
+ class Comparator(sqltypes.TypeEngine.Comparator):
36
+
37
+ def __getitem__(self, key):
38
+ return default_comparator._binary_operate(self.expr,
39
+ operators.getitem,
40
+ key)
41
+
42
+ def get_col_spec(self):
43
+ return 'GEO_SHAPE'
44
+
45
+ def result_processor(self, dialect, coltype):
46
+ return geojson.GeoJSON.to_instance
47
+
48
+ comparator_factory = Comparator
@@ -0,0 +1,92 @@
1
+ import warnings
2
+
3
+ from sqlalchemy import types as sqltypes
4
+ from sqlalchemy.ext.mutable import Mutable
5
+
6
+
7
+ class MutableDict(Mutable, dict):
8
+
9
+ @classmethod
10
+ def coerce(cls, key, value):
11
+ "Convert plain dictionaries to MutableDict."
12
+
13
+ if not isinstance(value, MutableDict):
14
+ if isinstance(value, dict):
15
+ return MutableDict(value)
16
+
17
+ # this call will raise ValueError
18
+ return Mutable.coerce(key, value)
19
+ else:
20
+ return value
21
+
22
+ def __init__(self, initval=None, to_update=None, root_change_key=None):
23
+ initval = initval or {}
24
+ self._changed_keys = set()
25
+ self._deleted_keys = set()
26
+ self._overwrite_key = root_change_key
27
+ self.to_update = self if to_update is None else to_update
28
+ for k in initval:
29
+ initval[k] = self._convert_dict(initval[k],
30
+ overwrite_key=k if self._overwrite_key is None else self._overwrite_key
31
+ )
32
+ dict.__init__(self, initval)
33
+
34
+ def __setitem__(self, key, value):
35
+ value = self._convert_dict(value, key if self._overwrite_key is None else self._overwrite_key)
36
+ dict.__setitem__(self, key, value)
37
+ self.to_update.on_key_changed(
38
+ key if self._overwrite_key is None else self._overwrite_key
39
+ )
40
+
41
+ def __delitem__(self, key):
42
+ dict.__delitem__(self, key)
43
+ # add the key to the deleted keys if this is the root object
44
+ # otherwise update on root object
45
+ if self._overwrite_key is None:
46
+ self._deleted_keys.add(key)
47
+ self.changed()
48
+ else:
49
+ self.to_update.on_key_changed(self._overwrite_key)
50
+
51
+ def on_key_changed(self, key):
52
+ self._deleted_keys.discard(key)
53
+ self._changed_keys.add(key)
54
+ self.changed()
55
+
56
+ def _convert_dict(self, value, overwrite_key):
57
+ if isinstance(value, dict) and not isinstance(value, MutableDict):
58
+ return MutableDict(value, self.to_update, overwrite_key)
59
+ return value
60
+
61
+ def __eq__(self, other):
62
+ return dict.__eq__(self, other)
63
+
64
+
65
+ class ObjectTypeImpl(sqltypes.UserDefinedType, sqltypes.JSON):
66
+
67
+ __visit_name__ = "OBJECT"
68
+
69
+ cache_ok = False
70
+ none_as_null = False
71
+
72
+
73
+ # Designated name to refer to. `Object` is too ambiguous.
74
+ ObjectType = MutableDict.as_mutable(ObjectTypeImpl)
75
+
76
+ # Backward-compatibility aliases.
77
+ _deprecated_Craty = ObjectType
78
+ _deprecated_Object = ObjectType
79
+
80
+ # https://www.lesinskis.com/deprecating-module-scope-variables.html
81
+ deprecated_names = ["Craty", "Object"]
82
+
83
+
84
+ def __getattr__(name):
85
+ if name in deprecated_names:
86
+ warnings.warn(f"{name} is deprecated and will be removed in future releases. "
87
+ f"Please use ObjectType instead.", DeprecationWarning)
88
+ return globals()[f"_deprecated_{name}"]
89
+ raise AttributeError(f"module {__name__} has no attribute {name}")
90
+
91
+
92
+ __all__ = deprecated_names + ["ObjectType"]
@@ -0,0 +1,173 @@
1
+ """
2
+ ## About
3
+ SQLAlchemy data type implementation for CrateDB's `FLOAT_VECTOR` type.
4
+
5
+ ## References
6
+ - https://cratedb.com/docs/crate/reference/en/latest/general/ddl/data-types.html#float-vector
7
+ - https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match
8
+
9
+ ## Details
10
+ The implementation is based on SQLAlchemy's `TypeDecorator`, and also
11
+ offers compiler support.
12
+
13
+ ## Notes
14
+ CrateDB currently only supports the similarity function `VectorSimilarityFunction.EUCLIDEAN`.
15
+ -- https://github.com/crate/crate/blob/5.5.1/server/src/main/java/io/crate/types/FloatVectorType.java#L55
16
+
17
+ pgvector use a comparator to apply different similarity functions as operators,
18
+ see `pgvector.sqlalchemy.Vector.comparator_factory`.
19
+
20
+ <->: l2/euclidean_distance
21
+ <#>: max_inner_product
22
+ <=>: cosine_distance
23
+
24
+ ## Backlog
25
+ - After dropping support for SQLAlchemy 1.3, use
26
+ `class FloatVector(sa.TypeDecorator[t.Sequence[float]]):`
27
+
28
+ ## Origin
29
+ This module is based on the corresponding pgvector implementation
30
+ by Andrew Kane. Thank you.
31
+
32
+ The MIT License (MIT)
33
+ Copyright (c) 2021-2023 Andrew Kane
34
+ https://github.com/pgvector/pgvector-python
35
+ """
36
+ import typing as t
37
+
38
+ if t.TYPE_CHECKING:
39
+ import numpy.typing as npt # pragma: no cover
40
+
41
+ import sqlalchemy as sa
42
+ from sqlalchemy.sql.expression import ColumnElement, literal
43
+ from sqlalchemy.ext.compiler import compiles
44
+
45
+
46
+ __all__ = [
47
+ "from_db",
48
+ "knn_match",
49
+ "to_db",
50
+ "FloatVector",
51
+ ]
52
+
53
+
54
+ def from_db(value: t.Iterable) -> t.Optional["npt.ArrayLike"]:
55
+ import numpy as np
56
+
57
+ # from `pgvector.utils`
58
+ # could be ndarray if already cast by lower-level driver
59
+ if value is None or isinstance(value, np.ndarray):
60
+ return value
61
+
62
+ return np.array(value, dtype=np.float32)
63
+
64
+
65
+ def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]:
66
+ import numpy as np
67
+
68
+ # from `pgvector.utils`
69
+ if value is None:
70
+ return value
71
+
72
+ if isinstance(value, np.ndarray):
73
+ if value.ndim != 1:
74
+ raise ValueError("expected ndim to be 1")
75
+
76
+ if not np.issubdtype(value.dtype, np.integer) and not np.issubdtype(value.dtype, np.floating):
77
+ raise ValueError("dtype must be numeric")
78
+
79
+ value = value.tolist()
80
+
81
+ if dim is not None and len(value) != dim:
82
+ raise ValueError("expected %d dimensions, not %d" % (dim, len(value)))
83
+
84
+ return value
85
+
86
+
87
+ class FloatVector(sa.TypeDecorator):
88
+ """
89
+ SQLAlchemy `FloatVector` data type for CrateDB.
90
+ """
91
+
92
+ cache_ok = False
93
+
94
+ __visit_name__ = "FLOAT_VECTOR"
95
+
96
+ _is_array = True
97
+
98
+ zero_indexes = False
99
+
100
+ impl = sa.ARRAY
101
+
102
+ def __init__(self, dimensions: int = None):
103
+ super().__init__(sa.FLOAT, dimensions=dimensions)
104
+
105
+ def as_generic(self, allow_nulltype=False):
106
+ return sa.ARRAY(item_type=sa.FLOAT)
107
+
108
+ @property
109
+ def python_type(self):
110
+ return list
111
+
112
+ def bind_processor(self, dialect: sa.engine.Dialect) -> t.Callable:
113
+ def process(value: t.Iterable) -> t.Optional[t.List]:
114
+ return to_db(value, self.dimensions)
115
+
116
+ return process
117
+
118
+ def result_processor(self, dialect: sa.engine.Dialect, coltype: t.Any) -> t.Callable:
119
+ def process(value: t.Any) -> t.Optional["npt.ArrayLike"]:
120
+ return from_db(value)
121
+
122
+ return process
123
+
124
+
125
+ class KnnMatch(ColumnElement):
126
+ """
127
+ Wrap CrateDB's `KNN_MATCH` function into an SQLAlchemy function.
128
+
129
+ https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match
130
+ """
131
+ inherit_cache = True
132
+
133
+ def __init__(self, column, term, k=None):
134
+ super().__init__()
135
+ self.column = column
136
+ self.term = term
137
+ self.k = k
138
+
139
+ def compile_column(self, compiler):
140
+ return compiler.process(self.column)
141
+
142
+ def compile_term(self, compiler):
143
+ return compiler.process(literal(self.term))
144
+
145
+ def compile_k(self, compiler):
146
+ return compiler.process(literal(self.k))
147
+
148
+
149
+ def knn_match(column, term, k):
150
+ """
151
+ Generate a match predicate for vector search.
152
+
153
+ :param column: A reference to a column or an index.
154
+
155
+ :param term: The term to match against. This is an array of floating point
156
+ values, which is compared to other vectors using a HNSW index search.
157
+
158
+ :param k: The `k` argument determines the number of nearest neighbours to
159
+ search in the index.
160
+ """
161
+ return KnnMatch(column, term, k)
162
+
163
+
164
+ @compiles(KnnMatch)
165
+ def compile_knn_match(knn_match, compiler, **kwargs):
166
+ """
167
+ Clause compiler for `KNN_MATCH`.
168
+ """
169
+ return "KNN_MATCH(%s, %s, %s)" % (
170
+ knn_match.compile_column(compiler),
171
+ knn_match.compile_term(compiler),
172
+ knn_match.compile_k(compiler),
173
+ )