hjxdl 0.2.29__py3-none-any.whl → 0.2.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hdl/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.29'
16
- __version_tuple__ = version_tuple = (0, 2, 29)
15
+ __version__ = version = '0.2.31'
16
+ __version_tuple__ = version_tuple = (0, 2, 31)
hdl/utils/llm/chatgr.py CHANGED
@@ -4,6 +4,19 @@ from .chat import OpenAI_M
4
4
 
5
5
  # 定义流式输出的生成函数
6
6
  def chat_with_llm(user_input, chat_history=[]):
7
+ """
8
+ Facilitates a chat interaction with a language model (LLM).
9
+ This function takes user input and maintains a chat history. It streams the response from the LLM and updates the chat history in real-time.
10
+ Args:
11
+ user_input (str): The input message from the user.
12
+ chat_history (list, optional): A list of tuples representing the chat history. Each tuple contains two strings: the user's message and the bot's response. Defaults to an empty list.
13
+ Yields:
14
+ tuple: A tuple containing three elements:
15
+ - An empty string (for compatibility with certain frameworks).
16
+ - The updated chat history including the latest user message and the bot's response.
17
+ - The same updated chat history.
18
+ """
19
+
7
20
  chat_history.append(("User: " + user_input, "Bot: ")) # 初始先追加用户消息
8
21
  yield "", chat_history, chat_history # 返回用户消息
9
22
 
hdl/utils/llm/vis.py CHANGED
@@ -10,8 +10,10 @@ import torch
10
10
  import numpy as np
11
11
  from PIL import Image
12
12
  # from transformers import ChineseCLIPProcessor, ChineseCLIPModel
13
+ from transformers import AutoModel
14
+ from transformers import AutoTokenizer
13
15
  import open_clip
14
- import natsort
16
+ # import natsort
15
17
  from redis.commands.search.field import VectorField
16
18
  from redis.commands.search.indexDefinition import IndexDefinition, IndexType
17
19
  from hdl.jupyfuncs.show.pbar import tqdm
@@ -96,7 +98,7 @@ def imgbase64_to_pilimg(img_base64: str):
96
98
  BytesIO(
97
99
  base64.b64decode(img_base64.split(",")[-1])
98
100
  )
99
- )
101
+ ).convert('RGB')
100
102
  return img_pil
101
103
 
102
104
 
@@ -119,34 +121,72 @@ def pilimg_to_base64(pilimg):
119
121
 
120
122
 
121
123
  class ImgHandler:
124
+ """
125
+ ImgHandler is a class for handling image processing tasks using pretrained models.
126
+ Attributes:
127
+ device_str (str): The device string (e.g., "cpu" or "cuda").
128
+ device (torch.device): The device to run the model on.
129
+ model_path (str): The path to the pretrained model.
130
+ model_name (str): The name of the model.
131
+ model_type (str): The type of the model (e.g., "openclip" or "cpm").
132
+ db_conn: The database connection object.
133
+ num_vec_dim (int): The number of vector dimensions.
134
+ pic_idx_name (str): The name of the picture index.
135
+ open_clip_cfg (dict): The configuration for the OpenCLIP model.
136
+ model: The pretrained model.
137
+ preprocess_train: The preprocessing function for training data.
138
+ preprocess_val: The preprocessing function for validation data.
139
+ tokenizer: The tokenizer for the model.
140
+ Methods:
141
+ __init__(self, model_path, conn=None, model_name: str = None, model_type: str = "openclip", device: str = "cpu", num_vec_dim: int = None, load_model: bool = True) -> None:
142
+ Initializes the ImgHandler class with the specified parameters.
143
+ load_model(self):
144
+ Loads the pretrained model and related configurations.
145
+ get_img_features(self, images, to_numpy=False, **kwargs):
146
+ Gets image features using a pretrained model.
147
+ get_text_features(self, texts, to_numpy=False, **kwargs):
148
+ Gets text features from the input texts.
149
+ get_text_img_probs(self, texts, images, probs=False, to_numpy=False, **kwargs):
150
+ Gets the probabilities of text-image associations.
151
+ get_pics_sims(self, images1, images2, to_numpy=False, **kwargs):
152
+ Calculates similarity scores between two sets of images.
153
+ vec_pics_todb(self, images: list[str], conn=None, print_idx_info=False):
154
+ Saves image features to a Redis database, avoiding duplicates.
155
+ get_pic_idx_info(self, conn=None):
156
+ Gets information about the picture index in the Redis database.
157
+ emb_search(self, emb_query, num_max=3, extra_params=None, conn=None):
158
+ Searches for similar embeddings in the database.
159
+ img_search(self, img, num_max=3, extra_params=None, conn=None):
160
+ Searches for similar images in the database based on the input image.
161
+ """
122
162
  def __init__(
123
163
  self,
124
164
  model_path,
125
165
  conn=None,
126
166
  model_name: str = None,
167
+ model_type: str = "openclip",
127
168
  device: str = "cpu",
128
169
  num_vec_dim: int = None,
129
170
  load_model: bool = True,
130
171
  ) -> None:
