hjxdl 0.1.69__py3-none-any.whl → 0.1.71__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hdl/_version.py +2 -2
- hdl/utils/llm/vis.py +106 -11
- {hjxdl-0.1.69.dist-info → hjxdl-0.1.71.dist-info}/METADATA +1 -1
- {hjxdl-0.1.69.dist-info → hjxdl-0.1.71.dist-info}/RECORD +6 -6
- {hjxdl-0.1.69.dist-info → hjxdl-0.1.71.dist-info}/WHEEL +0 -0
- {hjxdl-0.1.69.dist-info → hjxdl-0.1.71.dist-info}/top_level.txt +0 -0
    
        hdl/_version.py
    CHANGED
    
    
    
        hdl/utils/llm/vis.py
    CHANGED
    
    | @@ -11,6 +11,8 @@ import open_clip | |
| 11 11 | 
             
            import natsort
         | 
| 12 12 | 
             
            from redis.commands.search.field import VectorField
         | 
| 13 13 | 
             
            from redis.commands.search.indexDefinition import IndexDefinition, IndexType
         | 
| 14 | 
            +
            from hdl.jupyfuncs.show.pbar import tqdm
         | 
| 15 | 
            +
            from redis.commands.search.query import Query
         | 
| 14 16 |  | 
| 15 17 | 
             
            from ..database_tools.connect import conn_redis
         | 
| 16 18 |  | 
| @@ -75,6 +77,7 @@ class ImgHandler: | |
| 75 77 | 
             
                    self.db_port = db_port
         | 
| 76 78 | 
             
                    self._db_conn = None
         | 
| 77 79 | 
             
                    self.num_vec_dim = num_vec_dim
         | 
| 80 | 
            +
                    self.pic_idx_name = "idx:pic_idx"
         | 
| 78 81 | 
             
                    if load_model:
         | 
| 79 82 | 
             
                        self.load_model()
         | 
| 80 83 |  | 
| @@ -142,10 +145,26 @@ class ImgHandler: | |
| 142 145 | 
             
                    Returns:
         | 
| 143 146 | 
             
                        torch.Tensor or numpy.ndarray: Image features extracted from the model.
         | 
| 144 147 | 
             
                    """
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    images_fixed = []
         | 
| 150 | 
            +
                    for img in images:
         | 
| 151 | 
            +
                        if isinstance(img, str):
         | 
| 152 | 
            +
                            if Path(Path(img).is_file()):
         | 
| 153 | 
            +
                                images_fixed.append(Image.open(img))
         | 
| 154 | 
            +
                            elif img.startswith("data:image/jpeg;base64,"):
         | 
| 155 | 
            +
                                images_fixed.append(imgbase64_to_pilimg(img))
         | 
| 156 | 
            +
                        elif isinstance(img, Image.Image):
         | 
| 157 | 
            +
                            images_fixed.append(img)
         | 
| 158 | 
            +
                        else:
         | 
| 159 | 
            +
                            raise TypeError(
         | 
| 160 | 
            +
                                f"Not supported image type for {type(img)}"
         | 
| 161 | 
            +
                            )
         | 
| 162 | 
            +
             | 
| 163 | 
            +
             | 
| 145 164 | 
             
                    with torch.no_grad(), torch.amp.autocast("cuda"):
         | 
| 146 165 | 
             
                        imgs = torch.stack([
         | 
| 147 | 
            -
                            self.preprocess_val( | 
| 148 | 
            -
                            for image in  | 
| 166 | 
            +
                            self.preprocess_val(image).to(self.device)
         | 
| 167 | 
            +
                            for image in images_fixed
         | 
| 149 168 | 
             
                        ])
         | 
| 150 169 | 
             
                        img_features = self.model.encode_image(imgs, **kwargs)
         | 
| 151 170 | 
             
                        img_features /= img_features.norm(dim=-1, keepdim=True)
         | 
| @@ -242,7 +261,11 @@ class ImgHandler: | |
| 242 261 | 
             
                        # > 0.9 很相似
         | 
| 243 262 | 
             
                        return sims
         | 
| 244 263 |  | 
| 245 | 
            -
                def vec_pics_todb( | 
| 264 | 
            +
                def vec_pics_todb(
         | 
| 265 | 
            +
                    self,
         | 
| 266 | 
            +
                    images: list[str],
         | 
| 267 | 
            +
                    print_idx_info: bool = False
         | 
| 268 | 
            +
                ):
         | 
| 246 269 | 
             
                    """Save image features to a Redis database.
         | 
| 247 270 |  | 
| 248 271 | 
             
                    Args:
         | 
| @@ -257,7 +280,7 @@ class ImgHandler: | |
| 257 280 | 
             
                    sorted_imgs = natsort.natsorted(images)
         | 
| 258 281 | 
             
                    img_feats = self.get_img_features(sorted_imgs, to_numpy=True)
         | 
| 259 282 | 
             
                    pipeline = self.db_conn.pipeline()
         | 
| 260 | 
            -
                    for img_file, emb in zip(sorted_imgs, img_feats):
         | 
| 283 | 
            +
                    for img_file, emb in tqdm(zip(sorted_imgs, img_feats)):
         | 
