hjxdl 0.1.69__py3-none-any.whl → 0.1.70__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hdl/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.1.69'
16
- __version_tuple__ = version_tuple = (0, 1, 69)
15
+ __version__ = version = '0.1.70'
16
+ __version_tuple__ = version_tuple = (0, 1, 70)
hdl/utils/llm/vis.py CHANGED
@@ -11,6 +11,8 @@ import open_clip
11
11
  import natsort
12
12
  from redis.commands.search.field import VectorField
13
13
  from redis.commands.search.indexDefinition import IndexDefinition, IndexType
14
+ from hdl.jupyfuncs.show.pbar import tqdm
15
+ from redis.commands.search.query import Query
14
16
 
15
17
  from ..database_tools.connect import conn_redis
16
18
 
@@ -75,6 +77,7 @@ class ImgHandler:
75
77
  self.db_port = db_port
76
78
  self._db_conn = None
77
79
  self.num_vec_dim = num_vec_dim
80
+ self.pic_idx_name = "idx:pic_idx"
78
81
  if load_model:
79
82
  self.load_model()
80
83
 
@@ -142,10 +145,26 @@ class ImgHandler:
142
145
  Returns:
143
146
  torch.Tensor or numpy.ndarray: Image features extracted from the model.
144
147
  """
148
+
149
+ images_fixed = []
150
+ for img in images:
151
+ if isinstance(img, str):
152
+ if Path(Path(img).is_file()):
153
+ images_fixed.append(Image.open(img))
154
+ elif img.startswith("data:image/jpeg;base64,"):
155
+ images_fixed.append(imgbase64_to_pilimg(img))
156
+ elif isinstance(img, Image.Image):
157
+ images_fixed.append(img)
158
+ else:
159
+ raise TypeError(
160
+ f"Not supported image type for {type(img)}"
161
+ )
162
+
163
+
145
164
  with torch.no_grad(), torch.amp.autocast("cuda"):
146
165
  imgs = torch.stack([
147
- self.preprocess_val(Image.open(image)).to(self.device)
148
- for image in images
166
+ self.preprocess_val(image).to(self.device)
167
+ for image in images_fixed
149
168
  ])
150
169
  img_features = self.model.encode_image(imgs, **kwargs)
151
170
  img_features /= img_features.norm(dim=-1, keepdim=True)
@@ -242,7 +261,11 @@ class ImgHandler:
242
261
  # > 0.9 很相似
243
262
  return sims
244
263
 
245
- def vec_pics_todb(self, images):
264
+ def vec_pics_todb(
265
+ self,
266
+ images: list[str],
267
+ print_idx_info: bool = False
268
+ ):
246
269
  """Save image features to a Redis database.
247
270
 
248
271
  Args:
@@ -257,7 +280,7 @@ class ImgHandler:
257
280
  sorted_imgs = natsort.natsorted(images)
258
281
  img_feats = self.get_img_features(sorted_imgs, to_numpy=True)
259
282
  pipeline = self.db_conn.pipeline()
260
- for img_file, emb in zip(sorted_imgs, img_feats):
283
+ for img_file, emb in tqdm(zip(sorted_imgs, img_feats)):
261
284
  # 初始化 Redis,先使用 img 文件名作为 Key 和 Value,后续再更新为图片特征向量
262
285
  # pipeline.json().set(img_file, "$", img_file)
263
286
  emb = emb.astype(np.float32).tolist()
@@ -281,31 +304,80 @@ class ImgHandler:
281
304
  as_name="vector", # 给这个字段定义一个别名,后续可以使用
282
305
  ),
283
306
  )
284
- vector_idx_name = "idx:pic_idx"
307
+ # vector_idx_name = "idx:pic_idx"
285
308
  definition = IndexDefinition(
286
309
  prefix=["pic-"],
287
310
  index_type=IndexType.JSON
288
311
  )
289
312
  res = self.db_conn.ft(
290
- vector_idx_name
313
+ self.pic_idx_name
291
314
  ).create_index(
292
315
  fields=schema,
293
316
  definition=definition
294
317
  )
295
318
  print("create_index:", res)
319
+ if print_idx_info:
320
+ print(self.pic_idx_info)
296
321
 
297
322
  @property
298
323
  def pic_idx_info(self):
299
324
  res = self.db_conn.ping()
300
325
  print("redis connected:", res)
301
-
302
- vector_idx_name = "idx:pic_idx"
303
-
326
+ # vector_idx_name = "idx:pic_idx"
304
327
  # 从 Redis 数据库中读取索引状态
305
- info = self.db_conn.ft(vector_idx_name).info()
328
+ info = self.db_conn.ft(self.pic_idx_name).info()
306
329
  # 获取索引状态中的 num_docs 和 hash_indexing_failures
307
330
  num_docs = info["num_docs"]
308
331
  indexing_failures = info["hash_indexing_failures"]
