hjxdl 0.2.30__py3-none-any.whl → 0.2.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hdl/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.30'
16
- __version_tuple__ = version_tuple = (0, 2, 30)
15
+ __version__ = version = '0.2.32'
16
+ __version_tuple__ = version_tuple = (0, 2, 32)
hdl/utils/llm/chatgr.py CHANGED
@@ -4,6 +4,19 @@ from .chat import OpenAI_M
4
4
 
5
5
  # 定义流式输出的生成函数
6
6
  def chat_with_llm(user_input, chat_history=[]):
7
+ """
8
+ Facilitates a chat interaction with a language model (LLM).
9
+ This function takes user input and maintains a chat history. It streams the response from the LLM and updates the chat history in real-time.
10
+ Args:
11
+ user_input (str): The input message from the user.
12
+ chat_history (list, optional): A list of tuples representing the chat history. Each tuple contains two strings: the user's message and the bot's response. Defaults to an empty list.
13
+ Yields:
14
+ tuple: A tuple containing three elements:
15
+ - An empty string (for compatibility with certain frameworks).
16
+ - The updated chat history including the latest user message and the bot's response.
17
+ - The same updated chat history.
18
+ """
19
+
7
20
  chat_history.append(("User: " + user_input, "Bot: ")) # 初始先追加用户消息
8
21
  yield "", chat_history, chat_history # 返回用户消息
9
22
 
@@ -21,6 +34,17 @@ def chat_with_llm(user_input, chat_history=[]):
21
34
 
22
35
  # 构建 Gradio 界面
23
36
  def create_demo():
37
+ """
38
+ Creates a Gradio demo interface for a chatbot application.
39
+ The interface consists of:
40
+ - A chat history display at the top of the page.
41
+ - A user input textbox and a send button at the bottom of the page.
42
+ The send button and the enter key are both bound to the `chat_with_llm` function,
43
+ which handles sending the user's message and updating the chat history.
44
+ Returns:
45
+ gr.Blocks: The Gradio Blocks object representing the demo interface.
46
+ """
47
+
24
48
  with gr.Blocks() as demo:
25
49
  chat_history = gr.State([]) # 存储聊天历史
26
50
  output = gr.Chatbot(label="Chat History") # 聊天记录在页面顶端
@@ -59,4 +83,6 @@ if __name__ == "__main__":
59
83
 
60
84
  # 启动 Gradio 应用
61
85
  demo = create_demo()
62
- demo.launch(server_name=args.host, server_port=args.port)
86
+ demo.launch(server_name=args.host, server_port=args.port)
87
+
88
+
hdl/utils/llm/vis.py CHANGED
@@ -121,6 +121,44 @@ def pilimg_to_base64(pilimg):
121
121
 
122
122
 
123
123
  class ImgHandler:
124
+ """
125
+ ImgHandler is a class for handling image processing tasks using pretrained models.
126
+ Attributes:
127
+ device_str (str): The device string (e.g., "cpu" or "cuda").
128
+ device (torch.device): The device to run the model on.
129
+ model_path (str): The path to the pretrained model.
130
+ model_name (str): The name of the model.
131
+ model_type (str): The type of the model (e.g., "openclip" or "cpm").
132
+ db_conn: The database connection object.
133
+ num_vec_dim (int): The number of vector dimensions.
134
+ pic_idx_name (str): The name of the picture index.
135
+ open_clip_cfg (dict): The configuration for the OpenCLIP model.
136
+ model: The pretrained model.
137
+ preprocess_train: The preprocessing function for training data.
138
+ preprocess_val: The preprocessing function for validation data.
139
+ tokenizer: The tokenizer for the model.
140
+ Methods:
141
+ __init__(self, model_path, conn=None, model_name: str = None, model_type: str = "openclip", device: str = "cpu", num_vec_dim: int = None, load_model: bool = True) -> None:
142
+ Initializes the ImgHandler class with the specified parameters.
143
+ load_model(self):
144
+ Loads the pretrained model and related configurations.
145
+ get_img_features(self, images, to_numpy=False, **kwargs):
146
+ Gets image features using a pretrained model.
147
+ get_text_features(self, texts, to_numpy=False, **kwargs):
148
+ Gets text features from the input texts.
149
+ get_text_img_probs(self, texts, images, probs=False, to_numpy=False, **kwargs):
150
+ Gets the probabilities of text-image associations.
151
+ get_pics_sims(self, images1, images2, to_numpy=False, **kwargs):
152
+ Calculates similarity scores between two sets of images.
153
+ vec_pics_todb(self, images: list[str], conn=None, print_idx_info=False):
154
+ Saves image features to a Redis database, avoiding duplicates.
155
+ get_pic_idx_info(self, conn=None):
156
+ Gets information about the picture index in the Redis database.
157
+ emb_search(self, emb_query, num_max=3, extra_params=None, conn=None):
158
+ Searches for similar embeddings in the database.
159
+ img_search(self, img, num_max=3, extra_params=None, conn=None):
160
+ Searches for similar images in the database based on the input image.
161
+ """
124
162
  def __init__(
125
163
  self,
126
164
  model_path,
@@ -131,21 +169,19 @@ class ImgHandler:
131
169
  num_vec_dim: int = None,
132
170
  load_model: bool = True,
133
171
  ) -> None:
