torch-rechub 0.0.6__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,13 +4,24 @@ import torch.nn as nn
4
4
 
5
5
 
6
6
  class RegularizationLoss(nn.Module):
7
- """Unified L1/L2 Regularization Loss for embedding and dense parameters.
8
-
9
- Example:
10
- >>> reg_loss_fn = RegularizationLoss(embedding_l2=1e-5, dense_l2=1e-5)
11
- >>> # In model's forward or trainer
12
- >>> reg_loss = reg_loss_fn(model)
13
- >>> total_loss = task_loss + reg_loss
7
+ """Unified L1/L2 regularization for embedding and dense parameters.
8
+
9
+ Parameters
10
+ ----------
11
+ embedding_l1 : float, default=0.0
12
+ L1 coefficient for embedding parameters.
13
+ embedding_l2 : float, default=0.0
14
+ L2 coefficient for embedding parameters.
15
+ dense_l1 : float, default=0.0
16
+ L1 coefficient for dense (non-embedding) parameters.
17
+ dense_l2 : float, default=0.0
18
+ L2 coefficient for dense (non-embedding) parameters.
19
+
20
+ Examples
21
+ --------
22
+ >>> reg_loss_fn = RegularizationLoss(embedding_l2=1e-5, dense_l2=1e-5)
23
+ >>> reg_loss = reg_loss_fn(model)
24
+ >>> total_loss = task_loss + reg_loss
14
25
  """
15
26
 
16
27
  def __init__(self, embedding_l1=0.0, embedding_l2=0.0, dense_l1=0.0, dense_l2=0.0):
@@ -58,9 +69,11 @@ class RegularizationLoss(nn.Module):
58
69
 
59
70
 
60
71
  class HingeLoss(torch.nn.Module):
61
- """Hinge Loss for pairwise learning.
62
- reference: https://github.com/ustcml/RecStudio/blob/main/recstudio/model/loss_func.py
72
+ """Hinge loss for pairwise learning.
63
73
 
74
+ Notes
75
+ -----
76
+ Reference: https://github.com/ustcml/RecStudio/blob/main/recstudio/model/loss_func.py
64
77
  """
65
78
 
66
79
  def __init__(self, margin=2, num_items=None):
@@ -89,27 +102,28 @@ class BPRLoss(torch.nn.Module):
89
102
 
90
103
 
91
104
  class NCELoss(torch.nn.Module):
92
- """Noise Contrastive Estimation (NCE) Loss for recommendation systems.
93
-
94
- NCE Loss is more efficient than CrossEntropyLoss for large-scale recommendation
95
- scenarios. It uses in-batch negatives to reduce computational complexity.
96
-
97
- Reference:
98
- - Noise-contrastive estimation: A new estimation principle for unnormalized
99
- statistical models (Gutmann & Hyvärinen, 2010)
100
- - HLLM: Hierarchical Large Language Model for Recommendation
101
-
102
- Args:
103
- temperature (float): Temperature parameter for scaling logits. Default: 1.0
104
- ignore_index (int): Index to ignore in loss computation. Default: 0
105
- reduction (str): Specifies the reduction to apply to the output.
106
- Options: 'mean', 'sum', 'none'. Default: 'mean'
107
-
108
- Example:
109
- >>> nce_loss = NCELoss(temperature=0.1)
110
- >>> logits = torch.randn(32, 1000) # (batch_size, vocab_size)
111
- >>> targets = torch.randint(0, 1000, (32,))
112
- >>> loss = nce_loss(logits, targets)
105
+ """Noise Contrastive Estimation (NCE) loss for recommender systems.
106
+
107
+ Parameters
108
+ ----------
109
+ temperature : float, default=1.0
110
+ Temperature for scaling logits.
111
+ ignore_index : int, default=0
112
+ Target index to ignore.
113
+ reduction : {'mean', 'sum', 'none'}, default='mean'
114
+ Reduction applied to the output.
115
+
116
+ Notes
117
+ -----
118
+ - Gutmann & Hyvärinen (2010), Noise-contrastive estimation.
119
+ - HLLM: Hierarchical Large Language Model for Recommendation.
120
+
121
+ Examples
122
+ --------
123
+ >>> nce_loss = NCELoss(temperature=0.1)
124
+ >>> logits = torch.randn(32, 1000)
125
+ >>> targets = torch.randint(0, 1000, (32,))
126
+ >>> loss = nce_loss(logits, targets)
113
127
  """
114
128
 
115
129
  def __init__(self, temperature=1.0, ignore_index=0, reduction='mean'):
@@ -158,23 +172,24 @@ class NCELoss(torch.nn.Module):
158
172
 
159
173
 
160
174
  class InBatchNCELoss(torch.nn.Module):
161
- """In-Batch NCE Loss with explicit negative sampling.
162
-
163
- This loss function uses other samples in the batch as negative samples,
164
- which is more efficient than sampling random negatives.
165
-
166
- Args:
167
- temperature (float): Temperature parameter for scaling logits. Default: 0.1
168
- ignore_index (int): Index to ignore in loss computation. Default: 0
169
- reduction (str): Specifies the reduction to apply to the output.
170
- Options: 'mean', 'sum', 'none'. Default: 'mean'
171
-
172
- Example:
173
- >>> loss_fn = InBatchNCELoss(temperature=0.1)
174
- >>> embeddings = torch.randn(32, 256) # (batch_size, embedding_dim)
175
- >>> item_embeddings = torch.randn(1000, 256) # (vocab_size, embedding_dim)
176
- >>> targets = torch.randint(0, 1000, (32,))
177
- >>> loss = loss_fn(embeddings, item_embeddings, targets)
175
+ """In-batch NCE loss with explicit negatives.
176
+
177
+ Parameters
178
+ ----------
179
+ temperature : float, default=0.1
180
+ Temperature for scaling logits.
181
+ ignore_index : int, default=0
182
+ Target index to ignore.
183
+ reduction : {'mean', 'sum', 'none'}, default='mean'
184
+ Reduction applied to the output.
185
+
186
+ Examples
187
+ --------
188
+ >>> loss_fn = InBatchNCELoss(temperature=0.1)
189
+ >>> embeddings = torch.randn(32, 256)
190
+ >>> item_embeddings = torch.randn(1000, 256)
191
+ >>> targets = torch.randint(0, 1000, (32,))
192
+ >>> loss = loss_fn(embeddings, item_embeddings, targets)
178
193
  """
