pyobvector 0.2.22__py3-none-any.whl → 0.2.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. pyobvector/__init__.py +6 -5
  2. pyobvector/client/__init__.py +5 -4
  3. pyobvector/client/collection_schema.py +5 -1
  4. pyobvector/client/enum.py +1 -1
  5. pyobvector/client/exceptions.py +9 -7
  6. pyobvector/client/fts_index_param.py +8 -4
  7. pyobvector/client/hybrid_search.py +10 -4
  8. pyobvector/client/index_param.py +56 -41
  9. pyobvector/client/milvus_like_client.py +71 -54
  10. pyobvector/client/ob_client.py +20 -16
  11. pyobvector/client/ob_vec_client.py +40 -39
  12. pyobvector/client/ob_vec_json_table_client.py +366 -274
  13. pyobvector/client/partitions.py +81 -39
  14. pyobvector/client/schema_type.py +3 -1
  15. pyobvector/json_table/__init__.py +4 -3
  16. pyobvector/json_table/json_value_returning_func.py +12 -10
  17. pyobvector/json_table/oceanbase_dialect.py +15 -8
  18. pyobvector/json_table/virtual_data_type.py +47 -28
  19. pyobvector/schema/__init__.py +7 -1
  20. pyobvector/schema/array.py +6 -2
  21. pyobvector/schema/dialect.py +4 -0
  22. pyobvector/schema/full_text_index.py +8 -3
  23. pyobvector/schema/geo_srid_point.py +5 -2
  24. pyobvector/schema/gis_func.py +23 -11
  25. pyobvector/schema/match_against_func.py +10 -5
  26. pyobvector/schema/ob_table.py +2 -0
  27. pyobvector/schema/reflection.py +25 -8
  28. pyobvector/schema/replace_stmt.py +4 -0
  29. pyobvector/schema/sparse_vector.py +7 -4
  30. pyobvector/schema/vec_dist_func.py +22 -9
  31. pyobvector/schema/vector.py +3 -1
  32. pyobvector/schema/vector_index.py +7 -3
  33. pyobvector/util/__init__.py +1 -0
  34. pyobvector/util/ob_version.py +2 -0
  35. pyobvector/util/sparse_vector.py +9 -6
  36. pyobvector/util/vector.py +2 -0
  37. {pyobvector-0.2.22.dist-info → pyobvector-0.2.23.dist-info}/METADATA +13 -14
  38. pyobvector-0.2.23.dist-info/RECORD +40 -0
  39. {pyobvector-0.2.22.dist-info → pyobvector-0.2.23.dist-info}/licenses/LICENSE +1 -1
  40. pyobvector-0.2.22.dist-info/RECORD +0 -40
  41. {pyobvector-0.2.22.dist-info → pyobvector-0.2.23.dist-info}/WHEEL +0 -0
@@ -1,4 +1,5 @@
1
1
  """OceanBase dialect."""
2
+
2
3
  from sqlalchemy import util
3
4
  from sqlalchemy.dialects.mysql import aiomysql, pymysql
4
5
 
@@ -7,10 +8,12 @@ from .vector import VECTOR
7
8
  from .sparse_vector import SPARSE_VECTOR
8
9
  from .geo_srid_point import POINT
9
10
 
11
+
10
12
  class OceanBaseDialect(pymysql.MySQLDialect_pymysql):
11
13
  # not change dialect name, since it is a subclass of pymysql.MySQLDialect_pymysql
12
14
  # name = "oceanbase"
13
15
  """Ocenbase dialect."""
16
+
14
17
  supports_statement_cache = True
15
18
 
16
19
  def __init__(self, **kwargs):
@@ -36,6 +39,7 @@ class OceanBaseDialect(pymysql.MySQLDialect_pymysql):
36
39
 
37
40
  class AsyncOceanBaseDialect(aiomysql.MySQLDialect_aiomysql):
38
41
  """OceanBase async dialect."""
42
+
39
43
  supports_statement_cache = True
40
44
 
41
45
  def __init__(self, **kwargs):
@@ -1,4 +1,5 @@
1
1
  """FullTextIndex: full text search index type"""
2
+
2
3
  from sqlalchemy import Index
3
4
  from sqlalchemy.schema import DDLElement
4
5
  from sqlalchemy.ext.compiler import compiles
