nextrec 0.4.33__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -18
- nextrec/basic/asserts.py +1 -22
- nextrec/basic/callback.py +2 -2
- nextrec/basic/features.py +6 -37
- nextrec/basic/heads.py +13 -1
- nextrec/basic/layers.py +33 -123
- nextrec/basic/loggers.py +3 -2
- nextrec/basic/metrics.py +85 -4
- nextrec/basic/model.py +518 -7
- nextrec/basic/summary.py +88 -42
- nextrec/cli.py +117 -30
- nextrec/data/data_processing.py +8 -13
- nextrec/data/preprocessor.py +449 -844
- nextrec/loss/grad_norm.py +78 -76
- nextrec/models/multi_task/ple.py +1 -0
- nextrec/models/multi_task/share_bottom.py +1 -0
- nextrec/models/ranking/afm.py +4 -9
- nextrec/models/ranking/dien.py +7 -8
- nextrec/models/ranking/ffm.py +2 -2
- nextrec/models/retrieval/sdm.py +1 -2
- nextrec/models/sequential/hstu.py +0 -2
- nextrec/models/tree_base/base.py +1 -1
- nextrec/utils/__init__.py +2 -1
- nextrec/utils/config.py +1 -1
- nextrec/utils/console.py +1 -1
- nextrec/utils/onnx_utils.py +252 -0
- nextrec/utils/torch_utils.py +63 -56
- nextrec/utils/types.py +43 -0
- {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/METADATA +10 -4
- {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/RECORD +34 -42
- nextrec/models/multi_task/[pre]star.py +0 -192
- nextrec/models/representation/autorec.py +0 -0
- nextrec/models/representation/bpr.py +0 -0
- nextrec/models/representation/cl4srec.py +0 -0
- nextrec/models/representation/lightgcn.py +0 -0
- nextrec/models/representation/mf.py +0 -0
- nextrec/models/representation/s3rec.py +0 -0
- nextrec/models/sequential/sasrec.py +0 -0
- nextrec/utils/feature.py +0 -29
- {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/WHEEL +0 -0
- {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/licenses/LICENSE +0 -0
nextrec/loss/grad_norm.py
CHANGED
|
@@ -2,12 +2,40 @@
|
|
|
2
2
|
GradNorm loss weighting for multi-task learning.
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 22/01/2026
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
|
|
8
8
|
Reference:
|
|
9
9
|
Chen, Zhao, et al. "GradNorm: Gradient Normalization for Adaptive Loss Balancing
|
|
10
10
|
in Deep Multitask Networks." ICML 2018.
|
|
11
|
+
|
|
12
|
+
pseudocode:
|
|
13
|
+
---
|
|
14
|
+
Initialize w_i = 1
|
|
15
|
+
Record L_i(0)
|
|
16
|
+
|
|
17
|
+
for each step:
|
|
18
|
+
1. Forward: compute each task loss L_i
|
|
19
|
+
2. Compute G_i = ||∇_W (w_i * L_i)||
|
|
20
|
+
3. Compute r_i = (L_i / L_i(0)) / mean(...)
|
|
21
|
+
4. Compute target: Ĝ_i = mean(G) * r_i^α
|
|
22
|
+
5. L_grad = sum |G_i - Ĝ_i|
|
|
23
|
+
6. Update w_i using ∇ L_grad
|
|
24
|
+
7. Backprop with sum_i (w_i * L_i) to update model
|
|
25
|
+
|
|
26
|
+
伪代码:
|
|
27
|
+
---
|
|
28
|
+
初始化 w_i = 1
|
|
29
|
+
记录 L_i(0)
|
|
30
|
+
|
|
31
|
+
for each step:
|
|
32
|
+
1. 前向算各 task loss: L_i
|
|
33
|
+
2. 计算 G_i = ||∇_W (w_i * L_i)||
|
|
34
|
+
3. 计算 r_i = (L_i / L_i(0)) / mean(...)
|
|
35
|
+
4. 计算 target: Ĝ_i = mean(G) * r_i^α
|
|
36
|
+
5. L_grad = sum |G_i - Ĝ_i|
|
|
37
|
+
6. 对 w_i 用 ∇L_grad 更新
|
|
38
|
+
7. 用 ∑ w_i * L_i 反传更新模型
|
|
11
39
|
"""
|
|
12
40
|
|
|
13
41
|
from __future__ import annotations
|
|
@@ -15,6 +43,7 @@ from __future__ import annotations
|
|
|
15
43
|
from typing import Iterable
|
|
16
44
|
|
|
17
45
|
import torch
|
|
46
|
+
import torch.distributed as dist
|
|
18
47
|
import torch.nn as nn
|
|
19
48
|
import torch.nn.functional as F
|
|
20
49
|
|
|
@@ -23,7 +52,15 @@ def get_grad_norm_shared_params(
|
|
|
23
52
|
model,
|
|
24
53
|
shared_modules=None,
|
|
25
54
|
):
|
|
55
|
+
"""
|
|
56
|
+
Get shared parameters for GradNorm.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
model: A pytorch model instance containing grad_norm_shared_modules attribute.
|
|
60
|
+
shared_modules: Optional list of module names to consider as shared.
|
|
61
|
+
"""
|
|
26
62
|
if not shared_modules:
|
|
63
|
+
# If no specific shared modules are provided, consider all parameters as shared
|
|
27
64
|
return [p for p in model.parameters() if p.requires_grad]
|
|
28
65
|
shared_params = []
|
|
29
66
|
seen = set()
|
|
@@ -35,26 +72,10 @@ def get_grad_norm_shared_params(
|
|
|
35
72
|
if param.requires_grad and id(param) not in seen:
|
|
36
73
|
shared_params.append(param)
|
|
37
74
|
seen.add(id(param))
|
|
38
|
-
if not shared_params:
|
|
39
|
-
return [p for p in model.parameters() if p.requires_grad]
|
|
40
75
|
return shared_params
|
|
41
76
|
|
|
42
77
|
|
|
43
78
|
class GradNormLossWeighting:
|
|
44
|
-
"""
|
|
45
|
-
Adaptive multi-task loss weighting with GradNorm.
|
|
46
|
-
|
|
47
|
-
Args:
|
|
48
|
-
nums_task: Number of tasks.
|
|
49
|
-
alpha: GradNorm balancing strength.
|
|
50
|
-
lr: Learning rate for the weight optimizer.
|
|
51
|
-
init_weights: Optional initial weights per task.
|
|
52
|
-
device: Torch device for weights.
|
|
53
|
-
ema_decay: Optional EMA decay for smoothing loss ratios.
|
|
54
|
-
init_ema_steps: Number of steps to build EMA for initial losses.
|
|
55
|
-
init_ema_decay: EMA decay for initial losses when init_ema_steps > 0.
|
|
56
|
-
eps: Small value for numerical stability.
|
|
57
|
-
"""
|
|
58
79
|
|
|
59
80
|
def __init__(
|
|
60
81
|
self,
|
|
@@ -63,58 +84,43 @@ class GradNormLossWeighting:
|
|
|
63
84
|
lr: float = 0.025,
|
|
64
85
|
init_weights: Iterable[float] | None = None,
|
|
65
86
|
device: torch.device | str | None = None,
|
|
66
|
-
ema_decay: float | None = None,
|
|
67
|
-
init_ema_steps: int = 0,
|
|
68
|
-
init_ema_decay: float = 0.9,
|
|
69
87
|
eps: float = 1e-8,
|
|
70
88
|
) -> None:
|
|
89
|
+
"""
|
|
90
|
+
Adaptive multi-task loss weighting with GradNorm.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
nums_task: Number of tasks.
|
|
94
|
+
alpha: GradNorm balancing strength.
|
|
95
|
+
lr: Learning rate for the weight optimizer.
|
|
96
|
+
init_weights: Optional initial weights per task.
|
|
97
|
+
device: Torch device for weights.
|
|
98
|
+
eps: Small value for numerical stability.
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
"""
|
|
103
|
+
|
|
71
104
|
if nums_task <= 1:
|
|
72
105
|
raise ValueError("GradNorm requires nums_task > 1.")
|
|
106
|
+
|
|
73
107
|
self.nums_task = nums_task
|
|
74
108
|
self.alpha = alpha
|
|
75
109
|
self.eps = eps
|
|
76
|
-
if ema_decay is not None:
|
|
77
|
-
ema_decay = ema_decay
|
|
78
|
-
if ema_decay < 0.0 or ema_decay >= 1.0:
|
|
79
|
-
raise ValueError("ema_decay must be in [0.0, 1.0).")
|
|
80
|
-
self.ema_decay = ema_decay
|
|
81
|
-
self.init_ema_steps = init_ema_steps
|
|
82
|
-
if self.init_ema_steps < 0:
|
|
83
|
-
raise ValueError("init_ema_steps must be >= 0.")
|
|
84
|
-
self.init_ema_decay = init_ema_decay
|
|
85
|
-
if self.init_ema_decay < 0.0 or self.init_ema_decay >= 1.0:
|
|
86
|
-
raise ValueError("init_ema_decay must be in [0.0, 1.0).")
|
|
87
|
-
self.init_ema_count = 0
|
|
88
110
|
|
|
89
111
|
if init_weights is None:
|
|
90
112
|
weights = torch.ones(self.nums_task, dtype=torch.float32)
|
|
91
113
|
else:
|
|
92
114
|
weights = torch.tensor(list(init_weights), dtype=torch.float32)
|
|
93
|
-
|
|
94
|
-
raise ValueError(
|
|
95
|
-
"init_weights length must match nums_task for GradNorm."
|
|
96
|
-
)
|
|
115
|
+
|
|
97
116
|
if device is not None:
|
|
98
117
|
weights = weights.to(device)
|
|
99
118
|
self.weights = nn.Parameter(weights)
|
|
100
119
|
self.optimizer = torch.optim.Adam([self.weights], lr=float(lr))
|
|
101
120
|
|
|
102
121
|
self.initial_losses = None
|
|
103
|
-
self.initial_losses_ema = None
|
|
104
|
-
self.loss_ema = None
|
|
105
122
|
self.pending_grad = None
|
|
106
123
|
|
|
107
|
-
def to(self, device):
|
|
108
|
-
device = torch.device(device)
|
|
109
|
-
self.weights.data = self.weights.data.to(device)
|
|
110
|
-
if self.initial_losses is not None:
|
|
111
|
-
self.initial_losses = self.initial_losses.to(device)
|
|
112
|
-
if self.initial_losses_ema is not None:
|
|
113
|
-
self.initial_losses_ema = self.initial_losses_ema.to(device)
|
|
114
|
-
if self.loss_ema is not None:
|
|
115
|
-
self.loss_ema = self.loss_ema.to(device)
|
|
116
|
-
return self
|
|
117
|
-
|
|
118
124
|
def compute_weighted_loss(
|
|
119
125
|
self,
|
|
120
126
|
task_losses: list[torch.Tensor],
|
|
@@ -122,6 +128,8 @@ class GradNormLossWeighting:
|
|
|
122
128
|
) -> torch.Tensor:
|
|
123
129
|
"""
|
|
124
130
|
Return weighted total loss and update task weights with GradNorm.
|
|
131
|
+
|
|
132
|
+
BaseModel will use this method to compute the weighted loss when self.grad_norm is enabled.
|
|
125
133
|
"""
|
|
126
134
|
if len(task_losses) != self.nums_task:
|
|
127
135
|
raise ValueError(
|
|
@@ -136,19 +144,7 @@ class GradNormLossWeighting:
|
|
|
136
144
|
[loss.item() for loss in task_losses], device=self.weights.device
|
|
137
145
|
)
|
|
138
146
|
if self.initial_losses is None:
|
|
139
|
-
|
|
140
|
-
if self.initial_losses_ema is None:
|
|
141
|
-
self.initial_losses_ema = loss_values
|
|
142
|
-
else:
|
|
143
|
-
self.initial_losses_ema = (
|
|
144
|
-
self.init_ema_decay * self.initial_losses_ema
|
|
145
|
-
+ (1.0 - self.init_ema_decay) * loss_values
|
|
146
|
-
)
|
|
147
|
-
self.init_ema_count += 1
|
|
148
|
-
if self.init_ema_count >= self.init_ema_steps:
|
|
149
|
-
self.initial_losses = self.initial_losses_ema.clone()
|
|
150
|
-
else:
|
|
151
|
-
self.initial_losses = loss_values
|
|
147
|
+
self.initial_losses = loss_values.clone()
|
|
152
148
|
|
|
153
149
|
weights_detached = self.weights.detach()
|
|
154
150
|
weighted_losses = [
|
|
@@ -157,25 +153,14 @@ class GradNormLossWeighting:
|
|
|
157
153
|
total_loss = torch.stack(weighted_losses).sum()
|
|
158
154
|
|
|
159
155
|
grad_norms = self.compute_grad_norms(task_losses, shared_params)
|
|
156
|
+
|
|
157
|
+
# compute inverse training rate, inv rate = loss_ratio / mean(loss_ratio)
|
|
160
158
|
with torch.no_grad():
|
|
161
|
-
if self.ema_decay is not None:
|
|
162
|
-
if self.loss_ema is None:
|
|
163
|
-
self.loss_ema = loss_values
|
|
164
|
-
else:
|
|
165
|
-
self.loss_ema = (
|
|
166
|
-
self.ema_decay * self.loss_ema
|
|
167
|
-
+ (1.0 - self.ema_decay) * loss_values
|
|
168
|
-
)
|
|
169
|
-
ratio_source = self.loss_ema
|
|
170
|
-
else:
|
|
171
|
-
ratio_source = loss_values
|
|
172
159
|
if self.initial_losses is not None:
|
|
173
160
|
base_initial = self.initial_losses
|
|
174
|
-
elif self.initial_losses_ema is not None:
|
|
175
|
-
base_initial = self.initial_losses_ema
|
|
176
161
|
else:
|
|
177
162
|
base_initial = loss_values
|
|
178
|
-
loss_ratios =
|
|
163
|
+
loss_ratios = loss_values / (base_initial + self.eps)
|
|
179
164
|
inv_rate = loss_ratios / (loss_ratios.mean() + self.eps)
|
|
180
165
|
target = grad_norms.mean() * (inv_rate**self.alpha)
|
|
181
166
|
|
|
@@ -187,6 +172,7 @@ class GradNormLossWeighting:
|
|
|
187
172
|
|
|
188
173
|
def compute_grad_norms(self, task_losses, shared_params):
|
|
189
174
|
grad_norms = []
|
|
175
|
+
# compute gradient norms for each task, gradient norms = sqrt(sum(grad^2))
|
|
190
176
|
for i, task_loss in enumerate(task_losses):
|
|
191
177
|
grads = torch.autograd.grad(
|
|
192
178
|
self.weights[i] * task_loss,
|
|
@@ -230,3 +216,19 @@ class GradNormLossWeighting:
|
|
|
230
216
|
self.weights.copy_(w)
|
|
231
217
|
|
|
232
218
|
self.pending_grad = None
|
|
219
|
+
|
|
220
|
+
def sync(self) -> None:
|
|
221
|
+
"""
|
|
222
|
+
Synchronize GradNorm buffers across DDP ranks.
|
|
223
|
+
|
|
224
|
+
- pending_grad: averaged so all ranks update weights consistently
|
|
225
|
+
- initial_losses: averaged so the baseline loss is consistent
|
|
226
|
+
"""
|
|
227
|
+
|
|
228
|
+
world_size = dist.get_world_size()
|
|
229
|
+
if self.pending_grad is not None:
|
|
230
|
+
dist.all_reduce(self.pending_grad, op=dist.ReduceOp.SUM)
|
|
231
|
+
self.pending_grad /= world_size
|
|
232
|
+
if self.initial_losses is not None:
|
|
233
|
+
dist.all_reduce(self.initial_losses, op=dist.ReduceOp.SUM)
|
|
234
|
+
self.initial_losses /= world_size
|
nextrec/models/multi_task/ple.py
CHANGED
nextrec/models/ranking/afm.py
CHANGED
|
@@ -156,7 +156,7 @@ class AFM(BaseModel):
|
|
|
156
156
|
# First-order dense part
|
|
157
157
|
if self.linear_dense is not None:
|
|
158
158
|
dense_inputs = [
|
|
159
|
-
x[f.name].float().
|
|
159
|
+
x[f.name].float().reshape(batch_size, -1) for f in self.dense_features
|
|
160
160
|
]
|
|
161
161
|
dense_stack = torch.cat(dense_inputs, dim=1) if dense_inputs else None
|
|
162
162
|
if dense_stack is not None:
|
|
@@ -170,7 +170,7 @@ class AFM(BaseModel):
|
|
|
170
170
|
term = emb(x[feature.name].long()) # [B, 1]
|
|
171
171
|
else: # SequenceFeature
|
|
172
172
|
seq_input = x[feature.name].long() # [B, 1]
|
|
173
|
-
if feature.max_len is not None
|
|
173
|
+
if feature.max_len is not None:
|
|
174
174
|
seq_input = seq_input[:, -feature.max_len :]
|
|
175
175
|
mask = self.input_mask(x, feature, seq_input).squeeze(1) # [B, 1]
|
|
176
176
|
seq_weight = emb(seq_input).squeeze(-1) # [B, L]
|
|
@@ -186,16 +186,11 @@ class AFM(BaseModel):
|
|
|
186
186
|
for feature in self.fm_features:
|
|
187
187
|
value = x.get(f"{feature.name}_value")
|
|
188
188
|
if value is not None:
|
|
189
|
-
value = value.float()
|
|
190
|
-
if value.dim() == 1:
|
|
191
|
-
value = value.unsqueeze(-1)
|
|
189
|
+
value = value.float().reshape(batch_size, -1)
|
|
192
190
|
else:
|
|
193
191
|
if isinstance(feature, SequenceFeature):
|
|
194
192
|
seq_input = x[feature.name].long()
|
|
195
|
-
if
|
|
196
|
-
feature.max_len is not None
|
|
197
|
-
and seq_input.size(1) > feature.max_len
|
|
198
|
-
):
|
|
193
|
+
if feature.max_len is not None:
|
|
199
194
|
seq_input = seq_input[:, -feature.max_len :]
|
|
200
195
|
value = self.input_mask(x, feature, seq_input).sum(dim=2) # [B, 1]
|
|
201
196
|
else:
|
nextrec/models/ranking/dien.py
CHANGED
|
@@ -390,13 +390,13 @@ class DIEN(BaseModel):
|
|
|
390
390
|
dim=-1,
|
|
391
391
|
)
|
|
392
392
|
score_t = self.attention_layer.attention_net(concat_feat) # [B, 1]
|
|
393
|
-
att_scores_list.append(score_t)
|
|
393
|
+
att_scores_list.append(score_t.unsqueeze(1))
|
|
394
394
|
|
|
395
395
|
# [B, seq_len, 1]
|
|
396
396
|
att_scores = torch.cat(att_scores_list, dim=1)
|
|
397
397
|
|
|
398
|
-
scores_flat = att_scores
|
|
399
|
-
mask_flat = mask
|
|
398
|
+
scores_flat = att_scores[..., 0] # [B, seq_len]
|
|
399
|
+
mask_flat = mask[..., 0] # [B, seq_len]
|
|
400
400
|
|
|
401
401
|
scores_flat = scores_flat.masked_fill(mask_flat == 0, -1e9)
|
|
402
402
|
att_weights = torch.softmax(scores_flat, dim=1) # [B, seq_len]
|
|
@@ -437,8 +437,7 @@ class DIEN(BaseModel):
|
|
|
437
437
|
|
|
438
438
|
for feat in self.dense_features:
|
|
439
439
|
val = x[feat.name].float()
|
|
440
|
-
|
|
441
|
-
val = val.unsqueeze(1)
|
|
440
|
+
val = val.view(val.size(0), -1)
|
|
442
441
|
other_embeddings.append(val)
|
|
443
442
|
|
|
444
443
|
concat_input = torch.cat(other_embeddings, dim=-1) # [B, total_dim]
|
|
@@ -460,15 +459,15 @@ class DIEN(BaseModel):
|
|
|
460
459
|
interest_states = interest_states[:, :-1, :]
|
|
461
460
|
pos_seq = behavior_emb[:, 1:, :]
|
|
462
461
|
neg_seq = neg_behavior_emb[:, 1:, :]
|
|
463
|
-
aux_mask = mask[:, 1:,
|
|
462
|
+
aux_mask = mask[:, 1:, 0]
|
|
464
463
|
|
|
465
464
|
if aux_mask.sum() == 0:
|
|
466
465
|
return torch.tensor(0.0, device=self.device)
|
|
467
466
|
|
|
468
467
|
pos_input = torch.cat([interest_states, pos_seq], dim=-1)
|
|
469
468
|
neg_input = torch.cat([interest_states, neg_seq], dim=-1)
|
|
470
|
-
pos_logits = self.auxiliary_net(pos_input)
|
|
471
|
-
neg_logits = self.auxiliary_net(neg_input)
|
|
469
|
+
pos_logits = self.auxiliary_net(pos_input)[..., 0]
|
|
470
|
+
neg_logits = self.auxiliary_net(neg_input)[..., 0]
|
|
472
471
|
|
|
473
472
|
pos_loss = F.binary_cross_entropy_with_logits(
|
|
474
473
|
pos_logits, torch.ones_like(pos_logits), reduction="none"
|
nextrec/models/ranking/ffm.py
CHANGED
|
@@ -190,7 +190,7 @@ class FFM(BaseModel):
|
|
|
190
190
|
return emb(x[feature.name].long())
|
|
191
191
|
|
|
192
192
|
seq_input = x[feature.name].long()
|
|
193
|
-
if feature.max_len is not None
|
|
193
|
+
if feature.max_len is not None:
|
|
194
194
|
seq_input = seq_input[:, -feature.max_len :]
|
|
195
195
|
seq_emb = emb(seq_input) # [B, L, D]
|
|
196
196
|
mask = self.input_mask(x, feature, seq_input)
|
|
@@ -224,7 +224,7 @@ class FFM(BaseModel):
|
|
|
224
224
|
term = emb(x[feature.name].long()) # [B, 1]
|
|
225
225
|
else:
|
|
226
226
|
seq_input = x[feature.name].long()
|
|
227
|
-
if feature.max_len is not None
|
|
227
|
+
if feature.max_len is not None:
|
|
228
228
|
seq_input = seq_input[:, -feature.max_len :]
|
|
229
229
|
mask = self.input_mask(x, feature, seq_input).squeeze(1) # [B, L]
|
|
230
230
|
seq_weight = emb(seq_input).squeeze(-1) # [B, L]
|
nextrec/models/retrieval/sdm.py
CHANGED
|
@@ -223,8 +223,7 @@ class SDM(BaseMatchModel):
|
|
|
223
223
|
for feat in self.user_dense_features:
|
|
224
224
|
if feat.name in user_input:
|
|
225
225
|
val = user_input[feat.name].float()
|
|
226
|
-
|
|
227
|
-
val = val.unsqueeze(1)
|
|
226
|
+
val = val.reshape(val.size(0), -1)
|
|
228
227
|
dense_features.append(val)
|
|
229
228
|
if dense_features:
|
|
230
229
|
features_list.append(torch.cat(dense_features, dim=1))
|
|
@@ -438,8 +438,6 @@ class HSTU(BaseModel):
|
|
|
438
438
|
return self.causal_mask[:seq_len, :seq_len]
|
|
439
439
|
|
|
440
440
|
def trim_sequence(self, seq: torch.Tensor) -> torch.Tensor:
|
|
441
|
-
if seq.size(1) <= self.max_seq_len:
|
|
442
|
-
return seq
|
|
443
441
|
return seq[:, -self.max_seq_len :]
|
|
444
442
|
|
|
445
443
|
def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
nextrec/models/tree_base/base.py
CHANGED
|
@@ -29,7 +29,7 @@ from nextrec.data.dataloader import RecDataLoader
|
|
|
29
29
|
from nextrec.data.data_processing import get_column_data
|
|
30
30
|
from nextrec.utils.console import display_metrics_table
|
|
31
31
|
from nextrec.utils.data import FILE_FORMAT_CONFIG, check_streaming_support
|
|
32
|
-
from nextrec.utils.
|
|
32
|
+
from nextrec.utils.torch_utils import to_list
|
|
33
33
|
from nextrec.utils.torch_utils import to_numpy
|
|
34
34
|
|
|
35
35
|
|
nextrec/utils/__init__.py
CHANGED
|
@@ -36,7 +36,7 @@ from .data import (
|
|
|
36
36
|
resolve_file_paths,
|
|
37
37
|
)
|
|
38
38
|
from .embedding import get_auto_embedding_dim
|
|
39
|
-
from .
|
|
39
|
+
from .torch_utils import as_float, to_list
|
|
40
40
|
from .model import (
|
|
41
41
|
compute_pair_scores,
|
|
42
42
|
get_mlp_output_dim,
|
|
@@ -90,6 +90,7 @@ __all__ = [
|
|
|
90
90
|
"normalize_task_loss",
|
|
91
91
|
# Feature utilities
|
|
92
92
|
"to_list",
|
|
93
|
+
"as_float",
|
|
93
94
|
# Config utilities
|
|
94
95
|
"resolve_path",
|
|
95
96
|
"safe_value",
|
nextrec/utils/config.py
CHANGED
|
@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
|
|
21
21
|
import pandas as pd
|
|
22
22
|
import torch
|
|
23
23
|
|
|
24
|
-
from nextrec.utils.
|
|
24
|
+
from nextrec.utils.torch_utils import to_list
|
|
25
25
|
|
|
26
26
|
if TYPE_CHECKING:
|
|
27
27
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
nextrec/utils/console.py
CHANGED