camel-ai 0.2.46__py3-none-any.whl → 0.2.48__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- camel/__init__.py +1 -1
- camel/agents/chat_agent.py +5 -5
- camel/datasets/few_shot_generator.py +19 -3
- camel/datasets/models.py +1 -1
- camel/loaders/__init__.py +2 -0
- camel/loaders/scrapegraph_reader.py +96 -0
- camel/models/openai_model.py +3 -1
- camel/storages/__init__.py +2 -0
- camel/storages/vectordb_storages/__init__.py +2 -0
- camel/storages/vectordb_storages/oceanbase.py +458 -0
- camel/toolkits/__init__.py +2 -0
- camel/toolkits/browser_toolkit.py +4 -7
- camel/toolkits/dalle_toolkit.py +20 -6
- camel/toolkits/jina_reranker_toolkit.py +231 -0
- camel/toolkits/search_toolkit.py +167 -0
- camel/types/enums.py +6 -0
- camel/utils/token_counting.py +7 -3
- {camel_ai-0.2.46.dist-info → camel_ai-0.2.48.dist-info}/METADATA +13 -2
- {camel_ai-0.2.46.dist-info → camel_ai-0.2.48.dist-info}/RECORD +21 -18
- {camel_ai-0.2.46.dist-info → camel_ai-0.2.48.dist-info}/WHEEL +0 -0
- {camel_ai-0.2.46.dist-info → camel_ai-0.2.48.dist-info}/licenses/LICENSE +0 -0
camel/__init__.py
CHANGED
camel/agents/chat_agent.py
CHANGED
|
@@ -715,11 +715,11 @@ class ChatAgent(BaseAgent):
|
|
|
715
715
|
if external_tool_call_requests is None:
|
|
716
716
|
external_tool_call_requests = []
|
|
717
717
|
external_tool_call_requests.append(tool_call_request)
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
718
|
+
else:
|
|
719
|
+
tool_call_record = await self._aexecute_tool(
|
|
720
|
+
tool_call_request
|
|
721
|
+
)
|
|
722
|
+
tool_call_records.append(tool_call_record)
|
|
723
723
|
|
|
724
724
|
# If we found an external tool call, break the loop
|
|
725
725
|
if external_tool_call_requests:
|
|
@@ -16,7 +16,7 @@ import asyncio
|
|
|
16
16
|
from datetime import datetime
|
|
17
17
|
from typing import List
|
|
18
18
|
|
|
19
|
-
from pydantic import ValidationError
|
|
19
|
+
from pydantic import BaseModel, Field, ValidationError
|
|
20
20
|
|
|
21
21
|
from camel.agents import ChatAgent
|
|
22
22
|
from camel.logger import get_logger
|
|
@@ -176,14 +176,30 @@ class FewShotGenerator(BaseGenerator):
|
|
|
176
176
|
]
|
|
177
177
|
prompt = self._construct_prompt(examples)
|
|
178
178
|
|
|
179
|
+
# Create a simplified version of DataPoint that omits metadata
|
|
180
|
+
# because agent.step's response_format parameter doesn't
|
|
181
|
+
# support type Dict[str, Any]
|
|
182
|
+
class DataPointSimplified(BaseModel):
|
|
183
|
+
question: str = Field(
|
|
184
|
+
description="The primary question or issue to "
|
|
185
|
+
"be addressed."
|
|
186
|
+
)
|
|
187
|
+
final_answer: str = Field(description="The final answer.")
|
|
188
|
+
rationale: str = Field(
|
|
189
|
+
description="Logical reasoning or explanation "
|
|
190
|
+
"behind the answer."
|
|
191
|
+
)
|
|
192
|
+
|
|
179
193
|
try:
|
|
180
194
|
agent_output = (
|
|
181
|
-
self.agent.step(
|
|
195
|
+
self.agent.step(
|
|
196
|
+
prompt, response_format=DataPointSimplified
|
|
197
|
+
)
|
|
182
198
|
.msgs[0]
|
|
183
199
|
.parsed
|
|
184
200
|
)
|
|
185
201
|
|
|
186
|
-
assert isinstance(agent_output,
|
|
202
|
+
assert isinstance(agent_output, DataPointSimplified)
|
|
187
203
|
|
|
188
204
|
self.agent.reset()
|
|
189
205
|
|
camel/datasets/models.py
CHANGED
|
@@ -24,7 +24,7 @@ class DataPoint(BaseModel):
|
|
|
24
24
|
final_answer (str): The final answer.
|
|
25
25
|
rationale (Optional[str]): Logical reasoning or explanation behind the
|
|
26
26
|
answer. (default: :obj:`None`)
|
|
27
|
-
metadata Optional[Dict[str, Any]]: Additional metadata about the data
|
|
27
|
+
metadata (Optional[Dict[str, Any]]): Additional metadata about the data
|
|
28
28
|
point. (default: :obj:`None`)
|
|
29
29
|
"""
|
|
30
30
|
|
camel/loaders/__init__.py
CHANGED
|
@@ -20,6 +20,7 @@ from .firecrawl_reader import Firecrawl
|
|
|
20
20
|
from .jina_url_reader import JinaURLReader
|
|
21
21
|
from .mineru_extractor import MinerU
|
|
22
22
|
from .pandas_reader import PandasReader
|
|
23
|
+
from .scrapegraph_reader import ScrapeGraphAI
|
|
23
24
|
from .unstructured_io import UnstructuredIO
|
|
24
25
|
|
|
25
26
|
__all__ = [
|
|
@@ -34,4 +35,5 @@ __all__ = [
|
|
|
34
35
|
'PandasReader',
|
|
35
36
|
'MinerU',
|
|
36
37
|
'Crawl4AI',
|
|
38
|
+
'ScrapeGraphAI',
|
|
37
39
|
]
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
from typing import Any, Dict, Optional
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ScrapeGraphAI:
|
|
20
|
+
r"""ScrapeGraphAI allows you to perform AI-powered web scraping and
|
|
21
|
+
searching.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
api_key (Optional[str]): API key for authenticating with the
|
|
25
|
+
ScrapeGraphAI API.
|
|
26
|
+
|
|
27
|
+
References:
|
|
28
|
+
https://scrapegraph.ai/
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
api_key: Optional[str] = None,
|
|
34
|
+
) -> None:
|
|
35
|
+
from scrapegraph_py import Client
|
|
36
|
+
from scrapegraph_py.logger import sgai_logger
|
|
37
|
+
|
|
38
|
+
self._api_key = api_key or os.environ.get("SCRAPEGRAPH_API_KEY")
|
|
39
|
+
sgai_logger.set_logging(level="INFO")
|
|
40
|
+
self.client = Client(api_key=self._api_key)
|
|
41
|
+
|
|
42
|
+
def search(
|
|
43
|
+
self,
|
|
44
|
+
user_prompt: str,
|
|
45
|
+
) -> Dict[str, Any]:
|
|
46
|
+
r"""Perform an AI-powered web search using ScrapeGraphAI.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
user_prompt (str): The search query or instructions.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Dict[str, Any]: The search results including answer and reference
|
|
53
|
+
URLs.
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
RuntimeError: If the search process fails.
|
|
57
|
+
"""
|
|
58
|
+
try:
|
|
59
|
+
response = self.client.searchscraper(user_prompt=user_prompt)
|
|
60
|
+
return response
|
|
61
|
+
except Exception as e:
|
|
62
|
+
raise RuntimeError(f"Failed to perform search: {e}")
|
|
63
|
+
|
|
64
|
+
def scrape(
|
|
65
|
+
self,
|
|
66
|
+
website_url: str,
|
|
67
|
+
user_prompt: str,
|
|
68
|
+
website_html: Optional[str] = None,
|
|
69
|
+
) -> Dict[str, Any]:
|
|
70
|
+
r"""Perform AI-powered web scraping using ScrapeGraphAI.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
website_url (str): The URL to scrape.
|
|
74
|
+
user_prompt (str): Instructions for what data to extract.
|
|
75
|
+
website_html (Optional[str]): Optional HTML content to use instead
|
|
76
|
+
of fetching from the URL.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Dict[str, Any]: The scraped data including request ID and result.
|
|
80
|
+
|
|
81
|
+
Raises:
|
|
82
|
+
RuntimeError: If the scrape process fails.
|
|
83
|
+
"""
|
|
84
|
+
try:
|
|
85
|
+
response = self.client.smartscraper(
|
|
86
|
+
website_url=website_url,
|
|
87
|
+
user_prompt=user_prompt,
|
|
88
|
+
website_html=website_html,
|
|
89
|
+
)
|
|
90
|
+
return response
|
|
91
|
+
except Exception as e:
|
|
92
|
+
raise RuntimeError(f"Failed to perform scrape: {e}")
|
|
93
|
+
|
|
94
|
+
def close(self) -> None:
|
|
95
|
+
r"""Close the ScrapeGraphAI client connection."""
|
|
96
|
+
self.client.close()
|
camel/models/openai_model.py
CHANGED
|
@@ -111,9 +111,11 @@ class OpenAIModel(BaseModelBackend):
|
|
|
111
111
|
ModelType.O1_MINI,
|
|
112
112
|
ModelType.O1_PREVIEW,
|
|
113
113
|
ModelType.O3_MINI,
|
|
114
|
+
ModelType.O3,
|
|
115
|
+
ModelType.O4_MINI,
|
|
114
116
|
]:
|
|
115
117
|
warnings.warn(
|
|
116
|
-
"Warning: You are using an reasoning model (
|
|
118
|
+
"Warning: You are using an reasoning model (O series), "
|
|
117
119
|
"which has certain limitations, reference: "
|
|
118
120
|
"`https://platform.openai.com/docs/guides/reasoning`.",
|
|
119
121
|
UserWarning,
|
camel/storages/__init__.py
CHANGED
|
@@ -27,6 +27,7 @@ from .vectordb_storages.base import (
|
|
|
27
27
|
VectorRecord,
|
|
28
28
|
)
|
|
29
29
|
from .vectordb_storages.milvus import MilvusStorage
|
|
30
|
+
from .vectordb_storages.oceanbase import OceanBaseStorage
|
|
30
31
|
from .vectordb_storages.qdrant import QdrantStorage
|
|
31
32
|
from .vectordb_storages.tidb import TiDBStorage
|
|
32
33
|
|
|
@@ -46,4 +47,5 @@ __all__ = [
|
|
|
46
47
|
'Neo4jGraph',
|
|
47
48
|
'NebulaGraph',
|
|
48
49
|
'Mem0Storage',
|
|
50
|
+
'OceanBaseStorage',
|
|
49
51
|
]
|
|
@@ -20,6 +20,7 @@ from .base import (
|
|
|
20
20
|
VectorRecord,
|
|
21
21
|
)
|
|
22
22
|
from .milvus import MilvusStorage
|
|
23
|
+
from .oceanbase import OceanBaseStorage
|
|
23
24
|
from .qdrant import QdrantStorage
|
|
24
25
|
from .tidb import TiDBStorage
|
|
25
26
|
|
|
@@ -30,6 +31,7 @@ __all__ = [
|
|
|
30
31
|
'QdrantStorage',
|
|
31
32
|
'MilvusStorage',
|
|
32
33
|
"TiDBStorage",
|
|
34
|
+
'OceanBaseStorage',
|
|
33
35
|
'VectorRecord',
|
|
34
36
|
'VectorDBStatus',
|
|
35
37
|
]
|
|
@@ -0,0 +1,458 @@
|
|
|
1
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import logging
|
|
17
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
|
18
|
+
|
|
19
|
+
from sqlalchemy import JSON, Column, Integer
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from pyobvector.client import ObVecClient
|
|
23
|
+
|
|
24
|
+
from camel.storages.vectordb_storages import (
|
|
25
|
+
BaseVectorStorage,
|
|
26
|
+
VectorDBQuery,
|
|
27
|
+
VectorDBQueryResult,
|
|
28
|
+
VectorDBStatus,
|
|
29
|
+
VectorRecord,
|
|
30
|
+
)
|
|
31
|
+
from camel.utils import dependencies_required
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class OceanBaseStorage(BaseVectorStorage):
|
|
37
|
+
r"""An implementation of the `BaseVectorStorage` for interacting with
|
|
38
|
+
OceanBase Vector Database.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
vector_dim (int): The dimension of storing vectors.
|
|
42
|
+
table_name (str): Name for the table in OceanBase.
|
|
43
|
+
uri (str): Connection URI for OceanBase (host:port).
|
|
44
|
+
(default: :obj:`"127.0.0.1:2881"`)
|
|
45
|
+
user (str): Username for connecting to OceanBase.
|
|
46
|
+
(default: :obj:`"root@test"`)
|
|
47
|
+
password (str): Password for the user. (default: :obj:`""`)
|
|
48
|
+
db_name (str): Database name in OceanBase.
|
|
49
|
+
(default: :obj:`"test"`)
|
|
50
|
+
distance (Literal["l2", "cosine"], optional): The distance metric for
|
|
51
|
+
vector comparison. Options: "l2", "cosine". (default: :obj:`"l2"`)
|
|
52
|
+
delete_table_on_del (bool, optional): Flag to determine if the
|
|
53
|
+
table should be deleted upon object destruction.
|
|
54
|
+
(default: :obj:`False`)
|
|
55
|
+
**kwargs (Any): Additional keyword arguments for initializing
|
|
56
|
+
`ObVecClient`.
|
|
57
|
+
|
|
58
|
+
Raises:
|
|
59
|
+
ImportError: If `pyobvector` package is not installed.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
@dependencies_required('pyobvector')
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
vector_dim: int,
|
|
66
|
+
table_name: str,
|
|
67
|
+
uri: str = "127.0.0.1:2881",
|
|
68
|
+
user: str = "root@test",
|
|
69
|
+
password: str = "",
|
|
70
|
+
db_name: str = "test",
|
|
71
|
+
distance: Literal["l2", "cosine"] = "l2",
|
|
72
|
+
delete_table_on_del: bool = False,
|
|
73
|
+
**kwargs: Any,
|
|
74
|
+
) -> None:
|
|
75
|
+
from pyobvector.client import (
|
|
76
|
+
ObVecClient,
|
|
77
|
+
)
|
|
78
|
+
from pyobvector.client.index_param import (
|
|
79
|
+
IndexParam,
|
|
80
|
+
IndexParams,
|
|
81
|
+
)
|
|
82
|
+
from pyobvector.schema import VECTOR
|
|
83
|
+
|
|
84
|
+
self.vector_dim: int = vector_dim
|
|
85
|
+
self.table_name: str = table_name
|
|
86
|
+
self.distance: Literal["l2", "cosine"] = distance
|
|
87
|
+
self.delete_table_on_del: bool = delete_table_on_del
|
|
88
|
+
|
|
89
|
+
# Create client
|
|
90
|
+
self._client: ObVecClient = ObVecClient(
|
|
91
|
+
uri=uri, user=user, password=password, db_name=db_name, **kwargs
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# Map distance to distance function in OceanBase
|
|
95
|
+
self._distance_func_map: Dict[str, str] = {
|
|
96
|
+
"cosine": "cosine_distance",
|
|
97
|
+
"l2": "l2_distance",
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
# Check or create table with vector index
|
|
101
|
+
if not self._client.check_table_exists(self.table_name):
|
|
102
|
+
# Define table schema
|
|
103
|
+
columns: List[Column] = [
|
|
104
|
+
Column("id", Integer, primary_key=True, autoincrement=True),
|
|
105
|
+
Column("embedding", VECTOR(vector_dim)),
|
|
106
|
+
Column("metadata", JSON),
|
|
107
|
+
]
|
|
108
|
+
|
|
109
|
+
# Create table
|
|
110
|
+
self._client.create_table(
|
|
111
|
+
table_name=self.table_name, columns=columns
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Create vector index
|
|
115
|
+
index_params: IndexParams = IndexParams()
|
|
116
|
+
index_params.add_index_param(
|
|
117
|
+
IndexParam(
|
|
118
|
+
index_name="embedding_idx",
|
|
119
|
+
field_name="embedding",
|
|
120
|
+
distance=self.distance,
|
|
121
|
+
type="hnsw",
|
|
122
|
+
m=16,
|
|
123
|
+
ef_construction=256,
|
|
124
|
+
)
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
self._client.create_vidx_with_vec_index_param(
|
|
128
|
+
table_name=self.table_name, vidx_param=index_params.params[0]
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
logger.info(f"Created table {self.table_name} with vector index")
|
|
132
|
+
else:
|
|
133
|
+
logger.info(f"Using existing table {self.table_name}")
|
|
134
|
+
|
|
135
|
+
def __del__(self):
|
|
136
|
+
r"""Deletes the table if :obj:`delete_table_on_del` is set to
|
|
137
|
+
:obj:`True`.
|
|
138
|
+
"""
|
|
139
|
+
if hasattr(self, "delete_table_on_del") and self.delete_table_on_del:
|
|
140
|
+
try:
|
|
141
|
+
self._client.drop_table_if_exist(self.table_name)
|
|
142
|
+
logger.info(f"Deleted table {self.table_name}")
|
|
143
|
+
except Exception as e:
|
|
144
|
+
logger.error(f"Failed to delete table {self.table_name}: {e}")
|
|
145
|
+
|
|
146
|
+
def add(
|
|
147
|
+
self,
|
|
148
|
+
records: List[VectorRecord],
|
|
149
|
+
batch_size: int = 100,
|
|
150
|
+
**kwargs: Any,
|
|
151
|
+
) -> None:
|
|
152
|
+
r"""Saves a list of vector records to the storage.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
records (List[VectorRecord]): List of vector records to be saved.
|
|
156
|
+
batch_size (int): Number of records to insert each batch.
|
|
157
|
+
Larger batches are more efficient but use more memory.
|
|
158
|
+
(default: :obj:`100`)
|
|
159
|
+
**kwargs (Any): Additional keyword arguments.
|
|
160
|
+
|
|
161
|
+
Raises:
|
|
162
|
+
RuntimeError: If there is an error during the saving process.
|
|
163
|
+
ValueError: If any vector dimension doesn't match vector_dim.
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
if not records:
|
|
167
|
+
return
|
|
168
|
+
|
|
169
|
+
try:
|
|
170
|
+
# Convert records to OceanBase format
|
|
171
|
+
data: List[Dict[str, Any]] = []
|
|
172
|
+
for i, record in enumerate(records):
|
|
173
|
+
# Validate vector dimensions
|
|
174
|
+
if len(record.vector) != self.vector_dim:
|
|
175
|
+
raise ValueError(
|
|
176
|
+
f"Vector at index {i} has dimension "
|
|
177
|
+
f"{len(record.vector)}, expected {self.vector_dim}"
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
item: Dict[str, Any] = {
|
|
181
|
+
"embedding": record.vector,
|
|
182
|
+
"metadata": record.payload or {},
|
|
183
|
+
}
|
|
184
|
+
# If id is specified, use it
|
|
185
|
+
if record.id:
|
|
186
|
+
try:
|
|
187
|
+
# If id is numeric, use it directly
|
|
188
|
+
item["id"] = int(record.id)
|
|
189
|
+
except ValueError:
|
|
190
|
+
# If id is not numeric, store it in payload
|
|
191
|
+
item["metadata"]["_id"] = record.id
|
|
192
|
+
|
|
193
|
+
data.append(item)
|
|
194
|
+
|
|
195
|
+
# Batch insert when reaching batch_size
|
|
196
|
+
if len(data) >= batch_size:
|
|
197
|
+
self._client.insert(self.table_name, data=data)
|
|
198
|
+
data = []
|
|
199
|
+
|
|
200
|
+
# Insert any remaining records
|
|
201
|
+
if data:
|
|
202
|
+
self._client.insert(self.table_name, data=data)
|
|
203
|
+
|
|
204
|
+
except ValueError as e:
|
|
205
|
+
# Re-raise ValueError for dimension mismatch
|
|
206
|
+
raise e
|
|
207
|
+
except Exception as e:
|
|
208
|
+
error_msg = f"Failed to add records to OceanBase: {e}"
|
|
209
|
+
logger.error(error_msg)
|
|
210
|
+
raise RuntimeError(error_msg)
|
|
211
|
+
|
|
212
|
+
def delete(
|
|
213
|
+
self,
|
|
214
|
+
ids: List[str],
|
|
215
|
+
**kwargs: Any,
|
|
216
|
+
) -> None:
|
|
217
|
+
r"""Deletes a list of vectors identified by their IDs from the storage.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
ids (List[str]): List of unique identifiers for the vectors to
|
|
221
|
+
be deleted.
|
|
222
|
+
**kwargs (Any): Additional keyword arguments.
|
|
223
|
+
|
|
224
|
+
Raises:
|
|
225
|
+
RuntimeError: If there is an error during the deletion process.
|
|
226
|
+
"""
|
|
227
|
+
if not ids:
|
|
228
|
+
return
|
|
229
|
+
|
|
230
|
+
try:
|
|
231
|
+
numeric_ids: List[int] = []
|
|
232
|
+
non_numeric_ids: List[str] = []
|
|
233
|
+
|
|
234
|
+
# Separate numeric and non-numeric IDs
|
|
235
|
+
for id_val in ids:
|
|
236
|
+
try:
|
|
237
|
+
numeric_ids.append(int(id_val))
|
|
238
|
+
except ValueError:
|
|
239
|
+
non_numeric_ids.append(id_val)
|
|
240
|
+
|
|
241
|
+
# Delete records with numeric IDs
|
|
242
|
+
if numeric_ids:
|
|
243
|
+
self._client.delete(self.table_name, ids=numeric_ids)
|
|
244
|
+
|
|
245
|
+
# Delete records with non-numeric IDs stored in metadata
|
|
246
|
+
if non_numeric_ids:
|
|
247
|
+
from sqlalchemy import text
|
|
248
|
+
|
|
249
|
+
for id_val in non_numeric_ids:
|
|
250
|
+
self._client.delete(
|
|
251
|
+
self.table_name,
|
|
252
|
+
where_clause=[
|
|
253
|
+
text(f"metadata->>'$.._id' = '{id_val}'")
|
|
254
|
+
],
|
|
255
|
+
)
|
|
256
|
+
except Exception as e:
|
|
257
|
+
error_msg = f"Failed to delete records from OceanBase: {e}"
|
|
258
|
+
logger.error(error_msg)
|
|
259
|
+
raise RuntimeError(error_msg)
|
|
260
|
+
|
|
261
|
+
def status(self) -> VectorDBStatus:
|
|
262
|
+
r"""Returns status of the vector database.
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
VectorDBStatus: The vector database status.
|
|
266
|
+
"""
|
|
267
|
+
try:
|
|
268
|
+
# Get count of records
|
|
269
|
+
result = self._client.perform_raw_text_sql(
|
|
270
|
+
f"SELECT COUNT(*) FROM {self.table_name}"
|
|
271
|
+
)
|
|
272
|
+
count: int = result.fetchone()[0]
|
|
273
|
+
|
|
274
|
+
return VectorDBStatus(
|
|
275
|
+
vector_dim=self.vector_dim, vector_count=count
|
|
276
|
+
)
|
|
277
|
+
except Exception as e:
|
|
278
|
+
error_msg = f"Failed to get status from OceanBase: {e}"
|
|
279
|
+
logger.error(error_msg)
|
|
280
|
+
raise RuntimeError(error_msg)
|
|
281
|
+
|
|
282
|
+
def query(
|
|
283
|
+
self,
|
|
284
|
+
query: VectorDBQuery,
|
|
285
|
+
**kwargs: Any,
|
|
286
|
+
) -> List[VectorDBQueryResult]:
|
|
287
|
+
r"""Searches for similar vectors in the storage based on the
|
|
288
|
+
provided query.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
query (VectorDBQuery): The query object containing the search
|
|
292
|
+
vector and the number of top similar vectors to retrieve.
|
|
293
|
+
**kwargs (Any): Additional keyword arguments.
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
List[VectorDBQueryResult]: A list of vectors retrieved from the
|
|
297
|
+
storage based on similarity to the query vector.
|
|
298
|
+
|
|
299
|
+
Raises:
|
|
300
|
+
RuntimeError: If there is an error during the query process.
|
|
301
|
+
ValueError: If the query vector dimension does not match the
|
|
302
|
+
storage dimension.
|
|
303
|
+
"""
|
|
304
|
+
from sqlalchemy import func
|
|
305
|
+
|
|
306
|
+
try:
|
|
307
|
+
# Get distance function name
|
|
308
|
+
distance_func_name: str = self._distance_func_map.get(
|
|
309
|
+
self.distance, "l2_distance"
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
distance_func = getattr(func, distance_func_name)
|
|
313
|
+
|
|
314
|
+
# Validate query vector dimensions
|
|
315
|
+
if len(query.query_vector) != self.vector_dim:
|
|
316
|
+
raise ValueError(
|
|
317
|
+
f"Query vector dimension {len(query.query_vector)} "
|
|
318
|
+
f"does not match storage dimension {self.vector_dim}"
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
results = self._client.ann_search(
|
|
322
|
+
table_name=self.table_name,
|
|
323
|
+
vec_data=query.query_vector,
|
|
324
|
+
vec_column_name="embedding",
|
|
325
|
+
distance_func=distance_func,
|
|
326
|
+
with_dist=True,
|
|
327
|
+
topk=query.top_k,
|
|
328
|
+
output_column_names=["id", "embedding", "metadata"],
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
# Convert results to VectorDBQueryResult format
|
|
332
|
+
query_results: List[VectorDBQueryResult] = []
|
|
333
|
+
for row in results:
|
|
334
|
+
try:
|
|
335
|
+
result_dict: Dict[str, Any] = dict(row._mapping)
|
|
336
|
+
|
|
337
|
+
# Extract data
|
|
338
|
+
id_val: str = str(result_dict["id"])
|
|
339
|
+
|
|
340
|
+
# Handle vector - ensure it's a proper list of floats
|
|
341
|
+
vector: Any = result_dict.get("embedding")
|
|
342
|
+
if isinstance(vector, str):
|
|
343
|
+
# If vector is a string, try to parse it
|
|
344
|
+
try:
|
|
345
|
+
if vector.startswith('[') and vector.endswith(']'):
|
|
346
|
+
# Remove brackets and split by commas
|
|
347
|
+
vector = [
|
|
348
|
+
float(x.strip())
|
|
349
|
+
for x in vector[1:-1].split(',')
|
|
350
|
+
]
|
|
351
|
+
except (ValueError, TypeError) as e:
|
|
352
|
+
logger.warning(
|
|
353
|
+
f"Failed to parse vector string: {e}"
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# Ensure we have a proper vector
|
|
357
|
+
if (
|
|
358
|
+
not isinstance(vector, list)
|
|
359
|
+
or len(vector) != self.vector_dim
|
|
360
|
+
):
|
|
361
|
+
logger.warning(
|
|
362
|
+
f"Invalid vector format, using zeros: {vector}"
|
|
363
|
+
)
|
|
364
|
+
vector = [0.0] * self.vector_dim
|
|
365
|
+
|
|
366
|
+
# Ensure metadata is a dictionary
|
|
367
|
+
metadata: Dict[str, Any] = result_dict.get("metadata", {})
|
|
368
|
+
if not isinstance(metadata, dict):
|
|
369
|
+
# Convert to dict if it's not already
|
|
370
|
+
try:
|
|
371
|
+
if isinstance(metadata, str):
|
|
372
|
+
metadata = json.loads(metadata)
|
|
373
|
+
else:
|
|
374
|
+
metadata = {"value": metadata}
|
|
375
|
+
except Exception:
|
|
376
|
+
metadata = {"value": str(metadata)}
|
|
377
|
+
|
|
378
|
+
distance_value: Optional[float] = None
|
|
379
|
+
for key in result_dict:
|
|
380
|
+
if (
|
|
381
|
+
key.endswith(distance_func_name)
|
|
382
|
+
or distance_func_name in key
|
|
383
|
+
):
|
|
384
|
+
distance_value = float(result_dict[key])
|
|
385
|
+
break
|
|
386
|
+
|
|
387
|
+
if distance_value is None:
|
|
388
|
+
# If we can't find the distance, use a default value
|
|
389
|
+
logger.warning(
|
|
390
|
+
"Could not find distance value in query results, "
|
|
391
|
+
"using default"
|
|
392
|
+
)
|
|
393
|
+
distance_value = 0.0
|
|
394
|
+
|
|
395
|
+
similarity: float = self._convert_distance_to_similarity(
|
|
396
|
+
distance_value
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
# Check if the id is stored in metadata
|
|
400
|
+
if isinstance(metadata, dict) and "_id" in metadata:
|
|
401
|
+
id_val = metadata.pop("_id")
|
|
402
|
+
|
|
403
|
+
# Create query result
|
|
404
|
+
query_results.append(
|
|
405
|
+
VectorDBQueryResult.create(
|
|
406
|
+
similarity=similarity,
|
|
407
|
+
vector=vector,
|
|
408
|
+
id=id_val,
|
|
409
|
+
payload=metadata,
|
|
410
|
+
)
|
|
411
|
+
)
|
|
412
|
+
except Exception as e:
|
|
413
|
+
logger.warning(f"Failed to process result row: {e}")
|
|
414
|
+
continue
|
|
415
|
+
|
|
416
|
+
return query_results
|
|
417
|
+
except Exception as e:
|
|
418
|
+
error_msg = f"Failed to query OceanBase: {e}"
|
|
419
|
+
logger.error(error_msg)
|
|
420
|
+
raise RuntimeError(error_msg)
|
|
421
|
+
|
|
422
|
+
def _convert_distance_to_similarity(self, distance: float) -> float:
|
|
423
|
+
r"""Converts distance to similarity score based on distance metric."""
|
|
424
|
+
# Ensure distance is non-negative
|
|
425
|
+
distance = max(0.0, distance)
|
|
426
|
+
|
|
427
|
+
if self.distance == "cosine":
|
|
428
|
+
# Cosine distance = 1 - cosine similarity
|
|
429
|
+
# Ensure similarity is between 0 and 1
|
|
430
|
+
return max(0.0, min(1.0, 1.0 - distance))
|
|
431
|
+
elif self.distance == "l2":
|
|
432
|
+
import math
|
|
433
|
+
|
|
434
|
+
# Exponential decay function for L2 distance
|
|
435
|
+
return math.exp(-distance)
|
|
436
|
+
else:
|
|
437
|
+
# Default normalization, ensure result is between 0 and 1
|
|
438
|
+
return max(0.0, min(1.0, 1.0 - min(1.0, distance)))
|
|
439
|
+
|
|
440
|
+
def clear(self) -> None:
|
|
441
|
+
r"""Remove all vectors from the storage."""
|
|
442
|
+
try:
|
|
443
|
+
self._client.delete(self.table_name)
|
|
444
|
+
logger.info(f"Cleared all records from table {self.table_name}")
|
|
445
|
+
except Exception as e:
|
|
446
|
+
error_msg = f"Failed to clear records from OceanBase: {e}"
|
|
447
|
+
logger.error(error_msg)
|
|
448
|
+
raise RuntimeError(error_msg)
|
|
449
|
+
|
|
450
|
+
def load(self) -> None:
|
|
451
|
+
r"""Load the collection hosted on cloud service."""
|
|
452
|
+
# OceanBase doesn't require explicit loading
|
|
453
|
+
pass
|
|
454
|
+
|
|
455
|
+
@property
|
|
456
|
+
def client(self) -> "ObVecClient":
|
|
457
|
+
r"""Provides access to underlying OceanBase vector database client."""
|
|
458
|
+
return self._client
|