179
194
 
180
195
  def __init__(self, temperature=0.1, ignore_index=0, reduction='mean'):
@@ -1,40 +1,35 @@
1
1
  """Dataset implementations providing streaming, batch-wise data access for PyTorch."""
2
2
 
3
- import os
4
3
  import typing as ty
5
4
 
6
5
  import pyarrow.dataset as pd
7
6
  import torch
8
7
  from torch.utils.data import IterableDataset, get_worker_info
9
8
 
10
- from .convert import pa_array_to_tensor
9
+ from torch_rechub.types import FilePath
11
10
 
12
- # Type for path to a file
13
- _FilePath = ty.Union[str, os.PathLike]
11
+ from .convert import pa_array_to_tensor
14
12
 
15
13
  # The default batch size when reading a Parquet dataset
16
14
  _DEFAULT_BATCH_SIZE = 1024
17
15
 
18
16
 
19
17
  class ParquetIterableDataset(IterableDataset):
20
- """
21
- IterableDataset that streams data from one or more Parquet files.
18
+ """Stream Parquet data as PyTorch tensors.
22
19
 
23
20
  Parameters
24
21
  ----------
25
- file_paths : list[_FilePath]
22
+ file_paths : list[FilePath]
26
23
  Paths to Parquet files.
27
24
  columns : list[str], optional
28
- Column names to select. If ``None``, all columns are read.
29
- batch_size : int, default DEFAULT_BATCH_SIZE
30
- Number of rows per streamed batch.
25
+ Columns to select; if ``None``, read all columns.
26
+ batch_size : int, default _DEFAULT_BATCH_SIZE
27
+ Rows per streamed batch.
31
28
 
32
29
  Notes
33
30
  -----
34
- This dataset reads data lazily and never loads the entire Parquet dataset to memory.
35
- The current worker receives a partition of ``file_paths`` and builds its own PyArrow
36
- Dataset and Scanner. Iteration yields dictionaries mapping column names to PyTorch
37
- tensors created via NumPy, one batch at a time.
31
+ Reads lazily; no full Parquet load. Each worker gets a partition, builds its
32
+ own PyArrow Dataset/Scanner, and yields dicts of column tensors batch by batch.
38
33
 
39
34
  Examples
40
35
  --------
@@ -44,16 +39,14 @@ class ParquetIterableDataset(IterableDataset):
44
39
  ... batch_size=1024,
45
40
  ... )
