torch-rechub 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,198 @@
1
+ """Experiment tracking utilities for Torch-RecHub.
2
+
3
+ This module exposes lightweight adapters for common visualization and
4
+ experiment tracking tools, namely Weights & Biases (wandb), SwanLab, and
5
+ TensorBoardX.
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ from typing import Any, Dict, List, Optional, Union
10
+
11
+
12
+ class BaseLogger(ABC):
13
+ """Base interface for experiment tracking backends.
14
+
15
+ Methods
16
+ -------
17
+ log_metrics(metrics, step=None)
18
+ Record scalar metrics at a given step.
19
+ log_hyperparams(params)
20
+ Store hyperparameters and run configuration.
21
+ finish()
22
+ Flush pending logs and release resources.
23
+ """
24
+
25
+ @abstractmethod
26
+ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
27
+ """Log metrics to the tracking backend.
28
+
29
+ Parameters
30
+ ----------
31
+ metrics : dict of str to Any
32
+ Metric name-value pairs to record.
33
+ step : int, optional
34
+ Explicit global step or epoch index. When ``None``, the backend
35
+ uses its own default step handling.
36
+ """
37
+ raise NotImplementedError
38
+
39
+ @abstractmethod
40
+ def log_hyperparams(self, params: Dict[str, Any]) -> None:
41
+ """Log experiment hyperparameters.
42
+
43
+ Parameters
44
+ ----------
45
+ params : dict of str to Any
46
+ Hyperparameters or configuration values to persist with the run.
47
+ """
48
+ raise NotImplementedError
49
+
50
+ @abstractmethod
51
+ def finish(self) -> None:
52
+ """Finalize logging and free any backend resources."""
53
+ raise NotImplementedError
54
+
55
+
56
+ class WandbLogger(BaseLogger):
57
+ """Weights & Biases logger implementation.
58
+
59
+ Parameters
60
+ ----------
61
+ project : str
62
+ Name of the wandb project to log to.
63
+ name : str, optional
64
+ Display name for the run.
65
+ config : dict, optional
66
+ Initial hyperparameter configuration to record.
67
+ tags : list of str, optional
68
+ Optional tags for grouping runs.
69
+ notes : str, optional
70
+ Long-form notes shown in the run overview.
71
+ dir : str, optional
72
+ Local directory for wandb artifacts and cache.
73
+ **kwargs : dict
74
+ Additional keyword arguments forwarded to ``wandb.init``.
75
+
76
+ Raises
77
+ ------
78
+ ImportError
79
+ If ``wandb`` is not installed in the current environment.
80
+ """
81
+
82
+ def __init__(self, project: str, name: Optional[str] = None, config: Optional[Dict[str, Any]] = None, tags: Optional[List[str]] = None, notes: Optional[str] = None, dir: Optional[str] = None, **kwargs):
83
+ try:
84
+ import wandb
85
+ self._wandb = wandb
86
+ except ImportError:
87
+ raise ImportError("wandb is not installed. Install it with: pip install wandb")
88
+
89
+ self.run = self._wandb.init(project=project, name=name, config=config, tags=tags, notes=notes, dir=dir, **kwargs)
90
+
91
+ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
92
+ if step is not None:
93
+ self._wandb.log(metrics, step=step)
94
+ else:
95
+ self._wandb.log(metrics)
96
+
97
+ def log_hyperparams(self, params: Dict[str, Any]) -> None:
98
+ if self.run is not None:
99
+ self.run.config.update(params)
100
+
101
+ def finish(self) -> None:
102
+ if self.run is not None:
103
+ self.run.finish()
104
+
105
+
106
+ class SwanLabLogger(BaseLogger):
107
+ """SwanLab logger implementation.
108
+
109
+ Parameters
110
+ ----------
111
+ project : str, optional
112
+ Project identifier for grouping experiments.
113
+ experiment_name : str, optional
114
+ Display name for the experiment or run.
115
+ description : str, optional
116
+ Text description shown alongside the run.
117
+ config : dict, optional
118
+ Hyperparameters or configuration to log at startup.
119
+ logdir : str, optional
120
+ Directory where logs and artifacts are stored.
121
+ **kwargs : dict
122
+ Additional keyword arguments forwarded to ``swanlab.init``.
123
+
124
+ Raises
125
+ ------
126
+ ImportError
127
+ If ``swanlab`` is not installed in the current environment.
128
+ """
129
+
130
+ def __init__(self, project: Optional[str] = None, experiment_name: Optional[str] = None, description: Optional[str] = None, config: Optional[Dict[str, Any]] = None, logdir: Optional[str] = None, **kwargs):
131
+ try:
132
+ import swanlab
133
+ self._swanlab = swanlab
134
+ except ImportError:
135
+ raise ImportError("swanlab is not installed. Install it with: pip install swanlab")
136
+
137
+ self.run = self._swanlab.init(project=project, experiment_name=experiment_name, description=description, config=config, logdir=logdir, **kwargs)
138
+
139
+ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
140
+ if step is not None:
141
+ self._swanlab.log(metrics, step=step)
142
+ else:
143
+ self._swanlab.log(metrics)
144
+
145
+ def log_hyperparams(self, params: Dict[str, Any]) -> None:
146
+ if self.run is not None:
147
+ self.run.config.update(params)
148
+
149
+ def finish(self) -> None:
150
+ self._swanlab.finish()
151
+
152
+
153
+ class TensorBoardXLogger(BaseLogger):
154
+ """TensorBoardX logger implementation.
155
+
156
+ Parameters
157
+ ----------
158
+ log_dir : str
159
+ Directory where event files will be written.
160
+ comment : str, default=""
161
+ Comment appended to the log directory name.
162
+ **kwargs : dict
163
+ Additional keyword arguments forwarded to
164
+ ``tensorboardX.SummaryWriter``.
165
+
166
+ Raises
167
+ ------
168
+ ImportError
169
+ If ``tensorboardX`` is not installed in the current environment.
170
+ """
171
+
172
+ def __init__(self, log_dir: str, comment: str = "", **kwargs):
173
+ try:
174
+ from tensorboardX import SummaryWriter
175
+ self._SummaryWriter = SummaryWriter
176
+ except ImportError:
177
+ raise ImportError("tensorboardX is not installed. Install it with: pip install tensorboardX")
178
+
179
+ self.writer = self._SummaryWriter(log_dir=log_dir, comment=comment, **kwargs)
180
+ self._step = 0
181
+
182
+ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
183
+ if step is None:
184
+ step = self._step
185
+ self._step += 1
186
+
187
+ for key, value in metrics.items():
188
+ if value is not None:
189
+ if isinstance(value, (int, float)):
190
+ self.writer.add_scalar(key, value, step)
191
+
192
+ def log_hyperparams(self, params: Dict[str, Any]) -> None:
193
+ hparam_str = "\n".join([f"{k}: {v}" for k, v in params.items()])
194
+ self.writer.add_text("hyperparameters", hparam_str, 0)
195
+
196
+ def finish(self) -> None:
197
+ if self.writer is not None:
198
+ self.writer.close()
File without changes
@@ -0,0 +1,67 @@
1
+ """Utilities for converting array-like data structures into PyTorch tensors."""
2
+
3
+ import numpy.typing as npt
4
+ import pyarrow as pa
5
+ import pyarrow.compute as pc
6
+ import pyarrow.types as pt
7
+ import torch
8
+
9
+
10
+ def pa_array_to_tensor(arr: pa.Array) -> torch.Tensor:
11
+ """
12
+ Convert a PyArrow array to a PyTorch tensor.
13
+
14
+ Parameters
15
+ ----------
16
+ arr : pa.Array
17
+ The given PyArrow array.
18
+
19
+ Returns
20
+ -------
21
+ torch.Tensor: The result PyTorch tensor.
22
+
23
+ Raises
24
+ ------
25
+ TypeError
26
+ if the array type or the value type (when nested) is unsupported.
27
+ ValueError
28
+ if the nested array is ragged (unequal lengths of each row).
29
+ """
30
+ if _is_supported_scalar(arr.type):
31
+ arr = pc.cast(arr, pa.float32())
32
+ return torch.from_numpy(_to_writable_numpy(arr))
33
+
34
+ if not _is_supported_list(arr.type):
35
+ raise TypeError(f"Unsupported array type: {arr.type}")
36
+
37
+ if not _is_supported_scalar(val_type := arr.type.value_type):
38
+ raise TypeError(f"Unsupported value type in the nested array: {val_type}")
39
+
40
+ if len(pc.unique(pc.list_value_length(arr))) > 1:
41
+ raise ValueError("Cannot convert the ragged nested array.")
42
+
43
+ arr = pc.cast(arr, pa.list_(pa.float32()))
44
+ np_arr = _to_writable_numpy(arr.values) # type: ignore[attr-defined]
45
+
46
+ # For empty list-of-lists, define output shape as (0, 0); otherwise infer width.
47
+ return torch.from_numpy(np_arr.reshape(len(arr), -1 if len(arr) > 0 else 0))
48
+
49
+
50
+ # helper functions
51
+
52
+
53
+ def _is_supported_list(t: pa.DataType) -> bool:
54
+ """Check if the given PyArrow data type is a supported list."""
55
+ return pt.is_fixed_size_list(t) or pt.is_large_list(t) or pt.is_list(t)
56
+
57
+
58
+ def _is_supported_scalar(t: pa.DataType) -> bool:
59
+ """Check if the given PyArrow data type is a supported scalar type."""
60
+ return pt.is_boolean(t) or pt.is_floating(t) or pt.is_integer(t) or pt.is_null(t)
61
+
62
+
63
+ def _to_writable_numpy(arr: pa.Array) -> npt.NDArray:
64
+ """Dump a PyArrow array into a writable NumPy array."""
65
+ # Force the NumPy array to be writable. PyArrow's to_numpy() often returns a
66
+ # read-only view for zero-copy, which PyTorch's from_numpy() does not support.
67
+ return arr.to_numpy(writable=True, zero_copy_only=False)
@@ -0,0 +1,107 @@
1
+ """Dataset implementations providing streaming, batch-wise data access for PyTorch."""
2
+
3
+ import typing as ty
4
+
5
+ import pyarrow.dataset as pd
6
+ import torch
7
+ from torch.utils.data import IterableDataset, get_worker_info
8
+
9
+ from torch_rechub.types import FilePath
10
+
11
+ from .convert import pa_array_to_tensor
12
+
13
+ # The default batch size when reading a Parquet dataset
14
+ _DEFAULT_BATCH_SIZE = 1024
15
+
16
+
17
+ class ParquetIterableDataset(IterableDataset):
18
+ """Stream Parquet data as PyTorch tensors.
19
+
20
+ Parameters
21
+ ----------
22
+ file_paths : list[FilePath]
23
+ Paths to Parquet files.
24
+ columns : list[str], optional
25
+ Columns to select; if ``None``, read all columns.
26
+ batch_size : int, default _DEFAULT_BATCH_SIZE
27
+ Rows per streamed batch.
28
+
29
+ Notes
30
+ -----
31
+ Reads lazily; no full Parquet load. Each worker gets a partition, builds its
32
+ own PyArrow Dataset/Scanner, and yields dicts of column tensors batch by batch.
33
+
34
+ Examples
35
+ --------
36
+ >>> ds = ParquetIterableDataset(
37
+ ... ["/data/train1.parquet", "/data/train2.parquet"],
38
+ ... columns=["x", "y", "label"],
39
+ ... batch_size=1024,
40
+ ... )
41
+ >>> loader = DataLoader(ds, batch_size=None)
42
+ >>> for batch in loader:
43
+ ... x, y, label = batch["x"], batch["y"], batch["label"]
44
+ ... ...
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ file_paths: ty.Sequence[FilePath],
50
+ /,
51
+ columns: ty.Optional[ty.Sequence[str]] = None,
52
+ batch_size: int = _DEFAULT_BATCH_SIZE,
53
+ ) -> None:
54
+ """Initialize this instance."""
55
+ self._file_paths = tuple(map(str, file_paths))
56
+ self._columns = None if columns is None else tuple(columns)
57
+ self._batch_size = batch_size
58
+
59
+ def __iter__(self) -> ty.Iterator[dict[str, torch.Tensor]]:
60
+ """Stream Parquet data as mapped PyTorch tensors.
61
+
62
+ Builds a PyArrow Dataset from the current worker's file partition, then
63
+ lazily scans selected columns. Each batch becomes a dict of Torch tensors.
64
+
65
+ Returns
66
+ -------
67
+ Iterator[dict[str, torch.Tensor]]
68
+ One converted batch at a time.
69
+ """
70
+ if not (partition := self._get_partition()):
71
+ return
72
+
73
+ # Build the dataset for the current worker.
74
+ ds = pd.dataset(partition, format="parquet")
75
+
76
+ # Create a scanner. This does not read data.
77
+ columns = None if self._columns is None else list(self._columns)
78
+ scanner = ds.scanner(columns=columns, batch_size=self._batch_size)
79
+
80
+ for batch in scanner.to_batches():
81
+ data_dict: dict[str, torch.Tensor] = {}
82
+ for name, array in zip(batch.column_names, batch.columns):
83
+ data_dict[name] = pa_array_to_tensor(array)
84
+ yield data_dict
85
+
86
+ # private interfaces
87
+
88
+ def _get_partition(self) -> tuple[str, ...]:
89
+ """Get file partition for the current worker.
90
+
91
+ Splits file paths into contiguous partitions by number of workers and worker ID.
92
+ In the main process (no worker info), returns all paths.
93
+
94
+ Returns
95
+ -------
96
+ tuple[str, ...]
97
+ Partition of file paths for this worker.
98
+ """
99
+ if (info := get_worker_info()) is None:
100
+ return self._file_paths
101
+
102
+ n = len(self._file_paths)
103
+ per_worker = (n + info.num_workers - 1) // info.num_workers
104
+
105
+ start = info.id * per_worker
106
+ end = n if (end := start + per_worker) > n else end
107
+ return self._file_paths[start:end]
@@ -10,39 +10,54 @@ from torch_rechub.utils.hstu_utils import RelPosBias
10
10
 
