pyobvector 0.2.22__py3-none-any.whl → 0.2.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyobvector/__init__.py +6 -5
- pyobvector/client/__init__.py +5 -4
- pyobvector/client/collection_schema.py +5 -1
- pyobvector/client/enum.py +1 -1
- pyobvector/client/exceptions.py +9 -7
- pyobvector/client/fts_index_param.py +8 -4
- pyobvector/client/hybrid_search.py +10 -4
- pyobvector/client/index_param.py +56 -41
- pyobvector/client/milvus_like_client.py +71 -54
- pyobvector/client/ob_client.py +20 -16
- pyobvector/client/ob_vec_client.py +47 -40
- pyobvector/client/ob_vec_json_table_client.py +366 -274
- pyobvector/client/partitions.py +81 -39
- pyobvector/client/schema_type.py +3 -1
- pyobvector/json_table/__init__.py +4 -3
- pyobvector/json_table/json_value_returning_func.py +12 -10
- pyobvector/json_table/oceanbase_dialect.py +15 -8
- pyobvector/json_table/virtual_data_type.py +47 -28
- pyobvector/schema/__init__.py +7 -1
- pyobvector/schema/array.py +6 -2
- pyobvector/schema/dialect.py +4 -0
- pyobvector/schema/full_text_index.py +8 -3
- pyobvector/schema/geo_srid_point.py +5 -2
- pyobvector/schema/gis_func.py +23 -11
- pyobvector/schema/match_against_func.py +10 -5
- pyobvector/schema/ob_table.py +2 -0
- pyobvector/schema/reflection.py +25 -8
- pyobvector/schema/replace_stmt.py +4 -0
- pyobvector/schema/sparse_vector.py +7 -4
- pyobvector/schema/vec_dist_func.py +22 -9
- pyobvector/schema/vector.py +3 -1
- pyobvector/schema/vector_index.py +7 -3
- pyobvector/util/__init__.py +1 -0
- pyobvector/util/ob_version.py +2 -0
- pyobvector/util/sparse_vector.py +9 -6
- pyobvector/util/vector.py +2 -0
- {pyobvector-0.2.22.dist-info → pyobvector-0.2.24.dist-info}/METADATA +13 -14
- pyobvector-0.2.24.dist-info/RECORD +40 -0
- {pyobvector-0.2.22.dist-info → pyobvector-0.2.24.dist-info}/licenses/LICENSE +1 -1
- pyobvector-0.2.22.dist-info/RECORD +0 -40
- {pyobvector-0.2.22.dist-info → pyobvector-0.2.24.dist-info}/WHEEL +0 -0
pyobvector/client/partitions.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""A module to do compilation of OceanBase Parition Clause."""
|
|
2
|
+
|
|
2
3
|
from typing import Optional, Union
|
|
3
4
|
import logging
|
|
4
5
|
from dataclasses import dataclass
|
|
@@ -11,6 +12,7 @@ logger.setLevel(logging.DEBUG)
|
|
|
11
12
|
|
|
12
13
|
class PartType(IntEnum):
|
|
13
14
|
"""Partition type of table or collection for both ObVecClient and MilvusLikeClient"""
|
|
15
|
+
|
|
14
16
|
Range = 0
|
|
15
17
|
Hash = 1
|
|
16
18
|
Key = 2
|
|
@@ -21,12 +23,13 @@ class PartType(IntEnum):
|
|
|
21
23
|
|
|
22
24
|
class ObPartition:
|
|
23
25
|
"""Base class of all kind of Partition strategy
|
|
24
|
-
|
|
26
|
+
|
|
25
27
|
Attributes:
|
|
26
28
|
part_type (PartType) : type of partition strategy
|
|
27
29
|
sub_partition (ObPartition) : subpartition strategy
|
|
28
30
|
is_sub (bool) : this partition strategy is a subpartition or not
|
|
29
31
|
"""
|
|
32
|
+
|
|
30
33
|
def __init__(self, part_type: PartType):
|
|
31
34
|
self.part_type = part_type
|
|
32
35
|
self.sub_partition = None
|
|
@@ -38,7 +41,7 @@ class ObPartition:
|
|
|
38
41
|
|
|
39
42
|
def add_subpartition(self, sub_part):
|
|
40
43
|
"""Add subpartition strategy to current partition.
|
|
41
|
-
|
|
44
|
+
|
|
42
45
|
Args:
|
|
43
46
|
sub_part (ObPartition) : subpartition strategy
|
|
44
47
|
"""
|
|
@@ -60,14 +63,15 @@ class ObPartition:
|
|
|
60
63
|
@dataclass
|
|
61
64
|
class RangeListPartInfo:
|
|
62
65
|
"""Range/RangeColumns/List/ListColumns partition info for each partition.
|
|
63
|
-
|
|
66
|
+
|
|
64
67
|
Attributes:
|
|
65
68
|
part_name (string) : partition name
|
|
66
|
-
part_upper_bound_expr (Union[List, str, int]) :
|
|
67
|
-
For example, using `[1,2]`/`'DEFAULT'` as default case/`7` when create
|
|
69
|
+
part_upper_bound_expr (Union[List, str, int]) :
|
|
70
|
+
For example, using `[1,2]`/`'DEFAULT'` as default case/`7` when create
|
|
68
71
|
List/ListColumns partition.
|
|
69
72
|
Using 100 / `MAXVALUE` when create Range/RangeColumns partition.
|
|
70
73
|
"""
|
|
74
|
+
|
|
71
75
|
part_name: str
|
|
72
76
|
part_upper_bound_expr: Union[list, str, int]
|
|
73
77
|
|
|
@@ -84,6 +88,7 @@ class RangeListPartInfo:
|
|
|
84
88
|
|
|
85
89
|
class ObRangePartition(ObPartition):
|
|
86
90
|
"""Range/RangeColumns partition strategy."""
|
|
91
|
+
|
|
87
92
|
def __init__(
|
|
88
93
|
self,
|
|
89
94
|
is_range_columns: bool,
|
|
@@ -117,18 +122,24 @@ class ObRangePartition(ObPartition):
|
|
|
117
122
|
assert self.range_expr is not None
|
|
118
123
|
if self.sub_partition is None:
|
|
119
124
|
return f"RANGE ({self.range_expr}) ({self._parse_range_part_list()})"
|
|
120
|
-
return
|
|
121
|
-
|
|
125
|
+
return (
|
|
126
|
+
f"RANGE ({self.range_expr}) {self.sub_partition.do_compile()} "
|
|
127
|
+
f"({self._parse_range_part_list()})"
|
|
128
|
+
)
|
|
122
129
|
assert self.col_name_list is not None
|
|
123
130
|
if self.sub_partition is None:
|
|
124
|
-
return
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
131
|
+
return (
|
|
132
|
+
f"RANGE COLUMNS ({','.join(self.col_name_list)}) "
|
|
133
|
+
f"({self._parse_range_part_list()})"
|
|
134
|
+
)
|
|
135
|
+
return (
|
|
136
|
+
f"RANGE COLUMNS ({','.join(self.col_name_list)}) "
|
|
137
|
+
f"{self.sub_partition.do_compile()} ({self._parse_range_part_list()})"
|
|
138
|
+
)
|
|
128
139
|
|
|
129
140
|
def _parse_range_part_list(self) -> str:
|
|
130
141
|
range_partitions_complied = [
|
|
131
|
-
f"PARTITION {range_part_info.part_name} VALUES LESS THAN "
|
|
142
|
+
f"PARTITION {range_part_info.part_name} VALUES LESS THAN "
|
|
132
143
|
f"({range_part_info.get_part_expr_str()})"
|
|
133
144
|
for range_part_info in self.range_part_infos
|
|
134
145
|
]
|
|
@@ -137,6 +148,7 @@ class ObRangePartition(ObPartition):
|
|
|
137
148
|
|
|
138
149
|
class ObSubRangePartition(ObRangePartition):
|
|
139
150
|
"""Range/RangeColumns subpartition strategy."""
|
|
151
|
+
|
|
140
152
|
def __init__(
|
|
141
153
|
self,
|
|
142
154
|
is_range_columns: bool,
|
|
@@ -155,16 +167,20 @@ class ObSubRangePartition(ObRangePartition):
|
|
|
155
167
|
if self.part_type == PartType.Range:
|
|
156
168
|
assert self.range_expr is not None
|
|
157
169
|
assert self.sub_partition is None
|
|
158
|
-
return
|
|
159
|
-
|
|
170
|
+
return (
|
|
171
|
+
f"RANGE ({self.range_expr}) SUBPARTITION TEMPLATE "
|
|
172
|
+
f"({self._parse_range_part_list()})"
|
|
173
|
+
)
|
|
160
174
|
assert self.col_name_list is not None
|
|
161
175
|
assert self.sub_partition is None
|
|
162
|
-
return
|
|
163
|
-
|
|
176
|
+
return (
|
|
177
|
+
f"RANGE COLUMNS ({','.join(self.col_name_list)}) SUBPARTITION TEMPLATE "
|
|
178
|
+
f"({self._parse_range_part_list()})"
|
|
179
|
+
)
|
|
164
180
|
|
|
165
181
|
def _parse_range_part_list(self) -> str:
|
|
166
182
|
range_partitions_complied = [
|
|
167
|
-
f"SUBPARTITION {range_part_info.part_name} VALUES LESS THAN "
|
|
183
|
+
f"SUBPARTITION {range_part_info.part_name} VALUES LESS THAN "
|
|
168
184
|
f"({range_part_info.get_part_expr_str()})"
|
|
169
185
|
for range_part_info in self.range_part_infos
|
|
170
186
|
]
|
|
@@ -173,6 +189,7 @@ class ObSubRangePartition(ObRangePartition):
|
|
|
173
189
|
|
|
174
190
|
class ObListPartition(ObPartition):
|
|
175
191
|
"""List/ListColumns partition strategy."""
|
|
192
|
+
|
|
176
193
|
def __init__(
|
|
177
194
|
self,
|
|
178
195
|
is_list_columns: bool,
|
|
@@ -206,14 +223,20 @@ class ObListPartition(ObPartition):
|
|
|
206
223
|
assert self.list_expr is not None
|
|
207
224
|
if self.sub_partition is None:
|
|
208
225
|
return f"LIST ({self.list_expr}) ({self._parse_list_part_list()})"
|
|
209
|
-
return
|
|
210
|
-
|
|
226
|
+
return (
|
|
227
|
+
f"LIST ({self.list_expr}) {self.sub_partition.do_compile()} "
|
|
228
|
+
f"({self._parse_list_part_list()})"
|
|
229
|
+
)
|
|
211
230
|
assert self.col_name_list is not None
|
|
212
231
|
if self.sub_partition is None:
|
|
213
|
-
return
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
232
|
+
return (
|
|
233
|
+
f"LIST COLUMNS ({','.join(self.col_name_list)}) "
|
|
234
|
+
f"({self._parse_list_part_list()})"
|
|
235
|
+
)
|
|
236
|
+
return (
|
|
237
|
+
f"LIST COLUMNS ({','.join(self.col_name_list)}) "
|
|
238
|
+
f"{self.sub_partition.do_compile()} ({self._parse_list_part_list()})"
|
|
239
|
+
)
|
|
217
240
|
|
|
218
241
|
def _parse_list_part_list(self) -> str:
|
|
219
242
|
list_partitions_complied = [
|
|
@@ -225,6 +248,7 @@ class ObListPartition(ObPartition):
|
|
|
225
248
|
|
|
226
249
|
class ObSubListPartition(ObListPartition):
|
|
227
250
|
"""List/ListColumns subpartition strategy."""
|
|
251
|
+
|
|
228
252
|
def __init__(
|
|
229
253
|
self,
|
|
230
254
|
is_list_columns: bool,
|
|
@@ -246,12 +270,14 @@ class ObSubListPartition(ObListPartition):
|
|
|
246
270
|
return f"LIST ({self.list_expr}) SUBPARTITION TEMPLATE ({self._parse_list_part_list()})"
|
|
247
271
|
assert self.col_name_list is not None
|
|
248
272
|
assert self.sub_partition is None
|
|
249
|
-
return
|
|
250
|
-
|
|
273
|
+
return (
|
|
274
|
+
f"LIST COLUMNS ({','.join(self.col_name_list)}) SUBPARTITION TEMPLATE "
|
|
275
|
+
f"({self._parse_list_part_list()})"
|
|
276
|
+
)
|
|
251
277
|
|
|
252
278
|
def _parse_list_part_list(self) -> str:
|
|
253
279
|
list_partitions_complied = [
|
|
254
|
-
f"SUBPARTITION {list_part_info.part_name} VALUES IN "
|
|
280
|
+
f"SUBPARTITION {list_part_info.part_name} VALUES IN "
|
|
255
281
|
f"({list_part_info.get_part_expr_str()})"
|
|
256
282
|
for list_part_info in self.list_part_infos
|
|
257
283
|
]
|
|
@@ -260,6 +286,7 @@ class ObSubListPartition(ObListPartition):
|
|
|
260
286
|
|
|
261
287
|
class ObHashPartition(ObPartition):
|
|
262
288
|
"""Hash partition strategy."""
|
|
289
|
+
|
|
263
290
|
def __init__(
|
|
264
291
|
self,
|
|
265
292
|
hash_expr: str,
|
|
@@ -279,7 +306,7 @@ class ObHashPartition(ObPartition):
|
|
|
279
306
|
|
|
280
307
|
if self.part_count is not None and self.hash_part_name_list is not None:
|
|
281
308
|
logging.warning(
|
|
282
|
-
"part_count & hash_part_name_list are both set, "
|
|
309
|
+
"part_count & hash_part_name_list are both set, "
|
|
283
310
|
"hash_part_name_list will be override by part_count"
|
|
284
311
|
)
|
|
285
312
|
|
|
@@ -291,13 +318,17 @@ class ObHashPartition(ObPartition):
|
|
|
291
318
|
if self.part_count is not None:
|
|
292
319
|
if self.sub_partition is None:
|
|
293
320
|
return f"HASH ({self.hash_expr}) PARTITIONS {self.part_count}"
|
|
294
|
-
return
|
|
295
|
-
|
|
321
|
+
return (
|
|
322
|
+
f"HASH ({self.hash_expr}) {self.sub_partition.do_compile()} "
|
|
323
|
+
f"PARTITIONS {self.part_count}"
|
|
324
|
+
)
|
|
296
325
|
assert self.hash_part_name_list is not None
|
|
297
326
|
if self.sub_partition is None:
|
|
298
327
|
return f"HASH ({self.hash_expr}) ({self._parse_hash_part_list()})"
|
|
299
|
-
return
|
|
300
|
-
|
|
328
|
+
return (
|
|
329
|
+
f"HASH ({self.hash_expr}) {self.sub_partition.do_compile()} "
|
|
330
|
+
f"({self._parse_hash_part_list()})"
|
|
331
|
+
)
|
|
301
332
|
|
|
302
333
|
def _parse_hash_part_list(self):
|
|
303
334
|
return ",".join([f"PARTITION {name}" for name in self.hash_part_name_list])
|
|
@@ -305,6 +336,7 @@ class ObHashPartition(ObPartition):
|
|
|
305
336
|
|
|
306
337
|
class ObSubHashPartition(ObHashPartition):
|
|
307
338
|
"""Hash subpartition strategy."""
|
|
339
|
+
|
|
308
340
|
def __init__(
|
|
309
341
|
self,
|
|
310
342
|
hash_expr: str,
|
|
@@ -332,6 +364,7 @@ class ObSubHashPartition(ObHashPartition):
|
|
|
332
364
|
|
|
333
365
|
class ObKeyPartition(ObPartition):
|
|
334
366
|
"""Key partition strategy."""
|
|
367
|
+
|
|
335
368
|
def __init__(
|
|
336
369
|
self,
|
|
337
370
|
col_name_list: list[str],
|
|
@@ -351,7 +384,7 @@ class ObKeyPartition(ObPartition):
|
|
|
351
384
|
|
|
352
385
|
if self.part_count is not None and self.key_part_name_list is not None:
|
|
353
386
|
logging.warning(
|
|
354
|
-
"part_count & key_part_name_list are both set, "
|
|
387
|
+
"part_count & key_part_name_list are both set, "
|
|
355
388
|
"key_part_name_list will be override by part_count"
|
|
356
389
|
)
|
|
357
390
|
|
|
@@ -365,13 +398,19 @@ class ObKeyPartition(ObPartition):
|
|
|
365
398
|
return (
|
|
366
399
|
f"KEY ({','.join(self.col_name_list)}) PARTITIONS {self.part_count}"
|
|
367
400
|
)
|
|
368
|
-
return
|
|
369
|
-
|
|
401
|
+
return (
|
|
402
|
+
f"KEY ({','.join(self.col_name_list)}) {self.sub_partition.do_compile()} "
|
|
403
|
+
f"PARTITIONS {self.part_count}"
|
|
404
|
+
)
|
|
370
405
|
assert self.key_part_name_list is not None
|
|
371
406
|
if self.sub_partition is None:
|
|
372
|
-
return
|
|
373
|
-
|
|
374
|
-
|
|
407
|
+
return (
|
|
408
|
+
f"KEY ({','.join(self.col_name_list)}) ({self._parse_key_part_list()})"
|
|
409
|
+
)
|
|
410
|
+
return (
|
|
411
|
+
f"KEY ({','.join(self.col_name_list)}) {self.sub_partition.do_compile()} "
|
|
412
|
+
f"({self._parse_key_part_list()})"
|
|
413
|
+
)
|
|
375
414
|
|
|
376
415
|
def _parse_key_part_list(self):
|
|
377
416
|
return ",".join([f"PARTITION {name}" for name in self.key_part_name_list])
|
|
@@ -379,6 +418,7 @@ class ObKeyPartition(ObPartition):
|
|
|
379
418
|
|
|
380
419
|
class ObSubKeyPartition(ObKeyPartition):
|
|
381
420
|
"""Key subpartition strategy."""
|
|
421
|
+
|
|
382
422
|
def __init__(
|
|
383
423
|
self,
|
|
384
424
|
col_name_list: list[str],
|
|
@@ -400,8 +440,10 @@ class ObSubKeyPartition(ObKeyPartition):
|
|
|
400
440
|
)
|
|
401
441
|
assert self.key_part_name_list is not None
|
|
402
442
|
assert self.sub_partition is None
|
|
403
|
-
return
|
|
404
|
-
|
|
443
|
+
return (
|
|
444
|
+
f"KEY ({','.join(self.col_name_list)}) SUBPARTITION TEMPLATE "
|
|
445
|
+
f"({self._parse_key_part_list()})"
|
|
446
|
+
)
|
|
405
447
|
|
|
406
448
|
def _parse_key_part_list(self):
|
|
407
449
|
return ",".join([f"SUBPARTITION {name}" for name in self.key_part_name_list])
|
pyobvector/client/schema_type.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Data type module that compatible with Milvus."""
|
|
2
|
+
|
|
2
3
|
from sqlalchemy import (
|
|
3
4
|
Boolean,
|
|
4
5
|
SmallInteger,
|
|
@@ -16,6 +17,7 @@ from ..schema import ARRAY, SPARSE_VECTOR, VECTOR
|
|
|
16
17
|
|
|
17
18
|
class DataType(IntEnum):
|
|
18
19
|
"""Data type definition that compatible with Milvus."""
|
|
20
|
+
|
|
19
21
|
# NONE = 0
|
|
20
22
|
BOOL = 1
|
|
21
23
|
INT8 = 2
|
|
@@ -40,7 +42,7 @@ class DataType(IntEnum):
|
|
|
40
42
|
|
|
41
43
|
def convert_datatype_to_sqltype(datatype: DataType):
|
|
42
44
|
"""Convert Milvus data type to SQL type.
|
|
43
|
-
|
|
45
|
+
|
|
44
46
|
Args:
|
|
45
47
|
datatype (DataType) : Milvus data type.
|
|
46
48
|
"""
|
|
@@ -12,7 +12,8 @@ from .virtual_data_type import (
|
|
|
12
12
|
from .json_value_returning_func import json_value
|
|
13
13
|
|
|
14
14
|
__all__ = [
|
|
15
|
-
"OceanBase",
|
|
15
|
+
"OceanBase",
|
|
16
|
+
"ChangeColumn",
|
|
16
17
|
"JType",
|
|
17
18
|
"JsonTableDataType",
|
|
18
19
|
"JsonTableBool",
|
|
@@ -21,5 +22,5 @@ __all__ = [
|
|
|
21
22
|
"JsonTableDecimalFactory",
|
|
22
23
|
"JsonTableInt",
|
|
23
24
|
"val2json",
|
|
24
|
-
"json_value"
|
|
25
|
-
]
|
|
25
|
+
"json_value",
|
|
26
|
+
]
|
|
@@ -7,6 +7,7 @@ from sqlalchemy import Text
|
|
|
7
7
|
|
|
8
8
|
logger = logging.getLogger(__name__)
|
|
9
9
|
|
|
10
|
+
|
|
10
11
|
class json_value(FunctionElement):
|
|
11
12
|
type = Text()
|
|
12
13
|
inherit_cache = True
|
|
@@ -15,6 +16,7 @@ class json_value(FunctionElement):
|
|
|
15
16
|
super().__init__()
|
|
16
17
|
self.args = args
|
|
17
18
|
|
|
19
|
+
|
|
18
20
|
@compiles(json_value)
|
|
19
21
|
def compile_json_value(element, compiler, **kwargs):
|
|
20
22
|
args = []
|
|
@@ -23,25 +25,25 @@ def compile_json_value(element, compiler, **kwargs):
|
|
|
23
25
|
args.append(compiler.process(element.args[0]))
|
|
24
26
|
if not (isinstance(element.args[1], str) and isinstance(element.args[2], str)):
|
|
25
27
|
raise ValueError("Invalid args for json_value")
|
|
26
|
-
|
|
27
|
-
if element.args[2].startswith(
|
|
28
|
+
|
|
29
|
+
if element.args[2].startswith("TINYINT"):
|
|
28
30
|
returning_type = "SIGNED"
|
|
29
|
-
elif element.args[2].startswith(
|
|
31
|
+
elif element.args[2].startswith("TIMESTAMP"):
|
|
30
32
|
returning_type = "DATETIME"
|
|
31
|
-
elif element.args[2].startswith(
|
|
33
|
+
elif element.args[2].startswith("INT"):
|
|
32
34
|
returning_type = "SIGNED"
|
|
33
|
-
elif element.args[2].startswith(
|
|
34
|
-
if element.args[2] ==
|
|
35
|
+
elif element.args[2].startswith("VARCHAR"):
|
|
36
|
+
if element.args[2] == "VARCHAR":
|
|
35
37
|
returning_type = "CHAR(255)"
|
|
36
38
|
else:
|
|
37
|
-
varchar_pattern = r
|
|
39
|
+
varchar_pattern = r"VARCHAR\((\d+)\)"
|
|
38
40
|
varchar_matches = re.findall(varchar_pattern, element.args[2])
|
|
39
41
|
returning_type = f"CHAR({int(varchar_matches[0])})"
|
|
40
|
-
elif element.args[2].startswith(
|
|
41
|
-
if element.args[2] ==
|
|
42
|
+
elif element.args[2].startswith("DECIMAL"):
|
|
43
|
+
if element.args[2] == "DECIMAL":
|
|
42
44
|
returning_type = "DECIMAL(10, 0)"
|
|
43
45
|
else:
|
|
44
|
-
decimal_pattern = r
|
|
46
|
+
decimal_pattern = r"DECIMAL\((\d+),\s*(\d+)\)"
|
|
45
47
|
decimal_matches = re.findall(decimal_pattern, element.args[2])
|
|
46
48
|
x, y = decimal_matches[0]
|
|
47
49
|
returning_type = f"DECIMAL({x}, {y})"
|
|
@@ -3,6 +3,7 @@ from sqlglot import parser, exp, Expression
|
|
|
3
3
|
from sqlglot.dialects.mysql import MySQL
|
|
4
4
|
from sqlglot.tokens import TokenType
|
|
5
5
|
|
|
6
|
+
|
|
6
7
|
class ChangeColumn(Expression):
|
|
7
8
|
arg_types = {
|
|
8
9
|
"this": True,
|
|
@@ -14,12 +15,13 @@ class ChangeColumn(Expression):
|
|
|
14
15
|
def origin_col_name(self) -> str:
|
|
15
16
|
origin_col_name = self.args.get("origin_col_name")
|
|
16
17
|
return origin_col_name
|
|
17
|
-
|
|
18
|
+
|
|
18
19
|
@property
|
|
19
20
|
def dtype(self) -> Expression:
|
|
20
21
|
dtype = self.args.get("dtype")
|
|
21
22
|
return dtype
|
|
22
23
|
|
|
24
|
+
|
|
23
25
|
class OceanBase(MySQL):
|
|
24
26
|
class Parser(MySQL.Parser):
|
|
25
27
|
ALTER_PARSERS = {
|
|
@@ -27,7 +29,7 @@ class OceanBase(MySQL):
|
|
|
27
29
|
"MODIFY": lambda self: self._parse_alter_table_alter(),
|
|
28
30
|
"CHANGE": lambda self: self._parse_change_table_column(),
|
|
29
31
|
}
|
|
30
|
-
|
|
32
|
+
|
|
31
33
|
def _parse_alter_table_alter(self) -> t.Optional[exp.Expression]:
|
|
32
34
|
if self._match_texts(self.ALTER_ALTER_PARSERS):
|
|
33
35
|
return self.ALTER_ALTER_PARSERS[self._prev.text.upper()](self)
|
|
@@ -38,9 +40,13 @@ class OceanBase(MySQL):
|
|
|
38
40
|
if self._match_pair(TokenType.DROP, TokenType.DEFAULT):
|
|
39
41
|
return self.expression(exp.AlterColumn, this=column, drop=True)
|
|
40
42
|
if self._match_pair(TokenType.SET, TokenType.DEFAULT):
|
|
41
|
-
return self.expression(
|
|
43
|
+
return self.expression(
|
|
44
|
+
exp.AlterColumn, this=column, default=self._parse_assignment()
|
|
45
|
+
)
|
|
42
46
|
if self._match(TokenType.COMMENT):
|
|
43
|
-
return self.expression(
|
|
47
|
+
return self.expression(
|
|
48
|
+
exp.AlterColumn, this=column, comment=self._parse_string()
|
|
49
|
+
)
|
|
44
50
|
if self._match_text_seq("DROP", "NOT", "NULL"):
|
|
45
51
|
return self.expression(
|
|
46
52
|
exp.AlterColumn,
|
|
@@ -63,7 +69,7 @@ class OceanBase(MySQL):
|
|
|
63
69
|
collate=self._match(TokenType.COLLATE) and self._parse_term(),
|
|
64
70
|
using=self._match(TokenType.USING) and self._parse_assignment(),
|
|
65
71
|
)
|
|
66
|
-
|
|
72
|
+
|
|
67
73
|
def _parse_drop(self, exists: bool = False) -> t.Union[exp.Drop, exp.Command]:
|
|
68
74
|
temporary = self._match(TokenType.TEMPORARY)
|
|
69
75
|
materialized = self._match_text_seq("MATERIALIZED")
|
|
@@ -79,7 +85,8 @@ class OceanBase(MySQL):
|
|
|
79
85
|
this = self._parse_column()
|
|
80
86
|
else:
|
|
81
87
|
this = self._parse_table_parts(
|
|
82
|
-
schema=True,
|
|
88
|
+
schema=True,
|
|
89
|
+
is_db_reference=self._prev.token_type == TokenType.SCHEMA,
|
|
83
90
|
)
|
|
84
91
|
|
|
85
92
|
cluster = self._parse_on_property() if self._match(TokenType.ON) else None
|
|
@@ -103,7 +110,7 @@ class OceanBase(MySQL):
|
|
|
103
110
|
cluster=cluster,
|
|
104
111
|
concurrently=concurrently,
|
|
105
112
|
)
|
|
106
|
-
|
|
113
|
+
|
|
107
114
|
def _parse_change_table_column(self) -> t.Optional[exp.Expression]:
|
|
108
115
|
self._match(TokenType.COLUMN)
|
|
109
116
|
origin_col = self._parse_field(any_token=True)
|
|
@@ -113,4 +120,4 @@ class OceanBase(MySQL):
|
|
|
113
120
|
this=column,
|
|
114
121
|
origin_col_name=origin_col,
|
|
115
122
|
dtype=self._parse_types(),
|
|
116
|
-
)
|
|
123
|
+
)
|
|
@@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, AfterValidator, create_model
|
|
|
10
10
|
class IntEnum(int, Enum):
|
|
11
11
|
"""Int type enumerate definition."""
|
|
12
12
|
|
|
13
|
+
|
|
13
14
|
class JType(IntEnum):
|
|
14
15
|
J_BOOL = 1
|
|
15
16
|
J_TIMESTAMP = 2
|
|
@@ -17,27 +18,32 @@ class JType(IntEnum):
|
|
|
17
18
|
J_DECIMAL = 4
|
|
18
19
|
J_INT = 5
|
|
19
20
|
|
|
21
|
+
|
|
20
22
|
class JsonTableDataType(BaseModel):
|
|
21
23
|
type: JType
|
|
22
24
|
|
|
25
|
+
|
|
23
26
|
class JsonTableBool(JsonTableDataType):
|
|
24
27
|
type: JType = Field(default=JType.J_BOOL)
|
|
25
28
|
val: Optional[bool]
|
|
26
29
|
|
|
30
|
+
|
|
27
31
|
class JsonTableTimestamp(JsonTableDataType):
|
|
28
32
|
type: JType = Field(default=JType.J_TIMESTAMP)
|
|
29
33
|
val: Optional[datetime]
|
|
30
34
|
|
|
35
|
+
|
|
31
36
|
def check_varchar_len_with_length(length: int):
|
|
32
37
|
def check_varchar_len(x: Optional[str]):
|
|
33
38
|
if x is None:
|
|
34
39
|
return None
|
|
35
40
|
if len(x) > length:
|
|
36
|
-
raise ValueError(f
|
|
41
|
+
raise ValueError(f"{x} is longer than {length}")
|
|
37
42
|
return x
|
|
38
|
-
|
|
43
|
+
|
|
39
44
|
return check_varchar_len
|
|
40
45
|
|
|
46
|
+
|
|
41
47
|
class JsonTableVarcharFactory:
|
|
42
48
|
def __init__(self, length: int):
|
|
43
49
|
self.length = length
|
|
@@ -45,14 +51,17 @@ class JsonTableVarcharFactory:
|
|
|
45
51
|
def get_json_table_varchar_type(self):
|
|
46
52
|
model_name = f"JsonTableVarchar{self.length}"
|
|
47
53
|
fields = {
|
|
48
|
-
|
|
49
|
-
|
|
54
|
+
"type": (JType, JType.J_VARCHAR),
|
|
55
|
+
"val": (
|
|
56
|
+
Annotated[
|
|
57
|
+
Optional[str],
|
|
58
|
+
AfterValidator(check_varchar_len_with_length(self.length)),
|
|
59
|
+
],
|
|
60
|
+
...,
|
|
61
|
+
),
|
|
50
62
|
}
|
|
51
|
-
return create_model(
|
|
52
|
-
|
|
53
|
-
__base__=JsonTableDataType,
|
|
54
|
-
**fields
|
|
55
|
-
)
|
|
63
|
+
return create_model(model_name, __base__=JsonTableDataType, **fields)
|
|
64
|
+
|
|
56
65
|
|
|
57
66
|
def check_and_parse_decimal(x: int, y: int):
|
|
58
67
|
def check_float(v):
|
|
@@ -62,47 +71,57 @@ def check_and_parse_decimal(x: int, y: int):
|
|
|
62
71
|
decimal_value = Decimal(v)
|
|
63
72
|
except InvalidOperation:
|
|
64
73
|
raise ValueError(f"Value {v} cannot be converted to Decimal.")
|
|
65
|
-
|
|
74
|
+
|
|
66
75
|
decimal_str = str(decimal_value).strip()
|
|
67
|
-
|
|
68
|
-
if
|
|
69
|
-
integer_part, decimal_part = decimal_str.split(
|
|
76
|
+
|
|
77
|
+
if "." in decimal_str:
|
|
78
|
+
integer_part, decimal_part = decimal_str.split(".")
|
|
70
79
|
else:
|
|
71
|
-
integer_part, decimal_part = decimal_str,
|
|
72
|
-
|
|
73
|
-
integer_count = len(integer_part.lstrip(
|
|
80
|
+
integer_part, decimal_part = decimal_str, ""
|
|
81
|
+
|
|
82
|
+
integer_count = len(integer_part.lstrip("-")) # 去掉负号的长度
|
|
74
83
|
decimal_count = len(decimal_part)
|
|
75
84
|
|
|
76
85
|
if integer_count + min(decimal_count, y) > x:
|
|
77
86
|
raise ValueError(f"'{v}' Range out of Decimal({x}, {y})")
|
|
78
|
-
|
|
87
|
+
|
|
79
88
|
if decimal_count > y:
|
|
80
|
-
quantize_str =
|
|
81
|
-
decimal_value = decimal_value.quantize(
|
|
89
|
+
quantize_str = "1." + "0" * y
|
|
90
|
+
decimal_value = decimal_value.quantize(
|
|
91
|
+
Decimal(quantize_str), rounding=ROUND_DOWN
|
|
92
|
+
)
|
|
82
93
|
return decimal_value
|
|
94
|
+
|
|
83
95
|
return check_float
|
|
84
96
|
|
|
97
|
+
|
|
85
98
|
class JsonTableDecimalFactory:
|
|
86
99
|
def __init__(self, ndigits: int, decimal_p: int):
|
|
87
100
|
self.ndigits = ndigits
|
|
88
101
|
self.decimal_p = decimal_p
|
|
89
|
-
|
|
102
|
+
|
|
90
103
|
def get_json_table_decimal_type(self):
|
|
91
104
|
model_name = f"JsonTableDecimal_{self.ndigits}_{self.decimal_p}"
|
|
92
105
|
fields = {
|
|
93
|
-
|
|
94
|
-
|
|
106
|
+
"type": (JType, JType.J_DECIMAL),
|
|
107
|
+
"val": (
|
|
108
|
+
Annotated[
|
|
109
|
+
Optional[float],
|
|
110
|
+
AfterValidator(
|
|
111
|
+
check_and_parse_decimal(self.ndigits, self.decimal_p)
|
|
112
|
+
),
|
|
113
|
+
],
|
|
114
|
+
...,
|
|
115
|
+
),
|
|
95
116
|
}
|
|
96
|
-
return create_model(
|
|
97
|
-
|
|
98
|
-
__base__=JsonTableDataType,
|
|
99
|
-
**fields
|
|
100
|
-
)
|
|
117
|
+
return create_model(model_name, __base__=JsonTableDataType, **fields)
|
|
118
|
+
|
|
101
119
|
|
|
102
120
|
class JsonTableInt(JsonTableDataType):
|
|
103
121
|
type: JType = Field(default=JType.J_INT)
|
|
104
122
|
val: Optional[int]
|
|
105
123
|
|
|
124
|
+
|
|
106
125
|
def val2json(val):
|
|
107
126
|
if val is None:
|
|
108
127
|
return None
|
|
@@ -111,4 +130,4 @@ def val2json(val):
|
|
|
111
130
|
if isinstance(val, datetime):
|
|
112
131
|
return val.isoformat()
|
|
113
132
|
if isinstance(val, Decimal):
|
|
114
|
-
return float(val)
|
|
133
|
+
return float(val)
|
pyobvector/schema/__init__.py
CHANGED
|
@@ -20,13 +20,19 @@
|
|
|
20
20
|
* CreateFtsIndex Full Text Search Index Creation statement clause
|
|
21
21
|
* MatchAgainst Full Text Search clause
|
|
22
22
|
"""
|
|
23
|
+
|
|
23
24
|
from .array import ARRAY
|
|
24
25
|
from .vector import VECTOR
|
|
25
26
|
from .sparse_vector import SPARSE_VECTOR
|
|
26
27
|
from .geo_srid_point import POINT
|
|
27
28
|
from .vector_index import VectorIndex, CreateVectorIndex
|
|
28
29
|
from .ob_table import ObTable
|
|
29
|
-
from .vec_dist_func import
|
|
30
|
+
from .vec_dist_func import (
|
|
31
|
+
l2_distance,
|
|
32
|
+
cosine_distance,
|
|
33
|
+
inner_product,
|
|
34
|
+
negative_inner_product,
|
|
35
|
+
)
|
|
30
36
|
from .gis_func import ST_GeomFromText, st_distance, st_dwithin, st_astext
|
|
31
37
|
from .replace_stmt import ReplaceStmt
|
|
32
38
|
from .dialect import OceanBaseDialect, AsyncOceanBaseDialect
|
pyobvector/schema/array.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""ARRAY: An extended data type for SQLAlchemy"""
|
|
2
|
+
|
|
2
3
|
import json
|
|
3
4
|
from typing import Any, Optional, Union
|
|
4
5
|
from collections.abc import Sequence
|
|
@@ -9,6 +10,7 @@ from sqlalchemy.types import UserDefinedType, String
|
|
|
9
10
|
|
|
10
11
|
class ARRAY(UserDefinedType):
|
|
11
12
|
"""ARRAY data type definition with support for up to 6 levels of nesting."""
|
|
13
|
+
|
|
12
14
|
cache_ok = True
|
|
13
15
|
_string = String()
|
|
14
16
|
|
|
@@ -32,7 +34,7 @@ class ARRAY(UserDefinedType):
|
|
|
32
34
|
|
|
33
35
|
def get_col_spec(self, **kw): # pylint: disable=unused-argument
|
|
34
36
|
"""Parse to array data type definition in text SQL."""
|
|
35
|
-
if hasattr(self.item_type,
|
|
37
|
+
if hasattr(self.item_type, "get_col_spec"):
|
|
36
38
|
base_type = self.item_type.get_col_spec(**kw)
|
|
37
39
|
else:
|
|
38
40
|
base_type = str(self.item_type)
|
|
@@ -50,7 +52,9 @@ class ARRAY(UserDefinedType):
|
|
|
50
52
|
|
|
51
53
|
def _validate_dimension(self, value: list[Any]):
|
|
52
54
|
arr_depth = self._get_list_depth(value)
|
|
53
|
-
assert arr_depth == self.dim,
|
|
55
|
+
assert arr_depth == self.dim, (
|
|
56
|
+
f"Array dimension mismatch, expected {self.dim}, got {arr_depth}"
|
|
57
|
+
)
|
|
54
58
|
|
|
55
59
|
def bind_processor(self, dialect):
|
|
56
60
|
item_type = self.item_type
|