46
41
  >>> loader = DataLoader(ds, batch_size=None)
47
- >>> # Now iterate over batches.
48
42
  >>> for batch in loader:
49
43
  ... x, y, label = batch["x"], batch["y"], batch["label"]
50
- ... # Do some work.
51
44
  ... ...
52
45
  """
53
46
 
54
47
  def __init__(
55
48
  self,
56
- file_paths: ty.Sequence[_FilePath],
49
+ file_paths: ty.Sequence[FilePath],
57
50
  /,
58
51
  columns: ty.Optional[ty.Sequence[str]] = None,
59
52
  batch_size: int = _DEFAULT_BATCH_SIZE,
@@ -64,17 +57,15 @@ class ParquetIterableDataset(IterableDataset):
64
57
  self._batch_size = batch_size
65
58
 
66
59
  def __iter__(self) -> ty.Iterator[dict[str, torch.Tensor]]:
67
- """
68
- Stream Parquet data as mapped PyTorch tensors.
60
+ """Stream Parquet data as mapped PyTorch tensors.
69
61
 
70
- Build a PyArrow Dataset from the current worker's assigned file partition, then
71
- create a Scanner to lazily read batches of the selected columns. Each batch is
72
- converted to a dict mapping column names to PyTorch tensors (via NumPy).
62
+ Builds a PyArrow Dataset from the current worker's file partition, then
63
+ lazily scans selected columns. Each batch becomes a dict of Torch tensors.
73
64
 
74
65
  Returns
75
66
  -------
76
67
  Iterator[dict[str, torch.Tensor]]
77
- An iterator that yields one converted batch at a time.
68
+ One converted batch at a time.
78
69
  """
79
70
  if not (partition := self._get_partition()):
80
71
  return
@@ -95,19 +86,15 @@ class ParquetIterableDataset(IterableDataset):
95
86
  # private interfaces
96
87
 
97
88
  def _get_partition(self) -> tuple[str, ...]:
98
- """
99
- Get the partition of file paths for the current worker.
100
-
101
- This method splits the full list of file paths into contiguous partitions with
102
- a nearly equal size by the total number of workers and the current worker ID.
89
+ """Get file partition for the current worker.
103
90
 
104
- If running in the main process (i.e., no worker information is available), the
105
- entire list of file paths is returned.
91
+ Splits file paths into contiguous partitions by number of workers and worker ID.
92
+ In the main process (no worker info), returns all paths.
106
93
 
107
94
  Returns
108
95
  -------
109
96
  tuple[str, ...]