@@ -7,16 +8,18 @@ from sqlalchemy.sql.ddl import SchemaGenerator
7
8
 
8
9
  class CreateFtsIndex(DDLElement):
9
10
  """A new statement clause to create fts index.
10
-
11
+
11
12
  Attributes:
12
13
  index : fts index schema
13
14
  """
15
+
14
16
  def __init__(self, index):
15
17
  self.index = index
16
18
 
17
19
 
18
20
  class ObFtsSchemaGenerator(SchemaGenerator):
19
21
  """A new schema generator to handle create fts index statement."""
22
+
20
23
  def visit_fts_index(self, index, create_ok=False):
21
24
  """Handle create fts index statement compiling.
22
25
 
@@ -29,8 +32,10 @@ class ObFtsSchemaGenerator(SchemaGenerator):
29
32
  with self.with_ddl_events(index):
30
33
  CreateFtsIndex(index)._invoke_with(self.connection)
31
34
 
35
+
32
36
  class FtsIndex(Index):
33
37
  """Fts Index schema."""
38
+
34
39
  __visit_name__ = "fts_index"
35
40
 
36
41
  def __init__(self, name, fts_parser: str, *column_names, **kw):
@@ -39,7 +44,7 @@ class FtsIndex(Index):
39
44
 
40
45
  def create(self, bind, checkfirst: bool = False) -> None:
41
46
  """Create fts index.
42
-
47
+
43
48
  Args:
44
49
  bind: SQL engine or connection.
45
50
  checkfirst: check the index exists or not.
@@ -48,7 +53,7 @@ class FtsIndex(Index):
48
53
 
49
54
 
50
55
  @compiles(CreateFtsIndex)
51
- def compile_create_fts_index(element, compiler, **kw): # pylint: disable=unused-argument
56
+ def compile_create_fts_index(element, compiler, **kw): # pylint: disable=unused-argument
52
57
  """A decorator function to compile create fts index statement."""
53
58
  index = element.index
54
59
  table_name = index.table.name
@@ -1,23 +1,26 @@
1
1
  """Point: OceanBase GIS data type for SQLAlchemy"""
2
+
2
3
  from typing import Optional
3
4
  from sqlalchemy.types import UserDefinedType, String
4
5
 
6
+
5
7
  class POINT(UserDefinedType):
6
8
  """Point data type definition."""
9
+
7
10
  cache_ok = True
8
11
  _string = String()
9
12
 
10
13
  def __init__(
11
14
  self,
12
15
  # lat_long: Tuple[float, float],
13
- srid: Optional[int] = None
16
+ srid: Optional[int] = None,
14
17
  ):
15
18
  """Init Latitude and Longitude."""
16
19
  super(UserDefinedType, self).__init__()
17
20
  # self.lat_long = lat_long
18
21
  self.srid = srid
19
22
 
20
- def get_col_spec(self, **kw): # pylint: disable=unused-argument
23
+ def get_col_spec(self, **kw): # pylint: disable=unused-argument
21
24
  """Parse to Point data type definition in text SQL."""
22
25
  if self.srid is None:
23
26
  return "POINT"
@@ -10,31 +10,34 @@ from .geo_srid_point import POINT
10
10
 
11
11
  logger = logging.getLogger(__name__)
12
12
 
13
+
13
14
  class ST_GeomFromText(FunctionElement):
14
15
  """ST_GeomFromText: parse text to geometry object.
15
-
16
+
16
17
  Attributes:
17
18
  type : result type
18
19
  """
20
+
19
21
  type = BINARY()
20
22
 
21
23
  def __init__(self, *args):
22
24
  super().__init__()
23
25
  self.args = args
24
26
 
27
+
25
28
  @compiles(ST_GeomFromText)
26
- def compile_ST_GeomFromText(element, compiler, **kwargs): # pylint: disable=unused-argument
29
+ def compile_ST_GeomFromText(element, compiler, **kwargs): # pylint: disable=unused-argument
27
30
  """Compile ST_GeomFromText function."""
28
31
  args = []
29
32
  for idx, arg in enumerate(element.args):
30
33
  if idx == 0:
31
34
  if (
32
- (not isinstance(arg, tuple)) or
33
- (len(arg) != 2) or
34
- (not all(isinstance(x, float) for x in arg))
35
+ (not isinstance(arg, tuple))
36
+ or (len(arg) != 2)
37
+ or (not all(isinstance(x, float) for x in arg))
35
38
  ):