131
- """Initializes the class with the provided parameters.
132
-
172
+ """
173
+ Initializes the visualization utility.
133
174
  Args:
134
- model_path (str): Path to the model file.
135
- db_host (str): Hostname of the database.
136
- db_port (int): Port number of the database.
175
+ model_path (str): Path to the model.
176
+ conn (optional): Database connection object. Defaults to None.
137
177
  model_name (str, optional): Name of the model. Defaults to None.
178
+ model_type (str, optional): Type of the model. Defaults to "openclip".
138
179
  device (str, optional): Device to run the model on. Defaults to "cpu".
139
180
  num_vec_dim (int, optional): Number of vector dimensions. Defaults to None.
140
- load_model (bool, optional): Whether to load the model. Defaults to True.
141
-
181
+ load_model (bool, optional): Flag to load the model immediately. Defaults to True.
142
182
  Returns:
143
183
  None
144
184
  """
145
-
146
185
  self.device_str = device
147
186
  self.device = torch.device(device)
148
187
  self.model_path = model_path
149
188
  self.model_name = model_name
189
+ self.model_type = model_type
150
190
 
151
191
  self.db_conn = conn
152
192
  self.num_vec_dim = num_vec_dim
@@ -155,41 +195,72 @@ class ImgHandler:
155
195
  self.load_model()
156
196
 
157
197
  def load_model(self):
158
- """Load the OpenCLIP model and related configurations.
159
-
160
- This function loads the OpenCLIP model from the specified model path
161
- and initializes the necessary components such as the model,
162
- preprocessors for training and validation data, tokenizer, etc.
163
-
164
- Returns:
165
- None
166
198
  """
167
- ckpt_file = (
168
- Path(self.model_path) / Path("open_clip_pytorch_model.bin")
169
- ).as_posix()
170
- self.open_clip_cfg = json.load(
171
- open(Path(self.model_path) / Path("open_clip_config.json"))
172
- )
199
+ Loads the model and tokenizer based on the specified model type.
200
+ This method supports loading two types of models: "cpm" and "openclip".
201
+ For "cpm":
202
+ - Loads the tokenizer and model using `AutoTokenizer` and `AutoModel` from the Hugging Face library.
203
+ - Sets the model to the specified device.
204
+ - Sets the number of vector dimensions to 2304.
205
+ For "openclip":
206
+ - Loads the model checkpoint and configuration from the specified path.
207
+ - Sets the model name if not already specified.
208
+ - Creates the model and preprocessing transforms using `open_clip.create_model_and_transforms`.
209
+ - Sets the number of vector dimensions based on the configuration if not already specified.
210
+ - Loads the tokenizer using `open_clip.get_tokenizer`.
211
+ Attributes:
212
+ model_type (str): The type of the model to load ("cpm" or "openclip").
213
+ model_path (str): The path to the model files.
214
+ device (str): The device to load the model onto (e.g., "cpu" or "cuda").
215
+ model_name (str, optional): The name of the model (used for "openclip" type).
216
+ num_vec_dim (int, optional): The number of vector dimensions (used for "openclip" type).
217
+ tokenizer: The tokenizer for the model.
218
+ model: The loaded model.
219
+ preprocess_train: The preprocessing transform for training (used for "openclip" type).
220
+ preprocess_val: The preprocessing transform for validation (used for "openclip" type).
221
+ open_clip_cfg (dict): The configuration for the "openclip" model.
222
+ """
223
+
173
224
 