11
11
 
12
12
  class HSTUModel(nn.Module):
13
- """HSTU: Hierarchical Sequential Transduction Units model.
14
-
15
- Autoregressive generative recommendation model for sequential data.
16
- This module stacks multiple ``HSTUBlock`` layers to capture long-range
17
- dependencies in user interaction sequences and predicts the next item.
18
-
19
- Args:
20
- vocab_size (int): Vocabulary size (number of distinct items, including PAD).
21
- d_model (int): Hidden dimension of the model. Default: 512.
22
- n_heads (int): Number of attention heads. Default: 8.
23
- n_layers (int): Number of stacked HSTU layers. Default: 4.
24
- dqk (int): Dimension of query/key vectors per head. Default: 64.
25
- dv (int): Dimension of value vectors per head. Default: 64.
26
- max_seq_len (int): Maximum sequence length. Default: 256.
27
- dropout (float): Dropout rate applied in the model. Default: 0.1.
28
- use_rel_pos_bias (bool): Whether to use relative position bias. Default: True.
29
- use_time_embedding (bool): Whether to use time-difference embeddings. Default: True.
30
- num_time_buckets (int): Number of time buckets for time embeddings. Default: 2048.
31
- time_bucket_fn (str): Function used to bucketize time differences, ``"sqrt"``
32
- or ``"log"``. Default: ``"sqrt"``.
33
-
34
- Shape:
35
- - Input: ``x`` of shape ``(batch_size, seq_len)``; optional ``time_diffs``
36
- of shape ``(batch_size, seq_len)`` representing time differences in seconds.
37
- - Output: Logits of shape ``(batch_size, seq_len, vocab_size)``.
38
-
39
- Example:
40
- >>> model = HSTUModel(vocab_size=100000, d_model=512)
41
- >>> x = torch.randint(0, 100000, (32, 256))
42
- >>> time_diffs = torch.randint(0, 86400, (32, 256))
43
- >>> logits = model(x, time_diffs)
44
- >>> logits.shape
45
- torch.Size([32, 256, 100000])
13
+ """HSTU: Hierarchical Sequential Transduction Units.
14
+
15
+ Autoregressive generative recommender that stacks ``HSTUBlock`` layers to
16
+ capture long-range dependencies and predict the next item.
17
+
18
+ Parameters
19
+ ----------
20
+ vocab_size : int
21
+ Vocabulary size (items incl. PAD).
22
+ d_model : int, default=512
23
+ Hidden dimension.
24
+ n_heads : int, default=8
25
+ Attention heads.
26
+ n_layers : int, default=4
27
+ Number of stacked HSTU layers.
28
+ dqk : int, default=64
29
+ Query/key dim per head.
30
+ dv : int, default=64
31
+ Value dim per head.
32
+ max_seq_len : int, default=256
33
+ Maximum sequence length.
34
+ dropout : float, default=0.1
35
+ Dropout rate.
36
+ use_rel_pos_bias : bool, default=True
37
+ Use relative position bias.
38
+ use_time_embedding : bool, default=True
39
+ Use time-difference embeddings.
40
+ num_time_buckets : int, default=2048
41
+ Number of time buckets for time embeddings.
42
+ time_bucket_fn : {'sqrt', 'log'}, default='sqrt'
43
+ Bucketization function for time differences.
44
+
45
+ Shape
46
+ -----
47
+ Input
48
+ x : ``(batch_size, seq_len)``
49
+ time_diffs : ``(batch_size, seq_len)``, optional (seconds).
50
+ Output
51
+ logits : ``(batch_size, seq_len, vocab_size)``
52
+
53
+ Examples
54
+ --------
55
+ >>> model = HSTUModel(vocab_size=100000, d_model=512)
56
+ >>> x = torch.randint(0, 100000, (32, 256))
57
+ >>> time_diffs = torch.randint(0, 86400, (32, 256))
58
+ >>> logits = model(x, time_diffs)
59
+ >>> logits.shape
60
+ torch.Size([32, 256, 100000])
46
61
  """
47
62
 
48
63
  def __init__(self, vocab_size, d_model=512, n_heads=8, n_layers=4, dqk=64, dv=64, max_seq_len=256, dropout=0.1, use_rel_pos_bias=True, use_time_embedding=True, num_time_buckets=2048, time_bucket_fn='sqrt'):
@@ -0,0 +1,50 @@
1
+ import typing as ty
2
+
3
+ from .annoy import AnnoyBuilder
4
+ from .base import BaseBuilder
5
+ from .faiss import FaissBuilder
6
+ from .milvus import MilvusBuilder
7
+
8
+ # Type for supported retrieval models.
9
+ _RetrievalModel = ty.Literal["annoy", "faiss", "milvus"]
10
+
11
+
12
+ def builder_factory(model: _RetrievalModel, **builder_config) -> BaseBuilder:
13
+ """
14
+ Factory function for creating a vector index builder.
15
+
16
+ This function instantiates and returns a concrete implementation of ``BaseBuilder``
17
+ based on the specified retrieval backend. The returned builder is responsible for
18
+ constructing or loading the underlying ANN index via its own ``from_embeddings`` or
19
+ ``from_index_file`` method.
20
+
21
+ Parameters
22
+ ----------
23
+ model : "annoy", "faiss", or "milvus"
24
+ The retrieval backend to use.
25
+ **builder_config
26
+ Keyword arguments passed directly to the selected builder constructor.
27
+
28
+ Returns
29
+ -------
30
+ BaseBuilder
31
+ A concrete builder instance corresponding to the specified retrieval backend.
32
+
33
+ Raises
34
+ ------
35
+ NotImplementedError
36
+ if the specified retrieval model is not supported.
37
+ """
38
+ if model == "annoy":
39
+ return AnnoyBuilder(**builder_config)
40
+
41
+ if model == "faiss":
42
+ return FaissBuilder(**builder_config)
43
+
44
+ if model == "milvus":
45
+ return MilvusBuilder(**builder_config)
46
+
47
+ raise NotImplementedError(f"{model=} is not implemented yet!")
48
+
49
+
50
+ __all__ = ["builder_factory"]
@@ -0,0 +1,133 @@
1
+ """ANNOY-based vector index implementation for the retrieval stage."""
2
+
3
+ import contextlib
4
+ import typing as ty
5
+
6
+ import annoy
7
+ import numpy as np
8
+ import torch
9
+
10
+ from torch_rechub.types import FilePath
11
+
12
+ from .base import BaseBuilder, BaseIndexer
13
+
14
+ # Type for distance metrics for the ANNOY index.
15
+ _AnnoyMetric = ty.Literal["angular", "euclidean", "dot"]
16
+
17
+ # Default distance metric used by ANNOY.
18
+ _DEFAULT_METRIC: _AnnoyMetric = "angular"
19
+
20
+ # Default number of trees to build in the ANNOY index.
21
+ _DEFAULT_N_TREES = 10
22
+
23
+ # Default number of worker threads for building the ANNOY index.
24
+ _DEFAULT_THREADS = -1
25
+
26
+ # Default number of nodes to inspect during an ANNOY search.
27
+ _DEFAULT_SEARCHK = -1
28
+
29
+
30
+ class AnnoyBuilder(BaseBuilder):
31
+ """ANNOY-based implementation of ``BaseBuilder``."""
32
+
33
+ def __init__(
34
+ self,
35
+ d: int,
36
+ metric: _AnnoyMetric = _DEFAULT_METRIC,
37
+ *,
38
+ n_trees: int = _DEFAULT_N_TREES,
39
+ threads: int = _DEFAULT_THREADS,
40
+ searchk: int = _DEFAULT_SEARCHK,
41
+ ) -> None:
42
+ """
43
+ Initialize a ANNOY builder.
44
+
45
+ Parameters
46
+ ----------
47
+ d : int
48
+ The dimension of embeddings.
49
+ metric : ``"angular"``, ``"euclidean"``, or ``"dot"``, optional
50
+ The indexing metric. Default to ``"angular"``.
51
+ n_trees : int, optional
52
+ Number of trees to build an ANNOY index.
53
+ threads : int, optional
54
+ Number of worker threads to build an ANNOY index.
55
+ searchk : int, optional
56
+ Number of nodes to inspect during an ANNOY search.
57
+ """
58
+ self._d = d
59
+ self._metric = metric
60
+
61
+ self._n_trees = n_trees
62
+ self._threads = threads
63
+ self._searchk = searchk
64
+
65
+ @contextlib.contextmanager
66
+ def from_embeddings(
67
+ self,
68
+ embeddings: torch.Tensor,
69
+ ) -> ty.Generator["AnnoyIndexer",
70
+ None,
71
+ None]:
72
+ """Adhere to ``BaseBuilder.from_embeddings``."""
73
+ index = annoy.AnnoyIndex(self._d, metric=self._metric)
74
+
75
+ for idx, emb in enumerate(embeddings):
76
+ index.add_item(idx, emb)
77
+
78
+ index.build(self._n_trees, n_jobs=self._threads)
79
+
80
+ try:
81
+ yield AnnoyIndexer(index, self._searchk)
82
+ finally:
83
+ index.unload()
84
+
85
+ @contextlib.contextmanager
86
+ def from_index_file(
87
+ self,
88
+ index_file: FilePath,
89
+ ) -> ty.Generator["AnnoyIndexer",
90
+ None,
91
+ None]:
92
+ """Adhere to ``BaseBuilder.from_index_file``."""
93
+ index = annoy.AnnoyIndex(self._d, metric=self._metric)
94
+ index.load(str(index_file))
95
+
96
+ try:
97
+ yield AnnoyIndexer(index, searchk=self._searchk)
98
+ finally:
99
+ index.unload()
100
+
101
+
102
+ class AnnoyIndexer(BaseIndexer):
103
+ """ANNOY-based implementation of ``BaseIndexer``."""
104
+
105
+ def __init__(self, index: annoy.AnnoyIndex, searchk: int) -> None:
106
+ """Initialize a ANNOY indexer."""
107
+ self._index = index
108
+ self._searchk = searchk
109
+
110
+ def query(
111
+ self,
112
+ embeddings: torch.Tensor,
113
+ top_k: int,
114
+ ) -> tuple[torch.Tensor,
115
+ torch.Tensor]:
116
+ """Adhere to ``BaseIndexer.query``."""
117
+ n, _ = embeddings.shape
118
+ nn_ids = np.zeros((n, top_k), dtype=np.int64)
119
+ nn_distances = np.zeros((n, top_k), dtype=np.float32)
120
+
121
+ for idx, emb in enumerate(embeddings):
122
+ nn_ids[idx], nn_distances[idx] = self._index.get_nns_by_vector(
123
+ emb.cpu().numpy(),
124
+ top_k,
125
+ search_k=self._searchk,
126
+ include_distances=True,
127
+ )
128
+
129
+ return torch.from_numpy(nn_ids), torch.from_numpy(nn_distances)
130
+
131
+ def save(self, file_path: FilePath) -> None:
132
+ """Adhere to ``BaseIndexer.save``."""
133
+ self._index.save(str(file_path))