hjxdl 0.1.68__py3-none-any.whl → 0.1.70__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hdl/_version.py +2 -2
- hdl/utils/llm/vis.py +119 -6
- {hjxdl-0.1.68.dist-info → hjxdl-0.1.70.dist-info}/METADATA +1 -1
- {hjxdl-0.1.68.dist-info → hjxdl-0.1.70.dist-info}/RECORD +6 -6
- {hjxdl-0.1.68.dist-info → hjxdl-0.1.70.dist-info}/WHEEL +0 -0
- {hjxdl-0.1.68.dist-info → hjxdl-0.1.70.dist-info}/top_level.txt +0 -0
hdl/_version.py
CHANGED
hdl/utils/llm/vis.py
CHANGED
@@ -9,6 +9,10 @@ from PIL import Image
|
|
9
9
|
# from transformers import ChineseCLIPProcessor, ChineseCLIPModel
|
10
10
|
import open_clip
|
11
11
|
import natsort
|
12
|
+
from redis.commands.search.field import VectorField
|
13
|
+
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
14
|
+
from hdl.jupyfuncs.show.pbar import tqdm
|
15
|
+
from redis.commands.search.query import Query
|
12
16
|
|
13
17
|
from ..database_tools.connect import conn_redis
|
14
18
|
|
@@ -73,6 +77,7 @@ class ImgHandler:
|
|
73
77
|
self.db_port = db_port
|
74
78
|
self._db_conn = None
|
75
79
|
self.num_vec_dim = num_vec_dim
|
80
|
+
self.pic_idx_name = "idx:pic_idx"
|
76
81
|
if load_model:
|
77
82
|
self.load_model()
|
78
83
|
|
@@ -140,10 +145,26 @@ class ImgHandler:
|
|
140
145
|
Returns:
|
141
146
|
torch.Tensor or numpy.ndarray: Image features extracted from the model.
|
142
147
|
"""
|
148
|
+
|
149
|
+
images_fixed = []
|
150
|
+
for img in images:
|
151
|
+
if isinstance(img, str):
|
152
|
+
if Path(Path(img).is_file()):
|
153
|
+
images_fixed.append(Image.open(img))
|
154
|
+
elif img.startswith("data:image/jpeg;base64,"):
|
155
|
+
images_fixed.append(imgbase64_to_pilimg(img))
|
156
|
+
elif isinstance(img, Image.Image):
|
157
|
+
images_fixed.append(img)
|
158
|
+
else:
|
159
|
+
raise TypeError(
|
160
|
+
f"Not supported image type for {type(img)}"
|
161
|
+
)
|
162
|
+
|
163
|
+
|
143
164
|
with torch.no_grad(), torch.amp.autocast("cuda"):
|
144
165
|
imgs = torch.stack([
|
145
|
-
self.preprocess_val(
|
146
|
-
for image in
|
166
|
+
self.preprocess_val(image).to(self.device)
|
167
|
+
for image in images_fixed
|
147
168
|
])
|
148
169
|
img_features = self.model.encode_image(imgs, **kwargs)
|
149
170
|
img_features /= img_features.norm(dim=-1, keepdim=True)
|
@@ -240,7 +261,11 @@ class ImgHandler:
|
|
240
261
|
# > 0.9 很相似
|
241
262
|
return sims
|
242
263
|
|
243
|
-
def vec_pics_todb(
|
264
|
+
def vec_pics_todb(
|
265
|
+
self,
|
266
|
+
images: list[str],
|
267
|
+
print_idx_info: bool = False
|
268
|
+
):
|
244
269
|
"""Save image features to a Redis database.
|
245
270
|
|
246
271
|
Args:
|
@@ -255,7 +280,7 @@ class ImgHandler:
|
|
255
280
|
sorted_imgs = natsort.natsorted(images)
|
256
281
|
img_feats = self.get_img_features(sorted_imgs, to_numpy=True)
|
257
282
|
pipeline = self.db_conn.pipeline()
|
258
|
-
for img_file, emb in zip(sorted_imgs, img_feats):
|
283
|
+
for img_file, emb in tqdm(zip(sorted_imgs, img_feats)):
|
259
284
|
# 初始化 Redis,先使用 img 文件名作为 Key 和 Value,后续再更新为图片特征向量
|
260
285
|
# pipeline.json().set(img_file, "$", img_file)
|
261
286
|
emb = emb.astype(np.float32).tolist()
|
@@ -263,8 +288,96 @@ class ImgHandler:
|
|
263
288
|
"emb": emb,
|
264
289
|
"data": imgfile_to_base64(img_file)
|
265
290
|
}
|
266
|
-
pipeline.json().set(img_file, "$", emb_json)
|
291
|
+
pipeline.json().set(f"pic-{img_file}", "$", emb_json)
|
267
292
|
res = pipeline.execute()
|
268
|
-
print('redis set:', res)
|
293
|
+
# print('redis set:', res)
|
294
|
+
|
295
|
+
schema = (
|
296
|
+
VectorField(
|
297
|
+
"$.emb", # 这是 JSON 中存储向量的路径
|
298
|
+
"FLAT", # 使用 FLAT 索引类型
|
299
|
+
{
|
300
|
+
"TYPE": "FLOAT32", # 向量类型
|
301
|
+
"DIM": self.num_vec_dim, # 向量维度,必须与实际数据的维度一致
|
302
|
+
"DISTANCE_METRIC": "COSINE", # 余弦相似度作为距离度量
|
303
|
+
},
|
304
|
+
as_name="vector", # 给这个字段定义一个别名,后续可以使用
|
305
|
+
),
|
306
|
+
)
|
307
|
+
# vector_idx_name = "idx:pic_idx"
|
308
|
+
definition = IndexDefinition(
|
309
|
+
prefix=["pic-"],
|
310
|
+
index_type=IndexType.JSON
|
311
|
+
)
|
312
|
+
res = self.db_conn.ft(
|
313
|
+
self.pic_idx_name
|
314
|
+
).create_index(
|
315
|
+
fields=schema,
|
316
|
+
definition=definition
|
317
|
+
)
|
318
|
+
print("create_index:", res)
|
319
|
+
if print_idx_info:
|
320
|
+
print(self.pic_idx_info)
|
321
|
+
|
322
|
+
@property
|
323
|
+
def pic_idx_info(self):
|
324
|
+
res = self.db_conn.ping()
|
325
|
+
print("redis connected:", res)
|
326
|
+
# vector_idx_name = "idx:pic_idx"
|
327
|
+
# 从 Redis 数据库中读取索引状态
|
328
|
+
info = self.db_conn.ft(self.pic_idx_name).info()
|
329
|
+
# 获取索引状态中的 num_docs 和 hash_indexing_failures
|
330
|
+
num_docs = info["num_docs"]
|
331
|
+
indexing_failures = info["hash_indexing_failures"]
|
332
|
+
return (
|
333
|
+
f"{num_docs} documents indexed with {indexing_failures} failures"
|
334
|
+
)
|
335
|
+
|
336
|
+
def img_search(
|
337
|
+
self,
|
338
|
+
img,
|
339
|
+
num_max: int = 3,
|
340
|
+
extra_params: dict = None
|
341
|
+
):
|
342
|
+
"""Search for similar images in the database based on the input image.
|
343
|
+
|
344
|
+
Args:
|
345
|
+
img: Input image to search for similar images.
|
346
|
+
num_max: Maximum number of similar images to return (default is 3).
|
347
|
+
extra_params: Additional parameters to include in the search query (default is None).
|
348
|
+
|
349
|
+
Returns:
|
350
|
+
List of tuples containing the ID and JSON data of similar images found in the database.
|
351
|
+
"""
|
352
|
+
emb_query = self.get_img_features(
|
353
|
+
[img], to_numpy=True
|
354
|
+
).astype(np.float32)[0].tobytes()
|
355
|
+
query = (
|
356
|
+
Query(
|
357
|
+
f"(*)=>[KNN {str(num_max)} @vector $query_vector AS vector_score]"
|
358
|
+
)
|
359
|
+
.sort_by("vector_score")
|
360
|
+
.return_fields("$")
|
361
|
+
.dialect(2)
|
362
|
+
)
|
363
|
+
if extra_params is None:
|
364
|
+
extra_params = {}
|
365
|
+
result_docs = (
|
366
|
+
self.db_conn.ft("idx:pic_idx")
|
367
|
+
.search(
|
368
|
+
query,
|
369
|
+
{
|
370
|
+
"query_vector": emb_query
|
371
|
+
}
|
372
|
+
| extra_params,
|
373
|
+
)
|
374
|
+
.docs
|
375
|
+
)
|
376
|
+
results = [
|
377
|
+
(result_doc.id, json.loads(result_doc.json))
|
378
|
+
for result_doc in result_docs
|
379
|
+
]
|
380
|
+
return results
|
381
|
+
|
269
382
|
|
270
383
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
hdl/__init__.py,sha256=GffnD0jLJdhkd-vo989v40N90sQbofkayRBwxc6TVhQ,72
|
2
|
-
hdl/_version.py,sha256=
|
2
|
+
hdl/_version.py,sha256=AYQNyn783xY6psbpqi7TYW2XK3FmuvjMGND0tFVkkvI,413
|
3
3
|
hdl/args/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
hdl/args/loss_args.py,sha256=s7YzSdd7IjD24rZvvOrxLLFqMZQb9YylxKeyelSdrTk,70
|
5
5
|
hdl/controllers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -131,12 +131,12 @@ hdl/utils/llm/chat.py,sha256=sk7Lw5Oa30k-l2fnJknkMmTc5zkBeEKsR981aeFhH5s,11907
|
|
131
131
|
hdl/utils/llm/embs.py,sha256=Tf0FOYrOFZp7qQpEPiSCXzlgyHH0X9HVTUtsup74a9E,7174
|
132
132
|
hdl/utils/llm/extract.py,sha256=2sK_WJzmYIc8iuWaM9DA6Nw3_6q1O4lJ5pKpcZo-bBA,6512
|
133
133
|
hdl/utils/llm/llama_chat.py,sha256=watcHGOaz-bv3x-yDucYlGk5f8FiqfFhwWogrl334fk,4387
|
134
|
-
hdl/utils/llm/vis.py,sha256=
|
134
|
+
hdl/utils/llm/vis.py,sha256=8rzJMMmVpMibUvfvGfjwFwWE5xrhZM1pjkwbe5fgmNQ,12352
|
135
135
|
hdl/utils/schedulers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
136
136
|
hdl/utils/schedulers/norm_lr.py,sha256=bDwCmdEK-WkgxQMFBiMuchv8Mm7C0-GZJ6usm-PQk14,4461
|
137
137
|
hdl/utils/weather/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
138
138
|
hdl/utils/weather/weather.py,sha256=k11o6wM15kF8b9NMlEfrg68ak-SfSYLN3nOOflFUv-I,4381
|
139
|
-
hjxdl-0.1.
|
140
|
-
hjxdl-0.1.
|
141
|
-
hjxdl-0.1.
|
142
|
-
hjxdl-0.1.
|
139
|
+
hjxdl-0.1.70.dist-info/METADATA,sha256=A7kOeOkBy4YNzlalYPnyGE3yCkKoy8xcautLlDF0pkk,903
|
140
|
+
hjxdl-0.1.70.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
141
|
+
hjxdl-0.1.70.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
|
142
|
+
hjxdl-0.1.70.dist-info/RECORD,,
|
File without changes
|
File without changes
|