hjxdl 0.2.28__py3-none-any.whl → 0.2.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hdl/_version.py +2 -2
- hdl/utils/llm/vis.py +81 -46
- {hjxdl-0.2.28.dist-info → hjxdl-0.2.30.dist-info}/METADATA +1 -1
- {hjxdl-0.2.28.dist-info → hjxdl-0.2.30.dist-info}/RECORD +6 -6
- {hjxdl-0.2.28.dist-info → hjxdl-0.2.30.dist-info}/WHEEL +1 -1
- {hjxdl-0.2.28.dist-info → hjxdl-0.2.30.dist-info}/top_level.txt +0 -0
hdl/_version.py
CHANGED
hdl/utils/llm/vis.py
CHANGED
@@ -10,8 +10,10 @@ import torch
|
|
10
10
|
import numpy as np
|
11
11
|
from PIL import Image
|
12
12
|
# from transformers import ChineseCLIPProcessor, ChineseCLIPModel
|
13
|
+
from transformers import AutoModel
|
14
|
+
from transformers import AutoTokenizer
|
13
15
|
import open_clip
|
14
|
-
import natsort
|
16
|
+
# import natsort
|
15
17
|
from redis.commands.search.field import VectorField
|
16
18
|
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
17
19
|
from hdl.jupyfuncs.show.pbar import tqdm
|
@@ -96,7 +98,7 @@ def imgbase64_to_pilimg(img_base64: str):
|
|
96
98
|
BytesIO(
|
97
99
|
base64.b64decode(img_base64.split(",")[-1])
|
98
100
|
)
|
99
|
-
)
|
101
|
+
).convert('RGB')
|
100
102
|
return img_pil
|
101
103
|
|
102
104
|
|
@@ -124,6 +126,7 @@ class ImgHandler:
|
|
124
126
|
model_path,
|
125
127
|
conn=None,
|
126
128
|
model_name: str = None,
|
129
|
+
model_type: str = "openclip",
|
127
130
|
device: str = "cpu",
|
128
131
|
num_vec_dim: int = None,
|
129
132
|
load_model: bool = True,
|
@@ -143,9 +146,11 @@ class ImgHandler:
|
|
143
146
|
None
|
144
147
|
"""
|
145
148
|
|
149
|
+
self.device_str = device
|
146
150
|
self.device = torch.device(device)
|
147
151
|
self.model_path = model_path
|
148
152
|
self.model_name = model_name
|
153
|
+
self.model_type = model_type
|
149
154
|
|
150
155
|
self.db_conn = conn
|
151
156
|
self.num_vec_dim = num_vec_dim
|
@@ -163,32 +168,46 @@ class ImgHandler:
|
|
163
168
|
Returns:
|
164
169
|
None
|
165
170
|
"""
|
166
|
-
ckpt_file = (
|
167
|
-
Path(self.model_path) / Path("open_clip_pytorch_model.bin")
|
168
|
-
).as_posix()
|
169
|
-
self.open_clip_cfg = json.load(
|
170
|
-
open(Path(self.model_path) / Path("open_clip_config.json"))
|
171
|
-
)
|
172
171
|
|
173
|
-
if self.
|
174
|
-
self.
|
175
|
-
self.
|
176
|
-
|
172
|
+
if self.model_type == "cpm":
|
173
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
174
|
+
self.model_path,
|
175
|
+
trust_remote_code=True
|
176
|
+
)
|
177
|
+
self.model = AutoModel.from_pretrained(
|
178
|
+
self.model_path,
|
179
|
+
trust_remote_code=True
|
180
|
+
)
|
181
|
+
self.model.to(self.device)
|
182
|
+
self.num_vec_dim = 2304
|
183
|
+
|
184
|
+
elif self.model_type == "openclip":
|
185
|
+
ckpt_file = (
|
186
|
+
Path(self.model_path) / Path("open_clip_pytorch_model.bin")
|
187
|
+
).as_posix()
|
188
|
+
self.open_clip_cfg = json.load(
|
189
|
+
open(Path(self.model_path) / Path("open_clip_config.json"))
|
177
190
|
)
|
178
191
|
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
192
|
+
if self.model_name is None:
|
193
|
+
self.model_name = (
|
194
|
+
self.open_clip_cfg['model_cfg']['text_cfg']['hf_tokenizer_name']
|
195
|
+
.split('/')[-1]
|
196
|
+
)
|
197
|
+
|
198
|
+
self.model, self.preprocess_train, self.preprocess_val = (
|
199
|
+
open_clip.create_model_and_transforms(
|
200
|
+
model_name=self.model_name,
|
201
|
+
pretrained=ckpt_file,
|
202
|
+
device=self.device,
|
203
|
+
# precision=precision
|
204
|
+
)
|
205
|
+
)
|
206
|
+
if self.num_vec_dim is None:
|
207
|
+
self.num_vec_dim = self.open_clip_cfg["model_cfg"]["embed_dim"]
|
208
|
+
self.tokenizer = open_clip.get_tokenizer(
|
209
|
+
HF_HUB_PREFIX + self.model_path
|
185
210
|
)
|
186
|
-
)
|
187
|
-
if self.num_vec_dim is None:
|
188
|
-
self.num_vec_dim = self.open_clip_cfg["model_cfg"]["embed_dim"]
|
189
|
-
self.tokenizer = open_clip.get_tokenizer(
|
190
|
-
HF_HUB_PREFIX + self.model_path
|
191
|
-
)
|
192
211
|
|
193
212
|
def get_img_features(
|
194
213
|
self,
|
@@ -221,16 +240,24 @@ class ImgHandler:
|
|
221
240
|
f"Not supported image type for {type(img)}"
|
222
241
|
)
|
223
242
|
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
243
|
+
if self.model_type == "cpm":
|
244
|
+
with torch.no_grad():
|
245
|
+
img_features = self.model(
|
246
|
+
text=[""] * len(images_fixed),
|
247
|
+
image=images_fixed,
|
248
|
+
tokenizer=self.tokenizer
|
249
|
+
).reps
|
250
|
+
|
251
|
+
if self.model_type == "openclip":
|
252
|
+
with torch.no_grad(), torch.amp.autocast(self.device_str):
|
253
|
+
imgs = torch.stack([
|
254
|
+
self.preprocess_val(image).to(self.device)
|
255
|
+
for image in images_fixed
|
256
|
+
])
|
257
|
+
img_features = self.model.encode_image(imgs, **kwargs)
|
258
|
+
img_features /= img_features.norm(dim=-1, keepdim=True)
|
259
|
+
if to_numpy:
|
260
|
+
img_features = img_features.cpu().numpy()
|
234
261
|
return img_features
|
235
262
|
|
236
263
|
def get_text_features(
|
@@ -252,17 +279,25 @@ class ImgHandler:
|
|
252
279
|
Example:
|
253
280
|
get_text_features(["text1", "text2"], to_numpy=True)
|
254
281
|
"""
|
255
|
-
with torch.no_grad(), torch.amp.autocast("cuda"):
|
256
|
-
txts = self.tokenizer(
|
257
|
-
texts,
|
258
|
-
context_length=self.model.context_length
|
259
|
-
).to(self.device)
|
260
|
-
txt_features = self.model.encode_text(txts, **kwargs)
|
261
|
-
txt_features /= txt_features.norm(dim=-1, keepdim=True)
|
262
|
-
if to_numpy:
|
263
|
-
txt_features = txt_features.cpu().numpy()
|
264
|
-
return txt_features
|
265
282
|
|
283
|
+
if self.model_type == "cpm":
|
284
|
+
with torch.no_grad():
|
285
|
+
txt_features = self.model(
|
286
|
+
text=texts,
|
287
|
+
image=[None] * len(texts),
|
288
|
+
tokenizer=self.tokenizer
|
289
|
+
).reps
|
290
|
+
elif self.model_type == "openclip":
|
291
|
+
with torch.no_grad(), torch.amp.autocast(self.device_str):
|
292
|
+
txts = self.tokenizer(
|
293
|
+
texts,
|
294
|
+
context_length=self.model.context_length
|
295
|
+
).to(self.device)
|
296
|
+
txt_features = self.model.encode_text(txts, **kwargs)
|
297
|
+
txt_features /= txt_features.norm(dim=-1, keepdim=True)
|
298
|
+
if to_numpy:
|
299
|
+
txt_features = txt_features.cpu().numpy()
|
300
|
+
return txt_features
|
266
301
|
|
267
302
|
def get_text_img_probs(
|
268
303
|
self,
|
@@ -284,7 +319,7 @@ class ImgHandler:
|
|
284
319
|
Returns:
|
285
320
|
torch.Tensor or numpy.ndarray: Text-image association probabilities.
|
286
321
|
"""
|
287
|
-
with torch.no_grad(), torch.amp.autocast(
|
322
|
+
with torch.no_grad(), torch.amp.autocast(self.device_str):
|
288
323
|
image_features = self.get_img_features(images, **kwargs)
|
289
324
|
text_features = self.get_text_features(texts, **kwargs)
|
290
325
|
text_probs = (100.0 * image_features @ text_features.T)
|
@@ -313,7 +348,7 @@ class ImgHandler:
|
|
313
348
|
Returns:
|
314
349
|
torch.Tensor or numpy.ndarray: Similarity scores between the two sets of images.
|
315
350
|
"""
|
316
|
-
with torch.no_grad(), torch.amp.autocast(
|
351
|
+
with torch.no_grad(), torch.amp.autocast(self.device_str):
|
317
352
|
img1_feats = self.get_img_features(images1, **kwargs)
|
318
353
|
img2_feats = self.get_img_features(images2, **kwargs)
|
319
354
|
sims = img1_feats @ img2_feats.T
|
@@ -1,5 +1,5 @@
|
|
1
1
|
hdl/__init__.py,sha256=GffnD0jLJdhkd-vo989v40N90sQbofkayRBwxc6TVhQ,72
|
2
|
-
hdl/_version.py,sha256=
|
2
|
+
hdl/_version.py,sha256=YLA3hy-44LcRh6Hcq_A-SyVyfXwkdObMoIBSVcngmDI,413
|
3
3
|
hdl/args/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
hdl/args/loss_args.py,sha256=s7YzSdd7IjD24rZvvOrxLLFqMZQb9YylxKeyelSdrTk,70
|
5
5
|
hdl/controllers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -132,13 +132,13 @@ hdl/utils/llm/chatgr.py,sha256=GO2G7g6YybduA5VCUuGjvEsJfC_6L7rycSnPeHMcxyM,2820
|
|
132
132
|
hdl/utils/llm/embs.py,sha256=Tf0FOYrOFZp7qQpEPiSCXzlgyHH0X9HVTUtsup74a9E,7174
|
133
133
|
hdl/utils/llm/extract.py,sha256=2sK_WJzmYIc8iuWaM9DA6Nw3_6q1O4lJ5pKpcZo-bBA,6512
|
134
134
|
hdl/utils/llm/llama_chat.py,sha256=watcHGOaz-bv3x-yDucYlGk5f8FiqfFhwWogrl334fk,4387
|
135
|
-
hdl/utils/llm/vis.py,sha256=
|
135
|
+
hdl/utils/llm/vis.py,sha256=RWeI6lSmzCDG2HJMq8-teuC7to4pPiR0ee2Hx1clbRw,17656
|
136
136
|
hdl/utils/llm/visrag.py,sha256=_PuKtmQIXD5bnmXwDWhTLdzOhgC42JiqdMNb1uKA7n8,9190
|
137
137
|
hdl/utils/schedulers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
138
138
|
hdl/utils/schedulers/norm_lr.py,sha256=bDwCmdEK-WkgxQMFBiMuchv8Mm7C0-GZJ6usm-PQk14,4461
|
139
139
|
hdl/utils/weather/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
140
140
|
hdl/utils/weather/weather.py,sha256=k11o6wM15kF8b9NMlEfrg68ak-SfSYLN3nOOflFUv-I,4381
|
141
|
-
hjxdl-0.2.
|
142
|
-
hjxdl-0.2.
|
143
|
-
hjxdl-0.2.
|
144
|
-
hjxdl-0.2.
|
141
|
+
hjxdl-0.2.30.dist-info/METADATA,sha256=aKoW2m7jTjCWQSTm4Kzxi7n5YWyXXxjHlmqHjZjG8D4,836
|
142
|
+
hjxdl-0.2.30.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
143
|
+
hjxdl-0.2.30.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
|
144
|
+
hjxdl-0.2.30.dist-info/RECORD,,
|
File without changes
|