hjxdl 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hdl/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.7'
16
- __version_tuple__ = version_tuple = (0, 2, 7)
15
+ __version__ = version = '0.2.9'
16
+ __version_tuple__ = version_tuple = (0, 2, 9)
hdl/utils/llm/chatgr.py CHANGED
@@ -9,7 +9,8 @@ def chat_with_llm(user_input, chat_history=[]):
9
9
 
10
10
  bot_message = "" # Bot 消息初始化为空
11
11
  resp = llm.stream(
12
- "你的身份是芯途异构(ICTrek)的人工智能小助手,由芯途异构公司研发,请回答如下问题:\n"
12
+ "你的身份是VIVIBIT人工智能小助手,由芯途异构公司(ICTrek)研发,请回答如下问题,并保证回答所采用的语言与用户问题的语言保持一致。\n"
13
+ "Your identity is VIVIBIT AI Assistant, developed by ICTrek. Please answer the following question and ensure that the language used in the response matches the language of the user’s question.\n Question: "
13
14
  + user_input
14
15
  ) # 获取流式响应
15
16
 
@@ -0,0 +1,194 @@
1
+ import argparse
2
+ from PIL import Image
3
+ import hashlib
4
+ import torch
5
+ import fitz
6
+ import gradio as gr
7
+ import os
8
+ import numpy as np
9
+ import json
10
+ from transformers import AutoModel, AutoTokenizer
11
+ from hdl.utils.llm.chat import OpenAI_M
12
+
13
+ def get_image_md5(img: Image.Image):
14
+ img_byte_array = img.tobytes()
15
+ hash_md5 = hashlib.md5()
16
+ hash_md5.update(img_byte_array)
17
+ hex_digest = hash_md5.hexdigest()
18
+ return hex_digest
19
+
20
+ def calculate_md5_from_binary(binary_data):
21
+ hash_md5 = hashlib.md5()
22
+ hash_md5.update(binary_data)
23
+ return hash_md5.hexdigest()
24
+
25
+ def add_pdf_gradio(pdf_file_binary, progress=gr.Progress(), cache_dir=None, model=None, tokenizer=None):
26
+ model.eval()
27
+
28
+ knowledge_base_name = calculate_md5_from_binary(pdf_file_binary)
29
+
30
+ this_cache_dir = os.path.join(cache_dir, knowledge_base_name)
31
+ os.makedirs(this_cache_dir, exist_ok=True)
32
+
33
+ with open(os.path.join(this_cache_dir, f"src.pdf"), 'wb') as file:
34
+ file.write(pdf_file_binary)
35
+
36
+ dpi = 200
37
+ doc = fitz.open("pdf", pdf_file_binary)
38
+
39
+ reps_list = []
40
+ images = []
41
+ image_md5s = []
42
+
43
+ for page in progress.tqdm(doc):
44
+ pix = page.get_pixmap(dpi=dpi)
45
+ image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
46
+ image_md5 = get_image_md5(image)
47
+ image_md5s.append(image_md5)
48
+ with torch.no_grad():
49
+ reps = model(text=[''], image=[image], tokenizer=tokenizer).reps
50
+ reps_list.append(reps.squeeze(0).cpu().numpy())
51
+ images.append(image)
52
+
53
+ for idx in range(len(images)):
54
+ image = images[idx]
55
+ image_md5 = image_md5s[idx]
56
+ cache_image_path = os.path.join(this_cache_dir, f"{image_md5}.png")
57
+ image.save(cache_image_path)
58
+
59
+ np.save(os.path.join(this_cache_dir, f"reps.npy"), reps_list)
60
+
61
+ with open(os.path.join(this_cache_dir, f"md5s.txt"), 'w') as f:
62
+ for item in image_md5s:
63
+ f.write(item+'\n')
64
+
65
+ return knowledge_base_name
66
+
67
+ def retrieve_gradio(knowledge_base, query, topk, cache_dir=None, model=None, tokenizer=None):
68
+ model.eval()
69
+
70
+ target_cache_dir = os.path.join(cache_dir, knowledge_base)
71
+
72
+ if not os.path.exists(target_cache_dir):
73
+ return None
74
+
75
+ md5s = []
76
+ with open(os.path.join(target_cache_dir, f"md5s.txt"), 'r') as f:
77
+ for line in f:
78
+ md5s.append(line.rstrip('\n'))
79
+
80
+ doc_reps = np.load(os.path.join(target_cache_dir, f"reps.npy"))
81
+
82
+ query_with_instruction = "Represent this query for retrieving relevant document: " + query
83
+ with torch.no_grad():
84
+ query_rep = model(text=[query_with_instruction], image=[None], tokenizer=tokenizer).reps.squeeze(0).cpu()
85
+
86
+ query_md5 = hashlib.md5(query.encode()).hexdigest()
87
+
88
+ doc_reps_cat = torch.stack([torch.Tensor(i) for i in doc_reps], dim=0)
89
+
90
+ similarities = torch.matmul(query_rep, doc_reps_cat.T)
91
+
92
+ topk_values, topk_doc_ids = torch.topk(similarities, k=topk)
93
+
94
+ images_topk = [Image.open(os.path.join(target_cache_dir, f"{md5s[idx]}.png")) for idx in topk_doc_ids.cpu().numpy()]
95
+
96
+ with open(os.path.join(target_cache_dir, f"q-{query_md5}.json"), 'w') as f:
97
+ f.write(json.dumps(
98
+ {
99
+ "knowledge_base": knowledge_base,
100
+ "query": query,
101
+ "retrieved_docs": [os.path.join(target_cache_dir, f"{md5s[idx]}.png") for idx in topk_doc_ids.cpu().numpy()]
102
+ }, indent=4, ensure_ascii=False
103
+ ))
104
+
105
+ return images_topk
106
+
107
+ def answer_question(images, question, gen_model):
108
+ images_ = [Image.open(image[0]).convert('RGB') for image in images]
109
+ answer = gen_model.chat(
110
+ prompt=question,
111
+ images=images_,
112
+ stream=False
113
+ )
114
+ return answer
115
+
116
+ def upvote(knowledge_base, query, cache_dir):
117
+ target_cache_dir = os.path.join(cache_dir, knowledge_base)
118
+ query_md5 = hashlib.md5(query.encode()).hexdigest()
119
+
120
+ with open(os.path.join(target_cache_dir, f"q-{query_md5}.json"), 'r') as f:
121
+ data = json.loads(f.read())
122
+
123
+ data["user_preference"] = "upvote"
124
+
125
+ with open(os.path.join(target_cache_dir, f"q-{query_md5}-withpref.json"), 'w') as f:
126
+ f.write(json.dumps(data, indent=4, ensure_ascii=False))
127
+
128
+ def downvote(knowledge_base, query, cache_dir):
129
+ target_cache_dir = os.path.join(cache_dir, knowledge_base)
130
+ query_md5 = hashlib.md5(query.encode()).hexdigest()
131
+
132
+ with open(os.path.join(target_cache_dir, f"q-{query_md5}.json"), 'r') as f:
133
+ data = json.loads(f.read())
134
+
135
+ data["user_preference"] = "downvote"
136
+
137
+ with open(os.path.join(target_cache_dir, f"q-{query_md5}-withpref.json"), 'w') as f:
138
+ f.write(json.dumps(data, indent=4, ensure_ascii=False))
139
+
140
+ if __name__ == '__main__':
141
+ parser = argparse.ArgumentParser(description="RAG-PDFQA Script")
142
+ parser.add_argument('--cache_dir', type=str, required=True, help='Cache directory path')
143
+ parser.add_argument('--device', type=str, default='cuda:0', help='Device for model inference')
144
+ parser.add_argument('--model_path', type=str, required=True, help='Path to the embedding model')
145
+ parser.add_argument('--llm_host', type=str, default='127.0.0.0', help='LLM server IP address')
146
+ parser.add_argument('--llm_port', type=int, default=22299, help='LLM server port')
147
+ parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Gradio server name')
148
+ parser.add_argument('--server_port', type=int, default=10077, help='Gradio server port')
149
+
150
+ args = parser.parse_args()
151
+
152
+ print("Loading embedding model...")
153
+ tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
154
+ model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True)
155
+ model.to(args.device)
156
+ model.eval()
157
+ print("Embedding model loaded!")
158
+
159
+ gen_model = OpenAI_M(
160
+ server_ip=args.llm_host,
161
+ server_port=args.llm_port
162
+ )
163
+
164
+ with gr.Blocks() as app:
165
+ gr.Markdown("# Vision Language Models Enable End-to-End RAG")
166
+
167
+ file_input = gr.File(type="binary", label="Step 1: Upload PDF")
168
+ file_result = gr.Text(label="Knowledge Base ID")
169
+ process_button = gr.Button("Process PDF")
170
+
171
+ process_button.click(add_pdf_gradio, inputs=[file_input], outputs=file_result,
172
+ _kwargs={'cache_dir': args.cache_dir, 'model': model, 'tokenizer': tokenizer})
173
+
174
+ kb_id_input = gr.Text(label="Knowledge Base ID")
175
+ query_input = gr.Text(label="Your Question")
176
+ topk_input = gr.Number(value=5, minimum=1, maximum=10, step=1, label="Number of pages to retrieve")
177
+ retrieve_button = gr.Button("Retrieve Pages")
178
+ images_output = gr.Gallery(label="Retrieved Pages")
179
+
180
+ retrieve_button.click(retrieve_gradio, inputs=[kb_id_input, query_input, topk_input], outputs=images_output,
181
+ _kwargs={'cache_dir': args.cache_dir, 'model': model, 'tokenizer': tokenizer})
182
+
183
+ button = gr.Button("Answer Question")
184
+ gen_model_response = gr.Textbox(label="Model's Answer")
185
+
186
+ button.click(answer_question, inputs=[images_output, query_input], outputs=gen_model_response, _kwargs={'gen_model': gen_model})
187
+
188
+ upvote_button = gr.Button("🤗 Upvote")
189
+ downvote_button = gr.Button("🤣 Downvote")
190
+
191
+ upvote_button.click(upvote, inputs=[kb_id_input, query_input], outputs=None, _kwargs={'cache_dir': args.cache_dir})
192
+ downvote_button.click(downvote, inputs=[kb_id_input, query_input], outputs=None, _kwargs={'cache_dir': args.cache_dir})
193
+
194
+ app.launch(server_name=args.server_name, server_port=args.server_port)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: hjxdl
3
- Version: 0.2.7
3
+ Version: 0.2.9
4
4
  Summary: A collection of functions for Jupyter notebooks
