pyobvector 0.2.22__py3-none-any.whl → 0.2.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyobvector/__init__.py +6 -5
- pyobvector/client/__init__.py +5 -4
- pyobvector/client/collection_schema.py +5 -1
- pyobvector/client/enum.py +1 -1
- pyobvector/client/exceptions.py +9 -7
- pyobvector/client/fts_index_param.py +8 -4
- pyobvector/client/hybrid_search.py +10 -4
- pyobvector/client/index_param.py +56 -41
- pyobvector/client/milvus_like_client.py +71 -54
- pyobvector/client/ob_client.py +20 -16
- pyobvector/client/ob_vec_client.py +40 -39
- pyobvector/client/ob_vec_json_table_client.py +366 -274
- pyobvector/client/partitions.py +81 -39
- pyobvector/client/schema_type.py +3 -1
- pyobvector/json_table/__init__.py +4 -3
- pyobvector/json_table/json_value_returning_func.py +12 -10
- pyobvector/json_table/oceanbase_dialect.py +15 -8
- pyobvector/json_table/virtual_data_type.py +47 -28
- pyobvector/schema/__init__.py +7 -1
- pyobvector/schema/array.py +6 -2
- pyobvector/schema/dialect.py +4 -0
- pyobvector/schema/full_text_index.py +8 -3
- pyobvector/schema/geo_srid_point.py +5 -2
- pyobvector/schema/gis_func.py +23 -11
- pyobvector/schema/match_against_func.py +10 -5
- pyobvector/schema/ob_table.py +2 -0
- pyobvector/schema/reflection.py +25 -8
- pyobvector/schema/replace_stmt.py +4 -0
- pyobvector/schema/sparse_vector.py +7 -4
- pyobvector/schema/vec_dist_func.py +22 -9
- pyobvector/schema/vector.py +3 -1
- pyobvector/schema/vector_index.py +7 -3
- pyobvector/util/__init__.py +1 -0
- pyobvector/util/ob_version.py +2 -0
- pyobvector/util/sparse_vector.py +9 -6
- pyobvector/util/vector.py +2 -0
- {pyobvector-0.2.22.dist-info → pyobvector-0.2.23.dist-info}/METADATA +13 -14
- pyobvector-0.2.23.dist-info/RECORD +40 -0
- {pyobvector-0.2.22.dist-info → pyobvector-0.2.23.dist-info}/licenses/LICENSE +1 -1
- pyobvector-0.2.22.dist-info/RECORD +0 -40
- {pyobvector-0.2.22.dist-info → pyobvector-0.2.23.dist-info}/WHEEL +0 -0
pyobvector/schema/dialect.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""OceanBase dialect."""
|
|
2
|
+
|
|
2
3
|
from sqlalchemy import util
|
|
3
4
|
from sqlalchemy.dialects.mysql import aiomysql, pymysql
|
|
4
5
|
|
|
@@ -7,10 +8,12 @@ from .vector import VECTOR
|
|
|
7
8
|
from .sparse_vector import SPARSE_VECTOR
|
|
8
9
|
from .geo_srid_point import POINT
|
|
9
10
|
|
|
11
|
+
|
|
10
12
|
class OceanBaseDialect(pymysql.MySQLDialect_pymysql):
|
|
11
13
|
# not change dialect name, since it is a subclass of pymysql.MySQLDialect_pymysql
|
|
12
14
|
# name = "oceanbase"
|
|
13
15
|
"""Ocenbase dialect."""
|
|
16
|
+
|
|
14
17
|
supports_statement_cache = True
|
|
15
18
|
|
|
16
19
|
def __init__(self, **kwargs):
|
|
@@ -36,6 +39,7 @@ class OceanBaseDialect(pymysql.MySQLDialect_pymysql):
|
|
|
36
39
|
|
|
37
40
|
class AsyncOceanBaseDialect(aiomysql.MySQLDialect_aiomysql):
|
|
38
41
|
"""OceanBase async dialect."""
|
|
42
|
+
|
|
39
43
|
supports_statement_cache = True
|
|
40
44
|
|
|
41
45
|
def __init__(self, **kwargs):
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""FullTextIndex: full text search index type"""
|
|
2
|
+
|
|
2
3
|
from sqlalchemy import Index
|
|
3
4
|
from sqlalchemy.schema import DDLElement
|
|
4
5
|
from sqlalchemy.ext.compiler import compiles
|
|
@@ -7,16 +8,18 @@ from sqlalchemy.sql.ddl import SchemaGenerator
|
|
|
7
8
|
|
|
8
9
|
class CreateFtsIndex(DDLElement):
|
|
9
10
|
"""A new statement clause to create fts index.
|
|
10
|
-
|
|
11
|
+
|
|
11
12
|
Attributes:
|
|
12
13
|
index : fts index schema
|
|
13
14
|
"""
|
|
15
|
+
|
|
14
16
|
def __init__(self, index):
|
|
15
17
|
self.index = index
|
|
16
18
|
|
|
17
19
|
|
|
18
20
|
class ObFtsSchemaGenerator(SchemaGenerator):
|
|
19
21
|
"""A new schema generator to handle create fts index statement."""
|
|
22
|
+
|
|
20
23
|
def visit_fts_index(self, index, create_ok=False):
|
|
21
24
|
"""Handle create fts index statement compiling.
|
|
22
25
|
|
|
@@ -29,8 +32,10 @@ class ObFtsSchemaGenerator(SchemaGenerator):
|
|
|
29
32
|
with self.with_ddl_events(index):
|
|
30
33
|
CreateFtsIndex(index)._invoke_with(self.connection)
|
|
31
34
|
|
|
35
|
+
|
|
32
36
|
class FtsIndex(Index):
|
|
33
37
|
"""Fts Index schema."""
|
|
38
|
+
|
|
34
39
|
__visit_name__ = "fts_index"
|
|
35
40
|
|
|
36
41
|
def __init__(self, name, fts_parser: str, *column_names, **kw):
|
|
@@ -39,7 +44,7 @@ class FtsIndex(Index):
|
|
|
39
44
|
|
|
40
45
|
def create(self, bind, checkfirst: bool = False) -> None:
|
|
41
46
|
"""Create fts index.
|
|
42
|
-
|
|
47
|
+
|
|
43
48
|
Args:
|
|
44
49
|
bind: SQL engine or connection.
|
|
45
50
|
checkfirst: check the index exists or not.
|
|
@@ -48,7 +53,7 @@ class FtsIndex(Index):
|
|
|
48
53
|
|
|
49
54
|
|
|
50
55
|
@compiles(CreateFtsIndex)
|
|
51
|
-
def compile_create_fts_index(element, compiler, **kw):
|
|
56
|
+
def compile_create_fts_index(element, compiler, **kw): # pylint: disable=unused-argument
|
|
52
57
|
"""A decorator function to compile create fts index statement."""
|
|
53
58
|
index = element.index
|
|
54
59
|
table_name = index.table.name
|
|
@@ -1,23 +1,26 @@
|
|
|
1
1
|
"""Point: OceanBase GIS data type for SQLAlchemy"""
|
|
2
|
+
|
|
2
3
|
from typing import Optional
|
|
3
4
|
from sqlalchemy.types import UserDefinedType, String
|
|
4
5
|
|
|
6
|
+
|
|
5
7
|
class POINT(UserDefinedType):
|
|
6
8
|
"""Point data type definition."""
|
|
9
|
+
|
|
7
10
|
cache_ok = True
|
|
8
11
|
_string = String()
|
|
9
12
|
|
|
10
13
|
def __init__(
|
|
11
14
|
self,
|
|
12
15
|
# lat_long: Tuple[float, float],
|
|
13
|
-
srid: Optional[int] = None
|
|
16
|
+
srid: Optional[int] = None,
|
|
14
17
|
):
|
|
15
18
|
"""Init Latitude and Longitude."""
|
|
16
19
|
super(UserDefinedType, self).__init__()
|
|
17
20
|
# self.lat_long = lat_long
|
|
18
21
|
self.srid = srid
|
|
19
22
|
|
|
20
|
-
def get_col_spec(self, **kw):
|
|
23
|
+
def get_col_spec(self, **kw): # pylint: disable=unused-argument
|
|
21
24
|
"""Parse to Point data type definition in text SQL."""
|
|
22
25
|
if self.srid is None:
|
|
23
26
|
return "POINT"
|
pyobvector/schema/gis_func.py
CHANGED
|
@@ -10,31 +10,34 @@ from .geo_srid_point import POINT
|
|
|
10
10
|
|
|
11
11
|
logger = logging.getLogger(__name__)
|
|
12
12
|
|
|
13
|
+
|
|
13
14
|
class ST_GeomFromText(FunctionElement):
|
|
14
15
|
"""ST_GeomFromText: parse text to geometry object.
|
|
15
|
-
|
|
16
|
+
|
|
16
17
|
Attributes:
|
|
17
18
|
type : result type
|
|
18
19
|
"""
|
|
20
|
+
|
|
19
21
|
type = BINARY()
|
|
20
22
|
|
|
21
23
|
def __init__(self, *args):
|
|
22
24
|
super().__init__()
|
|
23
25
|
self.args = args
|
|
24
26
|
|
|
27
|
+
|
|
25
28
|
@compiles(ST_GeomFromText)
|
|
26
|
-
def compile_ST_GeomFromText(element, compiler, **kwargs):
|
|
29
|
+
def compile_ST_GeomFromText(element, compiler, **kwargs): # pylint: disable=unused-argument
|
|
27
30
|
"""Compile ST_GeomFromText function."""
|
|
28
31
|
args = []
|
|
29
32
|
for idx, arg in enumerate(element.args):
|
|
30
33
|
if idx == 0:
|
|
31
34
|
if (
|
|
32
|
-
(not isinstance(arg, tuple))
|
|
33
|
-
(len(arg) != 2)
|
|
34
|
-
(not all(isinstance(x, float) for x in arg))
|
|
35
|
+
(not isinstance(arg, tuple))
|
|
36
|
+
or (len(arg) != 2)
|
|
37
|
+
or (not all(isinstance(x, float) for x in arg))
|
|
35
38
|
):
|
|
36
39
|
raise ValueError(
|
|
37
|
-
f"Tuple[float, float] is expected for Point literal,"
|
|
40
|
+
f"Tuple[float, float] is expected for Point literal,"
|
|
38
41
|
f"while get {type(arg)}"
|
|
39
42
|
)
|
|
40
43
|
args.append(f"'{POINT.to_db(arg)}'")
|
|
@@ -44,12 +47,14 @@ def compile_ST_GeomFromText(element, compiler, **kwargs): # pylint: disable=unus
|
|
|
44
47
|
# logger.info(f"{args_str}")
|
|
45
48
|
return f"ST_GeomFromText({args_str})"
|
|
46
49
|
|
|
50
|
+
|
|
47
51
|
class st_distance(FunctionElement):
|
|
48
52
|
"""st_distance: calculate distance between Points.
|
|
49
|
-
|
|
53
|
+
|
|
50
54
|
Attributes:
|
|
51
55
|
type : result type
|
|
52
56
|
"""
|
|
57
|
+
|
|
53
58
|
type = Float()
|
|
54
59
|
inherit_cache = True
|
|
55
60
|
|
|
@@ -57,19 +62,22 @@ class st_distance(FunctionElement):
|
|
|
57
62
|
super().__init__()
|
|
58
63
|
self.args = args
|
|
59
64
|
|
|
65
|
+
|
|
60
66
|
@compiles(st_distance)
|
|
61
|
-
def compile_st_distance(element, compiler, **kwargs):
|
|
67
|
+
def compile_st_distance(element, compiler, **kwargs): # pylint: disable=unused-argument
|
|
62
68
|
"""Compile st_distance function."""
|
|
63
69
|
args = ", ".join(compiler.process(arg) for arg in element.args)
|
|
64
70
|
return f"st_distance({args})"
|
|
65
71
|
|
|
72
|
+
|
|
66
73
|
class st_dwithin(FunctionElement):
|
|
67
74
|
"""st_dwithin: Checks if the distance between two points
|
|
68
75
|
is less than a specified distance.
|
|
69
|
-
|
|
76
|
+
|
|
70
77
|
Attributes:
|
|
71
78
|
type : result type
|
|
72
79
|
"""
|
|
80
|
+
|
|
73
81
|
type = Boolean()
|
|
74
82
|
inherit_cache = True
|
|
75
83
|
|
|
@@ -77,8 +85,9 @@ class st_dwithin(FunctionElement):
|
|
|
77
85
|
super().__init__()
|
|
78
86
|
self.args = args
|
|
79
87
|
|
|
88
|
+
|
|
80
89
|
@compiles(st_dwithin)
|
|
81
|
-
def compile_st_dwithin(element, compiler, **kwargs):
|
|
90
|
+
def compile_st_dwithin(element, compiler, **kwargs): # pylint: disable=unused-argument
|
|
82
91
|
"""Compile st_dwithin function."""
|
|
83
92
|
args = []
|
|
84
93
|
for idx, arg in enumerate(element.args):
|
|
@@ -89,12 +98,14 @@ def compile_st_dwithin(element, compiler, **kwargs): # pylint: disable=unused-ar
|
|
|
89
98
|
args_str = ", ".join(args)
|
|
90
99
|
return f"_st_dwithin({args_str})"
|
|
91
100
|
|
|
101
|
+
|
|
92
102
|
class st_astext(FunctionElement):
|
|
93
103
|
"""st_astext: Returns a Point in human-readable format.
|
|
94
104
|
|
|
95
105
|
Attributes:
|
|
96
106
|
type : result type
|
|
97
107
|
"""
|
|
108
|
+
|
|
98
109
|
type = Text()
|
|
99
110
|
inherit_cache = True
|
|
100
111
|
|
|
@@ -102,8 +113,9 @@ class st_astext(FunctionElement):
|
|
|
102
113
|
super().__init__()
|
|
103
114
|
self.args = args
|
|
104
115
|
|
|
116
|
+
|
|
105
117
|
@compiles(st_astext)
|
|
106
|
-
def compile_st_astext(element, compiler, **kwargs):
|
|
118
|
+
def compile_st_astext(element, compiler, **kwargs): # pylint: disable=unused-argument
|
|
107
119
|
"""Compile st_astext function."""
|
|
108
120
|
args = ", ".join(compiler.process(arg) for arg in element.args)
|
|
109
121
|
return f"st_astext({args})"
|
|
@@ -8,31 +8,36 @@ from sqlalchemy import literal, column
|
|
|
8
8
|
|
|
9
9
|
logger = logging.getLogger(__name__)
|
|
10
10
|
|
|
11
|
+
|
|
11
12
|
class MatchAgainst(FunctionElement):
|
|
12
13
|
"""MatchAgainst: match clause for full text search.
|
|
13
14
|
|
|
14
15
|
Attributes:
|
|
15
16
|
type : result type
|
|
16
17
|
"""
|
|
18
|
+
|
|
17
19
|
inherit_cache = True
|
|
18
20
|
|
|
19
21
|
def __init__(self, query, *columns):
|
|
20
22
|
columns = [column(col) if isinstance(col, str) else col for col in columns]
|
|
21
23
|
super().__init__(literal(query), *columns)
|
|
22
|
-
|
|
24
|
+
|
|
25
|
+
|
|
23
26
|
@compiles(MatchAgainst)
|
|
24
|
-
def complie_MatchAgainst(element, compiler, **kwargs):
|
|
27
|
+
def complie_MatchAgainst(element, compiler, **kwargs): # pylint: disable=unused-argument
|
|
25
28
|
"""Compile MatchAgainst function."""
|
|
26
29
|
clauses = list(element.clauses)
|
|
27
30
|
if len(clauses) < 2:
|
|
28
31
|
raise ValueError(
|
|
29
|
-
f"MatchAgainst should take a string expression and "
|
|
32
|
+
f"MatchAgainst should take a string expression and "
|
|
30
33
|
f"at least one column name string as parameters."
|
|
31
34
|
)
|
|
32
|
-
|
|
35
|
+
|
|
33
36
|
query_expr = clauses[0]
|
|
34
37
|
compiled_query = compiler.process(query_expr, **kwargs)
|
|
35
38
|
column_exprs = clauses[1:]
|
|
36
|
-
compiled_columns = [
|
|
39
|
+
compiled_columns = [
|
|
40
|
+
compiler.process(col, identifier_prepared=True) for col in column_exprs
|
|
41
|
+
]
|
|
37
42
|
columns_str = ", ".join(compiled_columns)
|
|
38
43
|
return f"MATCH ({columns_str}) AGAINST ({compiled_query} IN NATURAL LANGUAGE MODE)"
|
pyobvector/schema/ob_table.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
"""ObTable: extension to Table for creating table with vector index."""
|
|
2
|
+
|
|
2
3
|
from sqlalchemy import Table
|
|
3
4
|
from .vector_index import ObSchemaGenerator
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
class ObTable(Table):
|
|
7
8
|
"""A class extends SQLAlchemy Table to do table creation with vector index."""
|
|
9
|
+
|
|
8
10
|
def create(self, bind, checkfirst: bool = False) -> None:
|
|
9
11
|
bind._run_ddl_visitor(ObSchemaGenerator, self, checkfirst=checkfirst)
|
pyobvector/schema/reflection.py
CHANGED
|
@@ -1,14 +1,21 @@
|
|
|
1
1
|
"""OceanBase table definition reflection."""
|
|
2
|
+
|
|
2
3
|
import re
|
|
3
4
|
import logging
|
|
4
|
-
from sqlalchemy.dialects.mysql.reflection import
|
|
5
|
+
from sqlalchemy.dialects.mysql.reflection import (
|
|
6
|
+
MySQLTableDefinitionParser,
|
|
7
|
+
_re_compile,
|
|
8
|
+
cleanup_text,
|
|
9
|
+
)
|
|
5
10
|
|
|
6
11
|
from pyobvector.schema.array import ARRAY
|
|
7
12
|
|
|
8
13
|
logger = logging.getLogger(__name__)
|
|
9
14
|
|
|
15
|
+
|
|
10
16
|
class OceanBaseTableDefinitionParser(MySQLTableDefinitionParser):
|
|
11
17
|
"""OceanBase table definition parser."""
|
|
18
|
+
|
|
12
19
|
def __init__(self, dialect, preparer, *, default_schema=None):
|
|
13
20
|
MySQLTableDefinitionParser.__init__(self, dialect, preparer)
|
|
14
21
|
self.default_schema = default_schema
|
|
@@ -82,15 +89,19 @@ class OceanBaseTableDefinitionParser(MySQLTableDefinitionParser):
|
|
|
82
89
|
m = self._re_array_column.match(line)
|
|
83
90
|
if m:
|
|
84
91
|
spec = m.groupdict()
|
|
85
|
-
name, coltype_with_args =
|
|
92
|
+
name, coltype_with_args = (
|
|
93
|
+
spec["name"].strip(),
|
|
94
|
+
spec["coltype_with_args"].strip(),
|
|
95
|
+
)
|
|
86
96
|
|
|
87
97
|
item_pattern = re.compile(
|
|
88
|
-
r"^(?:array\s*\()*([\w]+)(?:\(([\d,]+)\))?\)*$",
|
|
89
|
-
re.IGNORECASE
|
|
98
|
+
r"^(?:array\s*\()*([\w]+)(?:\(([\d,]+)\))?\)*$", re.IGNORECASE
|
|
90
99
|
)
|
|
91
100
|
item_m = item_pattern.match(coltype_with_args)
|
|
92
101
|
if not item_m:
|
|
93
|
-
raise ValueError(
|
|
102
|
+
raise ValueError(
|
|
103
|
+
f"Failed to find inner type from array column definition: {line}"
|
|
104
|
+
)
|
|
94
105
|
|
|
95
106
|
item_type = self.dialect.ischema_names[item_m.group(1).lower()]
|
|
96
107
|
item_type_arg = item_m.group(2)
|
|
@@ -99,9 +110,11 @@ class OceanBaseTableDefinitionParser(MySQLTableDefinitionParser):
|
|
|
99
110
|
elif item_type_arg[0] == "'" and item_type_arg[-1] == "'":
|
|
100
111
|
item_type_args = self._re_csv_str.findall(item_type_arg)
|
|
101
112
|
else:
|
|
102
|
-
item_type_args = [
|
|
113
|
+
item_type_args = [
|
|
114
|
+
int(v) for v in self._re_csv_int.findall(item_type_arg)
|
|
115
|
+
]
|
|
103
116
|
|
|
104
|
-
nested_level = coltype_with_args.lower().count(
|
|
117
|
+
nested_level = coltype_with_args.lower().count("array")
|
|
105
118
|
type_instance = item_type(*item_type_args)
|
|
106
119
|
for _ in range(nested_level):
|
|
107
120
|
type_instance = ARRAY(type_instance)
|
|
@@ -144,7 +157,11 @@ class OceanBaseTableDefinitionParser(MySQLTableDefinitionParser):
|
|
|
144
157
|
|
|
145
158
|
if tp == "fk_constraint":
|
|
146
159
|
table = spec.get("table", [])
|
|
147
|
-
if
|
|
160
|
+
if (
|
|
161
|
+
isinstance(table, list)
|
|
162
|
+
and len(table) == 2
|
|
163
|
+
and table[0] == self.default_schema
|
|
164
|
+
):
|
|
148
165
|
spec["table"] = table[1:]
|
|
149
166
|
|
|
150
167
|
for action in ["onupdate", "ondelete"]:
|
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
"""ReplaceStmt: replace into statement compilation."""
|
|
2
|
+
|
|
2
3
|
from sqlalchemy.ext.compiler import compiles
|
|
3
4
|
from sqlalchemy.sql.expression import Insert
|
|
4
5
|
|
|
6
|
+
|
|
5
7
|
class ReplaceStmt(Insert):
|
|
6
8
|
"""Replace into statement."""
|
|
9
|
+
|
|
7
10
|
inherit_cache = True
|
|
8
11
|
|
|
12
|
+
|
|
9
13
|
@compiles(ReplaceStmt)
|
|
10
14
|
def compile_replace_stmt(insert, compiler, **kw):
|
|
11
15
|
"""Compile replace into statement.
|
|
@@ -1,25 +1,28 @@
|
|
|
1
1
|
"""SPARSE_VECTOR: An extended data type for SQLAlchemy"""
|
|
2
|
+
|
|
2
3
|
from sqlalchemy.types import UserDefinedType, String
|
|
3
4
|
from ..util import SparseVector
|
|
4
5
|
|
|
6
|
+
|
|
5
7
|
class SPARSE_VECTOR(UserDefinedType):
|
|
6
8
|
"""SPARSE_VECTOR data type definition."""
|
|
9
|
+
|
|
7
10
|
cache_ok = True
|
|
8
11
|
_string = String()
|
|
9
12
|
|
|
10
13
|
def __init__(self):
|
|
11
14
|
super(UserDefinedType, self).__init__()
|
|
12
15
|
|
|
13
|
-
def get_col_spec(self, **kw):
|
|
16
|
+
def get_col_spec(self, **kw): # pylint: disable=unused-argument
|
|
14
17
|
"""Parse to sparse vector data type definition in text SQL."""
|
|
15
18
|
return "SPARSEVECTOR"
|
|
16
|
-
|
|
19
|
+
|
|
17
20
|
def bind_processor(self, dialect):
|
|
18
21
|
def process(value):
|
|
19
22
|
return SparseVector._to_db(value)
|
|
20
23
|
|
|
21
24
|
return process
|
|
22
|
-
|
|
25
|
+
|
|
23
26
|
def literal_processor(self, dialect):
|
|
24
27
|
string_literal_processor = self._string._cached_literal_processor(dialect)
|
|
25
28
|
|
|
@@ -32,4 +35,4 @@ class SPARSE_VECTOR(UserDefinedType):
|
|
|
32
35
|
def process(value):
|
|
33
36
|
return SparseVector._from_db(value)
|
|
34
37
|
|
|
35
|
-
return process
|
|
38
|
+
return process
|
|
@@ -11,6 +11,7 @@ from sqlalchemy import Float
|
|
|
11
11
|
|
|
12
12
|
logger = logging.getLogger(__name__)
|
|
13
13
|
|
|
14
|
+
|
|
14
15
|
def parse_vec_distance_func_args(element, compiler, **kwargs):
|
|
15
16
|
args = []
|
|
16
17
|
for arg in element.args:
|
|
@@ -21,20 +22,23 @@ def parse_vec_distance_func_args(element, compiler, **kwargs):
|
|
|
21
22
|
args = ", ".join(args)
|
|
22
23
|
return args
|
|
23
24
|
|
|
25
|
+
|
|
24
26
|
class l2_distance(FunctionElement):
|
|
25
27
|
"""Vector distance function: l2_distance.
|
|
26
|
-
|
|
28
|
+
|
|
27
29
|
Attributes:
|
|
28
30
|
type : result type
|
|
29
31
|
"""
|
|
32
|
+
|
|
30
33
|
type = Float()
|
|
31
34
|
|
|
32
35
|
def __init__(self, *args):
|
|
33
36
|
super().__init__()
|
|
34
37
|
self.args = args
|
|
35
38
|
|
|
39
|
+
|
|
36
40
|
@compiles(l2_distance)
|
|
37
|
-
def compile_l2_distance(element, compiler, **kwargs):
|
|
41
|
+
def compile_l2_distance(element, compiler, **kwargs): # pylint: disable=unused-argument
|
|
38
42
|
"""Compile l2_distance function.
|
|
39
43
|
|
|
40
44
|
Args:
|
|
@@ -46,41 +50,47 @@ def compile_l2_distance(element, compiler, **kwargs): # pylint: disable=unused-a
|
|
|
46
50
|
|
|
47
51
|
class cosine_distance(FunctionElement):
|
|
48
52
|
"""Vector distance function: cosine_distance.
|
|
49
|
-
|
|
53
|
+
|
|
50
54
|
Attributes:
|
|
51
55
|
type : result type
|
|
52
56
|
"""
|
|
57
|
+
|
|
53
58
|
type = Float()
|
|
54
59
|
|
|
55
60
|
def __init__(self, *args):
|
|
56
61
|
super().__init__()
|
|
57
62
|
self.args = args
|
|
58
63
|
|
|
64
|
+
|
|
59
65
|
@compiles(cosine_distance)
|
|
60
|
-
def compile_cosine_distance(element, compiler, **kwargs):
|
|
66
|
+
def compile_cosine_distance(element, compiler, **kwargs): # pylint: disable=unused-argument
|
|
61
67
|
"""Compile cosine_distance function.
|
|
62
68
|
|
|
63
69
|
Args:
|
|
64
70
|
element: cosine_distance arguments
|
|
65
71
|
compiler: SQL compiler
|
|
66
72
|
"""
|
|
67
|
-
return
|
|
73
|
+
return (
|
|
74
|
+
f"cosine_distance({parse_vec_distance_func_args(element, compiler, **kwargs)})"
|
|
75
|
+
)
|
|
68
76
|
|
|
69
77
|
|
|
70
78
|
class inner_product(FunctionElement):
|
|
71
79
|
"""Vector distance function: inner_product.
|
|
72
|
-
|
|
80
|
+
|
|
73
81
|
Attributes:
|
|
74
82
|
type : result type
|
|
75
83
|
"""
|
|
84
|
+
|
|
76
85
|
type = Float()
|
|
77
86
|
|
|
78
87
|
def __init__(self, *args):
|
|
79
88
|
super().__init__()
|
|
80
89
|
self.args = args
|
|
81
90
|
|
|
91
|
+
|
|
82
92
|
@compiles(inner_product)
|
|
83
|
-
def compile_inner_product(element, compiler, **kwargs):
|
|
93
|
+
def compile_inner_product(element, compiler, **kwargs): # pylint: disable=unused-argument
|
|
84
94
|
"""Compile inner_product function.
|
|
85
95
|
|
|
86
96
|
Args:
|
|
@@ -89,20 +99,23 @@ def compile_inner_product(element, compiler, **kwargs): # pylint: disable=unused
|
|
|
89
99
|
"""
|
|
90
100
|
return f"inner_product({parse_vec_distance_func_args(element, compiler, **kwargs)})"
|
|
91
101
|
|
|
102
|
+
|
|
92
103
|
class negative_inner_product(FunctionElement):
|
|
93
104
|
"""Vector distance function: negative_inner_product.
|
|
94
|
-
|
|
105
|
+
|
|
95
106
|
Attributes:
|
|
96
107
|
type : result type
|
|
97
108
|
"""
|
|
109
|
+
|
|
98
110
|
type = Float()
|
|
99
111
|
|
|
100
112
|
def __init__(self, *args):
|
|
101
113
|
super().__init__()
|
|
102
114
|
self.args = args
|
|
103
115
|
|
|
116
|
+
|
|
104
117
|
@compiles(negative_inner_product)
|
|
105
|
-
def compile_negative_inner_product(element, compiler, **kwargs):
|
|
118
|
+
def compile_negative_inner_product(element, compiler, **kwargs): # pylint: disable=unused-argument
|
|
106
119
|
"""Compile negative_inner_product function.
|
|
107
120
|
|
|
108
121
|
Args:
|
pyobvector/schema/vector.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
"""VECTOR: An extended data type for SQLAlchemy"""
|
|
2
|
+
|
|
2
3
|
from sqlalchemy.types import UserDefinedType, String
|
|
3
4
|
from ..util import Vector
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
class VECTOR(UserDefinedType):
|
|
7
8
|
"""VECTOR data type definition."""
|
|
9
|
+
|
|
8
10
|
cache_ok = True
|
|
9
11
|
_string = String()
|
|
10
12
|
|
|
@@ -12,7 +14,7 @@ class VECTOR(UserDefinedType):
|
|
|
12
14
|
super(UserDefinedType, self).__init__()
|
|
13
15
|
self.dim = dim
|
|
14
16
|
|
|
15
|
-
def get_col_spec(self, **kw):
|
|
17
|
+
def get_col_spec(self, **kw): # pylint: disable=unused-argument
|
|
16
18
|
"""Parse to vector data type definition in text SQL."""
|
|
17
19
|
if self.dim is None:
|
|
18
20
|
return "VECTOR"
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""VectorIndex: An extended index type for SQLAlchemy"""
|
|
2
|
+
|
|
2
3
|
from sqlalchemy import Index
|
|
3
4
|
from sqlalchemy.schema import DDLElement
|
|
4
5
|
from sqlalchemy.ext.compiler import compiles
|
|
@@ -7,16 +8,18 @@ from sqlalchemy.sql.ddl import SchemaGenerator
|
|
|
7
8
|
|
|
8
9
|
class CreateVectorIndex(DDLElement):
|
|
9
10
|
"""A new statement clause to create vector index.
|
|
10
|
-
|
|
11
|
+
|
|
11
12
|
Attributes:
|
|
12
13
|
index: vector index schema
|
|
13
14
|
"""
|
|
15
|
+
|
|
14
16
|
def __init__(self, index):
|
|
15
17
|
self.index = index
|
|
16
18
|
|
|
17
19
|
|
|
18
20
|
class ObSchemaGenerator(SchemaGenerator):
|
|
19
21
|
"""A new schema generator to handle create vector index statement."""
|
|
22
|
+
|
|
20
23
|
def visit_vector_index(self, index, create_ok=False):
|
|
21
24
|
"""Handle create vector index statement compiling.
|
|
22
25
|
|
|
@@ -32,6 +35,7 @@ class ObSchemaGenerator(SchemaGenerator):
|
|
|
32
35
|
|
|
33
36
|
class VectorIndex(Index):
|
|
34
37
|
"""Vector Index schema."""
|
|
38
|
+
|
|
35
39
|
__visit_name__ = "vector_index"
|
|
36
40
|
|
|
37
41
|
def __init__(self, name, *column_names, params: str = None, **kw):
|
|
@@ -44,7 +48,7 @@ class VectorIndex(Index):
|
|
|
44
48
|
|
|
45
49
|
def create(self, bind, checkfirst: bool = False) -> None:
|
|
46
50
|
"""Create vector index.
|
|
47
|
-
|
|
51
|
+
|
|
48
52
|
Args:
|
|
49
53
|
bind: SQL engine or connection.
|
|
50
54
|
checkfirst: check the index exists or not.
|
|
@@ -53,7 +57,7 @@ class VectorIndex(Index):
|
|
|
53
57
|
|
|
54
58
|
|
|
55
59
|
@compiles(CreateVectorIndex)
|
|
56
|
-
def compile_create_vector_index(element, compiler, **kw):
|
|
60
|
+
def compile_create_vector_index(element, compiler, **kw): # pylint: disable=unused-argument
|
|
57
61
|
"""A decorator function to compile create vector index statement."""
|
|
58
62
|
index = element.index
|
|
59
63
|
table_name = index.table.name
|
pyobvector/util/__init__.py
CHANGED
pyobvector/util/ob_version.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""OceanBase cluster version module."""
|
|
2
|
+
|
|
2
3
|
import copy
|
|
3
4
|
|
|
4
5
|
|
|
@@ -8,6 +9,7 @@ class ObVersion:
|
|
|
8
9
|
Attributes:
|
|
9
10
|
version_nums (List[int]): version number of OceanBase cluster. For example, '4.3.3.0'
|
|
10
11
|
"""
|
|
12
|
+
|
|
11
13
|
def __init__(self, version_nums: list[int]):
|
|
12
14
|
self.version_nums = copy.deepcopy(version_nums)
|
|
13
15
|
|
pyobvector/util/sparse_vector.py
CHANGED
|
@@ -1,24 +1,27 @@
|
|
|
1
1
|
"""A utility module for the extended data type class 'SPARSE_VECTOR'."""
|
|
2
|
+
|
|
2
3
|
import ast
|
|
3
4
|
|
|
5
|
+
|
|
4
6
|
class SparseVector:
|
|
5
7
|
"""A transformer class between python dict and OceanBase SPARSE_VECTOR.
|
|
6
8
|
|
|
7
9
|
Attributes:
|
|
8
10
|
_value (Dict) : a python dict
|
|
9
11
|
"""
|
|
12
|
+
|
|
10
13
|
def __init__(self, value):
|
|
11
14
|
if not isinstance(value, dict):
|
|
12
15
|
raise ValueError("Sparse Vector should be a dict in python")
|
|
13
|
-
|
|
16
|
+
|
|
14
17
|
self._value = value
|
|
15
|
-
|
|
18
|
+
|
|
16
19
|
def __repr__(self):
|
|
17
20
|
return f"{self._value}"
|
|
18
|
-
|
|
21
|
+
|
|
19
22
|
def to_text(self):
|
|
20
23
|
return f"{self._value}"
|
|
21
|
-
|
|
24
|
+
|
|
22
25
|
@classmethod
|
|
23
26
|
def from_text(cls, value: str):
|
|
24
27
|
"""Construct Sparse Vector class with dict in string format.
|
|
@@ -27,7 +30,7 @@ class SparseVector:
|
|
|
27
30
|
value: For example, '{1:1.1, 2:2.2}'
|
|
28
31
|
"""
|
|
29
32
|
return cls(ast.literal_eval(value))
|
|
30
|
-
|
|
33
|
+
|
|
31
34
|
@classmethod
|
|
32
35
|
def _to_db(cls, value):
|
|
33
36
|
if value is None:
|
|
@@ -45,4 +48,4 @@ class SparseVector:
|
|
|
45
48
|
|
|
46
49
|
if isinstance(value, str):
|
|
47
50
|
return cls.from_text(value)._value
|
|
48
|
-
raise ValueError(f"unexpected sparse vector type: {type(value)}")
|
|
51
|
+
raise ValueError(f"unexpected sparse vector type: {type(value)}")
|
pyobvector/util/vector.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""A utility module for the extended data type class 'VECTOR'."""
|
|
2
|
+
|
|
2
3
|
import json
|
|
3
4
|
import numpy as np
|
|
4
5
|
|
|
@@ -9,6 +10,7 @@ class Vector:
|
|
|
9
10
|
Attributes:
|
|
10
11
|
_value (numpy.array): a numpy array
|
|
11
12
|
"""
|
|
13
|
+
|
|
12
14
|
def __init__(self, value):
|
|
13
15
|
# big-endian float32
|
|
14
16
|
if not isinstance(value, np.ndarray) or value.dtype != ">f4":
|