torch-rechub 0.0.6__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,154 @@
1
+ """FAISS-based vector index implementation for the retrieval stage."""
2
+
3
+ import contextlib
4
+ import typing as ty
5
+
6
+ import faiss
7
+ import torch
8
+
9
+ from torch_rechub.types import FilePath
10
+
11
+ from .base import BaseBuilder, BaseIndexer
12
+
13
+ # Type for indexing methods.
14
+ _FaissIndexType = ty.Literal["Flat", "HNSW", "IVF"]
15
+
16
+ # Type for indexing metrics.
17
+ _FaissMetric = ty.Literal["IP", "L2"]
18
+
19
+ # Default indexing method.
20
+ _DEFAULT_FAISS_INDEX_TYPE: _FaissIndexType = "Flat"
21
+
22
+ # Default indexing metric.
23
+ _DEFAULT_FAISS_METRIC: _FaissMetric = "L2"
24
+
25
+ # Default number of clusters to build an IVF index.
26
+ _DEFAULT_N_LISTS = 100
27
+
28
+ # Default max number of neighbors to build an HNSW index.
29
+ _DEFAULT_M = 32
30
+
31
+
32
+ class FaissBuilder(BaseBuilder):
33
+ """Implement ``BaseBuilder`` for FAISS vector index construction."""
34
+
35
+ def __init__(
36
+ self,
37
+ index_type: _FaissIndexType = _DEFAULT_FAISS_INDEX_TYPE,
38
+ metric: _FaissMetric = _DEFAULT_FAISS_METRIC,
39
+ *,
40
+ m: int = _DEFAULT_M,
41
+ nlists: int = _DEFAULT_N_LISTS,
42
+ efSearch: ty.Optional[int] = None,
43
+ nprobe: ty.Optional[int] = None,
44
+ ) -> None:
45
+ """
46
+ Initialize a FAISS builder.
47
+
48
+ Parameters
49
+ ----------
50
+ index_type : ``"Flat"``, ``"HNSW"``, or ``"IVF"``, optional
51
+ The indexing index_type. Default to ``"Flat"``.
52
+ metric : ``"IP"``, ``"L2"``, optional
53
+ The indexing metric. Default to ``"L2"``.
54
+ m : int, optional
55
+ Max number of neighbors to build an HNSW index.
56
+ nlists : int, optional
57
+ Number of clusters to build an IVF index.
58
+ efSearch : int or None, optional
59
+ Number of candidate nodes during an HNSW search.
60
+ nprobe : int or None, optional
61
+ Number of clusters during an IVF search.
62
+ """
63
+ self._index_type_dsl = _build_index_type_dsl(index_type, m=m, nlists=nlists)
64
+ self._metric = _resolve_metric_type(metric)
65
+
66
+ self._efSearch = efSearch
67
+ self._nprobe = nprobe
68
+
69
+ @contextlib.contextmanager
70
+ def from_embeddings(
71
+ self,
72
+ embeddings: torch.Tensor,
73
+ ) -> ty.Generator["FaissIndexer",
74
+ None,
75
+ None]:
76
+ """Adhere to ``BaseBuilder.from_embeddings``."""
77
+ index: faiss.Index = faiss.index_factory(
78
+ embeddings.shape[1],
79
+ self._index_type_dsl,
80
+ self._metric,
81
+ )
82
+
83
+ if isinstance(index, faiss.IndexHNSW) and self._efSearch is not None:
84
+ index.hnsw.efSearch = self._efSearch
85
+
86
+ if isinstance(index, faiss.IndexIVF) and self._nprobe is not None:
87
+ index.nprobe = self._nprobe
88
+
89
+ index.train(embeddings)
90
+ index.add(embeddings)
91
+
92
+ try:
93
+ yield FaissIndexer(index)
94
+ finally:
95
+ pass
96
+
97
+ @contextlib.contextmanager
98
+ def from_index_file(
99
+ self,
100
+ index_file: FilePath,
101
+ ) -> ty.Generator["FaissIndexer",
102
+ None,
103
+ None]:
104
+ """Adhere to ``BaseBuilder.from_index_file``."""
105
+ index = faiss.read_index(str(index_file))
106
+
107
+ try:
108
+ yield FaissIndexer(index)
109
+ finally:
110
+ pass
111
+
112
+
113
+ class FaissIndexer(BaseIndexer):
114
+ """FAISS-based implementation of ``BaseIndexer``."""
115
+
116
+ def __init__(self, index: faiss.Index) -> None:
117
+ """Initialize a FAISS indexer."""
118
+ self._index = index
119
+
120
+ def query(
121
+ self,
122
+ embeddings: torch.Tensor,
123
+ top_k: int,
124
+ ) -> tuple[torch.Tensor,
125
+ torch.Tensor]:
126
+ """Adhere to ``BaseIndexer.query``."""
127
+ dists, ids = self._index.search(embeddings.cpu().numpy(), top_k)
128
+ return torch.from_numpy(ids), torch.from_numpy(dists)
129
+
130
+ def save(self, file_path: FilePath) -> None:
131
+ """Adhere to ``BaseIndexer.save``."""
132
+ faiss.write_index(self._index, str(file_path))
133
+
134
+
135
+ # helper functions
136
+
137
+
138
+ def _build_index_type_dsl(index_type: _FaissIndexType, *, m: int, nlists: int) -> str:
139
+ """Build the index_type DSL passed to ``faiss.index_factory``."""
140
+ if index_type == "HNSW":
141
+ return f"{index_type}{m},Flat"
142
+
143
+ if index_type == "IVF":
144
+ return f"{index_type}{nlists},Flat"
145
+
146
+ return "Flat"
147
+
148
+
149
+ def _resolve_metric_type(metric: _FaissMetric) -> int:
150
+ """Resolve the metric type from a string literal to an integer."""
151
+ if metric == "L2":
152
+ return ty.cast(int, faiss.METRIC_L2)
153
+
154
+ return ty.cast(int, faiss.METRIC_INNER_PRODUCT)
@@ -0,0 +1,215 @@
1
+ """Milvus-based vector index implementation for the retrieval stage."""
2
+
3
+ import contextlib
4
+ import typing as ty
5
+ import uuid
6
+
7
+ import numpy as np
8
+ import pymilvus as milvus
9
+ import torch
10
+
11
+ from torch_rechub.types import FilePath
12
+
13
+ from .base import BaseBuilder, BaseIndexer
14
+
15
+ # Type for indexing methods.
16
+ _MilvusIndexType = ty.Literal["FLAT", "HNSW", "IVF_FLAT"]
17
+
18
+ # Type for indexing metrics.
19
+ _MilvusMetric = ty.Literal["COSINE", "IP", "L2"]
20
+
21
+ # Default indexing method.
22
+ _DEFAULT_MILVUS_INDEX_TYPE: _MilvusIndexType = "FLAT"
23
+
24
+ # Default indexing metric.
25
+ _DEFAULT_MILVUS_METRIC: _MilvusMetric = "COSINE"
26
+
27
+ # Default number of clusters to build an IVF index.
28
+ _DEFAULT_N_LIST = 128
29
+
30
+ # Default max number of neighbors to build an HNSW index.
31
+ _DEFAULT_M = 30
32
+
33
+ # Default name of Milvus database connection.
34
+ _DEFAULT_NAME = "rechub"
35
+
36
+ # Default host of Milvus instance.
37
+ _DEFAULT_HOST = "localhost"
38
+
39
+ # Default port of Milvus instance.
40
+ _DEFAULT_PORT = 19530
41
+
42
+ # Name of the embedding column in the Milvus database table.
43
+ _EMBEDDING_COLUMN = "embedding"
44
+
45
+
46
+ class MilvusBuilder(BaseBuilder):
47
+ """Implement ``BaseBuilder`` for Milvus vector index construction."""
48
+
49
+ def __init__(
50
+ self,
51
+ d: int,
52
+ index_type: _MilvusIndexType = _DEFAULT_MILVUS_INDEX_TYPE,
53
+ metric: _MilvusMetric = _DEFAULT_MILVUS_METRIC,
54
+ *,
55
+ m: int = _DEFAULT_M,
56
+ nlist: int = _DEFAULT_N_LIST,
57
+ ef: ty.Optional[int] = None,
58
+ nprobe: ty.Optional[int] = None,
59
+ name: str = _DEFAULT_NAME,
60
+ host: str = _DEFAULT_HOST,
61
+ port: int = _DEFAULT_PORT,
62
+ ) -> None:
63
+ """
64
+ Initialize a Milvus builder.
65
+
66
+ Parameters
67
+ ----------
68
+ d : int
69
+ The dimension of embeddings.
70
+ index_type : ``"FLAT"``, ``"HNSW"``, or ``"IVF_FLAT"``, optional
71
+ The indexing index_type. Default to ``"FLAT"``.
72
+ metric : ``"COSINE"``, ``"IP"``, or ``"L2"``, optional
73
+ The indexing metric. Default to ``"COSINE"``.
74
+ m : int, optional
75
+ Max number of neighbors to build an HNSW index.
76
+ nlist : int, optional
77
+ Number of clusters to build an IVF index.
78
+ ef : int or None, optional
79
+ Number of candidate nodes during an HNSW search.
80
+ nprobe : int or None, optional
81
+ Number of clusters during an IVF search.
82
+ name : str, optional
83
+ The name of connection. Each name corresponds to one connection.
84
+ host : str, optional
85
+ The host of Milvus instance. Default at "localhost".
86
+ port : int, optional
87
+ The port of Milvus instance. Default at 19530
88
+ """
89
+ self._d = d
90
+
91
+ # connection parameters
92
+ self._name = name
93
+ self._host = host
94
+ self._port = port
95
+
96
+ bparams: dict[str, ty.Any] = {}
97
+ qparams: dict[str, ty.Any] = {}
98
+
99
+ if index_type == "HNSW":
100
+ bparams.update(M=m)
101
+ if ef is not None:
102
+ qparams.update(ef=ef)
103
+
104
+ if index_type == "IVF_FLAT":
105
+ bparams.update(nlist=nlist)
106
+ if nprobe is not None:
107
+ qparams.update(nprobe=nprobe)
108
+
109
+ self._build_params = dict(
110
+ index_type=index_type,
111
+ metric_type=metric,
112
+ params=bparams,
113
+ )
114
+
115
+ self._query_params = dict(
116
+ metric_type=metric,
117
+ params=qparams,
118
+ )
119
+
120
+ @contextlib.contextmanager
121
+ def from_embeddings(
122
+ self,
123
+ embeddings: torch.Tensor,
124
+ ) -> ty.Generator["MilvusIndexer",
125
+ None,
126
+ None]:
127
+ """Adhere to ``BaseBuilder.from_embeddings``."""
128
+ milvus.connections.connect(self._name, host=self._host, port=self._port)
129
+ collection = self._build_collection(embeddings)
130
+
131
+ try:
132
+ yield MilvusIndexer(collection, self._query_params)
133
+ finally:
134
+ collection.drop()
135
+ milvus.connections.disconnect(self._name)
136
+
137
+ @contextlib.contextmanager
138
+ def from_index_file(
139
+ self,
140
+ index_file: FilePath,
141
+ ) -> ty.Generator["MilvusIndexer",
142
+ None,
143
+ None]:
144
+ """Adhere to ``BaseBuilder.from_index_file``."""
145
+ raise NotImplementedError("Milvus does not support index files!")
146
+
147
+ def _build_collection(self, embeddings: torch.Tensor) -> milvus.Collection:
148
+ """Build a Milvus collection with the current connection."""
149
+ fields = [
150
+ milvus.FieldSchema(
151
+ name="id",
152
+ dtype=milvus.DataType.INT64,
153
+ is_primary=True,
154
+ ),
155
+ milvus.FieldSchema(
156
+ name=_EMBEDDING_COLUMN,
157
+ dtype=milvus.DataType.FLOAT_VECTOR,
158
+ dim=self._d,
159
+ ),
160
+ ]
161
+
162
+ collection = milvus.Collection(
163
+ name=f"{self._name}_{uuid.uuid4().hex}",
164
+ schema=milvus.CollectionSchema(fields=fields),
165
+ using=self._name,
166
+ )
167
+
168
+ n, _ = embeddings.shape
169
+ collection.insert([np.arange(n, dtype=np.int64), embeddings.cpu().numpy()])
170
+ collection.create_index(_EMBEDDING_COLUMN, index_params=self._build_params)
171
+ collection.load()
172
+
173
+ return collection
174
+
175
+
176
+ class MilvusIndexer(BaseIndexer):
177
+ """Milvus-based implementation of ``BaseIndexer``."""
178
+
179
+ def __init__(
180
+ self,
181
+ collection: milvus.Collection,
182
+ query_params: dict[str,
183
+ ty.Any],
184
+ ) -> None:
185
+ """Initialize a Milvus indexer."""
186
+ self._collection = collection
187
+ self._query_params = query_params
188
+
189
+ def query(
190
+ self,
191
+ embeddings: torch.Tensor,
192
+ top_k: int,
193
+ ) -> tuple[torch.Tensor,
194
+ torch.Tensor]:
195
+ """Adhere to ``BaseIndexer.query``."""
196
+ results = self._collection.search(
197
+ data=embeddings.cpu().numpy(),
198
+ anns_field=_EMBEDDING_COLUMN,
199
+ param=self._query_params,
200
+ limit=top_k,
201
+ )
202
+
203
+ n, _ = embeddings.shape
204
+ nn_ids = np.zeros((n, top_k), dtype=np.int64)
205
+ nn_distances = np.zeros((n, top_k), dtype=np.float32)
206
+
207
+ for i, result in enumerate(results):
208
+ nn_ids[i] = result.ids
209
+ nn_distances[i] = result.distances
210
+
211
+ return torch.from_numpy(nn_ids), torch.from_numpy(nn_distances)
212
+
213
+ def save(self, file_path: FilePath) -> None:
214
+ """Adhere to ``BaseIndexer.save``."""
215
+ raise NotImplementedError("Milvus does not support index files!")
@@ -185,7 +185,7 @@ class CTRTrainer(object):
185
185
  predicts.extend(y_pred.tolist())