36
39
  raise ValueError(
37
- f"Tuple[float, float] is expected for Point literal," \
40
+ f"Tuple[float, float] is expected for Point literal,"
38
41
  f"while get {type(arg)}"
39
42
  )
40
43
  args.append(f"'{POINT.to_db(arg)}'")
@@ -44,12 +47,14 @@ def compile_ST_GeomFromText(element, compiler, **kwargs): # pylint: disable=unus
44
47
  # logger.info(f"{args_str}")
45
48
  return f"ST_GeomFromText({args_str})"
46
49
 
50
+
47
51
  class st_distance(FunctionElement):
48
52
  """st_distance: calculate distance between Points.
49
-
53
+
50
54
  Attributes:
51
55
  type : result type
52
56
  """
57
+
53
58
  type = Float()
54
59
  inherit_cache = True
55
60
 
@@ -57,19 +62,22 @@ class st_distance(FunctionElement):
57
62
  super().__init__()
58
63
  self.args = args
59
64
 
65
+
60
66
  @compiles(st_distance)
61
- def compile_st_distance(element, compiler, **kwargs): # pylint: disable=unused-argument
67
+ def compile_st_distance(element, compiler, **kwargs): # pylint: disable=unused-argument
62
68
  """Compile st_distance function."""
63
69
  args = ", ".join(compiler.process(arg) for arg in element.args)
64
70
  return f"st_distance({args})"
65
71
 
72
+
66
73
  class st_dwithin(FunctionElement):
67
74
  """st_dwithin: Checks if the distance between two points
68
75
  is less than a specified distance.
69
-
76
+
70
77
  Attributes:
71
78
  type : result type
72
79
  """
80
+
73
81
  type = Boolean()
74
82
  inherit_cache = True
75
83
 
@@ -77,8 +85,9 @@ class st_dwithin(FunctionElement):
77
85
  super().__init__()
78
86
  self.args = args
79
87
 
88
+
80
89
  @compiles(st_dwithin)
81
- def compile_st_dwithin(element, compiler, **kwargs): # pylint: disable=unused-argument
90
+ def compile_st_dwithin(element, compiler, **kwargs): # pylint: disable=unused-argument
82
91
  """Compile st_dwithin function."""
83
92
  args = []
84
93
  for idx, arg in enumerate(element.args):
@@ -89,12 +98,14 @@ def compile_st_dwithin(element, compiler, **kwargs): # pylint: disable=unused-ar
89
98
  args_str = ", ".join(args)
90
99
  return f"_st_dwithin({args_str})"
91
100
 
101
+
92
102
  class st_astext(FunctionElement):
93
103
  """st_astext: Returns a Point in human-readable format.
94
104
 
95
105
  Attributes:
96
106
  type : result type
97
107
  """
108
+
98
109
  type = Text()
99
110
  inherit_cache = True
100
111
 
@@ -102,8 +113,9 @@ class st_astext(FunctionElement):
102
113
  super().__init__()
103
114
  self.args = args
104
115
 
116
+
105
117
  @compiles(st_astext)
106
- def compile_st_astext(element, compiler, **kwargs): # pylint: disable=unused-argument
118
+ def compile_st_astext(element, compiler, **kwargs): # pylint: disable=unused-argument
107
119
  """Compile st_astext function."""
108
120
  args = ", ".join(compiler.process(arg) for arg in element.args)
109
121
  return f"st_astext({args})"
@@ -8,31 +8,36 @@ from sqlalchemy import literal, column
8
8
 
9
9
  logger = logging.getLogger(__name__)
10
10
 
11
+
11
12
  class MatchAgainst(FunctionElement):
12
13
  """MatchAgainst: match clause for full text search.
13
14
 
14
15
  Attributes:
15
16
  type : result type
16
17
  """
18
+
17
19
  inherit_cache = True
18
20
 
19
21
  def __init__(self, query, *columns):
20
22
  columns = [column(col) if isinstance(col, str) else col for col in columns]
21
23
  super().__init__(literal(query), *columns)
22
-
24
+
25
+
23
26
  @compiles(MatchAgainst)
24
- def complie_MatchAgainst(element, compiler, **kwargs): # pylint: disable=unused-argument
27
+ def complie_MatchAgainst(element, compiler, **kwargs): # pylint: disable=unused-argument
25
28
  """Compile MatchAgainst function."""