309
- print(f"{num_docs} documents indexed with {indexing_failures} failures")
332
+ return (
333
+ f"{num_docs} documents indexed with {indexing_failures} failures"
334
+ )
335
+
336
+ def img_search(
337
+ self,
338
+ img,
339
+ num_max: int = 3,
340
+ extra_params: dict = None
341
+ ):
342
+ """Search for similar images in the database based on the input image.
343
+
344
+ Args:
345
+ img: Input image to search for similar images.
346
+ num_max: Maximum number of similar images to return (default is 3).
347
+ extra_params: Additional parameters to include in the search query (default is None).
348
+
349
+ Returns:
350
+ List of tuples containing the ID and JSON data of similar images found in the database.
351
+ """
352
+ emb_query = self.get_img_features(
353
+ [img], to_numpy=True
354
+ ).astype(np.float32)[0].tobytes()
355
+ query = (
356
+ Query(
357
+ f"(*)=>[KNN {str(num_max)} @vector $query_vector AS vector_score]"
358
+ )
359
+ .sort_by("vector_score")
360
+ .return_fields("$")
361
+ .dialect(2)
362
+ )
363
+ if extra_params is None:
364
+ extra_params = {}
365
+ result_docs = (
366
+ self.db_conn.ft("idx:pic_idx")
367
+ .search(
368
+ query,
369
+ {
370
+ "query_vector": emb_query
371
+ }
372
+ | extra_params,
373
+ )
374
+ .docs
375
+ )
376
+ results = [
377
+ (result_doc.id, json.loads(result_doc.json))
378
+ for result_doc in result_docs
379
+ ]
380
+ return results
381
+
310
382
 
311
383
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: hjxdl
3
- Version: 0.1.69
3
+ Version: 0.1.70
4
4
  Summary: A collection of functions for Jupyter notebooks
5
5
  Home-page: https://github.com/huluxiaohuowa/hdl
6
6
  Author: Jianxing Hu
@@ -1,5 +1,5 @@
1
1
  hdl/__init__.py,sha256=GffnD0jLJdhkd-vo989v40N90sQbofkayRBwxc6TVhQ,72
2
- hdl/_version.py,sha256=lVXaDORAcShTMS0Ix_T7zlcXAGst1lAMus_P8nq4GD4,413
2
+ hdl/_version.py,sha256=AYQNyn783xY6psbpqi7TYW2XK3FmuvjMGND0tFVkkvI,413
3
3
  hdl/args/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  hdl/args/loss_args.py,sha256=s7YzSdd7IjD24rZvvOrxLLFqMZQb9YylxKeyelSdrTk,70
5
5
  hdl/controllers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -131,12 +131,12 @@ hdl/utils/llm/chat.py,sha256=sk7Lw5Oa30k-l2fnJknkMmTc5zkBeEKsR981aeFhH5s,11907
131
131
  hdl/utils/llm/embs.py,sha256=Tf0FOYrOFZp7qQpEPiSCXzlgyHH0X9HVTUtsup74a9E,7174
132
132
  hdl/utils/llm/extract.py,sha256=2sK_WJzmYIc8iuWaM9DA6Nw3_6q1O4lJ5pKpcZo-bBA,6512
133
133
  hdl/utils/llm/llama_chat.py,sha256=watcHGOaz-bv3x-yDucYlGk5f8FiqfFhwWogrl334fk,4387
134
- hdl/utils/llm/vis.py,sha256=BygpjORSwbe_7SmThzt_OvzsrEym7FwCk6fubGyiDDc,10097
134
+ hdl/utils/llm/vis.py,sha256=8rzJMMmVpMibUvfvGfjwFwWE5xrhZM1pjkwbe5fgmNQ,12352
135
135
  hdl/utils/schedulers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
136
136
  hdl/utils/schedulers/norm_lr.py,sha256=bDwCmdEK-WkgxQMFBiMuchv8Mm7C0-GZJ6usm-PQk14,4461
137
137
  hdl/utils/weather/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
138
138
  hdl/utils/weather/weather.py,sha256=k11o6wM15kF8b9NMlEfrg68ak-SfSYLN3nOOflFUv-I,4381
139
- hjxdl-0.1.69.dist-info/METADATA,sha256=LZ8Y6-VsUCMYG1AQ5n-GBzff1lwfa-hOev1wT79OULc,903
140
- hjxdl-0.1.69.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
141
- hjxdl-0.1.69.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
142
- hjxdl-0.1.69.dist-info/RECORD,,
139
+ hjxdl-0.1.70.dist-info/METADATA,sha256=A7kOeOkBy4YNzlalYPnyGE3yCkKoy8xcautLlDF0pkk,903
140
+ hjxdl-0.1.70.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
141
+ hjxdl-0.1.70.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
142
+ hjxdl-0.1.70.dist-info/RECORD,,
File without changes