186
186
  return predicts
187
187
 
188
- def export_onnx(self, output_path, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False):
188
+ def export_onnx(self, output_path, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False, onnx_export_kwargs=None):
189
189
  """Export the trained model to ONNX format.
190
190
 
191
191
  This method exports the ranking model (e.g., DeepFM, WideDeep, DCN) to ONNX format
@@ -202,6 +202,7 @@ class CTRTrainer(object):
202
202
  device (str, optional): Device for export ('cpu', 'cuda', etc.).
203
203
  If None, defaults to 'cpu' for maximum compatibility.
204
204
  verbose (bool): Print export details (default: False).
205
+ onnx_export_kwargs (dict, optional): Extra kwargs forwarded to ``torch.onnx.export``.
205
206
 
206
207
  Returns:
207
208
  bool: True if export succeeded, False otherwise.
@@ -227,7 +228,16 @@ class CTRTrainer(object):
227
228
  export_device = device if device is not None else 'cpu'
228
229
 
229
230
  exporter = ONNXExporter(model, device=export_device)
230
- return exporter.export(output_path=output_path, dummy_input=dummy_input, batch_size=batch_size, seq_length=seq_length, opset_version=opset_version, dynamic_batch=dynamic_batch, verbose=verbose)
231
+ return exporter.export(
232
+ output_path=output_path,
233
+ dummy_input=dummy_input,
234
+ batch_size=batch_size,
235
+ seq_length=seq_length,
236
+ opset_version=opset_version,
237
+ dynamic_batch=dynamic_batch,
238
+ verbose=verbose,
239
+ onnx_export_kwargs=onnx_export_kwargs,
240
+ )
231
241
 
232
242
  def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
233
243
  """Visualize the model's computation graph.