26
29
  clauses = list(element.clauses)
27
30
  if len(clauses) < 2:
28
31
  raise ValueError(
29
- f"MatchAgainst should take a string expression and " \
32
+ f"MatchAgainst should take a string expression and "
30
33
  f"at least one column name string as parameters."
31
34
  )
32
-
35
+
33
36
  query_expr = clauses[0]
34
37
  compiled_query = compiler.process(query_expr, **kwargs)
35
38
  column_exprs = clauses[1:]
36
- compiled_columns = [compiler.process(col, identifier_prepared=True) for col in column_exprs]
39
+ compiled_columns = [
40
+ compiler.process(col, identifier_prepared=True) for col in column_exprs
41
+ ]
37
42
  columns_str = ", ".join(compiled_columns)
38
43
  return f"MATCH ({columns_str}) AGAINST ({compiled_query} IN NATURAL LANGUAGE MODE)"
@@ -1,9 +1,11 @@
1
1
  """ObTable: extension to Table for creating table with vector index."""
2
+
2
3
  from sqlalchemy import Table
3
4
  from .vector_index import ObSchemaGenerator
4
5
 
5
6
 
6
7
  class ObTable(Table):
7
8
  """A class extends SQLAlchemy Table to do table creation with vector index."""
9
+
8
10
  def create(self, bind, checkfirst: bool = False) -> None:
9
11
  bind._run_ddl_visitor(ObSchemaGenerator, self, checkfirst=checkfirst)
@@ -1,14 +1,21 @@
1
1
  """OceanBase table definition reflection."""
2
+
2
3
  import re
3
4
  import logging
4
- from sqlalchemy.dialects.mysql.reflection import MySQLTableDefinitionParser, _re_compile, cleanup_text
5
+ from sqlalchemy.dialects.mysql.reflection import (
6
+ MySQLTableDefinitionParser,
7
+ _re_compile,
8
+ cleanup_text,
9
+ )
5
10
 
6
11
  from pyobvector.schema.array import ARRAY
7
12
 
8
13
  logger = logging.getLogger(__name__)
9
14
 
15
+
10
16
  class OceanBaseTableDefinitionParser(MySQLTableDefinitionParser):
11
17
  """OceanBase table definition parser."""
18
+
12
19
  def __init__(self, dialect, preparer, *, default_schema=None):
13
20
  MySQLTableDefinitionParser.__init__(self, dialect, preparer)
14
21
  self.default_schema = default_schema
@@ -82,15 +89,19 @@ class OceanBaseTableDefinitionParser(MySQLTableDefinitionParser):
82
89
  m = self._re_array_column.match(line)
83
90
  if m:
84
91
  spec = m.groupdict()
85
- name, coltype_with_args = spec["name"].strip(), spec["coltype_with_args"].strip()
92
+ name, coltype_with_args = (
93
+ spec["name"].strip(),
94
+ spec["coltype_with_args"].strip(),
95
+ )
86
96
 
87
97
  item_pattern = re.compile(
88
- r"^(?:array\s*\()*([\w]+)(?:\(([\d,]+)\))?\)*$",
89
- re.IGNORECASE
98
+ r"^(?:array\s*\()*([\w]+)(?:\(([\d,]+)\))?\)*$", re.IGNORECASE
90
99
  )
91
100
  item_m = item_pattern.match(coltype_with_args)
92
101
  if not item_m:
93
- raise ValueError(f"Failed to find inner type from array column definition: {line}")
102
+ raise ValueError(
103
+ f"Failed to find inner type from array column definition: {line}"
104
+ )
94
105
 
95
106
  item_type = self.dialect.ischema_names[item_m.group(1).lower()]
96
107
  item_type_arg = item_m.group(2)
@@ -99,9 +110,11 @@ class OceanBaseTableDefinitionParser(MySQLTableDefinitionParser):
99
110
  elif item_type_arg[0] == "'" and item_type_arg[-1] == "'":
100
111
  item_type_args = self._re_csv_str.findall(item_type_arg)
101
112
  else:
102
- item_type_args = [int(v) for v in self._re_csv_int.findall(item_type_arg)]
113
+ item_type_args = [
114
+ int(v) for v in self._re_csv_int.findall(item_type_arg)
115
+ ]
103
116
 