| 261 284 | 
             
                        # 初始化 Redis,先使用 img 文件名作为 Key 和 Value,后续再更新为图片特征向量
         | 
| 262 285 | 
             
                        # pipeline.json().set(img_file, "$", img_file)
         | 
| 263 286 | 
             
                        emb = emb.astype(np.float32).tolist()
         | 
| @@ -281,31 +304,103 @@ class ImgHandler: | |
| 281 304 | 
             
                            as_name="vector",  # 给这个字段定义一个别名,后续可以使用
         | 
| 282 305 | 
             
                        ),
         | 
| 283 306 | 
             
                    )
         | 
| 284 | 
            -
                    vector_idx_name = "idx:pic_idx"
         | 
| 307 | 
            +
                    # vector_idx_name = "idx:pic_idx"
         | 
| 285 308 | 
             
                    definition = IndexDefinition(
         | 
| 286 309 | 
             
                        prefix=["pic-"],
         | 
| 287 310 | 
             
                        index_type=IndexType.JSON
         | 
| 288 311 | 
             
                    )
         | 
| 289 312 | 
             
                    res = self.db_conn.ft(
         | 
| 290 | 
            -
                         | 
| 313 | 
            +
                        self.pic_idx_name
         | 
| 291 314 | 
             
                    ).create_index(
         | 
| 292 315 | 
             
                        fields=schema,
         | 
| 293 316 | 
             
                        definition=definition
         | 
| 294 317 | 
             
                    )
         | 
| 295 318 | 
             
                    print("create_index:", res)
         | 
| 319 | 
            +
                    if print_idx_info:
         | 
| 320 | 
            +
                        print(self.pic_idx_info)
         | 
| 296 321 |  | 
| 297 322 | 
             
                @property
         | 
| 298 323 | 
             
                def pic_idx_info(self):
         | 
| 299 324 | 
             
                    res = self.db_conn.ping()
         | 
| 300 325 | 
             
                    print("redis connected:", res)
         | 
| 301 | 
            -
             | 
| 302 | 
            -
                    vector_idx_name = "idx:pic_idx"
         | 
| 303 | 
            -
             | 
| 326 | 
            +
                    # vector_idx_name = "idx:pic_idx"
         | 
| 304 327 | 
             
                    # 从 Redis 数据库中读取索引状态
         | 
