pyobvector 0.2.16__py3-none-any.whl → 0.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,459 @@
1
+ import logging
2
+ from typing import List, Optional, Dict, Union
3
+ from urllib.parse import quote
4
+
5
+ import sqlalchemy.sql.functions as func_mod
6
+ from sqlalchemy import (
7
+ create_engine,
8
+ MetaData,
9
+ Table,
10
+ Column,
11
+ Index,
12
+ select,
13
+ delete,
14
+ update,
15
+ insert,
16
+ text,
17
+ inspect,
18
+ and_,
19
+ )
20
+ from sqlalchemy.dialects import registry
21
+ from sqlalchemy.exc import NoSuchTableError
22
+
23
+ from .index_param import IndexParams
24
+ from .partitions import ObPartition
25
+ from ..schema import (
26
+ ObTable,
27
+ l2_distance,
28
+ cosine_distance,
29
+ inner_product,
30
+ negative_inner_product,
31
+ ST_GeomFromText,
32
+ st_distance,
33
+ st_dwithin,
34
+ st_astext,
35
+ ReplaceStmt,
36
+ )
37
+ from ..util import ObVersion
38
+
39
+ logger = logging.getLogger(__name__)
40
+ logger.setLevel(logging.DEBUG)
41
+
42
+
43
+ class ObClient:
44
+ """The OceanBase Client"""
45
+
46
+ def __init__(
47
+ self,
48
+ uri: str = "127.0.0.1:2881",
49
+ user: str = "root@test",
50
+ password: str = "",
51
+ db_name: str = "test",
52
+ **kwargs,
53
+ ):
54
+ registry.register("mysql.oceanbase", "pyobvector.schema.dialect", "OceanBaseDialect")
55
+
56
+ setattr(func_mod, "l2_distance", l2_distance)
57
+ setattr(func_mod, "cosine_distance", cosine_distance)
58
+ setattr(func_mod, "inner_product", inner_product)
59
+ setattr(func_mod, "negative_inner_product", negative_inner_product)
60
+ setattr(func_mod, "ST_GeomFromText", ST_GeomFromText)
61
+ setattr(func_mod, "st_distance", st_distance)
62
+ setattr(func_mod, "st_dwithin", st_dwithin)
63
+ setattr(func_mod, "st_astext", st_astext)
64
+
65
+ user = quote(user, safe="")
66
+ password = quote(password, safe="")
67
+
68
+ connection_str = (
69
+ f"mysql+oceanbase://{user}:{password}@{uri}/{db_name}?charset=utf8mb4"
70
+ )
71
+ self.engine = create_engine(connection_str, **kwargs)
72
+ self.metadata_obj = MetaData()
73
+ self.metadata_obj.reflect(bind=self.engine)
74
+
75
+ with self.engine.connect() as conn:
76
+ with conn.begin():
77
+ res = conn.execute(text("SELECT OB_VERSION() FROM DUAL"))
78
+ version = [r[0] for r in res][0]
79
+ self.ob_version = ObVersion.from_db_version_string(version)
80
+
81
+ def refresh_metadata(self, tables: Optional[list[str]] = None):
82
+ """Reload metadata from the database.
83
+
84
+ Args:
85
+ tables (Optional[list[str]]): names of the tables to refresh. If None, refresh all tables.
86
+ """
87
+ if tables is not None:
88
+ for table_name in tables:
89
+ if table_name in self.metadata_obj.tables:
90
+ self.metadata_obj.remove(Table(table_name, self.metadata_obj))
91
+ self.metadata_obj.reflect(bind=self.engine, only=tables, extend_existing=True)
92
+ else:
93
+ self.metadata_obj.clear()
94
+ self.metadata_obj.reflect(bind=self.engine, extend_existing=True)
95
+
96
+ def _insert_partition_hint_for_query_sql(self, sql: str, partition_hint: str):
97
+ from_index = sql.find("FROM")
98
+ assert from_index != -1
99
+ first_space_after_from = sql.find(" ", from_index + len("FROM") + 1)
100
+ if first_space_after_from == -1:
101
+ return sql + " " + partition_hint
102
+ return (
103
+ sql[:first_space_after_from]
104
+ + " "
105
+ + partition_hint
106
+ + sql[first_space_after_from:]
107
+ )
108
+
109
+ def check_table_exists(self, table_name: str):
110
+ """Check if table exists.
111
+
112
+ Args:
113
+ table_name (string): table name
114
+
115
+ Returns:
116
+ bool: True if table exists, False otherwise
117
+ """
118
+ inspector = inspect(self.engine)
119
+ return inspector.has_table(table_name)
120
+
121
+ def create_table(
122
+ self,
123
+ table_name: str,
124
+ columns: List[Column],
125
+ indexes: Optional[List[Index]] = None,
126
+ partitions: Optional[ObPartition] = None,
127
+ **kwargs,
128
+ ):
129
+ """Create a table.
130
+
131
+ Args:
132
+ table_name (string): table name
133
+ columns (List[Column]): column schema
134
+ indexes (Optional[List[Index]]): optional index schema
135
+ partitions (Optional[ObPartition]): optional partition strategy
136
+ **kwargs: additional keyword arguments
137
+ """
138
+ kwargs.setdefault("extend_existing", True)
139
+ with self.engine.connect() as conn:
140
+ with conn.begin():
141
+ if indexes is not None:
142
+ table = ObTable(
143
+ table_name,
144
+ self.metadata_obj,
145
+ *columns,
146
+ *indexes,
147
+ **kwargs,
148
+ )
149
+ else:
150
+ table = ObTable(
151
+ table_name,
152
+ self.metadata_obj,
153
+ *columns,
154
+ **kwargs,
155
+ )
156
+ table.create(self.engine, checkfirst=True)
157
+ # do partition
158
+ if partitions is not None:
159
+ conn.execute(
160
+ text(f"ALTER TABLE `{table_name}` {partitions.do_compile()}")
161
+ )
162
+
163
+ @classmethod
164
+ def prepare_index_params(cls):
165
+ """Create `IndexParams` to hold index configuration."""
166
+ return IndexParams()
167
+
168
+ def drop_table_if_exist(self, table_name: str):
169
+ """Drop table if exists."""
170
+ try:
171
+ table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
172
+ except NoSuchTableError:
173
+ return
174
+ with self.engine.connect() as conn:
175
+ with conn.begin():
176
+ table.drop(self.engine, checkfirst=True)
177
+ self.metadata_obj.remove(table)
178
+
179
+ def drop_index(self, table_name: str, index_name: str):
180
+ """drop index on specified table.
181
+
182
+ If the index not exists, SQL ERROR 1091 will raise.
183
+ """
184
+ with self.engine.connect() as conn:
185
+ with conn.begin():
186
+ conn.execute(text(f"DROP INDEX `{index_name}` ON `{table_name}`"))
187
+
188
+ def insert(
189
+ self,
190
+ table_name: str,
191
+ data: Union[Dict, List[Dict]],
192
+ partition_name: Optional[str] = "",
193
+ ):
194
+ """Insert data into table.
195
+
196
+ Args:
197
+ table_name (string): table name
198
+ data (Union[Dict, List[Dict]]): data that will be inserted
199
+ partition_name (Optional[str]): limit the query to certain partition
200
+ """
201
+ if isinstance(data, Dict):
202
+ data = [data]
203
+
204
+ if len(data) == 0:
205
+ return
206
+
207
+ table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
208
+
209
+ with self.engine.connect() as conn:
210
+ with conn.begin():
211
+ if partition_name is None or partition_name == "":
212
+ conn.execute(insert(table).values(data))
213
+ else:
214
+ conn.execute(
215
+ insert(table)
216
+ .with_hint(f"PARTITION({partition_name})")
217
+ .values(data)
218
+ )
219
+
220
+ def upsert(
221
+ self,
222
+ table_name: str,
223
+ data: Union[Dict, List[Dict]],
224
+ partition_name: Optional[str] = "",
225
+ ):
226
+ """Update data in table. If primary key is duplicated, replace it.
227
+
228
+ Args:
229
+ table_name (string): table name
230
+ data (Union[Dict, List[Dict]]): data that will be upserted
231
+ partition_name (Optional[str]): limit the query to certain partition
232
+ """
233
+ if isinstance(data, Dict):
234
+ data = [data]
235
+
236
+ if len(data) == 0:
237
+ return
238
+
239
+ table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
240
+
241
+ with self.engine.connect() as conn:
242
+ with conn.begin():
243
+ upsert_stmt = (
244
+ ReplaceStmt(table).with_hint(f"PARTITION({partition_name})")
245
+ if partition_name is not None and partition_name != ""
246
+ else ReplaceStmt(table)
247
+ )
248
+ upsert_stmt = upsert_stmt.values(data)
249
+ conn.execute(upsert_stmt)
250
+
251
+ def update(
252
+ self,
253
+ table_name: str,
254
+ values_clause,
255
+ where_clause=None,
256
+ partition_name: Optional[str] = "",
257
+ ):
258
+ """Update data in table.
259
+
260
+ Args:
261
+ table_name (string): table name
262
+ values_clause: update values clause
263
+ where_clause: update with filter
264
+ partition_name (Optional[str]): limit the query to certain partition
265
+
266
+ Example:
267
+ ... code-block:: python
268
+
269
+ data = [
270
+ {"id": 112, "embedding": [1, 2, 3], "meta": {'doc':'hhh1'}},
271
+ {"id": 190, "embedding": [0.13, 0.123, 1.213], "meta": {'doc':'hhh2'}},
272
+ ]
273
+ client.insert(collection_name=test_collection_name, data=data)
274
+ client.update(
275
+ table_name=test_collection_name,
276
+ values_clause=[{'meta':{'doc':'HHH'}}],
277
+ where_clause=[text("id=112")]
278
+ )
279
+ """
280
+ table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
281
+
282
+ with self.engine.connect() as conn:
283
+ with conn.begin():
284
+ update_stmt = (
285
+ update(table).with_hint(f"PARTITION({partition_name})")
286
+ if partition_name is not None and partition_name != ""
287
+ else update(table)
288
+ )
289
+ if where_clause is not None:
290
+ update_stmt = update_stmt.where(*where_clause).values(
291
+ *values_clause
292
+ )
293
+ else:
294
+ update_stmt = update_stmt.values(*values_clause)
295
+ conn.execute(update_stmt)
296
+
297
+ def delete(
298
+ self,
299
+ table_name: str,
300
+ ids: Optional[Union[list, str, int]] = None,
301
+ where_clause=None,
302
+ partition_name: Optional[str] = "",
303
+ ):
304
+ """Delete data in table.
305
+
306
+ Args:
307
+ table_name (string): table name
308
+ ids (Optional[Union[list, str, int]]): ids of data to delete
309
+ where_clause: delete with filter
310
+ partition_name (Optional[str]): limit the query to certain partition
311
+ """
312
+ table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
313
+ where_in_clause = None
314
+ if ids is not None:
315
+ primary_keys = table.primary_key
316
+ pkey_names = [column.name for column in primary_keys]
317
+ if len(pkey_names) == 1:
318
+ if isinstance(ids, list):
319
+ where_in_clause = table.c[pkey_names[0]].in_(ids)
320
+ elif isinstance(ids, (str, int)):
321
+ where_in_clause = table.c[pkey_names[0]].in_([ids])
322
+ else:
323
+ raise TypeError("'ids' is not a list/str/int")
324
+
325
+ with self.engine.connect() as conn:
326
+ with conn.begin():
327
+ delete_stmt = (
328
+ delete(table).with_hint(f"PARTITION({partition_name})")
329
+ if partition_name is not None and partition_name != ""
330
+ else delete(table)
331
+ )
332
+ if where_in_clause is None and where_clause is None:
333
+ conn.execute(delete_stmt)
334
+ elif where_in_clause is not None and where_clause is None:
335
+ conn.execute(delete_stmt.where(where_in_clause))
336
+ elif where_in_clause is None and where_clause is not None:
337
+ conn.execute(delete_stmt.where(*where_clause))
338
+ else:
339
+ conn.execute(
340
+ delete_stmt.where(and_(where_in_clause, *where_clause))
341
+ )
342
+
343
+ def get(
344
+ self,
345
+ table_name: str,
346
+ ids: Optional[Union[list, str, int]] = None,
347
+ where_clause=None,
348
+ output_column_name: Optional[List[str]] = None,
349
+ partition_names: Optional[List[str]] = None,
350
+ n_limits: Optional[int] = None,
351
+ ):
352
+ """Get records with specified primary field `ids`.
353
+
354
+ Args:
355
+ table_name (string): table name
356
+ ids (Optional[Union[list, str, int]]): specified primary field values
357
+ where_clause: SQL filter
358
+ output_column_name (Optional[List[str]]): output fields name
359
+ partition_names (Optional[List[str]]): limit the query to certain partitions
360
+ n_limits (Optional[int]): limit the number of results
361
+
362
+ Returns:
363
+ Result object from SQLAlchemy execution
364
+ """
365
+ table = Table(table_name, self.metadata_obj, autoload_with=self.engine)
366
+ if output_column_name is not None:
367
+ columns = [table.c[column_name] for column_name in output_column_name]
368
+ stmt = select(*columns)
369
+ else:
370
+ stmt = select(table)
371
+ primary_keys = table.primary_key
372
+ pkey_names = [column.name for column in primary_keys]
373
+ where_in_clause = None
374
+ if ids is not None and len(pkey_names) == 1:
375
+ if isinstance(ids, list):
376
+ where_in_clause = table.c[pkey_names[0]].in_(ids)
377
+ elif isinstance(ids, (str, int)):
378
+ where_in_clause = table.c[pkey_names[0]].in_([ids])
379
+ else:
380
+ raise TypeError("'ids' is not a list/str/int")
381
+
382
+ if where_in_clause is not None and where_clause is None:
383
+ stmt = stmt.where(where_in_clause)
384
+ elif where_in_clause is None and where_clause is not None:
385
+ stmt = stmt.where(*where_clause)
386
+ elif where_in_clause is not None and where_clause is not None:
387
+ stmt = stmt.where(and_(where_in_clause, *where_clause))
388
+
389
+ if n_limits is not None:
390
+ stmt = stmt.limit(n_limits)
391
+
392
+ with self.engine.connect() as conn:
393
+ with conn.begin():
394
+ if partition_names is None:
395
+ execute_res = conn.execute(stmt)
396
+ else:
397
+ stmt_str = str(stmt.compile(
398
+ dialect=self.engine.dialect,
399
+ compile_kwargs={"literal_binds": True}
400
+ ))
401
+ stmt_str = self._insert_partition_hint_for_query_sql(
402
+ stmt_str, f"PARTITION({', '.join(partition_names)})"
403
+ )
404
+ logging.debug(stmt_str)
405
+ execute_res = conn.execute(text(stmt_str))
406
+ return execute_res
407
+
408
+ def perform_raw_text_sql(
409
+ self,
410
+ text_sql: str,
411
+ ):
412
+ """Execute raw text SQL."""
413
+ with self.engine.connect() as conn:
414
+ with conn.begin():
415
+ return conn.execute(text(text_sql))
416
+
417
+ def add_columns(
418
+ self,
419
+ table_name: str,
420
+ columns: list[Column],
421
+ ):
422
+ """Add multiple columns to an existing table.
423
+
424
+ Args:
425
+ table_name (string): table name
426
+ columns (list[Column]): list of SQLAlchemy Column objects representing the new columns
427
+ """
428
+ compiler = self.engine.dialect.ddl_compiler(self.engine.dialect, None)
429
+ column_specs = [compiler.get_column_specification(column) for column in columns]
430
+ columns_ddl = ", ".join(f"ADD COLUMN {spec}" for spec in column_specs)
431
+
432
+ with self.engine.connect() as conn:
433
+ with conn.begin():
434
+ conn.execute(
435
+ text(f"ALTER TABLE `{table_name}` {columns_ddl}")
436
+ )
437
+
438
+ self.refresh_metadata([table_name])
439
+
440
+ def drop_columns(
441
+ self,
442
+ table_name: str,
443
+ column_names: list[str],
444
+ ):
445
+ """Drop multiple columns from an existing table.
446
+
447
+ Args:
448
+ table_name (string): table name
449
+ column_names (list[str]): names of the columns to drop
450
+ """
451
+ columns_ddl = ", ".join(f"DROP COLUMN `{name}`" for name in column_names)
452
+
453
+ with self.engine.connect() as conn:
454
+ with conn.begin():
455
+ conn.execute(
456
+ text(f"ALTER TABLE `{table_name}` {columns_ddl}")
457
+ )
458
+
459
+ self.refresh_metadata([table_name])