ai-parrot 0.3.4__cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-parrot might be problematic. Click here for more details.

Files changed (109) hide show
  1. ai_parrot-0.3.4.dist-info/LICENSE +21 -0
  2. ai_parrot-0.3.4.dist-info/METADATA +319 -0
  3. ai_parrot-0.3.4.dist-info/RECORD +109 -0
  4. ai_parrot-0.3.4.dist-info/WHEEL +6 -0
  5. ai_parrot-0.3.4.dist-info/top_level.txt +3 -0
  6. parrot/__init__.py +21 -0
  7. parrot/chatbots/__init__.py +7 -0
  8. parrot/chatbots/abstract.py +728 -0
  9. parrot/chatbots/asktroc.py +16 -0
  10. parrot/chatbots/base.py +366 -0
  11. parrot/chatbots/basic.py +9 -0
  12. parrot/chatbots/bose.py +17 -0
  13. parrot/chatbots/cody.py +17 -0
  14. parrot/chatbots/copilot.py +83 -0
  15. parrot/chatbots/dataframe.py +103 -0
  16. parrot/chatbots/hragents.py +15 -0
  17. parrot/chatbots/odoo.py +17 -0
  18. parrot/chatbots/retrievals/__init__.py +578 -0
  19. parrot/chatbots/retrievals/constitutional.py +19 -0
  20. parrot/conf.py +110 -0
  21. parrot/crew/__init__.py +3 -0
  22. parrot/crew/tools/__init__.py +22 -0
  23. parrot/crew/tools/bing.py +13 -0
  24. parrot/crew/tools/config.py +43 -0
  25. parrot/crew/tools/duckgo.py +62 -0
  26. parrot/crew/tools/file.py +24 -0
  27. parrot/crew/tools/google.py +168 -0
  28. parrot/crew/tools/gtrends.py +16 -0
  29. parrot/crew/tools/md2pdf.py +25 -0
  30. parrot/crew/tools/rag.py +42 -0
  31. parrot/crew/tools/search.py +32 -0
  32. parrot/crew/tools/url.py +21 -0
  33. parrot/exceptions.cpython-310-x86_64-linux-gnu.so +0 -0
  34. parrot/handlers/__init__.py +4 -0
  35. parrot/handlers/bots.py +196 -0
  36. parrot/handlers/chat.py +162 -0
  37. parrot/interfaces/__init__.py +6 -0
  38. parrot/interfaces/database.py +29 -0
  39. parrot/llms/__init__.py +137 -0
  40. parrot/llms/abstract.py +47 -0
  41. parrot/llms/anthropic.py +42 -0
  42. parrot/llms/google.py +42 -0
  43. parrot/llms/groq.py +45 -0
  44. parrot/llms/hf.py +45 -0
  45. parrot/llms/openai.py +59 -0
  46. parrot/llms/pipes.py +114 -0
  47. parrot/llms/vertex.py +78 -0
  48. parrot/loaders/__init__.py +20 -0
  49. parrot/loaders/abstract.py +456 -0
  50. parrot/loaders/audio.py +106 -0
  51. parrot/loaders/basepdf.py +102 -0
  52. parrot/loaders/basevideo.py +280 -0
  53. parrot/loaders/csv.py +42 -0
  54. parrot/loaders/dir.py +37 -0
  55. parrot/loaders/excel.py +349 -0
  56. parrot/loaders/github.py +65 -0
  57. parrot/loaders/handlers/__init__.py +5 -0
  58. parrot/loaders/handlers/data.py +213 -0
  59. parrot/loaders/image.py +119 -0
  60. parrot/loaders/json.py +52 -0
  61. parrot/loaders/pdf.py +437 -0
  62. parrot/loaders/pdfchapters.py +142 -0
  63. parrot/loaders/pdffn.py +112 -0
  64. parrot/loaders/pdfimages.py +207 -0
  65. parrot/loaders/pdfmark.py +88 -0
  66. parrot/loaders/pdftables.py +145 -0
  67. parrot/loaders/ppt.py +30 -0
  68. parrot/loaders/qa.py +81 -0
  69. parrot/loaders/repo.py +103 -0
  70. parrot/loaders/rtd.py +65 -0
  71. parrot/loaders/txt.py +92 -0
  72. parrot/loaders/utils/__init__.py +1 -0
  73. parrot/loaders/utils/models.py +25 -0
  74. parrot/loaders/video.py +96 -0
  75. parrot/loaders/videolocal.py +120 -0
  76. parrot/loaders/vimeo.py +106 -0
  77. parrot/loaders/web.py +216 -0
  78. parrot/loaders/web_base.py +112 -0
  79. parrot/loaders/word.py +125 -0
  80. parrot/loaders/youtube.py +192 -0
  81. parrot/manager.py +166 -0
  82. parrot/models.py +372 -0
  83. parrot/py.typed +0 -0
  84. parrot/stores/__init__.py +48 -0
  85. parrot/stores/abstract.py +171 -0
  86. parrot/stores/milvus.py +632 -0
  87. parrot/stores/qdrant.py +153 -0
  88. parrot/tools/__init__.py +12 -0
  89. parrot/tools/abstract.py +53 -0
  90. parrot/tools/asknews.py +32 -0
  91. parrot/tools/bing.py +13 -0
  92. parrot/tools/duck.py +62 -0
  93. parrot/tools/google.py +170 -0
  94. parrot/tools/stack.py +26 -0
  95. parrot/tools/weather.py +70 -0
  96. parrot/tools/wikipedia.py +59 -0
  97. parrot/tools/zipcode.py +179 -0
  98. parrot/utils/__init__.py +2 -0
  99. parrot/utils/parsers/__init__.py +5 -0
  100. parrot/utils/parsers/toml.cpython-310-x86_64-linux-gnu.so +0 -0
  101. parrot/utils/toml.py +11 -0
  102. parrot/utils/types.cpython-310-x86_64-linux-gnu.so +0 -0
  103. parrot/utils/uv.py +11 -0
  104. parrot/version.py +10 -0
  105. resources/users/__init__.py +5 -0
  106. resources/users/handlers.py +13 -0
  107. resources/users/models.py +205 -0
  108. settings/__init__.py +0 -0
  109. settings/settings.py +51 -0