134
- """Initializes the class with the provided parameters.
135
-
172
+ """
173
+ Initializes the visualization utility.
136
174
  Args:
137
- model_path (str): Path to the model file.
138
- db_host (str): Hostname of the database.
139
- db_port (int): Port number of the database.
175
+ model_path (str): Path to the model.
176
+ conn (optional): Database connection object. Defaults to None.
140
177
  model_name (str, optional): Name of the model. Defaults to None.
178
+ model_type (str, optional): Type of the model. Defaults to "openclip".
141
179
  device (str, optional): Device to run the model on. Defaults to "cpu".
142
180
  num_vec_dim (int, optional): Number of vector dimensions. Defaults to None.
143
- load_model (bool, optional): Whether to load the model. Defaults to True.
144
-
181
+ load_model (bool, optional): Flag to load the model immediately. Defaults to True.
145
182
  Returns:
146
183
  None
147
184
  """
148
-
149
185
  self.device_str = device
150
186
  self.device = torch.device(device)
151
187
  self.model_path = model_path
@@ -159,15 +195,32 @@ class ImgHandler:
159
195
  self.load_model()
160
196
 
161
197
  def load_model(self):
162
- """Load the OpenCLIP model and related configurations.
163
-
164
- This function loads the OpenCLIP model from the specified model path
165
- and initializes the necessary components such as the model,
166
- preprocessors for training and validation data, tokenizer, etc.
167
-
168
- Returns:
169
- None
170
198
  """
199
+ Loads the model and tokenizer based on the specified model type.
200
+ This method supports loading two types of models: "cpm" and "openclip".
201
+ For "cpm":
202
+ - Loads the tokenizer and model using `AutoTokenizer` and `AutoModel` from the Hugging Face library.
203
+ - Sets the model to the specified device.
204
+ - Sets the number of vector dimensions to 2304.
205
+ For "openclip":
206
+ - Loads the model checkpoint and configuration from the specified path.
207
+ - Sets the model name if not already specified.
208
+ - Creates the model and preprocessing transforms using `open_clip.create_model_and_transforms`.
209
+ - Sets the number of vector dimensions based on the configuration if not already specified.
210
+ - Loads the tokenizer using `open_clip.get_tokenizer`.
211
+ Attributes:
212
+ model_type (str): The type of the model to load ("cpm" or "openclip").
213
+ model_path (str): The path to the model files.
214
+ device (str): The device to load the model onto (e.g., "cpu" or "cuda").
215
+ model_name (str, optional): The name of the model (used for "openclip" type).
216
+ num_vec_dim (int, optional): The number of vector dimensions (used for "openclip" type).
217
+ tokenizer: The tokenizer for the model.
218
+ model: The loaded model.
219
+ preprocess_train: The preprocessing transform for training (used for "openclip" type).
220
+ preprocess_val: The preprocessing transform for validation (used for "openclip" type).
221
+ open_clip_cfg (dict): The configuration for the "openclip" model.
222
+ """
223
+
171
224
 
