torch-rechub 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/basic/tracking.py +198 -0
- torch_rechub/data/__init__.py +0 -0
- torch_rechub/data/convert.py +67 -0
- torch_rechub/data/dataset.py +120 -0
- torch_rechub/trainers/ctr_trainer.py +40 -1
- torch_rechub/trainers/match_trainer.py +39 -1
- torch_rechub/trainers/mtl_trainer.py +49 -1
- torch_rechub/trainers/seq_trainer.py +59 -2
- {torch_rechub-0.0.5.dist-info → torch_rechub-0.0.6.dist-info}/METADATA +13 -5
- {torch_rechub-0.0.5.dist-info → torch_rechub-0.0.6.dist-info}/RECORD +12 -8
- {torch_rechub-0.0.5.dist-info → torch_rechub-0.0.6.dist-info}/WHEEL +0 -0
- {torch_rechub-0.0.5.dist-info → torch_rechub-0.0.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
"""Experiment tracking utilities for Torch-RecHub.
|
|
2
|
+
|
|
3
|
+
This module exposes lightweight adapters for common visualization and
|
|
4
|
+
experiment tracking tools, namely Weights & Biases (wandb), SwanLab, and
|
|
5
|
+
TensorBoardX.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from typing import Any, Dict, List, Optional, Union
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BaseLogger(ABC):
|
|
13
|
+
"""Base interface for experiment tracking backends.
|
|
14
|
+
|
|
15
|
+
Methods
|
|
16
|
+
-------
|
|
17
|
+
log_metrics(metrics, step=None)
|
|
18
|
+
Record scalar metrics at a given step.
|
|
19
|
+
log_hyperparams(params)
|
|
20
|
+
Store hyperparameters and run configuration.
|
|
21
|
+
finish()
|
|
22
|
+
Flush pending logs and release resources.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
|
|
27
|
+
"""Log metrics to the tracking backend.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
metrics : dict of str to Any
|
|
32
|
+
Metric name-value pairs to record.
|
|
33
|
+
step : int, optional
|
|
34
|
+
Explicit global step or epoch index. When ``None``, the backend
|
|
35
|
+
uses its own default step handling.
|
|
36
|
+
"""
|
|
37
|
+
raise NotImplementedError
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def log_hyperparams(self, params: Dict[str, Any]) -> None:
|
|
41
|
+
"""Log experiment hyperparameters.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
params : dict of str to Any
|
|
46
|
+
Hyperparameters or configuration values to persist with the run.
|
|
47
|
+
"""
|
|
48
|
+
raise NotImplementedError
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def finish(self) -> None:
|
|
52
|
+
"""Finalize logging and free any backend resources."""
|
|
53
|
+
raise NotImplementedError
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class WandbLogger(BaseLogger):
|
|
57
|
+
"""Weights & Biases logger implementation.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
project : str
|
|
62
|
+
Name of the wandb project to log to.
|
|
63
|
+
name : str, optional
|
|
64
|
+
Display name for the run.
|
|
65
|
+
config : dict, optional
|
|
66
|
+
Initial hyperparameter configuration to record.
|
|
67
|
+
tags : list of str, optional
|
|
68
|
+
Optional tags for grouping runs.
|
|
69
|
+
notes : str, optional
|
|
70
|
+
Long-form notes shown in the run overview.
|
|
71
|
+
dir : str, optional
|
|
72
|
+
Local directory for wandb artifacts and cache.
|
|
73
|
+
**kwargs : dict
|
|
74
|
+
Additional keyword arguments forwarded to ``wandb.init``.
|
|
75
|
+
|
|
76
|
+
Raises
|
|
77
|
+
------
|
|
78
|
+
ImportError
|
|
79
|
+
If ``wandb`` is not installed in the current environment.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(self, project: str, name: Optional[str] = None, config: Optional[Dict[str, Any]] = None, tags: Optional[List[str]] = None, notes: Optional[str] = None, dir: Optional[str] = None, **kwargs):
|
|
83
|
+
try:
|
|
84
|
+
import wandb
|
|
85
|
+
self._wandb = wandb
|
|
86
|
+
except ImportError:
|
|
87
|
+
raise ImportError("wandb is not installed. Install it with: pip install wandb")
|
|
88
|
+
|
|
89
|
+
self.run = self._wandb.init(project=project, name=name, config=config, tags=tags, notes=notes, dir=dir, **kwargs)
|
|
90
|
+
|
|
91
|
+
def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
|
|
92
|
+
if step is not None:
|
|
93
|
+
self._wandb.log(metrics, step=step)
|
|
94
|
+
else:
|
|
95
|
+
self._wandb.log(metrics)
|
|
96
|
+
|
|
97
|
+
def log_hyperparams(self, params: Dict[str, Any]) -> None:
|
|
98
|
+
if self.run is not None:
|
|
99
|
+
self.run.config.update(params)
|
|
100
|
+
|
|
101
|
+
def finish(self) -> None:
|
|
102
|
+
if self.run is not None:
|
|
103
|
+
self.run.finish()
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class SwanLabLogger(BaseLogger):
|
|
107
|
+
"""SwanLab logger implementation.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
project : str, optional
|
|
112
|
+
Project identifier for grouping experiments.
|
|
113
|
+
experiment_name : str, optional
|
|
114
|
+
Display name for the experiment or run.
|
|
115
|
+
description : str, optional
|
|
116
|
+
Text description shown alongside the run.
|
|
117
|
+
config : dict, optional
|
|
118
|
+
Hyperparameters or configuration to log at startup.
|
|
119
|
+
logdir : str, optional
|
|
120
|
+
Directory where logs and artifacts are stored.
|
|
121
|
+
**kwargs : dict
|
|
122
|
+
Additional keyword arguments forwarded to ``swanlab.init``.
|
|
123
|
+
|
|
124
|
+
Raises
|
|
125
|
+
------
|
|
126
|
+
ImportError
|
|
127
|
+
If ``swanlab`` is not installed in the current environment.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
def __init__(self, project: Optional[str] = None, experiment_name: Optional[str] = None, description: Optional[str] = None, config: Optional[Dict[str, Any]] = None, logdir: Optional[str] = None, **kwargs):
|
|
131
|
+
try:
|
|
132
|
+
import swanlab
|
|
133
|
+
self._swanlab = swanlab
|
|
134
|
+
except ImportError:
|
|
135
|
+
raise ImportError("swanlab is not installed. Install it with: pip install swanlab")
|
|
136
|
+
|
|
137
|
+
self.run = self._swanlab.init(project=project, experiment_name=experiment_name, description=description, config=config, logdir=logdir, **kwargs)
|
|
138
|
+
|
|
139
|
+
def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
|
|
140
|
+
if step is not None:
|
|
141
|
+
self._swanlab.log(metrics, step=step)
|
|
142
|
+
else:
|
|
143
|
+
self._swanlab.log(metrics)
|
|
144
|
+
|
|
145
|
+
def log_hyperparams(self, params: Dict[str, Any]) -> None:
|
|
146
|
+
if self.run is not None:
|
|
147
|
+
self.run.config.update(params)
|
|
148
|
+
|
|
149
|
+
def finish(self) -> None:
|
|
150
|
+
self._swanlab.finish()
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class TensorBoardXLogger(BaseLogger):
|
|
154
|
+
"""TensorBoardX logger implementation.
|
|
155
|
+
|
|
156
|
+
Parameters
|
|
157
|
+
----------
|
|
158
|
+
log_dir : str
|
|
159
|
+
Directory where event files will be written.
|
|
160
|
+
comment : str, default=""
|
|
161
|
+
Comment appended to the log directory name.
|
|
162
|
+
**kwargs : dict
|
|
163
|
+
Additional keyword arguments forwarded to
|
|
164
|
+
``tensorboardX.SummaryWriter``.
|
|
165
|
+
|
|
166
|
+
Raises
|
|
167
|
+
------
|
|
168
|
+
ImportError
|
|
169
|
+
If ``tensorboardX`` is not installed in the current environment.
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
def __init__(self, log_dir: str, comment: str = "", **kwargs):
|
|
173
|
+
try:
|
|
174
|
+
from tensorboardX import SummaryWriter
|
|
175
|
+
self._SummaryWriter = SummaryWriter
|
|
176
|
+
except ImportError:
|
|
177
|
+
raise ImportError("tensorboardX is not installed. Install it with: pip install tensorboardX")
|
|
178
|
+
|
|
179
|
+
self.writer = self._SummaryWriter(log_dir=log_dir, comment=comment, **kwargs)
|
|
180
|
+
self._step = 0
|
|
181
|
+
|
|
182
|
+
def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
|
|
183
|
+
if step is None:
|
|
184
|
+
step = self._step
|
|
185
|
+
self._step += 1
|
|
186
|
+
|
|
187
|
+
for key, value in metrics.items():
|
|
188
|
+
if value is not None:
|
|
189
|
+
if isinstance(value, (int, float)):
|
|
190
|
+
self.writer.add_scalar(key, value, step)
|
|
191
|
+
|
|
192
|
+
def log_hyperparams(self, params: Dict[str, Any]) -> None:
|
|
193
|
+
hparam_str = "\n".join([f"{k}: {v}" for k, v in params.items()])
|
|
194
|
+
self.writer.add_text("hyperparameters", hparam_str, 0)
|
|
195
|
+
|
|
196
|
+
def finish(self) -> None:
|
|
197
|
+
if self.writer is not None:
|
|
198
|
+
self.writer.close()
|
|
File without changes
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Utilities for converting array-like data structures into PyTorch tensors."""
|
|
2
|
+
|
|
3
|
+
import numpy.typing as npt
|
|
4
|
+
import pyarrow as pa
|
|
5
|
+
import pyarrow.compute as pc
|
|
6
|
+
import pyarrow.types as pt
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def pa_array_to_tensor(arr: pa.Array) -> torch.Tensor:
|
|
11
|
+
"""
|
|
12
|
+
Convert a PyArrow array to a PyTorch tensor.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
arr : pa.Array
|
|
17
|
+
The given PyArrow array.
|
|
18
|
+
|
|
19
|
+
Returns
|
|
20
|
+
-------
|
|
21
|
+
torch.Tensor: The result PyTorch tensor.
|
|
22
|
+
|
|
23
|
+
Raises
|
|
24
|
+
------
|
|
25
|
+
TypeError
|
|
26
|
+
if the array type or the value type (when nested) is unsupported.
|
|
27
|
+
ValueError
|
|
28
|
+
if the nested array is ragged (unequal lengths of each row).
|
|
29
|
+
"""
|
|
30
|
+
if _is_supported_scalar(arr.type):
|
|
31
|
+
arr = pc.cast(arr, pa.float32())
|
|
32
|
+
return torch.from_numpy(_to_writable_numpy(arr))
|
|
33
|
+
|
|
34
|
+
if not _is_supported_list(arr.type):
|
|
35
|
+
raise TypeError(f"Unsupported array type: {arr.type}")
|
|
36
|
+
|
|
37
|
+
if not _is_supported_scalar(val_type := arr.type.value_type):
|
|
38
|
+
raise TypeError(f"Unsupported value type in the nested array: {val_type}")
|
|
39
|
+
|
|
40
|
+
if len(pc.unique(pc.list_value_length(arr))) > 1:
|
|
41
|
+
raise ValueError("Cannot convert the ragged nested array.")
|
|
42
|
+
|
|
43
|
+
arr = pc.cast(arr, pa.list_(pa.float32()))
|
|
44
|
+
np_arr = _to_writable_numpy(arr.values) # type: ignore[attr-defined]
|
|
45
|
+
|
|
46
|
+
# For empty list-of-lists, define output shape as (0, 0); otherwise infer width.
|
|
47
|
+
return torch.from_numpy(np_arr.reshape(len(arr), -1 if len(arr) > 0 else 0))
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# helper functions
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _is_supported_list(t: pa.DataType) -> bool:
|
|
54
|
+
"""Check if the given PyArrow data type is a supported list."""
|
|
55
|
+
return pt.is_fixed_size_list(t) or pt.is_large_list(t) or pt.is_list(t)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _is_supported_scalar(t: pa.DataType) -> bool:
|
|
59
|
+
"""Check if the given PyArrow data type is a supported scalar type."""
|
|
60
|
+
return pt.is_boolean(t) or pt.is_floating(t) or pt.is_integer(t) or pt.is_null(t)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _to_writable_numpy(arr: pa.Array) -> npt.NDArray:
|
|
64
|
+
"""Dump a PyArrow array into a writable NumPy array."""
|
|
65
|
+
# Force the NumPy array to be writable. PyArrow's to_numpy() often returns a
|
|
66
|
+
# read-only view for zero-copy, which PyTorch's from_numpy() does not support.
|
|
67
|
+
return arr.to_numpy(writable=True, zero_copy_only=False)
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""Dataset implementations providing streaming, batch-wise data access for PyTorch."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import typing as ty
|
|
5
|
+
|
|
6
|
+
import pyarrow.dataset as pd
|
|
7
|
+
import torch
|
|
8
|
+
from torch.utils.data import IterableDataset, get_worker_info
|
|
9
|
+
|
|
10
|
+
from .convert import pa_array_to_tensor
|
|
11
|
+
|
|
12
|
+
# Type for path to a file
|
|
13
|
+
_FilePath = ty.Union[str, os.PathLike]
|
|
14
|
+
|
|
15
|
+
# The default batch size when reading a Parquet dataset
|
|
16
|
+
_DEFAULT_BATCH_SIZE = 1024
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ParquetIterableDataset(IterableDataset):
|
|
20
|
+
"""
|
|
21
|
+
IterableDataset that streams data from one or more Parquet files.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
file_paths : list[_FilePath]
|
|
26
|
+
Paths to Parquet files.
|
|
27
|
+
columns : list[str], optional
|
|
28
|
+
Column names to select. If ``None``, all columns are read.
|
|
29
|
+
batch_size : int, default DEFAULT_BATCH_SIZE
|
|
30
|
+
Number of rows per streamed batch.
|
|
31
|
+
|
|
32
|
+
Notes
|
|
33
|
+
-----
|
|
34
|
+
This dataset reads data lazily and never loads the entire Parquet dataset to memory.
|
|
35
|
+
The current worker receives a partition of ``file_paths`` and builds its own PyArrow
|
|
36
|
+
Dataset and Scanner. Iteration yields dictionaries mapping column names to PyTorch
|
|
37
|
+
tensors created via NumPy, one batch at a time.
|
|
38
|
+
|
|
39
|
+
Examples
|
|
40
|
+
--------
|
|
41
|
+
>>> ds = ParquetIterableDataset(
|
|
42
|
+
... ["/data/train1.parquet", "/data/train2.parquet"],
|
|
43
|
+
... columns=["x", "y", "label"],
|
|
44
|
+
... batch_size=1024,
|
|
45
|
+
... )
|
|
46
|
+
>>> loader = DataLoader(ds, batch_size=None)
|
|
47
|
+
>>> # Now iterate over batches.
|
|
48
|
+
>>> for batch in loader:
|
|
49
|
+
... x, y, label = batch["x"], batch["y"], batch["label"]
|
|
50
|
+
... # Do some work.
|
|
51
|
+
... ...
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
file_paths: ty.Sequence[_FilePath],
|
|
57
|
+
/,
|
|
58
|
+
columns: ty.Optional[ty.Sequence[str]] = None,
|
|
59
|
+
batch_size: int = _DEFAULT_BATCH_SIZE,
|
|
60
|
+
) -> None:
|
|
61
|
+
"""Initialize this instance."""
|
|
62
|
+
self._file_paths = tuple(map(str, file_paths))
|
|
63
|
+
self._columns = None if columns is None else tuple(columns)
|
|
64
|
+
self._batch_size = batch_size
|
|
65
|
+
|
|
66
|
+
def __iter__(self) -> ty.Iterator[dict[str, torch.Tensor]]:
|
|
67
|
+
"""
|
|
68
|
+
Stream Parquet data as mapped PyTorch tensors.
|
|
69
|
+
|
|
70
|
+
Build a PyArrow Dataset from the current worker's assigned file partition, then
|
|
71
|
+
create a Scanner to lazily read batches of the selected columns. Each batch is
|
|
72
|
+
converted to a dict mapping column names to PyTorch tensors (via NumPy).
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
Iterator[dict[str, torch.Tensor]]
|
|
77
|
+
An iterator that yields one converted batch at a time.
|
|
78
|
+
"""
|
|
79
|
+
if not (partition := self._get_partition()):
|
|
80
|
+
return
|
|
81
|
+
|
|
82
|
+
# Build the dataset for the current worker.
|
|
83
|
+
ds = pd.dataset(partition, format="parquet")
|
|
84
|
+
|
|
85
|
+
# Create a scanner. This does not read data.
|
|
86
|
+
columns = None if self._columns is None else list(self._columns)
|
|
87
|
+
scanner = ds.scanner(columns=columns, batch_size=self._batch_size)
|
|
88
|
+
|
|
89
|
+
for batch in scanner.to_batches():
|
|
90
|
+
data_dict: dict[str, torch.Tensor] = {}
|
|
91
|
+
for name, array in zip(batch.column_names, batch.columns):
|
|
92
|
+
data_dict[name] = pa_array_to_tensor(array)
|
|
93
|
+
yield data_dict
|
|
94
|
+
|
|
95
|
+
# private interfaces
|
|
96
|
+
|
|
97
|
+
def _get_partition(self) -> tuple[str, ...]:
|
|
98
|
+
"""
|
|
99
|
+
Get the partition of file paths for the current worker.
|
|
100
|
+
|
|
101
|
+
This method splits the full list of file paths into contiguous partitions with
|
|
102
|
+
a nearly equal size by the total number of workers and the current worker ID.
|
|
103
|
+
|
|
104
|
+
If running in the main process (i.e., no worker information is available), the
|
|
105
|
+
entire list of file paths is returned.
|
|
106
|
+
|
|
107
|
+
Returns
|
|
108
|
+
-------
|
|
109
|
+
tuple[str, ...]
|
|
110
|
+
The partition of file paths for the current worker.
|
|
111
|
+
"""
|
|
112
|
+
if (info := get_worker_info()) is None:
|
|
113
|
+
return self._file_paths
|
|
114
|
+
|
|
115
|
+
n = len(self._file_paths)
|
|
116
|
+
per_worker = (n + info.num_workers - 1) // info.num_workers
|
|
117
|
+
|
|
118
|
+
start = info.id * per_worker
|
|
119
|
+
end = n if (end := start + per_worker) > n else end
|
|
120
|
+
return self._file_paths[start:end]
|
|
@@ -43,6 +43,7 @@ class CTRTrainer(object):
|
|
|
43
43
|
gpus=None,
|
|
44
44
|
loss_mode=True,
|
|
45
45
|
model_path="./",
|
|
46
|
+
model_logger=None,
|
|
46
47
|
):
|
|
47
48
|
self.model = model # for uniform weights save method in one gpu or multi gpu
|
|
48
49
|
if gpus is None:
|
|
@@ -70,10 +71,13 @@ class CTRTrainer(object):
|
|
|
70
71
|
self.model_path = model_path
|
|
71
72
|
# Initialize regularization loss
|
|
72
73
|
self.reg_loss_fn = RegularizationLoss(**regularization_params)
|
|
74
|
+
self.model_logger = model_logger
|
|
73
75
|
|
|
74
76
|
def train_one_epoch(self, data_loader, log_interval=10):
|
|
75
77
|
self.model.train()
|
|
76
78
|
total_loss = 0
|
|
79
|
+
epoch_loss = 0
|
|
80
|
+
batch_count = 0
|
|
77
81
|
tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
|
|
78
82
|
for i, (x_dict, y) in enumerate(tk0):
|
|
79
83
|
x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
|
|
@@ -93,27 +97,62 @@ class CTRTrainer(object):
|
|
|
93
97
|
loss.backward()
|
|
94
98
|
self.optimizer.step()
|
|
95
99
|
total_loss += loss.item()
|
|
100
|
+
epoch_loss += loss.item()
|
|
101
|
+
batch_count += 1
|
|
96
102
|
if (i + 1) % log_interval == 0:
|
|
97
103
|
tk0.set_postfix(loss=total_loss / log_interval)
|
|
98
104
|
total_loss = 0
|
|
99
105
|
|
|
106
|
+
# Return average epoch loss
|
|
107
|
+
return epoch_loss / batch_count if batch_count > 0 else 0
|
|
108
|
+
|
|
100
109
|
def fit(self, train_dataloader, val_dataloader=None):
|
|
110
|
+
for logger in self._iter_loggers():
|
|
111
|
+
logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self.optimizer.param_groups[0]['lr'], 'loss_mode': self.loss_mode})
|
|
112
|
+
|
|
101
113
|
for epoch_i in range(self.n_epoch):
|
|
102
114
|
print('epoch:', epoch_i)
|
|
103
|
-
self.train_one_epoch(train_dataloader)
|
|
115
|
+
train_loss = self.train_one_epoch(train_dataloader)
|
|
116
|
+
|
|
117
|
+
for logger in self._iter_loggers():
|
|
118
|
+
logger.log_metrics({'train/loss': train_loss, 'learning_rate': self.optimizer.param_groups[0]['lr']}, step=epoch_i)
|
|
119
|
+
|
|
104
120
|
if self.scheduler is not None:
|
|
105
121
|
if epoch_i % self.scheduler.step_size == 0:
|
|
106
122
|
print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
|
|
107
123
|
self.scheduler.step() # update lr in epoch level by scheduler
|
|
124
|
+
|
|
108
125
|
if val_dataloader:
|
|
109
126
|
auc = self.evaluate(self.model, val_dataloader)
|
|
110
127
|
print('epoch:', epoch_i, 'validation: auc:', auc)
|
|
128
|
+
|
|
129
|
+
for logger in self._iter_loggers():
|
|
130
|
+
logger.log_metrics({'val/auc': auc}, step=epoch_i)
|
|
131
|
+
|
|
111
132
|
if self.early_stopper.stop_training(auc, self.model.state_dict()):
|
|
112
133
|
print(f'validation: best auc: {self.early_stopper.best_auc}')
|
|
113
134
|
self.model.load_state_dict(self.early_stopper.best_weights)
|
|
114
135
|
break
|
|
136
|
+
|
|
115
137
|
torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best auc model
|
|
116
138
|
|
|
139
|
+
for logger in self._iter_loggers():
|
|
140
|
+
logger.finish()
|
|
141
|
+
|
|
142
|
+
def _iter_loggers(self):
|
|
143
|
+
"""Return logger instances as a list.
|
|
144
|
+
|
|
145
|
+
Returns
|
|
146
|
+
-------
|
|
147
|
+
list
|
|
148
|
+
Active logger instances. Empty when ``model_logger`` is ``None``.
|
|
149
|
+
"""
|
|
150
|
+
if self.model_logger is None:
|
|
151
|
+
return []
|
|
152
|
+
if isinstance(self.model_logger, (list, tuple)):
|
|
153
|
+
return list(self.model_logger)
|
|
154
|
+
return [self.model_logger]
|
|
155
|
+
|
|
117
156
|
def evaluate(self, model, data_loader):
|
|
118
157
|
model.eval()
|
|
119
158
|
targets, predicts = list(), list()
|
|
@@ -39,6 +39,7 @@ class MatchTrainer(object):
|
|
|
39
39
|
device="cpu",
|
|
40
40
|
gpus=None,
|
|
41
41
|
model_path="./",
|
|
42
|
+
model_logger=None,
|
|
42
43
|
):
|
|
43
44
|
self.model = model # for uniform weights save method in one gpu or multi gpu
|
|
44
45
|
if gpus is None:
|
|
@@ -73,10 +74,13 @@ class MatchTrainer(object):
|
|
|
73
74
|
self.model_path = model_path
|
|
74
75
|
# Initialize regularization loss
|
|
75
76
|
self.reg_loss_fn = RegularizationLoss(**regularization_params)
|
|
77
|
+
self.model_logger = model_logger
|
|
76
78
|
|
|
77
79
|
def train_one_epoch(self, data_loader, log_interval=10):
|
|
78
80
|
self.model.train()
|
|
79
81
|
total_loss = 0
|
|
82
|
+
epoch_loss = 0
|
|
83
|
+
batch_count = 0
|
|
80
84
|
tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
|
|
81
85
|
for i, (x_dict, y) in enumerate(tk0):
|
|
82
86
|
x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
|
|
@@ -114,14 +118,26 @@ class MatchTrainer(object):
|
|
|
114
118
|
loss.backward()
|
|
115
119
|
self.optimizer.step()
|
|
116
120
|
total_loss += loss.item()
|
|
121
|
+
epoch_loss += loss.item()
|
|
122
|
+
batch_count += 1
|
|
117
123
|
if (i + 1) % log_interval == 0:
|
|
118
124
|
tk0.set_postfix(loss=total_loss / log_interval)
|
|
119
125
|
total_loss = 0
|
|
120
126
|
|
|
127
|
+
# Return average epoch loss
|
|
128
|
+
return epoch_loss / batch_count if batch_count > 0 else 0
|
|
129
|
+
|
|
121
130
|
def fit(self, train_dataloader, val_dataloader=None):
|
|
131
|
+
for logger in self._iter_loggers():
|
|
132
|
+
logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self.optimizer.param_groups[0]['lr'], 'loss_mode': self.mode})
|
|
133
|
+
|
|
122
134
|
for epoch_i in range(self.n_epoch):
|
|
123
135
|
print('epoch:', epoch_i)
|
|
124
|
-
self.train_one_epoch(train_dataloader)
|
|
136
|
+
train_loss = self.train_one_epoch(train_dataloader)
|
|
137
|
+
|
|
138
|
+
for logger in self._iter_loggers():
|
|
139
|
+
logger.log_metrics({'train/loss': train_loss, 'learning_rate': self.optimizer.param_groups[0]['lr']}, step=epoch_i)
|
|
140
|
+
|
|
125
141
|
if self.scheduler is not None:
|
|
126
142
|
if epoch_i % self.scheduler.step_size == 0:
|
|
127
143
|
print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
|
|
@@ -130,12 +146,34 @@ class MatchTrainer(object):
|
|
|
130
146
|
if val_dataloader:
|
|
131
147
|
auc = self.evaluate(self.model, val_dataloader)
|
|
132
148
|
print('epoch:', epoch_i, 'validation: auc:', auc)
|
|
149
|
+
|
|
150
|
+
for logger in self._iter_loggers():
|
|
151
|
+
logger.log_metrics({'val/auc': auc}, step=epoch_i)
|
|
152
|
+
|
|
133
153
|
if self.early_stopper.stop_training(auc, self.model.state_dict()):
|
|
134
154
|
print(f'validation: best auc: {self.early_stopper.best_auc}')
|
|
135
155
|
self.model.load_state_dict(self.early_stopper.best_weights)
|
|
136
156
|
break
|
|
157
|
+
|
|
137
158
|
torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best auc model
|
|
138
159
|
|
|
160
|
+
for logger in self._iter_loggers():
|
|
161
|
+
logger.finish()
|
|
162
|
+
|
|
163
|
+
def _iter_loggers(self):
|
|
164
|
+
"""Return logger instances as a list.
|
|
165
|
+
|
|
166
|
+
Returns
|
|
167
|
+
-------
|
|
168
|
+
list
|
|
169
|
+
Active logger instances. Empty when ``model_logger`` is ``None``.
|
|
170
|
+
"""
|
|
171
|
+
if self.model_logger is None:
|
|
172
|
+
return []
|
|
173
|
+
if isinstance(self.model_logger, (list, tuple)):
|
|
174
|
+
return list(self.model_logger)
|
|
175
|
+
return [self.model_logger]
|
|
176
|
+
|
|
139
177
|
def evaluate(self, model, data_loader):
|
|
140
178
|
model.eval()
|
|
141
179
|
targets, predicts = list(), list()
|
|
@@ -47,6 +47,7 @@ class MTLTrainer(object):
|
|
|
47
47
|
device="cpu",
|
|
48
48
|
gpus=None,
|
|
49
49
|
model_path="./",
|
|
50
|
+
model_logger=None,
|
|
50
51
|
):
|
|
51
52
|
self.model = model
|
|
52
53
|
if gpus is None:
|
|
@@ -104,6 +105,7 @@ class MTLTrainer(object):
|
|
|
104
105
|
self.model_path = model_path
|
|
105
106
|
# Initialize regularization loss
|
|
106
107
|
self.reg_loss_fn = RegularizationLoss(**regularization_params)
|
|
108
|
+
self.model_logger = model_logger
|
|
107
109
|
|
|
108
110
|
def train_one_epoch(self, data_loader):
|
|
109
111
|
self.model.train()
|
|
@@ -163,21 +165,42 @@ class MTLTrainer(object):
|
|
|
163
165
|
def fit(self, train_dataloader, val_dataloader, mode='base', seed=0):
|
|
164
166
|
total_log = []
|
|
165
167
|
|
|
168
|
+
# Log hyperparameters once
|
|
169
|
+
for logger in self._iter_loggers():
|
|
170
|
+
logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self._current_lr(), 'adaptive_method': self.adaptive_method})
|
|
171
|
+
|
|
166
172
|
for epoch_i in range(self.n_epoch):
|
|
167
173
|
_log_per_epoch = self.train_one_epoch(train_dataloader)
|
|
168
174
|
|
|
175
|
+
# Collect metrics
|
|
176
|
+
logs = {f'train/task_{task_id}_loss': loss_val for task_id, loss_val in enumerate(_log_per_epoch)}
|
|
177
|
+
lr_value = self._current_lr()
|
|
178
|
+
if lr_value is not None:
|
|
179
|
+
logs['learning_rate'] = lr_value
|
|
180
|
+
|
|
169
181
|
if self.scheduler is not None:
|
|
170
182
|
if epoch_i % self.scheduler.step_size == 0:
|
|
171
183
|
print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
|
|
172
184
|
self.scheduler.step() # update lr in epoch level by scheduler
|
|
185
|
+
|
|
173
186
|
scores = self.evaluate(self.model, val_dataloader)
|
|
174
187
|
print('epoch:', epoch_i, 'validation scores: ', scores)
|
|
175
188
|
|
|
176
|
-
for score in scores:
|
|
189
|
+
for task_id, score in enumerate(scores):
|
|
190
|
+
logs[f'val/task_{task_id}_score'] = score
|
|
177
191
|
_log_per_epoch.append(score)
|
|
192
|
+
logs['auc'] = scores[self.earlystop_taskid]
|
|
193
|
+
|
|
194
|
+
if self.loss_weight:
|
|
195
|
+
for task_id, weight in enumerate(self.loss_weight):
|
|
196
|
+
logs[f'loss_weight/task_{task_id}'] = weight.item()
|
|
178
197
|
|
|
179
198
|
total_log.append(_log_per_epoch)
|
|
180
199
|
|
|
200
|
+
# Log metrics once per epoch
|
|
201
|
+
for logger in self._iter_loggers():
|
|
202
|
+
logger.log_metrics(logs, step=epoch_i)
|
|
203
|
+
|
|
181
204
|
if self.early_stopper.stop_training(scores[self.earlystop_taskid], self.model.state_dict()):
|
|
182
205
|
print('validation best auc of main task %d: %.6f' % (self.earlystop_taskid, self.early_stopper.best_auc))
|
|
183
206
|
self.model.load_state_dict(self.early_stopper.best_weights)
|
|
@@ -185,8 +208,33 @@ class MTLTrainer(object):
|
|
|
185
208
|
|
|
186
209
|
torch.save(self.model.state_dict(), os.path.join(self.model_path, "model_{}_{}.pth".format(mode, seed))) # save best auc model
|
|
187
210
|
|
|
211
|
+
for logger in self._iter_loggers():
|
|
212
|
+
logger.finish()
|
|
213
|
+
|
|
188
214
|
return total_log
|
|
189
215
|
|
|
216
|
+
def _iter_loggers(self):
|
|
217
|
+
"""Return logger instances as a list.
|
|
218
|
+
|
|
219
|
+
Returns
|
|
220
|
+
-------
|
|
221
|
+
list
|
|
222
|
+
Active logger instances. Empty when ``model_logger`` is ``None``.
|
|
223
|
+
"""
|
|
224
|
+
if self.model_logger is None:
|
|
225
|
+
return []
|
|
226
|
+
if isinstance(self.model_logger, (list, tuple)):
|
|
227
|
+
return list(self.model_logger)
|
|
228
|
+
return [self.model_logger]
|
|
229
|
+
|
|
230
|
+
def _current_lr(self):
|
|
231
|
+
"""Fetch current learning rate regardless of adaptive method."""
|
|
232
|
+
if self.adaptive_method == "metabalance":
|
|
233
|
+
return self.share_optimizer.param_groups[0]['lr'] if hasattr(self, 'share_optimizer') else None
|
|
234
|
+
if hasattr(self, 'optimizer'):
|
|
235
|
+
return self.optimizer.param_groups[0]['lr']
|
|
236
|
+
return None
|
|
237
|
+
|
|
190
238
|
def evaluate(self, model, data_loader):
|
|
191
239
|
model.eval()
|
|
192
240
|
targets, predicts = list(), list()
|
|
@@ -46,7 +46,22 @@ class SeqTrainer(object):
|
|
|
46
46
|
... )
|
|
47
47
|
"""
|
|
48
48
|
|
|
49
|
-
def __init__(
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
model,
|
|
52
|
+
optimizer_fn=torch.optim.Adam,
|
|
53
|
+
optimizer_params=None,
|
|
54
|
+
scheduler_fn=None,
|
|
55
|
+
scheduler_params=None,
|
|
56
|
+
n_epoch=10,
|
|
57
|
+
earlystop_patience=10,
|
|
58
|
+
device='cpu',
|
|
59
|
+
gpus=None,
|
|
60
|
+
model_path='./',
|
|
61
|
+
loss_type='cross_entropy',
|
|
62
|
+
loss_params=None,
|
|
63
|
+
model_logger=None
|
|
64
|
+
):
|
|
50
65
|
self.model = model # for uniform weights save method in one gpu or multi gpu
|
|
51
66
|
if gpus is None:
|
|
52
67
|
gpus = []
|
|
@@ -74,9 +89,11 @@ class SeqTrainer(object):
|
|
|
74
89
|
loss_params = {"ignore_index": 0}
|
|
75
90
|
self.loss_fn = nn.CrossEntropyLoss(**loss_params)
|
|
76
91
|
|
|
92
|
+
self.loss_type = loss_type
|
|
77
93
|
self.n_epoch = n_epoch
|
|
78
94
|
self.early_stopper = EarlyStopper(patience=earlystop_patience)
|
|
79
95
|
self.model_path = model_path
|
|
96
|
+
self.model_logger = model_logger
|
|
80
97
|
|
|
81
98
|
def fit(self, train_dataloader, val_dataloader=None):
|
|
82
99
|
"""训练模型.
|
|
@@ -90,10 +107,18 @@ class SeqTrainer(object):
|
|
|
90
107
|
"""
|
|
91
108
|
history = {'train_loss': [], 'val_loss': [], 'val_accuracy': []}
|
|
92
109
|
|
|
110
|
+
for logger in self._iter_loggers():
|
|
111
|
+
logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self.optimizer.param_groups[0]['lr'], 'loss_type': self.loss_type})
|
|
112
|
+
|
|
93
113
|
for epoch_i in range(self.n_epoch):
|
|
94
114
|
print('epoch:', epoch_i)
|
|
95
115
|
# 训练阶段
|
|
96
|
-
self.train_one_epoch(train_dataloader)
|
|
116
|
+
train_loss = self.train_one_epoch(train_dataloader)
|
|
117
|
+
history['train_loss'].append(train_loss)
|
|
118
|
+
|
|
119
|
+
# Collect metrics
|
|
120
|
+
logs = {'train/loss': train_loss, 'learning_rate': self.optimizer.param_groups[0]['lr']}
|
|
121
|
+
|
|
97
122
|
if self.scheduler is not None:
|
|
98
123
|
if epoch_i % self.scheduler.step_size == 0:
|
|
99
124
|
print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
|
|
@@ -105,6 +130,10 @@ class SeqTrainer(object):
|
|
|
105
130
|
history['val_loss'].append(val_loss)
|
|
106
131
|
history['val_accuracy'].append(val_accuracy)
|
|
107
132
|
|
|
133
|
+
logs['val/loss'] = val_loss
|
|
134
|
+
logs['val/accuracy'] = val_accuracy
|
|
135
|
+
logs['auc'] = val_accuracy # For compatibility with EarlyStopper
|
|
136
|
+
|
|
108
137
|
print(f"epoch: {epoch_i}, validation: loss: {val_loss:.4f}, accuracy: {val_accuracy:.4f}")
|
|
109
138
|
|
|
110
139
|
# 早停
|
|
@@ -113,9 +142,30 @@ class SeqTrainer(object):
|
|
|
113
142
|
self.model.load_state_dict(self.early_stopper.best_weights)
|
|
114
143
|
break
|
|
115
144
|
|
|
145
|
+
for logger in self._iter_loggers():
|
|
146
|
+
logger.log_metrics(logs, step=epoch_i)
|
|
147
|
+
|
|
116
148
|
torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best model
|
|
149
|
+
|
|
150
|
+
for logger in self._iter_loggers():
|
|
151
|
+
logger.finish()
|
|
152
|
+
|
|
117
153
|
return history
|
|
118
154
|
|
|
155
|
+
def _iter_loggers(self):
|
|
156
|
+
"""Return logger instances as a list.
|
|
157
|
+
|
|
158
|
+
Returns
|
|
159
|
+
-------
|
|
160
|
+
list
|
|
161
|
+
Active logger instances. Empty when ``model_logger`` is ``None``.
|
|
162
|
+
"""
|
|
163
|
+
if self.model_logger is None:
|
|
164
|
+
return []
|
|
165
|
+
if isinstance(self.model_logger, (list, tuple)):
|
|
166
|
+
return list(self.model_logger)
|
|
167
|
+
return [self.model_logger]
|
|
168
|
+
|
|
119
169
|
def train_one_epoch(self, data_loader, log_interval=10):
|
|
120
170
|
"""Train the model for a single epoch.
|
|
121
171
|
|
|
@@ -128,6 +178,8 @@ class SeqTrainer(object):
|
|
|
128
178
|
"""
|
|
129
179
|
self.model.train()
|
|
130
180
|
total_loss = 0
|
|
181
|
+
epoch_loss = 0
|
|
182
|
+
batch_count = 0
|
|
131
183
|
tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
|
|
132
184
|
for i, (seq_tokens, seq_positions, seq_time_diffs, targets) in enumerate(tk0):
|
|
133
185
|
# Move tensors to the target device
|
|
@@ -152,10 +204,15 @@ class SeqTrainer(object):
|
|
|
152
204
|
self.optimizer.step()
|
|
153
205
|
|
|
154
206
|
total_loss += loss.item()
|
|
207
|
+
epoch_loss += loss.item()
|
|
208
|
+
batch_count += 1
|
|
155
209
|
if (i + 1) % log_interval == 0:
|
|
156
210
|
tk0.set_postfix(loss=total_loss / log_interval)
|
|
157
211
|
total_loss = 0
|
|
158
212
|
|
|
213
|
+
# Return average epoch loss
|
|
214
|
+
return epoch_loss / batch_count if batch_count > 0 else 0
|
|
215
|
+
|
|
159
216
|
def evaluate(self, data_loader):
|
|
160
217
|
"""Evaluate the model on a validation/test data loader.
|
|
161
218
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torch-rechub
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.6
|
|
4
4
|
Summary: A Pytorch Toolbox for Recommendation Models, Easy-to-use and Easy-to-extend.
|
|
5
5
|
Project-URL: Homepage, https://github.com/datawhalechina/torch-rechub
|
|
6
6
|
Project-URL: Documentation, https://www.torch-rechub.com
|
|
@@ -28,19 +28,26 @@ Requires-Dist: scikit-learn>=0.24.0
|
|
|
28
28
|
Requires-Dist: torch>=1.10.0
|
|
29
29
|
Requires-Dist: tqdm>=4.60.0
|
|
30
30
|
Requires-Dist: transformers>=4.46.3
|
|
31
|
+
Provides-Extra: bigdata
|
|
32
|
+
Requires-Dist: pyarrow~=21.0; extra == 'bigdata'
|
|
31
33
|
Provides-Extra: dev
|
|
32
34
|
Requires-Dist: bandit>=1.7.0; extra == 'dev'
|
|
33
35
|
Requires-Dist: flake8>=3.8.0; extra == 'dev'
|
|
34
36
|
Requires-Dist: isort==5.13.2; extra == 'dev'
|
|
35
37
|
Requires-Dist: mypy>=0.800; extra == 'dev'
|
|
36
38
|
Requires-Dist: pre-commit>=2.20.0; extra == 'dev'
|
|
39
|
+
Requires-Dist: pyarrow-stubs>=20.0; extra == 'dev'
|
|
37
40
|
Requires-Dist: pytest-cov>=2.0; extra == 'dev'
|
|
38
41
|
Requires-Dist: pytest>=6.0; extra == 'dev'
|
|
39
42
|
Requires-Dist: toml>=0.10.2; extra == 'dev'
|
|
40
43
|
Requires-Dist: yapf==0.43.0; extra == 'dev'
|
|
41
44
|
Provides-Extra: onnx
|
|
42
|
-
Requires-Dist: onnx>=1.
|
|
43
|
-
Requires-Dist: onnxruntime>=1.
|
|
45
|
+
Requires-Dist: onnx>=1.14.0; extra == 'onnx'
|
|
46
|
+
Requires-Dist: onnxruntime>=1.14.0; extra == 'onnx'
|
|
47
|
+
Provides-Extra: tracking
|
|
48
|
+
Requires-Dist: swanlab>=0.1.0; extra == 'tracking'
|
|
49
|
+
Requires-Dist: tensorboardx>=2.5; extra == 'tracking'
|
|
50
|
+
Requires-Dist: wandb>=0.13.0; extra == 'tracking'
|
|
44
51
|
Provides-Extra: visualization
|
|
45
52
|
Requires-Dist: graphviz>=0.20; extra == 'visualization'
|
|
46
53
|
Requires-Dist: torchview>=0.2.6; extra == 'visualization'
|
|
@@ -89,7 +96,8 @@ Description-Content-Type: text/markdown
|
|
|
89
96
|
* **易于配置:** 通过配置文件或命令行参数轻松调整实验设置。
|
|
90
97
|
* **可复现性:** 旨在确保实验结果的可复现性。
|
|
91
98
|
* **ONNX 导出:** 支持将训练好的模型导出为 ONNX 格式,便于部署到生产环境。
|
|
92
|
-
*
|
|
99
|
+
* **跨引擎数据处理:** 现已支持基于 PySpark 的数据处理与转换,方便在大数据管道中落地。
|
|
100
|
+
* **实验可视化与跟踪:** 内置 WandB、SwanLab、TensorBoardX 三种可视化/追踪工具的统一集成。
|
|
93
101
|
|
|
94
102
|
## 📖 目录
|
|
95
103
|
|
|
@@ -399,4 +407,4 @@ ctr_trainer.visualization(save_path="model.pdf", dpi=300) # 保存为高清 PDF
|
|
|
399
407
|
|
|
400
408
|
---
|
|
401
409
|
|
|
402
|
-
*最后更新: [2025-12-
|
|
410
|
+
*最后更新: [2025-12-11]*
|
|
@@ -8,6 +8,10 @@ torch_rechub/basic/layers.py,sha256=URWk78dlffMOAhDVDhOhugcr4nmwEa192AI1diktC-4,
|
|
|
8
8
|
torch_rechub/basic/loss_func.py,sha256=6bjljqpiuUP6O8-wUbGd8FSvflY5Dp_DV_57OuQVMz4,7969
|
|
9
9
|
torch_rechub/basic/metaoptimizer.py,sha256=y-oT4MV3vXnSQ5Zd_ZEHP1KClITEi3kbZa6RKjlkYw8,3093
|
|
10
10
|
torch_rechub/basic/metric.py,sha256=9JsaJJGvT6VRvsLoM2Y171CZxESsjYTofD3qnMI-bPM,8443
|
|
11
|
+
torch_rechub/basic/tracking.py,sha256=7-aoyKJxyqb8GobpjRjFsgPYWsBDOV44BYOC_vMoCto,6608
|
|
12
|
+
torch_rechub/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
|
+
torch_rechub/data/convert.py,sha256=clGFEbDSDpdZBvscWatfjtuXMZUzgy1kiEAg4w_q7VM,2241
|
|
14
|
+
torch_rechub/data/dataset.py,sha256=fDDQ5N3x99KPfy0Ux4LRQbFlWbLg_dvKTO1WUEbEN04,4111
|
|
11
15
|
torch_rechub/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
16
|
torch_rechub/models/generative/__init__.py,sha256=TsCdVIhOcalQwqKZKjEuNbHKyIjyclapKGNwYfFR7TM,135
|
|
13
17
|
torch_rechub/models/generative/hllm.py,sha256=6Vrp5Bh0fTFHCn7C-3EqzOyc7UunOyEY9TzAKGHrW-8,9669
|
|
@@ -45,11 +49,11 @@ torch_rechub/models/ranking/edcn.py,sha256=6f_S8I6Ir16kCIU54R4EfumWfUFOND5KDKUPH
|
|
|
45
49
|
torch_rechub/models/ranking/fibinet.py,sha256=fmEJ9WkO8Mn0RtK_8aRHlnQFh_jMBPO0zODoHZPWmDA,2234
|
|
46
50
|
torch_rechub/models/ranking/widedeep.py,sha256=eciRvWRBHLlctabLLS5NB7k3MnqrWXCBdpflOU6jMB0,1636
|
|
47
51
|
torch_rechub/trainers/__init__.py,sha256=NSa2DqgfE1HGDyj40YgrbtUrfBHBxNBpw57XtaAB_jE,148
|
|
48
|
-
torch_rechub/trainers/ctr_trainer.py,sha256=
|
|
49
|
-
torch_rechub/trainers/match_trainer.py,sha256=
|
|
52
|
+
torch_rechub/trainers/ctr_trainer.py,sha256=e0xS-W48BOixN0ogksWOcVJNKFiO3g2oNA_hlHytRqk,14138
|
|
53
|
+
torch_rechub/trainers/match_trainer.py,sha256=atkO-gfDuTk6lh-WvaJOh5kgn6HPzbQQN42Rvz8kyXY,16327
|
|
50
54
|
torch_rechub/trainers/matching.md,sha256=vIBQ3UMmVpUpyk38rrkelFwm_wXVXqMOuqzYZ4M8bzw,30
|
|
51
|
-
torch_rechub/trainers/mtl_trainer.py,sha256=
|
|
52
|
-
torch_rechub/trainers/seq_trainer.py,sha256=
|
|
55
|
+
torch_rechub/trainers/mtl_trainer.py,sha256=n3T-ctWACSyl0awBQixOlZUQ8I5cfGyZzgKV09EF8hw,18293
|
|
56
|
+
torch_rechub/trainers/seq_trainer.py,sha256=pyY70kAjTWdKrnAYZynql1PPNtveYDLMB_1hbpCHa48,19217
|
|
53
57
|
torch_rechub/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
54
58
|
torch_rechub/utils/data.py,sha256=vzLAAVt6dujg_vbGhQewiJc0l6JzwzdcM_9EjoOz898,19882
|
|
55
59
|
torch_rechub/utils/hstu_utils.py,sha256=qLON_pJDC-kDyQn1PoN_HaHi5xTNCwZPgJeV51Z61Lc,6207
|
|
@@ -58,7 +62,7 @@ torch_rechub/utils/model_utils.py,sha256=VLhSbTpupxrFyyY3NzMQ32PPmo5YHm1T96u9KDl
|
|
|
58
62
|
torch_rechub/utils/mtl.py,sha256=AxU05ezizCuLdbPuCg1ZXE0WAStzuxaS5Sc3nwMCBpI,5737
|
|
59
63
|
torch_rechub/utils/onnx_export.py,sha256=LRHyZaR9zZJyg6xtuqQHWmusWq-yEvw9EhlmoEwcqsg,8364
|
|
60
64
|
torch_rechub/utils/visualization.py,sha256=Djv8W5SkCk3P2dol5VXf0_eanIhxDwRd7fzNOQY4uiU,9506
|
|
61
|
-
torch_rechub-0.0.
|
|
62
|
-
torch_rechub-0.0.
|
|
63
|
-
torch_rechub-0.0.
|
|
64
|
-
torch_rechub-0.0.
|
|
65
|
+
torch_rechub-0.0.6.dist-info/METADATA,sha256=OihjWb0yCI1bmTEoCYAC6pI6cCgl5KS5uSrAGZwv7yY,18470
|
|
66
|
+
torch_rechub-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
67
|
+
torch_rechub-0.0.6.dist-info/licenses/LICENSE,sha256=V7ietiX9G_84HtgEbxDgxClniqXGm2t5q8WM4AHGTu0,1066
|
|
68
|
+
torch_rechub-0.0.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|