@@ -215,7 +215,7 @@ class MatchTrainer(object):
215
215
  predicts.append(y_pred.data)
216
216
  return torch.cat(predicts, dim=0)
217
217
 
218
- def export_onnx(self, output_path, mode=None, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False):
218
+ def export_onnx(self, output_path, mode=None, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False, onnx_export_kwargs=None):
219
219
  """Export the trained matching model to ONNX format.
220
220
 
221
221
  This method exports matching/retrieval models (e.g., DSSM, YoutubeDNN, MIND)
@@ -237,6 +237,7 @@ class MatchTrainer(object):
237
237
  device (str, optional): Device for export ('cpu', 'cuda', etc.).
238
238
  If None, defaults to 'cpu' for maximum compatibility.
239
239
  verbose (bool): Print export details (default: False).
240
+ onnx_export_kwargs (dict, optional): Extra kwargs forwarded to ``torch.onnx.export``.
240
241
 
241
242
  Returns:
242
243
  bool: True if export succeeded, False otherwise.
@@ -270,7 +271,17 @@ class MatchTrainer(object):
270
271
 
271
272
  try:
272
273
  exporter = ONNXExporter(model, device=export_device)
273
- return exporter.export(output_path=output_path, mode=mode, dummy_input=dummy_input, batch_size=batch_size, seq_length=seq_length, opset_version=opset_version, dynamic_batch=dynamic_batch, verbose=verbose)
274
+ return exporter.export(
275
+ output_path=output_path,
276
+ mode=mode,
277
+ dummy_input=dummy_input,
278
+ batch_size=batch_size,
279
+ seq_length=seq_length,
280
+ opset_version=opset_version,
281
+ dynamic_batch=dynamic_batch,
282
+ verbose=verbose,
283
+ onnx_export_kwargs=onnx_export_kwargs,
284
+ )
274
285
  finally:
275
286
  # Restore original mode
276
287
  if hasattr(model, 'mode'):
@@ -261,7 +261,7 @@ class MTLTrainer(object):
261
261
  predicts.extend(y_preds.tolist())
262
262
  return predicts
263
263
 
264
- def export_onnx(self, output_path, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False):
264
+ def export_onnx(self, output_path, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False, onnx_export_kwargs=None):
265
265
  """Export the trained multi-task model to ONNX format.
