torch-rechub 0.0.6__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/basic/layers.py +228 -159
- torch_rechub/basic/loss_func.py +62 -47
- torch_rechub/data/dataset.py +18 -31
- torch_rechub/models/generative/hstu.py +48 -33
- torch_rechub/serving/__init__.py +50 -0
- torch_rechub/serving/annoy.py +133 -0
- torch_rechub/serving/base.py +107 -0
- torch_rechub/serving/faiss.py +154 -0
- torch_rechub/serving/milvus.py +215 -0
- torch_rechub/trainers/ctr_trainer.py +12 -2
- torch_rechub/trainers/match_trainer.py +13 -2
- torch_rechub/trainers/mtl_trainer.py +12 -2
- torch_rechub/trainers/seq_trainer.py +34 -15
- torch_rechub/types.py +5 -0
- torch_rechub/utils/data.py +191 -145
- torch_rechub/utils/hstu_utils.py +87 -76
- torch_rechub/utils/model_utils.py +10 -12
- torch_rechub/utils/onnx_export.py +98 -45
- torch_rechub/utils/quantization.py +128 -0
- torch_rechub/utils/visualization.py +4 -12
- {torch_rechub-0.0.6.dist-info → torch_rechub-0.2.0.dist-info}/METADATA +34 -18
- {torch_rechub-0.0.6.dist-info → torch_rechub-0.2.0.dist-info}/RECORD +24 -18
- torch_rechub/trainers/matching.md +0 -3
- {torch_rechub-0.0.6.dist-info → torch_rechub-0.2.0.dist-info}/WHEEL +0 -0
- {torch_rechub-0.0.6.dist-info → torch_rechub-0.2.0.dist-info}/licenses/LICENSE +0 -0
torch_rechub/basic/loss_func.py
CHANGED
|
@@ -4,13 +4,24 @@ import torch.nn as nn
|
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class RegularizationLoss(nn.Module):
|
|
7
|
-
"""Unified L1/L2
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
7
|
+
"""Unified L1/L2 regularization for embedding and dense parameters.
|
|
8
|
+
|
|
9
|
+
Parameters
|
|
10
|
+
----------
|
|
11
|
+
embedding_l1 : float, default=0.0
|
|
12
|
+
L1 coefficient for embedding parameters.
|
|
13
|
+
embedding_l2 : float, default=0.0
|
|
14
|
+
L2 coefficient for embedding parameters.
|
|
15
|
+
dense_l1 : float, default=0.0
|
|
16
|
+
L1 coefficient for dense (non-embedding) parameters.
|
|
17
|
+
dense_l2 : float, default=0.0
|
|
18
|
+
L2 coefficient for dense (non-embedding) parameters.
|
|
19
|
+
|
|
20
|
+
Examples
|
|
21
|
+
--------
|
|
22
|
+
>>> reg_loss_fn = RegularizationLoss(embedding_l2=1e-5, dense_l2=1e-5)
|
|
23
|
+
>>> reg_loss = reg_loss_fn(model)
|
|
24
|
+
>>> total_loss = task_loss + reg_loss
|
|
14
25
|
"""
|
|
15
26
|
|
|
16
27
|
def __init__(self, embedding_l1=0.0, embedding_l2=0.0, dense_l1=0.0, dense_l2=0.0):
|
|
@@ -58,9 +69,11 @@ class RegularizationLoss(nn.Module):
|
|
|
58
69
|
|
|
59
70
|
|
|
60
71
|
class HingeLoss(torch.nn.Module):
|
|
61
|
-
"""Hinge
|
|
62
|
-
reference: https://github.com/ustcml/RecStudio/blob/main/recstudio/model/loss_func.py
|
|
72
|
+
"""Hinge loss for pairwise learning.
|
|
63
73
|
|
|
74
|
+
Notes
|
|
75
|
+
-----
|
|
76
|
+
Reference: https://github.com/ustcml/RecStudio/blob/main/recstudio/model/loss_func.py
|
|
64
77
|
"""
|
|
65
78
|
|
|
66
79
|
def __init__(self, margin=2, num_items=None):
|
|
@@ -89,27 +102,28 @@ class BPRLoss(torch.nn.Module):
|
|
|
89
102
|
|
|
90
103
|
|
|
91
104
|
class NCELoss(torch.nn.Module):
|
|
92
|
-
"""Noise Contrastive Estimation (NCE)
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
105
|
+
"""Noise Contrastive Estimation (NCE) loss for recommender systems.
|
|
106
|
+
|
|
107
|
+
Parameters
|
|
108
|
+
----------
|
|
109
|
+
temperature : float, default=1.0
|
|
110
|
+
Temperature for scaling logits.
|
|
111
|
+
ignore_index : int, default=0
|
|
112
|
+
Target index to ignore.
|
|
113
|
+
reduction : {'mean', 'sum', 'none'}, default='mean'
|
|
114
|
+
Reduction applied to the output.
|
|
115
|
+
|
|
116
|
+
Notes
|
|
117
|
+
-----
|
|
118
|
+
- Gutmann & Hyvärinen (2010), Noise-contrastive estimation.
|
|
119
|
+
- HLLM: Hierarchical Large Language Model for Recommendation.
|
|
120
|
+
|
|
121
|
+
Examples
|
|
122
|
+
--------
|
|
123
|
+
>>> nce_loss = NCELoss(temperature=0.1)
|
|
124
|
+
>>> logits = torch.randn(32, 1000)
|
|
125
|
+
>>> targets = torch.randint(0, 1000, (32,))
|
|
126
|
+
>>> loss = nce_loss(logits, targets)
|
|
113
127
|
"""
|
|
114
128
|
|
|
115
129
|
def __init__(self, temperature=1.0, ignore_index=0, reduction='mean'):
|
|
@@ -158,23 +172,24 @@ class NCELoss(torch.nn.Module):
|
|
|
158
172
|
|
|
159
173
|
|
|
160
174
|
class InBatchNCELoss(torch.nn.Module):
|
|
161
|
-
"""In-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
175
|
+
"""In-batch NCE loss with explicit negatives.
|
|
176
|
+
|
|
177
|
+
Parameters
|
|
178
|
+
----------
|
|
179
|
+
temperature : float, default=0.1
|
|
180
|
+
Temperature for scaling logits.
|
|
181
|
+
ignore_index : int, default=0
|
|
182
|
+
Target index to ignore.
|
|
183
|
+
reduction : {'mean', 'sum', 'none'}, default='mean'
|
|
184
|
+
Reduction applied to the output.
|
|
185
|
+
|
|
186
|
+
Examples
|
|
187
|
+
--------
|
|
188
|
+
>>> loss_fn = InBatchNCELoss(temperature=0.1)
|
|
189
|
+
>>> embeddings = torch.randn(32, 256)
|
|
190
|
+
>>> item_embeddings = torch.randn(1000, 256)
|
|
191
|
+
>>> targets = torch.randint(0, 1000, (32,))
|
|
192
|
+
>>> loss = loss_fn(embeddings, item_embeddings, targets)
|
|
178
193
|
"""
|
|
179
194
|
|
|
180
195
|
def __init__(self, temperature=0.1, ignore_index=0, reduction='mean'):
|
torch_rechub/data/dataset.py
CHANGED
|
@@ -1,40 +1,35 @@
|
|
|
1
1
|
"""Dataset implementations providing streaming, batch-wise data access for PyTorch."""
|
|
2
2
|
|
|
3
|
-
import os
|
|
4
3
|
import typing as ty
|
|
5
4
|
|
|
6
5
|
import pyarrow.dataset as pd
|
|
7
6
|
import torch
|
|
8
7
|
from torch.utils.data import IterableDataset, get_worker_info
|
|
9
8
|
|
|
10
|
-
from .
|
|
9
|
+
from torch_rechub.types import FilePath
|
|
11
10
|
|
|
12
|
-
|
|
13
|
-
_FilePath = ty.Union[str, os.PathLike]
|
|
11
|
+
from .convert import pa_array_to_tensor
|
|
14
12
|
|
|
15
13
|
# The default batch size when reading a Parquet dataset
|
|
16
14
|
_DEFAULT_BATCH_SIZE = 1024
|
|
17
15
|
|
|
18
16
|
|
|
19
17
|
class ParquetIterableDataset(IterableDataset):
|
|
20
|
-
"""
|
|
21
|
-
IterableDataset that streams data from one or more Parquet files.
|
|
18
|
+
"""Stream Parquet data as PyTorch tensors.
|
|
22
19
|
|
|
23
20
|
Parameters
|
|
24
21
|
----------
|
|
25
|
-
file_paths : list[
|
|
22
|
+
file_paths : list[FilePath]
|
|
26
23
|
Paths to Parquet files.
|
|
27
24
|
columns : list[str], optional
|
|
28
|
-
|
|
29
|
-
batch_size : int, default
|
|
30
|
-
|
|
25
|
+
Columns to select; if ``None``, read all columns.
|
|
26
|
+
batch_size : int, default _DEFAULT_BATCH_SIZE
|
|
27
|
+
Rows per streamed batch.
|
|
31
28
|
|
|
32
29
|
Notes
|
|
33
30
|
-----
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
Dataset and Scanner. Iteration yields dictionaries mapping column names to PyTorch
|
|
37
|
-
tensors created via NumPy, one batch at a time.
|
|
31
|
+
Reads lazily; no full Parquet load. Each worker gets a partition, builds its
|
|
32
|
+
own PyArrow Dataset/Scanner, and yields dicts of column tensors batch by batch.
|
|
38
33
|
|
|
39
34
|
Examples
|
|
40
35
|
--------
|
|
@@ -44,16 +39,14 @@ class ParquetIterableDataset(IterableDataset):
|
|
|
44
39
|
... batch_size=1024,
|
|
45
40
|
... )
|
|
46
41
|
>>> loader = DataLoader(ds, batch_size=None)
|
|
47
|
-
>>> # Now iterate over batches.
|
|
48
42
|
>>> for batch in loader:
|
|
49
43
|
... x, y, label = batch["x"], batch["y"], batch["label"]
|
|
50
|
-
... # Do some work.
|
|
51
44
|
... ...
|
|
52
45
|
"""
|
|
53
46
|
|
|
54
47
|
def __init__(
|
|
55
48
|
self,
|
|
56
|
-
file_paths: ty.Sequence[
|
|
49
|
+
file_paths: ty.Sequence[FilePath],
|
|
57
50
|
/,
|
|
58
51
|
columns: ty.Optional[ty.Sequence[str]] = None,
|
|
59
52
|
batch_size: int = _DEFAULT_BATCH_SIZE,
|
|
@@ -64,17 +57,15 @@ class ParquetIterableDataset(IterableDataset):
|
|
|
64
57
|
self._batch_size = batch_size
|
|
65
58
|
|
|
66
59
|
def __iter__(self) -> ty.Iterator[dict[str, torch.Tensor]]:
|
|
67
|
-
"""
|
|
68
|
-
Stream Parquet data as mapped PyTorch tensors.
|
|
60
|
+
"""Stream Parquet data as mapped PyTorch tensors.
|
|
69
61
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
converted to a dict mapping column names to PyTorch tensors (via NumPy).
|
|
62
|
+
Builds a PyArrow Dataset from the current worker's file partition, then
|
|
63
|
+
lazily scans selected columns. Each batch becomes a dict of Torch tensors.
|
|
73
64
|
|
|
74
65
|
Returns
|
|
75
66
|
-------
|
|
76
67
|
Iterator[dict[str, torch.Tensor]]
|
|
77
|
-
|
|
68
|
+
One converted batch at a time.
|
|
78
69
|
"""
|
|
79
70
|
if not (partition := self._get_partition()):
|
|
80
71
|
return
|
|
@@ -95,19 +86,15 @@ class ParquetIterableDataset(IterableDataset):
|
|
|
95
86
|
# private interfaces
|
|
96
87
|
|
|
97
88
|
def _get_partition(self) -> tuple[str, ...]:
|
|
98
|
-
"""
|
|
99
|
-
Get the partition of file paths for the current worker.
|
|
100
|
-
|
|
101
|
-
This method splits the full list of file paths into contiguous partitions with
|
|
102
|
-
a nearly equal size by the total number of workers and the current worker ID.
|
|
89
|
+
"""Get file partition for the current worker.
|
|
103
90
|
|
|
104
|
-
|
|
105
|
-
|
|
91
|
+
Splits file paths into contiguous partitions by number of workers and worker ID.
|
|
92
|
+
In the main process (no worker info), returns all paths.
|
|
106
93
|
|
|
107
94
|
Returns
|
|
108
95
|
-------
|
|
109
96
|
tuple[str, ...]
|
|
110
|
-
|
|
97
|
+
Partition of file paths for this worker.
|
|
111
98
|
"""
|
|
112
99
|
if (info := get_worker_info()) is None:
|
|
113
100
|
return self._file_paths
|
|
@@ -10,39 +10,54 @@ from torch_rechub.utils.hstu_utils import RelPosBias
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class HSTUModel(nn.Module):
|
|
13
|
-
"""HSTU: Hierarchical Sequential Transduction Units
|
|
14
|
-
|
|
15
|
-
Autoregressive generative
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
13
|
+
"""HSTU: Hierarchical Sequential Transduction Units.
|
|
14
|
+
|
|
15
|
+
Autoregressive generative recommender that stacks ``HSTUBlock`` layers to
|
|
16
|
+
capture long-range dependencies and predict the next item.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
vocab_size : int
|
|
21
|
+
Vocabulary size (items incl. PAD).
|
|
22
|
+
d_model : int, default=512
|
|
23
|
+
Hidden dimension.
|
|
24
|
+
n_heads : int, default=8
|
|
25
|
+
Attention heads.
|
|
26
|
+
n_layers : int, default=4
|
|
27
|
+
Number of stacked HSTU layers.
|
|
28
|
+
dqk : int, default=64
|
|
29
|
+
Query/key dim per head.
|
|
30
|
+
dv : int, default=64
|
|
31
|
+
Value dim per head.
|
|
32
|
+
max_seq_len : int, default=256
|
|
33
|
+
Maximum sequence length.
|
|
34
|
+
dropout : float, default=0.1
|
|
35
|
+
Dropout rate.
|
|
36
|
+
use_rel_pos_bias : bool, default=True
|
|
37
|
+
Use relative position bias.
|
|
38
|
+
use_time_embedding : bool, default=True
|
|
39
|
+
Use time-difference embeddings.
|
|
40
|
+
num_time_buckets : int, default=2048
|
|
41
|
+
Number of time buckets for time embeddings.
|
|
42
|
+
time_bucket_fn : {'sqrt', 'log'}, default='sqrt'
|
|
43
|
+
Bucketization function for time differences.
|
|
44
|
+
|
|
45
|
+
Shape
|
|
46
|
+
-----
|
|
47
|
+
Input
|
|
48
|
+
x : ``(batch_size, seq_len)``
|
|
49
|
+
time_diffs : ``(batch_size, seq_len)``, optional (seconds).
|
|
50
|
+
Output
|
|
51
|
+
logits : ``(batch_size, seq_len, vocab_size)``
|
|
52
|
+
|
|
53
|
+
Examples
|
|
54
|
+
--------
|
|
55
|
+
>>> model = HSTUModel(vocab_size=100000, d_model=512)
|
|
56
|
+
>>> x = torch.randint(0, 100000, (32, 256))
|
|
57
|
+
>>> time_diffs = torch.randint(0, 86400, (32, 256))
|
|
58
|
+
>>> logits = model(x, time_diffs)
|
|
59
|
+
>>> logits.shape
|
|
60
|
+
torch.Size([32, 256, 100000])
|
|
46
61
|
"""
|
|
47
62
|
|
|
48
63
|
def __init__(self, vocab_size, d_model=512, n_heads=8, n_layers=4, dqk=64, dv=64, max_seq_len=256, dropout=0.1, use_rel_pos_bias=True, use_time_embedding=True, num_time_buckets=2048, time_bucket_fn='sqrt'):
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import typing as ty
|
|
2
|
+
|
|
3
|
+
from .annoy import AnnoyBuilder
|
|
4
|
+
from .base import BaseBuilder
|
|
5
|
+
from .faiss import FaissBuilder
|
|
6
|
+
from .milvus import MilvusBuilder
|
|
7
|
+
|
|
8
|
+
# Type for supported retrieval models.
|
|
9
|
+
_RetrievalModel = ty.Literal["annoy", "faiss", "milvus"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def builder_factory(model: _RetrievalModel, **builder_config) -> BaseBuilder:
|
|
13
|
+
"""
|
|
14
|
+
Factory function for creating a vector index builder.
|
|
15
|
+
|
|
16
|
+
This function instantiates and returns a concrete implementation of ``BaseBuilder``
|
|
17
|
+
based on the specified retrieval backend. The returned builder is responsible for
|
|
18
|
+
constructing or loading the underlying ANN index via its own ``from_embeddings`` or
|
|
19
|
+
``from_index_file`` method.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
model : "annoy", "faiss", or "milvus"
|
|
24
|
+
The retrieval backend to use.
|
|
25
|
+
**builder_config
|
|
26
|
+
Keyword arguments passed directly to the selected builder constructor.
|
|
27
|
+
|
|
28
|
+
Returns
|
|
29
|
+
-------
|
|
30
|
+
BaseBuilder
|
|
31
|
+
A concrete builder instance corresponding to the specified retrieval backend.
|
|
32
|
+
|
|
33
|
+
Raises
|
|
34
|
+
------
|
|
35
|
+
NotImplementedError
|
|
36
|
+
if the specified retrieval model is not supported.
|
|
37
|
+
"""
|
|
38
|
+
if model == "annoy":
|
|
39
|
+
return AnnoyBuilder(**builder_config)
|
|
40
|
+
|
|
41
|
+
if model == "faiss":
|
|
42
|
+
return FaissBuilder(**builder_config)
|
|
43
|
+
|
|
44
|
+
if model == "milvus":
|
|
45
|
+
return MilvusBuilder(**builder_config)
|
|
46
|
+
|
|
47
|
+
raise NotImplementedError(f"{model=} is not implemented yet!")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
__all__ = ["builder_factory"]
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""ANNOY-based vector index implementation for the retrieval stage."""
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import typing as ty
|
|
5
|
+
|
|
6
|
+
import annoy
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from torch_rechub.types import FilePath
|
|
11
|
+
|
|
12
|
+
from .base import BaseBuilder, BaseIndexer
|
|
13
|
+
|
|
14
|
+
# Type for distance metrics for the ANNOY index.
|
|
15
|
+
_AnnoyMetric = ty.Literal["angular", "euclidean", "dot"]
|
|
16
|
+
|
|
17
|
+
# Default distance metric used by ANNOY.
|
|
18
|
+
_DEFAULT_METRIC: _AnnoyMetric = "angular"
|
|
19
|
+
|
|
20
|
+
# Default number of trees to build in the ANNOY index.
|
|
21
|
+
_DEFAULT_N_TREES = 10
|
|
22
|
+
|
|
23
|
+
# Default number of worker threads for building the ANNOY index.
|
|
24
|
+
_DEFAULT_THREADS = -1
|
|
25
|
+
|
|
26
|
+
# Default number of nodes to inspect during an ANNOY search.
|
|
27
|
+
_DEFAULT_SEARCHK = -1
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class AnnoyBuilder(BaseBuilder):
|
|
31
|
+
"""ANNOY-based implementation of ``BaseBuilder``."""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
d: int,
|
|
36
|
+
metric: _AnnoyMetric = _DEFAULT_METRIC,
|
|
37
|
+
*,
|
|
38
|
+
n_trees: int = _DEFAULT_N_TREES,
|
|
39
|
+
threads: int = _DEFAULT_THREADS,
|
|
40
|
+
searchk: int = _DEFAULT_SEARCHK,
|
|
41
|
+
) -> None:
|
|
42
|
+
"""
|
|
43
|
+
Initialize a ANNOY builder.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
d : int
|
|
48
|
+
The dimension of embeddings.
|
|
49
|
+
metric : ``"angular"``, ``"euclidean"``, or ``"dot"``, optional
|
|
50
|
+
The indexing metric. Default to ``"angular"``.
|
|
51
|
+
n_trees : int, optional
|
|
52
|
+
Number of trees to build an ANNOY index.
|
|
53
|
+
threads : int, optional
|
|
54
|
+
Number of worker threads to build an ANNOY index.
|
|
55
|
+
searchk : int, optional
|
|
56
|
+
Number of nodes to inspect during an ANNOY search.
|
|
57
|
+
"""
|
|
58
|
+
self._d = d
|
|
59
|
+
self._metric = metric
|
|
60
|
+
|
|
61
|
+
self._n_trees = n_trees
|
|
62
|
+
self._threads = threads
|
|
63
|
+
self._searchk = searchk
|
|
64
|
+
|
|
65
|
+
@contextlib.contextmanager
|
|
66
|
+
def from_embeddings(
|
|
67
|
+
self,
|
|
68
|
+
embeddings: torch.Tensor,
|
|
69
|
+
) -> ty.Generator["AnnoyIndexer",
|
|
70
|
+
None,
|
|
71
|
+
None]:
|
|
72
|
+
"""Adhere to ``BaseBuilder.from_embeddings``."""
|
|
73
|
+
index = annoy.AnnoyIndex(self._d, metric=self._metric)
|
|
74
|
+
|
|
75
|
+
for idx, emb in enumerate(embeddings):
|
|
76
|
+
index.add_item(idx, emb)
|
|
77
|
+
|
|
78
|
+
index.build(self._n_trees, n_jobs=self._threads)
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
yield AnnoyIndexer(index, self._searchk)
|
|
82
|
+
finally:
|
|
83
|
+
index.unload()
|
|
84
|
+
|
|
85
|
+
@contextlib.contextmanager
|
|
86
|
+
def from_index_file(
|
|
87
|
+
self,
|
|
88
|
+
index_file: FilePath,
|
|
89
|
+
) -> ty.Generator["AnnoyIndexer",
|
|
90
|
+
None,
|
|
91
|
+
None]:
|
|
92
|
+
"""Adhere to ``BaseBuilder.from_index_file``."""
|
|
93
|
+
index = annoy.AnnoyIndex(self._d, metric=self._metric)
|
|
94
|
+
index.load(str(index_file))
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
yield AnnoyIndexer(index, searchk=self._searchk)
|
|
98
|
+
finally:
|
|
99
|
+
index.unload()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class AnnoyIndexer(BaseIndexer):
|
|
103
|
+
"""ANNOY-based implementation of ``BaseIndexer``."""
|
|
104
|
+
|
|
105
|
+
def __init__(self, index: annoy.AnnoyIndex, searchk: int) -> None:
|
|
106
|
+
"""Initialize a ANNOY indexer."""
|
|
107
|
+
self._index = index
|
|
108
|
+
self._searchk = searchk
|
|
109
|
+
|
|
110
|
+
def query(
|
|
111
|
+
self,
|
|
112
|
+
embeddings: torch.Tensor,
|
|
113
|
+
top_k: int,
|
|
114
|
+
) -> tuple[torch.Tensor,
|
|
115
|
+
torch.Tensor]:
|
|
116
|
+
"""Adhere to ``BaseIndexer.query``."""
|
|
117
|
+
n, _ = embeddings.shape
|
|
118
|
+
nn_ids = np.zeros((n, top_k), dtype=np.int64)
|
|
119
|
+
nn_distances = np.zeros((n, top_k), dtype=np.float32)
|
|
120
|
+
|
|
121
|
+
for idx, emb in enumerate(embeddings):
|
|
122
|
+
nn_ids[idx], nn_distances[idx] = self._index.get_nns_by_vector(
|
|
123
|
+
emb.cpu().numpy(),
|
|
124
|
+
top_k,
|
|
125
|
+
search_k=self._searchk,
|
|
126
|
+
include_distances=True,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
return torch.from_numpy(nn_ids), torch.from_numpy(nn_distances)
|
|
130
|
+
|
|
131
|
+
def save(self, file_path: FilePath) -> None:
|
|
132
|
+
"""Adhere to ``BaseIndexer.save``."""
|
|
133
|
+
self._index.save(str(file_path))
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""Base abstraction for vector indexers used in the retrieval stage."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
import typing as ty
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from torch_rechub.types import FilePath
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BaseBuilder(abc.ABC):
|
|
12
|
+
"""
|
|
13
|
+
Abstract base class for vector index construction.
|
|
14
|
+
|
|
15
|
+
A builder owns all build-time configuration and produces a ``BaseIndexer`` through a
|
|
16
|
+
context-managed build operation.
|
|
17
|
+
|
|
18
|
+
Examples
|
|
19
|
+
--------
|
|
20
|
+
>>> builder = BaseBuilder(...)
|
|
21
|
+
>>> embeddings = torch.randn(1000, 128)
|
|
22
|
+
>>> with builder.from_embeddings(embeddings) as indexer:
|
|
23
|
+
... ids, scores = indexer.query(embeddings[:2], top_k=5)
|
|
24
|
+
... indexer.save("index.bin")
|
|
25
|
+
>>> with builder.from_index_file("index.bin") as indexer:
|
|
26
|
+
... ids, scores = indexer.query(embeddings[:2], top_k=5)
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
@abc.abstractmethod
|
|
30
|
+
def from_embeddings(
|
|
31
|
+
self,
|
|
32
|
+
embeddings: torch.Tensor,
|
|
33
|
+
) -> ty.ContextManager["BaseIndexer"]:
|
|
34
|
+
"""
|
|
35
|
+
Build a vector index from the embeddings.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
embeddings : torch.Tensor
|
|
40
|
+
A 2D tensor (n, d) containing embedding vectors to build a new index.
|
|
41
|
+
|
|
42
|
+
Returns
|
|
43
|
+
-------
|
|
44
|
+
ContextManager[BaseIndexer]
|
|
45
|
+
A context manager that yields a fully initialized ``BaseIndexer``.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
@abc.abstractmethod
|
|
49
|
+
def from_index_file(
|
|
50
|
+
self,
|
|
51
|
+
index_file: FilePath,
|
|
52
|
+
) -> ty.ContextManager["BaseIndexer"]:
|
|
53
|
+
"""
|
|
54
|
+
Build a vector index from the index file.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
index_file : FilePath
|
|
59
|
+
Path to a serialized index on disk to be loaded.
|
|
60
|
+
|
|
61
|
+
Returns
|
|
62
|
+
-------
|
|
63
|
+
ContextManager[BaseIndexer]
|
|
64
|
+
A context manager that yields a fully initialized ``BaseIndexer``.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class BaseIndexer(abc.ABC):
|
|
69
|
+
"""Abstract base class for vector indexers in the retrieval stage."""
|
|
70
|
+
|
|
71
|
+
@abc.abstractmethod
|
|
72
|
+
def query(
|
|
73
|
+
self,
|
|
74
|
+
embeddings: torch.Tensor,
|
|
75
|
+
top_k: int,
|
|
76
|
+
) -> tuple[torch.Tensor,
|
|
77
|
+
torch.Tensor]:
|
|
78
|
+
"""
|
|
79
|
+
Query the vector index.
|
|
80
|
+
|
|
81
|
+
Parameters
|
|
82
|
+
----------
|
|
83
|
+
embeddings : torch.Tensor
|
|
84
|
+
A 2D tensor (n, d) containing embedding vectors to query the index.
|
|
85
|
+
top_k : int
|
|
86
|
+
The number of nearest items to retrieve for each vector.
|
|
87
|
+
|
|
88
|
+
Returns
|
|
89
|
+
-------
|
|
90
|
+
torch.Tensor
|
|
91
|
+
A 2D tensor of shape (n, top_k), containing the retrieved nearest neighbor
|
|
92
|
+
IDs for each vector, ordered by descending relevance.
|
|
93
|
+
torch.Tensor
|
|
94
|
+
A 2D tensor of shape (n, top_k), containing the relevance distances of the
|
|
95
|
+
nearest neighbors for each vector.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
@abc.abstractmethod
|
|
99
|
+
def save(self, file_path: FilePath) -> None:
|
|
100
|
+
"""
|
|
101
|
+
Persist the index to local disk.
|
|
102
|
+
|
|
103
|
+
Parameters
|
|
104
|
+
----------
|
|
105
|
+
file_path : FilePath
|
|
106
|
+
Destination path where the index will be saved.
|
|
107
|
+
"""
|