vector-index-embedding 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vector_index_embedding-0.1.0.dist-info/METADATA +28 -0
- vector_index_embedding-0.1.0.dist-info/RECORD +7 -0
- vector_index_embedding-0.1.0.dist-info/WHEEL +5 -0
- vector_index_embedding-0.1.0.dist-info/top_level.txt +1 -0
- vectorindex/__init__.py +2 -0
- vectorindex/flat_index.py +16 -0
- vectorindex/vector_index.py +102 -0
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: vector-index-embedding
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: HNSW based output embeddings for LLM's
|
|
5
|
+
Author-email: Martin Loretz <pypi@martinloretz.com>
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Classifier: Programming Language :: Python :: 3
|
|
8
|
+
Classifier: Operating System :: OS Independent
|
|
9
|
+
Requires-Python: >=3.9
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
Requires-Dist: torch
|
|
12
|
+
Requires-Dist: hnswlib
|
|
13
|
+
Requires-Dist: huggingface_hub
|
|
14
|
+
|
|
15
|
+
# vector-index-embedding
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
## Faster hnswlib
|
|
21
|
+
|
|
22
|
+
Our faster implementation can be found here: https://github.com/martinloretzzz/hnswlib
|
|
23
|
+
|
|
24
|
+
Warning: This implementation might not work on all systems as it was only tested on the one where we'Re running the bnechmarks and the SIMD implementation was only adapted for that architecture.
|
|
25
|
+
|
|
26
|
+
For benchmarking we use our own fork of hnswlib, that includes 2 improvements for fast inner product distances on high dimensional data:
|
|
27
|
+
- We calculate all the inner products in paralell, that way we reduce memory accesses in half (we load one element of the query and compare it to N other vectors at the same time)
|
|
28
|
+
- We removed a heuristic that restricted multi-threading, as as our data is extremly high dimesional and always benifit from using all cores.
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
vectorindex/__init__.py,sha256=QcAEBigbm7RZxX8u4F6hrB85tuS77IMCrbBgJ-6dOtE,118
|
|
2
|
+
vectorindex/flat_index.py,sha256=ouTpOBhuT28pOpQXLbyRzMYdAPu2SXy-h0IOfIATe4g,588
|
|
3
|
+
vectorindex/vector_index.py,sha256=e9ViG9N6sEbbLXkbS-BHxqrEy7d_pzcgzMksTnpa67c,4725
|
|
4
|
+
vector_index_embedding-0.1.0.dist-info/METADATA,sha256=sUGd7MZAoS3878UvHZMdTmP5j1QTWuA9tQbBCBwJICc,1215
|
|
5
|
+
vector_index_embedding-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
|
|
6
|
+
vector_index_embedding-0.1.0.dist-info/top_level.txt,sha256=3UTzrAMuIGn-k6iQ_K6X2x51t8eNAp6VKHD3-VyOCbM,12
|
|
7
|
+
vector_index_embedding-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
vectorindex
|
vectorindex/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
class FlatIndexEmbedding(nn.Module):
|
|
5
|
+
def __init__(self, layer: nn.Linear, k: int):
|
|
6
|
+
super().__init__()
|
|
7
|
+
self.layer, self.k = layer, k
|
|
8
|
+
|
|
9
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
10
|
+
x = self.layer(x)
|
|
11
|
+
return self.mask_non_topk(x, k=self.k, fill=float("-inf"))
|
|
12
|
+
|
|
13
|
+
def mask_non_topk(self, x: torch.Tensor, k: int, fill: float):
|
|
14
|
+
_, indices = torch.topk(x, k, dim=-1)
|
|
15
|
+
mask = torch.full_like(x, fill, dtype=x.dtype, device=x.device)
|
|
16
|
+
return mask.scatter(-1, indices, x.gather(-1, indices))
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from huggingface_hub import hf_hub_download
|
|
2
|
+
from huggingface_hub.errors import RemoteEntryNotFoundError
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import hnswlib
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from dataclasses import dataclass, asdict
|
|
8
|
+
import json
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class VectorIndexEmbeddingConfig:
|
|
12
|
+
model_name: str
|
|
13
|
+
k: int
|
|
14
|
+
ef: int
|
|
15
|
+
M: int
|
|
16
|
+
ef_construction: int
|
|
17
|
+
special_tokens: list[int] | None = None
|
|
18
|
+
dim: int = -1 # set from weight in build_index
|
|
19
|
+
vocab_size: int = -1 # set from weight in build_index
|
|
20
|
+
model_id: str | None = None
|
|
21
|
+
|
|
22
|
+
class VectorIndexEmbedding(nn.Module):
|
|
23
|
+
def __init__(self, config: VectorIndexEmbeddingConfig, index_path: str):
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.config = config
|
|
26
|
+
self.index = hnswlib.Index(space='ip', dim=int(config.dim))
|
|
27
|
+
self.index.load_index(index_path)
|
|
28
|
+
self.index.set_ef(config.ef)
|
|
29
|
+
self.num_threads = -1
|
|
30
|
+
|
|
31
|
+
if config.special_tokens is not None:
|
|
32
|
+
self.special_token_indices = torch.tensor(config.special_tokens, dtype=torch.long)
|
|
33
|
+
self.special_token_weight = torch.from_numpy(self.index.get_items(config.special_tokens, return_type="numpy"))
|
|
34
|
+
|
|
35
|
+
@torch.compiler.disable
|
|
36
|
+
def topk(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
37
|
+
indices, distances = self.index.knn_query(x.detach().float().numpy(), k=self.config.k, num_threads=self.num_threads)
|
|
38
|
+
return 1.0 - torch.from_numpy(distances), torch.from_numpy(indices).to(torch.int64)
|
|
39
|
+
|
|
40
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
41
|
+
# return torch.full((x.shape[0], x.shape[1], self.config.vocab_size), 0, dtype=x.dtype, device=x.device)
|
|
42
|
+
|
|
43
|
+
x_flat = x.view(-1, x.shape[-1]).float()
|
|
44
|
+
distances, indices = self.topk(x_flat.cpu())
|
|
45
|
+
|
|
46
|
+
logits = torch.full((x_flat.shape[0], self.config.vocab_size), float("-inf"), dtype=x.dtype, device=x.device)
|
|
47
|
+
logits.scatter_(-1, indices.to(x.device), distances.to(x.device).to(x.dtype))
|
|
48
|
+
|
|
49
|
+
if self.config.special_tokens is not None:
|
|
50
|
+
special_token_distances = torch.matmul(x_flat, self.special_token_weight.to(x.device).T).to(x.dtype)
|
|
51
|
+
special_token_indices = self.special_token_indices.to(x.device).unsqueeze(0).expand(x_flat.shape[0], -1)
|
|
52
|
+
logits.scatter_(-1, special_token_indices, special_token_distances)
|
|
53
|
+
|
|
54
|
+
return logits.view((x.shape[0], x.shape[1], self.config.vocab_size))
|
|
55
|
+
|
|
56
|
+
def set_ef(self, ef: int):
|
|
57
|
+
self.index.set_ef(ef)
|
|
58
|
+
|
|
59
|
+
@staticmethod
|
|
60
|
+
def from_pretrained(model_id: str, ef = None, k = None, repo_id = "martinloretzzz/vector-index-embedding") -> "VectorIndexEmbedding":
|
|
61
|
+
try:
|
|
62
|
+
index_name = VectorIndexEmbedding.get_index_name(model_id)
|
|
63
|
+
local_path = hf_hub_download(repo_id=repo_id, filename=f"{index_name}.index")
|
|
64
|
+
local_path_config = hf_hub_download(repo_id=repo_id, filename=f"{index_name}.json")
|
|
65
|
+
return VectorIndexEmbedding.from_file(local_path, ef=ef, k=k, config_path=local_path_config)
|
|
66
|
+
except RemoteEntryNotFoundError:
|
|
67
|
+
raise Exception(f"No prebuilt vector index for model '{model_id}' was found. To build your own index, please follow the guide at: https://github.com/martinloretzzz/vector-index-embedding")
|
|
68
|
+
|
|
69
|
+
@staticmethod
|
|
70
|
+
def from_file(path: str, ef = None, k = None, config_path: str | None = None) -> "VectorIndexEmbedding":
|
|
71
|
+
index_path = Path(path)
|
|
72
|
+
if config_path is None:
|
|
73
|
+
config_path = index_path.resolve().with_suffix(".json")
|
|
74
|
+
with open(config_path, "r") as f:
|
|
75
|
+
config = VectorIndexEmbeddingConfig(**json.load(f))
|
|
76
|
+
|
|
77
|
+
if k is not None: config.k = k
|
|
78
|
+
if ef is not None: config.ef = ef
|
|
79
|
+
|
|
80
|
+
return VectorIndexEmbedding(config, index_path=str(index_path.absolute()))
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
def build_index(weight: torch.Tensor, config: VectorIndexEmbeddingConfig, save_path: str = "data", seed=42):
|
|
84
|
+
config.vocab_size, config.dim = weight.shape
|
|
85
|
+
index_file = Path(save_path) / Path(f"{config.model_name}.index")
|
|
86
|
+
index_file.parent.mkdir(parents=True, exist_ok=True)
|
|
87
|
+
|
|
88
|
+
index = hnswlib.Index(space='ip', dim=config.dim)
|
|
89
|
+
index.init_index(max_elements=config.vocab_size, M=config.M, ef_construction=config.ef_construction, random_seed=seed)
|
|
90
|
+
index.set_ef(config.ef)
|
|
91
|
+
index.add_items(weight.cpu().numpy())
|
|
92
|
+
index.save_index(str(index_file))
|
|
93
|
+
|
|
94
|
+
with open(index_file.resolve().with_suffix(".json"), "w") as f:
|
|
95
|
+
json.dump(asdict(config), f, indent=4)
|
|
96
|
+
|
|
97
|
+
print(f"Index saved to {index_file}")
|
|
98
|
+
return str(index_file)
|
|
99
|
+
|
|
100
|
+
@staticmethod
|
|
101
|
+
def get_index_name(hf_model_id: str):
|
|
102
|
+
return hf_model_id.lower().replace("/", "-").replace(".", "-")
|