266
266
 
267
267
  This method exports multi-task learning models (e.g., MMOE, PLE, ESMM, SharedBottom)
@@ -283,6 +283,7 @@ class MTLTrainer(object):
283
283
  device (str, optional): Device for export ('cpu', 'cuda', etc.).
284
284
  If None, defaults to 'cpu' for maximum compatibility.
285
285
  verbose (bool): Print export details (default: False).
286
+ onnx_export_kwargs (dict, optional): Extra kwargs forwarded to ``torch.onnx.export``.
286
287
 
287
288
  Returns:
288
289
  bool: True if export succeeded, False otherwise.
@@ -304,7 +305,16 @@ class MTLTrainer(object):
304
305
  export_device = device if device is not None else 'cpu'
305
306
 
306
307
  exporter = ONNXExporter(model, device=export_device)
307
- return exporter.export(output_path=output_path, dummy_input=dummy_input, batch_size=batch_size, seq_length=seq_length, opset_version=opset_version, dynamic_batch=dynamic_batch, verbose=verbose)
308
+ return exporter.export(
309
+ output_path=output_path,
310
+ dummy_input=dummy_input,
311
+ batch_size=batch_size,
312
+ seq_length=seq_length,
313
+ opset_version=opset_version,
314
+ dynamic_batch=dynamic_batch,
315
+ verbose=verbose,
316
+ onnx_export_kwargs=onnx_export_kwargs,
317
+ )
308
318
 