104
- nested_level = coltype_with_args.lower().count('array')
117
+ nested_level = coltype_with_args.lower().count("array")
105
118
  type_instance = item_type(*item_type_args)
106
119
  for _ in range(nested_level):
107
120
  type_instance = ARRAY(type_instance)
@@ -144,7 +157,11 @@ class OceanBaseTableDefinitionParser(MySQLTableDefinitionParser):
144
157
 
145
158
  if tp == "fk_constraint":
146
159
  table = spec.get("table", [])
147
- if isinstance(table, list) and len(table) == 2 and table[0] == self.default_schema:
160
+ if (
161
+ isinstance(table, list)
162
+ and len(table) == 2
163
+ and table[0] == self.default_schema
164
+ ):
148
165
  spec["table"] = table[1:]
149
166
 
150
167
  for action in ["onupdate", "ondelete"]:
@@ -1,11 +1,15 @@
1
1
  """ReplaceStmt: replace into statement compilation."""
2
+
2
3
  from sqlalchemy.ext.compiler import compiles
3
4
  from sqlalchemy.sql.expression import Insert
4
5
 
6
+
5
7
  class ReplaceStmt(Insert):
6
8
  """Replace into statement."""
9
+
7
10
  inherit_cache = True
8
11
 
12
+
9
13
  @compiles(ReplaceStmt)
10
14
  def compile_replace_stmt(insert, compiler, **kw):
11
15
  """Compile replace into statement.
@@ -1,25 +1,28 @@
1
1
  """SPARSE_VECTOR: An extended data type for SQLAlchemy"""
2
+
2
3
  from sqlalchemy.types import UserDefinedType, String
3
4
  from ..util import SparseVector
4
5
 
6
+
5
7
  class SPARSE_VECTOR(UserDefinedType):
6
8
  """SPARSE_VECTOR data type definition."""
9
+
7
10
  cache_ok = True
8
11
  _string = String()
9
12
 
10
13
  def __init__(self):
11
14
  super(UserDefinedType, self).__init__()
12
15
 
13
- def get_col_spec(self, **kw): # pylint: disable=unused-argument
16
+ def get_col_spec(self, **kw): # pylint: disable=unused-argument
14
17
  """Parse to sparse vector data type definition in text SQL."""
15
18
  return "SPARSEVECTOR"
16
-
19
+
17
20
  def bind_processor(self, dialect):
18
21
  def process(value):
19
22
  return SparseVector._to_db(value)
20
23
 
21
24
  return process
22
-
25
+
23
26
  def literal_processor(self, dialect):
24
27
  string_literal_processor = self._string._cached_literal_processor(dialect)
25
28
 
@@ -32,4 +35,4 @@ class SPARSE_VECTOR(UserDefinedType):
32
35
  def process(value):
33
36
  return SparseVector._from_db(value)
34
37
 
35
- return process
38
+ return process
@@ -11,6 +11,7 @@ from sqlalchemy import Float
11
11
 
12
12
  logger = logging.getLogger(__name__)
13
13
 
14
+
14
15
  def parse_vec_distance_func_args(element, compiler, **kwargs):
15
16
  args = []
16
17
  for arg in element.args:
@@ -21,20 +22,23 @@ def parse_vec_distance_func_args(element, compiler, **kwargs):
21
22
  args = ", ".join(args)
22
23
  return args
23
24
 
25
+
24
26
  class l2_distance(FunctionElement):
25
27
  """Vector distance function: l2_distance.
26
-
28
+
27
29
  Attributes:
28
30
  type : result type
29
31
  """
32
+
30
33
  type = Float()
31
34
 
32
35
  def __init__(self, *args):
33
36
  super().__init__()
34
37
  self.args = args
35
38
 
39
+
36
40
  @compiles(l2_distance)
37
- def compile_l2_distance(element, compiler, **kwargs): # pylint: disable=unused-argument
41
+ def compile_l2_distance(element, compiler, **kwargs): # pylint: disable=unused-argument
38
42
  """Compile l2_distance function.
39
43
 
40
44
  Args:
@@ -46,41 +50,47 @@ def compile_l2_distance(element, compiler, **kwargs): # pylint: disable=unused-a
46
50
 
47
51
  class cosine_distance(FunctionElement):
