torch-rechub 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/basic/layers.py +213 -150
- torch_rechub/basic/loss_func.py +62 -47
- torch_rechub/basic/tracking.py +198 -0
- torch_rechub/data/__init__.py +0 -0
- torch_rechub/data/convert.py +67 -0
- torch_rechub/data/dataset.py +107 -0
- torch_rechub/models/generative/hstu.py +48 -33
- torch_rechub/serving/__init__.py +50 -0
- torch_rechub/serving/annoy.py +133 -0
- torch_rechub/serving/base.py +107 -0
- torch_rechub/serving/faiss.py +154 -0
- torch_rechub/serving/milvus.py +215 -0
- torch_rechub/trainers/ctr_trainer.py +52 -3
- torch_rechub/trainers/match_trainer.py +52 -3
- torch_rechub/trainers/mtl_trainer.py +61 -3
- torch_rechub/trainers/seq_trainer.py +93 -17
- torch_rechub/types.py +5 -0
- torch_rechub/utils/data.py +167 -137
- torch_rechub/utils/hstu_utils.py +87 -76
- torch_rechub/utils/model_utils.py +10 -12
- torch_rechub/utils/onnx_export.py +98 -45
- torch_rechub/utils/quantization.py +128 -0
- torch_rechub/utils/visualization.py +4 -12
- {torch_rechub-0.0.5.dist-info → torch_rechub-0.1.0.dist-info}/METADATA +20 -5
- {torch_rechub-0.0.5.dist-info → torch_rechub-0.1.0.dist-info}/RECORD +27 -17
- torch_rechub/trainers/matching.md +0 -3
- {torch_rechub-0.0.5.dist-info → torch_rechub-0.1.0.dist-info}/WHEEL +0 -0
- {torch_rechub-0.0.5.dist-info → torch_rechub-0.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""Base abstraction for vector indexers used in the retrieval stage."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
import typing as ty
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from torch_rechub.types import FilePath
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BaseBuilder(abc.ABC):
|
|
12
|
+
"""
|
|
13
|
+
Abstract base class for vector index construction.
|
|
14
|
+
|
|
15
|
+
A builder owns all build-time configuration and produces a ``BaseIndexer`` through a
|
|
16
|
+
context-managed build operation.
|
|
17
|
+
|
|
18
|
+
Examples
|
|
19
|
+
--------
|
|
20
|
+
>>> builder = BaseBuilder(...)
|
|
21
|
+
>>> embeddings = torch.randn(1000, 128)
|
|
22
|
+
>>> with builder.from_embeddings(embeddings) as indexer:
|
|
23
|
+
... ids, scores = indexer.query(embeddings[:2], top_k=5)
|
|
24
|
+
... indexer.save("index.bin")
|
|
25
|
+
>>> with builder.from_index_file("index.bin") as indexer:
|
|
26
|
+
... ids, scores = indexer.query(embeddings[:2], top_k=5)
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
@abc.abstractmethod
|
|
30
|
+
def from_embeddings(
|
|
31
|
+
self,
|
|
32
|
+
embeddings: torch.Tensor,
|
|
33
|
+
) -> ty.ContextManager["BaseIndexer"]:
|
|
34
|
+
"""
|
|
35
|
+
Build a vector index from the embeddings.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
embeddings : torch.Tensor
|
|
40
|
+
A 2D tensor (n, d) containing embedding vectors to build a new index.
|
|
41
|
+
|
|
42
|
+
Returns
|
|
43
|
+
-------
|
|
44
|
+
ContextManager[BaseIndexer]
|
|
45
|
+
A context manager that yields a fully initialized ``BaseIndexer``.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
@abc.abstractmethod
|
|
49
|
+
def from_index_file(
|
|
50
|
+
self,
|
|
51
|
+
index_file: FilePath,
|
|
52
|
+
) -> ty.ContextManager["BaseIndexer"]:
|
|
53
|
+
"""
|
|
54
|
+
Build a vector index from the index file.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
index_file : FilePath
|
|
59
|
+
Path to a serialized index on disk to be loaded.
|
|
60
|
+
|
|
61
|
+
Returns
|
|
62
|
+
-------
|
|
63
|
+
ContextManager[BaseIndexer]
|
|
64
|
+
A context manager that yields a fully initialized ``BaseIndexer``.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class BaseIndexer(abc.ABC):
|
|
69
|
+
"""Abstract base class for vector indexers in the retrieval stage."""
|
|
70
|
+
|
|
71
|
+
@abc.abstractmethod
|
|
72
|
+
def query(
|
|
73
|
+
self,
|
|
74
|
+
embeddings: torch.Tensor,
|
|
75
|
+
top_k: int,
|
|
76
|
+
) -> tuple[torch.Tensor,
|
|
77
|
+
torch.Tensor]:
|
|
78
|
+
"""
|
|
79
|
+
Query the vector index.
|
|
80
|
+
|
|
81
|
+
Parameters
|
|
82
|
+
----------
|
|
83
|
+
embeddings : torch.Tensor
|
|
84
|
+
A 2D tensor (n, d) containing embedding vectors to query the index.
|
|
85
|
+
top_k : int
|
|
86
|
+
The number of nearest items to retrieve for each vector.
|
|
87
|
+
|
|
88
|
+
Returns
|
|
89
|
+
-------
|
|
90
|
+
torch.Tensor
|
|
91
|
+
A 2D tensor of shape (n, top_k), containing the retrieved nearest neighbor
|
|
92
|
+
IDs for each vector, ordered by descending relevance.
|
|
93
|
+
torch.Tensor
|
|
94
|
+
A 2D tensor of shape (n, top_k), containing the relevance distances of the
|
|
95
|
+
nearest neighbors for each vector.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
@abc.abstractmethod
|
|
99
|
+
def save(self, file_path: FilePath) -> None:
|
|
100
|
+
"""
|
|
101
|
+
Persist the index to local disk.
|
|
102
|
+
|
|
103
|
+
Parameters
|
|
104
|
+
----------
|
|
105
|
+
file_path : FilePath
|
|
106
|
+
Destination path where the index will be saved.
|
|
107
|
+
"""
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""FAISS-based vector index implementation for the retrieval stage."""
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import typing as ty
|
|
5
|
+
|
|
6
|
+
import faiss
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from torch_rechub.types import FilePath
|
|
10
|
+
|
|
11
|
+
from .base import BaseBuilder, BaseIndexer
|
|
12
|
+
|
|
13
|
+
# Type for indexing methods.
|
|
14
|
+
_FaissIndexType = ty.Literal["Flat", "HNSW", "IVF"]
|
|
15
|
+
|
|
16
|
+
# Type for indexing metrics.
|
|
17
|
+
_FaissMetric = ty.Literal["IP", "L2"]
|
|
18
|
+
|
|
19
|
+
# Default indexing method.
|
|
20
|
+
_DEFAULT_FAISS_INDEX_TYPE: _FaissIndexType = "Flat"
|
|
21
|
+
|
|
22
|
+
# Default indexing metric.
|
|
23
|
+
_DEFAULT_FAISS_METRIC: _FaissMetric = "L2"
|
|
24
|
+
|
|
25
|
+
# Default number of clusters to build an IVF index.
|
|
26
|
+
_DEFAULT_N_LISTS = 100
|
|
27
|
+
|
|
28
|
+
# Default max number of neighbors to build an HNSW index.
|
|
29
|
+
_DEFAULT_M = 32
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class FaissBuilder(BaseBuilder):
|
|
33
|
+
"""Implement ``BaseBuilder`` for FAISS vector index construction."""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
index_type: _FaissIndexType = _DEFAULT_FAISS_INDEX_TYPE,
|
|
38
|
+
metric: _FaissMetric = _DEFAULT_FAISS_METRIC,
|
|
39
|
+
*,
|
|
40
|
+
m: int = _DEFAULT_M,
|
|
41
|
+
nlists: int = _DEFAULT_N_LISTS,
|
|
42
|
+
efSearch: ty.Optional[int] = None,
|
|
43
|
+
nprobe: ty.Optional[int] = None,
|
|
44
|
+
) -> None:
|
|
45
|
+
"""
|
|
46
|
+
Initialize a FAISS builder.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
index_type : ``"Flat"``, ``"HNSW"``, or ``"IVF"``, optional
|
|
51
|
+
The indexing index_type. Default to ``"Flat"``.
|
|
52
|
+
metric : ``"IP"``, ``"L2"``, optional
|
|
53
|
+
The indexing metric. Default to ``"L2"``.
|
|
54
|
+
m : int, optional
|
|
55
|
+
Max number of neighbors to build an HNSW index.
|
|
56
|
+
nlists : int, optional
|
|
57
|
+
Number of clusters to build an IVF index.
|
|
58
|
+
efSearch : int or None, optional
|
|
59
|
+
Number of candidate nodes during an HNSW search.
|
|
60
|
+
nprobe : int or None, optional
|
|
61
|
+
Number of clusters during an IVF search.
|
|
62
|
+
"""
|
|
63
|
+
self._index_type_dsl = _build_index_type_dsl(index_type, m=m, nlists=nlists)
|
|
64
|
+
self._metric = _resolve_metric_type(metric)
|
|
65
|
+
|
|
66
|
+
self._efSearch = efSearch
|
|
67
|
+
self._nprobe = nprobe
|
|
68
|
+
|
|
69
|
+
@contextlib.contextmanager
|
|
70
|
+
def from_embeddings(
|
|
71
|
+
self,
|
|
72
|
+
embeddings: torch.Tensor,
|
|
73
|
+
) -> ty.Generator["FaissIndexer",
|
|
74
|
+
None,
|
|
75
|
+
None]:
|
|
76
|
+
"""Adhere to ``BaseBuilder.from_embeddings``."""
|
|
77
|
+
index: faiss.Index = faiss.index_factory(
|
|
78
|
+
embeddings.shape[1],
|
|
79
|
+
self._index_type_dsl,
|
|
80
|
+
self._metric,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
if isinstance(index, faiss.IndexHNSW) and self._efSearch is not None:
|
|
84
|
+
index.hnsw.efSearch = self._efSearch
|
|
85
|
+
|
|
86
|
+
if isinstance(index, faiss.IndexIVF) and self._nprobe is not None:
|
|
87
|
+
index.nprobe = self._nprobe
|
|
88
|
+
|
|
89
|
+
index.train(embeddings)
|
|
90
|
+
index.add(embeddings)
|
|
91
|
+
|
|
92
|
+
try:
|
|
93
|
+
yield FaissIndexer(index)
|
|
94
|
+
finally:
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
@contextlib.contextmanager
|
|
98
|
+
def from_index_file(
|
|
99
|
+
self,
|
|
100
|
+
index_file: FilePath,
|
|
101
|
+
) -> ty.Generator["FaissIndexer",
|
|
102
|
+
None,
|
|
103
|
+
None]:
|
|
104
|
+
"""Adhere to ``BaseBuilder.from_index_file``."""
|
|
105
|
+
index = faiss.read_index(str(index_file))
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
yield FaissIndexer(index)
|
|
109
|
+
finally:
|
|
110
|
+
pass
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class FaissIndexer(BaseIndexer):
|
|
114
|
+
"""FAISS-based implementation of ``BaseIndexer``."""
|
|
115
|
+
|
|
116
|
+
def __init__(self, index: faiss.Index) -> None:
|
|
117
|
+
"""Initialize a FAISS indexer."""
|
|
118
|
+
self._index = index
|
|
119
|
+
|
|
120
|
+
def query(
|
|
121
|
+
self,
|
|
122
|
+
embeddings: torch.Tensor,
|
|
123
|
+
top_k: int,
|
|
124
|
+
) -> tuple[torch.Tensor,
|
|
125
|
+
torch.Tensor]:
|
|
126
|
+
"""Adhere to ``BaseIndexer.query``."""
|
|
127
|
+
dists, ids = self._index.search(embeddings.cpu().numpy(), top_k)
|
|
128
|
+
return torch.from_numpy(ids), torch.from_numpy(dists)
|
|
129
|
+
|
|
130
|
+
def save(self, file_path: FilePath) -> None:
|
|
131
|
+
"""Adhere to ``BaseIndexer.save``."""
|
|
132
|
+
faiss.write_index(self._index, str(file_path))
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
# helper functions
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _build_index_type_dsl(index_type: _FaissIndexType, *, m: int, nlists: int) -> str:
|
|
139
|
+
"""Build the index_type DSL passed to ``faiss.index_factory``."""
|
|
140
|
+
if index_type == "HNSW":
|
|
141
|
+
return f"{index_type}{m},Flat"
|
|
142
|
+
|
|
143
|
+
if index_type == "IVF":
|
|
144
|
+
return f"{index_type}{nlists},Flat"
|
|
145
|
+
|
|
146
|
+
return "Flat"
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _resolve_metric_type(metric: _FaissMetric) -> int:
|
|
150
|
+
"""Resolve the metric type from a string literal to an integer."""
|
|
151
|
+
if metric == "L2":
|
|
152
|
+
return ty.cast(int, faiss.METRIC_L2)
|
|
153
|
+
|
|
154
|
+
return ty.cast(int, faiss.METRIC_INNER_PRODUCT)
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""Milvus-based vector index implementation for the retrieval stage."""
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import typing as ty
|
|
5
|
+
import uuid
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pymilvus as milvus
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from torch_rechub.types import FilePath
|
|
12
|
+
|
|
13
|
+
from .base import BaseBuilder, BaseIndexer
|
|
14
|
+
|
|
15
|
+
# Type for indexing methods.
|
|
16
|
+
_MilvusIndexType = ty.Literal["FLAT", "HNSW", "IVF_FLAT"]
|
|
17
|
+
|
|
18
|
+
# Type for indexing metrics.
|
|
19
|
+
_MilvusMetric = ty.Literal["COSINE", "IP", "L2"]
|
|
20
|
+
|
|
21
|
+
# Default indexing method.
|
|
22
|
+
_DEFAULT_MILVUS_INDEX_TYPE: _MilvusIndexType = "FLAT"
|
|
23
|
+
|
|
24
|
+
# Default indexing metric.
|
|
25
|
+
_DEFAULT_MILVUS_METRIC: _MilvusMetric = "COSINE"
|
|
26
|
+
|
|
27
|
+
# Default number of clusters to build an IVF index.
|
|
28
|
+
_DEFAULT_N_LIST = 128
|
|
29
|
+
|
|
30
|
+
# Default max number of neighbors to build an HNSW index.
|
|
31
|
+
_DEFAULT_M = 30
|
|
32
|
+
|
|
33
|
+
# Default name of Milvus database connection.
|
|
34
|
+
_DEFAULT_NAME = "rechub"
|
|
35
|
+
|
|
36
|
+
# Default host of Milvus instance.
|
|
37
|
+
_DEFAULT_HOST = "localhost"
|
|
38
|
+
|
|
39
|
+
# Default port of Milvus instance.
|
|
40
|
+
_DEFAULT_PORT = 19530
|
|
41
|
+
|
|
42
|
+
# Name of the embedding column in the Milvus database table.
|
|
43
|
+
_EMBEDDING_COLUMN = "embedding"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class MilvusBuilder(BaseBuilder):
|
|
47
|
+
"""Implement ``BaseBuilder`` for Milvus vector index construction."""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
d: int,
|
|
52
|
+
index_type: _MilvusIndexType = _DEFAULT_MILVUS_INDEX_TYPE,
|
|
53
|
+
metric: _MilvusMetric = _DEFAULT_MILVUS_METRIC,
|
|
54
|
+
*,
|
|
55
|
+
m: int = _DEFAULT_M,
|
|
56
|
+
nlist: int = _DEFAULT_N_LIST,
|
|
57
|
+
ef: ty.Optional[int] = None,
|
|
58
|
+
nprobe: ty.Optional[int] = None,
|
|
59
|
+
name: str = _DEFAULT_NAME,
|
|
60
|
+
host: str = _DEFAULT_HOST,
|
|
61
|
+
port: int = _DEFAULT_PORT,
|
|
62
|
+
) -> None:
|
|
63
|
+
"""
|
|
64
|
+
Initialize a Milvus builder.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
d : int
|
|
69
|
+
The dimension of embeddings.
|
|
70
|
+
index_type : ``"FLAT"``, ``"HNSW"``, or ``"IVF_FLAT"``, optional
|
|
71
|
+
The indexing index_type. Default to ``"FLAT"``.
|
|
72
|
+
metric : ``"COSINE"``, ``"IP"``, or ``"L2"``, optional
|
|
73
|
+
The indexing metric. Default to ``"COSINE"``.
|
|
74
|
+
m : int, optional
|
|
75
|
+
Max number of neighbors to build an HNSW index.
|
|
76
|
+
nlist : int, optional
|
|
77
|
+
Number of clusters to build an IVF index.
|
|
78
|
+
ef : int or None, optional
|
|
79
|
+
Number of candidate nodes during an HNSW search.
|
|
80
|
+
nprobe : int or None, optional
|
|
81
|
+
Number of clusters during an IVF search.
|
|
82
|
+
name : str, optional
|
|
83
|
+
The name of connection. Each name corresponds to one connection.
|
|
84
|
+
host : str, optional
|
|
85
|
+
The host of Milvus instance. Default at "localhost".
|
|
86
|
+
port : int, optional
|
|
87
|
+
The port of Milvus instance. Default at 19530
|
|
88
|
+
"""
|
|
89
|
+
self._d = d
|
|
90
|
+
|
|
91
|
+
# connection parameters
|
|
92
|
+
self._name = name
|
|
93
|
+
self._host = host
|
|
94
|
+
self._port = port
|
|
95
|
+
|
|
96
|
+
bparams: dict[str, ty.Any] = {}
|
|
97
|
+
qparams: dict[str, ty.Any] = {}
|
|
98
|
+
|
|
99
|
+
if index_type == "HNSW":
|
|
100
|
+
bparams.update(M=m)
|
|
101
|
+
if ef is not None:
|
|
102
|
+
qparams.update(ef=ef)
|
|
103
|
+
|
|
104
|
+
if index_type == "IVF_FLAT":
|
|
105
|
+
bparams.update(nlist=nlist)
|
|
106
|
+
if nprobe is not None:
|
|
107
|
+
qparams.update(nprobe=nprobe)
|
|
108
|
+
|
|
109
|
+
self._build_params = dict(
|
|
110
|
+
index_type=index_type,
|
|
111
|
+
metric_type=metric,
|
|
112
|
+
params=bparams,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
self._query_params = dict(
|
|
116
|
+
metric_type=metric,
|
|
117
|
+
params=qparams,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
@contextlib.contextmanager
|
|
121
|
+
def from_embeddings(
|
|
122
|
+
self,
|
|
123
|
+
embeddings: torch.Tensor,
|
|
124
|
+
) -> ty.Generator["MilvusIndexer",
|
|
125
|
+
None,
|
|
126
|
+
None]:
|
|
127
|
+
"""Adhere to ``BaseBuilder.from_embeddings``."""
|
|
128
|
+
milvus.connections.connect(self._name, host=self._host, port=self._port)
|
|
129
|
+
collection = self._build_collection(embeddings)
|
|
130
|
+
|
|
131
|
+
try:
|
|
132
|
+
yield MilvusIndexer(collection, self._query_params)
|
|
133
|
+
finally:
|
|
134
|
+
collection.drop()
|
|
135
|
+
milvus.connections.disconnect(self._name)
|
|
136
|
+
|
|
137
|
+
@contextlib.contextmanager
|
|
138
|
+
def from_index_file(
|
|
139
|
+
self,
|
|
140
|
+
index_file: FilePath,
|
|
141
|
+
) -> ty.Generator["MilvusIndexer",
|
|
142
|
+
None,
|
|
143
|
+
None]:
|
|
144
|
+
"""Adhere to ``BaseBuilder.from_index_file``."""
|
|
145
|
+
raise NotImplementedError("Milvus does not support index files!")
|
|
146
|
+
|
|
147
|
+
def _build_collection(self, embeddings: torch.Tensor) -> milvus.Collection:
|
|
148
|
+
"""Build a Milvus collection with the current connection."""
|
|
149
|
+
fields = [
|
|
150
|
+
milvus.FieldSchema(
|
|
151
|
+
name="id",
|
|
152
|
+
dtype=milvus.DataType.INT64,
|
|
153
|
+
is_primary=True,
|
|
154
|
+
),
|
|
155
|
+
milvus.FieldSchema(
|
|
156
|
+
name=_EMBEDDING_COLUMN,
|
|
157
|
+
dtype=milvus.DataType.FLOAT_VECTOR,
|
|
158
|
+
dim=self._d,
|
|
159
|
+
),
|
|
160
|
+
]
|
|
161
|
+
|
|
162
|
+
collection = milvus.Collection(
|
|
163
|
+
name=f"{self._name}_{uuid.uuid4().hex}",
|
|
164
|
+
schema=milvus.CollectionSchema(fields=fields),
|
|
165
|
+
using=self._name,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
n, _ = embeddings.shape
|
|
169
|
+
collection.insert([np.arange(n, dtype=np.int64), embeddings.cpu().numpy()])
|
|
170
|
+
collection.create_index(_EMBEDDING_COLUMN, index_params=self._build_params)
|
|
171
|
+
collection.load()
|
|
172
|
+
|
|
173
|
+
return collection
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class MilvusIndexer(BaseIndexer):
|
|
177
|
+
"""Milvus-based implementation of ``BaseIndexer``."""
|
|
178
|
+
|
|
179
|
+
def __init__(
|
|
180
|
+
self,
|
|
181
|
+
collection: milvus.Collection,
|
|
182
|
+
query_params: dict[str,
|
|
183
|
+
ty.Any],
|
|
184
|
+
) -> None:
|
|
185
|
+
"""Initialize a Milvus indexer."""
|
|
186
|
+
self._collection = collection
|
|
187
|
+
self._query_params = query_params
|
|
188
|
+
|
|
189
|
+
def query(
|
|
190
|
+
self,
|
|
191
|
+
embeddings: torch.Tensor,
|
|
192
|
+
top_k: int,
|
|
193
|
+
) -> tuple[torch.Tensor,
|
|
194
|
+
torch.Tensor]:
|
|
195
|
+
"""Adhere to ``BaseIndexer.query``."""
|
|
196
|
+
results = self._collection.search(
|
|
197
|
+
data=embeddings.cpu().numpy(),
|
|
198
|
+
anns_field=_EMBEDDING_COLUMN,
|
|
199
|
+
param=self._query_params,
|
|
200
|
+
limit=top_k,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
n, _ = embeddings.shape
|
|
204
|
+
nn_ids = np.zeros((n, top_k), dtype=np.int64)
|
|
205
|
+
nn_distances = np.zeros((n, top_k), dtype=np.float32)
|
|
206
|
+
|
|
207
|
+
for i, result in enumerate(results):
|
|
208
|
+
nn_ids[i] = result.ids
|
|
209
|
+
nn_distances[i] = result.distances
|
|
210
|
+
|
|
211
|
+
return torch.from_numpy(nn_ids), torch.from_numpy(nn_distances)
|
|
212
|
+
|
|
213
|
+
def save(self, file_path: FilePath) -> None:
|
|
214
|
+
"""Adhere to ``BaseIndexer.save``."""
|
|
215
|
+
raise NotImplementedError("Milvus does not support index files!")
|
|
@@ -43,6 +43,7 @@ class CTRTrainer(object):
|
|
|
43
43
|
gpus=None,
|
|
44
44
|
loss_mode=True,
|
|
45
45
|
model_path="./",
|
|
46
|
+
model_logger=None,
|
|
46
47
|
):
|
|
47
48
|
self.model = model # for uniform weights save method in one gpu or multi gpu
|
|
48
49
|
if gpus is None:
|
|
@@ -70,10 +71,13 @@ class CTRTrainer(object):
|
|
|
70
71
|
self.model_path = model_path
|
|
71
72
|
# Initialize regularization loss
|
|
72
73
|
self.reg_loss_fn = RegularizationLoss(**regularization_params)
|
|
74
|
+
self.model_logger = model_logger
|
|
73
75
|
|
|
74
76
|
def train_one_epoch(self, data_loader, log_interval=10):
|
|
75
77
|
self.model.train()
|
|
76
78
|
total_loss = 0
|
|
79
|
+
epoch_loss = 0
|
|
80
|
+
batch_count = 0
|
|
77
81
|
tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
|
|
78
82
|
for i, (x_dict, y) in enumerate(tk0):
|
|
79
83
|
x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
|
|
@@ -93,27 +97,62 @@ class CTRTrainer(object):
|
|
|
93
97
|
loss.backward()
|
|
94
98
|
self.optimizer.step()
|
|
95
99
|
total_loss += loss.item()
|
|
100
|
+
epoch_loss += loss.item()
|
|
101
|
+
batch_count += 1
|
|
96
102
|
if (i + 1) % log_interval == 0:
|
|
97
103
|
tk0.set_postfix(loss=total_loss / log_interval)
|
|
98
104
|
total_loss = 0
|
|
99
105
|
|
|
106
|
+
# Return average epoch loss
|
|
107
|
+
return epoch_loss / batch_count if batch_count > 0 else 0
|
|
108
|
+
|
|
100
109
|
def fit(self, train_dataloader, val_dataloader=None):
|
|
110
|
+
for logger in self._iter_loggers():
|
|
111
|
+
logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self.optimizer.param_groups[0]['lr'], 'loss_mode': self.loss_mode})
|
|
112
|
+
|
|
101
113
|
for epoch_i in range(self.n_epoch):
|
|
102
114
|
print('epoch:', epoch_i)
|
|
103
|
-
self.train_one_epoch(train_dataloader)
|
|
115
|
+
train_loss = self.train_one_epoch(train_dataloader)
|
|
116
|
+
|
|
117
|
+
for logger in self._iter_loggers():
|
|
118
|
+
logger.log_metrics({'train/loss': train_loss, 'learning_rate': self.optimizer.param_groups[0]['lr']}, step=epoch_i)
|
|
119
|
+
|
|
104
120
|
if self.scheduler is not None:
|
|
105
121
|
if epoch_i % self.scheduler.step_size == 0:
|
|
106
122
|
print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
|
|
107
123
|
self.scheduler.step() # update lr in epoch level by scheduler
|
|
124
|
+
|
|
108
125
|
if val_dataloader:
|
|
109
126
|
auc = self.evaluate(self.model, val_dataloader)
|
|
110
127
|
print('epoch:', epoch_i, 'validation: auc:', auc)
|
|
128
|
+
|
|
129
|
+
for logger in self._iter_loggers():
|
|
130
|
+
logger.log_metrics({'val/auc': auc}, step=epoch_i)
|
|
131
|
+
|
|
111
132
|
if self.early_stopper.stop_training(auc, self.model.state_dict()):
|
|
112
133
|
print(f'validation: best auc: {self.early_stopper.best_auc}')
|
|
113
134
|
self.model.load_state_dict(self.early_stopper.best_weights)
|
|
114
135
|
break
|
|
136
|
+
|
|
115
137
|
torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best auc model
|
|
116
138
|
|
|
139
|
+
for logger in self._iter_loggers():
|
|
140
|
+
logger.finish()
|
|
141
|
+
|
|
142
|
+
def _iter_loggers(self):
|
|
143
|
+
"""Return logger instances as a list.
|
|
144
|
+
|
|
145
|
+
Returns
|
|
146
|
+
-------
|
|
147
|
+
list
|
|
148
|
+
Active logger instances. Empty when ``model_logger`` is ``None``.
|
|
149
|
+
"""
|
|
150
|
+
if self.model_logger is None:
|
|
151
|
+
return []
|
|
152
|
+
if isinstance(self.model_logger, (list, tuple)):
|
|
153
|
+
return list(self.model_logger)
|
|
154
|
+
return [self.model_logger]
|
|
155
|
+
|
|
117
156
|
def evaluate(self, model, data_loader):
|
|
118
157
|
model.eval()
|
|
119
158
|
targets, predicts = list(), list()
|
|
@@ -146,7 +185,7 @@ class CTRTrainer(object):
|
|
|
146
185
|
predicts.extend(y_pred.tolist())
|
|
147
186
|
return predicts
|
|
148
187
|
|
|
149
|
-
def export_onnx(self, output_path, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False):
|
|
188
|
+
def export_onnx(self, output_path, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False, onnx_export_kwargs=None):
|
|
150
189
|
"""Export the trained model to ONNX format.
|
|
151
190
|
|
|
152
191
|
This method exports the ranking model (e.g., DeepFM, WideDeep, DCN) to ONNX format
|
|
@@ -163,6 +202,7 @@ class CTRTrainer(object):
|
|
|
163
202
|
device (str, optional): Device for export ('cpu', 'cuda', etc.).
|
|
164
203
|
If None, defaults to 'cpu' for maximum compatibility.
|
|
165
204
|
verbose (bool): Print export details (default: False).
|
|
205
|
+
onnx_export_kwargs (dict, optional): Extra kwargs forwarded to ``torch.onnx.export``.
|
|
166
206
|
|
|
167
207
|
Returns:
|
|
168
208
|
bool: True if export succeeded, False otherwise.
|
|
@@ -188,7 +228,16 @@ class CTRTrainer(object):
|
|
|
188
228
|
export_device = device if device is not None else 'cpu'
|
|
189
229
|
|
|
190
230
|
exporter = ONNXExporter(model, device=export_device)
|
|
191
|
-
return exporter.export(
|
|
231
|
+
return exporter.export(
|
|
232
|
+
output_path=output_path,
|
|
233
|
+
dummy_input=dummy_input,
|
|
234
|
+
batch_size=batch_size,
|
|
235
|
+
seq_length=seq_length,
|
|
236
|
+
opset_version=opset_version,
|
|
237
|
+
dynamic_batch=dynamic_batch,
|
|
238
|
+
verbose=verbose,
|
|
239
|
+
onnx_export_kwargs=onnx_export_kwargs,
|
|
240
|
+
)
|
|
192
241
|
|
|
193
242
|
def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
|
|
194
243
|
"""Visualize the model's computation graph.
|