hjxdl 0.2.29__py3-none-any.whl → 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hdl/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.29'
16
- __version_tuple__ = version_tuple = (0, 2, 29)
15
+ __version__ = version = '0.2.30'
16
+ __version_tuple__ = version_tuple = (0, 2, 30)
hdl/utils/llm/vis.py CHANGED
@@ -10,8 +10,10 @@ import torch
10
10
  import numpy as np
11
11
  from PIL import Image
12
12
  # from transformers import ChineseCLIPProcessor, ChineseCLIPModel
13
+ from transformers import AutoModel
14
+ from transformers import AutoTokenizer
13
15
  import open_clip
14
- import natsort
16
+ # import natsort
15
17
  from redis.commands.search.field import VectorField
16
18
  from redis.commands.search.indexDefinition import IndexDefinition, IndexType
17
19
  from hdl.jupyfuncs.show.pbar import tqdm
@@ -96,7 +98,7 @@ def imgbase64_to_pilimg(img_base64: str):
96
98
  BytesIO(
97
99
  base64.b64decode(img_base64.split(",")[-1])
98
100
  )
99
- )
101
+ ).convert('RGB')
100
102
  return img_pil
101
103
 
102
104
 
@@ -124,6 +126,7 @@ class ImgHandler:
124
126
  model_path,
125
127
  conn=None,
126
128
  model_name: str = None,
129
+ model_type: str = "openclip",
127
130
  device: str = "cpu",
128
131
  num_vec_dim: int = None,
129
132
  load_model: bool = True,
@@ -147,6 +150,7 @@ class ImgHandler:
147
150
  self.device = torch.device(device)
148
151
  self.model_path = model_path
149
152
  self.model_name = model_name
153
+ self.model_type = model_type
150
154
 
151
155
  self.db_conn = conn
152
156
  self.num_vec_dim = num_vec_dim
@@ -164,32 +168,46 @@ class ImgHandler:
164
168
  Returns:
165
169
  None
166
170
  """
167
- ckpt_file = (
168
- Path(self.model_path) / Path("open_clip_pytorch_model.bin")
169
- ).as_posix()
170
- self.open_clip_cfg = json.load(
171
- open(Path(self.model_path) / Path("open_clip_config.json"))
172
- )
173
171
 
174
- if self.model_name is None:
175
- self.model_name = (
176
- self.open_clip_cfg['model_cfg']['text_cfg']['hf_tokenizer_name']
177
- .split('/')[-1]
172
+ if self.model_type == "cpm":
173
+ self.tokenizer = AutoTokenizer.from_pretrained(
174
+ self.model_path,
175
+ trust_remote_code=True
176
+ )
177
+ self.model = AutoModel.from_pretrained(
178
+ self.model_path,
179
+ trust_remote_code=True
178
180
  )
181
+ self.model.to(self.device)
182
+ self.num_vec_dim = 2304
183
+
184
+ elif self.model_type == "openclip":
185
+ ckpt_file = (
186
+ Path(self.model_path) / Path("open_clip_pytorch_model.bin")
187
+ ).as_posix()
188
+ self.open_clip_cfg = json.load(
189
+ open(Path(self.model_path) / Path("open_clip_config.json"))
190
+ )
191
+
192
+ if self.model_name is None:
193
+ self.model_name = (
194
+ self.open_clip_cfg['model_cfg']['text_cfg']['hf_tokenizer_name']
195
+ .split('/')[-1]
196
+ )
179
197
 
180
- self.model, self.preprocess_train, self.preprocess_val = (
181
- open_clip.create_model_and_transforms(
182
- model_name=self.model_name,
183
- pretrained=ckpt_file,
184
- device=self.device,
185
- # precision=precision
198
+ self.model, self.preprocess_train, self.preprocess_val = (
199
+ open_clip.create_model_and_transforms(
200
+ model_name=self.model_name,
201
+ pretrained=ckpt_file,
202
+ device=self.device,
203
+ # precision=precision
204
+ )
205
+ )
206
+ if self.num_vec_dim is None:
207
+ self.num_vec_dim = self.open_clip_cfg["model_cfg"]["embed_dim"]
208
+ self.tokenizer = open_clip.get_tokenizer(
209
+ HF_HUB_PREFIX + self.model_path
186
210
  )
187
- )
188
- if self.num_vec_dim is None:
189
- self.num_vec_dim = self.open_clip_cfg["model_cfg"]["embed_dim"]
190
- self.tokenizer = open_clip.get_tokenizer(
191
- HF_HUB_PREFIX + self.model_path
192
- )
193
211
 
194
212
  def get_img_features(
195
213
  self,
@@ -222,16 +240,24 @@ class ImgHandler:
222
240
  f"Not supported image type for {type(img)}"
223
241
  )
224
242
 
225
-
226
- with torch.no_grad(), torch.amp.autocast(self.device_str):
227
- imgs = torch.stack([
228
- self.preprocess_val(image).to(self.device)
229
- for image in images_fixed
230
- ])
231
- img_features = self.model.encode_image(imgs, **kwargs)
232
- img_features /= img_features.norm(dim=-1, keepdim=True)
233
- if to_numpy:
234
- img_features = img_features.cpu().numpy()
243
+ if self.model_type == "cpm":
244
+ with torch.no_grad():
245
+ img_features = self.model(
246
+ text=[""] * len(images_fixed),
247
+ image=images_fixed,
248
+ tokenizer=self.tokenizer
249
+ ).reps
250
+
251
+ if self.model_type == "openclip":
252
+ with torch.no_grad(), torch.amp.autocast(self.device_str):
253
+ imgs = torch.stack([
254
+ self.preprocess_val(image).to(self.device)
255
+ for image in images_fixed
256
+ ])
257
+ img_features = self.model.encode_image(imgs, **kwargs)
258
+ img_features /= img_features.norm(dim=-1, keepdim=True)
259
+ if to_numpy:
260
+ img_features = img_features.cpu().numpy()
235
261
  return img_features
236
262
 
237
263
  def get_text_features(
@@ -253,17 +279,25 @@ class ImgHandler:
253
279
  Example:
254
280
  get_text_features(["text1", "text2"], to_numpy=True)
255
281
  """