48
52
  """Vector distance function: cosine_distance.
49
-
53
+
50
54
  Attributes:
51
55
  type : result type
52
56
  """
57
+
53
58
  type = Float()
54
59
 
55
60
  def __init__(self, *args):
56
61
  super().__init__()
57
62
  self.args = args
58
63
 
64
+
59
65
  @compiles(cosine_distance)
60
- def compile_cosine_distance(element, compiler, **kwargs): # pylint: disable=unused-argument
66
+ def compile_cosine_distance(element, compiler, **kwargs): # pylint: disable=unused-argument
61
67
  """Compile cosine_distance function.
62
68
 
63
69
  Args:
64
70
  element: cosine_distance arguments
65
71
  compiler: SQL compiler
66
72
  """
67
- return f"cosine_distance({parse_vec_distance_func_args(element, compiler, **kwargs)})"
73
+ return (
74
+ f"cosine_distance({parse_vec_distance_func_args(element, compiler, **kwargs)})"
75
+ )
68
76
 
69
77
 
70
78
  class inner_product(FunctionElement):
71
79
  """Vector distance function: inner_product.
72
-
80
+
73
81
  Attributes:
74
82
  type : result type
75
83
  """
84
+
76
85
  type = Float()
77
86
 
78
87
  def __init__(self, *args):
79
88
  super().__init__()
80
89
  self.args = args
81
90
 
91
+
82
92
  @compiles(inner_product)
83
- def compile_inner_product(element, compiler, **kwargs): # pylint: disable=unused-argument
93
+ def compile_inner_product(element, compiler, **kwargs): # pylint: disable=unused-argument
84
94
  """Compile inner_product function.
85
95
 
86
96
  Args:
@@ -89,20 +99,23 @@ def compile_inner_product(element, compiler, **kwargs): # pylint: disable=unused
89
99
  """
90
100
  return f"inner_product({parse_vec_distance_func_args(element, compiler, **kwargs)})"
91
101
 
102
+
92
103
  class negative_inner_product(FunctionElement):
93
104
  """Vector distance function: negative_inner_product.
94
-
105
+
95
106
  Attributes:
96
107
  type : result type
97
108
  """
109
+
98
110
  type = Float()
99
111
 
100
112
  def __init__(self, *args):
101
113
  super().__init__()
102
114
  self.args = args
103
115
 
116
+
104
117
  @compiles(negative_inner_product)
105
- def compile_negative_inner_product(element, compiler, **kwargs): # pylint: disable=unused-argument
118
+ def compile_negative_inner_product(element, compiler, **kwargs): # pylint: disable=unused-argument
106
119
  """Compile negative_inner_product function.
107
120
 
108
121
  Args:
@@ -1,10 +1,12 @@
1
1
  """VECTOR: An extended data type for SQLAlchemy"""
2
+
2
3
  from sqlalchemy.types import UserDefinedType, String
3
4
  from ..util import Vector
4
5
 
5
6
 
6
7
  class VECTOR(UserDefinedType):
7
8
  """VECTOR data type definition."""
9
+
8
10
  cache_ok = True
9
11
  _string = String()
10
12
 
@@ -12,7 +14,7 @@ class VECTOR(UserDefinedType):
12
14
  super(UserDefinedType, self).__init__()
13
15
  self.dim = dim
14
16
 
15
- def get_col_spec(self, **kw): # pylint: disable=unused-argument
17
+ def get_col_spec(self, **kw): # pylint: disable=unused-argument
16
18
  """Parse to vector data type definition in text SQL."""
17
19
  if self.dim is None:
18
20
  return "VECTOR"
@@ -1,4 +1,5 @@
1
1
  """VectorIndex: An extended index type for SQLAlchemy"""
2
+
2
3
  from sqlalchemy import Index
3
4
  from sqlalchemy.schema import DDLElement
4
5
  from sqlalchemy.ext.compiler import compiles
@@ -7,16 +8,18 @@ from sqlalchemy.sql.ddl import SchemaGenerator
7
8
 
8
9
  class CreateVectorIndex(DDLElement):
9
10
  """A new statement clause to create vector index.
10
-
11
+
11
12
  Attributes:
12
13
  index: vector index schema
