mb-rag 1.1.22__py3-none-any.whl → 1.1.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mb-rag might be problematic. Click here for more details.
- mb_rag/chatbot/basic.py +21 -13
- mb_rag/version.py +1 -1
- {mb_rag-1.1.22.dist-info → mb_rag-1.1.24.dist-info}/METADATA +1 -1
- {mb_rag-1.1.22.dist-info → mb_rag-1.1.24.dist-info}/RECORD +6 -6
- {mb_rag-1.1.22.dist-info → mb_rag-1.1.24.dist-info}/WHEEL +0 -0
- {mb_rag-1.1.22.dist-info → mb_rag-1.1.24.dist-info}/top_level.txt +0 -0
mb_rag/chatbot/basic.py
CHANGED
|
@@ -209,7 +209,7 @@ class ModelFactory:
|
|
|
209
209
|
Create and load hugging face model.
|
|
210
210
|
Args:
|
|
211
211
|
model_name (str): Name of the model
|
|
212
|
-
model_function (str): model function
|
|
212
|
+
model_function (str): model function. Default is image-text-to-text.
|
|
213
213
|
device (str): Device to use. Default is cpu
|
|
214
214
|
**kwargs: Additional arguments
|
|
215
215
|
Returns:
|
|
@@ -218,12 +218,12 @@ class ModelFactory:
|
|
|
218
218
|
if not check_package("transformers"):
|
|
219
219
|
raise ImportError("Transformers package not found. Please install it using: pip install transformers")
|
|
220
220
|
if not check_package("langchain_huggingface"):
|
|
221
|
-
raise ImportError("
|
|
221
|
+
raise ImportError("langchain_huggingface package not found. Please install it using: pip install langchain_huggingface")
|
|
222
222
|
if not check_package("torch"):
|
|
223
223
|
raise ImportError("Torch package not found. Please install it using: pip install torch")
|
|
224
224
|
|
|
225
225
|
from langchain_huggingface import HuggingFacePipeline
|
|
226
|
-
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
|
226
|
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForImageTextToText,AutoProcessor
|
|
227
227
|
import torch
|
|
228
228
|
|
|
229
229
|
device = torch.device(device) if torch.cuda.is_available() else torch.device("cpu")
|
|
@@ -231,14 +231,23 @@ class ModelFactory:
|
|
|
231
231
|
temperature = kwargs.pop("temperature", 0.7)
|
|
232
232
|
max_length = kwargs.pop("max_length", 1024)
|
|
233
233
|
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
234
|
+
if model_function == "image-text-to-text":
|
|
235
|
+
tokenizer = AutoProcessor.from_pretrained(model_name,trust_remote_code=True)
|
|
236
|
+
model = AutoModelForImageTextToText.from_pretrained(
|
|
237
|
+
model_name,
|
|
238
|
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
|
239
|
+
device_map=device,
|
|
240
|
+
trust_remote_code=True,
|
|
241
|
+
**kwargs
|
|
242
|
+
)
|
|
243
|
+
else:
|
|
244
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name,trust_remote_code=True)
|
|
245
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
246
|
+
model_name,
|
|
247
|
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
|
248
|
+
device_map=device,
|
|
249
|
+
trust_remote_code=True,
|
|
250
|
+
**kwargs)
|
|
242
251
|
|
|
243
252
|
# Create pipeline
|
|
244
253
|
pipe = pipeline(
|
|
@@ -246,8 +255,7 @@ class ModelFactory:
|
|
|
246
255
|
model=model,
|
|
247
256
|
tokenizer=tokenizer,
|
|
248
257
|
max_length=max_length,
|
|
249
|
-
temperature=temperature
|
|
250
|
-
device=device
|
|
258
|
+
temperature=temperature
|
|
251
259
|
)
|
|
252
260
|
|
|
253
261
|
# Create and return LangChain HuggingFacePipeline
|
mb_rag/version.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
mb_rag/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
mb_rag/version.py,sha256=
|
|
2
|
+
mb_rag/version.py,sha256=vdrfmhTVWpawOFVwGZBQMQ63LTdqBCR1S5yLSyF-vIY,207
|
|
3
3
|
mb_rag/chatbot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
|
-
mb_rag/chatbot/basic.py,sha256=
|
|
4
|
+
mb_rag/chatbot/basic.py,sha256=_D2U3QzRGs7BPT691ii3TM49Gj3d1mKXnlFdniz2Wy0,23104
|
|
5
5
|
mb_rag/chatbot/chains.py,sha256=vDbLX5R29sWN1pcFqJ5fyxJEgMCM81JAikunAEvMC9A,7223
|
|
6
6
|
mb_rag/chatbot/prompts.py,sha256=n1PyiLbU-5fkslRv6aVOzt0dDlwya_cEdQ7kRnRhMuY,1749
|
|
7
7
|
mb_rag/rag/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -9,7 +9,7 @@ mb_rag/rag/embeddings.py,sha256=KjBdekFDb5M3dRMco4r3dDMXMsG5dxdzKImuVIipsd0,2709
|
|
|
9
9
|
mb_rag/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
10
|
mb_rag/utils/bounding_box.py,sha256=G0hdDam8QmYtD9lfwMeDHGm-TTo6KZg-yK5ESFL9zaM,8366
|
|
11
11
|
mb_rag/utils/extra.py,sha256=spbFrGgdruNyYQ5PzgvpSIa6Nm0rn9bb4qc8W9g582o,2492
|
|
12
|
-
mb_rag-1.1.
|
|
13
|
-
mb_rag-1.1.
|
|
14
|
-
mb_rag-1.1.
|
|
15
|
-
mb_rag-1.1.
|
|
12
|
+
mb_rag-1.1.24.dist-info/METADATA,sha256=AbJF5--ihwpl7dGVLRqVwiDHrqbnUL5_6-QXBPUKi3k,234
|
|
13
|
+
mb_rag-1.1.24.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
|
|
14
|
+
mb_rag-1.1.24.dist-info/top_level.txt,sha256=FIK1eAa5uYnurgXZquBG-s3PIy-HDTC5yJBW4lTH_pM,7
|
|
15
|
+
mb_rag-1.1.24.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|