torch-rechub 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,107 @@
1
+ """Base abstraction for vector indexers used in the retrieval stage."""
2
+
3
+ import abc
4
+ import typing as ty
5
+
6
+ import torch
7
+
8
+ from torch_rechub.types import FilePath
9
+
10
+
11
+ class BaseBuilder(abc.ABC):
12
+ """
13
+ Abstract base class for vector index construction.
14
+
15
+ A builder owns all build-time configuration and produces a ``BaseIndexer`` through a
16
+ context-managed build operation.
17
+
18
+ Examples
19
+ --------
20
+ >>> builder = BaseBuilder(...)
21
+ >>> embeddings = torch.randn(1000, 128)
22
+ >>> with builder.from_embeddings(embeddings) as indexer:
23
+ ... ids, scores = indexer.query(embeddings[:2], top_k=5)
24
+ ... indexer.save("index.bin")
25
+ >>> with builder.from_index_file("index.bin") as indexer:
26
+ ... ids, scores = indexer.query(embeddings[:2], top_k=5)
27
+ """
28
+
29
+ @abc.abstractmethod
30
+ def from_embeddings(
31
+ self,
32
+ embeddings: torch.Tensor,
33
+ ) -> ty.ContextManager["BaseIndexer"]:
34
+ """
35
+ Build a vector index from the embeddings.
36
+
37
+ Parameters
38
+ ----------
39
+ embeddings : torch.Tensor
40
+ A 2D tensor (n, d) containing embedding vectors to build a new index.
41
+
42
+ Returns
43
+ -------
44
+ ContextManager[BaseIndexer]
45
+ A context manager that yields a fully initialized ``BaseIndexer``.
46
+ """
47
+
48
+ @abc.abstractmethod
49
+ def from_index_file(
50
+ self,
51
+ index_file: FilePath,
52
+ ) -> ty.ContextManager["BaseIndexer"]:
53
+ """
54
+ Build a vector index from the index file.
55
+
56
+ Parameters
57
+ ----------
58
+ index_file : FilePath
59
+ Path to a serialized index on disk to be loaded.
60
+
61
+ Returns
62
+ -------
63
+ ContextManager[BaseIndexer]
64
+ A context manager that yields a fully initialized ``BaseIndexer``.
65
+ """
66
+
67
+
68
+ class BaseIndexer(abc.ABC):
69
+ """Abstract base class for vector indexers in the retrieval stage."""
70
+
71
+ @abc.abstractmethod
72
+ def query(
73
+ self,
74
+ embeddings: torch.Tensor,
75
+ top_k: int,
76
+ ) -> tuple[torch.Tensor,
77
+ torch.Tensor]:
78
+ """
79
+ Query the vector index.
80
+
81
+ Parameters
82
+ ----------
83
+ embeddings : torch.Tensor
84
+ A 2D tensor (n, d) containing embedding vectors to query the index.
85
+ top_k : int
86
+ The number of nearest items to retrieve for each vector.
87
+
88
+ Returns
89
+ -------
90
+ torch.Tensor
91
+ A 2D tensor of shape (n, top_k), containing the retrieved nearest neighbor
92
+ IDs for each vector, ordered by descending relevance.
93
+ torch.Tensor
94
+ A 2D tensor of shape (n, top_k), containing the relevance distances of the
95
+ nearest neighbors for each vector.
96
+ """
97
+
98
+ @abc.abstractmethod
99
+ def save(self, file_path: FilePath) -> None:
100
+ """
101
+ Persist the index to local disk.
102
+
103
+ Parameters
104
+ ----------
105
+ file_path : FilePath
106
+ Destination path where the index will be saved.
107
+ """
@@ -0,0 +1,154 @@
1
+ """FAISS-based vector index implementation for the retrieval stage."""
2
+
3
+ import contextlib
4
+ import typing as ty
5
+
6
+ import faiss
7
+ import torch
8
+
9
+ from torch_rechub.types import FilePath
10
+
11
+ from .base import BaseBuilder, BaseIndexer
12
+
13
+ # Type for indexing methods.
14
+ _FaissIndexType = ty.Literal["Flat", "HNSW", "IVF"]
15
+
16
+ # Type for indexing metrics.
17
+ _FaissMetric = ty.Literal["IP", "L2"]
18
+
19
+ # Default indexing method.
20
+ _DEFAULT_FAISS_INDEX_TYPE: _FaissIndexType = "Flat"
21
+
22
+ # Default indexing metric.
23
+ _DEFAULT_FAISS_METRIC: _FaissMetric = "L2"
24
+
25
+ # Default number of clusters to build an IVF index.
26
+ _DEFAULT_N_LISTS = 100
27
+
28
+ # Default max number of neighbors to build an HNSW index.
29
+ _DEFAULT_M = 32
30
+
31
+
32
+ class FaissBuilder(BaseBuilder):
33
+ """Implement ``BaseBuilder`` for FAISS vector index construction."""
34
+
35
+ def __init__(
36
+ self,
37
+ index_type: _FaissIndexType = _DEFAULT_FAISS_INDEX_TYPE,
38
+ metric: _FaissMetric = _DEFAULT_FAISS_METRIC,
39
+ *,
40
+ m: int = _DEFAULT_M,
41
+ nlists: int = _DEFAULT_N_LISTS,
42
+ efSearch: ty.Optional[int] = None,
43
+ nprobe: ty.Optional[int] = None,
44
+ ) -> None:
45
+ """
46
+ Initialize a FAISS builder.
47
+
48
+ Parameters
49
+ ----------
50
+ index_type : ``"Flat"``, ``"HNSW"``, or ``"IVF"``, optional
51
+ The indexing index_type. Default to ``"Flat"``.
52
+ metric : ``"IP"``, ``"L2"``, optional
53
+ The indexing metric. Default to ``"L2"``.
54
+ m : int, optional
55
+ Max number of neighbors to build an HNSW index.
56
+ nlists : int, optional
57
+ Number of clusters to build an IVF index.
58
+ efSearch : int or None, optional
59
+ Number of candidate nodes during an HNSW search.
60
+ nprobe : int or None, optional
61
+ Number of clusters during an IVF search.
62
+ """
63
+ self._index_type_dsl = _build_index_type_dsl(index_type, m=m, nlists=nlists)
64
+ self._metric = _resolve_metric_type(metric)
65
+
66
+ self._efSearch = efSearch
67
+ self._nprobe = nprobe
68
+
69
+ @contextlib.contextmanager
70
+ def from_embeddings(
71
+ self,
72
+ embeddings: torch.Tensor,
73
+ ) -> ty.Generator["FaissIndexer",
74
+ None,
75
+ None]:
76
+ """Adhere to ``BaseBuilder.from_embeddings``."""
77
+ index: faiss.Index = faiss.index_factory(
78
+ embeddings.shape[1],
79
+ self._index_type_dsl,
80
+ self._metric,
81
+ )
82
+
83
+ if isinstance(index, faiss.IndexHNSW) and self._efSearch is not None:
84
+ index.hnsw.efSearch = self._efSearch
85
+
86
+ if isinstance(index, faiss.IndexIVF) and self._nprobe is not None:
87
+ index.nprobe = self._nprobe
88
+
89
+ index.train(embeddings)
90
+ index.add(embeddings)
91
+
92
+ try:
93
+ yield FaissIndexer(index)
94
+ finally:
95
+ pass
96
+
97
+ @contextlib.contextmanager
98
+ def from_index_file(
99
+ self,
100
+ index_file: FilePath,
101
+ ) -> ty.Generator["FaissIndexer",
102
+ None,
103
+ None]:
104
+ """Adhere to ``BaseBuilder.from_index_file``."""
105
+ index = faiss.read_index(str(index_file))
106
+
107
+ try:
108
+ yield FaissIndexer(index)
109
+ finally:
110
+ pass
111
+
112
+
113
+ class FaissIndexer(BaseIndexer):
114
+ """FAISS-based implementation of ``BaseIndexer``."""
115
+
116
+ def __init__(self, index: faiss.Index) -> None:
117
+ """Initialize a FAISS indexer."""
118
+ self._index = index
119
+
120
+ def query(
121
+ self,
122
+ embeddings: torch.Tensor,
123
+ top_k: int,
124
+ ) -> tuple[torch.Tensor,
125
+ torch.Tensor]:
126
+ """Adhere to ``BaseIndexer.query``."""
127
+ dists, ids = self._index.search(embeddings.cpu().numpy(), top_k)
128
+ return torch.from_numpy(ids), torch.from_numpy(dists)
129
+
130
+ def save(self, file_path: FilePath) -> None:
131
+ """Adhere to ``BaseIndexer.save``."""
132
+ faiss.write_index(self._index, str(file_path))
133
+
134
+
135
+ # helper functions
136
+
137
+
138
+ def _build_index_type_dsl(index_type: _FaissIndexType, *, m: int, nlists: int) -> str:
139
+ """Build the index_type DSL passed to ``faiss.index_factory``."""
140
+ if index_type == "HNSW":
141
+ return f"{index_type}{m},Flat"
142
+
143
+ if index_type == "IVF":
144
+ return f"{index_type}{nlists},Flat"
145
+
146
+ return "Flat"
147
+
148
+
149
+ def _resolve_metric_type(metric: _FaissMetric) -> int:
150
+ """Resolve the metric type from a string literal to an integer."""
151
+ if metric == "L2":
152
+ return ty.cast(int, faiss.METRIC_L2)
153
+
154
+ return ty.cast(int, faiss.METRIC_INNER_PRODUCT)
@@ -0,0 +1,215 @@
1
+ """Milvus-based vector index implementation for the retrieval stage."""
2
+
3
+ import contextlib
4
+ import typing as ty
5
+ import uuid
6
+
7
+ import numpy as np
8
+ import pymilvus as milvus
9
+ import torch
10
+
11
+ from torch_rechub.types import FilePath
12
+
13
+ from .base import BaseBuilder, BaseIndexer
14
+
15
+ # Type for indexing methods.
16
+ _MilvusIndexType = ty.Literal["FLAT", "HNSW", "IVF_FLAT"]
17
+
18
+ # Type for indexing metrics.
19
+ _MilvusMetric = ty.Literal["COSINE", "IP", "L2"]
20
+
21
+ # Default indexing method.
22
+ _DEFAULT_MILVUS_INDEX_TYPE: _MilvusIndexType = "FLAT"
23
+
24
+ # Default indexing metric.
25
+ _DEFAULT_MILVUS_METRIC: _MilvusMetric = "COSINE"
26
+
27
+ # Default number of clusters to build an IVF index.
28
+ _DEFAULT_N_LIST = 128
29
+
30
+ # Default max number of neighbors to build an HNSW index.
31
+ _DEFAULT_M = 30
32
+
33
+ # Default name of Milvus database connection.
34
+ _DEFAULT_NAME = "rechub"
35
+
36
+ # Default host of Milvus instance.
37
+ _DEFAULT_HOST = "localhost"
38
+
39
+ # Default port of Milvus instance.
40
+ _DEFAULT_PORT = 19530
41
+
42
+ # Name of the embedding column in the Milvus database table.
43
+ _EMBEDDING_COLUMN = "embedding"
44
+
45
+
46
+ class MilvusBuilder(BaseBuilder):
47
+ """Implement ``BaseBuilder`` for Milvus vector index construction."""
48
+
49
+ def __init__(
50
+ self,
51
+ d: int,
52
+ index_type: _MilvusIndexType = _DEFAULT_MILVUS_INDEX_TYPE,
53
+ metric: _MilvusMetric = _DEFAULT_MILVUS_METRIC,
54
+ *,
55
+ m: int = _DEFAULT_M,
56
+ nlist: int = _DEFAULT_N_LIST,
57
+ ef: ty.Optional[int] = None,
58
+ nprobe: ty.Optional[int] = None,
59
+ name: str = _DEFAULT_NAME,
60
+ host: str = _DEFAULT_HOST,
61
+ port: int = _DEFAULT_PORT,
62
+ ) -> None:
63
+ """
64
+ Initialize a Milvus builder.
65
+
66
+ Parameters
67
+ ----------
68
+ d : int
69
+ The dimension of embeddings.
70
+ index_type : ``"FLAT"``, ``"HNSW"``, or ``"IVF_FLAT"``, optional
71
+ The indexing index_type. Default to ``"FLAT"``.
72
+ metric : ``"COSINE"``, ``"IP"``, or ``"L2"``, optional
73
+ The indexing metric. Default to ``"COSINE"``.
74
+ m : int, optional
75
+ Max number of neighbors to build an HNSW index.
76
+ nlist : int, optional
77
+ Number of clusters to build an IVF index.
78
+ ef : int or None, optional
79
+ Number of candidate nodes during an HNSW search.
80
+ nprobe : int or None, optional
81
+ Number of clusters during an IVF search.
82
+ name : str, optional
83
+ The name of connection. Each name corresponds to one connection.
84
+ host : str, optional
85
+ The host of Milvus instance. Default at "localhost".
86
+ port : int, optional
87
+ The port of Milvus instance. Default at 19530
88
+ """
89
+ self._d = d
90
+
91
+ # connection parameters
92
+ self._name = name
93
+ self._host = host
94
+ self._port = port
95
+
96
+ bparams: dict[str, ty.Any] = {}
97
+ qparams: dict[str, ty.Any] = {}
98
+
99
+ if index_type == "HNSW":
100
+ bparams.update(M=m)
101
+ if ef is not None:
102
+ qparams.update(ef=ef)
103
+
104
+ if index_type == "IVF_FLAT":
105
+ bparams.update(nlist=nlist)
106
+ if nprobe is not None:
107
+ qparams.update(nprobe=nprobe)
108
+
109
+ self._build_params = dict(
110
+ index_type=index_type,
111
+ metric_type=metric,
112
+ params=bparams,
113
+ )
114
+
115
+ self._query_params = dict(
116
+ metric_type=metric,
117
+ params=qparams,
118
+ )
119
+
120
+ @contextlib.contextmanager
121
+ def from_embeddings(
122
+ self,
123
+ embeddings: torch.Tensor,
124
+ ) -> ty.Generator["MilvusIndexer",
125
+ None,
126
+ None]:
127
+ """Adhere to ``BaseBuilder.from_embeddings``."""
128
+ milvus.connections.connect(self._name, host=self._host, port=self._port)
129
+ collection = self._build_collection(embeddings)
130
+
131
+ try:
132
+ yield MilvusIndexer(collection, self._query_params)
133
+ finally:
134
+ collection.drop()
135
+ milvus.connections.disconnect(self._name)
136
+
137
+ @contextlib.contextmanager
138
+ def from_index_file(
139
+ self,
140
+ index_file: FilePath,
141
+ ) -> ty.Generator["MilvusIndexer",
142
+ None,
143
+ None]:
144
+ """Adhere to ``BaseBuilder.from_index_file``."""
145
+ raise NotImplementedError("Milvus does not support index files!")
146
+
147
+ def _build_collection(self, embeddings: torch.Tensor) -> milvus.Collection:
148
+ """Build a Milvus collection with the current connection."""
149
+ fields = [
150
+ milvus.FieldSchema(
151
+ name="id",
152
+ dtype=milvus.DataType.INT64,
153
+ is_primary=True,
154
+ ),
155
+ milvus.FieldSchema(
156
+ name=_EMBEDDING_COLUMN,
157
+ dtype=milvus.DataType.FLOAT_VECTOR,
158
+ dim=self._d,
159
+ ),
160
+ ]
161
+
162
+ collection = milvus.Collection(
163
+ name=f"{self._name}_{uuid.uuid4().hex}",
164
+ schema=milvus.CollectionSchema(fields=fields),
165
+ using=self._name,
166
+ )
167
+
168
+ n, _ = embeddings.shape
169
+ collection.insert([np.arange(n, dtype=np.int64), embeddings.cpu().numpy()])
170
+ collection.create_index(_EMBEDDING_COLUMN, index_params=self._build_params)
171
+ collection.load()
172
+
173
+ return collection
174
+
175
+
176
+ class MilvusIndexer(BaseIndexer):
177
+ """Milvus-based implementation of ``BaseIndexer``."""
178
+
179
+ def __init__(
180
+ self,
181
+ collection: milvus.Collection,
182
+ query_params: dict[str,
183
+ ty.Any],
184
+ ) -> None:
185
+ """Initialize a Milvus indexer."""
186
+ self._collection = collection
187
+ self._query_params = query_params
188
+
189
+ def query(
190
+ self,
191
+ embeddings: torch.Tensor,
192
+ top_k: int,
193
+ ) -> tuple[torch.Tensor,
194
+ torch.Tensor]:
195
+ """Adhere to ``BaseIndexer.query``."""
196
+ results = self._collection.search(
197
+ data=embeddings.cpu().numpy(),
198
+ anns_field=_EMBEDDING_COLUMN,
199
+ param=self._query_params,
200
+ limit=top_k,
201
+ )
202
+
203
+ n, _ = embeddings.shape
204
+ nn_ids = np.zeros((n, top_k), dtype=np.int64)
205
+ nn_distances = np.zeros((n, top_k), dtype=np.float32)
206
+
207
+ for i, result in enumerate(results):
208
+ nn_ids[i] = result.ids
209
+ nn_distances[i] = result.distances
210
+
211
+ return torch.from_numpy(nn_ids), torch.from_numpy(nn_distances)
212
+
213
+ def save(self, file_path: FilePath) -> None:
214
+ """Adhere to ``BaseIndexer.save``."""
215
+ raise NotImplementedError("Milvus does not support index files!")
@@ -43,6 +43,7 @@ class CTRTrainer(object):
43
43
  gpus=None,
