pyobvector 0.2.22__py3-none-any.whl → 0.2.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. pyobvector/__init__.py +6 -5
  2. pyobvector/client/__init__.py +5 -4
  3. pyobvector/client/collection_schema.py +5 -1
  4. pyobvector/client/enum.py +1 -1
  5. pyobvector/client/exceptions.py +9 -7
  6. pyobvector/client/fts_index_param.py +8 -4
  7. pyobvector/client/hybrid_search.py +10 -4
  8. pyobvector/client/index_param.py +56 -41
  9. pyobvector/client/milvus_like_client.py +71 -54
  10. pyobvector/client/ob_client.py +20 -16
  11. pyobvector/client/ob_vec_client.py +47 -40
  12. pyobvector/client/ob_vec_json_table_client.py +366 -274
  13. pyobvector/client/partitions.py +81 -39
  14. pyobvector/client/schema_type.py +3 -1
  15. pyobvector/json_table/__init__.py +4 -3
  16. pyobvector/json_table/json_value_returning_func.py +12 -10
  17. pyobvector/json_table/oceanbase_dialect.py +15 -8
  18. pyobvector/json_table/virtual_data_type.py +47 -28
  19. pyobvector/schema/__init__.py +7 -1
  20. pyobvector/schema/array.py +6 -2
  21. pyobvector/schema/dialect.py +4 -0
  22. pyobvector/schema/full_text_index.py +8 -3
  23. pyobvector/schema/geo_srid_point.py +5 -2
  24. pyobvector/schema/gis_func.py +23 -11
  25. pyobvector/schema/match_against_func.py +10 -5
  26. pyobvector/schema/ob_table.py +2 -0
  27. pyobvector/schema/reflection.py +25 -8
  28. pyobvector/schema/replace_stmt.py +4 -0
  29. pyobvector/schema/sparse_vector.py +7 -4
  30. pyobvector/schema/vec_dist_func.py +22 -9
  31. pyobvector/schema/vector.py +3 -1
  32. pyobvector/schema/vector_index.py +7 -3
  33. pyobvector/util/__init__.py +1 -0
  34. pyobvector/util/ob_version.py +2 -0
  35. pyobvector/util/sparse_vector.py +9 -6
  36. pyobvector/util/vector.py +2 -0
  37. {pyobvector-0.2.22.dist-info → pyobvector-0.2.24.dist-info}/METADATA +13 -14
  38. pyobvector-0.2.24.dist-info/RECORD +40 -0
  39. {pyobvector-0.2.22.dist-info → pyobvector-0.2.24.dist-info}/licenses/LICENSE +1 -1
  40. pyobvector-0.2.22.dist-info/RECORD +0 -40
  41. {pyobvector-0.2.22.dist-info → pyobvector-0.2.24.dist-info}/WHEEL +0 -0
@@ -1,4 +1,5 @@
1
1
  """A module to do compilation of OceanBase Parition Clause."""
2
+
2
3
  from typing import Optional, Union
3
4
  import logging
4
5
  from dataclasses import dataclass
@@ -11,6 +12,7 @@ logger.setLevel(logging.DEBUG)
11
12
 
12
13
  class PartType(IntEnum):
13
14
  """Partition type of table or collection for both ObVecClient and MilvusLikeClient"""
15
+
14
16
  Range = 0
15
17
  Hash = 1
16
18
  Key = 2
@@ -21,12 +23,13 @@ class PartType(IntEnum):
21
23
 
22
24
  class ObPartition:
23
25
  """Base class of all kind of Partition strategy
24
-
26
+
25
27
  Attributes:
26
28
  part_type (PartType) : type of partition strategy
27
29
  sub_partition (ObPartition) : subpartition strategy
28
30
  is_sub (bool) : this partition strategy is a subpartition or not
29
31
  """
32
+
30
33
  def __init__(self, part_type: PartType):
31
34
  self.part_type = part_type
32
35
  self.sub_partition = None
@@ -38,7 +41,7 @@ class ObPartition:
38
41
 
39
42
  def add_subpartition(self, sub_part):
40
43
  """Add subpartition strategy to current partition.
41
-
44
+
42
45
  Args:
43
46
  sub_part (ObPartition) : subpartition strategy
44
47
  """
