hjxdl 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hdl/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.14'
16
- __version_tuple__ = version_tuple = (0, 2, 14)
15
+ __version__ = version = '0.2.16'
16
+ __version_tuple__ = version_tuple = (0, 2, 16)
@@ -0,0 +1,224 @@
1
+ import argparse
2
+ from PIL import Image
3
+ import hashlib
4
+ import torch
5
+ import fitz
6
+ import gradio as gr
7
+ import os
8
+ import numpy as np
9
+ import json
10
+ from transformers import AutoModel, AutoTokenizer
11
+
12
+ from .chat import OpenAI_M
13
+ from .vis import pilimg_to_base64
14
+
15
+ def get_image_md5(img: Image.Image):
16
+ img_byte_array = img.tobytes()
17
+ hash_md5 = hashlib.md5()
18
+ hash_md5.update(img_byte_array)
19
+ hex_digest = hash_md5.hexdigest()
20
+ return hex_digest
21
+
22
+ def calculate_md5_from_binary(binary_data):
23
+ hash_md5 = hashlib.md5()
24
+ hash_md5.update(binary_data)
25
+ return hash_md5.hexdigest()
26
+
27
+ def add_pdf_gradio(pdf_file_binary, progress=gr.Progress(), cache_dir=None, model=None, tokenizer=None):
28
+ model.eval()
29
+
30
+ knowledge_base_name = calculate_md5_from_binary(pdf_file_binary)
31
+
32
+ this_cache_dir = os.path.join(cache_dir, knowledge_base_name)
33
+ os.makedirs(this_cache_dir, exist_ok=True)
34
+
35
+ with open(os.path.join(this_cache_dir, f"src.pdf"), 'wb') as file:
36
+ file.write(pdf_file_binary)
37
+
38
+ dpi = 200
39
+ doc = fitz.open("pdf", pdf_file_binary)
40
+
41
+ reps_list = []
42
+ images = []
43
+ image_md5s = []
44
+
45
+ for page in progress.tqdm(doc):
46
+ pix = page.get_pixmap(dpi=dpi)
47
+ image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
48
+ image_md5 = get_image_md5(image)
49
+ image_md5s.append(image_md5)
50
+ with torch.no_grad():
51
+ reps = model(text=[''], image=[image], tokenizer=tokenizer).reps
52
+ reps_list.append(reps.squeeze(0).cpu().numpy())
53
+ images.append(image)
54
+
55
+ for idx in range(len(images)):
56
+ image = images[idx]
57
+ image_md5 = image_md5s[idx]
58
+ cache_image_path = os.path.join(this_cache_dir, f"{image_md5}.png")
59
+ image.save(cache_image_path)
60
+
61
+ np.save(os.path.join(this_cache_dir, f"reps.npy"), reps_list)
62
+
63
+ with open(os.path.join(this_cache_dir, f"md5s.txt"), 'w') as f:
64
+ for item in image_md5s:
65
+ f.write(item+'\n')
66
+
67
+ return knowledge_base_name
68
+
69
+ def retrieve_gradio(knowledge_base, query, topk, cache_dir=None, model=None, tokenizer=None):
70
+ model.eval()
71
+
72
+ target_cache_dir = os.path.join(cache_dir, knowledge_base)
73
+
74
+ if not os.path.exists(target_cache_dir):
75
+ return None
76
+
77
+ md5s = []
78
+ with open(os.path.join(target_cache_dir, f"md5s.txt"), 'r') as f:
79
+ for line in f:
80
+ md5s.append(line.rstrip('\n'))
81
+
82
+ doc_reps = np.load(os.path.join(target_cache_dir, f"reps.npy"))
83
+
84
+ query_with_instruction = "Represent this query for retrieving relevant document: " + query
85
+ with torch.no_grad():
86
+ query_rep = model(text=[query_with_instruction], image=[None], tokenizer=tokenizer).reps.squeeze(0).cpu()
87
+
88
+ query_md5 = hashlib.md5(query.encode()).hexdigest()
89
+
90
+ doc_reps_cat = torch.stack([torch.Tensor(i) for i in doc_reps], dim=0)
91
+
92
+ similarities = torch.matmul(query_rep, doc_reps_cat.T)
93
+
94
+ topk_values, topk_doc_ids = torch.topk(similarities, k=topk)
95
+
96
+ images_topk = [Image.open(os.path.join(target_cache_dir, f"{md5s[idx]}.png")) for idx in topk_doc_ids.cpu().numpy()]
97
+
98
+ with open(os.path.join(target_cache_dir, f"q-{query_md5}.json"), 'w') as f:
99
+ f.write(json.dumps(
100
+ {
101
+ "knowledge_base": knowledge_base,
102
+ "query": query,
103
+ "retrieved_docs": [os.path.join(target_cache_dir, f"{md5s[idx]}.png") for idx in topk_doc_ids.cpu().numpy()]
104
+ }, indent=4, ensure_ascii=False
105
+ ))
106
+
107
+ return images_topk
108
+
109
+ def answer_question_stream(images, question, gen_model):
110
+ # Load images from the image paths in images[0]
111
+ pil_images = [Image.open(image[0]).convert('RGB') for image in images]
112
+
113
+ # Calculate the total size of the new image (for vertical concatenation)
114
+ widths, heights = zip(*(img.size for img in pil_images))
115
+
116
+ # Assuming vertical concatenation, so width is the max width, height is the sum of heights
117
+ total_width = max(widths)
118
+ total_height = sum(heights)
119
+
120
+ # Create a new blank image with the total width and height
121
+ new_image = Image.new('RGB', (total_width, total_height))
122
+
123
+ # Paste each image into the new image
124
+ y_offset = 0
125
+ for img in pil_images:
126
+ new_image.paste(img, (0, y_offset))
127
+ y_offset += img.height # Move the offset down by the height of the image
128
+
129
+ # Convert the concatenated image to base64
130
+ new_image_base64 = pilimg_to_base64(new_image)
131
+
132
+ # Call the model with the base64-encoded concatenated image and stream=True
133
+ for partial_answer in gen_model.chat(
134
+ prompt=question,
135
+ images=[new_image_base64], # Use the concatenated image
136
+ stream=True # Enable streaming
137
+ ):
138
+ # Yield the partial answer as it comes in
139
+ yield partial_answer # Stream the output to Gradio
140
+
141
+ def upvote(knowledge_base, query, cache_dir):
142
+ target_cache_dir = os.path.join(cache_dir, knowledge_base)
143
+ query_md5 = hashlib.md5(query.encode()).hexdigest()
144
+
145
+ with open(os.path.join(target_cache_dir, f"q-{query_md5}.json"), 'r') as f:
146
+ data = json.loads(f.read())
147
+
148
+ data["user_preference"] = "upvote"
149
+
150
+ with open(os.path.join(target_cache_dir, f"q-{query_md5}-withpref.json"), 'w') as f:
151
+ f.write(json.dumps(data, indent=4, ensure_ascii=False))
152
+
153
+ def downvote(knowledge_base, query, cache_dir):
154
+ target_cache_dir = os.path.join(cache_dir, knowledge_base)
155
+ query_md5 = hashlib.md5(query.encode()).hexdigest()
156
+
157
+ with open(os.path.join(target_cache_dir, f"q-{query_md5}.json"), 'r') as f:
158
+ data = json.loads(f.read())
159
+
160
+ data["user_preference"] = "downvote"
161
+
162
+ with open(os.path.join(target_cache_dir, f"q-{query_md5}-withpref.json"), 'w') as f:
163
+ f.write(json.dumps(data, indent=4, ensure_ascii=False))
164
+
165
+ if __name__ == '__main__':
166
+ parser = argparse.ArgumentParser(description="MiniCPMV-RAG-PDFQA Script")
167
+ parser.add_argument('--cache-dir', dest='cache_dir', type=str, required=True, help='Cache directory path')
168
+ parser.add_argument('--device', dest='device', type=str, default='cuda:0', help='Device for model inference')
169
+ parser.add_argument('--model-path', dest='model_path', type=str, required=True, help='Path to the embedding model')
170
+ parser.add_argument('--llm-host', dest='llm_host', type=str, default='127.0.0.1', help='LLM server IP address')
171
+ parser.add_argument('--llm-port', dest='llm_port', type=int, default=22299, help='LLM server port')
172
+ parser.add_argument('--server-name', dest='server_name', type=str, default='0.0.0.0', help='Gradio server name')
173
+ parser.add_argument('--server-port', dest='server_port', type=int, default=10077, help='Gradio server port')
174
+
175
+ args = parser.parse_args()
176
+
177
+ print("Loading embedding model...")
178
+ tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
179
+ model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True)
180
+ model.to(args.device)
181
+ model.eval()
182
+ print("Embedding model loaded!")
183
+
184
+ gen_model = OpenAI_M(
185
+ server_ip=args.llm_host,
186
+ server_port=args.llm_port
187
+ )
188
+
189
+ with gr.Blocks() as app:
190
+ gr.Markdown("# MiniCPMV-RAG-PDFQA: Two Vision Language Models Enable End-to-End RAG")
191
+
192
+ file_input = gr.File(type="binary", label="Step 1: Upload PDF")
193
+ file_result = gr.Text(label="Knowledge Base ID")
194
+ process_button = gr.Button("Process PDF")
195
+
196
+ process_button.click(lambda pdf: add_pdf_gradio(pdf, cache_dir=args.cache_dir, model=model, tokenizer=tokenizer),
197
+ inputs=file_input, outputs=file_result)
198
+
199
+ kb_id_input = gr.Text(label="Knowledge Base ID")
200
+ query_input = gr.Text(label="Your Question")
201
+ topk_input = gr.Number(value=5, minimum=1, maximum=10, step=1, label="Number of pages to retrieve")
202
+ retrieve_button = gr.Button("Retrieve Pages")
203
+ images_output = gr.Gallery(label="Retrieved Pages")
204
+
205
+ retrieve_button.click(lambda kb, query, topk: retrieve_gradio(kb, query, topk, cache_dir=args.cache_dir, model=model, tokenizer=tokenizer),
206
+ inputs=[kb_id_input, query_input, topk_input], outputs=images_output)
207
+
208
+ button = gr.Button("Answer Question")
209
+ gen_model_response = gr.Textbox(label="MiniCPM-V-2.6's Answer", lines=10)
210
+
211
+ # Use answer_question_stream for streaming response
212
+ button.click(answer_question_stream,
213
+ inputs=[images_output, query_input],
214
+ outputs=gen_model_response)
215
+
216
+ upvote_button = gr.Button("🤗 Upvote")
217
+ downvote_button = gr.Button("🤣 Downvote")
218
+
219
+ upvote_button.click(lambda kb, query: upvote(kb, query, cache_dir=args.cache_dir),
220
+ inputs=[kb_id_input, query_input], outputs=None)
221
+ downvote_button.click(lambda kb, query: downvote(kb, query, cache_dir=args.cache_dir),
222
+ inputs=[kb_id_input, query_input], outputs=None)
223
+
224
+ app.launch(server_name=args.server_name, server_port=args.server_port)
hdl/utils/llm/vis.py CHANGED
@@ -15,6 +15,7 @@ from redis.commands.search.indexDefinition import IndexDefinition, IndexType
15
15
  from hdl.jupyfuncs.show.pbar import tqdm