309
319
  def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
310
320
  """Visualize the model's computation graph.
@@ -255,7 +255,7 @@ class SeqTrainer(object):
255
255
 
256
256
  return avg_loss, accuracy
257
257
 
258
- def export_onnx(self, output_path, batch_size=2, seq_length=50, vocab_size=None, opset_version=14, dynamic_batch=True, device=None, verbose=False):
258
+ def export_onnx(self, output_path, batch_size=2, seq_length=50, vocab_size=None, opset_version=14, dynamic_batch=True, device=None, verbose=False, onnx_export_kwargs=None):
259
259
  """Export the trained sequence generation model to ONNX format.
260
260
 
261
261
  This method exports sequence generation models (e.g., HSTU) to ONNX format.
@@ -273,6 +273,7 @@ class SeqTrainer(object):
273
273
  device (str, optional): Device for export ('cpu', 'cuda', etc.).
274
274
  If None, defaults to 'cpu' for maximum compatibility.
275
275
  verbose (bool): Print export details (default: False).
276
+ onnx_export_kwargs (dict, optional): Extra kwargs forwarded to ``torch.onnx.export``.
276
277
 
277
278
  Returns:
278
279
  bool: True if export succeeded, False otherwise.
@@ -321,20 +322,38 @@ class SeqTrainer(object):
321
322
 
