ws-bom-robot-app 0.0.63__py3-none-any.whl → 0.0.103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ws_bom_robot_app/config.py +30 -8
- ws_bom_robot_app/cron_manager.py +13 -12
- ws_bom_robot_app/llm/agent_context.py +1 -1
- ws_bom_robot_app/llm/agent_handler.py +11 -12
- ws_bom_robot_app/llm/agent_lcel.py +80 -18
- ws_bom_robot_app/llm/api.py +69 -7
- ws_bom_robot_app/llm/evaluator.py +319 -0
- ws_bom_robot_app/llm/main.py +51 -28
- ws_bom_robot_app/llm/models/api.py +40 -6
- ws_bom_robot_app/llm/nebuly_handler.py +18 -15
- ws_bom_robot_app/llm/providers/llm_manager.py +233 -75
- ws_bom_robot_app/llm/tools/tool_builder.py +4 -1
- ws_bom_robot_app/llm/tools/tool_manager.py +48 -22
- ws_bom_robot_app/llm/utils/chunker.py +6 -1
- ws_bom_robot_app/llm/utils/cleanup.py +81 -0
- ws_bom_robot_app/llm/utils/cms.py +60 -14
- ws_bom_robot_app/llm/utils/download.py +112 -8
- ws_bom_robot_app/llm/vector_store/db/base.py +50 -0
- ws_bom_robot_app/llm/vector_store/db/chroma.py +28 -8
- ws_bom_robot_app/llm/vector_store/db/faiss.py +35 -8
- ws_bom_robot_app/llm/vector_store/db/qdrant.py +29 -14
- ws_bom_robot_app/llm/vector_store/integration/api.py +216 -0
- ws_bom_robot_app/llm/vector_store/integration/azure.py +1 -1
- ws_bom_robot_app/llm/vector_store/integration/base.py +58 -15
- ws_bom_robot_app/llm/vector_store/integration/confluence.py +33 -5
- ws_bom_robot_app/llm/vector_store/integration/dropbox.py +1 -1
- ws_bom_robot_app/llm/vector_store/integration/gcs.py +1 -1
- ws_bom_robot_app/llm/vector_store/integration/github.py +22 -22
- ws_bom_robot_app/llm/vector_store/integration/googledrive.py +46 -17
- ws_bom_robot_app/llm/vector_store/integration/jira.py +93 -60
- ws_bom_robot_app/llm/vector_store/integration/manager.py +6 -2
- ws_bom_robot_app/llm/vector_store/integration/s3.py +1 -1
- ws_bom_robot_app/llm/vector_store/integration/sftp.py +1 -1
- ws_bom_robot_app/llm/vector_store/integration/sharepoint.py +7 -14
- ws_bom_robot_app/llm/vector_store/integration/shopify.py +143 -0
- ws_bom_robot_app/llm/vector_store/integration/sitemap.py +6 -1
- ws_bom_robot_app/llm/vector_store/integration/slack.py +3 -2
- ws_bom_robot_app/llm/vector_store/integration/thron.py +236 -0
- ws_bom_robot_app/llm/vector_store/loader/base.py +52 -8
- ws_bom_robot_app/llm/vector_store/loader/docling.py +71 -33
- ws_bom_robot_app/main.py +148 -146
- ws_bom_robot_app/subprocess_runner.py +106 -0
- ws_bom_robot_app/task_manager.py +204 -53
- ws_bom_robot_app/util.py +6 -0
- {ws_bom_robot_app-0.0.63.dist-info → ws_bom_robot_app-0.0.103.dist-info}/METADATA +158 -75
- ws_bom_robot_app-0.0.103.dist-info/RECORD +76 -0
- ws_bom_robot_app/llm/settings.py +0 -4
- ws_bom_robot_app/llm/utils/kb.py +0 -34
- ws_bom_robot_app-0.0.63.dist-info/RECORD +0 -72
- {ws_bom_robot_app-0.0.63.dist-info → ws_bom_robot_app-0.0.103.dist-info}/WHEEL +0 -0
- {ws_bom_robot_app-0.0.63.dist-info → ws_bom_robot_app-0.0.103.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import logging, aiohttp
|
|
2
|
-
from typing import List, Optional
|
|
3
|
-
|
|
2
|
+
from typing import Any, List, Optional
|
|
4
3
|
from pydantic import AliasChoices, BaseModel, ConfigDict, Field
|
|
5
|
-
from ws_bom_robot_app.llm.models.api import LlmAppTool
|
|
4
|
+
from ws_bom_robot_app.llm.models.api import LlmAppTool, LlmRules, StreamRequest
|
|
5
|
+
from ws_bom_robot_app.llm.models.kb import LlmKbEndpoint, LlmKbIntegration
|
|
6
6
|
from ws_bom_robot_app.util import cache_with_ttl
|
|
7
7
|
|
|
8
8
|
class CmsAppCredential(BaseModel):
|
|
@@ -12,13 +12,16 @@ class CmsAppCredential(BaseModel):
|
|
|
12
12
|
class CmsApp(BaseModel):
|
|
13
13
|
id: str = Field(..., description="Unique identifier for the app")
|
|
14
14
|
name: str = Field(..., description="Name of the app")
|
|
15
|
+
mode: str
|
|
16
|
+
prompt_samples: Optional[List[str]]
|
|
15
17
|
credentials: CmsAppCredential = None
|
|
16
|
-
|
|
18
|
+
rq: StreamRequest
|
|
19
|
+
kb: Optional[Any] = None
|
|
17
20
|
model_config = ConfigDict(extra='ignore')
|
|
18
21
|
|
|
19
22
|
@cache_with_ttl(600) # Cache for 10 minutes
|
|
20
23
|
async def get_apps() -> list[CmsApp]:
|
|
21
|
-
import json
|
|
24
|
+
import json
|
|
22
25
|
from ws_bom_robot_app.config import config
|
|
23
26
|
class DictObject(object):
|
|
24
27
|
def __init__(self, dict_):
|
|
@@ -34,9 +37,17 @@ async def get_apps() -> list[CmsApp]:
|
|
|
34
37
|
if obj is None:
|
|
35
38
|
break
|
|
36
39
|
return obj
|
|
40
|
+
def __to_dict(obj):
|
|
41
|
+
"""Converts DictObject to dict recursively"""
|
|
42
|
+
if isinstance(obj, DictObject):
|
|
43
|
+
return {k: __to_dict(v) for k, v in obj.__dict__.items()}
|
|
44
|
+
elif isinstance(obj, list):
|
|
45
|
+
return [__to_dict(item) for item in obj]
|
|
46
|
+
else:
|
|
47
|
+
return obj
|
|
37
48
|
host = config.robot_cms_host
|
|
38
49
|
if host:
|
|
39
|
-
url = f"{host}/api/llmApp?depth=1&pagination=false"
|
|
50
|
+
url = f"{host}/api/llmApp?depth=1&pagination=false&locale=it"
|
|
40
51
|
auth = config.robot_cms_auth
|
|
41
52
|
headers = {"Authorization": auth} if auth else {}
|
|
42
53
|
async with aiohttp.ClientSession() as session:
|
|
@@ -47,15 +58,49 @@ async def get_apps() -> list[CmsApp]:
|
|
|
47
58
|
for cms_app in cms_apps:
|
|
48
59
|
if __attr(cms_app,"isActive",default=True) == True:
|
|
49
60
|
_cms_app_dict = DictObject.from_dict(cms_app)
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
61
|
+
try:
|
|
62
|
+
_app: CmsApp = CmsApp(
|
|
63
|
+
id=_cms_app_dict.id,
|
|
64
|
+
name=_cms_app_dict.name,
|
|
65
|
+
mode=_cms_app_dict.mode,
|
|
66
|
+
prompt_samples=[__attr(sample,'sampleInputText') or f"{sample.__dict__}" for sample in _cms_app_dict.contents.sampleInputTexts],
|
|
67
|
+
credentials=CmsAppCredential(app_key=_cms_app_dict.settings.credentials.appKey,api_key=_cms_app_dict.settings.credentials.apiKey),
|
|
68
|
+
rq=StreamRequest(
|
|
69
|
+
#thread_id=str(uuid.uuid1()),
|
|
70
|
+
messages=[],
|
|
71
|
+
secrets={
|
|
72
|
+
"apiKey": __attr(_cms_app_dict.settings,'llmConfig','secrets','apiKey', default=''),
|
|
73
|
+
"langChainApiKey": __attr(_cms_app_dict.settings,'llmConfig','secrets','langChainApiKey', default=''),
|
|
74
|
+
"nebulyApiKey": __attr(_cms_app_dict.settings,'llmConfig','secrets','nebulyApiKey', default=''),
|
|
75
|
+
},
|
|
76
|
+
system_message=__attr(_cms_app_dict.settings,'llmConfig','prompt','prompt','systemMessage') if __attr(_cms_app_dict.settings,'llmConfig','prompt','prompt','systemMessage') else __attr(_cms_app_dict.settings,'llmConfig','prompt','systemMessage'),
|
|
77
|
+
provider= __attr(_cms_app_dict.settings,'llmConfig','provider') or 'openai',
|
|
78
|
+
model= __attr(_cms_app_dict.settings,'llmConfig','model') or 'gpt-4o',
|
|
79
|
+
temperature=_cms_app_dict.settings.llmConfig.temperature or 0,
|
|
80
|
+
app_tools=[LlmAppTool(**tool) for tool in cms_app.get('settings').get('appTools',[])],
|
|
81
|
+
rules=LlmRules(
|
|
82
|
+
vector_type=__attr(_cms_app_dict.settings,'rules','vectorDbType', default='faiss'),
|
|
83
|
+
vector_db=__attr(_cms_app_dict.settings,'rules','vectorDbFile','filename'),
|
|
84
|
+
threshold=__attr(_cms_app_dict.settings,'rules','threshold', default=0.7)
|
|
85
|
+
) if __attr(_cms_app_dict.settings,'rules','vectorDbFile','filename') else None,
|
|
86
|
+
#fine_tuned_model=__attr(_cms_app_dict.settings,'llmConfig','fineTunedModel'),
|
|
87
|
+
lang_chain_tracing= __attr(_cms_app_dict.settings,'llmConfig','langChainTracing', default=False),
|
|
88
|
+
lang_chain_project= __attr(_cms_app_dict.settings,'llmConfig','langChainProject', default=''),
|
|
89
|
+
output_structure= __to_dict(__attr(_cms_app_dict.settings,'llmConfig','outputStructure')) if __attr(_cms_app_dict.settings,'llmConfig','outputStructure') else None
|
|
90
|
+
))
|
|
91
|
+
except Exception as e:
|
|
92
|
+
import traceback
|
|
93
|
+
ex = traceback.format_exc()
|
|
94
|
+
logging.error(f"Error creating CmsApp {_cms_app_dict.name} from dict: {e}\n{ex}")
|
|
95
|
+
continue
|
|
96
|
+
if _app.rq.app_tools:
|
|
97
|
+
for tool in _app.rq.app_tools:
|
|
58
98
|
_knowledgeBase = tool.knowledgeBase
|
|
99
|
+
tool.integrations = [LlmKbIntegration(**item) for item in _knowledgeBase.get('integrations')] if _knowledgeBase.get('integrations') else []
|
|
100
|
+
try:
|
|
101
|
+
tool.endpoints = [LlmKbEndpoint(**item) for item in _knowledgeBase.get('externalEndpoints')] if _knowledgeBase.get('externalEndpoints') else []
|
|
102
|
+
except Exception as e:
|
|
103
|
+
logging.error(f"Error parsing endpoints for app {_cms_app_dict.name} tool {tool.name}: {e}")
|
|
59
104
|
tool.vector_db = _knowledgeBase.get('vectorDbFile').get('filename') if _knowledgeBase.get('vectorDbFile') else None
|
|
60
105
|
tool.vector_type = _knowledgeBase.get('vectorDbType') if _knowledgeBase.get('vectorDbType') else 'faiss'
|
|
61
106
|
del tool.knowledgeBase
|
|
@@ -67,6 +112,7 @@ async def get_apps() -> list[CmsApp]:
|
|
|
67
112
|
logging.error("robot_cms_host environment variable is not set.")
|
|
68
113
|
return []
|
|
69
114
|
|
|
115
|
+
|
|
70
116
|
async def get_app_by_id(app_id: str) -> CmsApp | None:
|
|
71
117
|
apps = await get_apps()
|
|
72
118
|
app = next((a for a in apps if a.id == app_id), None)
|
|
@@ -1,6 +1,13 @@
|
|
|
1
|
+
import httpx
|
|
1
2
|
from typing import List,Optional
|
|
2
|
-
import os, logging, aiohttp, asyncio
|
|
3
|
+
import os, logging, aiohttp, asyncio, hashlib, json
|
|
4
|
+
import uuid
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
import base64, requests, mimetypes
|
|
7
|
+
from urllib.parse import urlparse
|
|
3
8
|
from tqdm.asyncio import tqdm
|
|
9
|
+
from ws_bom_robot_app.config import config
|
|
10
|
+
import aiofiles
|
|
4
11
|
|
|
5
12
|
async def download_files(urls: List[str], destination_folder: str, authorization: str = None):
|
|
6
13
|
tasks = [download_file(file, os.path.join(destination_folder, os.path.basename(file)), authorization=authorization) for file in urls]
|
|
@@ -28,14 +35,13 @@ async def download_file(url: str, destination: str, chunk_size: int = 8192, auth
|
|
|
28
35
|
# Ensure the destination directory exists
|
|
29
36
|
os.makedirs(os.path.dirname(os.path.abspath(destination)), exist_ok=True)
|
|
30
37
|
|
|
31
|
-
async with
|
|
38
|
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
32
39
|
if authorization:
|
|
33
40
|
headers = {'Authorization': authorization}
|
|
34
|
-
|
|
35
|
-
async with session.get(url) as response:
|
|
41
|
+
async with client.stream("GET", url, headers=headers) as response:
|
|
36
42
|
# Check if the request was successful
|
|
37
|
-
if response.
|
|
38
|
-
logging.error(f"Failed to download file. Status code: {response.
|
|
43
|
+
if response.status_code != 200:
|
|
44
|
+
logging.error(f"Failed to download file. Status code: {response.status_code}")
|
|
39
45
|
return None
|
|
40
46
|
|
|
41
47
|
# Get the total file size if available
|
|
@@ -49,7 +55,7 @@ async def download_file(url: str, destination: str, chunk_size: int = 8192, auth
|
|
|
49
55
|
unit_scale=True,
|
|
50
56
|
unit_divisor=1024
|
|
51
57
|
) as pbar:
|
|
52
|
-
async for chunk in response.
|
|
58
|
+
async for chunk in response.aiter_bytes(chunk_size):
|
|
53
59
|
if chunk:
|
|
54
60
|
f.write(chunk)
|
|
55
61
|
pbar.update(len(chunk))
|
|
@@ -57,7 +63,7 @@ async def download_file(url: str, destination: str, chunk_size: int = 8192, auth
|
|
|
57
63
|
logging.info(f"File downloaded successfully to {destination}")
|
|
58
64
|
return destination
|
|
59
65
|
|
|
60
|
-
except
|
|
66
|
+
except httpx.RequestError as e:
|
|
61
67
|
logging.error(f"Network error occurred: {str(e)}")
|
|
62
68
|
return None
|
|
63
69
|
except asyncio.TimeoutError:
|
|
@@ -77,3 +83,101 @@ async def download_file(url: str, destination: str, chunk_size: int = 8192, auth
|
|
|
77
83
|
logging.info(f"Cleaned up incomplete download: {destination}")
|
|
78
84
|
except OSError:
|
|
79
85
|
pass
|
|
86
|
+
|
|
87
|
+
class Base64File(BaseModel):
|
|
88
|
+
"""Base64 encoded file representation"""
|
|
89
|
+
url: str
|
|
90
|
+
base64_url: str
|
|
91
|
+
base64_content: str
|
|
92
|
+
name: str
|
|
93
|
+
extension: str
|
|
94
|
+
mime_type: str
|
|
95
|
+
|
|
96
|
+
@staticmethod
|
|
97
|
+
def _is_base64_data_uri(url: str) -> bool:
|
|
98
|
+
"""Check if URL is already a base64 data URI"""
|
|
99
|
+
return (isinstance(url, str) and
|
|
100
|
+
url.startswith('data:') and
|
|
101
|
+
';base64,' in url and
|
|
102
|
+
len(url.split(',')) == 2)
|
|
103
|
+
|
|
104
|
+
async def from_url(url: str) -> "Base64File":
|
|
105
|
+
"""Download file and return as base64 data URI"""
|
|
106
|
+
def _cache_file(url: str) -> str:
|
|
107
|
+
_hash = hashlib.md5(url.encode()).hexdigest()
|
|
108
|
+
return os.path.join(config.robot_data_folder, config.robot_data_attachment_folder, f"{_hash}.json")
|
|
109
|
+
async def from_cache(url: str) -> "Base64File":
|
|
110
|
+
"""Check if file is already downloaded and return data"""
|
|
111
|
+
_file = _cache_file(url)
|
|
112
|
+
if os.path.exists(_file):
|
|
113
|
+
try:
|
|
114
|
+
async with aiofiles.open(_file, 'rb') as f:
|
|
115
|
+
content = await f.read()
|
|
116
|
+
return Base64File(**json.loads(content))
|
|
117
|
+
except Exception as e:
|
|
118
|
+
logging.error(f"Error reading cache file {_file}: {e}")
|
|
119
|
+
return None
|
|
120
|
+
return None
|
|
121
|
+
async def to_cache(file: "Base64File", url: str) -> None:
|
|
122
|
+
"""Save file to cache"""
|
|
123
|
+
_file = _cache_file(url)
|
|
124
|
+
try:
|
|
125
|
+
async with aiofiles.open(_file, 'wb') as f:
|
|
126
|
+
await f.write(file.model_dump_json().encode('utf-8'))
|
|
127
|
+
except Exception as e:
|
|
128
|
+
logging.error(f"Error writing cache file {_file}: {e}")
|
|
129
|
+
|
|
130
|
+
# special case: base64 data URI
|
|
131
|
+
if Base64File._is_base64_data_uri(url):
|
|
132
|
+
mime_type = url.split(';')[0].replace('data:', '')
|
|
133
|
+
base64_content = url.split(',')[1]
|
|
134
|
+
extension=mime_type.split('/')[-1]
|
|
135
|
+
name = f"file-{uuid.uuid4()}.{extension}"
|
|
136
|
+
return Base64File(
|
|
137
|
+
url=url,
|
|
138
|
+
base64_url=url,
|
|
139
|
+
base64_content=base64_content,
|
|
140
|
+
name=name,
|
|
141
|
+
extension=extension,
|
|
142
|
+
mime_type=mime_type
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# default download
|
|
146
|
+
_error = None
|
|
147
|
+
try:
|
|
148
|
+
if _content := await from_cache(url):
|
|
149
|
+
return _content
|
|
150
|
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
151
|
+
response = await client.get(url, headers={"User-Agent": "Mozilla/5.0"})
|
|
152
|
+
logging.info(f"Downloading {url} - Status: {response.status_code}")
|
|
153
|
+
response.raise_for_status()
|
|
154
|
+
content = response.read()
|
|
155
|
+
# mime type detection
|
|
156
|
+
mime_type = response.headers.get('content-type', '').split(';')[0]
|
|
157
|
+
if not mime_type:
|
|
158
|
+
mime_type, _ = mimetypes.guess_type(urlparse(url).path)
|
|
159
|
+
if not mime_type:
|
|
160
|
+
mime_type = 'application/octet-stream'
|
|
161
|
+
# to base64
|
|
162
|
+
base64_content = base64.b64encode(content).decode('utf-8')
|
|
163
|
+
name = url.split('/')[-1]
|
|
164
|
+
extension = name.split('.')[-1]
|
|
165
|
+
except Exception as e:
|
|
166
|
+
_error = f"Failed to download file from {url}: {e}"
|
|
167
|
+
logging.error(_error)
|
|
168
|
+
base64_content = base64.b64encode(_error.encode('utf-8')).decode('utf-8')
|
|
169
|
+
name = "download_error.txt"
|
|
170
|
+
mime_type = "text/plain"
|
|
171
|
+
extension = "txt"
|
|
172
|
+
|
|
173
|
+
_file = Base64File(
|
|
174
|
+
url=url,
|
|
175
|
+
base64_url= f"data:{mime_type};base64,{base64_content}",
|
|
176
|
+
base64_content=base64_content,
|
|
177
|
+
name=name,
|
|
178
|
+
extension=extension,
|
|
179
|
+
mime_type=mime_type
|
|
180
|
+
)
|
|
181
|
+
if not _error:
|
|
182
|
+
await to_cache(_file, url)
|
|
183
|
+
return _file
|
|
@@ -7,6 +7,7 @@ from langchain_core.language_models import BaseChatModel
|
|
|
7
7
|
from langchain_core.vectorstores.base import VectorStoreRetriever, VectorStore
|
|
8
8
|
from langchain.retrievers import SelfQueryRetriever
|
|
9
9
|
from langchain.chains.query_constructor.schema import AttributeInfo
|
|
10
|
+
import tiktoken
|
|
10
11
|
|
|
11
12
|
class VectorDBStrategy(ABC):
|
|
12
13
|
class VectorDBStrategy:
|
|
@@ -49,6 +50,52 @@ class VectorDBStrategy(ABC):
|
|
|
49
50
|
Asynchronously invokes multiple retrievers in parallel, then merges
|
|
50
51
|
their results while removing duplicates.
|
|
51
52
|
"""
|
|
53
|
+
MAX_TOKENS_PER_BATCH = 300_000 * 0.8
|
|
54
|
+
def __init__(self):
|
|
55
|
+
try:
|
|
56
|
+
self.encoding = tiktoken.get_encoding("cl100k_base") # text-embedding-3-small, text-embedding-3-large: https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken
|
|
57
|
+
except Exception:
|
|
58
|
+
self.encoding = None
|
|
59
|
+
|
|
60
|
+
def _count_tokens(self, text: str) -> int:
|
|
61
|
+
"""Count tokens in text using tiktoken or fallback estimation"""
|
|
62
|
+
if self.encoding:
|
|
63
|
+
try:
|
|
64
|
+
return len(self.encoding.encode(text))
|
|
65
|
+
except Exception:
|
|
66
|
+
pass
|
|
67
|
+
# fallback: rough estimation (1 token ≈ 4 characters)
|
|
68
|
+
return len(text) // 4
|
|
69
|
+
|
|
70
|
+
def _batch_documents_by_tokens(self, documents: list[Document]) -> list[list[Document]]:
|
|
71
|
+
"""Split documents into batches based on token count"""
|
|
72
|
+
if not documents:
|
|
73
|
+
return []
|
|
74
|
+
batches = []
|
|
75
|
+
current_batch = []
|
|
76
|
+
current_token_count = 0
|
|
77
|
+
|
|
78
|
+
for doc in documents:
|
|
79
|
+
doc_tokens = self._count_tokens(doc.page_content)
|
|
80
|
+
# check if adding this document exceeds the limit
|
|
81
|
+
if current_token_count + doc_tokens > VectorDBStrategy.MAX_TOKENS_PER_BATCH:
|
|
82
|
+
# start new batch if current batch is not empty
|
|
83
|
+
if current_batch:
|
|
84
|
+
batches.append(current_batch)
|
|
85
|
+
# reset current batch
|
|
86
|
+
current_batch = [doc]
|
|
87
|
+
current_token_count = doc_tokens # reset to current doc's tokens
|
|
88
|
+
else:
|
|
89
|
+
# add to current batch
|
|
90
|
+
current_batch.append(doc)
|
|
91
|
+
current_token_count += doc_tokens
|
|
92
|
+
|
|
93
|
+
# add final batch if not empty
|
|
94
|
+
if current_batch:
|
|
95
|
+
batches.append(current_batch)
|
|
96
|
+
|
|
97
|
+
return batches
|
|
98
|
+
|
|
52
99
|
_CACHE: dict[str, VectorStore] = {}
|
|
53
100
|
def _clear_cache(self, key: str):
|
|
54
101
|
if key in self._CACHE:
|
|
@@ -131,6 +178,9 @@ class VectorDBStrategy(ABC):
|
|
|
131
178
|
return await retriever.ainvoke(query, config={"source": kwargs.get("source", "retriever")})
|
|
132
179
|
|
|
133
180
|
@staticmethod
|
|
181
|
+
def _remove_empty_documents(docs: List[Document]) -> List[Document]:
|
|
182
|
+
return [doc for doc in docs if doc.page_content and doc.page_content.strip()]
|
|
183
|
+
@staticmethod
|
|
134
184
|
def _remove_duplicates(docs: List[Document]) -> List[Document]:
|
|
135
185
|
seen = set()
|
|
136
186
|
return [doc for doc in docs if not (doc.page_content in seen or seen.add(doc.page_content))]
|
|
@@ -38,6 +38,9 @@ class Chroma(VectorDBStrategy):
|
|
|
38
38
|
Returns:
|
|
39
39
|
CHROMA: The retrieved or newly created Chroma instance.
|
|
40
40
|
"""
|
|
41
|
+
def __init__(self):
|
|
42
|
+
super().__init__()
|
|
43
|
+
|
|
41
44
|
async def create(
|
|
42
45
|
self,
|
|
43
46
|
embeddings: Embeddings,
|
|
@@ -46,20 +49,37 @@ class Chroma(VectorDBStrategy):
|
|
|
46
49
|
**kwargs
|
|
47
50
|
) -> Optional[str]:
|
|
48
51
|
try:
|
|
52
|
+
documents = self._remove_empty_documents(documents)
|
|
49
53
|
chunked_docs = DocumentChunker.chunk(documents)
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
54
|
+
batches = self._batch_documents_by_tokens(chunked_docs)
|
|
55
|
+
logging.info(f"documents: {len(documents)}, after chunking: {len(chunked_docs)}, processing batches: {len(batches)}")
|
|
56
|
+
_instance: CHROMA = None
|
|
57
|
+
for i, batch in enumerate(batches):
|
|
58
|
+
batch_tokens = sum(self._count_tokens(doc.page_content) for doc in batch)
|
|
59
|
+
logging.info(f"processing batch {i+1}/{len(batches)} with {len(batch)} docs ({batch_tokens:,} tokens)")
|
|
60
|
+
# create instance from first batch
|
|
61
|
+
if _instance is None:
|
|
62
|
+
_instance = await asyncio.to_thread(
|
|
63
|
+
CHROMA.from_documents,
|
|
64
|
+
documents=batch,
|
|
65
|
+
embedding=embeddings,
|
|
66
|
+
persist_directory=storage_id
|
|
67
|
+
)
|
|
68
|
+
else:
|
|
69
|
+
# merge to existing instance
|
|
70
|
+
await _instance.aadd_documents(batch)
|
|
71
|
+
# add a small delay to avoid rate limiting
|
|
72
|
+
if i < len(batches) - 1: # except last batch
|
|
73
|
+
await asyncio.sleep(1)
|
|
74
|
+
if _instance:
|
|
75
|
+
self._clear_cache(storage_id)
|
|
76
|
+
logging.info(f"Successfully created {Chroma.__name__} index with {len(chunked_docs)} total documents")
|
|
57
77
|
return storage_id
|
|
58
78
|
except Exception as e:
|
|
59
79
|
logging.error(f"{Chroma.__name__} create error: {e}")
|
|
60
80
|
raise e
|
|
61
81
|
finally:
|
|
62
|
-
del documents
|
|
82
|
+
del documents, chunked_docs, _instance
|
|
63
83
|
gc.collect()
|
|
64
84
|
|
|
65
85
|
def get_loader(
|
|
@@ -22,6 +22,9 @@ class Faiss(VectorDBStrategy):
|
|
|
22
22
|
was previously loaded and cached, it returns the cached instance; otherwise,
|
|
23
23
|
it loads the index from local storage and caches it for subsequent use.
|
|
24
24
|
"""
|
|
25
|
+
def __init__(self):
|
|
26
|
+
super().__init__()
|
|
27
|
+
|
|
25
28
|
async def create(
|
|
26
29
|
self,
|
|
27
30
|
embeddings: Embeddings,
|
|
@@ -30,20 +33,44 @@ class Faiss(VectorDBStrategy):
|
|
|
30
33
|
**kwargs
|
|
31
34
|
) -> Optional[str]:
|
|
32
35
|
try:
|
|
36
|
+
documents = self._remove_empty_documents(documents)
|
|
33
37
|
chunked_docs = DocumentChunker.chunk(documents)
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
38
|
+
batches = self._batch_documents_by_tokens(chunked_docs)
|
|
39
|
+
logging.info(f"documents: {len(documents)}, after chunking: {len(chunked_docs)}, processing batches: {len(batches)}")
|
|
40
|
+
_instance: FAISS = None
|
|
41
|
+
for i, batch in enumerate(batches):
|
|
42
|
+
batch_tokens = sum(self._count_tokens(doc.page_content) for doc in batch)
|
|
43
|
+
logging.info(f"processing batch {i+1}/{len(batches)} with {len(batch)} docs ({batch_tokens:,} tokens)")
|
|
44
|
+
# init
|
|
45
|
+
_batch_instance = await asyncio.to_thread(
|
|
46
|
+
FAISS.from_documents,
|
|
47
|
+
batch,
|
|
48
|
+
embeddings
|
|
49
|
+
)
|
|
50
|
+
# create instance from first batch
|
|
51
|
+
if _instance is None:
|
|
52
|
+
_instance = _batch_instance
|
|
53
|
+
else:
|
|
54
|
+
# merge to existing instance
|
|
55
|
+
await asyncio.to_thread(
|
|
56
|
+
_instance.merge_from,
|
|
57
|
+
_batch_instance
|
|
58
|
+
)
|
|
59
|
+
del _batch_instance
|
|
60
|
+
gc.collect()
|
|
61
|
+
# add a small delay to avoid rate limiting
|
|
62
|
+
if i < len(batches) - 1: # except last batch
|
|
63
|
+
await asyncio.sleep(1)
|
|
64
|
+
if _instance:
|
|
65
|
+
await asyncio.to_thread(_instance.save_local, storage_id)
|
|
66
|
+
self._clear_cache(storage_id)
|
|
67
|
+
logging.info(f"Successfully created {Faiss.__name__} index with {len(chunked_docs)} total documents")
|
|
41
68
|
return storage_id
|
|
42
69
|
except Exception as e:
|
|
43
70
|
logging.error(f"{Faiss.__name__} create error: {e}")
|
|
44
71
|
raise e
|
|
45
72
|
finally:
|
|
46
|
-
del documents, _instance
|
|
73
|
+
del documents, chunked_docs, _instance
|
|
47
74
|
gc.collect()
|
|
48
75
|
|
|
49
76
|
def get_loader(
|
|
@@ -17,28 +17,43 @@ class Qdrant(VectorDBStrategy):
|
|
|
17
17
|
**kwargs
|
|
18
18
|
) -> Optional[str]:
|
|
19
19
|
try:
|
|
20
|
+
documents = self._remove_empty_documents(documents)
|
|
20
21
|
chunked_docs = DocumentChunker.chunk(documents)
|
|
22
|
+
batches = self._batch_documents_by_tokens(chunked_docs)
|
|
23
|
+
logging.info(f"documents: {len(documents)}, after chunking: {len(chunked_docs)}, processing batches: {len(batches)}")
|
|
24
|
+
_instance: QDRANT = None
|
|
21
25
|
if not os.path.exists(storage_id):
|
|
22
26
|
os.makedirs(storage_id)
|
|
23
27
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
28
|
+
for i, batch in enumerate(batches):
|
|
29
|
+
batch_tokens = sum(self._count_tokens(doc.page_content) for doc in batch)
|
|
30
|
+
logging.info(f"processing batch {i+1}/{len(batches)} with {len(batch)} docs ({batch_tokens:,} tokens)")
|
|
31
|
+
# create instance from first batch
|
|
32
|
+
if _instance is None:
|
|
33
|
+
_instance = await asyncio.to_thread(
|
|
34
|
+
QDRANT.from_documents,
|
|
35
|
+
documents=batch,
|
|
36
|
+
embedding=embeddings,
|
|
37
|
+
sparse_embedding=kwargs['sparse_embedding'] if 'sparse_embedding' in kwargs else FastEmbedSparse(),
|
|
38
|
+
collection_name="default",
|
|
39
|
+
path=storage_id,
|
|
40
|
+
retrieval_mode=RetrievalMode.HYBRID
|
|
41
|
+
)
|
|
42
|
+
else:
|
|
43
|
+
# merge to existing instance
|
|
44
|
+
await _instance.aadd_documents(batch)
|
|
45
|
+
# add a small delay to avoid rate limiting
|
|
46
|
+
if i < len(batches) - 1: # except last batch
|
|
47
|
+
await asyncio.sleep(1)
|
|
48
|
+
if _instance:
|
|
49
|
+
self._clear_cache(storage_id)
|
|
50
|
+
logging.info(f"Successfully created {Qdrant.__name__} index with {len(chunked_docs)} total documents")
|
|
51
|
+
return storage_id
|
|
37
52
|
except Exception as e:
|
|
38
53
|
logging.error(f"{Qdrant.__name__} create error: {e}")
|
|
39
54
|
raise e
|
|
40
55
|
finally:
|
|
41
|
-
del documents
|
|
56
|
+
del documents, chunked_docs, _instance
|
|
42
57
|
gc.collect()
|
|
43
58
|
|
|
44
59
|
def get_loader(
|