16
16
  from redis.commands.search.query import Query
17
17
 
18
+
18
19
  from ..database_tools.connect import conn_redis
19
20
 
20
21
 
@@ -89,6 +90,7 @@ def imgfile_to_base64(img_dir: str):
89
90
 
90
91
  return img_base64
91
92
 
93
+
92
94
  def imgbase64_to_pilimg(img_base64: str):
93
95
  """Converts a base64 encoded image to a PIL image.
94
96
 
@@ -107,6 +109,24 @@ def imgbase64_to_pilimg(img_base64: str):
107
109
  return img_pil
108
110
 
109
111
 
112
+ def pilimg_to_base64(pilimg):
113
+ """Converts a PIL image to base64 format.
114
+
115
+ Args:
116
+ pilimg (PIL.Image): The PIL image to be converted.
117
+
118
+ Returns:
119
+ str: Base64 encoded image string.
120
+ """
121
+ buffered = BytesIO()
122
+ pilimg.save(buffered, format="PNG")
123
+ image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
124
+ img_format = 'png'
125
+ mime_type = f"image/{img_format}"
126
+ img_base64 = f"data:{mime_type};base64,{image_base64}"
127
+ return img_base64
128
+
129
+
110
130
  class ImgHandler:
111
131
  def __init__(
112
132
  self,
hdl/utils/llm/visrag.py CHANGED
@@ -12,7 +12,7 @@ import json
12
12
  from transformers import AutoModel, AutoTokenizer
13
13
 
14
14
  from .chat import OpenAI_M
15
- from .vis import imgfile_to_base64
15
+ from .vis import pilimg_to_base64
16
16
 
17
17
  def get_image_md5(img: Image.Image):
18
18
  img_byte_array = img.tobytes()
@@ -116,14 +116,35 @@ def retrieve_gradio(knowledge_base, query, topk, cache_dir=None, model=None, tok
116
116
  # return image_base64
117
117
 
118
118
  def answer_question(images, question, gen_model):
119
- # Convert images to base64
120
- # images_base64 = [convert_image_to_base64(Image.open(image[0]).convert('RGB')) for image in images]
121
- images_base64 = [imgfile_to_base64(image[0]) for image in images]
119
+ # Load images from the image paths in images[0]
120
+ pil_images = [Image.open(image[0]).convert('RGB') for image in images]
122
121
 
123
- # Pass base64-encoded images to gen_model.chat
122
+ # Calculate the total size of the new image (for vertical concatenation)
123
+ widths, heights = zip(*(img.size for img in pil_images))
124
+
125
+ # Assuming vertical concatenation, so width is the max width, height is the sum of heights
126
+ total_width = max(widths)
127
+ total_height = sum(heights)
128
+
129
+ # Create a new blank image with the total width and height
130
+ new_image = Image.new('RGB', (total_width, total_height))
131
+
132
+ # Paste each image into the new image
133
+ y_offset = 0
134
+ for img in pil_images:
135
+ new_image.paste(img, (0, y_offset))
136
+ y_offset += img.height # Move the offset down by the height of the image
137
+
138
+ # Optionally save or display the final concatenated image (for debugging)
139
+ # new_image.save('concatenated_image.png')
140
+
141
+ # Convert the concatenated image to base64
142
+ new_image_base64 = pilimg_to_base64(new_image)
143
+
144
+ # Call the model with the base64-encoded concatenated image
124
145
  answer = gen_model.chat(
125
146
  prompt=question,
126
- images=images_base64, # Use the base64 images
147
+ images=[new_image_base64], # Use the concatenated image
127
148
  stream=False
128
149
  )
129
150
  return answer
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: hjxdl
3
- Version: 0.2.14
3
+ Version: 0.2.16
4
4
  Summary: A collection of functions for Jupyter notebooks
5
5
  Home-page: https://github.com/huluxiaohuowa/hdl
6
6
  Author: Jianxing Hu
@@ -1,5 +1,5 @@
1
1
  hdl/__init__.py,sha256=GffnD0jLJdhkd-vo989v40N90sQbofkayRBwxc6TVhQ,72
2
- hdl/_version.py,sha256=KYrSahOPivF0LOfn4qq6iTibWNx1Db_9urh-NXAGe9E,413
2
+ hdl/_version.py,sha256=24Q7k0pOfSN3Vkvs8-MWQxeJqcvAZ3JvN_YWtflaEyU,413
3
3
  hdl/args/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  hdl/args/loss_args.py,sha256=s7YzSdd7IjD24rZvvOrxLLFqMZQb9YylxKeyelSdrTk,70
5
5
  hdl/controllers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -132,13 +132,14 @@ hdl/utils/llm/chatgr.py,sha256=GO2G7g6YybduA5VCUuGjvEsJfC_6L7rycSnPeHMcxyM,2820
132
132
  hdl/utils/llm/embs.py,sha256=Tf0FOYrOFZp7qQpEPiSCXzlgyHH0X9HVTUtsup74a9E,7174
133
133
  hdl/utils/llm/extract.py,sha256=2sK_WJzmYIc8iuWaM9DA6Nw3_6q1O4lJ5pKpcZo-bBA,6512
134
134
  hdl/utils/llm/llama_chat.py,sha256=watcHGOaz-bv3x-yDucYlGk5f8FiqfFhwWogrl334fk,4387
135
- hdl/utils/llm/vis.py,sha256=2pI0439GWi_BEVfQJtY29Y72FkUa8jEvBeqMlwy7xkc,15716
136
- hdl/utils/llm/visrag.py,sha256=8IsY4e3AlzmyfR1bTQhHQq-Z5uxLHiN9kPu-b_byTKw,8411
135
+ hdl/utils/llm/ocrrag.py,sha256=AxzoSZ9AHBJihTwxllprlukVYb0JI83GgvQDKHcJl-4,8982
136
+ hdl/utils/llm/vis.py,sha256=-6QvxSVzKqxLh_l0aYg2wN2G5HOiQvCpfp-jn9twXw0,16210
137
+ hdl/utils/llm/visrag.py,sha256=vNj4cHsvfC_Vc0eDPKZc-yflLUMGApZGpggjAqAlwS8,9215
137
138
  hdl/utils/schedulers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
138
139
  hdl/utils/schedulers/norm_lr.py,sha256=bDwCmdEK-WkgxQMFBiMuchv8Mm7C0-GZJ6usm-PQk14,4461
139
140
  hdl/utils/weather/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
140
141
  hdl/utils/weather/weather.py,sha256=k11o6wM15kF8b9NMlEfrg68ak-SfSYLN3nOOflFUv-I,4381
141
- hjxdl-0.2.14.dist-info/METADATA,sha256=_IIRb9CIkJLiuQ42cVgk9pJOevHNnqOoWEetESOYX2I,836
142
- hjxdl-0.2.14.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
143
- hjxdl-0.2.14.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
144
- hjxdl-0.2.14.dist-info/RECORD,,
142
+ hjxdl-0.2.16.dist-info/METADATA,sha256=5jUlljbBjcD-EWi_2s4qYK3G2r7uoCSL4t7WinXNwmE,836
143
+ hjxdl-0.2.16.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
144
+ hjxdl-0.2.16.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
145
+ hjxdl-0.2.16.dist-info/RECORD,,
File without changes