322
323
  try:
323
324
  with torch.no_grad():
324
- torch.onnx.export(
325
- model,
326
- (dummy_seq_tokens,
327
- dummy_seq_time_diffs),
328
- output_path,
329
- input_names=["seq_tokens",
330
- "seq_time_diffs"],
331
- output_names=["output"],
332
- dynamic_axes=dynamic_axes,
333
- opset_version=opset_version,
334
- do_constant_folding=True,
335
- verbose=verbose,
336
- dynamo=False # Use legacy exporter for dynamic_axes support
337
- )
325
+ import inspect
326
+
327
+ export_kwargs = {
328
+ "f": output_path,
329
+ "input_names": ["seq_tokens",
330
+ "seq_time_diffs"],
331
+ "output_names": ["output"],
332
+ "dynamic_axes": dynamic_axes,
333
+ "opset_version": opset_version,
334
+ "do_constant_folding": True,
335
+ "verbose": verbose,
336
+ }
337
+
338
+ if onnx_export_kwargs:
339
+ overlap = set(export_kwargs.keys()) & set(onnx_export_kwargs.keys())
340
+ overlap.discard("dynamo")
341
+ if overlap:
342
+ raise ValueError("onnx_export_kwargs contains keys that overlap with explicit args: "
343
+ f"{sorted(overlap)}. Please set them via export_onnx() parameters instead.")
344
+ export_kwargs.update(onnx_export_kwargs)
345
+
346
+ # Auto-pick exporter:
347
+ # - dynamic_axes present => prefer legacy exporter (dynamo=False) for dynamic batch/seq
348
+ # - otherwise prefer dynamo exporter (dynamo=True) on newer torch
349
+ sig = inspect.signature(torch.onnx.export)
350
+ if "dynamo" in sig.parameters:
351
+ if "dynamo" not in export_kwargs:
352
+ export_kwargs["dynamo"] = False if dynamic_axes is not None else True
353
+ else:
354
+ export_kwargs.pop("dynamo", None)
355
+
356
+ torch.onnx.export(model, (dummy_seq_tokens, dummy_seq_time_diffs), **export_kwargs)
338
357
 
339
358
  if verbose:
340
359
  print(f"Successfully exported ONNX model to: {output_path}")
torch_rechub/types.py ADDED
@@ -0,0 +1,5 @@
1
+ import os
2
+ import typing as ty
3
+
4
+ #: Type for path to a file.
5
+ FilePath = ty.Union[str, os.PathLike]