44
44
  loss_mode=True,
45
45
  model_path="./",
46
+ model_logger=None,
46
47
  ):
47
48
  self.model = model # for uniform weights save method in one gpu or multi gpu
48
49
  if gpus is None:
@@ -70,10 +71,13 @@ class CTRTrainer(object):
70
71
  self.model_path = model_path
71
72
  # Initialize regularization loss
72
73
  self.reg_loss_fn = RegularizationLoss(**regularization_params)
74
+ self.model_logger = model_logger
73
75
 
74
76
  def train_one_epoch(self, data_loader, log_interval=10):
75
77
  self.model.train()
76
78
  total_loss = 0
79
+ epoch_loss = 0
80
+ batch_count = 0
77
81
  tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
78
82
  for i, (x_dict, y) in enumerate(tk0):
79
83
  x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
@@ -93,27 +97,62 @@ class CTRTrainer(object):
93
97
  loss.backward()
94
98
  self.optimizer.step()
95
99
  total_loss += loss.item()
100
+ epoch_loss += loss.item()
101
+ batch_count += 1
96
102
  if (i + 1) % log_interval == 0:
97
103
  tk0.set_postfix(loss=total_loss / log_interval)
98
104
  total_loss = 0
99
105
 
106
+ # Return average epoch loss
107
+ return epoch_loss / batch_count if batch_count > 0 else 0
108
+
100
109
  def fit(self, train_dataloader, val_dataloader=None):
