hjxdl 0.1.60__py3-none-any.whl → 0.1.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hdl/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.1.60'
16
- __version_tuple__ = version_tuple = (0, 1, 60)
15
+ __version__ = version = '0.1.62'
16
+ __version_tuple__ = version_tuple = (0, 1, 62)
hdl/utils/llm/vis.py CHANGED
@@ -22,6 +22,7 @@ class ImgHandler:
22
22
  model_path,
23
23
  redis_host,
24
24
  redis_port,
25
+ model_name: str = None,
25
26
  device: str = None
26
27
  ) -> None:
27
28
  if device is None:
@@ -37,10 +38,13 @@ class ImgHandler:
37
38
  self.open_clip_cfg = json.load(
38
39
  open(Path(model_path) / Path("open_clip_config.json"))
39
40
  )
40
- self.model_name = (
41
- self.open_clip_cfg['model_cfg']['text_cfg']['hf_tokenizer_name']
42
- .split('/')[-1]
43
- )
41
+ if model_name is not None:
42
+ self.model_name = model_name
43
+ else:
44
+ self.model_name = (
45
+ self.open_clip_cfg['model_cfg']['text_cfg']['hf_tokenizer_name']
46
+ .split('/')[-1]
47
+ )
44
48
 
45
49
  self.model, self.preprocess_train, self.preprocess_val = (
46
50
  open_clip.create_model_and_transforms(
@@ -72,14 +76,15 @@ class ImgHandler:
72
76
  to_numpy = False,
73
77
  **kwargs
74
78
  ):
75
- imgs = torch.stack([
76
- self.preprocess_val(Image.open(image)).to(self.device)
77
- for image in images
78
- ])
79
- img_features = self.model.encode_image(imgs, **kwargs)
80
- img_features /= img_features.norm(dim=-1, keepdim=True)
81
- if to_numpy:
82
- img_features = img_features.cpu().numpy()
79
+ with torch.no_grad(), torch.cuda.amp.autocast():
80
+ imgs = torch.stack([
81
+ self.preprocess_val(Image.open(image)).to(self.device)
82
+ for image in images
83
+ ])
84
+ img_features = self.model.encode_image(imgs, **kwargs)
85
+ img_features /= img_features.norm(dim=-1, keepdim=True)
86
+ if to_numpy:
87
+ img_features = img_features.cpu().numpy()
83
88
  return img_features
84
89
 
85
90
  def get_text_features(
@@ -88,14 +93,15 @@ class ImgHandler:
88
93
  to_numpy = False,
89
94
  **kwargs
90
95
  ):
91
- txts = self.tokenizer(
92
- texts,
93
- context_length=self.model.context_length
94
- ).to(self.device)
95
- txt_features = self.model.encode_text(txts, **kwargs)
96
- txt_features /= txt_features.norm(dim=-1, keepdim=True)
97
- if to_numpy:
98
- txt_features = txt_features.cpu().numpy()
96
+ with torch.no_grad(), torch.cuda.amp.autocast():
97
+ txts = self.tokenizer(
98
+ texts,
99
+ context_length=self.model.context_length
100
+ ).to(self.device)
101
+ txt_features = self.model.encode_text(txts, **kwargs)
102
+ txt_features /= txt_features.norm(dim=-1, keepdim=True)
103
+ if to_numpy:
104
+ txt_features = txt_features.cpu().numpy()
99
105
  return txt_features
100
106
 
101
107
 
@@ -103,11 +109,18 @@ class ImgHandler:
103
109
  self,
104
110
  texts,
105
111
  images,
112
+ probs: bool = False,
113
+ to_numpy: bool = False,
106
114
  **kwargs
107
115
  ):
108
- image_features = self.get_img_features(images, **kwargs)
109
- text_features = self.get_text_features(texts, **kwargs)
110
- text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
116
+ with torch.no_grad(), torch.cuda.amp.autocast():
117
+ image_features = self.get_img_features(images, **kwargs)
118
+ text_features = self.get_text_features(texts, **kwargs)
119
+ text_probs = (100.0 * image_features @ text_features.T)
120
+ if probs:
121
+ text_probs = text_probs.softmax(dim=-1)
122
+ if to_numpy:
123
+ text_probs = text_probs.cpu().numpy()
111
124
  return text_probs
112
125
 
113
126
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: hjxdl
3
- Version: 0.1.60
3
+ Version: 0.1.62
4
4
  Summary: A collection of functions for Jupyter notebooks
5
5
  Home-page: https://github.com/huluxiaohuowa/hdl
6
6
  Author: Jianxing Hu
@@ -1,5 +1,5 @@
1
1
  hdl/__init__.py,sha256=GffnD0jLJdhkd-vo989v40N90sQbofkayRBwxc6TVhQ,72
2
- hdl/_version.py,sha256=I8yYicE18LeNuqpxw5_Q3MjNjqUav4r_xspI3E7RXfQ,413
2
+ hdl/_version.py,sha256=rG3DSNYFAHXdVeQ2WjD0Z0Cpa8cTqknlt8QLq3_8uSk,413
3
3
  hdl/args/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  hdl/args/loss_args.py,sha256=s7YzSdd7IjD24rZvvOrxLLFqMZQb9YylxKeyelSdrTk,70
5
5
  hdl/controllers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -131,12 +131,12 @@ hdl/utils/llm/chat.py,sha256=sk7Lw5Oa30k-l2fnJknkMmTc5zkBeEKsR981aeFhH5s,11907
131
131
  hdl/utils/llm/embs.py,sha256=Tf0FOYrOFZp7qQpEPiSCXzlgyHH0X9HVTUtsup74a9E,7174
132
132
  hdl/utils/llm/extract.py,sha256=2sK_WJzmYIc8iuWaM9DA6Nw3_6q1O4lJ5pKpcZo-bBA,6512
133
133
  hdl/utils/llm/llama_chat.py,sha256=watcHGOaz-bv3x-yDucYlGk5f8FiqfFhwWogrl334fk,4387
134
- hdl/utils/llm/vis.py,sha256=ckxGl21MZchCChLJRMX13j3xgIHpW4cfO7Xl5ySGAHw,3147
134
+ hdl/utils/llm/vis.py,sha256=mmitBc5zRwT98oHIKmkhPahc8LI5YhAJlMtvCsRpT4c,3734
135
135
  hdl/utils/schedulers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
136
136
  hdl/utils/schedulers/norm_lr.py,sha256=bDwCmdEK-WkgxQMFBiMuchv8Mm7C0-GZJ6usm-PQk14,4461
137
137
  hdl/utils/weather/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
138
138
  hdl/utils/weather/weather.py,sha256=k11o6wM15kF8b9NMlEfrg68ak-SfSYLN3nOOflFUv-I,4381
139
- hjxdl-0.1.60.dist-info/METADATA,sha256=nc4_lLnBG1ulMSJ7JAa8FF2v1ot5FQohFQKTqdtESic,880
140
- hjxdl-0.1.60.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
141
- hjxdl-0.1.60.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
142
- hjxdl-0.1.60.dist-info/RECORD,,
139
+ hjxdl-0.1.62.dist-info/METADATA,sha256=3Gb5-bSC047umAiQbseqxL3e5yvt82RNt1Usor3cGkg,880
140
+ hjxdl-0.1.62.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
141
+ hjxdl-0.1.62.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
142
+ hjxdl-0.1.62.dist-info/RECORD,,
File without changes