hjxdl 0.2.29__py3-none-any.whl → 0.2.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hdl/_version.py +2 -2
- hdl/utils/llm/vis.py +78 -44
- {hjxdl-0.2.29.dist-info → hjxdl-0.2.30.dist-info}/METADATA +1 -1
- {hjxdl-0.2.29.dist-info → hjxdl-0.2.30.dist-info}/RECORD +6 -6
- {hjxdl-0.2.29.dist-info → hjxdl-0.2.30.dist-info}/WHEEL +1 -1
- {hjxdl-0.2.29.dist-info → hjxdl-0.2.30.dist-info}/top_level.txt +0 -0
hdl/_version.py
CHANGED
hdl/utils/llm/vis.py
CHANGED
@@ -10,8 +10,10 @@ import torch
|
|
10
10
|
import numpy as np
|
11
11
|
from PIL import Image
|
12
12
|
# from transformers import ChineseCLIPProcessor, ChineseCLIPModel
|
13
|
+
from transformers import AutoModel
|
14
|
+
from transformers import AutoTokenizer
|
13
15
|
import open_clip
|
14
|
-
import natsort
|
16
|
+
# import natsort
|
15
17
|
from redis.commands.search.field import VectorField
|
16
18
|
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
17
19
|
from hdl.jupyfuncs.show.pbar import tqdm
|
@@ -96,7 +98,7 @@ def imgbase64_to_pilimg(img_base64: str):
|
|
96
98
|
BytesIO(
|
97
99
|
base64.b64decode(img_base64.split(",")[-1])
|
98
100
|
)
|
99
|
-
)
|
101
|
+
).convert('RGB')
|
100
102
|
return img_pil
|
101
103
|
|
102
104
|
|
@@ -124,6 +126,7 @@ class ImgHandler:
|
|
124
126
|
model_path,
|
125
127
|
conn=None,
|
126
128
|
model_name: str = None,
|
129
|
+
model_type: str = "openclip",
|
127
130
|
device: str = "cpu",
|
128
131
|
num_vec_dim: int = None,
|
129
132
|
load_model: bool = True,
|
@@ -147,6 +150,7 @@ class ImgHandler:
|
|
147
150
|
self.device = torch.device(device)
|
148
151
|
self.model_path = model_path
|
149
152
|
self.model_name = model_name
|
153
|
+
self.model_type = model_type
|
150
154
|
|
151
155
|
self.db_conn = conn
|
152
156
|
self.num_vec_dim = num_vec_dim
|
@@ -164,32 +168,46 @@ class ImgHandler:
|
|
164
168
|
Returns:
|
165
169
|
None
|
166
170
|
"""
|
167
|
-
ckpt_file = (
|
168
|
-
Path(self.model_path) / Path("open_clip_pytorch_model.bin")
|
169
|
-
).as_posix()
|
170
|
-
self.open_clip_cfg = json.load(
|
171
|
-
open(Path(self.model_path) / Path("open_clip_config.json"))
|
172
|
-
)
|
173
171
|
|
174
|
-
if self.
|
175
|
-
self.
|
176
|
-
self.
|
177
|
-
|
172
|
+
if self.model_type == "cpm":
|
173
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
174
|
+
self.model_path,
|
175
|
+
trust_remote_code=True
|
176
|
+
)
|
177
|
+
self.model = AutoModel.from_pretrained(
|
178
|
+
self.model_path,
|
179
|
+
trust_remote_code=True
|
178
180
|
)
|
181
|
+
self.model.to(self.device)
|
182
|
+
self.num_vec_dim = 2304
|
183
|
+
|
184
|
+
elif self.model_type == "openclip":
|
185
|
+
ckpt_file = (
|
186
|
+
Path(self.model_path) / Path("open_clip_pytorch_model.bin")
|
187
|
+
).as_posix()
|
188
|
+
self.open_clip_cfg = json.load(
|
189
|
+
open(Path(self.model_path) / Path("open_clip_config.json"))
|
190
|
+
)
|
191
|
+
|
192
|
+
if self.model_name is None:
|
193
|
+
self.model_name = (
|
194
|
+
self.open_clip_cfg['model_cfg']['text_cfg']['hf_tokenizer_name']
|
195
|
+
.split('/')[-1]
|
196
|
+
)
|
179
197
|
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
198
|
+
self.model, self.preprocess_train, self.preprocess_val = (
|
199
|
+
open_clip.create_model_and_transforms(
|
200
|
+
model_name=self.model_name,
|
201
|
+
pretrained=ckpt_file,
|
202
|
+
device=self.device,
|
203
|
+
# precision=precision
|
204
|
+
)
|
205
|
+
)
|
206
|
+
if self.num_vec_dim is None:
|
207
|
+
self.num_vec_dim = self.open_clip_cfg["model_cfg"]["embed_dim"]
|
208
|
+
self.tokenizer = open_clip.get_tokenizer(
|
209
|
+
HF_HUB_PREFIX + self.model_path
|
186
210
|
)
|
187
|
-
)
|
188
|
-
if self.num_vec_dim is None:
|
189
|
-
self.num_vec_dim = self.open_clip_cfg["model_cfg"]["embed_dim"]
|
190
|
-
self.tokenizer = open_clip.get_tokenizer(
|
191
|
-
HF_HUB_PREFIX + self.model_path
|
192
|
-
)
|
193
211
|
|
194
212
|
def get_img_features(
|
195
213
|
self,
|
@@ -222,16 +240,24 @@ class ImgHandler:
|
|
222
240
|
f"Not supported image type for {type(img)}"
|
223
241
|
)
|
224
242
|
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
243
|
+
if self.model_type == "cpm":
|
244
|
+
with torch.no_grad():
|
245
|
+
img_features = self.model(
|
246
|
+
text=[""] * len(images_fixed),
|
247
|
+
image=images_fixed,
|
248
|
+
tokenizer=self.tokenizer
|
249
|
+
).reps
|
250
|
+
|
251
|
+
if self.model_type == "openclip":
|
252
|
+
with torch.no_grad(), torch.amp.autocast(self.device_str):
|
253
|
+
imgs = torch.stack([
|
254
|
+
self.preprocess_val(image).to(self.device)
|
255
|
+
for image in images_fixed
|
256
|
+
])
|
257
|
+
img_features = self.model.encode_image(imgs, **kwargs)
|
258
|
+
img_features /= img_features.norm(dim=-1, keepdim=True)
|
259
|
+
if to_numpy:
|
260
|
+
img_features = img_features.cpu().numpy()
|
235
261
|
return img_features
|
236
262
|
|
237
263
|
def get_text_features(
|
@@ -253,17 +279,25 @@ class ImgHandler:
|
|
253
279
|
Example:
|
254
280
|
get_text_features(["text1", "text2"], to_numpy=True)
|
255
281
|
"""
|
256
|
-
with torch.no_grad(), torch.amp.autocast(self.device_str):
|
257
|
-
txts = self.tokenizer(
|
258
|
-
texts,
|
259
|
-
context_length=self.model.context_length
|
260
|
-
).to(self.device)
|
261
|
-
txt_features = self.model.encode_text(txts, **kwargs)
|
262
|
-
txt_features /= txt_features.norm(dim=-1, keepdim=True)
|
263
|
-
if to_numpy:
|
264
|
-
txt_features = txt_features.cpu().numpy()
|
265
|
-
return txt_features
|
266
282
|
|
283
|
+
if self.model_type == "cpm":
|
284
|
+
with torch.no_grad():
|
285
|
+
txt_features = self.model(
|
286
|
+
text=texts,
|
287
|
+
image=[None] * len(texts),
|
288
|
+
tokenizer=self.tokenizer
|
289
|
+
).reps
|
290
|
+
elif self.model_type == "openclip":
|
291
|
+
with torch.no_grad(), torch.amp.autocast(self.device_str):
|
292
|
+
txts = self.tokenizer(
|
293
|
+
texts,
|
294
|
+
context_length=self.model.context_length
|
295
|
+
).to(self.device)
|
296
|
+
txt_features = self.model.encode_text(txts, **kwargs)
|
297
|
+
txt_features /= txt_features.norm(dim=-1, keepdim=True)
|
298
|
+
if to_numpy:
|
299
|
+
txt_features = txt_features.cpu().numpy()
|
300
|
+
return txt_features
|
267
301
|
|
268
302
|
def get_text_img_probs(
|
269
303
|
self,
|
@@ -1,5 +1,5 @@
|
|
1
1
|
hdl/__init__.py,sha256=GffnD0jLJdhkd-vo989v40N90sQbofkayRBwxc6TVhQ,72
|
2
|
-
hdl/_version.py,sha256=
|
2
|
+
hdl/_version.py,sha256=YLA3hy-44LcRh6Hcq_A-SyVyfXwkdObMoIBSVcngmDI,413
|
3
3
|
hdl/args/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
hdl/args/loss_args.py,sha256=s7YzSdd7IjD24rZvvOrxLLFqMZQb9YylxKeyelSdrTk,70
|
5
5
|
hdl/controllers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -132,13 +132,13 @@ hdl/utils/llm/chatgr.py,sha256=GO2G7g6YybduA5VCUuGjvEsJfC_6L7rycSnPeHMcxyM,2820
|
|
132
132
|
hdl/utils/llm/embs.py,sha256=Tf0FOYrOFZp7qQpEPiSCXzlgyHH0X9HVTUtsup74a9E,7174
|
133
133
|
hdl/utils/llm/extract.py,sha256=2sK_WJzmYIc8iuWaM9DA6Nw3_6q1O4lJ5pKpcZo-bBA,6512
|
134
134
|
hdl/utils/llm/llama_chat.py,sha256=watcHGOaz-bv3x-yDucYlGk5f8FiqfFhwWogrl334fk,4387
|
135
|
-
hdl/utils/llm/vis.py,sha256=
|
135
|
+
hdl/utils/llm/vis.py,sha256=RWeI6lSmzCDG2HJMq8-teuC7to4pPiR0ee2Hx1clbRw,17656
|
136
136
|
hdl/utils/llm/visrag.py,sha256=_PuKtmQIXD5bnmXwDWhTLdzOhgC42JiqdMNb1uKA7n8,9190
|
137
137
|
hdl/utils/schedulers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
138
138
|
hdl/utils/schedulers/norm_lr.py,sha256=bDwCmdEK-WkgxQMFBiMuchv8Mm7C0-GZJ6usm-PQk14,4461
|
139
139
|
hdl/utils/weather/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
140
140
|
hdl/utils/weather/weather.py,sha256=k11o6wM15kF8b9NMlEfrg68ak-SfSYLN3nOOflFUv-I,4381
|
141
|
-
hjxdl-0.2.
|
142
|
-
hjxdl-0.2.
|
143
|
-
hjxdl-0.2.
|
144
|
-
hjxdl-0.2.
|
141
|
+
hjxdl-0.2.30.dist-info/METADATA,sha256=aKoW2m7jTjCWQSTm4Kzxi7n5YWyXXxjHlmqHjZjG8D4,836
|
142
|
+
hjxdl-0.2.30.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
143
|
+
hjxdl-0.2.30.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
|
144
|
+
hjxdl-0.2.30.dist-info/RECORD,,
|
File without changes
|