@@ -60,14 +63,15 @@ class ObPartition:
60
63
  @dataclass
61
64
  class RangeListPartInfo:
62
65
  """Range/RangeColumns/List/ListColumns partition info for each partition.
63
-
66
+
64
67
  Attributes:
65
68
  part_name (string) : partition name
66
- part_upper_bound_expr (Union[List, str, int]) :
67
- For example, using `[1,2]`/`'DEFAULT'` as default case/`7` when create
69
+ part_upper_bound_expr (Union[List, str, int]) :
70
+ For example, using `[1,2]`/`'DEFAULT'` as default case/`7` when create
68
71
  List/ListColumns partition.
69
72
  Using 100 / `MAXVALUE` when create Range/RangeColumns partition.
70
73
  """
74
+
71
75
  part_name: str
72
76
  part_upper_bound_expr: Union[list, str, int]
73
77
 
@@ -84,6 +88,7 @@ class RangeListPartInfo:
84
88
 
85
89
  class ObRangePartition(ObPartition):
86
90
  """Range/RangeColumns partition strategy."""
91
+
87
92
  def __init__(
88
93
  self,
89
94
  is_range_columns: bool,
@@ -117,18 +122,24 @@ class ObRangePartition(ObPartition):
117
122
  assert self.range_expr is not None
118
123
  if self.sub_partition is None:
119
124
  return f"RANGE ({self.range_expr}) ({self._parse_range_part_list()})"
120
- return f"RANGE ({self.range_expr}) {self.sub_partition.do_compile()} " \
121
- f"({self._parse_range_part_list()})"
125
+ return (
126
+ f"RANGE ({self.range_expr}) {self.sub_partition.do_compile()} "
127
+ f"({self._parse_range_part_list()})"
128
+ )
122
129
  assert self.col_name_list is not None
123
130
  if self.sub_partition is None:
124
- return f"RANGE COLUMNS ({','.join(self.col_name_list)}) " \
125
- f"({self._parse_range_part_list()})"
126
- return f"RANGE COLUMNS ({','.join(self.col_name_list)}) " \
127
- f"{self.sub_partition.do_compile()} ({self._parse_range_part_list()})"
131
+ return (
132
+ f"RANGE COLUMNS ({','.join(self.col_name_list)}) "
133
+ f"({self._parse_range_part_list()})"
134
+ )
135
+ return (
136
+ f"RANGE COLUMNS ({','.join(self.col_name_list)}) "
137
+ f"{self.sub_partition.do_compile()} ({self._parse_range_part_list()})"
138
+ )
128
139
 
129
140
  def _parse_range_part_list(self) -> str:
130
141
  range_partitions_complied = [
131
- f"PARTITION {range_part_info.part_name} VALUES LESS THAN " \
142
+ f"PARTITION {range_part_info.part_name} VALUES LESS THAN "
132
143
  f"({range_part_info.get_part_expr_str()})"
133
144
  for range_part_info in self.range_part_infos
134
145
  ]
@@ -137,6 +148,7 @@ class ObRangePartition(ObPartition):
137
148
 
138
149
  class ObSubRangePartition(ObRangePartition):
139
150
  """Range/RangeColumns subpartition strategy."""
151
+
140
152
  def __init__(
141
153
  self,
142
154
  is_range_columns: bool,
@@ -155,16 +167,20 @@ class ObSubRangePartition(ObRangePartition):
155
167
  if self.part_type == PartType.Range:
156
168
  assert self.range_expr is not None
157
169
  assert self.sub_partition is None
158
- return f"RANGE ({self.range_expr}) SUBPARTITION TEMPLATE " \
159
- f"({self._parse_range_part_list()})"
170
+ return (
171
+ f"RANGE ({self.range_expr}) SUBPARTITION TEMPLATE "
172
+ f"({self._parse_range_part_list()})"
173
+ )
160
174
  assert self.col_name_list is not None
161
175
  assert self.sub_partition is None
162
- return f"RANGE COLUMNS ({','.join(self.col_name_list)}) SUBPARTITION TEMPLATE " \
163
- f"({self._parse_range_part_list()})"
176
+ return (
177
+ f"RANGE COLUMNS ({','.join(self.col_name_list)}) SUBPARTITION TEMPLATE "
178
+ f"({self._parse_range_part_list()})"
179
+ )
164
180
 
165
181
  def _parse_range_part_list(self) -> str:
166
182
  range_partitions_complied = [
167
- f"SUBPARTITION {range_part_info.part_name} VALUES LESS THAN " \
183
+ f"SUBPARTITION {range_part_info.part_name} VALUES LESS THAN "
168
184
  f"({range_part_info.get_part_expr_str()})"
169
185
  for range_part_info in self.range_part_infos
170
186
  ]
@@ -173,6 +189,7 @@ class ObSubRangePartition(ObRangePartition):
173
189
 
174
190
  class ObListPartition(ObPartition):
175
191
  """List/ListColumns partition strategy."""
192
+
176
193
  def __init__(
177
194
  self,
178
195
  is_list_columns: bool,
@@ -206,14 +223,20 @@ class ObListPartition(ObPartition):
206
223
  assert self.list_expr is not None
207
224
  if self.sub_partition is None:
208
225
  return f"LIST ({self.list_expr}) ({self._parse_list_part_list()})"
209
- return f"LIST ({self.list_expr}) {self.sub_partition.do_compile()} " \
210
- f"({self._parse_list_part_list()})"
226
+ return (
227
+ f"LIST ({self.list_expr}) {self.sub_partition.do_compile()} "
228
+ f"({self._parse_list_part_list()})"
229
+ )
211
230
  assert self.col_name_list is not None
212
231
  if self.sub_partition is None:
213
- return f"LIST COLUMNS ({','.join(self.col_name_list)}) " \
214
- f"({self._parse_list_part_list()})"
215
- return f"LIST COLUMNS ({','.join(self.col_name_list)}) " \
216
- f"{self.sub_partition.do_compile()} ({self._parse_list_part_list()})"
232
+ return (
233
+ f"LIST COLUMNS ({','.join(self.col_name_list)}) "
234
+ f"({self._parse_list_part_list()})"
235
+ )
236
+ return (
237
+ f"LIST COLUMNS ({','.join(self.col_name_list)}) "
238
+ f"{self.sub_partition.do_compile()} ({self._parse_list_part_list()})"
239
+ )
217
240
 
218
241
  def _parse_list_part_list(self) -> str:
219
242
  list_partitions_complied = [
@@ -225,6 +248,7 @@ class ObListPartition(ObPartition):
225
248
 
226
249
  class ObSubListPartition(ObListPartition):
227
250
  """List/ListColumns subpartition strategy."""
251
+
228
252
  def __init__(
229
253
  self,
230
254
  is_list_columns: bool,
@@ -246,12 +270,14 @@ class ObSubListPartition(ObListPartition):
246
270
  return f"LIST ({self.list_expr}) SUBPARTITION TEMPLATE ({self._parse_list_part_list()})"
247
271
  assert self.col_name_list is not None
248
272
  assert self.sub_partition is None
249
- return f"LIST COLUMNS ({','.join(self.col_name_list)}) SUBPARTITION TEMPLATE " \
250
- f"({self._parse_list_part_list()})"
273
+ return (
274
+ f"LIST COLUMNS ({','.join(self.col_name_list)}) SUBPARTITION TEMPLATE "
275
+ f"({self._parse_list_part_list()})"
276
+ )
251
277
 
252
278
  def _parse_list_part_list(self) -> str:
253
279
  list_partitions_complied = [
254
- f"SUBPARTITION {list_part_info.part_name} VALUES IN " \
280
+ f"SUBPARTITION {list_part_info.part_name} VALUES IN "
255
281
  f"({list_part_info.get_part_expr_str()})"
256
282
  for list_part_info in self.list_part_infos
257
283
  ]
@@ -260,6 +286,7 @@ class ObSubListPartition(ObListPartition):
260
286
 
261
287
  class ObHashPartition(ObPartition):
262
288
  """Hash partition strategy."""
289
+
263
290
  def __init__(
264
291
  self,
265
292
  hash_expr: str,
@@ -279,7 +306,7 @@ class ObHashPartition(ObPartition):
279
306
 
280
307
  if self.part_count is not None and self.hash_part_name_list is not None:
281
308
  logging.warning(
282
- "part_count & hash_part_name_list are both set, " \
309
+ "part_count & hash_part_name_list are both set, "
283
310
  "hash_part_name_list will be override by part_count"
284
311
  )
285
312
 
@@ -291,13 +318,17 @@ class ObHashPartition(ObPartition):
291
318
  if self.part_count is not None:
292
319
  if self.sub_partition is None:
293
320
  return f"HASH ({self.hash_expr}) PARTITIONS {self.part_count}"
294
- return f"HASH ({self.hash_expr}) {self.sub_partition.do_compile()} " \
295
- f"PARTITIONS {self.part_count}"
321
+ return (
322
+ f"HASH ({self.hash_expr}) {self.sub_partition.do_compile()} "
323
+ f"PARTITIONS {self.part_count}"
324
+ )
296
325
  assert self.hash_part_name_list is not None
297
326
  if self.sub_partition is None:
298
327
  return f"HASH ({self.hash_expr}) ({self._parse_hash_part_list()})"
299
- return f"HASH ({self.hash_expr}) {self.sub_partition.do_compile()} " \
300
- f"({self._parse_hash_part_list()})"
328
+ return (
329
+ f"HASH ({self.hash_expr}) {self.sub_partition.do_compile()} "
330
+ f"({self._parse_hash_part_list()})"
331
+ )
301
332
 
302
333
  def _parse_hash_part_list(self):
303
334
  return ",".join([f"PARTITION {name}" for name in self.hash_part_name_list])
@@ -305,6 +336,7 @@ class ObHashPartition(ObPartition):
305
336
 
306
337
  class ObSubHashPartition(ObHashPartition):
307
338
  """Hash subpartition strategy."""
339
+
308
340
  def __init__(
309
341
  self,
310
342
  hash_expr: str,
@@ -332,6 +364,7 @@ class ObSubHashPartition(ObHashPartition):
332
364
 
333
365
  class ObKeyPartition(ObPartition):
334
366
  """Key partition strategy."""
367
+
335
368
  def __init__(
336
369
  self,
337
370
  col_name_list: list[str],
@@ -351,7 +384,7 @@ class ObKeyPartition(ObPartition):
351
384
 
352
385
  if self.part_count is not None and self.key_part_name_list is not None:
353
386
  logging.warning(
354
- "part_count & key_part_name_list are both set, " \
387
+ "part_count & key_part_name_list are both set, "
355
388
  "key_part_name_list will be override by part_count"
356
389
  )
357
390
 
@@ -365,13 +398,19 @@ class ObKeyPartition(ObPartition):
365
398
  return (
366
399
  f"KEY ({','.join(self.col_name_list)}) PARTITIONS {self.part_count}"
367
400
  )
368
- return f"KEY ({','.join(self.col_name_list)}) {self.sub_partition.do_compile()} " \
369
- f"PARTITIONS {self.part_count}"
401
+ return (
402
+ f"KEY ({','.join(self.col_name_list)}) {self.sub_partition.do_compile()} "
403
+ f"PARTITIONS {self.part_count}"
404
+ )
370
405
  assert self.key_part_name_list is not None
371
406
  if self.sub_partition is None:
372
- return f"KEY ({','.join(self.col_name_list)}) ({self._parse_key_part_list()})"
373
- return f"KEY ({','.join(self.col_name_list)}) {self.sub_partition.do_compile()} " \
374
- f"({self._parse_key_part_list()})"
407
+ return (
408
+ f"KEY ({','.join(self.col_name_list)}) ({self._parse_key_part_list()})"
409
+ )
410
+ return (
411
+ f"KEY ({','.join(self.col_name_list)}) {self.sub_partition.do_compile()} "
412
+ f"({self._parse_key_part_list()})"
413
+ )
375
414
 
376
415
  def _parse_key_part_list(self):
377
416
  return ",".join([f"PARTITION {name}" for name in self.key_part_name_list])
@@ -379,6 +418,7 @@ class ObKeyPartition(ObPartition):
379
418
 
380
419
  class ObSubKeyPartition(ObKeyPartition):
381
420
  """Key subpartition strategy."""
421
+
382
422
  def __init__(
383
423
  self,
384
424
  col_name_list: list[str],
@@ -400,8 +440,10 @@ class ObSubKeyPartition(ObKeyPartition):
400
440
  )
401
441
  assert self.key_part_name_list is not None
402
442
  assert self.sub_partition is None
403
- return f"KEY ({','.join(self.col_name_list)}) SUBPARTITION TEMPLATE " \
404
- f"({self._parse_key_part_list()})"
443
+ return (
444
+ f"KEY ({','.join(self.col_name_list)}) SUBPARTITION TEMPLATE "
445
+ f"({self._parse_key_part_list()})"
446
+ )
405
447
 
406
448
  def _parse_key_part_list(self):
407
449
  return ",".join([f"SUBPARTITION {name}" for name in self.key_part_name_list])
@@ -1,4 +1,5 @@
1
1
  """Data type module that compatible with Milvus."""
2
+
2
3
  from sqlalchemy import (
3
4
  Boolean,
4
5
  SmallInteger,
@@ -16,6 +17,7 @@ from ..schema import ARRAY, SPARSE_VECTOR, VECTOR
16
17
 
17
18
  class DataType(IntEnum):
18
19
  """Data type definition that compatible with Milvus."""
20
+
19
21
  # NONE = 0
20
22
  BOOL = 1
21
23
  INT8 = 2
@@ -40,7 +42,7 @@ class DataType(IntEnum):
40
42
 
41
43
  def convert_datatype_to_sqltype(datatype: DataType):
42
44
  """Convert Milvus data type to SQL type.
43
-
45
+
44
46
  Args:
45
47
  datatype (DataType) : Milvus data type.
46
48
  """
@@ -12,7 +12,8 @@ from .virtual_data_type import (
12
12
  from .json_value_returning_func import json_value
13
13
 
14
14
  __all__ = [
15
- "OceanBase", "ChangeColumn",
15
+ "OceanBase",
16
+ "ChangeColumn",
16
17
  "JType",
17
18
  "JsonTableDataType",
18
19
  "JsonTableBool",
@@ -21,5 +22,5 @@ __all__ = [
21
22
  "JsonTableDecimalFactory",
22
23
  "JsonTableInt",
23
24
  "val2json",
24
- "json_value"
25
- ]
25
+ "json_value",
26
+ ]
@@ -7,6 +7,7 @@ from sqlalchemy import Text
7
7
 
8
8
  logger = logging.getLogger(__name__)
9
9
 
10
+
10
11
  class json_value(FunctionElement):
11
12
  type = Text()
12
13
  inherit_cache = True
@@ -15,6 +16,7 @@ class json_value(FunctionElement):
15
16
  super().__init__()
16
17
  self.args = args
17
18
 
19
+
18
20
  @compiles(json_value)
19
21
  def compile_json_value(element, compiler, **kwargs):
20
22
  args = []
@@ -23,25 +25,25 @@ def compile_json_value(element, compiler, **kwargs):
23
25
  args.append(compiler.process(element.args[0]))
24
26
  if not (isinstance(element.args[1], str) and isinstance(element.args[2], str)):
25
27
  raise ValueError("Invalid args for json_value")
26
-
27
- if element.args[2].startswith('TINYINT'):
28
+
29
+ if element.args[2].startswith("TINYINT"):
28
30
  returning_type = "SIGNED"
29
- elif element.args[2].startswith('TIMESTAMP'):
31
+ elif element.args[2].startswith("TIMESTAMP"):
30
32
  returning_type = "DATETIME"
31
- elif element.args[2].startswith('INT'):
33
+ elif element.args[2].startswith("INT"):
32
34
  returning_type = "SIGNED"
33
- elif element.args[2].startswith('VARCHAR'):
34
- if element.args[2] == 'VARCHAR':
35
+ elif element.args[2].startswith("VARCHAR"):
36
+ if element.args[2] == "VARCHAR":
35
37
  returning_type = "CHAR(255)"
36
38
  else:
37
- varchar_pattern = r'VARCHAR\((\d+)\)'
39
+ varchar_pattern = r"VARCHAR\((\d+)\)"
38
40
  varchar_matches = re.findall(varchar_pattern, element.args[2])
39
41
  returning_type = f"CHAR({int(varchar_matches[0])})"
40
- elif element.args[2].startswith('DECIMAL'):
41
- if element.args[2] == 'DECIMAL':
42
+ elif element.args[2].startswith("DECIMAL"):
43
+ if element.args[2] == "DECIMAL":
42
44
  returning_type = "DECIMAL(10, 0)"
43
45
  else:
44
- decimal_pattern = r'DECIMAL\((\d+),\s*(\d+)\)'
46
+ decimal_pattern = r"DECIMAL\((\d+),\s*(\d+)\)"
45
47
  decimal_matches = re.findall(decimal_pattern, element.args[2])
46
48
  x, y = decimal_matches[0]
47
49
  returning_type = f"DECIMAL({x}, {y})"
@@ -3,6 +3,7 @@ from sqlglot import parser, exp, Expression
3
3
  from sqlglot.dialects.mysql import MySQL
4
4
  from sqlglot.tokens import TokenType
5
5
 
6
+
6
7
  class ChangeColumn(Expression):
7
8
  arg_types = {
8
9
  "this": True,
@@ -14,12 +15,13 @@ class ChangeColumn(Expression):
14
15
  def origin_col_name(self) -> str:
15
16
  origin_col_name = self.args.get("origin_col_name")
16
17
  return origin_col_name
17
-
18
+
18
19
  @property
19
20
  def dtype(self) -> Expression:
20
21
  dtype = self.args.get("dtype")
21
22
  return dtype
22
23
 
24
+
23
25
  class OceanBase(MySQL):
24
26
  class Parser(MySQL.Parser):
25
27
  ALTER_PARSERS = {
@@ -27,7 +29,7 @@ class OceanBase(MySQL):
27
29
  "MODIFY": lambda self: self._parse_alter_table_alter(),
28
30
  "CHANGE": lambda self: self._parse_change_table_column(),
29
31
  }
30
-
32
+
31
33
  def _parse_alter_table_alter(self) -> t.Optional[exp.Expression]:
32
34
  if self._match_texts(self.ALTER_ALTER_PARSERS):
33
35
  return self.ALTER_ALTER_PARSERS[self._prev.text.upper()](self)
@@ -38,9 +40,13 @@ class OceanBase(MySQL):
38
40
  if self._match_pair(TokenType.DROP, TokenType.DEFAULT):
39
41
  return self.expression(exp.AlterColumn, this=column, drop=True)
40
42
  if self._match_pair(TokenType.SET, TokenType.DEFAULT):
41
- return self.expression(exp.AlterColumn, this=column, default=self._parse_assignment())
43
+ return self.expression(
44
+ exp.AlterColumn, this=column, default=self._parse_assignment()
45
+ )
42
46
  if self._match(TokenType.COMMENT):
43
- return self.expression(exp.AlterColumn, this=column, comment=self._parse_string())
47
+ return self.expression(
48
+ exp.AlterColumn, this=column, comment=self._parse_string()
49
+ )
44
50
  if self._match_text_seq("DROP", "NOT", "NULL"):
45
51
  return self.expression(
46
52
  exp.AlterColumn,
@@ -63,7 +69,7 @@ class OceanBase(MySQL):
63
69
  collate=self._match(TokenType.COLLATE) and self._parse_term(),
64
70
  using=self._match(TokenType.USING) and self._parse_assignment(),
65
71
  )
66
-
72
+
67
73
  def _parse_drop(self, exists: bool = False) -> t.Union[exp.Drop, exp.Command]:
68
74
  temporary = self._match(TokenType.TEMPORARY)
69
75
  materialized = self._match_text_seq("MATERIALIZED")
@@ -79,7 +85,8 @@ class OceanBase(MySQL):
79
85
  this = self._parse_column()
80
86
  else:
81
87
  this = self._parse_table_parts(
82
- schema=True, is_db_reference=self._prev.token_type == TokenType.SCHEMA
88
+ schema=True,
89
+ is_db_reference=self._prev.token_type == TokenType.SCHEMA,
83
90
  )
84
91
 
85
92
  cluster = self._parse_on_property() if self._match(TokenType.ON) else None
@@ -103,7 +110,7 @@ class OceanBase(MySQL):
103
110
  cluster=cluster,
104
111
  concurrently=concurrently,
105
112
  )
106
-
113
+
107
114
  def _parse_change_table_column(self) -> t.Optional[exp.Expression]:
108
115
  self._match(TokenType.COLUMN)
109
116
  origin_col = self._parse_field(any_token=True)
@@ -113,4 +120,4 @@ class OceanBase(MySQL):
113
120
  this=column,
114
121
  origin_col_name=origin_col,
115
122
  dtype=self._parse_types(),
116
- )
123
+ )
@@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, AfterValidator, create_model
10
10
  class IntEnum(int, Enum):
11
11
  """Int type enumerate definition."""
12
12
 
13
+
13
14
  class JType(IntEnum):
14
15
  J_BOOL = 1
15
16
  J_TIMESTAMP = 2
@@ -17,27 +18,32 @@ class JType(IntEnum):
17
18
  J_DECIMAL = 4
18
19
  J_INT = 5
19
20
 
21
+
20
22
  class JsonTableDataType(BaseModel):
21
23
  type: JType
22
24
 
25
+
23
26
  class JsonTableBool(JsonTableDataType):
24
27
  type: JType = Field(default=JType.J_BOOL)
25
28
  val: Optional[bool]
26
29
 
30
+
27
31
  class JsonTableTimestamp(JsonTableDataType):
28
32
  type: JType = Field(default=JType.J_TIMESTAMP)
29
33
  val: Optional[datetime]
30
34
 
35
+
31
36
  def check_varchar_len_with_length(length: int):
32
37
  def check_varchar_len(x: Optional[str]):
33
38
  if x is None:
34
39
  return None
35
40
  if len(x) > length:
36
- raise ValueError(f'{x} is longer than {length}')
41
+ raise ValueError(f"{x} is longer than {length}")
37
42
  return x
38
-
43
+
39
44
  return check_varchar_len
40
45
 
46
+
41
47
  class JsonTableVarcharFactory:
42
48
  def __init__(self, length: int):
43
49
  self.length = length
@@ -45,14 +51,17 @@ class JsonTableVarcharFactory:
45
51
  def get_json_table_varchar_type(self):
46
52
  model_name = f"JsonTableVarchar{self.length}"
47
53
  fields = {
48
- 'type': (JType, JType.J_VARCHAR),
49
- 'val': (Annotated[Optional[str], AfterValidator(check_varchar_len_with_length(self.length))], ...)
54
+ "type": (JType, JType.J_VARCHAR),
55
+ "val": (
56
+ Annotated[
57
+ Optional[str],
58
+ AfterValidator(check_varchar_len_with_length(self.length)),
59
+ ],
60
+ ...,
61
+ ),
50
62
  }
51
- return create_model(
52
- model_name,
53
- __base__=JsonTableDataType,
54
- **fields
55
- )
63
+ return create_model(model_name, __base__=JsonTableDataType, **fields)
64
+
56
65
 
57
66
  def check_and_parse_decimal(x: int, y: int):
58
67
  def check_float(v):
@@ -62,47 +71,57 @@ def check_and_parse_decimal(x: int, y: int):
62
71
  decimal_value = Decimal(v)
63
72
  except InvalidOperation:
64
73
  raise ValueError(f"Value {v} cannot be converted to Decimal.")
65
-
74
+
66
75
  decimal_str = str(decimal_value).strip()
67
-
68
- if '.' in decimal_str:
69
- integer_part, decimal_part = decimal_str.split('.')
76
+
77
+ if "." in decimal_str:
78
+ integer_part, decimal_part = decimal_str.split(".")
70
79
  else:
71
- integer_part, decimal_part = decimal_str, ''
72
-
73
- integer_count = len(integer_part.lstrip('-')) # 去掉负号的长度
80
+ integer_part, decimal_part = decimal_str, ""
81
+
82
+ integer_count = len(integer_part.lstrip("-")) # 去掉负号的长度
74
83
  decimal_count = len(decimal_part)
75
84
 
76
85
  if integer_count + min(decimal_count, y) > x:
77
86
  raise ValueError(f"'{v}' Range out of Decimal({x}, {y})")
78
-
87
+
79
88
  if decimal_count > y:
80
- quantize_str = '1.' + '0' * y
81
- decimal_value = decimal_value.quantize(Decimal(quantize_str), rounding=ROUND_DOWN)
89
+ quantize_str = "1." + "0" * y
90
+ decimal_value = decimal_value.quantize(
91
+ Decimal(quantize_str), rounding=ROUND_DOWN
92
+ )
82
93
  return decimal_value
94
+
83
95
  return check_float
84
96
 
97
+
85
98
  class JsonTableDecimalFactory:
86
99
  def __init__(self, ndigits: int, decimal_p: int):
87
100
  self.ndigits = ndigits
88
101
  self.decimal_p = decimal_p
89
-
102
+
90
103
  def get_json_table_decimal_type(self):
91
104
  model_name = f"JsonTableDecimal_{self.ndigits}_{self.decimal_p}"
92
105
  fields = {
93
- 'type': (JType, JType.J_DECIMAL),
94
- 'val': (Annotated[Optional[float], AfterValidator(check_and_parse_decimal(self.ndigits, self.decimal_p))], ...)
106
+ "type": (JType, JType.J_DECIMAL),
107
+ "val": (
108
+ Annotated[
109
+ Optional[float],
110
+ AfterValidator(
111
+ check_and_parse_decimal(self.ndigits, self.decimal_p)
112
+ ),
113
+ ],
114
+ ...,
115
+ ),
95
116
  }
96
- return create_model(
97
- model_name,
98
- __base__=JsonTableDataType,
99
- **fields
100
- )
117
+ return create_model(model_name, __base__=JsonTableDataType, **fields)
118
+
101
119
 
102
120
  class JsonTableInt(JsonTableDataType):
103
121
  type: JType = Field(default=JType.J_INT)
104
122
  val: Optional[int]
105
123
 
124
+
106
125
  def val2json(val):
107
126
  if val is None:
108
127
  return None
@@ -111,4 +130,4 @@ def val2json(val):
111
130
  if isinstance(val, datetime):
112
131
  return val.isoformat()
113
132
  if isinstance(val, Decimal):
114
- return float(val)
133
+ return float(val)
@@ -20,13 +20,19 @@
20
20
  * CreateFtsIndex Full Text Search Index Creation statement clause
21
21
  * MatchAgainst Full Text Search clause
22
22
  """
23
+
23
24
  from .array import ARRAY
24
25
  from .vector import VECTOR
25
26
  from .sparse_vector import SPARSE_VECTOR
26
27
  from .geo_srid_point import POINT
27
28
  from .vector_index import VectorIndex, CreateVectorIndex
28
29
  from .ob_table import ObTable
29
- from .vec_dist_func import l2_distance, cosine_distance, inner_product, negative_inner_product
30
+ from .vec_dist_func import (
31
+ l2_distance,
32
+ cosine_distance,
33
+ inner_product,
34
+ negative_inner_product,
35
+ )
30
36
  from .gis_func import ST_GeomFromText, st_distance, st_dwithin, st_astext
31
37
  from .replace_stmt import ReplaceStmt
32
38
  from .dialect import OceanBaseDialect, AsyncOceanBaseDialect
@@ -1,4 +1,5 @@
1
1
  """ARRAY: An extended data type for SQLAlchemy"""
2
+
2
3
  import json
3
4
  from typing import Any, Optional, Union
4
5
  from collections.abc import Sequence
@@ -9,6 +10,7 @@ from sqlalchemy.types import UserDefinedType, String
9
10
 
10
11
  class ARRAY(UserDefinedType):
11
12
  """ARRAY data type definition with support for up to 6 levels of nesting."""
13
+
12
14
  cache_ok = True
13
15
  _string = String()
14
16
 
@@ -32,7 +34,7 @@ class ARRAY(UserDefinedType):
32
34
 
33
35
  def get_col_spec(self, **kw): # pylint: disable=unused-argument
34
36
  """Parse to array data type definition in text SQL."""
35
- if hasattr(self.item_type, 'get_col_spec'):
37
+ if hasattr(self.item_type, "get_col_spec"):
36
38
  base_type = self.item_type.get_col_spec(**kw)
37
39
  else:
38
40
  base_type = str(self.item_type)
@@ -50,7 +52,9 @@ class ARRAY(UserDefinedType):
50
52
 
51
53
  def _validate_dimension(self, value: list[Any]):
52
54
  arr_depth = self._get_list_depth(value)
53
- assert arr_depth == self.dim, f"Array dimension mismatch, expected {self.dim}, got {arr_depth}"
55
+ assert arr_depth == self.dim, (
56
+ f"Array dimension mismatch, expected {self.dim}, got {arr_depth}"
57
+ )
54
58
 
55
59
  def bind_processor(self, dialect):
56
60
  item_type = self.item_type