256
- with torch.no_grad(), torch.amp.autocast(self.device_str):
257
- txts = self.tokenizer(
258
- texts,
259
- context_length=self.model.context_length
260
- ).to(self.device)
261
- txt_features = self.model.encode_text(txts, **kwargs)
262
- txt_features /= txt_features.norm(dim=-1, keepdim=True)
263
- if to_numpy:
264
- txt_features = txt_features.cpu().numpy()
265
- return txt_features
266
282
 
283
+ if self.model_type == "cpm":
284
+ with torch.no_grad():
285
+ txt_features = self.model(
286
+ text=texts,
287
+ image=[None] * len(texts),
288
+ tokenizer=self.tokenizer
289
+ ).reps
290
+ elif self.model_type == "openclip":
291
+ with torch.no_grad(), torch.amp.autocast(self.device_str):
292
+ txts = self.tokenizer(
293
+ texts,
294
+ context_length=self.model.context_length
295
+ ).to(self.device)
296
+ txt_features = self.model.encode_text(txts, **kwargs)
297
+ txt_features /= txt_features.norm(dim=-1, keepdim=True)
298
+ if to_numpy:
299
+ txt_features = txt_features.cpu().numpy()
300
+ return txt_features
267
301
 
268
302
  def get_text_img_probs(
269
303
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: hjxdl
3
- Version: 0.2.29
3
+ Version: 0.2.30
4
4
  Summary: A collection of functions for Jupyter notebooks
5
5
  Home-page: https://github.com/huluxiaohuowa/hdl
6
6
  Author: Jianxing Hu
@@ -1,5 +1,5 @@
1
1
  hdl/__init__.py,sha256=GffnD0jLJdhkd-vo989v40N90sQbofkayRBwxc6TVhQ,72
2
- hdl/_version.py,sha256=AtwvoTC96AXOg97Emp1_Wmo7L-xUzOR7aFTUto-9JfA,413
2
+ hdl/_version.py,sha256=YLA3hy-44LcRh6Hcq_A-SyVyfXwkdObMoIBSVcngmDI,413
3
3
  hdl/args/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  hdl/args/loss_args.py,sha256=s7YzSdd7IjD24rZvvOrxLLFqMZQb9YylxKeyelSdrTk,70
5
5
  hdl/controllers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -132,13 +132,13 @@ hdl/utils/llm/chatgr.py,sha256=GO2G7g6YybduA5VCUuGjvEsJfC_6L7rycSnPeHMcxyM,2820
132
132
  hdl/utils/llm/embs.py,sha256=Tf0FOYrOFZp7qQpEPiSCXzlgyHH0X9HVTUtsup74a9E,7174
133
133
  hdl/utils/llm/extract.py,sha256=2sK_WJzmYIc8iuWaM9DA6Nw3_6q1O4lJ5pKpcZo-bBA,6512
134
134
  hdl/utils/llm/llama_chat.py,sha256=watcHGOaz-bv3x-yDucYlGk5f8FiqfFhwWogrl334fk,4387
135
- hdl/utils/llm/vis.py,sha256=Kixrhc3eByHjdxiAcB2BnthsS31dHMavp6qc2JX46Dc,16292
135
+ hdl/utils/llm/vis.py,sha256=RWeI6lSmzCDG2HJMq8-teuC7to4pPiR0ee2Hx1clbRw,17656
136
136
  hdl/utils/llm/visrag.py,sha256=_PuKtmQIXD5bnmXwDWhTLdzOhgC42JiqdMNb1uKA7n8,9190
137
137
  hdl/utils/schedulers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
138
138
  hdl/utils/schedulers/norm_lr.py,sha256=bDwCmdEK-WkgxQMFBiMuchv8Mm7C0-GZJ6usm-PQk14,4461
139
139
  hdl/utils/weather/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
140
140
  hdl/utils/weather/weather.py,sha256=k11o6wM15kF8b9NMlEfrg68ak-SfSYLN3nOOflFUv-I,4381
141
- hjxdl-0.2.29.dist-info/METADATA,sha256=7qiDTPY06pouX5FgbV4ekWvdDdXKae-Q9m1iGrIO1GA,836
142
- hjxdl-0.2.29.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
143
- hjxdl-0.2.29.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
144
- hjxdl-0.2.29.dist-info/RECORD,,
141
+ hjxdl-0.2.30.dist-info/METADATA,sha256=aKoW2m7jTjCWQSTm4Kzxi7n5YWyXXxjHlmqHjZjG8D4,836
142
+ hjxdl-0.2.30.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
143
+ hjxdl-0.2.30.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
144
+ hjxdl-0.2.30.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.2.0)
2
+ Generator: setuptools (75.3.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5