| 305 | 
            -
                    info = self.db_conn.ft( | 
| 328 | 
            +
                    info = self.db_conn.ft(self.pic_idx_name).info()
         | 
| 306 329 | 
             
                    # 获取索引状态中的 num_docs 和 hash_indexing_failures
         | 
| 307 330 | 
             
                    num_docs = info["num_docs"]
         | 
| 308 331 | 
             
                    indexing_failures = info["hash_indexing_failures"]
         | 
| 309 | 
            -
                     | 
| 332 | 
            +
                    return (
         | 
| 333 | 
            +
                        f"{num_docs} documents indexed with {indexing_failures} failures"
         | 
| 334 | 
            +
                    )
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                def emb_search(
         | 
| 337 | 
            +
                    self,
         | 
| 338 | 
            +
                    emb_query,
         | 
| 339 | 
            +
                    num_max: int = 3,
         | 
| 340 | 
            +
                    extra_params: dict = None,
         | 
| 341 | 
            +
                ):
         | 
| 342 | 
            +
                    """Search for similar embeddings in the database.
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                    Args:
         | 
| 345 | 
            +
                        emb_query (str): The embedding query to search for.
         | 
| 346 | 
            +
                        num_max (int, optional): The maximum number of results to return. Defaults to 3.
         | 
| 347 | 
            +
                        extra_params (dict, optional): Extra parameters to include in the search query. Defaults to None.
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                    Returns:
         | 
| 350 | 
            +
                        list: A list of tuples containing the document ID and JSON data for each result.
         | 
| 351 | 
            +
                    """
         | 
| 352 | 
            +
                    query = (
         | 
| 353 | 
            +
                        Query(
         | 
| 354 | 
            +
                            f"(*)=>[KNN {str(num_max)} @vector $query_vector AS vector_score]"
         | 
| 355 | 
            +
                        )
         | 
| 356 | 
            +
                        .sort_by("vector_score")
         | 
| 357 | 
            +
                        .return_fields("$")
         | 
| 358 | 
            +
                        .dialect(2)
         | 
| 359 | 
            +
                    )
         | 
| 360 | 
            +
                    if extra_params is None:
         | 
| 361 | 
            +
                        extra_params = {}
         | 
| 362 | 
            +
                    result_docs = (
         | 
| 363 | 
            +
                        self.db_conn.ft("idx:pic_idx")
         | 
| 364 | 
            +
                        .search(
         | 
| 365 | 
            +
                            query,
         | 
| 366 | 
            +
                            {
         | 
| 367 | 
            +
                                "query_vector": emb_query
         | 
| 368 | 
            +
                            }
         | 
| 369 | 
            +
                            | extra_params,
         | 
| 370 | 
            +
                        )
         | 
| 371 | 
            +
                        .docs
         | 
| 372 | 
            +
                    )
         | 
| 373 | 
            +
                    results = [
         | 
| 374 | 
            +
                        (result_doc.id, json.loads(result_doc.json))
         | 
| 375 | 
            +
                        for result_doc in result_docs
         | 
| 376 | 
            +
                    ]
         | 
| 377 | 
            +
                    return results
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                def img_search(
         | 
| 380 | 
            +
                    self,
         | 
| 381 | 
            +
                    img,
         | 
| 382 | 
            +
                    num_max: int = 3,
         | 
| 383 | 
            +
                    extra_params: dict = None
         | 
| 384 | 
            +
                ):
         | 
| 385 | 
            +
                    """Search for similar images in the database based on the input image.
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                    Args:
         | 
| 388 | 
            +
                        img: Input image to search for similar images.
         | 
| 389 | 
            +
                        num_max: Maximum number of similar images to return (default is 3).
         | 
| 390 | 
            +
                        extra_params: Additional parameters to include in the search query (default is None).
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                    Returns:
         | 
| 393 | 
            +
                        List of tuples containing the ID and JSON data of similar images found in the database.
         | 
| 394 | 
            +
                    """
         | 
| 395 | 
            +
                    emb_query = self.get_img_features(
         | 
| 396 | 
            +
                        [img], to_numpy=True
         | 
| 397 | 
            +
                    ).astype(np.float32)[0].tobytes()
         | 
| 398 | 
            +
                    results = self.emb_search(
         | 
| 399 | 
            +
                        emb_query=emb_query,
         | 
| 400 | 
            +
                        num_max=num_max,
         | 
| 401 | 
            +
                        extra_params=extra_params
         | 
| 402 | 
            +
                    )
         | 
| 403 | 
            +
                    return results
         | 
| 404 | 
            +
             | 
| 310 405 |  | 
| 311 406 |  | 
| @@ -1,5 +1,5 @@ | |
| 1 1 | 
             
            hdl/__init__.py,sha256=GffnD0jLJdhkd-vo989v40N90sQbofkayRBwxc6TVhQ,72
         | 
| 2 | 
            -
            hdl/_version.py,sha256= | 
| 2 | 
            +
            hdl/_version.py,sha256=Ov_2Aea_x3wREaTTkbdMMjDZcnW666PWCOUw0Yjj1mA,413
         | 
| 3 3 | 
             
            hdl/args/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 4 4 | 
             
            hdl/args/loss_args.py,sha256=s7YzSdd7IjD24rZvvOrxLLFqMZQb9YylxKeyelSdrTk,70
         | 
| 5 5 | 
             
            hdl/controllers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| @@ -131,12 +131,12 @@ hdl/utils/llm/chat.py,sha256=sk7Lw5Oa30k-l2fnJknkMmTc5zkBeEKsR981aeFhH5s,11907 | |
| 131 131 | 
             
            hdl/utils/llm/embs.py,sha256=Tf0FOYrOFZp7qQpEPiSCXzlgyHH0X9HVTUtsup74a9E,7174
         | 
| 132 132 | 
             
            hdl/utils/llm/extract.py,sha256=2sK_WJzmYIc8iuWaM9DA6Nw3_6q1O4lJ5pKpcZo-bBA,6512
         | 
| 133 133 | 
             
            hdl/utils/llm/llama_chat.py,sha256=watcHGOaz-bv3x-yDucYlGk5f8FiqfFhwWogrl334fk,4387
         | 
| 134 | 
            -
            hdl/utils/llm/vis.py,sha256= | 
| 134 | 
            +
            hdl/utils/llm/vis.py,sha256=_xHEZ2_sRp7UiOWaI_6v96rBZwiQ7JkUF-DX5OHC7HE,13105
         | 
| 135 135 | 
             
            hdl/utils/schedulers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 136 136 | 
             
            hdl/utils/schedulers/norm_lr.py,sha256=bDwCmdEK-WkgxQMFBiMuchv8Mm7C0-GZJ6usm-PQk14,4461
         | 
| 137 137 | 
             
            hdl/utils/weather/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 138 138 | 
             
            hdl/utils/weather/weather.py,sha256=k11o6wM15kF8b9NMlEfrg68ak-SfSYLN3nOOflFUv-I,4381
         | 
| 139 | 
            -
            hjxdl-0.1. | 
| 140 | 
            -
            hjxdl-0.1. | 
| 141 | 
            -
            hjxdl-0.1. | 
| 142 | 
            -
            hjxdl-0.1. | 
| 139 | 
            +
            hjxdl-0.1.71.dist-info/METADATA,sha256=jAEXEO3CVhotWgJjClsaKFlL59YBd2r6iotdUxo2o3g,903
         | 
| 140 | 
            +
            hjxdl-0.1.71.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
         | 
| 141 | 
            +
            hjxdl-0.1.71.dist-info/top_level.txt,sha256=-kxwTM5JPhylp06z3zAVO3w6_h7wtBfBo2zgM6YZoTk,4
         | 
| 142 | 
            +
            hjxdl-0.1.71.dist-info/RECORD,,
         | 
| 
            File without changes
         | 
| 
            File without changes
         |