hjxdl 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hdl/_version.py +2 -2
- hdl/utils/llm/chatgr.py +1 -1
- hdl/utils/llm/visrag.py +194 -0
- {hjxdl-0.2.7.dist-info → hjxdl-0.2.8.dist-info}/METADATA +1 -1
- {hjxdl-0.2.7.dist-info → hjxdl-0.2.8.dist-info}/RECORD +7 -6
- {hjxdl-0.2.7.dist-info → hjxdl-0.2.8.dist-info}/WHEEL +0 -0
- {hjxdl-0.2.7.dist-info → hjxdl-0.2.8.dist-info}/top_level.txt +0 -0
hdl/_version.py
CHANGED
hdl/utils/llm/chatgr.py
CHANGED
hdl/utils/llm/visrag.py
ADDED
@@ -0,0 +1,194 @@
|
|
1
|
+
import argparse
|
2
|
+
from PIL import Image
|
3
|
+
import hashlib
|
4
|
+
import torch
|
5
|
+
import fitz
|
6
|
+
import gradio as gr
|
7
|
+
import os
|
8
|
+
import numpy as np
|
9
|
+
import json
|
10
|
+
from transformers import AutoModel, AutoTokenizer
|
11
|
+
from hdl.utils.llm.chat import OpenAI_M
|
12
|
+
|
13
|
+
def get_image_md5(img: Image.Image):
|
14
|
+
img_byte_array = img.tobytes()
|
15
|
+
hash_md5 = hashlib.md5()
|
16
|
+
hash_md5.update(img_byte_array)
|
17
|
+
hex_digest = hash_md5.hexdigest()
|
18
|
+
return hex_digest
|
19
|
+
|
20
|
+
def calculate_md5_from_binary(binary_data):
|
21
|
+
hash_md5 = hashlib.md5()
|
22
|
+
hash_md5.update(binary_data)
|
23
|
+
return hash_md5.hexdigest()
|
24
|
+
|
25
|
+
def add_pdf_gradio(pdf_file_binary, progress=gr.Progress(), cache_dir=None, model=None, tokenizer=None):
|
26
|
+
model.eval()
|
27
|
+
|
28
|
+
knowledge_base_name = calculate_md5_from_binary(pdf_file_binary)
|
29
|
+
|
30
|
+
this_cache_dir = os.path.join(cache_dir, knowledge_base_name)
|
31
|
+
os.makedirs(this_cache_dir, exist_ok=True)
|
32
|
+
|
33
|
+
with open(os.path.join(this_cache_dir, f"src.pdf"), 'wb') as file:
|
34
|
+
file.write(pdf_file_binary)
|
35
|
+
|
36
|
+
dpi = 200
|
37
|
+
doc = fitz.open("pdf", pdf_file_binary)
|
38
|
+
|
39
|
+
reps_list = []
|
40
|
+
images = []
|
41
|
+
image_md5s = []
|
42
|
+
|
43
|
+
for page in progress.tqdm(doc):
|
44
|
+
pix = page.get_pixmap(dpi=dpi)
|
45
|
+
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
46
|
+
image_md5 = get_image_md5(image)
|
47
|
+
image_md5s.append(image_md5)
|
48
|
+
with torch.no_grad():
|
49
|
+
reps = model(text=[''], image=[image], tokenizer=tokenizer).reps
|
50
|
+
reps_list.append(reps.squeeze(0).cpu().numpy())
|
51
|
+
images.append(image)
|
52
|
+
|
53
|
+
for idx in range(len(images)):
|
54
|
+
image = images[idx]
|
55
|
+
image_md5 = image_md5s[idx]
|
56
|
+
cache_image_path = os.path.join(this_cache_dir, f"{image_md5}.png")
|
57
|
+
image.save(cache_image_path)
|
58
|
+
|
59
|
+
np.save(os.path.join(this_cache_dir, f"reps.npy"), reps_list)
|
60
|
+
|
61
|
+
with open(os.path.join(this_cache_dir, f"md5s.txt"), 'w') as f:
|
62
|
+
for item in image_md5s:
|
63
|
+
f.write(item+'\n')
|
64
|
+
|
65
|
+
return knowledge_base_name
|
66
|
+
|
67
|
+
def retrieve_gradio(knowledge_base, query, topk, cache_dir=None, model=None, tokenizer=None):
|
68
|
+
model.eval()
|
69
|
+
|
70
|
+
target_cache_dir = os.path.join(cache_dir, knowledge_base)
|
71
|
+
|
72
|
+
if not os.path.exists(target_cache_dir):
|
73
|
+
return None
|
74
|
+
|
75
|
+
md5s = []
|
76
|
+
with open(os.path.join(target_cache_dir, f"md5s.txt"), 'r') as f:
|
77
|
+
for line in f:
|
78
|
+
md5s.append(line.rstrip('\n'))
|
79
|
+
|
80
|
+
doc_reps = np.load(os.path.join(target_cache_dir, f"reps.npy"))
|
81
|
+
|
82
|
+
query_with_instruction = "Represent this query for retrieving relevant document: " + query
|
83
|
+
with torch.no_grad():
|
84
|
+
query_rep = model(text=[query_with_instruction], image=[None], tokenizer=tokenizer).reps.squeeze(0).cpu()
|
85
|
+
|
86
|
+
query_md5 = hashlib.md5(query.encode()).hexdigest()
|
87
|
+
|
88
|
+
doc_reps_cat = torch.stack([torch.Tensor(i) for i in doc_reps], dim=0)
|
89
|
+
|
90
|
+
similarities = torch.matmul(query_rep, doc_reps_cat.T)
|
91
|
+
|
92
|
+
topk_values, topk_doc_ids = torch.topk(similarities, k=topk)
|
93
|
+
|
94
|
+
images_topk = [Image.open(os.path.join(target_cache_dir, f"{md5s[idx]}.png")) for idx in topk_doc_ids.cpu().numpy()]
|
95
|
+
|
96
|
+
with open(os.path.join(target_cache_dir, f"q-{query_md5}.json"), 'w') as f:
|
97
|
+
f.write(json.dumps(
|
98
|
+
{
|
99
|
+
"knowledge_base": knowledge_base,
|
100
|
+
"query": query,
|
101
|
+
"retrieved_docs": [os.path.join(target_cache_dir, f"{md5s[idx]}.png") for idx in topk_doc_ids.cpu().numpy()]
|
102
|
+
}, indent=4, ensure_ascii=False
|
103
|
+
))
|
104
|
+
|
105
|
+
return images_topk
|
106
|
+
|
107
|
+
def answer_question(images, question, gen_model):
|
108
|
+
images_ = [Image.open(image[0]).convert('RGB') for image in images]
|
109
|
+
answer = gen_model.chat(
|
110
|
+
prompt=question,
|
111
|
+
images=images_,
|
112
|
+
stream=False
|
113
|
+
)
|
114
|
+
return answer
|
115
|
+
|
116
|
+
def upvote(knowledge_base, query, cache_dir):
|
117
|
+
target_cache_dir = os.path.join(cache_dir, knowledge_base)
|
118
|
+
query_md5 = hashlib.md5(query.encode()).hexdigest()
|
119
|
+
|
120
|
+
with open(os.path.join(target_cache_dir, f"q-{query_md5}.json"), 'r') as f:
|
121
|
+
data = json.loads(f.read())
|
122
|
+
|
123
|
+
data["user_preference"] = "upvote"
|
124
|
+
|
125
|
+
with open(os.path.join(target_cache_dir, f"q-{query_md5}-withpref.json"), 'w') as f:
|
126
|
+
f.write(json.dumps(data, indent=4, ensure_ascii=False))
|
127
|
+
|
128
|
+
def downvote(knowledge_base, query, cache_dir):
|
129
|
+
target_cache_dir = os.path.join(cache_dir, knowledge_base)
|
130
|
+
query_md5 = hashlib.md5(query.encode()).hexdigest()
|
131
|
+
|
132
|
+
with open(os.path.join(target_cache_dir, f"q-{query_md5}.json"), 'r') as f:
|
133
|
+
data = json.loads(f.read())
|
134
|
+
|
135
|
+
data["user_preference"] = "downvote"
|
136
|
+
|
137
|
+
with open(os.path.join(target_cache_dir, f"q-{query_md5}-withpref.json"), 'w') as f:
|
138
|
+
f.write(json.dumps(data, indent=4, ensure_ascii=False))
|
139
|
+
|
140
|
+
if __name__ == '__main__':
|
141
|
+
parser = argparse.ArgumentParser(description="RAG-PDFQA Script")
|
142
|
+
parser.add_argument('--cache_dir', type=str, required=True, help='Cache directory path')
|
143
|
+
parser.add_argument('--device', type=str, default='cuda:0', help='Device for model inference')
|
144
|
+
parser.add_argument('--model_path', type=str, required=True, help='Path to the embedding model')
|
145
|
+
parser.add_argument('--llm_host', type=str, default='127.0.0.0', help='LLM server IP address')
|
146
|
+
parser.add_argument('--llm_port', type=int, default=22299, help='LLM server port')
|
147
|
+
parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Gradio server name')
|
148
|
+
parser.add_argument('--server_port', type=int, default=10077, help='Gradio server port')
|
149
|
+
|
150
|
+
args = parser.parse_args()
|
151
|
+
|
152
|
+
print("Loading embedding model...")
|
153
|
+
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
|
154
|
+
model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True)
|
155
|
+
model.to(args.device)
|
156
|
+
model.eval()
|
157
|
+
print("Embedding model loaded!")
|
158
|
+
|
159
|
+
gen_model = OpenAI_M(
|
160
|
+
server_ip=args.llm_host,
|
161
|
+
server_port=args.llm_port
|
162
|
+
)
|
163
|
+
|
164
|
+
with gr.Blocks() as app:
|
165
|
+
gr.Markdown("# Vision Language Models Enable End-to-End RAG")
|
166
|
+
|
167
|
+
file_input = gr.File(type="binary", label="Step 1: Upload PDF")
|
168
|
+
file_result = gr.Text(label="Knowledge Base ID")
|
169
|
+
process_button = gr.Button("Process PDF")
|
170
|
+
|
171
|
+
process_button.click(add_pdf_gradio, inputs=[file_input], outputs=file_result,
|
172
|
+
_kwargs={'cache_dir': args.cache_dir, 'model': model, 'tokenizer': tokenizer})
|
173
|
+
|
174
|
+
kb_id_input = gr.Text(label="Knowledge Base ID")
|
175
|
+
query_input = gr.Text(label="Your Question")
|
176
|
+
topk_input = gr.Number(value=5, minimum=1, maximum=10, step=1, label="Number of pages to retrieve")
|
177
|
+
retrieve_button = gr.Button("Retrieve Pages")
|
178
|
+
images_output = gr.Gallery(label="Retrieved Pages")
|
179
|
+
|
180
|
+
retrieve_button.click(retrieve_gradio, inputs=[kb_id_input, query_input, topk_input], outputs=images_output,
|
181
|
+
_kwargs={'cache_dir': args.cache_dir, 'model': model, 'tokenizer': tokenizer})
|
182
|
+
|
183
|
+
button = gr.Button("Answer Question")
|
184
|
+
gen_model_response = gr.Textbox(label="Model's Answer")
|
185
|
+
|
186
|
+
button.click(answer_question, inputs=[images_output, query_input], outputs=gen_model_response, _kwargs={'gen_model': gen_model})
|
187
|
+
|
188
|
+
upvote_button = gr.Button("🤗 Upvote")
|
189
|
+
downvote_button = gr.Button("🤣 Downvote")
|
190
|
+
|
191
|
+
upvote_button.click(upvote, inputs=[kb_id_input, query_input], outputs=None, _kwargs={'cache_dir': args.cache_dir})
|
192
|
+
downvote_button.click(downvote, inputs=[kb_id_input, query_input], outputs=None, _kwargs={'cache_dir': args.cache_dir})
|
193
|
+
|
194
|
+
app.launch(server_name=args.server_name, server_port=args.server_port)
|
@@ -1,5 +1,5 @@
|
|
1
1
|
hdl/__init__.py,sha256=GffnD0jLJdhkd-vo989v40N90sQbofkayRBwxc6TVhQ,72
|
2
|
-
hdl/_version.py,sha256=
|
2
|
+
hdl/_version.py,sha256=fYxHoWMqnWqlH_xgKiqQVeWMgvVIUtmXuDqSFMUT1bU,411
|
3
3
|
hdl/args/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
hdl/args/loss_args.py,sha256=s7YzSdd7IjD24rZvvOrxLLFqMZQb9YylxKeyelSdrTk,70
|
5
5
|
hdl/controllers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -128,16 +128,17 @@ hdl/utils/general/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU
|
|
128
128
|
hdl/utils/general/glob.py,sha256=8-RCnt6L297wMIfn34ZAMCsGCZUjHG3MGglGZI1cX0g,491
|
129
129
|
hdl/utils/llm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
130
130
|
hdl/utils/llm/chat.py,sha256=OzyY9xACOOocx9zZigtq9YAPvHtDUo8v2fvf1Tyjg_U,14891
|
131
|
-
hdl/utils/llm/chatgr.py,sha256
|
131
|
+
hdl/utils/llm/chatgr.py,sha256=-1-c2GUYCqSiky_QFf54RjJnmFgdgvPjeRHYMrwjFT0,2595
|
132
132
|
hdl/utils/llm/embs.py,sha256=Tf0FOYrOFZp7qQpEPiSCXzlgyHH0X9HVTUtsup74a9E,7174
|
133
133
|
hdl/utils/llm/extract.py,sha256=2sK_WJzmYIc8iuWaM9DA6Nw3_6q1O4lJ5pKpcZo-bBA,6512
|
134
134
|
hdl/utils/llm/llama_chat.py,sha256=watcHGOaz-bv3x-yDucYlGk5f8FiqfFhwWogrl334fk,4387
|
135
135
|
hdl/utils/llm/vis.py,sha256=2pI0439GWi_BEVfQJtY29Y72FkUa8jEvBeqMlwy7xkc,15716
|
136
|
+
hdl/utils/llm/visrag.py,sha256=BBWmYI8p9e9ZclWBP8nC3kuv7LzvYlg-7gyw_JTZ7K0,7556
|
136
137
|
hdl/utils/schedulers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
137
138
|
hdl/utils/schedulers/norm_lr.py,sha256=bDwCmdEK-WkgxQMFBiMuchv8Mm7C0-GZJ6usm-PQk14,4461
|
138
139
|
hdl/utils/weather/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
139
140
|
hdl/utils/weather/weather.py,sha256=k11o6wM15kF8b9NMlEfrg68ak-SfSYLN3nOOflFUv-I,4381
|
140
|
-
hjxdl-0.2.
|
141
|
-
hjxdl-0.2.
|
142
|
-
hjxdl-0.2.
|
143
|
-
hjxdl-0.2.
|
141
|
+
hjxdl-0.2.8.dist-info/METADATA,sha256=Hil-hE2MpSnoSW2KS290_JN9_3sXvpuMZY__tkvNCps,835
|
142
|
+
hjxdl-0.2.8.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
|
143
|
+
hjxdl-0.2.8.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
|
144
|
+
hjxdl-0.2.8.dist-info/RECORD,,
|
File without changes
|
File without changes
|