parrot/models.py ADDED
@@ -0,0 +1,372 @@
1
+ from typing import Union, Optional
2
+ import uuid
3
+ import time
4
+ from datetime import datetime
5
+ from pathlib import Path, PurePath
6
+ from enum import Enum
7
+ from langchain_core.agents import AgentAction
8
+
9
+ from datamodel import BaseModel, Field
10
+ from datamodel.types import Text # pylint: disable=no-name-in-module
11
+ from asyncdb.models import Model
12
+
13
+ def created_at(*args, **kwargs) -> int:
14
+ return int(time.time()) * 1000
15
+
16
+
17
+ class AgentResponse(BaseModel):
18
+ """AgentResponse.
19
+ dict_keys(
20
+ ['input', 'chat_history', 'output', 'intermediate_steps']
21
+ )
22
+
23
+ Response from Chatbots.
24
+ """
25
+ question: str = Field(required=False)
26
+ input: str = Field(required=False)
27
+ output: str = Field(required=False)
28
+ response: str = Field(required=False)
29
+ intermediate_steps: list = Field(default_factory=list)
30
+ chat_history: list = Field(repr=True, default_factory=list)
31
+ source_documents: list = Field(required=False, default_factory=list)
32
+
33
+ def __post_init__(self) -> None:
34
+ if self.intermediate_steps:
35
+ steps = []
36
+ for item, result in self.intermediate_steps:
37
+ if isinstance(item, AgentAction):
38
+ # convert into dictionary:
39
+ steps.append(
40
+ {
41
+ "tool": item.tool,
42
+ "tool_input": item.tool_input,
43
+ "result": result,
44
+ "log": str(item.log)
45
+ }
46
+ )
47
+ if steps:
48
+ self.intermediate_steps = steps
49
+
50
+
51
+ class ChatResponse(BaseModel):
52
+ """ChatResponse.
53
+ dict_keys(
54
+ ['question', 'chat_history', 'answer', 'source_documents', 'generated_question']
55
+ )
56
+
57
+ Response from Chatbots.
58
+ """
59
+ query: str = Field(required=False)
60
+ result: str = Field(required=False)
61
+ question: str = Field(required=False)
62
+ generated_question: str = Field(required=False)
63
+ answer: str = Field(required=False)
64
+ response: str = Field(required=False)
65
+ chat_history: list = Field(repr=True, default_factory=list)
66
+ source_documents: list = Field(required=False, default_factory=list)
67
+ documents: dict = Field(required=False, default_factory=dict)
68
+ sid: uuid.UUID = Field(primary_key=True, required=False, default=uuid.uuid4)
69
+ at: int = Field(default=created_at)
70
+
71
+ def __post_init__(self) -> None:
72
+ if self.result and not self.answer:
73
+ self.answer = self.result
74
+ if self.question and not self.generated_question:
75
+ self.generated_question = self.question
76
+ return super().__post_init__()
77
+
78
+
79
+ # Chatbot Model:
80
+ class ChatbotModel(Model):
81
+ """Chatbot.
82
+ CREATE TABLE IF NOT EXISTS navigator.chatbots (
83
+ chatbot_id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
84
+ name VARCHAR NOT NULL DEFAULT 'Nav',
85
+ description VARCHAR,
86
+ config_file VARCHAR DEFAULT 'config.toml',
87
+ company_information JSONB DEFAULT '{}'::JSONB,
88
+ contact_information VARCHAR,
89
+ contact_form VARCHAR,
90
+ contact_email VARCHAR,
91
+ company_website VARCHAR,
92
+ avatar TEXT,
93
+ enabled BOOLEAN NOT NULL DEFAULT TRUE,
94
+ timezone VARCHAR DEFAULT 'UTC',
95
+ attributes JSONB DEFAULT '{}'::JSONB,
96
+ role VARCHAR DEFAULT 'a Human Resources Assistant',
97
+ goal VARCHAR NOT NULL DEFAULT 'Bring useful information to Users.',
98
+ backstory VARCHAR NOT NULL DEFAULT 'I was created by a team of developers to assist with users tasks.',
99
+ rationale VARCHAR NOT NULL DEFAULT 'Remember to maintain a professional tone. Please provide accurate and relevant information.',
100
+ language VARCHAR DEFAULT 'en',
101
+ template_prompt VARCHAR,
102
+ pre_instructions JSONB DEFAULT '[]'::JSONB,
103
+ llm VARCHAR DEFAULT 'VertexLLM',
104
+ model_name VARCHAR DEFAULT 'gemini-pro',
105
+ model_config JSONB DEFAULT '{}'::JSONB,
106
+ embedding_name VARCHAR DEFAULT 'thenlper/gte-base',
107
+ tokenizer VARCHAR DEFAULT 'thenlper/gte-base',
108
+ summarize_model VARCHAR DEFAULT 'facebook/bart-large-cnn',
109
+ classification_model VARCHAR DEFAULT 'facebook/bart-large-cnn',
110
+ database JSONB DEFAULT '{"vector_database": "MilvusStore", "database": "TROC", "collection_name": "troc_information"}'::JSONB,
111
+ created_at TIMESTAMPTZ DEFAULT NOW(),
112
+ created_by INTEGER,
113
+ updated_at TIMESTAMPTZ DEFAULT NOW()
114
+ );
115
+ """
116
+ chatbot_id: uuid.UUID = Field(primary_key=True, required=False, default_factory=uuid.uuid4)
117
+ name: str = Field(default='Nav', required=True)
118
+ description: str = Field(default='Nav Chatbot', required=False)
119
+ config_file: str = Field(default='config.toml', required=False)
120
+ custom_class: str = Field(required=False)
121
+ company_information: dict = Field(default_factory=dict, required=False)
122
+ avatar: Text
123
+ enabled: bool = Field(required=True, default=True)
124
+ timezone: str = Field(required=False, max=75, default="UTC", repr=False)
125
+ attributes: Optional[dict] = Field(required=False, default_factory=dict)
126
+ # Chatbot Configuration
127
+ role: str = Field(
128
+ default="a Human Resources Assistant",
129
+ required=False
130
+ )
131
+ goal: str = Field(
132
+ default="Bring useful information to Users.",
133
+ required=True
134
+ )
135
+ backstory: str = Field(
136
+ default="I was created by a team of developers to assist with users tasks.",
137
+ required=True
138
+ )
139
+ rationale: str = Field(
140
+ default=(
141
+ "Remember to maintain a professional tone."
142
+ " Please provide accurate and relevant information."
143
+ ),
144
+ required=True
145
+ )
146
+ language: str = Field(default='en', required=False)
147
+ template_prompt: Union[str, PurePath] = Field(
148
+ default=None,
149
+ required=False
150
+ )
151
+ pre_instructions: list = Field(
152
+ default_factory=list,
153
+ required=False
154
+ )
155
+ # Model Configuration:
156
+ llm: str = Field(default='VertexLLM', required=False)
157
+ llm_config: dict = Field(default_factory=dict, required=False)
158
+ embedding_name: str = Field(default="thenlper/gte-base", required=False)
159
+ tokenizer: str = Field(default='thenlper/gte-base', required=False)
160
+ summarize_model: str = Field(default="facebook/bart-large-cnn", required=False)
161
+ classification_model: str = Field(default="facebook/bart-large-cnn", required=False)
162
+ # Database Configuration
163
+ database: dict = Field(default='TROC', required=False, default_factory=dict)
164
+ # Bot/Agent type
165
+ bot_type: str = Field(default='chatbot', required=False)
166
+ # When created
167
+ created_at: datetime = Field(required=False, default=datetime.now())
168
+ created_by: int = Field(required=False)
169
+ updated_at: datetime = Field(required=False, default=datetime.now())
170
+
171
+ def __post_init__(self) -> None:
172
+ super(ChatbotModel, self).__post_init__()
173
+ if isinstance(self.config_file, str):
174
+ self.config_file = Path(self.config_file).resolve()
175
+ if isinstance(self.config_file, PurePath):
176
+ self.config_file = str(self.config_file)
177
+
178
+ class Meta:
179
+ """Meta Chatbot."""
180
+ driver = 'pg'
181
+ name = "chatbots"
182
+ schema = "navigator"
183
+ strict = True
184
+ frozen = False
185
+
186
+
187
+ class ChatbotUsage(Model):
188
+ """ChatbotUsage.
189
+
190
+ Saving information about Chatbot Usage.
191
+
192
+ -- ScyllaDB CREATE TABLE Syntax --
193
+ CREATE TABLE IF NOT EXISTS navigator.chatbots_usage (
194
+ chatbot_id TEXT,
195
+ user_id SMALLINT,
196
+ sid TEXT,
197
+ source_path TEXT,
198
+ platform TEXT,
199
+ origin inet,
200
+ user_agent TEXT,
201
+ question TEXT,
202
+ response TEXT,
203
+ used_at BIGINT,
204
+ at TEXT,
205
+ PRIMARY KEY ((chatbot_id, sid, at), used_at)
206
+ ) WITH CLUSTERING ORDER BY (used_at DESC)
207
+ AND default_time_to_live = 10368000;
208
+
209
+ """
210
+ chatbot_id: uuid.UUID = Field(primary_key=True, required=False)
211
+ user_id: int = Field(primary_key=True, required=False)
212
+ sid: uuid.UUID = Field(primary_key=True, required=False, default=uuid.uuid4)
213
+ source_path: str = Field(required=False, default='web')
214
+ platform: str = Field(required=False, default='web')
215
+ origin: str = Field(required=False)
216
+ user_agent: str = Field(required=False)
217
+ question: str = Field(required=False)
218
+ response: str = Field(required=False)
219
+ used_at: int = Field(required=False, default=created_at)
220
+ event_timestamp: datetime = Field(required=False, default=datetime.now)
221
+ _at: str = Field(primary_key=True, required=False)
222
+
223
+ class Meta:
224
+ """Meta Chatbot."""
225
+ driver = 'bigquery'
226
+ name = "chatbots_usage"
227
+ schema = "navigator"
228
+ ttl = 10368000 # 120 days in seconds
229
+ strict = True
230
+ frozen = False
231
+
232
+ def __post_init__(self) -> None:
233
+ if not self._at:
234
+ # Generate a unique session id
235
+ self._at = f'{self.sid}:{self.used_at}'
236
+ super(ChatbotUsage, self).__post_init__()
237
+
238
+
239
+ class FeedbackType(Enum):
240
+ """FeedbackType."""
241
+ # Good Feedback
242
+ GOOD_COMPLETE = "Completeness"
243
+ GOOD_CORRECT = "Correct"
244
+ GOOD_FOLLOW = "Follow the instructions"
245
+ GOOD_UNDERSTAND = "Understandable"
246
+ GOOD_USEFUL = "very useful"
247
+ GOOD_OTHER = "Please Explain"
248
+ # Bad Feedback
249
+ BAD_DONTLIKE = "Don't like the style"
250
+ BAD_INCORRECT = "Incorrect"
251
+ BAD_NOTFOLLOW = "Didn't follow the instructions"
252
+ BAD_LAZY = "Being lazy"
253
+ BAD_NOTUSEFUL = "Not useful"
254
+ BAD_UNSAFE = "Unsafe or problematic"
255
+ BAD_OTHER = "Other"
256
+
257
+ @classmethod
258
+ def list_feedback(cls, feedback_category):
259
+ """Return a list of feedback types based on the given category (Good or Bad)."""
260
+ prefix = feedback_category.upper() + "_"
261
+ return [feedback for feedback in cls if feedback.name.startswith(prefix)]
262
+
263
+ class ChatbotFeedback(Model):
264
+ """ChatbotFeedback.
265
+
266
+ Saving information about Chatbot Feedback.
267
+
268
+ -- ScyllaDB CREATE TABLE Syntax --
269
+ CREATE TABLE IF NOT EXISTS navigator.chatbots_feedback (
270
+ chatbot_id UUID,
271
+ user_id INT,
272
+ sid UUID,
273
+ at TEXT,
274
+ rating TINYINT,
275
+ like BOOLEAN,
276
+ dislike BOOLEAN,
277
+ feedback_type TEXT,
278
+ feedback TEXT,
279
+ created_at BIGINT,
280
+ PRIMARY KEY ((chatbot_id, user_id, sid), created_at)
281
+ ) WITH CLUSTERING ORDER BY (created_at DESC)
282
+ AND default_time_to_live = 7776000;
283
+
284
+ """
285
+ chatbot_id: uuid.UUID = Field(primary_key=True, required=False)
286
+ user_id: int = Field(required=False)
287
+ sid: uuid.UUID = Field(primary_key=True, required=False)
288
+ _at: str = Field(primary_key=True, required=False)
289
+ # feedback information:
290
+ rating: int = Field(required=False, default=0)
291
+ _like: bool = Field(required=False, default=False)
292
+ _dislike: bool = Field(required=False, default=False)
293
+ feedback_type: FeedbackType = Field(required=False)
294
+ feedback: str = Field(required=False)
295
+ created_at: int = Field(required=False, default=created_at)
296
+ expiration_timestamp: datetime = Field(required=False, default=datetime.now)
297
+
298
+ class Meta:
299
+ """Meta Chatbot."""
300
+ driver = 'bigquery'
301
+ name = "chatbots_feedback"
302
+ schema = "navigator"
303
+ ttl = 7776000 # 3 months in seconds
304
+ strict = True
305
+ frozen = False
306
+
307
+ def __post_init__(self) -> None:
308
+ if not self._at:
309
+ # Generate a unique session id
310
+ if not self.created_at:
311
+ self.created_at = created_at()
312
+ self._at = f'{self.sid}:{self.created_at}'
313
+ super(ChatbotFeedback, self).__post_init__()
314
+
315
+
316
+ ## Prompt Library:
317
+
318
+ class PromptCategory(Enum):
319
+ """
320
+ Prompt Category.
321
+
322
+ Categorization of Prompts, as "tech", "tech-or-explain", "idea", "explain", "action", "command", "other".
323
+ """
324
+ TECH = "tech"
325
+ TECH_OR_EXPLAIN = "tech-or-explain"
326
+ IDEA = "idea"
327
+ EXPLAIN = "explain"
328
+ ACTION = "action"
329
+ COMMAND = "command"
330
+ OTHER = "other"
331
+
332
+ class PromptLibrary(Model):
333
+ """PromptLibrary.
334
+
335
+ Saving information about Prompt Library.
336
+
337
+ -- PostgreSQL CREATE TABLE Syntax --
338
+ CREATE TABLE IF NOT EXISTS navigator.prompt_library (
339
+ prompt_id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
340
+ chatbot_id UUID,
341
+ title varchar,
342
+ query varchar,
343
+ description TEXT,
344
+ prompt_category varchar,
345
+ prompt_tags varchar[],
346
+ created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
347
+ created_by INTEGER,
348
+ updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
349
+ );
350
+ """
351
+ prompt_id: uuid.UUID = Field(primary_key=True, required=False, default_factory=uuid.uuid4)
352
+ chatbot_id: uuid.UUID = Field(required=True)
353
+ title: str = Field(required=True)
354
+ query: str = Field(required=True)
355
+ description: str = Field(required=False)
356
+ prompt_category: str = Field(required=False, default=PromptCategory.OTHER)
357
+ prompt_tags: list = Field(required=False, default_factory=list)
358
+ created_at: datetime = Field(required=False, default=datetime.now)
359
+ created_by: int = Field(required=False)
360
+ updated_at: datetime = Field(required=False, default=datetime.now)
361
+
362
+ class Meta:
363
+ """Meta Prompt Library."""
364
+ driver = 'pg'
365
+ name = "prompt_library"
366
+ schema = "navigator"
367
+ strict = True
368
+ frozen = False
369
+
370
+
371
+ def __post_init__(self) -> None:
372
+ super(PromptLibrary, self).__post_init__()
parrot/py.typed ADDED
File without changes
@@ -0,0 +1,48 @@
1
+ from typing import Any
2
+ from .abstract import AbstractStore
3
+ from ..exceptions import ConfigError # pylint: disable=E0611
4
+ try:
5
+ from .qdrant import QdrantStore
6
+ QDRANT_ENABLED = True
7
+ except (ModuleNotFoundError, ImportError):
8
+ QDRANT_ENABLED = False
9
+
10
+ try:
11
+ from .milvus import MilvusStore
12
+ MILVUS_ENABLED = True
13
+ except (ModuleNotFoundError, ImportError):
14
+ MILVUS_ENABLED = False
15
+
16
+
17
+ def get_vectordb(vector_db: str, embeddings: Any, **kwargs) -> AbstractStore:
18
+ if vector_db in ('QdrantStore', 'qdrant'):
19
+ if QDRANT_ENABLED is True:
20
+ ## TODO: support pluggable vector store
21
+ return QdrantStore( # pylint: disable=E0110
22
+ embeddings=embeddings,
23
+ **kwargs
24
+ )
25
+ else:
26
+ raise ConfigError(
27
+ (
28
+ "Qdrant is enabled but not installed, "
29
+ "Hint: Please install with pip install -e .[qdrant]"
30
+ )
31
+ )
32
+ if vector_db in ('milvus', 'MilvusStore'):
33
+ if MILVUS_ENABLED is True:
34
+ return MilvusStore(
35
+ embeddings=embeddings,
36
+ **kwargs
37
+ )
38
+ else:
39
+ raise ConfigError(
40
+ (
41
+ "Milvus is enabled but not installed, "
42
+ "Hint: Please install with pip install -e .[milvus]"
43
+ )
44
+ )
45
+ else:
46
+ raise ValueError(
47
+ f"Vector Database {vector_db} not supported"
48
+ )
@@ -0,0 +1,171 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Union, Any
3
+ from collections.abc import Callable
4
+ import torch
5
+ from langchain_huggingface import (
6
+ HuggingFaceEmbeddings
7
+ )
8
+ from langchain_community.embeddings import (
9
+ HuggingFaceBgeEmbeddings
10
+ )
11
+ from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
12
+ from navconfig.logging import logging
13
+ from ..conf import (
14
+ EMBEDDING_DEVICE,
15
+ EMBEDDING_DEFAULT_MODEL,
16
+ CUDA_DEFAULT_DEVICE,
17
+ MAX_BATCH_SIZE
18
+ )
19
+
20
+
21
+ class AbstractStore(ABC):
22
+ """AbstractStore class.
23
+
24
+ Args:
25
+ embeddings (str): Embeddings.
26
+ """
27
+
28
+ def __init__(self, embeddings: Union[str, Callable] = None, **kwargs):
29
+ self.client: Callable = None
30
+ self.vector: Callable = None
31
+ self._embed_: Callable = None
32
+ self._connected: bool = False
33
+ self.use_bge: bool = kwargs.pop("use_bge", False)
34
+ self.fastembed: bool = kwargs.pop("use_fastembed", False)
35
+ self.embedding_name: str = kwargs.pop('embedding_name', EMBEDDING_DEFAULT_MODEL)
36
+ self.dimension: int = kwargs.pop("dimension", 768)
37
+ self._metric_type: str = kwargs.pop("metric_type", 'COSINE')
38
+ self._index_type: str = kwargs.pop("index_type", 'IVF_FLAT')
39
+ self.database: str = kwargs.pop('database', '')
40
+ self.collection = kwargs.pop("collection_name", "my_collection")
41
+ self.index_name = kwargs.pop("index_name", "my_index")
42
+ if embeddings is not None:
43
+ if isinstance(embeddings, str):
44
+ self.embedding_name = embeddings
45
+ else:
46
+ self._embed_ = embeddings
47
+ self.logger = logging.getLogger(f"Store.{__name__}")
48
+ # client
49
+ self._client = None
50
+ self._client_id = None
51
+
52
+ @property
53
+ def connected(self) -> bool:
54
+ return self._connected
55
+
56
+ async def __aenter__(self):
57
+ try:
58
+ self.tensor = torch.randn(1000, 1000).cuda()
59
+ except RuntimeError:
60
+ self.tensor = None
61
+ if self._embed_ is None:
62
+ self._embed_ = self.create_embedding(
63
+ model_name=self.embedding_name
64
+ )
65
+ self._client, self._client_id = self.connect()
66
+ return self
67
+
68
+ @abstractmethod
69
+ def connect(self):
70
+ pass
71
+
72
+ def __enter__(self):
73
+ if self._embed_ is None:
74
+ self._embed_ = self.create_embedding(
75
+ model_name=self.embedding_name
76
+ )
77
+ return self
78
+
79
+ async def __aexit__(self, exc_type, exc_value, traceback):
80
+ # closing Embedding
81
+ self._embed_ = None
82
+ del self.tensor
83
+ try:
84
+ torch.cuda.empty_cache()
85
+ except RuntimeError:
86
+ pass
87
+
88
+ def __exit__(self, exc_type, exc_value, traceback):
89
+ # closing Embedding
90
+ self._embed_ = None
91
+ try:
92
+ torch.cuda.empty_cache()
93
+ except RuntimeError:
94
+ pass
95
+
96
+ @abstractmethod
97
+ def get_vector(self):
98
+ pass
99
+
100
+ @abstractmethod
101
+ async def load_documents(
102
+ self,
103
+ documents: list,
104
+ collection: str = None
105
+ ):
106
+ pass
107
+
108
+ @abstractmethod
109
+ def upsert(self, payload: dict, collection_name: str = None) -> None:
110
+ pass
111
+
112
+ @abstractmethod
113
+ def search(self, payload: dict, collection_name: str = None) -> dict:
114
+ pass
115
+
116
+ @abstractmethod
117
+ async def delete_collection(self, collection_name: str = None) -> dict:
118
+ pass
119
+
120
+ @abstractmethod
121
+ async def create_collection(self, collection_name: str, document: Any) -> dict:
122
+ pass
123
+
124
+ def create_embedding(
125
+ self,
126
+ model_name: str = None
127
+ ):
128
+ encode_kwargs: str = {
129
+ 'normalize_embeddings': True,
130
+ "batch_size": MAX_BATCH_SIZE
131
+ }
132
+ if torch.backends.mps.is_available():
133
+ # Use CUDA Multi-Processing Service if available
134
+ device = torch.device("mps")
135
+ elif torch.cuda.is_available():
136
+ # Use CUDA GPU if available
137
+ device = torch.device(
138
+ f'cuda:{CUDA_DEFAULT_DEVICE}'
139
+ )
140
+ elif EMBEDDING_DEVICE == 'cuda':
141
+ device = torch.device(
142
+ f'cuda:{CUDA_DEFAULT_DEVICE}'
143
+ )
144
+ else:
145
+ device = torch.device(EMBEDDING_DEVICE)
146
+ model_kwargs: str = {'device': device}
147
+ if model_name is None:
148
+ model_name = EMBEDDING_DEFAULT_MODEL
149
+ if self.use_bge is True:
150
+ return HuggingFaceBgeEmbeddings(
151
+ model_name=model_name,
152
+ model_kwargs=model_kwargs,
153
+ encode_kwargs=encode_kwargs
154
+ )
155
+ if self.fastembed is True:
156
+ return FastEmbedEmbeddings(
157
+ model_name=model_name,
158
+ max_length=1024,
159
+ threads=4
160
+ )
161
+ return HuggingFaceEmbeddings(
162
+ model_name=model_name,
163
+ model_kwargs=model_kwargs,
164
+ encode_kwargs=encode_kwargs
165
+ )
166
+
167
+ def get_default_embedding(
168
+ self,
169
+ model_name: str = EMBEDDING_DEFAULT_MODEL
170
+ ):
171
+ return self.create_embedding(model_name=model_name)