5
5
  Home-page: https://github.com/huluxiaohuowa/hdl
6
6
  Author: Jianxing Hu
@@ -1,5 +1,5 @@
1
1
  hdl/__init__.py,sha256=GffnD0jLJdhkd-vo989v40N90sQbofkayRBwxc6TVhQ,72
2
- hdl/_version.py,sha256=mXPA9v4zB_3yyyQ0dNvl4mba1fMh4hU3qcyCEtD2HlE,411
2
+ hdl/_version.py,sha256=3Wn5WgE4hQ0wW5cv50DvKHFoeoZaIrtStGD1O5cZq04,411
3
3
  hdl/args/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  hdl/args/loss_args.py,sha256=s7YzSdd7IjD24rZvvOrxLLFqMZQb9YylxKeyelSdrTk,70
5
5
  hdl/controllers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -128,16 +128,17 @@ hdl/utils/general/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU
128
128
  hdl/utils/general/glob.py,sha256=8-RCnt6L297wMIfn34ZAMCsGCZUjHG3MGglGZI1cX0g,491
129
129
  hdl/utils/llm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
130
130
  hdl/utils/llm/chat.py,sha256=OzyY9xACOOocx9zZigtq9YAPvHtDUo8v2fvf1Tyjg_U,14891
131
- hdl/utils/llm/chatgr.py,sha256=TFMYaJLNyg2eT3v1PXxZBnodLIBlUQCtTXl_2XRaGUs,2539
131
+ hdl/utils/llm/chatgr.py,sha256=GO2G7g6YybduA5VCUuGjvEsJfC_6L7rycSnPeHMcxyM,2820
132
132
  hdl/utils/llm/embs.py,sha256=Tf0FOYrOFZp7qQpEPiSCXzlgyHH0X9HVTUtsup74a9E,7174