110
+ for logger in self._iter_loggers():
111
+ logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self.optimizer.param_groups[0]['lr'], 'loss_mode': self.loss_mode})
112
+
101
113
  for epoch_i in range(self.n_epoch):
102
114
  print('epoch:', epoch_i)
103
- self.train_one_epoch(train_dataloader)
115
+ train_loss = self.train_one_epoch(train_dataloader)
116
+
117
+ for logger in self._iter_loggers():
118
+ logger.log_metrics({'train/loss': train_loss, 'learning_rate': self.optimizer.param_groups[0]['lr']}, step=epoch_i)
119
+
104
120
  if self.scheduler is not None:
105
121
  if epoch_i % self.scheduler.step_size == 0:
106
122
  print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
107
123
  self.scheduler.step() # update lr in epoch level by scheduler
124
+
108
125
  if val_dataloader:
109
126
  auc = self.evaluate(self.model, val_dataloader)
110
127
  print('epoch:', epoch_i, 'validation: auc:', auc)
128
+
129
+ for logger in self._iter_loggers():
130
+ logger.log_metrics({'val/auc': auc}, step=epoch_i)
131
+
111
132
  if self.early_stopper.stop_training(auc, self.model.state_dict()):
112
133
  print(f'validation: best auc: {self.early_stopper.best_auc}')
