ddi-fw 0.0.241__py3-none-any.whl → 0.0.242__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddi_fw/datasets/core.py +70 -65
- ddi_fw/langchain/__init__.py +1 -1
- ddi_fw/langchain/chroma_storage.py +134 -22
- ddi_fw/langchain/faiss_storage.py +25 -9
- ddi_fw/pipeline/multi_pipeline.py +5 -1
- ddi_fw/pipeline/pipeline.py +16 -0
- {ddi_fw-0.0.241.dist-info → ddi_fw-0.0.242.dist-info}/METADATA +1 -1
- {ddi_fw-0.0.241.dist-info → ddi_fw-0.0.242.dist-info}/RECORD +10 -10
- {ddi_fw-0.0.241.dist-info → ddi_fw-0.0.242.dist-info}/WHEEL +0 -0
- {ddi_fw-0.0.241.dist-info → ddi_fw-0.0.242.dist-info}/top_level.txt +0 -0
ddi_fw/datasets/core.py
CHANGED
@@ -3,12 +3,13 @@ from collections import defaultdict
|
|
3
3
|
import glob
|
4
4
|
import logging
|
5
5
|
from typing import Any, Dict, List, Optional, Type
|
6
|
-
import chromadb
|
6
|
+
# import chromadb
|
7
7
|
# from chromadb.api.types import IncludeEnum
|
8
8
|
import numpy as np
|
9
9
|
import pandas as pd
|
10
10
|
from pydantic import BaseModel, Field, computed_field
|
11
11
|
from ddi_fw.datasets.dataset_splitter import DatasetSplitter
|
12
|
+
from ddi_fw.langchain.faiss_storage import BaseVectorStoreManager
|
12
13
|
from ddi_fw.utils.utils import create_folder_if_not_exists
|
13
14
|
|
14
15
|
|
@@ -280,6 +281,8 @@ class TextDatasetMixin(BaseModel):
|
|
280
281
|
default_factory=dict, description="Dictionary for embeddings")
|
281
282
|
pooling_strategy: PoolingStrategy | None = None
|
282
283
|
column_embedding_configs: Optional[List] = None
|
284
|
+
vector_store_manager: BaseVectorStoreManager| None = None # <-- NEW
|
285
|
+
|
283
286
|
vector_db_persist_directory: Optional[str] = None
|
284
287
|
vector_db_collection_name: Optional[str] = None
|
285
288
|
_embedding_size: int
|
@@ -292,70 +295,70 @@ class TextDatasetMixin(BaseModel):
|
|
292
295
|
class Config:
|
293
296
|
arbitrary_types_allowed = True
|
294
297
|
|
295
|
-
def __create_or_update_embeddings__(self, embedding_dict, vector_db_persist_directory, vector_db_collection_name, column=None):
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
298
|
+
# def __create_or_update_embeddings__(self, embedding_dict, vector_db_persist_directory, vector_db_collection_name, column=None):
|
299
|
+
# """
|
300
|
+
# Fetch embeddings and metadata from a persistent Chroma vector database and update the provided embedding_dict.
|
301
|
+
|
302
|
+
# Args:
|
303
|
+
# - vector_db_persist_directory (str): The path to the directory where the Chroma vector database is stored.
|
304
|
+
# - vector_db_collection_name (str): The name of the collection to query.
|
305
|
+
# - embedding_dict (dict): The existing dictionary to update with embeddings.
|
306
|
+
|
307
|
+
# """
|
308
|
+
# if vector_db_persist_directory:
|
309
|
+
# # Initialize the Chroma client and get the collection
|
310
|
+
# vector_db = chromadb.PersistentClient(
|
311
|
+
# path=vector_db_persist_directory)
|
312
|
+
# collection = vector_db.get_collection(vector_db_collection_name)
|
313
|
+
# # include = [IncludeEnum.embeddings, IncludeEnum.metadatas]
|
314
|
+
# include: chromadb.Include = ["embeddings","metadatas"]
|
315
|
+
# dictionary: chromadb.GetResult
|
316
|
+
# # Fetch the embeddings and metadata
|
317
|
+
# if column == None:
|
318
|
+
# dictionary = collection.get(
|
319
|
+
# include=include
|
320
|
+
# # include=['embeddings', 'metadatas']
|
321
|
+
# )
|
322
|
+
# print(
|
323
|
+
# f"Embeddings are calculated from {vector_db_collection_name}")
|
324
|
+
# else:
|
325
|
+
# dictionary = collection.get(
|
326
|
+
# include=include,
|
327
|
+
# # include=['embeddings', 'metadatas'],
|
328
|
+
# where={
|
329
|
+
# "type": {"$eq": f"{column}"}})
|
330
|
+
# print(
|
331
|
+
# f"Embeddings of {column} are calculated from {vector_db_collection_name}")
|
332
|
+
|
333
|
+
# # Populate the embedding dictionary with embeddings from the vector database
|
334
|
+
# metadatas = dictionary["metadatas"]
|
335
|
+
# embeddings = dictionary["embeddings"]
|
336
|
+
# if metadatas is None or embeddings is None:
|
337
|
+
# raise ValueError(
|
338
|
+
# "The collection does not contain embeddings or metadatas.")
|
339
|
+
# for metadata, embedding in zip(metadatas, embeddings):
|
340
|
+
# embedding_dict[metadata["type"]
|
341
|
+
# ][metadata["id"]].append(embedding)
|
342
|
+
|
343
|
+
# else:
|
344
|
+
# raise ValueError(
|
345
|
+
# "Persistent directory for the vector DB is not specified.")
|
343
346
|
|
344
|
-
def __initialize_embedding_dict(self):
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
347
|
+
# def __initialize_embedding_dict(self):
|
348
|
+
# embedding_dict = defaultdict(lambda: defaultdict(list))
|
349
|
+
# if self.column_embedding_configs:
|
350
|
+
# for item in self.column_embedding_configs:
|
351
|
+
# col = item["column"]
|
352
|
+
# col_db_dir = item["vector_db_persist_directory"]
|
353
|
+
# col_db_collection = item["vector_db_collection_name"]
|
354
|
+
# self.__create_or_update_embeddings__(embedding_dict, col_db_dir, col_db_collection, col)
|
355
|
+
# elif self.vector_db_persist_directory:
|
356
|
+
# self.__create_or_update_embeddings__(embedding_dict, self.vector_db_persist_directory, self.vector_db_collection_name)
|
357
|
+
# else:
|
358
|
+
# logging.warning("There is no configuration of Embeddings")
|
359
|
+
# raise ValueError(
|
360
|
+
# "There is no configuration of Embeddings. Please provide a vector database directory and collection name.")
|
361
|
+
# return embedding_dict
|
359
362
|
|
360
363
|
def __calculate_embedding_size(self):
|
361
364
|
if self.embedding_dict is None:
|
@@ -373,7 +376,9 @@ class TextDatasetMixin(BaseModel):
|
|
373
376
|
# for k, v in self.ner_threshold.items():
|
374
377
|
# kwargs[k] = v
|
375
378
|
if self.embedding_dict is None:
|
376
|
-
self.
|
379
|
+
if self.vector_store_manager is not None:
|
380
|
+
self.embedding_dict = self.vector_store_manager.initialize_embedding_dict()
|
381
|
+
# self.embedding_dict = self.__initialize_embedding_dict()
|
377
382
|
self.__calculate_embedding_size()
|
378
383
|
|
379
384
|
|
ddi_fw/langchain/__init__.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
from ..langchain.embeddings import PoolingStrategy,SumPoolingStrategy,MeanPoolingStrategy,SentenceTransformerDecorator,PretrainedEmbeddings,SBertEmbeddings
|
2
2
|
from .sentence_splitter import SentenceSplitter
|
3
3
|
from .storage import DataFrameToVectorDB, generate_embeddings
|
4
|
-
from .faiss_storage import BaseVectorStoreManager,
|
4
|
+
from .faiss_storage import BaseVectorStoreManager, FaissVectorStoreManager
|
5
5
|
from .chroma_storage import ChromaVectorStoreManager
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from collections import defaultdict
|
2
|
+
import logging
|
1
3
|
import pandas as pd
|
2
4
|
from langchain.vectorstores import Chroma
|
3
5
|
from langchain_core.embeddings import Embeddings
|
@@ -5,10 +7,11 @@ from langchain_core.documents import Document
|
|
5
7
|
from langchain.text_splitter import TextSplitter
|
6
8
|
from typing import Callable, Optional, Dict, Any, List
|
7
9
|
import numpy as np
|
10
|
+
from pydantic import Field
|
8
11
|
|
9
12
|
from ddi_fw.langchain.faiss_storage import BaseVectorStoreManager
|
10
13
|
from langchain.document_loaders import DataFrameLoader
|
11
|
-
|
14
|
+
import chromadb
|
12
15
|
|
13
16
|
def split_dataframe(df, min_size=512):
|
14
17
|
total_size = len(df)
|
@@ -82,19 +85,13 @@ def split_dataframe_indices(df, min_size=512):
|
|
82
85
|
|
83
86
|
|
84
87
|
class ChromaVectorStoreManager(BaseVectorStoreManager):
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
):
|
93
|
-
super().__init__(embeddings)
|
94
|
-
self.collection_name = collection_name
|
95
|
-
self.persist_directory = persist_directory
|
96
|
-
self.text_splitter = text_splitter
|
97
|
-
self.batch_size = batch_size
|
88
|
+
collection_name: str = Field(default="default")
|
89
|
+
persist_directory: str = Field(default="./chroma_db")
|
90
|
+
text_splitter: Optional[TextSplitter] = None
|
91
|
+
batch_size: int = Field(default=1024)
|
92
|
+
|
93
|
+
class Config:
|
94
|
+
arbitrary_types_allowed = True
|
98
95
|
|
99
96
|
|
100
97
|
|
@@ -176,6 +173,76 @@ class ChromaVectorStoreManager(BaseVectorStoreManager):
|
|
176
173
|
# Chroma persists automatically, but you can copy files if needed
|
177
174
|
print("ChromaDB persists automatically. No explicit save needed.")
|
178
175
|
|
176
|
+
|
177
|
+
def __create_or_update_embeddings__(self, embedding_dict, vector_db_persist_directory, vector_db_collection_name, column=None):
|
178
|
+
"""
|
179
|
+
Fetch embeddings and metadata from a persistent Chroma vector database and update the provided embedding_dict.
|
180
|
+
|
181
|
+
Args:
|
182
|
+
- vector_db_persist_directory (str): The path to the directory where the Chroma vector database is stored.
|
183
|
+
- vector_db_collection_name (str): The name of the collection to query.
|
184
|
+
- embedding_dict (dict): The existing dictionary to update with embeddings.
|
185
|
+
|
186
|
+
"""
|
187
|
+
if vector_db_persist_directory:
|
188
|
+
# Initialize the Chroma client and get the collection
|
189
|
+
vector_db = chromadb.PersistentClient(
|
190
|
+
path=vector_db_persist_directory)
|
191
|
+
collection = vector_db.get_collection(vector_db_collection_name)
|
192
|
+
# include = [IncludeEnum.embeddings, IncludeEnum.metadatas]
|
193
|
+
include: chromadb.Include = ["embeddings","metadatas"]
|
194
|
+
dictionary: chromadb.GetResult
|
195
|
+
# Fetch the embeddings and metadata
|
196
|
+
if column == None:
|
197
|
+
dictionary = collection.get(
|
198
|
+
include=include
|
199
|
+
# include=['embeddings', 'metadatas']
|
200
|
+
)
|
201
|
+
print(
|
202
|
+
f"Embeddings are calculated from {vector_db_collection_name}")
|
203
|
+
else:
|
204
|
+
dictionary = collection.get(
|
205
|
+
include=include,
|
206
|
+
# include=['embeddings', 'metadatas'],
|
207
|
+
where={
|
208
|
+
"type": {"$eq": f"{column}"}})
|
209
|
+
print(
|
210
|
+
f"Embeddings of {column} are calculated from {vector_db_collection_name}")
|
211
|
+
|
212
|
+
# Populate the embedding dictionary with embeddings from the vector database
|
213
|
+
metadatas = dictionary["metadatas"]
|
214
|
+
embeddings = dictionary["embeddings"]
|
215
|
+
if metadatas is None or embeddings is None:
|
216
|
+
raise ValueError(
|
217
|
+
"The collection does not contain embeddings or metadatas.")
|
218
|
+
for metadata, embedding in zip(metadatas, embeddings):
|
219
|
+
embedding_dict[metadata["type"]
|
220
|
+
][metadata["id"]].append(embedding)
|
221
|
+
|
222
|
+
else:
|
223
|
+
raise ValueError(
|
224
|
+
"Persistent directory for the vector DB is not specified.")
|
225
|
+
|
226
|
+
def initialize_embedding_dict(self, **kwargs):
|
227
|
+
column_embedding_configs = kwargs.get("column_embedding_configs")
|
228
|
+
vector_db_persist_directory = kwargs.get("vector_db_persist_directory")
|
229
|
+
vector_db_collection_name = kwargs.get("vector_db_collection_name")
|
230
|
+
embedding_dict = defaultdict(lambda: defaultdict(list))
|
231
|
+
if column_embedding_configs:
|
232
|
+
for item in column_embedding_configs:
|
233
|
+
col = item["column"]
|
234
|
+
col_db_dir = item["vector_db_persist_directory"]
|
235
|
+
col_db_collection = item["vector_db_collection_name"]
|
236
|
+
self.__create_or_update_embeddings__(embedding_dict, col_db_dir, col_db_collection, col)
|
237
|
+
elif vector_db_persist_directory:
|
238
|
+
self.__create_or_update_embeddings__(embedding_dict, vector_db_persist_directory, vector_db_collection_name)
|
239
|
+
else:
|
240
|
+
logging.warning("There is no configuration of Embeddings")
|
241
|
+
raise ValueError(
|
242
|
+
"There is no configuration of Embeddings. Please provide a vector database directory and collection name.")
|
243
|
+
return embedding_dict
|
244
|
+
|
245
|
+
|
179
246
|
def load(self, path):
|
180
247
|
self.vector_store = Chroma(
|
181
248
|
collection_name=self.collection_name,
|
@@ -187,21 +254,66 @@ class ChromaVectorStoreManager(BaseVectorStoreManager):
|
|
187
254
|
self,
|
188
255
|
formatter_fn: Optional[Callable[[Document, np.ndarray], Dict[str, Any]]] = None
|
189
256
|
) -> pd.DataFrame:
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
257
|
+
"""
|
258
|
+
Retrieve all documents and their embeddings from the Chroma vector store
|
259
|
+
and return them as a pandas DataFrame.
|
260
|
+
"""
|
261
|
+
# Retrieve all data from the collection
|
262
|
+
# include=['embeddings', 'metadatas', 'documents']
|
263
|
+
results = self.vector_store._collection.get(include=['embeddings', 'metadatas', 'documents'])
|
264
|
+
|
265
|
+
# Ensure all lists are not None and have the same length
|
266
|
+
docs = results.get('documents', []) or []
|
267
|
+
metadatas = results.get('metadatas', []) or []
|
268
|
+
embeddings = results.get('embeddings', []) or []
|
269
|
+
|
270
|
+
# Check if all lists have the same length
|
271
|
+
if not (len(docs) == len(metadatas) == len(embeddings)):
|
272
|
+
# This should not happen if Chroma returns consistent results, but as a safeguard
|
273
|
+
raise ValueError(
|
274
|
+
"Inconsistent lengths of documents, metadatas, and embeddings. ")
|
275
|
+
# print("Warning: Inconsistent lengths of documents, metadatas, and embeddings.")
|
276
|
+
# # Find the minimum length to avoid errors
|
277
|
+
# min_len = min(len(docs), len(metadatas), len(embeddings))
|
278
|
+
# docs = docs[:min_len]
|
279
|
+
# metadatas = metadatas[:min_len]
|
280
|
+
# embeddings = embeddings[:min_len]
|
281
|
+
|
282
|
+
|
195
283
|
items = []
|
196
284
|
for doc, meta, emb in zip(docs, metadatas, embeddings):
|
197
285
|
document = Document(page_content=doc, metadata=meta)
|
198
286
|
if formatter_fn:
|
199
|
-
|
287
|
+
formatted_doc = formatter_fn(document, np.array(emb))
|
200
288
|
else:
|
201
|
-
|
202
|
-
items.append(
|
289
|
+
formatted_doc = document
|
290
|
+
items.append({
|
291
|
+
'document': formatted_doc,
|
292
|
+
'metadata': meta,
|
293
|
+
'embedding': emb
|
294
|
+
})
|
295
|
+
|
203
296
|
return pd.DataFrame(items)
|
204
297
|
|
298
|
+
# def as_dataframe(
|
299
|
+
# self,
|
300
|
+
# formatter_fn: Optional[Callable[[Document, np.ndarray], Dict[str, Any]]] = None
|
301
|
+
# ) -> pd.DataFrame:
|
302
|
+
# # Chroma does not expose direct vector access, so we fetch all docs and embeddings
|
303
|
+
# results = self.vector_store.get()
|
304
|
+
# docs = results['documents']
|
305
|
+
# metadatas = results['metadatas']
|
306
|
+
# embeddings = results['embeddings']
|
307
|
+
# items = []
|
308
|
+
# for doc, meta, emb in zip(docs, metadatas, embeddings):
|
309
|
+
# document = Document(page_content=doc, metadata=meta)
|
310
|
+
# if formatter_fn:
|
311
|
+
# item = formatter_fn(document, np.array(emb))
|
312
|
+
# else:
|
313
|
+
# item = {"embedding": emb, **meta}
|
314
|
+
# items.append(item)
|
315
|
+
# return pd.DataFrame(items)
|
316
|
+
|
205
317
|
def get_data(self, id):
|
206
318
|
# Chroma does not use integer IDs, but document IDs (UUIDs)
|
207
319
|
results = self.vector_store.get(ids=[id])
|
@@ -8,9 +8,17 @@ from langchain_core.documents import Document
|
|
8
8
|
import numpy as np # optional, if you're using NumPy vectors
|
9
9
|
from langchain_core.embeddings import Embeddings
|
10
10
|
|
11
|
-
|
12
|
-
|
13
|
-
|
11
|
+
from pydantic import BaseModel, Field
|
12
|
+
from langchain_core.embeddings import Embeddings
|
13
|
+
|
14
|
+
class BaseVectorStoreManager(BaseModel):
|
15
|
+
embeddings: Embeddings
|
16
|
+
|
17
|
+
class Config:
|
18
|
+
arbitrary_types_allowed = True
|
19
|
+
|
20
|
+
def initialize_embedding_dict(self, **kwargs):
|
21
|
+
raise NotImplementedError("This method should be implemented by subclasses.")
|
14
22
|
|
15
23
|
def generate_vector_store(self, docs):
|
16
24
|
raise NotImplementedError("This method should be implemented by subclasses.")
|
@@ -24,12 +32,12 @@ class BaseVectorStoreManager:
|
|
24
32
|
def as_dataframe(self, formatter_fn: Optional[Callable[[Document, np.ndarray], Dict[str, Any]]] = None) -> pd.DataFrame:
|
25
33
|
raise NotImplementedError("This method should be implemented by subclasses.")
|
26
34
|
|
27
|
-
class
|
28
|
-
|
29
|
-
|
30
|
-
self.index = None
|
31
|
-
self.vector_store = None
|
35
|
+
class FaissVectorStoreManager(BaseVectorStoreManager):
|
36
|
+
index: Any = None
|
37
|
+
vector_store: Any = None
|
32
38
|
|
39
|
+
class Config:
|
40
|
+
arbitrary_types_allowed = True
|
33
41
|
# def generate_vector_store(self, docs):
|
34
42
|
# dimension = len(self.embeddings.embed_query("hello world"))
|
35
43
|
# self.index = faiss.IndexFlatL2(dimension)
|
@@ -45,6 +53,14 @@ class VectorStoreManager:
|
|
45
53
|
# uuids = [str(uuid4()) for _ in range(len(docs))]
|
46
54
|
# self.vector_store.add_documents(documents=docs, ids=uuids)
|
47
55
|
|
56
|
+
def initialize_embedding_dict(self):
|
57
|
+
df = self.as_dataframe(formatter_fn=custom_formatter )
|
58
|
+
type_dict = (
|
59
|
+
df.groupby('type')
|
60
|
+
.apply(lambda group: dict(zip(group['id'], group['embedding'])))
|
61
|
+
.to_dict()
|
62
|
+
)
|
63
|
+
return type_dict
|
48
64
|
|
49
65
|
def generate_vector_store(self, docs, handle_empty='zero'):
|
50
66
|
"""
|
@@ -217,7 +233,7 @@ class VectorStoreManager:
|
|
217
233
|
|
218
234
|
def custom_formatter(document: Document, vector: np.ndarray) -> Dict[str, Any]:
|
219
235
|
return {
|
220
|
-
"
|
236
|
+
"id": document.metadata.get("drugbank_id", None),
|
221
237
|
"type": document.metadata.get("type", None),
|
222
238
|
"embedding": vector
|
223
239
|
}
|
@@ -1,5 +1,6 @@
|
|
1
1
|
import json
|
2
|
-
from typing import Optional
|
2
|
+
from typing import Optional, Type
|
3
|
+
from ddi_fw.langchain.faiss_storage import BaseVectorStoreManager
|
3
4
|
from ddi_fw.pipeline.pipeline import Pipeline
|
4
5
|
from ddi_fw.pipeline.ner_pipeline import NerParameterSearch
|
5
6
|
import importlib
|
@@ -128,11 +129,13 @@ class MultiPipeline():
|
|
128
129
|
|
129
130
|
# Vector database configuration
|
130
131
|
vector_database = config.get("vector_databases", {})
|
132
|
+
vector_store_manager_type:Type[BaseVectorStoreManager]|None = None
|
131
133
|
vector_db_persist_directory = None
|
132
134
|
vector_db_collection_name = None
|
133
135
|
embedding_pooling_strategy = None
|
134
136
|
column_embedding_configs = None
|
135
137
|
if vector_database:
|
138
|
+
vector_store_manager_type = get_import(vector_database.get("db_type"))
|
136
139
|
vector_db_persist_directory = vector_database.get("vector_db_persist_directory")
|
137
140
|
vector_db_collection_name = vector_database.get("vector_db_collection_name")
|
138
141
|
embedding_pooling_strategy = get_import(vector_database.get("embedding_pooling_strategy"))
|
@@ -181,6 +184,7 @@ class MultiPipeline():
|
|
181
184
|
dataset_additional_config=additional_config,
|
182
185
|
dataset_splitter_type=dataset_splitter_type,
|
183
186
|
columns=columns,
|
187
|
+
vector_store_manager_type=vector_store_manager_type,
|
184
188
|
column_embedding_configs=column_embedding_configs,
|
185
189
|
vector_db_persist_directory=vector_db_persist_directory,
|
186
190
|
vector_db_collection_name=vector_db_collection_name,
|
ddi_fw/pipeline/pipeline.py
CHANGED
@@ -3,6 +3,7 @@ from ddi_fw.datasets.dataset_splitter import DatasetSplitter
|
|
3
3
|
|
4
4
|
from pydantic import BaseModel
|
5
5
|
from ddi_fw.datasets.core import TextDatasetMixin
|
6
|
+
from ddi_fw.langchain.faiss_storage import BaseVectorStoreManager
|
6
7
|
from ddi_fw.ml.tracking_service import TrackingService
|
7
8
|
from ddi_fw.langchain.embeddings import PoolingStrategy
|
8
9
|
from ddi_fw.datasets import BaseDataset
|
@@ -26,6 +27,7 @@ class Pipeline(BaseModel):
|
|
26
27
|
vector_db_persist_directory: Optional[str] = None
|
27
28
|
vector_db_collection_name: Optional[str] = None
|
28
29
|
embedding_pooling_strategy_type: Type[PoolingStrategy] | None = None
|
30
|
+
vector_store_manager_type: Type[BaseVectorStoreManager] | None = None
|
29
31
|
combinations: Optional[List[tuple]] = None
|
30
32
|
model: Optional[Any] = None
|
31
33
|
default_model: Optional[Any] = None
|
@@ -85,8 +87,22 @@ class Pipeline(BaseModel):
|
|
85
87
|
dataset_splitter = self.dataset_splitter_type()
|
86
88
|
pooling_strategy = self.embedding_pooling_strategy_type(
|
87
89
|
) if self.embedding_pooling_strategy_type else None
|
90
|
+
|
91
|
+
params = {}
|
92
|
+
|
93
|
+
if self.embedding_dict is not None:
|
94
|
+
params["embedding_dict"] = self.embedding_dict
|
95
|
+
if self.vector_db_persist_directory is not None:
|
96
|
+
params["persist_directory"] = self.vector_db_persist_directory
|
97
|
+
if self.vector_db_collection_name is not None:
|
98
|
+
params["collection_name"] = self.vector_db_collection_name
|
99
|
+
|
100
|
+
|
101
|
+
vector_store_manager = self.vector_store_manager_type(**params) if self.vector_store_manager_type else None
|
88
102
|
if issubclass(self.dataset_type, TextDatasetMixin):
|
103
|
+
|
89
104
|
dataset = self.dataset_type(
|
105
|
+
vector_store_manager = vector_store_manager,
|
90
106
|
embedding_dict=self.embedding_dict,
|
91
107
|
pooling_strategy=pooling_strategy,
|
92
108
|
column_embedding_configs=self.column_embedding_configs,
|
@@ -1,12 +1,12 @@
|
|
1
1
|
ddi_fw/datasets/__init__.py,sha256=NozQvXPYIS01U0srZmcKhiqJgRDkD-C-VXHL6sKrFSw,166
|
2
|
-
ddi_fw/datasets/core.py,sha256=
|
2
|
+
ddi_fw/datasets/core.py,sha256=FGa_OfM6oHGPYt5TmZczepkZ9F6sNxJPpVoMYa1FiB8,17421
|
3
3
|
ddi_fw/datasets/dataset_splitter.py,sha256=8H8uZTAf8N9LUZeSeHOMawtJFJhnDgUUqFcnl7dquBQ,1672
|
4
4
|
ddi_fw/datasets/db_utils.py,sha256=xRj28U_uXTRPHcz3yIICczFUHXUPiAOZtAj5BM6kH44,6465
|
5
5
|
ddi_fw/datasets/setup_._py,sha256=khYVJuW5PlOY_i_A16F3UbSZ6s6o_ljw33Byw3C-A8E,1047
|
6
|
-
ddi_fw/langchain/__init__.py,sha256=
|
7
|
-
ddi_fw/langchain/chroma_storage.py,sha256=
|
6
|
+
ddi_fw/langchain/__init__.py,sha256=xGNaTEZCUxyc_aT1zvzVWGRfsj-9VXqMvPKtV_G7ChA,399
|
7
|
+
ddi_fw/langchain/chroma_storage.py,sha256=I8xoqlc2K4gJdOUn5b33mGGMPFKYG3UiptY2HeM34_c,15483
|
8
8
|
ddi_fw/langchain/embeddings.py,sha256=eEWy4okcjdhUJHi4N48Wd8XauPXyeaQVLUdNWEvtEcY,6754
|
9
|
-
ddi_fw/langchain/faiss_storage.py,sha256=
|
9
|
+
ddi_fw/langchain/faiss_storage.py,sha256=H--yYOmHX7nr34THNojqP_qhGXd-kMkhzzWDbMMeoqo,8923
|
10
10
|
ddi_fw/langchain/sentence_splitter.py,sha256=h_bYElx4Ud1mwDNJfL7mUwvgadwKX3GKlSzu5L2PXzg,280
|
11
11
|
ddi_fw/langchain/storage.py,sha256=OizKyWm74Js7T6Q9kez-ulUoBGzIMFo4R46h4kjUyIM,11200
|
12
12
|
ddi_fw/ml/__init__.py,sha256=FteYEawCkVQOaK-cTv2VrHZ2ZnfeFr31BD6VucO7_DQ,268
|
@@ -21,10 +21,10 @@ ddi_fw/ner/mmlrestclient.py,sha256=NZta7m2Qm6I_qtVguMZhqtAUjVBmmXn0-TMnsNp0jpg,6
|
|
21
21
|
ddi_fw/ner/ner.py,sha256=FHyyX53Xwpdw8Hec261dyN88yD7Z9LmJua2mIrQLguI,17967
|
22
22
|
ddi_fw/pipeline/__init__.py,sha256=tKDM_rW4vPjlYTeOkNgi9PujDzb4e9O3LK1w5wqnebw,212
|
23
23
|
ddi_fw/pipeline/multi_modal_combination_strategy.py,sha256=JSyuP71b1I1yuk0s2ecCJZTtCED85jBtkpwTUxibJvI,1706
|
24
|
-
ddi_fw/pipeline/multi_pipeline.py,sha256=
|
24
|
+
ddi_fw/pipeline/multi_pipeline.py,sha256=jHjSfQmRQ-zEwh_5ZPdG4MBVYMrRRzlqYgFAMbDZN0g,10206
|
25
25
|
ddi_fw/pipeline/multi_pipeline_org.py,sha256=AbErwu05-3YIPnCcXRsj-jxPJG8HG2H7cMZlGjzaYa8,9037
|
26
26
|
ddi_fw/pipeline/ner_pipeline.py,sha256=1gBk81LeZlU1rhjJ1qBgHbFt_HqOeJ5WLnJ4AkYku4s,8188
|
27
|
-
ddi_fw/pipeline/pipeline.py,sha256=
|
27
|
+
ddi_fw/pipeline/pipeline.py,sha256=m6pZrhoBK2lUr7PwpmJl6-WEpYcPGGc9N9C1LNJ78NQ,6974
|
28
28
|
ddi_fw/utils/__init__.py,sha256=WNxkQXk-694roG50D355TGLXstfdWVb_tUyr-PM-8rg,537
|
29
29
|
ddi_fw/utils/categorical_data_encoding_checker.py,sha256=T1X70Rh4atucAuqyUZmz-iFULllY9dY0NRyV9-jTjJ0,3438
|
30
30
|
ddi_fw/utils/enums.py,sha256=19eJ3fX5eRK_xPvkYcukmug144jXPH4X9zQqtsFBj5A,671
|
@@ -38,7 +38,7 @@ ddi_fw/utils/zip_helper.py,sha256=YRZA4tKZVBJwGQM0_WK6L-y5MoqkKoC-nXuuHK6CU9I,55
|
|
38
38
|
ddi_fw/vectorization/__init__.py,sha256=LcJOpLVoLvHPDw9phGFlUQGeNcST_zKV-Oi1Pm5h_nE,110
|
39
39
|
ddi_fw/vectorization/feature_vector_generation.py,sha256=QQQGhCti653BdU343Ag1bH_g1fzi2hlic7dgNy7otjE,7694
|
40
40
|
ddi_fw/vectorization/idf_helper.py,sha256=_Gd1dtDSLaw8o-o0JugzSKMt9FpeXewTh4wGEaUd4VQ,2571
|
41
|
-
ddi_fw-0.0.
|
42
|
-
ddi_fw-0.0.
|
43
|
-
ddi_fw-0.0.
|
44
|
-
ddi_fw-0.0.
|
41
|
+
ddi_fw-0.0.242.dist-info/METADATA,sha256=jq8Op7HG_u5PE0DjELixnPMKwEl6mUkNtPTyQ5uBWU8,2632
|
42
|
+
ddi_fw-0.0.242.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
43
|
+
ddi_fw-0.0.242.dist-info/top_level.txt,sha256=PMwHICFZTZtcpzQNPV4UQnfNXYIeLR_Ste-Wfc1h810,7
|
44
|
+
ddi_fw-0.0.242.dist-info/RECORD,,
|
File without changes
|
File without changes
|