ws-bom-robot-app 0.0.63__py3-none-any.whl → 0.0.103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. ws_bom_robot_app/config.py +30 -8
  2. ws_bom_robot_app/cron_manager.py +13 -12
  3. ws_bom_robot_app/llm/agent_context.py +1 -1
  4. ws_bom_robot_app/llm/agent_handler.py +11 -12
  5. ws_bom_robot_app/llm/agent_lcel.py +80 -18
  6. ws_bom_robot_app/llm/api.py +69 -7
  7. ws_bom_robot_app/llm/evaluator.py +319 -0
  8. ws_bom_robot_app/llm/main.py +51 -28
  9. ws_bom_robot_app/llm/models/api.py +40 -6
  10. ws_bom_robot_app/llm/nebuly_handler.py +18 -15
  11. ws_bom_robot_app/llm/providers/llm_manager.py +233 -75
  12. ws_bom_robot_app/llm/tools/tool_builder.py +4 -1
  13. ws_bom_robot_app/llm/tools/tool_manager.py +48 -22
  14. ws_bom_robot_app/llm/utils/chunker.py +6 -1
  15. ws_bom_robot_app/llm/utils/cleanup.py +81 -0
  16. ws_bom_robot_app/llm/utils/cms.py +60 -14
  17. ws_bom_robot_app/llm/utils/download.py +112 -8
  18. ws_bom_robot_app/llm/vector_store/db/base.py +50 -0
  19. ws_bom_robot_app/llm/vector_store/db/chroma.py +28 -8
  20. ws_bom_robot_app/llm/vector_store/db/faiss.py +35 -8
  21. ws_bom_robot_app/llm/vector_store/db/qdrant.py +29 -14
  22. ws_bom_robot_app/llm/vector_store/integration/api.py +216 -0
  23. ws_bom_robot_app/llm/vector_store/integration/azure.py +1 -1
  24. ws_bom_robot_app/llm/vector_store/integration/base.py +58 -15
  25. ws_bom_robot_app/llm/vector_store/integration/confluence.py +33 -5
  26. ws_bom_robot_app/llm/vector_store/integration/dropbox.py +1 -1
  27. ws_bom_robot_app/llm/vector_store/integration/gcs.py +1 -1
  28. ws_bom_robot_app/llm/vector_store/integration/github.py +22 -22
  29. ws_bom_robot_app/llm/vector_store/integration/googledrive.py +46 -17
  30. ws_bom_robot_app/llm/vector_store/integration/jira.py +93 -60
  31. ws_bom_robot_app/llm/vector_store/integration/manager.py +6 -2
  32. ws_bom_robot_app/llm/vector_store/integration/s3.py +1 -1
  33. ws_bom_robot_app/llm/vector_store/integration/sftp.py +1 -1
  34. ws_bom_robot_app/llm/vector_store/integration/sharepoint.py +7 -14
  35. ws_bom_robot_app/llm/vector_store/integration/shopify.py +143 -0
  36. ws_bom_robot_app/llm/vector_store/integration/sitemap.py +6 -1
  37. ws_bom_robot_app/llm/vector_store/integration/slack.py +3 -2
  38. ws_bom_robot_app/llm/vector_store/integration/thron.py +236 -0
  39. ws_bom_robot_app/llm/vector_store/loader/base.py +52 -8
  40. ws_bom_robot_app/llm/vector_store/loader/docling.py +71 -33
  41. ws_bom_robot_app/main.py +148 -146
  42. ws_bom_robot_app/subprocess_runner.py +106 -0
  43. ws_bom_robot_app/task_manager.py +204 -53
  44. ws_bom_robot_app/util.py +6 -0
  45. {ws_bom_robot_app-0.0.63.dist-info → ws_bom_robot_app-0.0.103.dist-info}/METADATA +158 -75
  46. ws_bom_robot_app-0.0.103.dist-info/RECORD +76 -0
  47. ws_bom_robot_app/llm/settings.py +0 -4
  48. ws_bom_robot_app/llm/utils/kb.py +0 -34
  49. ws_bom_robot_app-0.0.63.dist-info/RECORD +0 -72
  50. {ws_bom_robot_app-0.0.63.dist-info → ws_bom_robot_app-0.0.103.dist-info}/WHEEL +0 -0
  51. {ws_bom_robot_app-0.0.63.dist-info → ws_bom_robot_app-0.0.103.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,8 @@
1
1
  import logging, aiohttp
2
- from typing import List, Optional
3
-
2
+ from typing import Any, List, Optional
4
3
  from pydantic import AliasChoices, BaseModel, ConfigDict, Field
5
- from ws_bom_robot_app.llm.models.api import LlmAppTool
4
+ from ws_bom_robot_app.llm.models.api import LlmAppTool, LlmRules, StreamRequest
5
+ from ws_bom_robot_app.llm.models.kb import LlmKbEndpoint, LlmKbIntegration
6
6
  from ws_bom_robot_app.util import cache_with_ttl
7
7
 
8
8
  class CmsAppCredential(BaseModel):
@@ -12,13 +12,16 @@ class CmsAppCredential(BaseModel):
12
12
  class CmsApp(BaseModel):
13
13
  id: str = Field(..., description="Unique identifier for the app")
14
14
  name: str = Field(..., description="Name of the app")
15
+ mode: str
16
+ prompt_samples: Optional[List[str]]
15
17
  credentials: CmsAppCredential = None
16
- app_tools: Optional[List[LlmAppTool]] = Field([], validation_alias=AliasChoices("appTools","app_tools"))
18
+ rq: StreamRequest
19
+ kb: Optional[Any] = None
17
20
  model_config = ConfigDict(extra='ignore')
18
21
 
19
22
  @cache_with_ttl(600) # Cache for 10 minutes
20
23
  async def get_apps() -> list[CmsApp]:
21
- import json, os
24
+ import json
22
25
  from ws_bom_robot_app.config import config
23
26
  class DictObject(object):
24
27
  def __init__(self, dict_):
@@ -34,9 +37,17 @@ async def get_apps() -> list[CmsApp]:
34
37
  if obj is None:
35
38
  break
36
39
  return obj
40
+ def __to_dict(obj):
41
+ """Converts DictObject to dict recursively"""
42
+ if isinstance(obj, DictObject):
43
+ return {k: __to_dict(v) for k, v in obj.__dict__.items()}
44
+ elif isinstance(obj, list):
45
+ return [__to_dict(item) for item in obj]
46
+ else:
47
+ return obj
37
48
  host = config.robot_cms_host
38
49
  if host:
39
- url = f"{host}/api/llmApp?depth=1&pagination=false"
50
+ url = f"{host}/api/llmApp?depth=1&pagination=false&locale=it"
40
51
  auth = config.robot_cms_auth
41
52
  headers = {"Authorization": auth} if auth else {}
42
53
  async with aiohttp.ClientSession() as session:
@@ -47,15 +58,49 @@ async def get_apps() -> list[CmsApp]:
47
58
  for cms_app in cms_apps:
48
59
  if __attr(cms_app,"isActive",default=True) == True:
49
60
  _cms_app_dict = DictObject.from_dict(cms_app)
50
- _app: CmsApp = CmsApp(
51
- id=_cms_app_dict.id,
52
- name=_cms_app_dict.name,
53
- credentials=CmsAppCredential(app_key=_cms_app_dict.settings.credentials.appKey,api_key=_cms_app_dict.settings.credentials.apiKey),
54
- app_tools=[LlmAppTool(**tool) for tool in cms_app.get('settings').get('appTools',[])]
55
- )
56
- if _app.app_tools:
57
- for tool in _app.app_tools:
61
+ try:
62
+ _app: CmsApp = CmsApp(
63
+ id=_cms_app_dict.id,
64
+ name=_cms_app_dict.name,
65
+ mode=_cms_app_dict.mode,
66
+ prompt_samples=[__attr(sample,'sampleInputText') or f"{sample.__dict__}" for sample in _cms_app_dict.contents.sampleInputTexts],
67
+ credentials=CmsAppCredential(app_key=_cms_app_dict.settings.credentials.appKey,api_key=_cms_app_dict.settings.credentials.apiKey),
68
+ rq=StreamRequest(
69
+ #thread_id=str(uuid.uuid1()),
70
+ messages=[],
71
+ secrets={
72
+ "apiKey": __attr(_cms_app_dict.settings,'llmConfig','secrets','apiKey', default=''),
73
+ "langChainApiKey": __attr(_cms_app_dict.settings,'llmConfig','secrets','langChainApiKey', default=''),
74
+ "nebulyApiKey": __attr(_cms_app_dict.settings,'llmConfig','secrets','nebulyApiKey', default=''),
75
+ },
76
+ system_message=__attr(_cms_app_dict.settings,'llmConfig','prompt','prompt','systemMessage') if __attr(_cms_app_dict.settings,'llmConfig','prompt','prompt','systemMessage') else __attr(_cms_app_dict.settings,'llmConfig','prompt','systemMessage'),
77
+ provider= __attr(_cms_app_dict.settings,'llmConfig','provider') or 'openai',
78
+ model= __attr(_cms_app_dict.settings,'llmConfig','model') or 'gpt-4o',
79
+ temperature=_cms_app_dict.settings.llmConfig.temperature or 0,
80
+ app_tools=[LlmAppTool(**tool) for tool in cms_app.get('settings').get('appTools',[])],
81
+ rules=LlmRules(
82
+ vector_type=__attr(_cms_app_dict.settings,'rules','vectorDbType', default='faiss'),
83
+ vector_db=__attr(_cms_app_dict.settings,'rules','vectorDbFile','filename'),
84
+ threshold=__attr(_cms_app_dict.settings,'rules','threshold', default=0.7)
85
+ ) if __attr(_cms_app_dict.settings,'rules','vectorDbFile','filename') else None,
86
+ #fine_tuned_model=__attr(_cms_app_dict.settings,'llmConfig','fineTunedModel'),
87
+ lang_chain_tracing= __attr(_cms_app_dict.settings,'llmConfig','langChainTracing', default=False),
88
+ lang_chain_project= __attr(_cms_app_dict.settings,'llmConfig','langChainProject', default=''),
89
+ output_structure= __to_dict(__attr(_cms_app_dict.settings,'llmConfig','outputStructure')) if __attr(_cms_app_dict.settings,'llmConfig','outputStructure') else None
90
+ ))
91
+ except Exception as e:
92
+ import traceback
93
+ ex = traceback.format_exc()
94
+ logging.error(f"Error creating CmsApp {_cms_app_dict.name} from dict: {e}\n{ex}")
95
+ continue
96
+ if _app.rq.app_tools:
97
+ for tool in _app.rq.app_tools:
58
98
  _knowledgeBase = tool.knowledgeBase
99
+ tool.integrations = [LlmKbIntegration(**item) for item in _knowledgeBase.get('integrations')] if _knowledgeBase.get('integrations') else []
100
+ try:
101
+ tool.endpoints = [LlmKbEndpoint(**item) for item in _knowledgeBase.get('externalEndpoints')] if _knowledgeBase.get('externalEndpoints') else []
102
+ except Exception as e:
103
+ logging.error(f"Error parsing endpoints for app {_cms_app_dict.name} tool {tool.name}: {e}")
59
104
  tool.vector_db = _knowledgeBase.get('vectorDbFile').get('filename') if _knowledgeBase.get('vectorDbFile') else None
60
105
  tool.vector_type = _knowledgeBase.get('vectorDbType') if _knowledgeBase.get('vectorDbType') else 'faiss'
61
106
  del tool.knowledgeBase
@@ -67,6 +112,7 @@ async def get_apps() -> list[CmsApp]:
67
112
  logging.error("robot_cms_host environment variable is not set.")
68
113
  return []
69
114
 
115
+
70
116
  async def get_app_by_id(app_id: str) -> CmsApp | None:
71
117
  apps = await get_apps()
72
118
  app = next((a for a in apps if a.id == app_id), None)
@@ -1,6 +1,13 @@
1
+ import httpx
1
2
  from typing import List,Optional
2
- import os, logging, aiohttp, asyncio
3
+ import os, logging, aiohttp, asyncio, hashlib, json
4
+ import uuid
5
+ from pydantic import BaseModel
6
+ import base64, requests, mimetypes
7
+ from urllib.parse import urlparse
3
8
  from tqdm.asyncio import tqdm
9
+ from ws_bom_robot_app.config import config
10
+ import aiofiles
4
11
 
5
12
  async def download_files(urls: List[str], destination_folder: str, authorization: str = None):
6
13
  tasks = [download_file(file, os.path.join(destination_folder, os.path.basename(file)), authorization=authorization) for file in urls]
@@ -28,14 +35,13 @@ async def download_file(url: str, destination: str, chunk_size: int = 8192, auth
28
35
  # Ensure the destination directory exists
29
36
  os.makedirs(os.path.dirname(os.path.abspath(destination)), exist_ok=True)
30
37
 
31
- async with aiohttp.ClientSession() as session:
38
+ async with httpx.AsyncClient(timeout=30.0) as client:
32
39
  if authorization:
33
40
  headers = {'Authorization': authorization}
34
- session.headers.update(headers)
35
- async with session.get(url) as response:
41
+ async with client.stream("GET", url, headers=headers) as response:
36
42
  # Check if the request was successful
37
- if response.status != 200:
38
- logging.error(f"Failed to download file. Status code: {response.status}")
43
+ if response.status_code != 200:
44
+ logging.error(f"Failed to download file. Status code: {response.status_code}")
39
45
  return None
40
46
 
41
47
  # Get the total file size if available
@@ -49,7 +55,7 @@ async def download_file(url: str, destination: str, chunk_size: int = 8192, auth
49
55
  unit_scale=True,
50
56
  unit_divisor=1024
51
57
  ) as pbar:
52
- async for chunk in response.content.iter_chunked(chunk_size):
58
+ async for chunk in response.aiter_bytes(chunk_size):
53
59
  if chunk:
54
60
  f.write(chunk)
55
61
  pbar.update(len(chunk))
@@ -57,7 +63,7 @@ async def download_file(url: str, destination: str, chunk_size: int = 8192, auth
57
63
  logging.info(f"File downloaded successfully to {destination}")
58
64
  return destination
59
65
 
60
- except aiohttp.ClientError as e:
66
+ except httpx.RequestError as e:
61
67
  logging.error(f"Network error occurred: {str(e)}")
62
68
  return None
63
69
  except asyncio.TimeoutError:
@@ -77,3 +83,101 @@ async def download_file(url: str, destination: str, chunk_size: int = 8192, auth
77
83
  logging.info(f"Cleaned up incomplete download: {destination}")
78
84
  except OSError:
79
85
  pass
86
+
87
+ class Base64File(BaseModel):
88
+ """Base64 encoded file representation"""
89
+ url: str
90
+ base64_url: str
91
+ base64_content: str
92
+ name: str
93
+ extension: str
94
+ mime_type: str
95
+
96
+ @staticmethod
97
+ def _is_base64_data_uri(url: str) -> bool:
98
+ """Check if URL is already a base64 data URI"""
99
+ return (isinstance(url, str) and
100
+ url.startswith('data:') and
101
+ ';base64,' in url and
102
+ len(url.split(',')) == 2)
103
+
104
+ async def from_url(url: str) -> "Base64File":
105
+ """Download file and return as base64 data URI"""
106
+ def _cache_file(url: str) -> str:
107
+ _hash = hashlib.md5(url.encode()).hexdigest()
108
+ return os.path.join(config.robot_data_folder, config.robot_data_attachment_folder, f"{_hash}.json")
109
+ async def from_cache(url: str) -> "Base64File":
110
+ """Check if file is already downloaded and return data"""
111
+ _file = _cache_file(url)
112
+ if os.path.exists(_file):
113
+ try:
114
+ async with aiofiles.open(_file, 'rb') as f:
115
+ content = await f.read()
116
+ return Base64File(**json.loads(content))
117
+ except Exception as e:
118
+ logging.error(f"Error reading cache file {_file}: {e}")
119
+ return None
120
+ return None
121
+ async def to_cache(file: "Base64File", url: str) -> None:
122
+ """Save file to cache"""
123
+ _file = _cache_file(url)
124
+ try:
125
+ async with aiofiles.open(_file, 'wb') as f:
126
+ await f.write(file.model_dump_json().encode('utf-8'))
127
+ except Exception as e:
128
+ logging.error(f"Error writing cache file {_file}: {e}")
129
+
130
+ # special case: base64 data URI
131
+ if Base64File._is_base64_data_uri(url):
132
+ mime_type = url.split(';')[0].replace('data:', '')
133
+ base64_content = url.split(',')[1]
134
+ extension=mime_type.split('/')[-1]
135
+ name = f"file-{uuid.uuid4()}.{extension}"
136
+ return Base64File(
137
+ url=url,
138
+ base64_url=url,
139
+ base64_content=base64_content,
140
+ name=name,
141
+ extension=extension,
142
+ mime_type=mime_type
143
+ )
144
+
145
+ # default download
146
+ _error = None
147
+ try:
148
+ if _content := await from_cache(url):
149
+ return _content
150
+ async with httpx.AsyncClient(timeout=30.0) as client:
151
+ response = await client.get(url, headers={"User-Agent": "Mozilla/5.0"})
152
+ logging.info(f"Downloading {url} - Status: {response.status_code}")
153
+ response.raise_for_status()
154
+ content = response.read()
155
+ # mime type detection
156
+ mime_type = response.headers.get('content-type', '').split(';')[0]
157
+ if not mime_type:
158
+ mime_type, _ = mimetypes.guess_type(urlparse(url).path)
159
+ if not mime_type:
160
+ mime_type = 'application/octet-stream'
161
+ # to base64
162
+ base64_content = base64.b64encode(content).decode('utf-8')
163
+ name = url.split('/')[-1]
164
+ extension = name.split('.')[-1]
165
+ except Exception as e:
166
+ _error = f"Failed to download file from {url}: {e}"
167
+ logging.error(_error)
168
+ base64_content = base64.b64encode(_error.encode('utf-8')).decode('utf-8')
169
+ name = "download_error.txt"
170
+ mime_type = "text/plain"
171
+ extension = "txt"
172
+
173
+ _file = Base64File(
174
+ url=url,
175
+ base64_url= f"data:{mime_type};base64,{base64_content}",
176
+ base64_content=base64_content,
177
+ name=name,
178
+ extension=extension,
179
+ mime_type=mime_type
180
+ )
181
+ if not _error:
182
+ await to_cache(_file, url)
183
+ return _file
@@ -7,6 +7,7 @@ from langchain_core.language_models import BaseChatModel
7
7
  from langchain_core.vectorstores.base import VectorStoreRetriever, VectorStore
8
8
  from langchain.retrievers import SelfQueryRetriever
9
9
  from langchain.chains.query_constructor.schema import AttributeInfo
10
+ import tiktoken
10
11
 
11
12
  class VectorDBStrategy(ABC):
12
13
  class VectorDBStrategy:
@@ -49,6 +50,52 @@ class VectorDBStrategy(ABC):
49
50
  Asynchronously invokes multiple retrievers in parallel, then merges
50
51
  their results while removing duplicates.
51
52
  """
53
+ MAX_TOKENS_PER_BATCH = 300_000 * 0.8
54
+ def __init__(self):
55
+ try:
56
+ self.encoding = tiktoken.get_encoding("cl100k_base") # text-embedding-3-small, text-embedding-3-large: https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken
57
+ except Exception:
58
+ self.encoding = None
59
+
60
+ def _count_tokens(self, text: str) -> int:
61
+ """Count tokens in text using tiktoken or fallback estimation"""
62
+ if self.encoding:
63
+ try:
64
+ return len(self.encoding.encode(text))
65
+ except Exception:
66
+ pass
67
+ # fallback: rough estimation (1 token ≈ 4 characters)
68
+ return len(text) // 4
69
+
70
+ def _batch_documents_by_tokens(self, documents: list[Document]) -> list[list[Document]]:
71
+ """Split documents into batches based on token count"""
72
+ if not documents:
73
+ return []
74
+ batches = []
75
+ current_batch = []
76
+ current_token_count = 0
77
+
78
+ for doc in documents:
79
+ doc_tokens = self._count_tokens(doc.page_content)
80
+ # check if adding this document exceeds the limit
81
+ if current_token_count + doc_tokens > VectorDBStrategy.MAX_TOKENS_PER_BATCH:
82
+ # start new batch if current batch is not empty
83
+ if current_batch:
84
+ batches.append(current_batch)
85
+ # reset current batch
86
+ current_batch = [doc]
87
+ current_token_count = doc_tokens # reset to current doc's tokens
88
+ else:
89
+ # add to current batch
90
+ current_batch.append(doc)
91
+ current_token_count += doc_tokens
92
+
93
+ # add final batch if not empty
94
+ if current_batch:
95
+ batches.append(current_batch)
96
+
97
+ return batches
98
+
52
99
  _CACHE: dict[str, VectorStore] = {}
53
100
  def _clear_cache(self, key: str):
54
101
  if key in self._CACHE:
@@ -131,6 +178,9 @@ class VectorDBStrategy(ABC):
131
178
  return await retriever.ainvoke(query, config={"source": kwargs.get("source", "retriever")})
132
179
 
133
180
  @staticmethod
181
+ def _remove_empty_documents(docs: List[Document]) -> List[Document]:
182
+ return [doc for doc in docs if doc.page_content and doc.page_content.strip()]
183
+ @staticmethod
134
184
  def _remove_duplicates(docs: List[Document]) -> List[Document]:
135
185
  seen = set()
136
186
  return [doc for doc in docs if not (doc.page_content in seen or seen.add(doc.page_content))]
@@ -38,6 +38,9 @@ class Chroma(VectorDBStrategy):
38
38
  Returns:
39
39
  CHROMA: The retrieved or newly created Chroma instance.
40
40
  """
41
+ def __init__(self):
42
+ super().__init__()
43
+
41
44
  async def create(
42
45
  self,
43
46
  embeddings: Embeddings,
@@ -46,20 +49,37 @@ class Chroma(VectorDBStrategy):
46
49
  **kwargs
47
50
  ) -> Optional[str]:
48
51
  try:
52
+ documents = self._remove_empty_documents(documents)
49
53
  chunked_docs = DocumentChunker.chunk(documents)
50
- await asyncio.to_thread(
51
- CHROMA.from_documents,
52
- documents=chunked_docs,
53
- embedding=embeddings,
54
- persist_directory=storage_id
55
- )
56
- self._clear_cache(storage_id)
54
+ batches = self._batch_documents_by_tokens(chunked_docs)
55
+ logging.info(f"documents: {len(documents)}, after chunking: {len(chunked_docs)}, processing batches: {len(batches)}")
56
+ _instance: CHROMA = None
57
+ for i, batch in enumerate(batches):
58
+ batch_tokens = sum(self._count_tokens(doc.page_content) for doc in batch)
59
+ logging.info(f"processing batch {i+1}/{len(batches)} with {len(batch)} docs ({batch_tokens:,} tokens)")
60
+ # create instance from first batch
61
+ if _instance is None:
62
+ _instance = await asyncio.to_thread(
63
+ CHROMA.from_documents,
64
+ documents=batch,
65
+ embedding=embeddings,
66
+ persist_directory=storage_id
67
+ )
68
+ else:
69
+ # merge to existing instance
70
+ await _instance.aadd_documents(batch)
71
+ # add a small delay to avoid rate limiting
72
+ if i < len(batches) - 1: # except last batch
73
+ await asyncio.sleep(1)
74
+ if _instance:
75
+ self._clear_cache(storage_id)
76
+ logging.info(f"Successfully created {Chroma.__name__} index with {len(chunked_docs)} total documents")
57
77
  return storage_id
58
78
  except Exception as e:
59
79
  logging.error(f"{Chroma.__name__} create error: {e}")
60
80
  raise e
61
81
  finally:
62
- del documents
82
+ del documents, chunked_docs, _instance
63
83
  gc.collect()
64
84
 
65
85
  def get_loader(
@@ -22,6 +22,9 @@ class Faiss(VectorDBStrategy):
22
22
  was previously loaded and cached, it returns the cached instance; otherwise,
23
23
  it loads the index from local storage and caches it for subsequent use.
24
24
  """
25
+ def __init__(self):
26
+ super().__init__()
27
+
25
28
  async def create(
26
29
  self,
27
30
  embeddings: Embeddings,
@@ -30,20 +33,44 @@ class Faiss(VectorDBStrategy):
30
33
  **kwargs
31
34
  ) -> Optional[str]:
32
35
  try:
36
+ documents = self._remove_empty_documents(documents)
33
37
  chunked_docs = DocumentChunker.chunk(documents)
34
- _instance = await asyncio.to_thread(
35
- FAISS.from_documents,
36
- chunked_docs,
37
- embeddings
38
- )
39
- await asyncio.to_thread(_instance.save_local, storage_id)
40
- self._clear_cache(storage_id)
38
+ batches = self._batch_documents_by_tokens(chunked_docs)
39
+ logging.info(f"documents: {len(documents)}, after chunking: {len(chunked_docs)}, processing batches: {len(batches)}")
40
+ _instance: FAISS = None
41
+ for i, batch in enumerate(batches):
42
+ batch_tokens = sum(self._count_tokens(doc.page_content) for doc in batch)
43
+ logging.info(f"processing batch {i+1}/{len(batches)} with {len(batch)} docs ({batch_tokens:,} tokens)")
44
+ # init
45
+ _batch_instance = await asyncio.to_thread(
46
+ FAISS.from_documents,
47
+ batch,
48
+ embeddings
49
+ )
50
+ # create instance from first batch
51
+ if _instance is None:
52
+ _instance = _batch_instance
53
+ else:
54
+ # merge to existing instance
55
+ await asyncio.to_thread(
56
+ _instance.merge_from,
57
+ _batch_instance
58
+ )
59
+ del _batch_instance
60
+ gc.collect()
61
+ # add a small delay to avoid rate limiting
62
+ if i < len(batches) - 1: # except last batch
63
+ await asyncio.sleep(1)
64
+ if _instance:
65
+ await asyncio.to_thread(_instance.save_local, storage_id)
66
+ self._clear_cache(storage_id)
67
+ logging.info(f"Successfully created {Faiss.__name__} index with {len(chunked_docs)} total documents")
41
68
  return storage_id
42
69
  except Exception as e:
43
70
  logging.error(f"{Faiss.__name__} create error: {e}")
44
71
  raise e
45
72
  finally:
46
- del documents, _instance
73
+ del documents, chunked_docs, _instance
47
74
  gc.collect()
48
75
 
49
76
  def get_loader(
@@ -17,28 +17,43 @@ class Qdrant(VectorDBStrategy):
17
17
  **kwargs
18
18
  ) -> Optional[str]:
19
19
  try:
20
+ documents = self._remove_empty_documents(documents)
20
21
  chunked_docs = DocumentChunker.chunk(documents)
22
+ batches = self._batch_documents_by_tokens(chunked_docs)
23
+ logging.info(f"documents: {len(documents)}, after chunking: {len(chunked_docs)}, processing batches: {len(batches)}")
24
+ _instance: QDRANT = None
21
25
  if not os.path.exists(storage_id):
22
26
  os.makedirs(storage_id)
23
27
 
24
- def _create():
25
- QDRANT.from_documents(
26
- documents=chunked_docs,
27
- embedding=embeddings,
28
- sparse_embedding=kwargs['sparse_embedding'] if 'sparse_embedding' in kwargs else FastEmbedSparse(),
29
- collection_name="default",
30
- path=storage_id,
31
- retrieval_mode=RetrievalMode.HYBRID
32
- )
33
-
34
- await asyncio.to_thread(_create)
35
- self._clear_cache(storage_id)
36
- return storage_id
28
+ for i, batch in enumerate(batches):
29
+ batch_tokens = sum(self._count_tokens(doc.page_content) for doc in batch)
30
+ logging.info(f"processing batch {i+1}/{len(batches)} with {len(batch)} docs ({batch_tokens:,} tokens)")
31
+ # create instance from first batch
32
+ if _instance is None:
33
+ _instance = await asyncio.to_thread(
34
+ QDRANT.from_documents,
35
+ documents=batch,
36
+ embedding=embeddings,
37
+ sparse_embedding=kwargs['sparse_embedding'] if 'sparse_embedding' in kwargs else FastEmbedSparse(),
38
+ collection_name="default",
39
+ path=storage_id,
40
+ retrieval_mode=RetrievalMode.HYBRID
41
+ )
42
+ else:
43
+ # merge to existing instance
44
+ await _instance.aadd_documents(batch)
45
+ # add a small delay to avoid rate limiting
46
+ if i < len(batches) - 1: # except last batch
47
+ await asyncio.sleep(1)
48
+ if _instance:
49
+ self._clear_cache(storage_id)
50
+ logging.info(f"Successfully created {Qdrant.__name__} index with {len(chunked_docs)} total documents")
51
+ return storage_id
37
52
  except Exception as e:
38
53
  logging.error(f"{Qdrant.__name__} create error: {e}")
39
54
  raise e
40
55
  finally:
41
- del documents
56
+ del documents, chunked_docs, _instance
42
57
  gc.collect()
43
58
 
44
59
  def get_loader(