174
- if self.model_name is None:
175
- self.model_name = (
176
- self.open_clip_cfg['model_cfg']['text_cfg']['hf_tokenizer_name']
177
- .split('/')[-1]
225
+ if self.model_type == "cpm":
226
+ self.tokenizer = AutoTokenizer.from_pretrained(
227
+ self.model_path,
228
+ trust_remote_code=True
229
+ )
230
+ self.model = AutoModel.from_pretrained(
231
+ self.model_path,
232
+ trust_remote_code=True
178
233
  )
234
+ self.model.to(self.device)
235
+ self.num_vec_dim = 2304
236
+
237
+ elif self.model_type == "openclip":
238
+ ckpt_file = (
239
+ Path(self.model_path) / Path("open_clip_pytorch_model.bin")
240
+ ).as_posix()
241
+ self.open_clip_cfg = json.load(
242
+ open(Path(self.model_path) / Path("open_clip_config.json"))
243
+ )
244
+
245
+ if self.model_name is None:
246
+ self.model_name = (
247
+ self.open_clip_cfg['model_cfg']['text_cfg']['hf_tokenizer_name']
248
+ .split('/')[-1]
249
+ )
179
250
 
180
- self.model, self.preprocess_train, self.preprocess_val = (
181
- open_clip.create_model_and_transforms(
182
- model_name=self.model_name,
183
- pretrained=ckpt_file,
184
- device=self.device,
185
- # precision=precision
251
+ self.model, self.preprocess_train, self.preprocess_val = (
252
+ open_clip.create_model_and_transforms(
253
+ model_name=self.model_name,
254
+ pretrained=ckpt_file,
255
+ device=self.device,
256
+ # precision=precision
257
+ )
258
+ )
259
+ if self.num_vec_dim is None:
260
+ self.num_vec_dim = self.open_clip_cfg["model_cfg"]["embed_dim"]
261
+ self.tokenizer = open_clip.get_tokenizer(
262
+ HF_HUB_PREFIX + self.model_path
186
263
  )
187
- )
188
- if self.num_vec_dim is None:
189
- self.num_vec_dim = self.open_clip_cfg["model_cfg"]["embed_dim"]
190
- self.tokenizer = open_clip.get_tokenizer(
191
- HF_HUB_PREFIX + self.model_path
192
- )
193
264
 
194
265
  def get_img_features(
195
266
  self,
@@ -222,16 +293,24 @@ class ImgHandler:
222
293
  f"Not supported image type for {type(img)}"
223
294
  )
224
295
 
225
-
226
- with torch.no_grad(), torch.amp.autocast(self.device_str):
227
- imgs = torch.stack([
228
- self.preprocess_val(image).to(self.device)
229
- for image in images_fixed
230
- ])
231
- img_features = self.model.encode_image(imgs, **kwargs)
232
- img_features /= img_features.norm(dim=-1, keepdim=True)
233
- if to_numpy:
234
- img_features = img_features.cpu().numpy()
296
+ if self.model_type == "cpm":
297
+ with torch.no_grad():
298
+ img_features = self.model(
299
+ text=[""] * len(images_fixed),
300
+ image=images_fixed,
301
+ tokenizer=self.tokenizer
302
+ ).reps
303
+
304
+ if self.model_type == "openclip":
305
+ with torch.no_grad(), torch.amp.autocast(self.device_str):
306
+ imgs = torch.stack([
307
+ self.preprocess_val(image).to(self.device)
308
+ for image in images_fixed
309
+ ])
310
+ img_features = self.model.encode_image(imgs, **kwargs)
311
+ img_features /= img_features.norm(dim=-1, keepdim=True)
312
+ if to_numpy:
313
+ img_features = img_features.cpu().numpy()
235
314
  return img_features
236
315
 
237
316
  def get_text_features(
@@ -253,17 +332,25 @@ class ImgHandler:
253
332
  Example:
254
333
  get_text_features(["text1", "text2"], to_numpy=True)
255
334
  """
256
- with torch.no_grad(), torch.amp.autocast(self.device_str):
257
- txts = self.tokenizer(
258
- texts,
259
- context_length=self.model.context_length
260
- ).to(self.device)
261
- txt_features = self.model.encode_text(txts, **kwargs)
262
- txt_features /= txt_features.norm(dim=-1, keepdim=True)
263
- if to_numpy:
264
- txt_features = txt_features.cpu().numpy()
265
- return txt_features
266
335
 
336
+ if self.model_type == "cpm":
337
+ with torch.no_grad():
338
+ txt_features = self.model(
339
+ text=texts,
340
+ image=[None] * len(texts),
341
+ tokenizer=self.tokenizer
342
+ ).reps
343
+ elif self.model_type == "openclip":
344
+ with torch.no_grad(), torch.amp.autocast(self.device_str):
345
+ txts = self.tokenizer(
346
+ texts,
347
+ context_length=self.model.context_length
348
+ ).to(self.device)
349
+ txt_features = self.model.encode_text(txts, **kwargs)
350
+ txt_features /= txt_features.norm(dim=-1, keepdim=True)
351
+ if to_numpy:
352
+ txt_features = txt_features.cpu().numpy()
353
+ return txt_features
267
354
 
268
355
  def get_text_img_probs(
269
356
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: hjxdl
3
- Version: 0.2.29
3
+ Version: 0.2.31
4
4
  Summary: A collection of functions for Jupyter notebooks
5
5
  Home-page: https://github.com/huluxiaohuowa/hdl
6
6
  Author: Jianxing Hu
@@ -1,5 +1,5 @@
1
1
  hdl/__init__.py,sha256=GffnD0jLJdhkd-vo989v40N90sQbofkayRBwxc6TVhQ,72
2
- hdl/_version.py,sha256=AtwvoTC96AXOg97Emp1_Wmo7L-xUzOR7aFTUto-9JfA,413
2
+ hdl/_version.py,sha256=DftfWlt0Q-ezfk5YMbK9lW9QTADqDhJKW_mOmhx7dwY,413
3
3
  hdl/args/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  hdl/args/loss_args.py,sha256=s7YzSdd7IjD24rZvvOrxLLFqMZQb9YylxKeyelSdrTk,70
5
5
  hdl/controllers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -128,17 +128,17 @@ hdl/utils/general/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU
128
128
  hdl/utils/general/glob.py,sha256=8-RCnt6L297wMIfn34ZAMCsGCZUjHG3MGglGZI1cX0g,491
129
129
  hdl/utils/llm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
130
130
  hdl/utils/llm/chat.py,sha256=OzyY9xACOOocx9zZigtq9YAPvHtDUo8v2fvf1Tyjg_U,14891
131
- hdl/utils/llm/chatgr.py,sha256=GO2G7g6YybduA5VCUuGjvEsJfC_6L7rycSnPeHMcxyM,2820
131
+ hdl/utils/llm/chatgr.py,sha256=8L7RLpi3tU_9HfP1qSiqH1BQDhBilSEj6Rn93lZOdDc,3584
132
132
  hdl/utils/llm/embs.py,sha256=Tf0FOYrOFZp7qQpEPiSCXzlgyHH0X9HVTUtsup74a9E,7174
133
133
  hdl/utils/llm/extract.py,sha256=2sK_WJzmYIc8iuWaM9DA6Nw3_6q1O4lJ5pKpcZo-bBA,6512
134
134
  hdl/utils/llm/llama_chat.py,sha256=watcHGOaz-bv3x-yDucYlGk5f8FiqfFhwWogrl334fk,4387
135
- hdl/utils/llm/vis.py,sha256=Kixrhc3eByHjdxiAcB2BnthsS31dHMavp6qc2JX46Dc,16292
135
+ hdl/utils/llm/vis.py,sha256=BsGAfy5X8sMFnX5A3vHpTPDRe_-IDdhs6YVQ-efvyQ0,21424
136
136
  hdl/utils/llm/visrag.py,sha256=_PuKtmQIXD5bnmXwDWhTLdzOhgC42JiqdMNb1uKA7n8,9190
137
137
  hdl/utils/schedulers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
138
138
  hdl/utils/schedulers/norm_lr.py,sha256=bDwCmdEK-WkgxQMFBiMuchv8Mm7C0-GZJ6usm-PQk14,4461
139
139
  hdl/utils/weather/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
140
140
  hdl/utils/weather/weather.py,sha256=k11o6wM15kF8b9NMlEfrg68ak-SfSYLN3nOOflFUv-I,4381
141
- hjxdl-0.2.29.dist-info/METADATA,sha256=7qiDTPY06pouX5FgbV4ekWvdDdXKae-Q9m1iGrIO1GA,836
142
- hjxdl-0.2.29.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
143
- hjxdl-0.2.29.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
144
- hjxdl-0.2.29.dist-info/RECORD,,
141
+ hjxdl-0.2.31.dist-info/METADATA,sha256=T9COT4TfRlNcqpIOdr4tbU2bC_JhNIbRpbOW9855RKo,836
142
+ hjxdl-0.2.31.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
143
+ hjxdl-0.2.31.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
144
+ hjxdl-0.2.31.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.2.0)
2
+ Generator: setuptools (75.3.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5