hjxdl 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hdl/_version.py +2 -2
- hdl/utils/llm/visrag.py +25 -10
- {hjxdl-0.2.10.dist-info → hjxdl-0.2.12.dist-info}/METADATA +1 -1
- {hjxdl-0.2.10.dist-info → hjxdl-0.2.12.dist-info}/RECORD +6 -6
- {hjxdl-0.2.10.dist-info → hjxdl-0.2.12.dist-info}/WHEEL +0 -0
- {hjxdl-0.2.10.dist-info → hjxdl-0.2.12.dist-info}/top_level.txt +0 -0
hdl/_version.py
CHANGED
hdl/utils/llm/visrag.py
CHANGED
@@ -7,6 +7,8 @@ import gradio as gr
|
|
7
7
|
import os
|
8
8
|
import numpy as np
|
9
9
|
import json
|
10
|
+
import base64
|
11
|
+
import io
|
10
12
|
from transformers import AutoModel, AutoTokenizer
|
11
13
|
from hdl.utils.llm.chat import OpenAI_M
|
12
14
|
|
@@ -104,11 +106,21 @@ def retrieve_gradio(knowledge_base, query, topk, cache_dir=None, model=None, tok
|
|
104
106
|
|
105
107
|
return images_topk
|
106
108
|
|
109
|
+
def convert_image_to_base64(image):
|
110
|
+
"""Convert a PIL Image to a base64 encoded string."""
|
111
|
+
buffered = io.BytesIO()
|
112
|
+
image.save(buffered, format="PNG")
|
113
|
+
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
114
|
+
return image_base64
|
115
|
+
|
107
116
|
def answer_question(images, question, gen_model):
|
108
|
-
|
117
|
+
# Convert images to base64
|
118
|
+
images_base64 = [convert_image_to_base64(Image.open(image[0]).convert('RGB')) for image in images]
|
119
|
+
|
120
|
+
# Pass base64-encoded images to gen_model.chat
|
109
121
|
answer = gen_model.chat(
|
110
122
|
prompt=question,
|
111
|
-
images=
|
123
|
+
images=images_base64, # Use the base64 images
|
112
124
|
stream=False
|
113
125
|
)
|
114
126
|
return answer
|
@@ -142,7 +154,7 @@ if __name__ == '__main__':
|
|
142
154
|
parser.add_argument('--cache-dir', dest='cache_dir', type=str, required=True, help='Cache directory path')
|
143
155
|
parser.add_argument('--device', dest='device', type=str, default='cuda:0', help='Device for model inference')
|
144
156
|
parser.add_argument('--model-path', dest='model_path', type=str, required=True, help='Path to the embedding model')
|
145
|
-
parser.add_argument('--llm-host', dest='llm_host', type=str, default='127.0.0.
|
157
|
+
parser.add_argument('--llm-host', dest='llm_host', type=str, default='127.0.0.1', help='LLM server IP address')
|
146
158
|
parser.add_argument('--llm-port', dest='llm_port', type=int, default=22299, help='LLM server port')
|
147
159
|
parser.add_argument('--server-name', dest='server_name', type=str, default='0.0.0.0', help='Gradio server name')
|
148
160
|
parser.add_argument('--server-port', dest='server_port', type=int, default=10077, help='Gradio server port')
|
@@ -168,8 +180,8 @@ if __name__ == '__main__':
|
|
168
180
|
file_result = gr.Text(label="Knowledge Base ID")
|
169
181
|
process_button = gr.Button("Process PDF")
|
170
182
|
|
171
|
-
process_button.click(add_pdf_gradio,
|
172
|
-
|
183
|
+
process_button.click(lambda pdf: add_pdf_gradio(pdf, cache_dir=args.cache_dir, model=model, tokenizer=tokenizer),
|
184
|
+
inputs=file_input, outputs=file_result)
|
173
185
|
|
174
186
|
kb_id_input = gr.Text(label="Knowledge Base ID")
|
175
187
|
query_input = gr.Text(label="Your Question")
|
@@ -177,18 +189,21 @@ if __name__ == '__main__':
|
|
177
189
|
retrieve_button = gr.Button("Retrieve Pages")
|
178
190
|
images_output = gr.Gallery(label="Retrieved Pages")
|
179
191
|
|
180
|
-
retrieve_button.click(retrieve_gradio,
|
181
|
-
|
192
|
+
retrieve_button.click(lambda kb, query, topk: retrieve_gradio(kb, query, topk, cache_dir=args.cache_dir, model=model, tokenizer=tokenizer),
|
193
|
+
inputs=[kb_id_input, query_input, topk_input], outputs=images_output)
|
182
194
|
|
183
195
|
button = gr.Button("Answer Question")
|
184
196
|
gen_model_response = gr.Textbox(label="MiniCPM-V-2.6's Answer")
|
185
197
|
|
186
|
-
button.click(
|
198
|
+
button.click(lambda images, question: answer_question(images, question, gen_model),
|
199
|
+
inputs=[images_output, query_input], outputs=gen_model_response)
|
187
200
|
|
188
201
|
upvote_button = gr.Button("🤗 Upvote")
|
189
202
|
downvote_button = gr.Button("🤣 Downvote")
|
190
203
|
|
191
|
-
upvote_button.click(
|
192
|
-
|
204
|
+
upvote_button.click(lambda kb, query: upvote(kb, query, cache_dir=args.cache_dir),
|
205
|
+
inputs=[kb_id_input, query_input], outputs=None)
|
206
|
+
downvote_button.click(lambda kb, query: downvote(kb, query, cache_dir=args.cache_dir),
|
207
|
+
inputs=[kb_id_input, query_input], outputs=None)
|
193
208
|
|
194
209
|
app.launch(server_name=args.server_name, server_port=args.server_port)
|
@@ -1,5 +1,5 @@
|
|
1
1
|
hdl/__init__.py,sha256=GffnD0jLJdhkd-vo989v40N90sQbofkayRBwxc6TVhQ,72
|
2
|
-
hdl/_version.py,sha256=
|
2
|
+
hdl/_version.py,sha256=svcGBcBtXuMqAz4h0v_Mst9pNApgzncWOHA1dv0p748,413
|
3
3
|
hdl/args/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
hdl/args/loss_args.py,sha256=s7YzSdd7IjD24rZvvOrxLLFqMZQb9YylxKeyelSdrTk,70
|
5
5
|
hdl/controllers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -133,12 +133,12 @@ hdl/utils/llm/embs.py,sha256=Tf0FOYrOFZp7qQpEPiSCXzlgyHH0X9HVTUtsup74a9E,7174
|
|
133
133
|
hdl/utils/llm/extract.py,sha256=2sK_WJzmYIc8iuWaM9DA6Nw3_6q1O4lJ5pKpcZo-bBA,6512
|
134
134
|
hdl/utils/llm/llama_chat.py,sha256=watcHGOaz-bv3x-yDucYlGk5f8FiqfFhwWogrl334fk,4387
|
135
135
|
hdl/utils/llm/vis.py,sha256=2pI0439GWi_BEVfQJtY29Y72FkUa8jEvBeqMlwy7xkc,15716
|
136
|
-
hdl/utils/llm/visrag.py,sha256=
|
136
|
+
hdl/utils/llm/visrag.py,sha256=jZgo1awEVRq3z0IEXs-soa1scbeHSBk_IUq9c7rS5ZA,8300
|
137
137
|
hdl/utils/schedulers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
138
138
|
hdl/utils/schedulers/norm_lr.py,sha256=bDwCmdEK-WkgxQMFBiMuchv8Mm7C0-GZJ6usm-PQk14,4461
|
139
139
|
hdl/utils/weather/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
140
140
|
hdl/utils/weather/weather.py,sha256=k11o6wM15kF8b9NMlEfrg68ak-SfSYLN3nOOflFUv-I,4381
|
141
|
-
hjxdl-0.2.
|
142
|
-
hjxdl-0.2.
|
143
|
-
hjxdl-0.2.
|
144
|
-
hjxdl-0.2.
|
141
|
+
hjxdl-0.2.12.dist-info/METADATA,sha256=J68Ist70QStGrBWqWYK2bzoBFV40DI_hYOhOyCu9150,836
|
142
|
+
hjxdl-0.2.12.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
|
143
|
+
hjxdl-0.2.12.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
|
144
|
+
hjxdl-0.2.12.dist-info/RECORD,,
|
File without changes
|
File without changes
|