hjxdl 0.1.57__py3-none-any.whl → 0.1.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hdl/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.1.57'
16
- __version_tuple__ = version_tuple = (0, 1, 57)
15
+ __version__ = version = '0.1.58'
16
+ __version_tuple__ = version_tuple = (0, 1, 58)
hdl/utils/llm/vis.py CHANGED
@@ -1,19 +1,20 @@
1
- import requests
1
+ from pathlib import Path
2
+ import json
2
3
 
3
4
  import torch
4
5
  import numpy as np
5
6
  from PIL import Image
6
- from transformers import ChineseCLIPProcessor, ChineseCLIPModel
7
+ # from transformers import ChineseCLIPProcessor, ChineseCLIPModel
8
+ import open_clip
7
9
 
8
10
  from ..database_tools.connect import conn_redis
9
11
 
10
12
 
11
- # url = "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg"
12
- # image = Image.open(requests.get(url, stream=True).raw)
13
13
  __all__ = [
14
14
  "ImgHandler"
15
15
  ]
16
16
 
17
+ HF_HUB_PREFIX = "hf-hub:"
17
18
 
18
19
  class ImgHandler:
19
20
  def __init__(
@@ -29,9 +30,31 @@ class ImgHandler:
29
30
  else torch.device("cpu")
30
31
  else:
31
32
  self.device = device
33
+ ckpt_file = (
34
+ Path(model_path) / Path("open_clip_pytorch_model.bin")
35
+ ).as_posix()
32
36
 
33
- self.model = ChineseCLIPModel.from_pretrained(model_path).to(self.device)
34
- self.processor = ChineseCLIPProcessor.from_pretrained(model_path)
37
+ self.open_clip_cfg = json.load(
38
+ open(Path(model_path) / Path("open_clip_config.json"))
39
+ )
40
+ self.model_name = (
41
+ self.open_clip_cfg['model_cfg']['text_cfg']['hf_tokenizer_name']
42
+ .split('/')[-1]
43
+ )
44
+
45
+ self.model, self.preprocess_train, self.preprocess_val = (
46
+ open_clip.create_model_and_transforms(
47
+ model_name=self.model_name,
48
+ pretrained=ckpt_file,
49
+ device=self.device,
50
+ # precision=precision
51
+ )
52
+ )
53
+ self.tokenizer = open_clip.get_tokenizer(
54
+ HF_HUB_PREFIX + model_path
55
+ )
56
+ # self.model = ChineseCLIPModel.from_pretrained(model_path).to(self.device)
57
+ # self.processor = ChineseCLIPProcessor.from_pretrained(model_path)
35
58
  self.redis_host = redis_host
36
59
  self.redis_port = redis_port
37
60
  self._redis_conn = None
@@ -43,48 +66,45 @@ class ImgHandler:
43
66
  self._redis_conn = conn_redis(self.redis_host, self.redis_port)
44
67
  return self._redis_conn
45
68
 
46
- def get_img_features(self, images, **kwargs):
47
- inputs = self.processor(
48
- images=images,
49
- return_tensors="pt",
50
- **kwargs
51
- ).to(self.device)
52
- image_features = self.model.get_image_features(**inputs)
53
- image_features = image_features / \
54
- image_features.norm(p=2, dim=-1, keepdim=True)
55
- return image_features
69
+ def get_img_features(
70
+ self,
71
+ images,
72
+ to_numpy = False,
73
+ **kwargs
74
+ ):
75
+ imgs = [
76
+ self.preprocess_val(Image.open(image)).unsqueeze(0).to(self.device)
77
+ for image in images
78
+ ]
79
+ img_features = self.model.encode_image(imgs, **kwargs)
80
+ img_features /= img_features.norm(dim=-1, keepdim=True)
81
+ if to_numpy:
82
+ img_features = img_features.cpu().numpy()
83
+ return img_features
56
84
 
57
85
  def get_text_features(
58
86
  self,
59
87
  texts,
88
+ to_numpy = False,
60
89
  **kwargs
61
90
  ):
62
- inputs = self.processor(
63
- text=texts,
64
- padding=True,
65
- return_tensors="pt",
66
- **kwargs
67
- ).to(self.device)
68
- text_features = self.model.get_text_features(**inputs)
69
- text_features = text_features / \
70
- text_features.norm(p=2, dim=-1, keepdim=True)
71
- return text_features
72
-
73
- def get_text_img_sims(
91
+ txts = self.tokenizer(texts).to(self.device)
92
+ txt_features = self.model.encode_text(txts, **kwargs)
93
+ txt_features /= txt_features.norm(dim=-1, keepdim=True)
94
+ if to_numpy:
95
+ txt_features = txt_features.cpu().numpy()
96
+ return txt_features
97
+
98
+
99
+ def get_text_img_probs(
74
100
  self,
75
101
  texts,
76
102
  images,
77
103
  **kwargs
78
104
  ):
79
- inputs = self.processor(
80
- text=texts,
81
- images=images,
82
- return_tensors="pt",
83
- padding=True,
84
- **kwargs
85
- ).to(self.device)
86
- outputs = self.model(**inputs)
87
- logits_per_image = outputs.logits_per_image # this is the image-text similarity score
88
- probs = logits_per_image.softmax(dim=1)
89
- return probs
105
+ image_features = self.get_img_features(images, **kwargs)
106
+ text_features = self.get_text_features(texts, **kwargs)
107
+ text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
108
+ return text_probs
109
+
90
110
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: hjxdl
3
- Version: 0.1.57
3
+ Version: 0.1.58
4
4
  Summary: A collection of functions for Jupyter notebooks
5
5
  Home-page: https://github.com/huluxiaohuowa/hdl
6
6
  Author: Jianxing Hu
@@ -22,6 +22,7 @@ Requires-Dist: opencv-python
22
22
  Requires-Dist: redis[hiredis]
23
23
  Requires-Dist: psycopg[binary]
24
24
  Requires-Dist: Pillow
25
+ Requires-Dist: open-clip-torch
25
26
 
26
27
  # DL framework by Jianxing
27
28
 
@@ -1,5 +1,5 @@
1
1
  hdl/__init__.py,sha256=GffnD0jLJdhkd-vo989v40N90sQbofkayRBwxc6TVhQ,72
2
- hdl/_version.py,sha256=1L2CAEYH8rRBQ01naMSlAdbRLOkKJXOpIY9i96QXS-s,413
2
+ hdl/_version.py,sha256=0coHK1MSWGlBW3NiGGuClrVbkHMpMYKfQkg34r9FVSU,413
3
3
  hdl/args/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  hdl/args/loss_args.py,sha256=s7YzSdd7IjD24rZvvOrxLLFqMZQb9YylxKeyelSdrTk,70
5
5
  hdl/controllers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -131,12 +131,12 @@ hdl/utils/llm/chat.py,sha256=sk7Lw5Oa30k-l2fnJknkMmTc5zkBeEKsR981aeFhH5s,11907
131
131
  hdl/utils/llm/embs.py,sha256=Tf0FOYrOFZp7qQpEPiSCXzlgyHH0X9HVTUtsup74a9E,7174
132
132
  hdl/utils/llm/extract.py,sha256=2sK_WJzmYIc8iuWaM9DA6Nw3_6q1O4lJ5pKpcZo-bBA,6512
133
133
  hdl/utils/llm/llama_chat.py,sha256=watcHGOaz-bv3x-yDucYlGk5f8FiqfFhwWogrl334fk,4387
134
- hdl/utils/llm/vis.py,sha256=dzXpv9xtm9qxZSj1zPTIwq2sskzMPsPgh30_LjAcDgU,2480
134
+ hdl/utils/llm/vis.py,sha256=mbtSG76h8PjWCZ4Pp6k5rlfTONM-K8el6f3D8kF0U0c,3071
135
135
  hdl/utils/schedulers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
136
136
  hdl/utils/schedulers/norm_lr.py,sha256=bDwCmdEK-WkgxQMFBiMuchv8Mm7C0-GZJ6usm-PQk14,4461
137
137
  hdl/utils/weather/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
138
138
  hdl/utils/weather/weather.py,sha256=k11o6wM15kF8b9NMlEfrg68ak-SfSYLN3nOOflFUv-I,4381
139
- hjxdl-0.1.57.dist-info/METADATA,sha256=cLXP4zGr3OY_Xcx0mjMLyj6jyRJBH3vp2PxmkmS-hYA,849
140
- hjxdl-0.1.57.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
141
- hjxdl-0.1.57.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
142
- hjxdl-0.1.57.dist-info/RECORD,,
139
+ hjxdl-0.1.58.dist-info/METADATA,sha256=NP2vFvS46Yv4PIoXOLassCxLHwq0L3WkuguoY7aiLcQ,880
140
+ hjxdl-0.1.58.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
141
+ hjxdl-0.1.58.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
142
+ hjxdl-0.1.58.dist-info/RECORD,,
File without changes