torch-rechub 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/basic/layers.py +213 -150
- torch_rechub/basic/loss_func.py +62 -47
- torch_rechub/basic/tracking.py +198 -0
- torch_rechub/data/__init__.py +0 -0
- torch_rechub/data/convert.py +67 -0
- torch_rechub/data/dataset.py +107 -0
- torch_rechub/models/generative/hstu.py +48 -33
- torch_rechub/serving/__init__.py +50 -0
- torch_rechub/serving/annoy.py +133 -0
- torch_rechub/serving/base.py +107 -0
- torch_rechub/serving/faiss.py +154 -0
- torch_rechub/serving/milvus.py +215 -0
- torch_rechub/trainers/ctr_trainer.py +52 -3
- torch_rechub/trainers/match_trainer.py +52 -3
- torch_rechub/trainers/mtl_trainer.py +61 -3
- torch_rechub/trainers/seq_trainer.py +93 -17
- torch_rechub/types.py +5 -0
- torch_rechub/utils/data.py +167 -137
- torch_rechub/utils/hstu_utils.py +87 -76
- torch_rechub/utils/model_utils.py +10 -12
- torch_rechub/utils/onnx_export.py +98 -45
- torch_rechub/utils/quantization.py +128 -0
- torch_rechub/utils/visualization.py +4 -12
- {torch_rechub-0.0.5.dist-info → torch_rechub-0.1.0.dist-info}/METADATA +20 -5
- {torch_rechub-0.0.5.dist-info → torch_rechub-0.1.0.dist-info}/RECORD +27 -17
- torch_rechub/trainers/matching.md +0 -3
- {torch_rechub-0.0.5.dist-info → torch_rechub-0.1.0.dist-info}/WHEEL +0 -0
- {torch_rechub-0.0.5.dist-info → torch_rechub-0.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
"""Experiment tracking utilities for Torch-RecHub.
|
|
2
|
+
|
|
3
|
+
This module exposes lightweight adapters for common visualization and
|
|
4
|
+
experiment tracking tools, namely Weights & Biases (wandb), SwanLab, and
|
|
5
|
+
TensorBoardX.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from typing import Any, Dict, List, Optional, Union
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BaseLogger(ABC):
|
|
13
|
+
"""Base interface for experiment tracking backends.
|
|
14
|
+
|
|
15
|
+
Methods
|
|
16
|
+
-------
|
|
17
|
+
log_metrics(metrics, step=None)
|
|
18
|
+
Record scalar metrics at a given step.
|
|
19
|
+
log_hyperparams(params)
|
|
20
|
+
Store hyperparameters and run configuration.
|
|
21
|
+
finish()
|
|
22
|
+
Flush pending logs and release resources.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
|
|
27
|
+
"""Log metrics to the tracking backend.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
metrics : dict of str to Any
|
|
32
|
+
Metric name-value pairs to record.
|
|
33
|
+
step : int, optional
|
|
34
|
+
Explicit global step or epoch index. When ``None``, the backend
|
|
35
|
+
uses its own default step handling.
|
|
36
|
+
"""
|
|
37
|
+
raise NotImplementedError
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def log_hyperparams(self, params: Dict[str, Any]) -> None:
|
|
41
|
+
"""Log experiment hyperparameters.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
params : dict of str to Any
|
|
46
|
+
Hyperparameters or configuration values to persist with the run.
|
|
47
|
+
"""
|
|
48
|
+
raise NotImplementedError
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def finish(self) -> None:
|
|
52
|
+
"""Finalize logging and free any backend resources."""
|
|
53
|
+
raise NotImplementedError
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class WandbLogger(BaseLogger):
|
|
57
|
+
"""Weights & Biases logger implementation.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
project : str
|
|
62
|
+
Name of the wandb project to log to.
|
|
63
|
+
name : str, optional
|
|
64
|
+
Display name for the run.
|
|
65
|
+
config : dict, optional
|
|
66
|
+
Initial hyperparameter configuration to record.
|
|
67
|
+
tags : list of str, optional
|
|
68
|
+
Optional tags for grouping runs.
|
|
69
|
+
notes : str, optional
|
|
70
|
+
Long-form notes shown in the run overview.
|
|
71
|
+
dir : str, optional
|
|
72
|
+
Local directory for wandb artifacts and cache.
|
|
73
|
+
**kwargs : dict
|
|
74
|
+
Additional keyword arguments forwarded to ``wandb.init``.
|
|
75
|
+
|
|
76
|
+
Raises
|
|
77
|
+
------
|
|
78
|
+
ImportError
|
|
79
|
+
If ``wandb`` is not installed in the current environment.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(self, project: str, name: Optional[str] = None, config: Optional[Dict[str, Any]] = None, tags: Optional[List[str]] = None, notes: Optional[str] = None, dir: Optional[str] = None, **kwargs):
|
|
83
|
+
try:
|
|
84
|
+
import wandb
|
|
85
|
+
self._wandb = wandb
|
|
86
|
+
except ImportError:
|
|
87
|
+
raise ImportError("wandb is not installed. Install it with: pip install wandb")
|
|
88
|
+
|
|
89
|
+
self.run = self._wandb.init(project=project, name=name, config=config, tags=tags, notes=notes, dir=dir, **kwargs)
|
|
90
|
+
|
|
91
|
+
def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
|
|
92
|
+
if step is not None:
|
|
93
|
+
self._wandb.log(metrics, step=step)
|
|
94
|
+
else:
|
|
95
|
+
self._wandb.log(metrics)
|
|
96
|
+
|
|
97
|
+
def log_hyperparams(self, params: Dict[str, Any]) -> None:
|
|
98
|
+
if self.run is not None:
|
|
99
|
+
self.run.config.update(params)
|
|
100
|
+
|
|
101
|
+
def finish(self) -> None:
|
|
102
|
+
if self.run is not None:
|
|
103
|
+
self.run.finish()
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class SwanLabLogger(BaseLogger):
|
|
107
|
+
"""SwanLab logger implementation.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
project : str, optional
|
|
112
|
+
Project identifier for grouping experiments.
|
|
113
|
+
experiment_name : str, optional
|
|
114
|
+
Display name for the experiment or run.
|
|
115
|
+
description : str, optional
|
|
116
|
+
Text description shown alongside the run.
|
|
117
|
+
config : dict, optional
|
|
118
|
+
Hyperparameters or configuration to log at startup.
|
|
119
|
+
logdir : str, optional
|
|
120
|
+
Directory where logs and artifacts are stored.
|
|
121
|
+
**kwargs : dict
|
|
122
|
+
Additional keyword arguments forwarded to ``swanlab.init``.
|
|
123
|
+
|
|
124
|
+
Raises
|
|
125
|
+
------
|
|
126
|
+
ImportError
|
|
127
|
+
If ``swanlab`` is not installed in the current environment.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
def __init__(self, project: Optional[str] = None, experiment_name: Optional[str] = None, description: Optional[str] = None, config: Optional[Dict[str, Any]] = None, logdir: Optional[str] = None, **kwargs):
|
|
131
|
+
try:
|
|
132
|
+
import swanlab
|
|
133
|
+
self._swanlab = swanlab
|
|
134
|
+
except ImportError:
|
|
135
|
+
raise ImportError("swanlab is not installed. Install it with: pip install swanlab")
|
|
136
|
+
|
|
137
|
+
self.run = self._swanlab.init(project=project, experiment_name=experiment_name, description=description, config=config, logdir=logdir, **kwargs)
|
|
138
|
+
|
|
139
|
+
def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
|
|
140
|
+
if step is not None:
|
|
141
|
+
self._swanlab.log(metrics, step=step)
|
|
142
|
+
else:
|
|
143
|
+
self._swanlab.log(metrics)
|
|
144
|
+
|
|
145
|
+
def log_hyperparams(self, params: Dict[str, Any]) -> None:
|
|
146
|
+
if self.run is not None:
|
|
147
|
+
self.run.config.update(params)
|
|
148
|
+
|
|
149
|
+
def finish(self) -> None:
|
|
150
|
+
self._swanlab.finish()
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class TensorBoardXLogger(BaseLogger):
|
|
154
|
+
"""TensorBoardX logger implementation.
|
|
155
|
+
|
|
156
|
+
Parameters
|
|
157
|
+
----------
|
|
158
|
+
log_dir : str
|
|
159
|
+
Directory where event files will be written.
|
|
160
|
+
comment : str, default=""
|
|
161
|
+
Comment appended to the log directory name.
|
|
162
|
+
**kwargs : dict
|
|
163
|
+
Additional keyword arguments forwarded to
|
|
164
|
+
``tensorboardX.SummaryWriter``.
|
|
165
|
+
|
|
166
|
+
Raises
|
|
167
|
+
------
|
|
168
|
+
ImportError
|
|
169
|
+
If ``tensorboardX`` is not installed in the current environment.
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
def __init__(self, log_dir: str, comment: str = "", **kwargs):
|
|
173
|
+
try:
|
|
174
|
+
from tensorboardX import SummaryWriter
|
|
175
|
+
self._SummaryWriter = SummaryWriter
|
|
176
|
+
except ImportError:
|
|
177
|
+
raise ImportError("tensorboardX is not installed. Install it with: pip install tensorboardX")
|
|
178
|
+
|
|
179
|
+
self.writer = self._SummaryWriter(log_dir=log_dir, comment=comment, **kwargs)
|
|
180
|
+
self._step = 0
|
|
181
|
+
|
|
182
|
+
def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
|
|
183
|
+
if step is None:
|
|
184
|
+
step = self._step
|
|
185
|
+
self._step += 1
|
|
186
|
+
|
|
187
|
+
for key, value in metrics.items():
|
|
188
|
+
if value is not None:
|
|
189
|
+
if isinstance(value, (int, float)):
|
|
190
|
+
self.writer.add_scalar(key, value, step)
|
|
191
|
+
|
|
192
|
+
def log_hyperparams(self, params: Dict[str, Any]) -> None:
|
|
193
|
+
hparam_str = "\n".join([f"{k}: {v}" for k, v in params.items()])
|
|
194
|
+
self.writer.add_text("hyperparameters", hparam_str, 0)
|
|
195
|
+
|
|
196
|
+
def finish(self) -> None:
|
|
197
|
+
if self.writer is not None:
|
|
198
|
+
self.writer.close()
|
|
File without changes
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Utilities for converting array-like data structures into PyTorch tensors."""
|
|
2
|
+
|
|
3
|
+
import numpy.typing as npt
|
|
4
|
+
import pyarrow as pa
|
|
5
|
+
import pyarrow.compute as pc
|
|
6
|
+
import pyarrow.types as pt
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def pa_array_to_tensor(arr: pa.Array) -> torch.Tensor:
|
|
11
|
+
"""
|
|
12
|
+
Convert a PyArrow array to a PyTorch tensor.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
arr : pa.Array
|
|
17
|
+
The given PyArrow array.
|
|
18
|
+
|
|
19
|
+
Returns
|
|
20
|
+
-------
|
|
21
|
+
torch.Tensor: The result PyTorch tensor.
|
|
22
|
+
|
|
23
|
+
Raises
|
|
24
|
+
------
|
|
25
|
+
TypeError
|
|
26
|
+
if the array type or the value type (when nested) is unsupported.
|
|
27
|
+
ValueError
|
|
28
|
+
if the nested array is ragged (unequal lengths of each row).
|
|
29
|
+
"""
|
|
30
|
+
if _is_supported_scalar(arr.type):
|
|
31
|
+
arr = pc.cast(arr, pa.float32())
|
|
32
|
+
return torch.from_numpy(_to_writable_numpy(arr))
|
|
33
|
+
|
|
34
|
+
if not _is_supported_list(arr.type):
|
|
35
|
+
raise TypeError(f"Unsupported array type: {arr.type}")
|
|
36
|
+
|
|
37
|
+
if not _is_supported_scalar(val_type := arr.type.value_type):
|
|
38
|
+
raise TypeError(f"Unsupported value type in the nested array: {val_type}")
|
|
39
|
+
|
|
40
|
+
if len(pc.unique(pc.list_value_length(arr))) > 1:
|
|
41
|
+
raise ValueError("Cannot convert the ragged nested array.")
|
|
42
|
+
|
|
43
|
+
arr = pc.cast(arr, pa.list_(pa.float32()))
|
|
44
|
+
np_arr = _to_writable_numpy(arr.values) # type: ignore[attr-defined]
|
|
45
|
+
|
|
46
|
+
# For empty list-of-lists, define output shape as (0, 0); otherwise infer width.
|
|
47
|
+
return torch.from_numpy(np_arr.reshape(len(arr), -1 if len(arr) > 0 else 0))
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# helper functions
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _is_supported_list(t: pa.DataType) -> bool:
|
|
54
|
+
"""Check if the given PyArrow data type is a supported list."""
|
|
55
|
+
return pt.is_fixed_size_list(t) or pt.is_large_list(t) or pt.is_list(t)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _is_supported_scalar(t: pa.DataType) -> bool:
|
|
59
|
+
"""Check if the given PyArrow data type is a supported scalar type."""
|
|
60
|
+
return pt.is_boolean(t) or pt.is_floating(t) or pt.is_integer(t) or pt.is_null(t)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _to_writable_numpy(arr: pa.Array) -> npt.NDArray:
|
|
64
|
+
"""Dump a PyArrow array into a writable NumPy array."""
|
|
65
|
+
# Force the NumPy array to be writable. PyArrow's to_numpy() often returns a
|
|
66
|
+
# read-only view for zero-copy, which PyTorch's from_numpy() does not support.
|
|
67
|
+
return arr.to_numpy(writable=True, zero_copy_only=False)
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""Dataset implementations providing streaming, batch-wise data access for PyTorch."""
|
|
2
|
+
|
|
3
|
+
import typing as ty
|
|
4
|
+
|
|
5
|
+
import pyarrow.dataset as pd
|
|
6
|
+
import torch
|
|
7
|
+
from torch.utils.data import IterableDataset, get_worker_info
|
|
8
|
+
|
|
9
|
+
from torch_rechub.types import FilePath
|
|
10
|
+
|
|
11
|
+
from .convert import pa_array_to_tensor
|
|
12
|
+
|
|
13
|
+
# The default batch size when reading a Parquet dataset
|
|
14
|
+
_DEFAULT_BATCH_SIZE = 1024
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ParquetIterableDataset(IterableDataset):
|
|
18
|
+
"""Stream Parquet data as PyTorch tensors.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
file_paths : list[FilePath]
|
|
23
|
+
Paths to Parquet files.
|
|
24
|
+
columns : list[str], optional
|
|
25
|
+
Columns to select; if ``None``, read all columns.
|
|
26
|
+
batch_size : int, default _DEFAULT_BATCH_SIZE
|
|
27
|
+
Rows per streamed batch.
|
|
28
|
+
|
|
29
|
+
Notes
|
|
30
|
+
-----
|
|
31
|
+
Reads lazily; no full Parquet load. Each worker gets a partition, builds its
|
|
32
|
+
own PyArrow Dataset/Scanner, and yields dicts of column tensors batch by batch.
|
|
33
|
+
|
|
34
|
+
Examples
|
|
35
|
+
--------
|
|
36
|
+
>>> ds = ParquetIterableDataset(
|
|
37
|
+
... ["/data/train1.parquet", "/data/train2.parquet"],
|
|
38
|
+
... columns=["x", "y", "label"],
|
|
39
|
+
... batch_size=1024,
|
|
40
|
+
... )
|
|
41
|
+
>>> loader = DataLoader(ds, batch_size=None)
|
|
42
|
+
>>> for batch in loader:
|
|
43
|
+
... x, y, label = batch["x"], batch["y"], batch["label"]
|
|
44
|
+
... ...
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
file_paths: ty.Sequence[FilePath],
|
|
50
|
+
/,
|
|
51
|
+
columns: ty.Optional[ty.Sequence[str]] = None,
|
|
52
|
+
batch_size: int = _DEFAULT_BATCH_SIZE,
|
|
53
|
+
) -> None:
|
|
54
|
+
"""Initialize this instance."""
|
|
55
|
+
self._file_paths = tuple(map(str, file_paths))
|
|
56
|
+
self._columns = None if columns is None else tuple(columns)
|
|
57
|
+
self._batch_size = batch_size
|
|
58
|
+
|
|
59
|
+
def __iter__(self) -> ty.Iterator[dict[str, torch.Tensor]]:
|
|
60
|
+
"""Stream Parquet data as mapped PyTorch tensors.
|
|
61
|
+
|
|
62
|
+
Builds a PyArrow Dataset from the current worker's file partition, then
|
|
63
|
+
lazily scans selected columns. Each batch becomes a dict of Torch tensors.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
Iterator[dict[str, torch.Tensor]]
|
|
68
|
+
One converted batch at a time.
|
|
69
|
+
"""
|
|
70
|
+
if not (partition := self._get_partition()):
|
|
71
|
+
return
|
|
72
|
+
|
|
73
|
+
# Build the dataset for the current worker.
|
|
74
|
+
ds = pd.dataset(partition, format="parquet")
|
|
75
|
+
|
|
76
|
+
# Create a scanner. This does not read data.
|
|
77
|
+
columns = None if self._columns is None else list(self._columns)
|
|
78
|
+
scanner = ds.scanner(columns=columns, batch_size=self._batch_size)
|
|
79
|
+
|
|
80
|
+
for batch in scanner.to_batches():
|
|
81
|
+
data_dict: dict[str, torch.Tensor] = {}
|
|
82
|
+
for name, array in zip(batch.column_names, batch.columns):
|
|
83
|
+
data_dict[name] = pa_array_to_tensor(array)
|
|
84
|
+
yield data_dict
|
|
85
|
+
|
|
86
|
+
# private interfaces
|
|
87
|
+
|
|
88
|
+
def _get_partition(self) -> tuple[str, ...]:
|
|
89
|
+
"""Get file partition for the current worker.
|
|
90
|
+
|
|
91
|
+
Splits file paths into contiguous partitions by number of workers and worker ID.
|
|
92
|
+
In the main process (no worker info), returns all paths.
|
|
93
|
+
|
|
94
|
+
Returns
|
|
95
|
+
-------
|
|
96
|
+
tuple[str, ...]
|
|
97
|
+
Partition of file paths for this worker.
|
|
98
|
+
"""
|
|
99
|
+
if (info := get_worker_info()) is None:
|
|
100
|
+
return self._file_paths
|
|
101
|
+
|
|
102
|
+
n = len(self._file_paths)
|
|
103
|
+
per_worker = (n + info.num_workers - 1) // info.num_workers
|
|
104
|
+
|
|
105
|
+
start = info.id * per_worker
|
|
106
|
+
end = n if (end := start + per_worker) > n else end
|
|
107
|
+
return self._file_paths[start:end]
|
|
@@ -10,39 +10,54 @@ from torch_rechub.utils.hstu_utils import RelPosBias
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class HSTUModel(nn.Module):
|
|
13
|
-
"""HSTU: Hierarchical Sequential Transduction Units
|
|
14
|
-
|
|
15
|
-
Autoregressive generative
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
13
|
+
"""HSTU: Hierarchical Sequential Transduction Units.
|
|
14
|
+
|
|
15
|
+
Autoregressive generative recommender that stacks ``HSTUBlock`` layers to
|
|
16
|
+
capture long-range dependencies and predict the next item.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
vocab_size : int
|
|
21
|
+
Vocabulary size (items incl. PAD).
|
|
22
|
+
d_model : int, default=512
|
|
23
|
+
Hidden dimension.
|
|
24
|
+
n_heads : int, default=8
|
|
25
|
+
Attention heads.
|
|
26
|
+
n_layers : int, default=4
|
|
27
|
+
Number of stacked HSTU layers.
|
|
28
|
+
dqk : int, default=64
|
|
29
|
+
Query/key dim per head.
|
|
30
|
+
dv : int, default=64
|
|
31
|
+
Value dim per head.
|
|
32
|
+
max_seq_len : int, default=256
|
|
33
|
+
Maximum sequence length.
|
|
34
|
+
dropout : float, default=0.1
|
|
35
|
+
Dropout rate.
|
|
36
|
+
use_rel_pos_bias : bool, default=True
|
|
37
|
+
Use relative position bias.
|
|
38
|
+
use_time_embedding : bool, default=True
|
|
39
|
+
Use time-difference embeddings.
|
|
40
|
+
num_time_buckets : int, default=2048
|
|
41
|
+
Number of time buckets for time embeddings.
|
|
42
|
+
time_bucket_fn : {'sqrt', 'log'}, default='sqrt'
|
|
43
|
+
Bucketization function for time differences.
|
|
44
|
+
|
|
45
|
+
Shape
|
|
46
|
+
-----
|
|
47
|
+
Input
|
|
48
|
+
x : ``(batch_size, seq_len)``
|
|
49
|
+
time_diffs : ``(batch_size, seq_len)``, optional (seconds).
|
|
50
|
+
Output
|
|
51
|
+
logits : ``(batch_size, seq_len, vocab_size)``
|
|
52
|
+
|
|
53
|
+
Examples
|
|
54
|
+
--------
|
|
55
|
+
>>> model = HSTUModel(vocab_size=100000, d_model=512)
|
|
56
|
+
>>> x = torch.randint(0, 100000, (32, 256))
|
|
57
|
+
>>> time_diffs = torch.randint(0, 86400, (32, 256))
|
|
58
|
+
>>> logits = model(x, time_diffs)
|
|
59
|
+
>>> logits.shape
|
|
60
|
+
torch.Size([32, 256, 100000])
|
|
46
61
|
"""
|
|
47
62
|
|
|
48
63
|
def __init__(self, vocab_size, d_model=512, n_heads=8, n_layers=4, dqk=64, dv=64, max_seq_len=256, dropout=0.1, use_rel_pos_bias=True, use_time_embedding=True, num_time_buckets=2048, time_bucket_fn='sqrt'):
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import typing as ty
|
|
2
|
+
|
|
3
|
+
from .annoy import AnnoyBuilder
|
|
4
|
+
from .base import BaseBuilder
|
|
5
|
+
from .faiss import FaissBuilder
|
|
6
|
+
from .milvus import MilvusBuilder
|
|
7
|
+
|
|
8
|
+
# Type for supported retrieval models.
|
|
9
|
+
_RetrievalModel = ty.Literal["annoy", "faiss", "milvus"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def builder_factory(model: _RetrievalModel, **builder_config) -> BaseBuilder:
|
|
13
|
+
"""
|
|
14
|
+
Factory function for creating a vector index builder.
|
|
15
|
+
|
|
16
|
+
This function instantiates and returns a concrete implementation of ``BaseBuilder``
|
|
17
|
+
based on the specified retrieval backend. The returned builder is responsible for
|
|
18
|
+
constructing or loading the underlying ANN index via its own ``from_embeddings`` or
|
|
19
|
+
``from_index_file`` method.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
model : "annoy", "faiss", or "milvus"
|
|
24
|
+
The retrieval backend to use.
|
|
25
|
+
**builder_config
|
|
26
|
+
Keyword arguments passed directly to the selected builder constructor.
|
|
27
|
+
|
|
28
|
+
Returns
|
|
29
|
+
-------
|
|
30
|
+
BaseBuilder
|
|
31
|
+
A concrete builder instance corresponding to the specified retrieval backend.
|
|
32
|
+
|
|
33
|
+
Raises
|
|
34
|
+
------
|
|
35
|
+
NotImplementedError
|
|
36
|
+
if the specified retrieval model is not supported.
|
|
37
|
+
"""
|
|
38
|
+
if model == "annoy":
|
|
39
|
+
return AnnoyBuilder(**builder_config)
|
|
40
|
+
|
|
41
|
+
if model == "faiss":
|
|
42
|
+
return FaissBuilder(**builder_config)
|
|
43
|
+
|
|
44
|
+
if model == "milvus":
|
|
45
|
+
return MilvusBuilder(**builder_config)
|
|
46
|
+
|
|
47
|
+
raise NotImplementedError(f"{model=} is not implemented yet!")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
__all__ = ["builder_factory"]
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""ANNOY-based vector index implementation for the retrieval stage."""
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import typing as ty
|
|
5
|
+
|
|
6
|
+
import annoy
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from torch_rechub.types import FilePath
|
|
11
|
+
|
|
12
|
+
from .base import BaseBuilder, BaseIndexer
|
|
13
|
+
|
|
14
|
+
# Type for distance metrics for the ANNOY index.
|
|
15
|
+
_AnnoyMetric = ty.Literal["angular", "euclidean", "dot"]
|
|
16
|
+
|
|
17
|
+
# Default distance metric used by ANNOY.
|
|
18
|
+
_DEFAULT_METRIC: _AnnoyMetric = "angular"
|
|
19
|
+
|
|
20
|
+
# Default number of trees to build in the ANNOY index.
|
|
21
|
+
_DEFAULT_N_TREES = 10
|
|
22
|
+
|
|
23
|
+
# Default number of worker threads for building the ANNOY index.
|
|
24
|
+
_DEFAULT_THREADS = -1
|
|
25
|
+
|
|
26
|
+
# Default number of nodes to inspect during an ANNOY search.
|
|
27
|
+
_DEFAULT_SEARCHK = -1
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class AnnoyBuilder(BaseBuilder):
|
|
31
|
+
"""ANNOY-based implementation of ``BaseBuilder``."""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
d: int,
|
|
36
|
+
metric: _AnnoyMetric = _DEFAULT_METRIC,
|
|
37
|
+
*,
|
|
38
|
+
n_trees: int = _DEFAULT_N_TREES,
|
|
39
|
+
threads: int = _DEFAULT_THREADS,
|
|
40
|
+
searchk: int = _DEFAULT_SEARCHK,
|
|
41
|
+
) -> None:
|
|
42
|
+
"""
|
|
43
|
+
Initialize a ANNOY builder.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
d : int
|
|
48
|
+
The dimension of embeddings.
|
|
49
|
+
metric : ``"angular"``, ``"euclidean"``, or ``"dot"``, optional
|
|
50
|
+
The indexing metric. Default to ``"angular"``.
|
|
51
|
+
n_trees : int, optional
|
|
52
|
+
Number of trees to build an ANNOY index.
|
|
53
|
+
threads : int, optional
|
|
54
|
+
Number of worker threads to build an ANNOY index.
|
|
55
|
+
searchk : int, optional
|
|
56
|
+
Number of nodes to inspect during an ANNOY search.
|
|
57
|
+
"""
|
|
58
|
+
self._d = d
|
|
59
|
+
self._metric = metric
|
|
60
|
+
|
|
61
|
+
self._n_trees = n_trees
|
|
62
|
+
self._threads = threads
|
|
63
|
+
self._searchk = searchk
|
|
64
|
+
|
|
65
|
+
@contextlib.contextmanager
|
|
66
|
+
def from_embeddings(
|
|
67
|
+
self,
|
|
68
|
+
embeddings: torch.Tensor,
|
|
69
|
+
) -> ty.Generator["AnnoyIndexer",
|
|
70
|
+
None,
|
|
71
|
+
None]:
|
|
72
|
+
"""Adhere to ``BaseBuilder.from_embeddings``."""
|
|
73
|
+
index = annoy.AnnoyIndex(self._d, metric=self._metric)
|
|
74
|
+
|
|
75
|
+
for idx, emb in enumerate(embeddings):
|
|
76
|
+
index.add_item(idx, emb)
|
|
77
|
+
|
|
78
|
+
index.build(self._n_trees, n_jobs=self._threads)
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
yield AnnoyIndexer(index, self._searchk)
|
|
82
|
+
finally:
|
|
83
|
+
index.unload()
|
|
84
|
+
|
|
85
|
+
@contextlib.contextmanager
|
|
86
|
+
def from_index_file(
|
|
87
|
+
self,
|
|
88
|
+
index_file: FilePath,
|
|
89
|
+
) -> ty.Generator["AnnoyIndexer",
|
|
90
|
+
None,
|
|
91
|
+
None]:
|
|
92
|
+
"""Adhere to ``BaseBuilder.from_index_file``."""
|
|
93
|
+
index = annoy.AnnoyIndex(self._d, metric=self._metric)
|
|
94
|
+
index.load(str(index_file))
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
yield AnnoyIndexer(index, searchk=self._searchk)
|
|
98
|
+
finally:
|
|
99
|
+
index.unload()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class AnnoyIndexer(BaseIndexer):
|
|
103
|
+
"""ANNOY-based implementation of ``BaseIndexer``."""
|
|
104
|
+
|
|
105
|
+
def __init__(self, index: annoy.AnnoyIndex, searchk: int) -> None:
|
|
106
|
+
"""Initialize a ANNOY indexer."""
|
|
107
|
+
self._index = index
|
|
108
|
+
self._searchk = searchk
|
|
109
|
+
|
|
110
|
+
def query(
|
|
111
|
+
self,
|
|
112
|
+
embeddings: torch.Tensor,
|
|
113
|
+
top_k: int,
|
|
114
|
+
) -> tuple[torch.Tensor,
|
|
115
|
+
torch.Tensor]:
|
|
116
|
+
"""Adhere to ``BaseIndexer.query``."""
|
|
117
|
+
n, _ = embeddings.shape
|
|
118
|
+
nn_ids = np.zeros((n, top_k), dtype=np.int64)
|
|
119
|
+
nn_distances = np.zeros((n, top_k), dtype=np.float32)
|
|
120
|
+
|
|
121
|
+
for idx, emb in enumerate(embeddings):
|
|
122
|
+
nn_ids[idx], nn_distances[idx] = self._index.get_nns_by_vector(
|
|
123
|
+
emb.cpu().numpy(),
|
|
124
|
+
top_k,
|
|
125
|
+
search_k=self._searchk,
|
|
126
|
+
include_distances=True,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
return torch.from_numpy(nn_ids), torch.from_numpy(nn_distances)
|
|
130
|
+
|
|
131
|
+
def save(self, file_path: FilePath) -> None:
|
|
132
|
+
"""Adhere to ``BaseIndexer.save``."""
|
|
133
|
+
self._index.save(str(file_path))
|