172
225
  if self.model_type == "cpm":
173
226
  self.tokenizer = AutoTokenizer.from_pretrained(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: hjxdl
3
- Version: 0.2.30
3
+ Version: 0.2.32
4
4
  Summary: A collection of functions for Jupyter notebooks
5
5
  Home-page: https://github.com/huluxiaohuowa/hdl
6
6
  Author: Jianxing Hu
@@ -1,5 +1,5 @@
1
1
  hdl/__init__.py,sha256=GffnD0jLJdhkd-vo989v40N90sQbofkayRBwxc6TVhQ,72
2
- hdl/_version.py,sha256=YLA3hy-44LcRh6Hcq_A-SyVyfXwkdObMoIBSVcngmDI,413
2
+ hdl/_version.py,sha256=jSnpQjxfX4ZsfrMFFxLeBiF7_cxq8o44X1E_cpi4UPE,413
3
3
  hdl/args/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  hdl/args/loss_args.py,sha256=s7YzSdd7IjD24rZvvOrxLLFqMZQb9YylxKeyelSdrTk,70
5
5
  hdl/controllers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -128,17 +128,17 @@ hdl/utils/general/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU
128
128
  hdl/utils/general/glob.py,sha256=8-RCnt6L297wMIfn34ZAMCsGCZUjHG3MGglGZI1cX0g,491
129
129
  hdl/utils/llm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
130
130
  hdl/utils/llm/chat.py,sha256=OzyY9xACOOocx9zZigtq9YAPvHtDUo8v2fvf1Tyjg_U,14891
131
- hdl/utils/llm/chatgr.py,sha256=GO2G7g6YybduA5VCUuGjvEsJfC_6L7rycSnPeHMcxyM,2820
131
+ hdl/utils/llm/chatgr.py,sha256=lKaDeimYz-Bw32QXYtde51xjf7gbfS2CEYt5R0ODunM,4075
132
132
  hdl/utils/llm/embs.py,sha256=Tf0FOYrOFZp7qQpEPiSCXzlgyHH0X9HVTUtsup74a9E,7174
133
133
  hdl/utils/llm/extract.py,sha256=2sK_WJzmYIc8iuWaM9DA6Nw3_6q1O4lJ5pKpcZo-bBA,6512
134
134
  hdl/utils/llm/llama_chat.py,sha256=watcHGOaz-bv3x-yDucYlGk5f8FiqfFhwWogrl334fk,4387
135
- hdl/utils/llm/vis.py,sha256=RWeI6lSmzCDG2HJMq8-teuC7to4pPiR0ee2Hx1clbRw,17656
135
+ hdl/utils/llm/vis.py,sha256=BsGAfy5X8sMFnX5A3vHpTPDRe_-IDdhs6YVQ-efvyQ0,21424
136
136
  hdl/utils/llm/visrag.py,sha256=_PuKtmQIXD5bnmXwDWhTLdzOhgC42JiqdMNb1uKA7n8,9190
137
137
  hdl/utils/schedulers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
138
138
  hdl/utils/schedulers/norm_lr.py,sha256=bDwCmdEK-WkgxQMFBiMuchv8Mm7C0-GZJ6usm-PQk14,4461
139
139
  hdl/utils/weather/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
140
140
  hdl/utils/weather/weather.py,sha256=k11o6wM15kF8b9NMlEfrg68ak-SfSYLN3nOOflFUv-I,4381
141
- hjxdl-0.2.30.dist-info/METADATA,sha256=aKoW2m7jTjCWQSTm4Kzxi7n5YWyXXxjHlmqHjZjG8D4,836
142
- hjxdl-0.2.30.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
143
- hjxdl-0.2.30.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
144
- hjxdl-0.2.30.dist-info/RECORD,,
141
+ hjxdl-0.2.32.dist-info/METADATA,sha256=DLlf5Bv3kIdC3Zs-a9iKvjQsCzN5e7QL_BXvVa9IywA,836
142
+ hjxdl-0.2.32.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
143
+ hjxdl-0.2.32.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
144
+ hjxdl-0.2.32.dist-info/RECORD,,
File without changes