110
- The partition of file paths for the current worker.
97
+ Partition of file paths for this worker.
111
98
  """
112
99
  if (info := get_worker_info()) is None:
113
100
  return self._file_paths
@@ -10,39 +10,54 @@ from torch_rechub.utils.hstu_utils import RelPosBias
10
10
 
11
11
 
12
12
  class HSTUModel(nn.Module):
13
- """HSTU: Hierarchical Sequential Transduction Units model.
14
-
15
- Autoregressive generative recommendation model for sequential data.
16
- This module stacks multiple ``HSTUBlock`` layers to capture long-range
17
- dependencies in user interaction sequences and predicts the next item.
18
-
19
- Args:
20
- vocab_size (int): Vocabulary size (number of distinct items, including PAD).
21
- d_model (int): Hidden dimension of the model. Default: 512.
22
- n_heads (int): Number of attention heads. Default: 8.
23
- n_layers (int): Number of stacked HSTU layers. Default: 4.
24
- dqk (int): Dimension of query/key vectors per head. Default: 64.
25
- dv (int): Dimension of value vectors per head. Default: 64.
26
- max_seq_len (int): Maximum sequence length. Default: 256.
27
- dropout (float): Dropout rate applied in the model. Default: 0.1.
28
- use_rel_pos_bias (bool): Whether to use relative position bias. Default: True.
29
- use_time_embedding (bool): Whether to use time-difference embeddings. Default: True.
30
- num_time_buckets (int): Number of time buckets for time embeddings. Default: 2048.
31
- time_bucket_fn (str): Function used to bucketize time differences, ``"sqrt"``
32
- or ``"log"``. Default: ``"sqrt"``.
33
-
34
- Shape:
35
- - Input: ``x`` of shape ``(batch_size, seq_len)``; optional ``time_diffs``
36
- of shape ``(batch_size, seq_len)`` representing time differences in seconds.
37
- - Output: Logits of shape ``(batch_size, seq_len, vocab_size)``.
38
-
39
- Example:
40
- >>> model = HSTUModel(vocab_size=100000, d_model=512)
41
- >>> x = torch.randint(0, 100000, (32, 256))
42
- >>> time_diffs = torch.randint(0, 86400, (32, 256))
43
- >>> logits = model(x, time_diffs)
44
- >>> logits.shape
45
- torch.Size([32, 256, 100000])
13
+ """HSTU: Hierarchical Sequential Transduction Units.
14
+
15
+ Autoregressive generative recommender that stacks ``HSTUBlock`` layers to
16
+ capture long-range dependencies and predict the next item.
17
+
18
+ Parameters
19
+ ----------
20
+ vocab_size : int
21
+ Vocabulary size (items incl. PAD).
22
+ d_model : int, default=512
23
+ Hidden dimension.
24
+ n_heads : int, default=8
25
+ Attention heads.
26
+ n_layers : int, default=4
27
+ Number of stacked HSTU layers.
28
+ dqk : int, default=64
29
+ Query/key dim per head.
30
+ dv : int, default=64
31
+ Value dim per head.
32
+ max_seq_len : int, default=256
33
+ Maximum sequence length.
34
+ dropout : float, default=0.1
35
+ Dropout rate.
36
+ use_rel_pos_bias : bool, default=True
37
+ Use relative position bias.
38
+ use_time_embedding : bool, default=True
39
+ Use time-difference embeddings.
40
+ num_time_buckets : int, default=2048
41
+ Number of time buckets for time embeddings.
42
+ time_bucket_fn : {'sqrt', 'log'}, default='sqrt'
43
+ Bucketization function for time differences.
44
+
45
+ Shape
46
+ -----
47
+ Input
48
+ x : ``(batch_size, seq_len)``
49
+ time_diffs : ``(batch_size, seq_len)``, optional (seconds).
50
+ Output
51
+ logits : ``(batch_size, seq_len, vocab_size)``
52
+
53
+ Examples
54
+ --------
55
+ >>> model = HSTUModel(vocab_size=100000, d_model=512)
56
+ >>> x = torch.randint(0, 100000, (32, 256))
57
+ >>> time_diffs = torch.randint(0, 86400, (32, 256))
58
+ >>> logits = model(x, time_diffs)
59
+ >>> logits.shape
60
+ torch.Size([32, 256, 100000])
46
61
  """
47
62
 
48
63
  def __init__(self, vocab_size, d_model=512, n_heads=8, n_layers=4, dqk=64, dv=64, max_seq_len=256, dropout=0.1, use_rel_pos_bias=True, use_time_embedding=True, num_time_buckets=2048, time_bucket_fn='sqrt'):
