sqlalchemy-cratedb 0.41.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,130 @@
1
+ import typing as t
2
+
3
+ import sqlalchemy as sa
4
+ from sqlalchemy.event import listen
5
+
6
+ from sqlalchemy_cratedb.support.util import refresh_dirty, refresh_table
7
+
8
+
9
+ def patch_autoincrement_timestamp():
10
+ """
11
+ Configure SQLAlchemy model columns with an alternative to `autoincrement=True`.
12
+ Use the current timestamp instead.
13
+
14
+ This is used by CrateDB's MLflow adapter.
15
+
16
+ TODO: Maybe enable through a dialect parameter `crate_polyfill_autoincrement` or such.
17
+ """
18
+ import sqlalchemy.sql.schema as schema
19
+
20
+ init_dist = schema.Column.__init__
21
+
22
+ def __init__(self, *args, **kwargs):
23
+ if "autoincrement" in kwargs:
24
+ del kwargs["autoincrement"]
25
+ if "default" not in kwargs:
26
+ kwargs["default"] = sa.func.now()
27
+ init_dist(self, *args, **kwargs)
28
+
29
+ schema.Column.__init__ = __init__ # type: ignore[method-assign]
30
+
31
+
32
+ def check_uniqueness_factory(sa_entity, *attribute_names):
33
+ """
34
+ Run a manual column value uniqueness check on a table, and raise an IntegrityError if applicable.
35
+
36
+ CrateDB does not support the UNIQUE constraint on columns. This attempts to emulate it.
37
+
38
+ https://github.com/crate/sqlalchemy-cratedb/issues/76
39
+
40
+ This is used by CrateDB's MLflow adapter.
41
+
42
+ TODO: Maybe enable through a dialect parameter `crate_polyfill_unique` or such.
43
+ """ # noqa: E501
44
+
45
+ # Synthesize a canonical "name" for the constraint,
46
+ # composed of all column names involved.
47
+ constraint_name: str = "-".join(attribute_names)
48
+
49
+ def check_uniqueness(mapper, connection, target):
50
+ from sqlalchemy.exc import IntegrityError
51
+
52
+ if isinstance(target, sa_entity):
53
+ # TODO: How to use `session.query(SqlExperiment)` here?
54
+ stmt = mapper.selectable.select()
55
+ for attribute_name in attribute_names:
56
+ stmt = stmt.filter(
57
+ getattr(sa_entity, attribute_name) == getattr(target, attribute_name)
58
+ )
59
+ stmt = stmt.compile(bind=connection.engine)
60
+ results = connection.execute(stmt)
61
+ if results.rowcount > 0:
62
+ raise IntegrityError(
63
+ statement=stmt,
64
+ params=[],
65
+ orig=Exception(
66
+ f"DuplicateKeyException in table '{target.__tablename__}' "
67
+ f"on constraint '{constraint_name}'"
68
+ ),
69
+ )
70
+
71
+ return check_uniqueness
72
+
73
+
74
+ def refresh_after_dml_session(session: sa.orm.Session):
75
+ """
76
+ Run `REFRESH TABLE` after each DML operation (INSERT, UPDATE, DELETE).
77
+
78
+ CrateDB is eventually consistent, i.e. write operations are not flushed to
79
+ disk immediately, so readers may see stale data. In a traditional OLTP-like
80
+ application, this is not applicable.
81
+
82
+ This SQLAlchemy extension makes sure that data is synchronized after each
83
+ operation manipulating data.
84
+
85
+ > `after_{insert,update,delete}` events only apply to the session flush operation
86
+ > and do not apply to the ORM DML operations described at ORM-Enabled INSERT,
87
+ > UPDATE, and DELETE statements. To intercept ORM DML events, use
88
+ > `SessionEvents.do_orm_execute().`
89
+ > -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.MapperEvents.after_insert
90
+
91
+ > Intercept statement executions that occur on behalf of an ORM Session object.
92
+ > -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.SessionEvents.do_orm_execute
93
+
94
+ > Execute after flush has completed, but before commit has been called.
95
+ > -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.SessionEvents.after_flush
96
+
97
+ This is used by CrateDB's LangChain adapter.
98
+
99
+ TODO: Maybe enable through a dialect parameter `crate_dml_refresh` or such.
100
+ """ # noqa: E501
101
+ listen(session, "after_flush", refresh_dirty)
102
+
103
+
104
+ def refresh_after_dml_engine(engine: sa.engine.Engine):
105
+ """
106
+ Run `REFRESH TABLE` after each DML operation (INSERT, UPDATE, DELETE).
107
+
108
+ This is used by CrateDB's Singer/Meltano and `rdflib-sqlalchemy` adapters.
109
+ """
110
+
111
+ def receive_after_execute(
112
+ conn: sa.engine.Connection, clauseelement, multiparams, params, execution_options, result
113
+ ):
114
+ if isinstance(clauseelement, (sa.sql.Insert, sa.sql.Update, sa.sql.Delete)):
115
+ if not isinstance(clauseelement.table, sa.sql.Join):
116
+ refresh_table(conn, clauseelement.table)
117
+
118
+ sa.event.listen(engine, "after_execute", receive_after_execute)
119
+
120
+
121
+ def refresh_after_dml(engine_or_session: t.Union[sa.engine.Engine, sa.orm.Session]):
122
+ """
123
+ Run `REFRESH TABLE` after each DML operation (INSERT, UPDATE, DELETE).
124
+ """
125
+ if isinstance(engine_or_session, sa.engine.Engine):
126
+ refresh_after_dml_engine(engine_or_session)
127
+ elif isinstance(engine_or_session, (sa.orm.Session, sa.orm.scoping.scoped_session)):
128
+ refresh_after_dml_session(engine_or_session)
129
+ else:
130
+ raise TypeError(f"Unknown type: {type(engine_or_session)}")
@@ -0,0 +1,82 @@
1
+ import itertools
2
+ import typing as t
3
+
4
+ import sqlalchemy as sa
5
+
6
+ from sqlalchemy_cratedb.dialect import CrateDialect
7
+
8
+ if t.TYPE_CHECKING:
9
+ try:
10
+ from sqlalchemy.orm import DeclarativeBase
11
+ except ImportError:
12
+ pass
13
+
14
+
15
+ # An instance of the dialect used for quoting purposes.
16
+ identifier_preparer = CrateDialect().identifier_preparer
17
+
18
+
19
+ def refresh_table(
20
+ connection, target: t.Union[str, "DeclarativeBase", "sa.sql.selectable.TableClause"]
21
+ ):
22
+ """
23
+ Invoke a `REFRESH TABLE` statement.
24
+ """
25
+
26
+ if isinstance(target, sa.sql.selectable.TableClause):
27
+ full_table_name = f'"{target.name}"'
28
+ if target.schema is not None:
29
+ full_table_name = f'"{target.schema}".' + full_table_name
30
+ elif hasattr(target, "__tablename__"):
31
+ full_table_name = target.__tablename__
32
+ else:
33
+ full_table_name = target
34
+
35
+ sql = f"REFRESH TABLE {full_table_name}"
36
+ connection.execute(sa.text(sql))
37
+
38
+
39
+ def refresh_dirty(session, flush_context=None):
40
+ """
41
+ Invoke a `REFRESH TABLE` statement on each table entity flagged as "dirty".
42
+
43
+ SQLAlchemy event handler for the 'after_flush' event,
44
+ invoking `REFRESH TABLE` on each table which has been modified.
45
+ """
46
+ dirty_entities = itertools.chain(session.new, session.dirty, session.deleted)
47
+ dirty_classes = {entity.__class__ for entity in dirty_entities}
48
+ for class_ in dirty_classes:
49
+ refresh_table(session, class_)
50
+
51
+
52
+ def quote_relation_name(ident: str) -> str:
53
+ """
54
+ Quote a simple or full-qualified table/relation name, when needed.
55
+
56
+ Simple: <table>
57
+ Full-qualified: <schema>.<table>
58
+
59
+ Happy path examples:
60
+
61
+ foo => foo
62
+ Foo => "Foo"
63
+ "Foo" => "Foo"
64
+ foo.bar => foo.bar
65
+ foo-bar.baz_qux => "foo-bar".baz_qux
66
+
67
+ Such input strings will not be modified:
68
+
69
+ "foo.bar" => "foo.bar"
70
+ """
71
+
72
+ # If a quote exists at the beginning or the end of the input string,
73
+ # let's consider that the relation name has been quoted already.
74
+ if ident.startswith('"') or ident.endswith('"'):
75
+ return ident
76
+
77
+ # If a dot is included, it's a full-qualified identifier like <schema>.<table>.
78
+ # It needs to be split, in order to apply identifier quoting properly.
79
+ parts = ident.split(".")
80
+ if len(parts) > 3:
81
+ raise ValueError(f"Invalid relation name, too many parts: {ident}")
82
+ return ".".join(map(identifier_preparer.quote, parts))
@@ -0,0 +1,13 @@
1
+ from .array import ObjectArray
2
+ from .geo import Geopoint, Geoshape
3
+ from .object import ObjectType
4
+ from .vector import FloatVector, knn_match
5
+
6
+ __all__ = [
7
+ Geopoint,
8
+ Geoshape,
9
+ ObjectArray,
10
+ ObjectType,
11
+ FloatVector,
12
+ knn_match,
13
+ ]
@@ -0,0 +1,143 @@
1
+ # -*- coding: utf-8; -*-
2
+ #
3
+ # Licensed to CRATE Technology GmbH ("Crate") under one or more contributor
4
+ # license agreements. See the NOTICE file distributed with this work for
5
+ # additional information regarding copyright ownership. Crate licenses
6
+ # this file to you under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License. You may
8
+ # obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14
+ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15
+ # License for the specific language governing permissions and limitations
16
+ # under the License.
17
+ #
18
+ # However, if you have executed another commercial license agreement
19
+ # with Crate these terms will supersede the license and you may use the
20
+ # software solely pursuant to the terms of the relevant commercial agreement.
21
+
22
+ # ruff: noqa: A005 # Module `array` shadows a Python standard-library module
23
+
24
+ import sqlalchemy.types as sqltypes
25
+ from sqlalchemy.ext.mutable import Mutable
26
+ from sqlalchemy.sql import default_comparator, expression, operators
27
+
28
+
29
+ class MutableList(Mutable, list):
30
+ @classmethod
31
+ def coerce(cls, key, value):
32
+ """Convert plain list to MutableList"""
33
+ if not isinstance(value, MutableList):
34
+ if isinstance(value, list):
35
+ return MutableList(value)
36
+ elif value is None:
37
+ return value
38
+ else:
39
+ return MutableList([value])
40
+ else:
41
+ return value
42
+
43
+ def __init__(self, initval=None):
44
+ list.__init__(self, initval or [])
45
+
46
+ def __setitem__(self, key, value):
47
+ list.__setitem__(self, key, value)
48
+ self.changed()
49
+
50
+ def __eq__(self, other):
51
+ return list.__eq__(self, other)
52
+
53
+ def append(self, item):
54
+ list.append(self, item)
55
+ self.changed()
56
+
57
+ def insert(self, idx, item):
58
+ list.insert(self, idx, item)
59
+ self.changed()
60
+
61
+ def extend(self, iterable):
62
+ list.extend(self, iterable)
63
+ self.changed()
64
+
65
+ def pop(self, index=-1):
66
+ list.pop(self, index)
67
+ self.changed()
68
+
69
+ def remove(self, item):
70
+ list.remove(self, item)
71
+ self.changed()
72
+
73
+
74
+ class Any(expression.ColumnElement):
75
+ """Represent the clause ``left operator ANY (right)``. ``right`` must be
76
+ an array expression.
77
+
78
+ copied from postgresql dialect
79
+
80
+ .. seealso::
81
+
82
+ :class:`sqlalchemy.dialects.postgresql.ARRAY`
83
+
84
+ :meth:`sqlalchemy.dialects.postgresql.ARRAY.Comparator.any`
85
+ ARRAY-bound method
86
+
87
+ """
88
+
89
+ __visit_name__ = "any"
90
+ inherit_cache = True
91
+
92
+ def __init__(self, left, right, operator=operators.eq):
93
+ self.type = sqltypes.Boolean()
94
+ self.left = expression.literal(left)
95
+ self.right = right
96
+ self.operator = operator
97
+
98
+
99
+ class _ObjectArray(sqltypes.UserDefinedType):
100
+ cache_ok = True
101
+
102
+ class Comparator(sqltypes.TypeEngine.Comparator):
103
+ def __getitem__(self, key):
104
+ return default_comparator._binary_operate(self.expr, operators.getitem, key)
105
+
106
+ def any(self, other, operator=operators.eq):
107
+ """Return ``other operator ANY (array)`` clause.
108
+
109
+ Argument places are switched, because ANY requires array
110
+ expression to be on the right hand-side.
111
+
112
+ E.g.::
113
+
114
+ from sqlalchemy.sql import operators
115
+
116
+ conn.execute(
117
+ select([table.c.data]).where(
118
+ table.c.data.any(7, operator=operators.lt)
119
+ )
120
+ )
121
+
122
+ :param other: expression to be compared
123
+ :param operator: an operator object from the
124
+ :mod:`sqlalchemy.sql.operators`
125
+ package, defaults to :func:`.operators.eq`.
126
+
127
+ .. seealso::
128
+
129
+ :class:`.postgresql.Any`
130
+
131
+ :meth:`.postgresql.ARRAY.Comparator.all`
132
+
133
+ """
134
+ return Any(other, self.expr, operator=operator)
135
+
136
+ type = MutableList
137
+ comparator_factory = Comparator
138
+
139
+ def get_col_spec(self, **kws):
140
+ return "ARRAY(OBJECT)"
141
+
142
+
143
+ ObjectArray = MutableList.as_mutable(_ObjectArray)
@@ -0,0 +1,43 @@
1
+ import geojson
2
+ from sqlalchemy import types as sqltypes
3
+ from sqlalchemy.sql import default_comparator, operators
4
+
5
+
6
+ class Geopoint(sqltypes.UserDefinedType):
7
+ cache_ok = True
8
+
9
+ class Comparator(sqltypes.TypeEngine.Comparator):
10
+ def __getitem__(self, key):
11
+ return default_comparator._binary_operate(self.expr, operators.getitem, key)
12
+
13
+ def get_col_spec(self):
14
+ return "GEO_POINT"
15
+
16
+ def bind_processor(self, dialect):
17
+ def process(value):
18
+ if isinstance(value, geojson.Point):
19
+ return value.coordinates
20
+ return value
21
+
22
+ return process
23
+
24
+ def result_processor(self, dialect, coltype):
25
+ return tuple
26
+
27
+ comparator_factory = Comparator
28
+
29
+
30
+ class Geoshape(sqltypes.UserDefinedType):
31
+ cache_ok = True
32
+
33
+ class Comparator(sqltypes.TypeEngine.Comparator):
34
+ def __getitem__(self, key):
35
+ return default_comparator._binary_operate(self.expr, operators.getitem, key)
36
+
37
+ def get_col_spec(self):
38
+ return "GEO_SHAPE"
39
+
40
+ def result_processor(self, dialect, coltype):
41
+ return geojson.GeoJSON.to_instance
42
+
43
+ comparator_factory = Comparator
@@ -0,0 +1,94 @@
1
+ import warnings
2
+
3
+ from sqlalchemy import types as sqltypes
4
+ from sqlalchemy.ext.mutable import Mutable
5
+
6
+
7
+ class MutableDict(Mutable, dict):
8
+ @classmethod
9
+ def coerce(cls, key, value):
10
+ "Convert plain dictionaries to MutableDict."
11
+
12
+ if not isinstance(value, MutableDict):
13
+ if isinstance(value, dict):
14
+ return MutableDict(value)
15
+
16
+ # this call will raise ValueError
17
+ return Mutable.coerce(key, value)
18
+ else:
19
+ return value
20
+
21
+ def __init__(self, initval=None, to_update=None, root_change_key=None):
22
+ initval = initval or {}
23
+ self._changed_keys = set()
24
+ self._deleted_keys = set()
25
+ self._overwrite_key = root_change_key
26
+ self.to_update = self if to_update is None else to_update
27
+ for k in initval:
28
+ initval[k] = self._convert_dict(
29
+ initval[k], overwrite_key=k if self._overwrite_key is None else self._overwrite_key
30
+ )
31
+ dict.__init__(self, initval)
32
+
33
+ def __setitem__(self, key, value):
34
+ value = self._convert_dict(
35
+ value, key if self._overwrite_key is None else self._overwrite_key
36
+ )
37
+ dict.__setitem__(self, key, value)
38
+ self.to_update.on_key_changed(key if self._overwrite_key is None else self._overwrite_key)
39
+
40
+ def __delitem__(self, key):
41
+ dict.__delitem__(self, key)
42
+ # add the key to the deleted keys if this is the root object
43
+ # otherwise update on root object
44
+ if self._overwrite_key is None:
45
+ self._deleted_keys.add(key)
46
+ self.changed()
47
+ else:
48
+ self.to_update.on_key_changed(self._overwrite_key)
49
+
50
+ def on_key_changed(self, key):
51
+ self._deleted_keys.discard(key)
52
+ self._changed_keys.add(key)
53
+ self.changed()
54
+
55
+ def _convert_dict(self, value, overwrite_key):
56
+ if isinstance(value, dict) and not isinstance(value, MutableDict):
57
+ return MutableDict(value, self.to_update, overwrite_key)
58
+ return value
59
+
60
+ def __eq__(self, other):
61
+ return dict.__eq__(self, other)
62
+
63
+
64
+ class ObjectTypeImpl(sqltypes.UserDefinedType, sqltypes.JSON):
65
+ __visit_name__ = "OBJECT"
66
+
67
+ cache_ok = False
68
+ none_as_null = False
69
+
70
+
71
+ # Designated name to refer to. `Object` is too ambiguous.
72
+ ObjectType = MutableDict.as_mutable(ObjectTypeImpl)
73
+
74
+ # Backward-compatibility aliases.
75
+ _deprecated_Craty = ObjectType
76
+ _deprecated_Object = ObjectType
77
+
78
+ # https://www.lesinskis.com/deprecating-module-scope-variables.html
79
+ deprecated_names = ["Craty", "Object"]
80
+
81
+
82
+ def __getattr__(name):
83
+ if name in deprecated_names:
84
+ warnings.warn(
85
+ f"{name} is deprecated and will be removed in future releases. "
86
+ f"Please use ObjectType instead.",
87
+ category=DeprecationWarning,
88
+ stacklevel=2,
89
+ )
90
+ return globals()[f"_deprecated_{name}"]
91
+ raise AttributeError(f"module {__name__} has no attribute {name}")
92
+
93
+
94
+ __all__ = deprecated_names + ["ObjectType"]
@@ -0,0 +1,176 @@
1
+ """
2
+ ## About
3
+ SQLAlchemy data type implementation for CrateDB's `FLOAT_VECTOR` type.
4
+
5
+ ## References
6
+ - https://cratedb.com/docs/crate/reference/en/latest/general/ddl/data-types.html#float-vector
7
+ - https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match
8
+
9
+ ## Details
10
+ The implementation is based on SQLAlchemy's `TypeDecorator`, and also
11
+ offers compiler support.
12
+
13
+ ## Notes
14
+ CrateDB currently only supports the similarity function `VectorSimilarityFunction.EUCLIDEAN`.
15
+ -- https://github.com/crate/crate/blob/5.5.1/server/src/main/java/io/crate/types/FloatVectorType.java#L55
16
+
17
+ pgvector use a comparator to apply different similarity functions as operators,
18
+ see `pgvector.sqlalchemy.Vector.comparator_factory`.
19
+
20
+ <->: l2/euclidean_distance
21
+ <#>: max_inner_product
22
+ <=>: cosine_distance
23
+
24
+ ## Backlog
25
+ - After dropping support for SQLAlchemy 1.3, use
26
+ `class FloatVector(sa.TypeDecorator[t.Sequence[float]]):`
27
+
28
+ ## Origin
29
+ This module is based on the corresponding pgvector implementation
30
+ by Andrew Kane. Thank you.
31
+
32
+ The MIT License (MIT)
33
+ Copyright (c) 2021-2023 Andrew Kane
34
+ https://github.com/pgvector/pgvector-python
35
+ """
36
+
37
+ import typing as t
38
+
39
+ if t.TYPE_CHECKING:
40
+ import numpy.typing as npt # pragma: no cover
41
+
42
+ import sqlalchemy as sa
43
+ from sqlalchemy.ext.compiler import compiles
44
+ from sqlalchemy.sql.expression import ColumnElement, literal
45
+
46
+ __all__ = [
47
+ "from_db",
48
+ "knn_match",
49
+ "to_db",
50
+ "FloatVector",
51
+ ]
52
+
53
+
54
+ def from_db(value: t.Iterable) -> t.Optional["npt.ArrayLike"]:
55
+ import numpy as np
56
+
57
+ # from `pgvector.utils`
58
+ # could be ndarray if already cast by lower-level driver
59
+ if value is None or isinstance(value, np.ndarray):
60
+ return value
61
+
62
+ return np.array(value, dtype=np.float32)
63
+
64
+
65
+ def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]:
66
+ import numpy as np
67
+
68
+ # from `pgvector.utils`
69
+ if value is None:
70
+ return value
71
+
72
+ if isinstance(value, np.ndarray):
73
+ if value.ndim != 1:
74
+ raise ValueError("expected ndim to be 1")
75
+
76
+ if not np.issubdtype(value.dtype, np.integer) and not np.issubdtype(
77
+ value.dtype, np.floating
78
+ ):
79
+ raise ValueError("dtype must be numeric")
80
+
81
+ value = value.tolist()
82
+
83
+ if dim is not None and len(value) != dim:
84
+ raise ValueError("expected %d dimensions, not %d" % (dim, len(value)))
85
+
86
+ return value
87
+
88
+
89
+ class FloatVector(sa.TypeDecorator):
90
+ """
91
+ SQLAlchemy `FloatVector` data type for CrateDB.
92
+ """
93
+
94
+ cache_ok = False
95
+
96
+ __visit_name__ = "FLOAT_VECTOR"
97
+
98
+ _is_array = True
99
+
100
+ zero_indexes = False
101
+
102
+ impl = sa.ARRAY
103
+
104
+ def __init__(self, dimensions: int = None):
105
+ super().__init__(sa.FLOAT, dimensions=dimensions)
106
+
107
+ def as_generic(self, allow_nulltype=False):
108
+ return sa.ARRAY(item_type=sa.FLOAT)
109
+
110
+ @property
111
+ def python_type(self):
112
+ return list
113
+
114
+ def bind_processor(self, dialect: sa.engine.Dialect) -> t.Callable:
115
+ def process(value: t.Iterable) -> t.Optional[t.List]:
116
+ return to_db(value, self.dimensions)
117
+
118
+ return process
119
+
120
+ def result_processor(self, dialect: sa.engine.Dialect, coltype: t.Any) -> t.Callable:
121
+ def process(value: t.Any) -> t.Optional["npt.ArrayLike"]:
122
+ return from_db(value)
123
+
124
+ return process
125
+
126
+
127
+ class KnnMatch(ColumnElement):
128
+ """
129
+ Wrap CrateDB's `KNN_MATCH` function into an SQLAlchemy function.
130
+
131
+ https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match
132
+ """
133
+
134
+ inherit_cache = True
135
+
136
+ def __init__(self, column, term, k=None):
137
+ super().__init__()
138
+ self.column = column
139
+ self.term = term
140
+ self.k = k
141
+
142
+ def compile_column(self, compiler):
143
+ return compiler.process(self.column)
144
+
145
+ def compile_term(self, compiler):
146
+ return compiler.process(literal(self.term))
147
+
148
+ def compile_k(self, compiler):
149
+ return compiler.process(literal(self.k))
150
+
151
+
152
+ def knn_match(column, term, k):
153
+ """
154
+ Generate a match predicate for vector search.
155
+
156
+ :param column: A reference to a column or an index.
157
+
158
+ :param term: The term to match against. This is an array of floating point
159
+ values, which is compared to other vectors using a HNSW index search.
160
+
161
+ :param k: The `k` argument determines the number of nearest neighbours to
162
+ search in the index.
163
+ """
164
+ return KnnMatch(column, term, k)
165
+
166
+
167
+ @compiles(KnnMatch)
168
+ def compile_knn_match(knn_match, compiler, **kwargs):
169
+ """
170
+ Clause compiler for `KNN_MATCH`.
171
+ """
172
+ return "KNN_MATCH(%s, %s, %s)" % (
173
+ knn_match.compile_column(compiler),
174
+ knn_match.compile_term(compiler),
175
+ knn_match.compile_k(compiler),
176
+ )