13
14
  """
15
+
14
16
  def __init__(self, index):
15
17
  self.index = index
16
18
 
17
19
 
18
20
  class ObSchemaGenerator(SchemaGenerator):
19
21
  """A new schema generator to handle create vector index statement."""
22
+
20
23
  def visit_vector_index(self, index, create_ok=False):
21
24
  """Handle create vector index statement compiling.
22
25
 
@@ -32,6 +35,7 @@ class ObSchemaGenerator(SchemaGenerator):
32
35
 
33
36
  class VectorIndex(Index):
34
37
  """Vector Index schema."""
38
+
35
39
  __visit_name__ = "vector_index"
36
40
 
37
41
  def __init__(self, name, *column_names, params: str = None, **kw):
@@ -44,7 +48,7 @@ class VectorIndex(Index):
44
48
 
45
49
  def create(self, bind, checkfirst: bool = False) -> None:
46
50
  """Create vector index.
47
-
51
+
48
52
  Args:
49
53
  bind: SQL engine or connection.
50
54
  checkfirst: check the index exists or not.
@@ -53,7 +57,7 @@ class VectorIndex(Index):
53
57
 
54
58
 
55
59
  @compiles(CreateVectorIndex)
56
- def compile_create_vector_index(element, compiler, **kw): # pylint: disable=unused-argument
60
+ def compile_create_vector_index(element, compiler, **kw): # pylint: disable=unused-argument
57
61
  """A decorator function to compile create vector index statement."""
58
62
  index = element.index
59
63
  table_name = index.table.name
@@ -4,6 +4,7 @@
4
4
  * SparseVector A utility class for the extended data type class 'SPARSE_VECTOR'
5
5
  * ObVersion OceanBase cluster version class
6
6
  """
7
+
7
8
  from .vector import Vector
8
9
  from .sparse_vector import SparseVector
9
10
  from .ob_version import ObVersion
@@ -1,4 +1,5 @@
1
1
  """OceanBase cluster version module."""
2
+
2
3
  import copy
3
4
 
4
5
 
@@ -8,6 +9,7 @@ class ObVersion:
8
9
  Attributes:
9
10
  version_nums (List[int]): version number of OceanBase cluster. For example, '4.3.3.0'
10
11
  """
12
+
11
13
  def __init__(self, version_nums: list[int]):
12
14
  self.version_nums = copy.deepcopy(version_nums)
13
15
 
@@ -1,24 +1,27 @@
1
1
  """A utility module for the extended data type class 'SPARSE_VECTOR'."""
2
+
2
3
  import ast
3
4
 
5
+
4
6
  class SparseVector:
5
7
  """A transformer class between python dict and OceanBase SPARSE_VECTOR.
6
8
 
7
9
  Attributes:
8
10
  _value (Dict) : a python dict
9
11
  """
12
+
10
13
  def __init__(self, value):
11
14
  if not isinstance(value, dict):
12
15
  raise ValueError("Sparse Vector should be a dict in python")
13
-
16
+
14
17
  self._value = value
15
-
18
+
16
19
  def __repr__(self):
17
20
  return f"{self._value}"
18
-
21
+
19
22
  def to_text(self):
20
23
  return f"{self._value}"
21
-
24
+
22
25
  @classmethod
23
26
  def from_text(cls, value: str):
24
27
  """Construct Sparse Vector class with dict in string format.
@@ -27,7 +30,7 @@ class SparseVector:
27
30
  value: For example, '{1:1.1, 2:2.2}'
28
31
  """
29
32
  return cls(ast.literal_eval(value))
30
-
33
+
31
34
  @classmethod
32
35
  def _to_db(cls, value):
33
36
  if value is None:
@@ -45,4 +48,4 @@ class SparseVector:
45
48
 
46
49
  if isinstance(value, str):
47
50
  return cls.from_text(value)._value
48
- raise ValueError(f"unexpected sparse vector type: {type(value)}")
51
+ raise ValueError(f"unexpected sparse vector type: {type(value)}")
pyobvector/util/vector.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """A utility module for the extended data type class 'VECTOR'."""
2
+
2
3
  import json
3
4
  import numpy as np
4
5
 
@@ -9,6 +10,7 @@ class Vector:
9
10
  Attributes:
10
11
  _value (numpy.array): a numpy array
11
12
  """
13
+
12
14
  def __init__(self, value):
13
15
  # big-endian float32
14
16
  if not isinstance(value, np.ndarray) or value.dtype != ">f4":