133
133
  hdl/utils/llm/extract.py,sha256=2sK_WJzmYIc8iuWaM9DA6Nw3_6q1O4lJ5pKpcZo-bBA,6512
134
134
  hdl/utils/llm/llama_chat.py,sha256=watcHGOaz-bv3x-yDucYlGk5f8FiqfFhwWogrl334fk,4387
135
135
  hdl/utils/llm/vis.py,sha256=2pI0439GWi_BEVfQJtY29Y72FkUa8jEvBeqMlwy7xkc,15716
136
+ hdl/utils/llm/visrag.py,sha256=BBWmYI8p9e9ZclWBP8nC3kuv7LzvYlg-7gyw_JTZ7K0,7556
136
137
  hdl/utils/schedulers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
137
138
  hdl/utils/schedulers/norm_lr.py,sha256=bDwCmdEK-WkgxQMFBiMuchv8Mm7C0-GZJ6usm-PQk14,4461
138
139
  hdl/utils/weather/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
139
140
  hdl/utils/weather/weather.py,sha256=k11o6wM15kF8b9NMlEfrg68ak-SfSYLN3nOOflFUv-I,4381
140
- hjxdl-0.2.7.dist-info/METADATA,sha256=ojs28v3_8CbZmw4NexzDGNgA-uebZ-65GKX43e41yOo,835
141
- hjxdl-0.2.7.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
142
- hjxdl-0.2.7.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
143
- hjxdl-0.2.7.dist-info/RECORD,,
141
+ hjxdl-0.2.9.dist-info/METADATA,sha256=1T0y4ijcXd1dnb9E3LmS6dKK3Jr-GLenT2lB11h7l54,835
142
+ hjxdl-0.2.9.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
143
+ hjxdl-0.2.9.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
144
+ hjxdl-0.2.9.dist-info/RECORD,,
File without changes