hjxdl 0.2.28__py3-none-any.whl → 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hdl/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.28'
16
- __version_tuple__ = version_tuple = (0, 2, 28)
15
+ __version__ = version = '0.2.30'
16
+ __version_tuple__ = version_tuple = (0, 2, 30)
hdl/utils/llm/vis.py CHANGED
@@ -10,8 +10,10 @@ import torch
10
10
  import numpy as np
11
11
  from PIL import Image
12
12
  # from transformers import ChineseCLIPProcessor, ChineseCLIPModel
13
+ from transformers import AutoModel
14
+ from transformers import AutoTokenizer
13
15
  import open_clip
14
- import natsort
16
+ # import natsort
15
17
  from redis.commands.search.field import VectorField
16
18
  from redis.commands.search.indexDefinition import IndexDefinition, IndexType
17
19
  from hdl.jupyfuncs.show.pbar import tqdm
@@ -96,7 +98,7 @@ def imgbase64_to_pilimg(img_base64: str):
96
98
  BytesIO(
97
99
  base64.b64decode(img_base64.split(",")[-1])
98
100
  )
99
- )
101
+ ).convert('RGB')
100
102
  return img_pil
101
103
 
102
104
 
@@ -124,6 +126,7 @@ class ImgHandler:
124
126
  model_path,
125
127
  conn=None,
126
128
  model_name: str = None,
129
+ model_type: str = "openclip",
127
130
  device: str = "cpu",
128
131
  num_vec_dim: int = None,
129
132
  load_model: bool = True,
@@ -143,9 +146,11 @@ class ImgHandler:
143
146
  None
144
147
  """
145
148
 
149
+ self.device_str = device
146
150
  self.device = torch.device(device)
147
151
  self.model_path = model_path
148
152
  self.model_name = model_name
153
+ self.model_type = model_type
149
154
 
150
155
  self.db_conn = conn
151
156
  self.num_vec_dim = num_vec_dim
@@ -163,32 +168,46 @@ class ImgHandler:
163
168
  Returns:
164
169
  None
165
170
  """
166
- ckpt_file = (
167
- Path(self.model_path) / Path("open_clip_pytorch_model.bin")
168
- ).as_posix()
169
- self.open_clip_cfg = json.load(
170
- open(Path(self.model_path) / Path("open_clip_config.json"))
171
- )
172
171
 