@@ -0,0 +1,50 @@
1
+ import typing as ty
2
+
3
+ from .annoy import AnnoyBuilder
4
+ from .base import BaseBuilder
5
+ from .faiss import FaissBuilder
6
+ from .milvus import MilvusBuilder
7
+
8
+ # Type for supported retrieval models.
9
+ _RetrievalModel = ty.Literal["annoy", "faiss", "milvus"]
10
+
11
+
12
+ def builder_factory(model: _RetrievalModel, **builder_config) -> BaseBuilder:
13
+ """
14
+ Factory function for creating a vector index builder.
15
+
16
+ This function instantiates and returns a concrete implementation of ``BaseBuilder``
17
+ based on the specified retrieval backend. The returned builder is responsible for
18
+ constructing or loading the underlying ANN index via its own ``from_embeddings`` or
19
+ ``from_index_file`` method.
20
+
21
+ Parameters
22
+ ----------
23
+ model : "annoy", "faiss", or "milvus"
24
+ The retrieval backend to use.
25
+ **builder_config
26
+ Keyword arguments passed directly to the selected builder constructor.
27
+
28
+ Returns
29
+ -------
30
+ BaseBuilder
31
+ A concrete builder instance corresponding to the specified retrieval backend.
32
+
33
+ Raises
34
+ ------
35
+ NotImplementedError
36
+ if the specified retrieval model is not supported.
37
+ """
38
+ if model == "annoy":
39
+ return AnnoyBuilder(**builder_config)
40
+
41
+ if model == "faiss":
42
+ return FaissBuilder(**builder_config)
43
+
44
+ if model == "milvus":
45
+ return MilvusBuilder(**builder_config)
46
+
47
+ raise NotImplementedError(f"{model=} is not implemented yet!")
48
+
49
+
50
+ __all__ = ["builder_factory"]
@@ -0,0 +1,133 @@
1
+ """ANNOY-based vector index implementation for the retrieval stage."""
2
+
3
+ import contextlib
4
+ import typing as ty
5
+
6
+ import annoy
7
+ import numpy as np
8
+ import torch
9
+
10
+ from torch_rechub.types import FilePath
11
+
12
+ from .base import BaseBuilder, BaseIndexer
13
+
14
+ # Type for distance metrics for the ANNOY index.
15
+ _AnnoyMetric = ty.Literal["angular", "euclidean", "dot"]
16
+
17
+ # Default distance metric used by ANNOY.
18
+ _DEFAULT_METRIC: _AnnoyMetric = "angular"
19
+
20
+ # Default number of trees to build in the ANNOY index.
21
+ _DEFAULT_N_TREES = 10
22
+
23
+ # Default number of worker threads for building the ANNOY index.
24
+ _DEFAULT_THREADS = -1
25
+
26
+ # Default number of nodes to inspect during an ANNOY search.
27
+ _DEFAULT_SEARCHK = -1
28
+
29
+
30
+ class AnnoyBuilder(BaseBuilder):
31
+ """ANNOY-based implementation of ``BaseBuilder``."""
32
+
33
+ def __init__(
34
+ self,
35
+ d: int,
36
+ metric: _AnnoyMetric = _DEFAULT_METRIC,
37
+ *,
38
+ n_trees: int = _DEFAULT_N_TREES,
39
+ threads: int = _DEFAULT_THREADS,
40
+ searchk: int = _DEFAULT_SEARCHK,
41
+ ) -> None:
42
+ """
43
+ Initialize a ANNOY builder.
44
+
45
+ Parameters
46
+ ----------
47
+ d : int
48
+ The dimension of embeddings.
49
+ metric : ``"angular"``, ``"euclidean"``, or ``"dot"``, optional
50
+ The indexing metric. Default to ``"angular"``.
51
+ n_trees : int, optional
52
+ Number of trees to build an ANNOY index.
53
+ threads : int, optional
54
+ Number of worker threads to build an ANNOY index.
55
+ searchk : int, optional
56
+ Number of nodes to inspect during an ANNOY search.
57
+ """
58
+ self._d = d
59
+ self._metric = metric
60
+
61
+ self._n_trees = n_trees
62
+ self._threads = threads
63
+ self._searchk = searchk
64
+
65
+ @contextlib.contextmanager
66
+ def from_embeddings(
67
+ self,
68
+ embeddings: torch.Tensor,
69
+ ) -> ty.Generator["AnnoyIndexer",
70
+ None,
71
+ None]:
72
+ """Adhere to ``BaseBuilder.from_embeddings``."""
73
+ index = annoy.AnnoyIndex(self._d, metric=self._metric)
74
+
75
+ for idx, emb in enumerate(embeddings):
76
+ index.add_item(idx, emb)
77
+
78
+ index.build(self._n_trees, n_jobs=self._threads)
79
+
80
+ try:
81
+ yield AnnoyIndexer(index, self._searchk)
82
+ finally:
83
+ index.unload()
84
+
85
+ @contextlib.contextmanager
86
+ def from_index_file(
87
+ self,
88
+ index_file: FilePath,
89
+ ) -> ty.Generator["AnnoyIndexer",
90
+ None,
91
+ None]:
92
+ """Adhere to ``BaseBuilder.from_index_file``."""
93
+ index = annoy.AnnoyIndex(self._d, metric=self._metric)
94
+ index.load(str(index_file))
95
+
96
+ try:
97
+ yield AnnoyIndexer(index, searchk=self._searchk)
98
+ finally:
99
+ index.unload()
100
+
101
+
102
+ class AnnoyIndexer(BaseIndexer):
103
+ """ANNOY-based implementation of ``BaseIndexer``."""
104
+
105
+ def __init__(self, index: annoy.AnnoyIndex, searchk: int) -> None:
106
+ """Initialize a ANNOY indexer."""
107
+ self._index = index
108
+ self._searchk = searchk
109
+
110
+ def query(
111
+ self,
112
+ embeddings: torch.Tensor,
113
+ top_k: int,
114
+ ) -> tuple[torch.Tensor,
115
+ torch.Tensor]:
116
+ """Adhere to ``BaseIndexer.query``."""
117
+ n, _ = embeddings.shape
118
+ nn_ids = np.zeros((n, top_k), dtype=np.int64)
119
+ nn_distances = np.zeros((n, top_k), dtype=np.float32)
120
+
121
+ for idx, emb in enumerate(embeddings):
122
+ nn_ids[idx], nn_distances[idx] = self._index.get_nns_by_vector(
123
+ emb.cpu().numpy(),
124
+ top_k,
125
+ search_k=self._searchk,
126
+ include_distances=True,
127
+ )
128
+
129
+ return torch.from_numpy(nn_ids), torch.from_numpy(nn_distances)
130
+
131
+ def save(self, file_path: FilePath) -> None:
132
+ """Adhere to ``BaseIndexer.save``."""
133
+ self._index.save(str(file_path))
@@ -0,0 +1,107 @@
1
+ """Base abstraction for vector indexers used in the retrieval stage."""
2
+
3
+ import abc
4
+ import typing as ty
5
+
6
+ import torch
7
+
8
+ from torch_rechub.types import FilePath
9
+
10
+
11
+ class BaseBuilder(abc.ABC):
12
+ """
13
+ Abstract base class for vector index construction.
14
+
15
+ A builder owns all build-time configuration and produces a ``BaseIndexer`` through a
16
+ context-managed build operation.
17
+
18
+ Examples
19
+ --------
20
+ >>> builder = BaseBuilder(...)
21
+ >>> embeddings = torch.randn(1000, 128)
22
+ >>> with builder.from_embeddings(embeddings) as indexer:
23
+ ... ids, scores = indexer.query(embeddings[:2], top_k=5)
24
+ ... indexer.save("index.bin")
25
+ >>> with builder.from_index_file("index.bin") as indexer:
26
+ ... ids, scores = indexer.query(embeddings[:2], top_k=5)
27
+ """
28
+
29
+ @abc.abstractmethod
30
+ def from_embeddings(
31
+ self,
32
+ embeddings: torch.Tensor,
33
+ ) -> ty.ContextManager["BaseIndexer"]:
34
+ """
35
+ Build a vector index from the embeddings.
36
+
37
+ Parameters
38
+ ----------
39
+ embeddings : torch.Tensor
40
+ A 2D tensor (n, d) containing embedding vectors to build a new index.
41
+
42
+ Returns
43
+ -------
44
+ ContextManager[BaseIndexer]
45
+ A context manager that yields a fully initialized ``BaseIndexer``.
46
+ """
47
+
48
+ @abc.abstractmethod
49
+ def from_index_file(
50
+ self,
51
+ index_file: FilePath,
52
+ ) -> ty.ContextManager["BaseIndexer"]:
53
+ """
54
+ Build a vector index from the index file.
55
+
56
+ Parameters
57
+ ----------
58
+ index_file : FilePath
59
+ Path to a serialized index on disk to be loaded.
60
+
61
+ Returns
62
+ -------
63
+ ContextManager[BaseIndexer]
64
+ A context manager that yields a fully initialized ``BaseIndexer``.
65
+ """
66
+
67
+
68
+ class BaseIndexer(abc.ABC):
69
+ """Abstract base class for vector indexers in the retrieval stage."""
70
+
71
+ @abc.abstractmethod
72
+ def query(
73
+ self,
74
+ embeddings: torch.Tensor,
75
+ top_k: int,
76
+ ) -> tuple[torch.Tensor,
77
+ torch.Tensor]:
78
+ """
79
+ Query the vector index.
80
+
81
+ Parameters
82
+ ----------
83
+ embeddings : torch.Tensor
84
+ A 2D tensor (n, d) containing embedding vectors to query the index.
85
+ top_k : int
86
+ The number of nearest items to retrieve for each vector.
87
+
88
+ Returns
89
+ -------
90
+ torch.Tensor
91
+ A 2D tensor of shape (n, top_k), containing the retrieved nearest neighbor
92
+ IDs for each vector, ordered by descending relevance.
93
+ torch.Tensor
94
+ A 2D tensor of shape (n, top_k), containing the relevance distances of the
95
+ nearest neighbors for each vector.
96
+ """
97
+
98
+ @abc.abstractmethod
99
+ def save(self, file_path: FilePath) -> None:
100
+ """
101
+ Persist the index to local disk.
102
+
103
+ Parameters
104
+ ----------
105
+ file_path : FilePath
106
+ Destination path where the index will be saved.
107
+ """