113
134
  self.model.load_state_dict(self.early_stopper.best_weights)
114
135
  break
136
+
115
137
  torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best auc model
116
138
 
139
+ for logger in self._iter_loggers():
140
+ logger.finish()
141
+
142
+ def _iter_loggers(self):
143
+ """Return logger instances as a list.
144
+
145
+ Returns
146
+ -------
147
+ list
148
+ Active logger instances. Empty when ``model_logger`` is ``None``.
149
+ """
150
+ if self.model_logger is None:
151
+ return []
152
+ if isinstance(self.model_logger, (list, tuple)):
153
+ return list(self.model_logger)
154
+ return [self.model_logger]
155
+
117
156
  def evaluate(self, model, data_loader):
118
157
  model.eval()
119
158
  targets, predicts = list(), list()
@@ -146,7 +185,7 @@ class CTRTrainer(object):
146
185
  predicts.extend(y_pred.tolist())
147
186
  return predicts
148
187
 
149
- def export_onnx(self, output_path, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False):
188
+ def export_onnx(self, output_path, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False, onnx_export_kwargs=None):
150
189
  """Export the trained model to ONNX format.
151
190
 
152
191
  This method exports the ranking model (e.g., DeepFM, WideDeep, DCN) to ONNX format
@@ -163,6 +202,7 @@ class CTRTrainer(object):
163
202
  device (str, optional): Device for export ('cpu', 'cuda', etc.).
164
203
  If None, defaults to 'cpu' for maximum compatibility.
165
204
  verbose (bool): Print export details (default: False).
205
+ onnx_export_kwargs (dict, optional): Extra kwargs forwarded to ``torch.onnx.export``.
166
206
 
167
207
  Returns:
168
208
  bool: True if export succeeded, False otherwise.
@@ -188,7 +228,16 @@ class CTRTrainer(object):
188
228
  export_device = device if device is not None else 'cpu'
189
229
 
190
230
  exporter = ONNXExporter(model, device=export_device)
191
- return exporter.export(output_path=output_path, dummy_input=dummy_input, batch_size=batch_size, seq_length=seq_length, opset_version=opset_version, dynamic_batch=dynamic_batch, verbose=verbose)
231
+ return exporter.export(
232
+ output_path=output_path,
233
+ dummy_input=dummy_input,
234
+ batch_size=batch_size,
235
+ seq_length=seq_length,
236
+ opset_version=opset_version,
237
+ dynamic_batch=dynamic_batch,
238
+ verbose=verbose,
239
+ onnx_export_kwargs=onnx_export_kwargs,
240
+ )
192
241
 
193
242
  def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
194
243
  """Visualize the model's computation graph.