173
- if self.model_name is None:
174
- self.model_name = (
175
- self.open_clip_cfg['model_cfg']['text_cfg']['hf_tokenizer_name']
176
- .split('/')[-1]
172
+ if self.model_type == "cpm":
173
+ self.tokenizer = AutoTokenizer.from_pretrained(
174
+ self.model_path,
175
+ trust_remote_code=True
176
+ )
177
+ self.model = AutoModel.from_pretrained(
178
+ self.model_path,
179
+ trust_remote_code=True
180
+ )
181
+ self.model.to(self.device)
182
+ self.num_vec_dim = 2304
183
+
184
+ elif self.model_type == "openclip":
185
+ ckpt_file = (
186
+ Path(self.model_path) / Path("open_clip_pytorch_model.bin")
187
+ ).as_posix()
188
+ self.open_clip_cfg = json.load(
189
+ open(Path(self.model_path) / Path("open_clip_config.json"))
177
190
  )
178
191
 
179
- self.model, self.preprocess_train, self.preprocess_val = (
180
- open_clip.create_model_and_transforms(
181
- model_name=self.model_name,
182
- pretrained=ckpt_file,
183
- device=self.device,
184
- # precision=precision
192
+ if self.model_name is None:
193
+ self.model_name = (
194
+ self.open_clip_cfg['model_cfg']['text_cfg']['hf_tokenizer_name']
195
+ .split('/')[-1]
196
+ )
197
+
198
+ self.model, self.preprocess_train, self.preprocess_val = (
199
+ open_clip.create_model_and_transforms(
200
+ model_name=self.model_name,
201
+ pretrained=ckpt_file,
202
+ device=self.device,
203
+ # precision=precision
204
+ )
205
+ )
206
+ if self.num_vec_dim is None:
207
+ self.num_vec_dim = self.open_clip_cfg["model_cfg"]["embed_dim"]
208
+ self.tokenizer = open_clip.get_tokenizer(
209
+ HF_HUB_PREFIX + self.model_path
185
210
  )
186
- )
187
- if self.num_vec_dim is None:
188
- self.num_vec_dim = self.open_clip_cfg["model_cfg"]["embed_dim"]
189
- self.tokenizer = open_clip.get_tokenizer(
190
- HF_HUB_PREFIX + self.model_path
191
- )
192
211
 
193
212
  def get_img_features(
194
213
  self,
@@ -221,16 +240,24 @@ class ImgHandler:
221
240
  f"Not supported image type for {type(img)}"
222
241
  )
223
242
 
224
-
225
- with torch.no_grad(), torch.amp.autocast("cuda"):
226
- imgs = torch.stack([
227
- self.preprocess_val(image).to(self.device)
228
- for image in images_fixed
229
- ])
230
- img_features = self.model.encode_image(imgs, **kwargs)
231
- img_features /= img_features.norm(dim=-1, keepdim=True)
232
- if to_numpy:
233
- img_features = img_features.cpu().numpy()
243
+ if self.model_type == "cpm":
244
+ with torch.no_grad():
245
+ img_features = self.model(
246
+ text=[""] * len(images_fixed),
247
+ image=images_fixed,
248
+ tokenizer=self.tokenizer
249
+ ).reps
250
+
251
+ if self.model_type == "openclip":
252
+ with torch.no_grad(), torch.amp.autocast(self.device_str):
253
+ imgs = torch.stack([
254
+ self.preprocess_val(image).to(self.device)
255
+ for image in images_fixed
256
+ ])
257
+ img_features = self.model.encode_image(imgs, **kwargs)
258
+ img_features /= img_features.norm(dim=-1, keepdim=True)
259
+ if to_numpy:
260
+ img_features = img_features.cpu().numpy()
234
261
  return img_features
235
262
 
236
263
  def get_text_features(
@@ -252,17 +279,25 @@ class ImgHandler:
252
279
  Example:
253
280
  get_text_features(["text1", "text2"], to_numpy=True)
254
281
  """
255
- with torch.no_grad(), torch.amp.autocast("cuda"):
256
- txts = self.tokenizer(
257
- texts,
258
- context_length=self.model.context_length
259
- ).to(self.device)
260
- txt_features = self.model.encode_text(txts, **kwargs)
261
- txt_features /= txt_features.norm(dim=-1, keepdim=True)
262
- if to_numpy:
263
- txt_features = txt_features.cpu().numpy()
264
- return txt_features
265
282
 
283
+ if self.model_type == "cpm":
284
+ with torch.no_grad():
285
+ txt_features = self.model(
286
+ text=texts,
287
+ image=[None] * len(texts),
288
+ tokenizer=self.tokenizer
289
+ ).reps
290
+ elif self.model_type == "openclip":
291
+ with torch.no_grad(), torch.amp.autocast(self.device_str):
292
+ txts = self.tokenizer(
293
+ texts,
294
+ context_length=self.model.context_length
295
+ ).to(self.device)
296
+ txt_features = self.model.encode_text(txts, **kwargs)
297
+ txt_features /= txt_features.norm(dim=-1, keepdim=True)
298
+ if to_numpy:
299
+ txt_features = txt_features.cpu().numpy()
300
+ return txt_features
266
301
 
267
302
  def get_text_img_probs(
268
303
  self,
@@ -284,7 +319,7 @@ class ImgHandler:
284
319
  Returns:
285
320
  torch.Tensor or numpy.ndarray: Text-image association probabilities.
286
321
  """
287
- with torch.no_grad(), torch.amp.autocast("cuda"):
322
+ with torch.no_grad(), torch.amp.autocast(self.device_str):
288
323
  image_features = self.get_img_features(images, **kwargs)
289
324
  text_features = self.get_text_features(texts, **kwargs)
290
325
  text_probs = (100.0 * image_features @ text_features.T)
@@ -313,7 +348,7 @@ class ImgHandler:
313
348
  Returns:
314
349
  torch.Tensor or numpy.ndarray: Similarity scores between the two sets of images.
315
350
  """
316
- with torch.no_grad(), torch.amp.autocast("cuda"):
351
+ with torch.no_grad(), torch.amp.autocast(self.device_str):
317
352
  img1_feats = self.get_img_features(images1, **kwargs)
318
353
  img2_feats = self.get_img_features(images2, **kwargs)
319
354
  sims = img1_feats @ img2_feats.T
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: hjxdl
3
- Version: 0.2.28
3
+ Version: 0.2.30
4
4
  Summary: A collection of functions for Jupyter notebooks
5
5
  Home-page: https://github.com/huluxiaohuowa/hdl
6
6
  Author: Jianxing Hu
@@ -1,5 +1,5 @@
1
1
  hdl/__init__.py,sha256=GffnD0jLJdhkd-vo989v40N90sQbofkayRBwxc6TVhQ,72
2
- hdl/_version.py,sha256=wEQ7TyF4Askd44xloaNNJbxFTrLO2VUUsYcGF-OOTok,413
2
+ hdl/_version.py,sha256=YLA3hy-44LcRh6Hcq_A-SyVyfXwkdObMoIBSVcngmDI,413
3
3
  hdl/args/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  hdl/args/loss_args.py,sha256=s7YzSdd7IjD24rZvvOrxLLFqMZQb9YylxKeyelSdrTk,70
5
5
  hdl/controllers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -132,13 +132,13 @@ hdl/utils/llm/chatgr.py,sha256=GO2G7g6YybduA5VCUuGjvEsJfC_6L7rycSnPeHMcxyM,2820
132
132
  hdl/utils/llm/embs.py,sha256=Tf0FOYrOFZp7qQpEPiSCXzlgyHH0X9HVTUtsup74a9E,7174
133
133
  hdl/utils/llm/extract.py,sha256=2sK_WJzmYIc8iuWaM9DA6Nw3_6q1O4lJ5pKpcZo-bBA,6512
134
134
  hdl/utils/llm/llama_chat.py,sha256=watcHGOaz-bv3x-yDucYlGk5f8FiqfFhwWogrl334fk,4387
135
- hdl/utils/llm/vis.py,sha256=KCIsgGdIfrHX_snL2GBeBfUc8MNTyJ0G_VxDvdT-sp8,16223
135
+ hdl/utils/llm/vis.py,sha256=RWeI6lSmzCDG2HJMq8-teuC7to4pPiR0ee2Hx1clbRw,17656
136
136
  hdl/utils/llm/visrag.py,sha256=_PuKtmQIXD5bnmXwDWhTLdzOhgC42JiqdMNb1uKA7n8,9190
137
137
  hdl/utils/schedulers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
138
138
  hdl/utils/schedulers/norm_lr.py,sha256=bDwCmdEK-WkgxQMFBiMuchv8Mm7C0-GZJ6usm-PQk14,4461
139
139
  hdl/utils/weather/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
140
140
  hdl/utils/weather/weather.py,sha256=k11o6wM15kF8b9NMlEfrg68ak-SfSYLN3nOOflFUv-I,4381
141
- hjxdl-0.2.28.dist-info/METADATA,sha256=IlknmsmWEQ29ZkZVxTUIJhIyoWd5YjEpKYuXfXsePLE,836
142
- hjxdl-0.2.28.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
143
- hjxdl-0.2.28.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
144
- hjxdl-0.2.28.dist-info/RECORD,,
141
+ hjxdl-0.2.30.dist-info/METADATA,sha256=aKoW2m7jTjCWQSTm4Kzxi7n5YWyXXxjHlmqHjZjG8D4,836
142
+ hjxdl-0.2.30.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
143
+ hjxdl-0.2.30.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
144
+ hjxdl-0.2.30.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.2.0)
2
+ Generator: setuptools (75.3.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5