clap-agents 0.1.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clap/__init__.py +13 -42
- clap/embedding/__init__.py +21 -0
- clap/embedding/base_embedding.py +28 -0
- clap/embedding/fastembed_embedding.py +75 -0
- clap/embedding/ollama_embedding.py +76 -0
- clap/embedding/sentence_transformer_embedding.py +44 -0
- clap/llm_services/__init__.py +15 -0
- clap/llm_services/base.py +3 -6
- clap/llm_services/google_openai_compat_service.py +1 -5
- clap/llm_services/groq_service.py +5 -13
- clap/llm_services/ollama_service.py +101 -0
- clap/mcp_client/client.py +13 -25
- clap/multiagent_pattern/agent.py +107 -34
- clap/multiagent_pattern/team.py +54 -29
- clap/react_pattern/react_agent.py +339 -126
- clap/tool_pattern/tool.py +94 -165
- clap/tool_pattern/tool_agent.py +171 -171
- clap/tools/__init__.py +1 -1
- clap/tools/email_tools.py +16 -19
- clap/tools/web_crawler.py +26 -18
- clap/utils/completions.py +35 -37
- clap/utils/extraction.py +3 -3
- clap/utils/rag_utils.py +183 -0
- clap/vector_stores/__init__.py +16 -0
- clap/vector_stores/base.py +85 -0
- clap/vector_stores/chroma_store.py +142 -0
- clap/vector_stores/qdrant_store.py +155 -0
- {clap_agents-0.1.1.dist-info → clap_agents-0.2.2.dist-info}/METADATA +201 -23
- clap_agents-0.2.2.dist-info/RECORD +38 -0
- clap_agents-0.1.1.dist-info/RECORD +0 -27
- {clap_agents-0.1.1.dist-info → clap_agents-0.2.2.dist-info}/WHEEL +0 -0
- {clap_agents-0.1.1.dist-info → clap_agents-0.2.2.dist-info}/licenses/LICENSE +0 -0
clap/tools/web_crawler.py
CHANGED
@@ -1,17 +1,22 @@
|
|
1
|
-
|
2
1
|
import asyncio
|
3
2
|
import json
|
4
3
|
import os
|
5
4
|
from dotenv import load_dotenv
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
from clap.tool_pattern.tool import tool
|
6
8
|
|
7
|
-
|
9
|
+
_CRAWL4AI_AVAILABLE = False
|
10
|
+
_AsyncWebCrawler_Placeholder_Type = Any
|
8
11
|
|
9
12
|
try:
|
10
|
-
from crawl4ai import AsyncWebCrawler
|
13
|
+
from crawl4ai import AsyncWebCrawler as ImportedAsyncWebCrawler
|
14
|
+
_AsyncWebCrawler_Placeholder_Type = ImportedAsyncWebCrawler
|
15
|
+
_CRAWL4AI_AVAILABLE = True
|
11
16
|
except ImportError:
|
12
|
-
|
17
|
+
pass
|
13
18
|
|
14
|
-
load_dotenv()
|
19
|
+
load_dotenv()
|
15
20
|
|
16
21
|
@tool
|
17
22
|
async def scrape_url(url: str) -> str:
|
@@ -24,9 +29,17 @@ async def scrape_url(url: str) -> str:
|
|
24
29
|
Returns:
|
25
30
|
The webpage content in markdown format or an error message.
|
26
31
|
"""
|
32
|
+
if not _CRAWL4AI_AVAILABLE:
|
33
|
+
raise ImportError("The 'crawl4ai' library is required for scrape_url. Install with 'pip install \"clap-agents[standard_tools]\"' or 'pip install crawl4ai'.")
|
34
|
+
|
27
35
|
try:
|
28
|
-
|
29
|
-
|
36
|
+
crawler: _AsyncWebCrawler_Placeholder_Type = _AsyncWebCrawler_Placeholder_Type() # type: ignore
|
37
|
+
|
38
|
+
if not hasattr(crawler, 'arun') or not hasattr(crawler, 'close'): # Basic check
|
39
|
+
raise RuntimeError("AsyncWebCrawler from crawl4ai is not correctly initialized (likely due to missing dependency).")
|
40
|
+
|
41
|
+
async with crawler: # type: ignore
|
42
|
+
result = await crawler.arun(url=url) # type: ignore
|
30
43
|
return result.markdown.raw_markdown if result.markdown else "No content found"
|
31
44
|
except Exception as e:
|
32
45
|
return f"Error scraping URL '{url}': {str(e)}"
|
@@ -44,32 +57,29 @@ async def extract_text_by_query(url: str, query: str, context_size: int = 300) -
|
|
44
57
|
Returns:
|
45
58
|
Relevant text snippets containing the query or a message indicating no matches/content.
|
46
59
|
"""
|
60
|
+
if not _CRAWL4AI_AVAILABLE:
|
61
|
+
raise ImportError("The 'crawl4ai' library is required for extract_text_by_query. Install with 'pip install \"clap-agents[standard_tools]\"' or 'pip install crawl4ai'.")
|
62
|
+
|
47
63
|
try:
|
48
|
-
markdown_content = await scrape_url
|
64
|
+
markdown_content = await scrape_url(url=url)
|
49
65
|
|
50
66
|
if not markdown_content or markdown_content == "No content found" or markdown_content.startswith("Error"):
|
51
|
-
|
52
|
-
return markdown_content if markdown_content.startswith("Error") else f"Could not retrieve content from URL: {url}"
|
67
|
+
return markdown_content
|
53
68
|
|
54
69
|
lower_query = query.lower()
|
55
70
|
lower_content = markdown_content.lower()
|
56
71
|
matches = []
|
57
72
|
start_index = 0
|
58
|
-
|
59
73
|
while len(matches) < 5: # Limit matches
|
60
74
|
pos = lower_content.find(lower_query, start_index)
|
61
|
-
if pos == -1:
|
62
|
-
break
|
63
|
-
|
75
|
+
if pos == -1: break
|
64
76
|
start = max(0, pos - context_size)
|
65
77
|
end = min(len(markdown_content), pos + len(lower_query) + context_size)
|
66
78
|
context_snippet = markdown_content[start:end]
|
67
79
|
prefix = "..." if start > 0 else ""
|
68
80
|
suffix = "..." if end < len(markdown_content) else ""
|
69
81
|
matches.append(f"{prefix}{context_snippet}{suffix}")
|
70
|
-
|
71
82
|
start_index = pos + len(lower_query)
|
72
|
-
|
73
83
|
if matches:
|
74
84
|
result_text = "\n\n---\n\n".join([f"Match {i+1}:\n{match}" for i, match in enumerate(matches)])
|
75
85
|
return f"Found {len(matches)} matches for '{query}' on the page:\n\n{result_text}"
|
@@ -77,6 +87,4 @@ async def extract_text_by_query(url: str, query: str, context_size: int = 300) -
|
|
77
87
|
return f"No matches found for '{query}' on the page."
|
78
88
|
|
79
89
|
except Exception as e:
|
80
|
-
# Catch potential errors during the find/string manipulation logic
|
81
90
|
return f"Error processing content from '{url}' for query '{query}': {str(e)}"
|
82
|
-
|
clap/utils/completions.py
CHANGED
@@ -2,10 +2,9 @@
|
|
2
2
|
|
3
3
|
import asyncio
|
4
4
|
from typing import Optional, List, Dict, Any
|
5
|
-
#
|
6
|
-
# from groq import
|
7
|
-
# from groq.types.chat.
|
8
|
-
# from groq.types.chat.chat_completion_message import ChatCompletionMessage # Example type hint
|
5
|
+
# from groq import Groq
|
6
|
+
# from groq.types.chat.chat_completion import ChatCompletion
|
7
|
+
# from groq.types.chat.chat_completion_message import ChatCompletionMessage
|
9
8
|
from groq import AsyncGroq
|
10
9
|
|
11
10
|
|
@@ -15,11 +14,11 @@ ChatCompletionMessage = Any
|
|
15
14
|
|
16
15
|
async def completions_create(
|
17
16
|
client: AsyncGroq,
|
18
|
-
messages: List[Dict[str, Any]],
|
17
|
+
messages: List[Dict[str, Any]],
|
19
18
|
model: str,
|
20
|
-
tools: Optional[List[Dict[str, Any]]] = None,
|
21
|
-
tool_choice: str = "auto"
|
22
|
-
) -> ChatCompletionMessage:
|
19
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
20
|
+
tool_choice: str = "auto"
|
21
|
+
) -> ChatCompletionMessage:
|
23
22
|
"""
|
24
23
|
Sends an asynchronous request to the client's completions endpoint, supporting tool use.
|
25
24
|
|
@@ -34,7 +33,7 @@ async def completions_create(
|
|
34
33
|
The message object from the API response, which might contain content or tool calls.
|
35
34
|
"""
|
36
35
|
try:
|
37
|
-
|
36
|
+
|
38
37
|
api_kwargs = {
|
39
38
|
"messages": messages,
|
40
39
|
"model": model,
|
@@ -43,15 +42,15 @@ async def completions_create(
|
|
43
42
|
api_kwargs["tools"] = tools
|
44
43
|
api_kwargs["tool_choice"] = tool_choice
|
45
44
|
|
46
|
-
|
45
|
+
|
47
46
|
response = await client.chat.completions.create(**api_kwargs)
|
48
|
-
|
47
|
+
|
49
48
|
return response.choices[0].message
|
50
49
|
except Exception as e:
|
51
|
-
|
50
|
+
|
52
51
|
print(f"Error calling LLM API asynchronously: {e}")
|
53
|
-
|
54
|
-
|
52
|
+
|
53
|
+
|
55
54
|
class ErrorMessage:
|
56
55
|
content = f"Error communicating with LLM: {e}"
|
57
56
|
tool_calls = None
|
@@ -61,10 +60,10 @@ async def completions_create(
|
|
61
60
|
|
62
61
|
def build_prompt_structure(
|
63
62
|
role: str,
|
64
|
-
content: Optional[str] = None,
|
63
|
+
content: Optional[str] = None,
|
65
64
|
tag: str = "",
|
66
|
-
tool_calls: Optional[List[Dict[str, Any]]] = None,
|
67
|
-
tool_call_id: Optional[str] = None
|
65
|
+
tool_calls: Optional[List[Dict[str, Any]]] = None,
|
66
|
+
tool_call_id: Optional[str] = None
|
68
67
|
) -> dict:
|
69
68
|
"""
|
70
69
|
Builds a structured message dictionary for the chat API.
|
@@ -85,17 +84,17 @@ def build_prompt_structure(
|
|
85
84
|
content = f"<{tag}>{content}</{tag}>"
|
86
85
|
message["content"] = content
|
87
86
|
|
88
|
-
|
87
|
+
|
89
88
|
if role == "assistant" and tool_calls:
|
90
89
|
message["tool_calls"] = tool_calls
|
91
90
|
|
92
|
-
|
91
|
+
|
93
92
|
if role == "tool" and tool_call_id:
|
94
93
|
message["tool_call_id"] = tool_call_id
|
95
|
-
if content is None:
|
94
|
+
if content is None:
|
96
95
|
raise ValueError("Content is required for role 'tool'.")
|
97
96
|
|
98
|
-
|
97
|
+
|
99
98
|
if role == "tool" and not tool_call_id:
|
100
99
|
raise ValueError("tool_call_id is required for role 'tool'.")
|
101
100
|
if role != "assistant" and tool_calls:
|
@@ -106,7 +105,7 @@ def build_prompt_structure(
|
|
106
105
|
|
107
106
|
def update_chat_history(
|
108
107
|
history: list,
|
109
|
-
message: ChatCompletionMessage | Dict[str, Any]
|
108
|
+
message: ChatCompletionMessage | Dict[str, Any]
|
110
109
|
):
|
111
110
|
"""
|
112
111
|
Updates the chat history by appending a message object or a manually created message dict.
|
@@ -115,37 +114,37 @@ def update_chat_history(
|
|
115
114
|
history (list): The list representing the current chat history.
|
116
115
|
message: The message object from the API response or a dict created by build_prompt_structure.
|
117
116
|
"""
|
118
|
-
|
117
|
+
|
119
118
|
if hasattr(message, "role"): # Basic check if it looks like an API message object
|
120
119
|
msg_dict = {"role": message.role}
|
121
120
|
if hasattr(message, "content") and message.content is not None:
|
122
121
|
msg_dict["content"] = message.content
|
123
122
|
if hasattr(message, "tool_calls") and message.tool_calls:
|
124
|
-
|
123
|
+
|
125
124
|
msg_dict["tool_calls"] = message.tool_calls
|
126
|
-
|
125
|
+
|
127
126
|
history.append(msg_dict)
|
128
127
|
elif isinstance(message, dict) and "role" in message:
|
129
|
-
|
128
|
+
|
130
129
|
history.append(message)
|
131
130
|
else:
|
132
131
|
raise TypeError("Invalid message type provided to update_chat_history.")
|
133
132
|
|
134
133
|
|
135
134
|
class ChatHistory(list):
|
136
|
-
def __init__(self, messages: Optional[List[Dict[str, Any]]] = None, total_length: int = -1):
|
135
|
+
def __init__(self, messages: Optional[List[Dict[str, Any]]] = None, total_length: int = -1):
|
137
136
|
if messages is None:
|
138
137
|
messages = []
|
139
138
|
super().__init__(messages)
|
140
|
-
self.total_length = total_length
|
139
|
+
self.total_length = total_length
|
141
140
|
|
142
|
-
def append(self, msg: Dict[str, Any]):
|
141
|
+
def append(self, msg: Dict[str, Any]):
|
143
142
|
if not isinstance(msg, dict) or "role" not in msg:
|
144
143
|
raise TypeError("ChatHistory can only append message dictionaries with a 'role'.")
|
145
144
|
|
146
|
-
|
145
|
+
|
147
146
|
if self.total_length > 0 and len(self) == self.total_length:
|
148
|
-
self.pop(0)
|
147
|
+
self.pop(0)
|
149
148
|
super().append(msg)
|
150
149
|
|
151
150
|
|
@@ -157,17 +156,16 @@ class FixedFirstChatHistory(ChatHistory):
|
|
157
156
|
if not isinstance(msg, dict) or "role" not in msg:
|
158
157
|
raise TypeError("ChatHistory can only append message dictionaries with a 'role'.")
|
159
158
|
|
160
|
-
|
159
|
+
|
161
160
|
if self.total_length > 0 and len(self) == self.total_length:
|
162
|
-
if len(self) > 1:
|
163
|
-
self.pop(1)
|
161
|
+
if len(self) > 1:
|
162
|
+
self.pop(1)
|
164
163
|
else:
|
165
|
-
|
164
|
+
|
166
165
|
print("Warning: Cannot append to FixedFirstChatHistory of size 1.")
|
167
166
|
return
|
168
|
-
|
167
|
+
|
169
168
|
if self.total_length <= 0 or len(self) < self.total_length:
|
170
169
|
super().append(msg)
|
171
170
|
|
172
171
|
|
173
|
-
# --- END OF ASYNC MODIFIED completions.py ---
|
clap/utils/extraction.py
CHANGED
@@ -29,13 +29,13 @@ def extract_tag_content(text: str, tag: str) -> TagContentResult:
|
|
29
29
|
- 'content' (list): A list of strings containing the content found between the specified tags.
|
30
30
|
- 'found' (bool): A flag indicating whether any content was found for the given tag.
|
31
31
|
"""
|
32
|
-
|
32
|
+
|
33
33
|
tag_pattern = rf"<{tag}>(.*?)</{tag}>"
|
34
34
|
|
35
|
-
|
35
|
+
|
36
36
|
matched_contents = re.findall(tag_pattern, text, re.DOTALL)
|
37
37
|
|
38
|
-
|
38
|
+
|
39
39
|
return TagContentResult(
|
40
40
|
content=[content.strip() for content in matched_contents],
|
41
41
|
found=bool(matched_contents),
|
clap/utils/rag_utils.py
ADDED
@@ -0,0 +1,183 @@
|
|
1
|
+
|
2
|
+
import csv
|
3
|
+
from typing import List, Dict, Any, Tuple, Optional, Union
|
4
|
+
|
5
|
+
|
6
|
+
try:
|
7
|
+
import pypdf
|
8
|
+
except ImportError:
|
9
|
+
raise ImportError(
|
10
|
+
"pypdf not found. Please install it for PDF loading: pip install pypdf"
|
11
|
+
)
|
12
|
+
|
13
|
+
|
14
|
+
|
15
|
+
def load_text_file(file_path: str) -> str:
|
16
|
+
"""Loads text content from a file."""
|
17
|
+
try:
|
18
|
+
with open(file_path, 'r', encoding='utf-8') as f:
|
19
|
+
return f.read()
|
20
|
+
except Exception as e:
|
21
|
+
print(f"Error loading text file {file_path}: {e}")
|
22
|
+
return ""
|
23
|
+
|
24
|
+
def load_pdf_file(file_path: str) -> str:
|
25
|
+
"""Loads text content from a PDF file."""
|
26
|
+
text = ""
|
27
|
+
try:
|
28
|
+
with open(file_path, 'rb') as f:
|
29
|
+
reader = pypdf.PdfReader(f)
|
30
|
+
print(f"Loading PDF '{file_path}' with {len(reader.pages)} pages...")
|
31
|
+
for i, page in enumerate(reader.pages):
|
32
|
+
page_text = page.extract_text()
|
33
|
+
if page_text:
|
34
|
+
text += page_text + "\n"
|
35
|
+
else:
|
36
|
+
print(f"Warning: No text extracted from page {i+1} of {file_path}")
|
37
|
+
print(f"Finished loading PDF '{file_path}'.")
|
38
|
+
return text.strip()
|
39
|
+
except FileNotFoundError:
|
40
|
+
print(f"Error: PDF file not found at {file_path}")
|
41
|
+
return ""
|
42
|
+
except Exception as e:
|
43
|
+
print(f"Error loading PDF file {file_path}: {e}")
|
44
|
+
return ""
|
45
|
+
|
46
|
+
|
47
|
+
def load_csv_file(
|
48
|
+
file_path: str,
|
49
|
+
content_column: Union[str, int],
|
50
|
+
metadata_columns: Optional[List[Union[str, int]]] = None,
|
51
|
+
delimiter: str = ',',
|
52
|
+
encoding: str = 'utf-8'
|
53
|
+
) -> List[Tuple[str, Dict[str, Any]]]:
|
54
|
+
"""
|
55
|
+
Loads data from a CSV file, extracting content and metadata.
|
56
|
+
|
57
|
+
Each row is treated as a potential document/chunk.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
file_path: Path to the CSV file.
|
61
|
+
content_column: The name (string) or index (int) of the column containing the main text content.
|
62
|
+
metadata_columns: Optional list of column names (string) or indices (int)
|
63
|
+
to include as metadata for each row.
|
64
|
+
delimiter: CSV delimiter (default ',').
|
65
|
+
encoding: File encoding (default 'utf-8').
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
A list of tuples, where each tuple contains:
|
69
|
+
(document_text: str, metadata: dict)
|
70
|
+
"""
|
71
|
+
data = []
|
72
|
+
metadata_columns = metadata_columns or []
|
73
|
+
try:
|
74
|
+
with open(file_path, mode='r', encoding=encoding, newline='') as f:
|
75
|
+
|
76
|
+
has_header = isinstance(content_column, str) or any(isinstance(mc, str) for mc in metadata_columns)
|
77
|
+
|
78
|
+
if has_header:
|
79
|
+
reader = csv.DictReader(f, delimiter=delimiter)
|
80
|
+
headers = reader.fieldnames
|
81
|
+
if headers is None:
|
82
|
+
print(f"Error: Could not read headers from CSV {file_path}")
|
83
|
+
return []
|
84
|
+
|
85
|
+
|
86
|
+
if isinstance(content_column, str) and content_column not in headers:
|
87
|
+
raise ValueError(f"Content column '{content_column}' not found in CSV headers: {headers}")
|
88
|
+
for mc in metadata_columns:
|
89
|
+
if isinstance(mc, str) and mc not in headers:
|
90
|
+
raise ValueError(f"Metadata column '{mc}' not found in CSV headers: {headers}")
|
91
|
+
|
92
|
+
content_key = content_column
|
93
|
+
meta_keys = [mc for mc in metadata_columns if isinstance(mc, str)]
|
94
|
+
|
95
|
+
else:
|
96
|
+
|
97
|
+
reader = csv.reader(f, delimiter=delimiter)
|
98
|
+
|
99
|
+
content_key = int(content_column)
|
100
|
+
meta_keys = [int(mc) for mc in metadata_columns]
|
101
|
+
|
102
|
+
|
103
|
+
print(f"Loading CSV '{file_path}'...")
|
104
|
+
for i, row in enumerate(reader):
|
105
|
+
try:
|
106
|
+
if has_header:
|
107
|
+
|
108
|
+
doc_text = row.get(content_key, "").strip()
|
109
|
+
metadata = {key: row.get(key, "") for key in meta_keys}
|
110
|
+
else:
|
111
|
+
|
112
|
+
if content_key >= len(row): continue
|
113
|
+
doc_text = row[content_key].strip()
|
114
|
+
metadata = {}
|
115
|
+
for key_index in meta_keys:
|
116
|
+
if key_index < len(row):
|
117
|
+
|
118
|
+
metadata[f"column_{key_index}"] = row[key_index]
|
119
|
+
|
120
|
+
|
121
|
+
if doc_text:
|
122
|
+
metadata["source_row"] = i + (1 if has_header else 0)
|
123
|
+
data.append((doc_text, metadata))
|
124
|
+
except IndexError:
|
125
|
+
print(f"Warning: Skipping row {i} due to index out of bounds (check column indices).")
|
126
|
+
except Exception as row_e:
|
127
|
+
print(f"Warning: Skipping row {i} due to error: {row_e}")
|
128
|
+
|
129
|
+
|
130
|
+
print(f"Finished loading CSV '{file_path}', processed {len(data)} rows with content.")
|
131
|
+
return data
|
132
|
+
|
133
|
+
except FileNotFoundError:
|
134
|
+
print(f"Error: CSV file not found at {file_path}")
|
135
|
+
return []
|
136
|
+
except ValueError as ve:
|
137
|
+
print(f"Error processing CSV header/indices for {file_path}: {ve}")
|
138
|
+
return []
|
139
|
+
except Exception as e:
|
140
|
+
print(f"Error loading CSV file {file_path}: {e}")
|
141
|
+
return []
|
142
|
+
|
143
|
+
|
144
|
+
|
145
|
+
|
146
|
+
def chunk_text_by_fixed_size(
|
147
|
+
text: str, chunk_size: int, chunk_overlap: int = 0
|
148
|
+
) -> List[str]:
|
149
|
+
"""Chunks text into fixed size blocks with optional overlap."""
|
150
|
+
if not isinstance(text, str):
|
151
|
+
print(f"Warning: chunk_text_by_fixed_size expected string, got {type(text)}. Skipping.")
|
152
|
+
return []
|
153
|
+
if chunk_overlap >= chunk_size:
|
154
|
+
|
155
|
+
raise ValueError("chunk_overlap must be smaller than chunk_size")
|
156
|
+
if chunk_size <= 0:
|
157
|
+
raise ValueError("chunk_size must be positive")
|
158
|
+
|
159
|
+
chunks = []
|
160
|
+
start = 0
|
161
|
+
while start < len(text):
|
162
|
+
end = start + chunk_size
|
163
|
+
chunks.append(text[start:end])
|
164
|
+
|
165
|
+
step = chunk_size - chunk_overlap
|
166
|
+
if step <= 0:
|
167
|
+
|
168
|
+
step = 1
|
169
|
+
|
170
|
+
start += step
|
171
|
+
|
172
|
+
return [chunk for chunk in chunks if chunk.strip()]
|
173
|
+
|
174
|
+
|
175
|
+
def chunk_text_by_separator(text: str, separator: str = "\n\n") -> List[str]:
|
176
|
+
"""Chunks text based on a specified separator."""
|
177
|
+
if not isinstance(text, str):
|
178
|
+
print(f"Warning: chunk_text_by_separator expected string, got {type(text)}. Skipping.")
|
179
|
+
return []
|
180
|
+
chunks = text.split(separator)
|
181
|
+
return [chunk for chunk in chunks if chunk.strip()]
|
182
|
+
|
183
|
+
|
@@ -0,0 +1,16 @@
|
|
1
|
+
from .base import VectorStoreInterface, QueryResult, Document, Embedding, ID, Metadata
|
2
|
+
|
3
|
+
__all__ = ["VectorStoreInterface", "QueryResult", "Document", "Embedding", "ID", "Metadata"]
|
4
|
+
|
5
|
+
try:
|
6
|
+
from .chroma_store import ChromaStore
|
7
|
+
__all__.append("ChromaStore")
|
8
|
+
except ImportError:
|
9
|
+
pass
|
10
|
+
|
11
|
+
try:
|
12
|
+
from .qdrant_store import QdrantStore
|
13
|
+
__all__.append("QdrantStore")
|
14
|
+
except ImportError:
|
15
|
+
pass
|
16
|
+
|
@@ -0,0 +1,85 @@
|
|
1
|
+
|
2
|
+
import abc
|
3
|
+
from typing import Any, Dict, List, Optional, TypedDict, Union
|
4
|
+
|
5
|
+
Document = str
|
6
|
+
Embedding = List[float]
|
7
|
+
ID = str
|
8
|
+
Metadata = Dict[str, Any]
|
9
|
+
|
10
|
+
class QueryResult(TypedDict):
|
11
|
+
ids: List[List[ID]]
|
12
|
+
embeddings: Optional[List[List[Embedding]]]
|
13
|
+
documents: Optional[List[List[Document]]]
|
14
|
+
metadatas: Optional[List[List[Metadata]]]
|
15
|
+
distances: Optional[List[List[float]]]
|
16
|
+
|
17
|
+
class VectorStoreInterface(abc.ABC):
|
18
|
+
"""Abstract Base Class for Vector Store interactions."""
|
19
|
+
|
20
|
+
@abc.abstractmethod
|
21
|
+
async def add_documents(
|
22
|
+
self,
|
23
|
+
documents: List[Document],
|
24
|
+
ids: List[ID],
|
25
|
+
metadatas: Optional[List[Metadata]] = None,
|
26
|
+
embeddings: Optional[List[Embedding]] = None,
|
27
|
+
) -> None:
|
28
|
+
"""
|
29
|
+
Add documents and their embeddings to the store.
|
30
|
+
If embeddings are not provided, the implementation should handle embedding generation.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
documents: List of document texts.
|
34
|
+
ids: List of unique IDs for each document.
|
35
|
+
metadatas: Optional list of metadata dictionaries for each document.
|
36
|
+
embeddings: Optional list of pre-computed embeddings.
|
37
|
+
"""
|
38
|
+
pass
|
39
|
+
|
40
|
+
@abc.abstractmethod
|
41
|
+
async def aquery(
|
42
|
+
self,
|
43
|
+
query_texts: Optional[List[Document]] = None,
|
44
|
+
query_embeddings: Optional[List[Embedding]] = None,
|
45
|
+
n_results: int = 5,
|
46
|
+
where: Optional[Dict[str, Any]] = None,
|
47
|
+
where_document: Optional[Dict[str, Any]] = None,
|
48
|
+
include: List[str] = ["metadatas", "documents", "distances"],
|
49
|
+
) -> QueryResult:
|
50
|
+
"""
|
51
|
+
Query the vector store for similar documents.
|
52
|
+
Provide either query_texts or query_embeddings.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
query_texts: List of query texts. Embeddings will be generated.
|
56
|
+
query_embeddings: List of query embeddings.
|
57
|
+
n_results: Number of results to return for each query.
|
58
|
+
where: Optional metadata filter (syntax depends on implementation).
|
59
|
+
where_document: Optional document content filter (syntax depends on implementation).
|
60
|
+
include: List of fields to include in the results (e.g., "documents", "metadatas", "distances", "embeddings").
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
A QueryResult dictionary containing the search results.
|
64
|
+
"""
|
65
|
+
pass
|
66
|
+
|
67
|
+
@abc.abstractmethod
|
68
|
+
async def adelete(
|
69
|
+
self,
|
70
|
+
ids: Optional[List[ID]] = None,
|
71
|
+
where: Optional[Dict[str, Any]] = None,
|
72
|
+
where_document: Optional[Dict[str, Any]] = None,
|
73
|
+
) -> None:
|
74
|
+
"""
|
75
|
+
Delete documents from the store by ID or filter.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
ids: Optional list of IDs to delete.
|
79
|
+
where: Optional metadata filter for deletion.
|
80
|
+
where_document: Optional document content filter for deletion.
|
81
|
+
"""
|
82
|
+
pass
|
83
|
+
|
84
|
+
|
85
|
+
|