nextrec 0.1.10__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +1 -2
- nextrec/basic/callback.py +1 -2
- nextrec/basic/features.py +39 -8
- nextrec/basic/layers.py +1 -2
- nextrec/basic/loggers.py +15 -10
- nextrec/basic/metrics.py +1 -2
- nextrec/basic/model.py +87 -84
- nextrec/basic/session.py +150 -0
- nextrec/data/__init__.py +13 -2
- nextrec/data/data_utils.py +74 -22
- nextrec/data/dataloader.py +513 -0
- nextrec/data/preprocessor.py +494 -134
- nextrec/loss/listwise.py +6 -0
- nextrec/loss/loss_utils.py +1 -2
- nextrec/loss/match_losses.py +4 -5
- nextrec/loss/pairwise.py +6 -0
- nextrec/loss/pointwise.py +6 -0
- nextrec/models/match/dssm.py +2 -2
- nextrec/models/match/dssm_v2.py +2 -2
- nextrec/models/match/mind.py +2 -2
- nextrec/models/match/sdm.py +2 -2
- nextrec/models/match/youtube_dnn.py +2 -2
- nextrec/models/multi_task/esmm.py +3 -3
- nextrec/models/multi_task/mmoe.py +3 -3
- nextrec/models/multi_task/ple.py +3 -3
- nextrec/models/multi_task/share_bottom.py +3 -3
- nextrec/models/ranking/afm.py +2 -3
- nextrec/models/ranking/autoint.py +3 -3
- nextrec/models/ranking/dcn.py +3 -3
- nextrec/models/ranking/deepfm.py +2 -3
- nextrec/models/ranking/dien.py +3 -3
- nextrec/models/ranking/din.py +3 -3
- nextrec/models/ranking/fibinet.py +3 -3
- nextrec/models/ranking/fm.py +3 -3
- nextrec/models/ranking/masknet.py +3 -3
- nextrec/models/ranking/pnn.py +3 -3
- nextrec/models/ranking/widedeep.py +3 -3
- nextrec/models/ranking/xdeepfm.py +3 -3
- nextrec/utils/__init__.py +4 -8
- nextrec/utils/embedding.py +2 -4
- nextrec/utils/initializer.py +1 -2
- nextrec/utils/optimizer.py +1 -2
- {nextrec-0.1.10.dist-info → nextrec-0.2.1.dist-info}/METADATA +4 -5
- nextrec-0.2.1.dist-info/RECORD +54 -0
- nextrec/basic/dataloader.py +0 -447
- nextrec/utils/common.py +0 -14
- nextrec-0.1.10.dist-info/RECORD +0 -51
- {nextrec-0.1.10.dist-info → nextrec-0.2.1.dist-info}/WHEEL +0 -0
- {nextrec-0.1.10.dist-info → nextrec-0.2.1.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.1
|
|
1
|
+
__version__ = "0.2.1"
|
nextrec/basic/activation.py
CHANGED
nextrec/basic/callback.py
CHANGED
nextrec/basic/features.py
CHANGED
|
@@ -2,12 +2,11 @@
|
|
|
2
2
|
Feature definitions
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Author:
|
|
6
|
-
Yang Zhou,zyaztec@gmail.com
|
|
5
|
+
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
6
|
"""
|
|
8
|
-
|
|
9
|
-
from typing import Optional
|
|
10
|
-
from nextrec.utils import get_auto_embedding_dim
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
from typing import List, Sequence, Optional
|
|
9
|
+
from nextrec.utils.embedding import get_auto_embedding_dim
|
|
11
10
|
|
|
12
11
|
class BaseFeature(object):
|
|
13
12
|
def __repr__(self):
|
|
@@ -26,9 +25,9 @@ class SequenceFeature(BaseFeature):
|
|
|
26
25
|
vocab_size: int,
|
|
27
26
|
max_len: int = 20,
|
|
28
27
|
embedding_name: str = '',
|
|
29
|
-
embedding_dim:
|
|
28
|
+
embedding_dim: int | None = 4,
|
|
30
29
|
combiner: str = "mean",
|
|
31
|
-
padding_idx:
|
|
30
|
+
padding_idx: int | None = None,
|
|
32
31
|
init_type: str='normal',
|
|
33
32
|
init_params: dict|None = None,
|
|
34
33
|
l1_reg: float = 0.0,
|
|
@@ -55,7 +54,7 @@ class SparseFeature(BaseFeature):
|
|
|
55
54
|
name: str,
|
|
56
55
|
vocab_size: int,
|
|
57
56
|
embedding_name: str = '',
|
|
58
|
-
embedding_dim: int = 4,
|
|
57
|
+
embedding_dim: int | None = 4,
|
|
59
58
|
padding_idx: int | None = None,
|
|
60
59
|
init_type: str='normal',
|
|
61
60
|
init_params: dict|None = None,
|
|
@@ -84,4 +83,36 @@ class DenseFeature(BaseFeature):
|
|
|
84
83
|
self.embedding_dim = embedding_dim
|
|
85
84
|
|
|
86
85
|
|
|
86
|
+
class FeatureConfig:
|
|
87
|
+
"""
|
|
88
|
+
Mixin that normalizes dense/sparse/sequence feature lists and target/id columns.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def _set_feature_config(
|
|
92
|
+
self,
|
|
93
|
+
dense_features: Sequence[DenseFeature] | None = None,
|
|
94
|
+
sparse_features: Sequence[SparseFeature] | None = None,
|
|
95
|
+
sequence_features: Sequence[SequenceFeature] | None = None,
|
|
96
|
+
) -> None:
|
|
97
|
+
self.dense_features: List[DenseFeature] = list(dense_features) if dense_features else []
|
|
98
|
+
self.sparse_features: List[SparseFeature] = list(sparse_features) if sparse_features else []
|
|
99
|
+
self.sequence_features: List[SequenceFeature] = list(sequence_features) if sequence_features else []
|
|
100
|
+
|
|
101
|
+
self.all_features = self.dense_features + self.sparse_features + self.sequence_features
|
|
102
|
+
self.feature_names = [feat.name for feat in self.all_features]
|
|
103
|
+
|
|
104
|
+
def _set_target_config(
|
|
105
|
+
self,
|
|
106
|
+
target: str | Sequence[str] | None = None,
|
|
107
|
+
id_columns: str | Sequence[str] | None = None,
|
|
108
|
+
) -> None:
|
|
109
|
+
self.target_columns = self._normalize_to_list(target)
|
|
110
|
+
self.id_columns = self._normalize_to_list(id_columns)
|
|
87
111
|
|
|
112
|
+
@staticmethod
|
|
113
|
+
def _normalize_to_list(value: str | Sequence[str] | None) -> list[str]:
|
|
114
|
+
if value is None:
|
|
115
|
+
return []
|
|
116
|
+
if isinstance(value, str):
|
|
117
|
+
return [value]
|
|
118
|
+
return list(value)
|
nextrec/basic/layers.py
CHANGED
nextrec/basic/loggers.py
CHANGED
|
@@ -2,16 +2,18 @@
|
|
|
2
2
|
NextRec Basic Loggers
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Author:
|
|
6
|
-
Yang Zhou,zyaztec@gmail.com
|
|
5
|
+
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
6
|
"""
|
|
8
7
|
|
|
8
|
+
|
|
9
9
|
import os
|
|
10
10
|
import re
|
|
11
11
|
import sys
|
|
12
12
|
import copy
|
|
13
13
|
import datetime
|
|
14
14
|
import logging
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from nextrec.basic.session import resolve_save_path, create_session
|
|
15
17
|
|
|
16
18
|
ANSI_CODES = {
|
|
17
19
|
'black': '\033[30m',
|
|
@@ -89,16 +91,19 @@ def colorize(text: str, color: str | None = None, bold: bool = False) -> str:
|
|
|
89
91
|
|
|
90
92
|
return result
|
|
91
93
|
|
|
92
|
-
def setup_logger(
|
|
94
|
+
def setup_logger(session_id: str | os.PathLike | None = None):
|
|
93
95
|
"""Set up a logger that logs to both console and a file with ANSI formatting.
|
|
94
|
-
Only console output has colors; file output is stripped of ANSI codes.
|
|
96
|
+
Only console output has colors; file output is stripped of ANSI codes.
|
|
97
|
+
Logs are stored under ``log/<experiment_id>/logs`` by default. A stable
|
|
98
|
+
log file is used per experiment so multiple components (e.g. data
|
|
99
|
+
processor and model training) append to the same file instead of creating
|
|
100
|
+
separate timestamped files.
|
|
95
101
|
"""
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
log_file = os.path.join(log_dir, f"nextrec_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
|
|
102
|
+
|
|
103
|
+
session = create_session(str(session_id) if session_id is not None else None)
|
|
104
|
+
log_dir = session.logs_dir
|
|
105
|
+
log_dir.mkdir(parents=True, exist_ok=True)
|
|
106
|
+
log_file = log_dir / f"{session.experiment_id}.log"
|
|
102
107
|
|
|
103
108
|
console_format = '%(message)s'
|
|
104
109
|
file_format = '%(asctime)s - %(levelname)s - %(message)s'
|
nextrec/basic/metrics.py
CHANGED
nextrec/basic/model.py
CHANGED
|
@@ -2,34 +2,38 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Author:
|
|
6
|
-
Yang Zhou,zyaztec@gmail.com
|
|
5
|
+
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
6
|
"""
|
|
8
7
|
|
|
9
8
|
import os
|
|
10
|
-
import tqdm
|
|
11
|
-
import torch
|
|
12
|
-
import logging
|
|
13
9
|
import datetime
|
|
10
|
+
import logging
|
|
11
|
+
import os
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
14
|
import numpy as np
|
|
15
15
|
import pandas as pd
|
|
16
|
+
import torch
|
|
16
17
|
import torch.nn as nn
|
|
17
18
|
import torch.nn.functional as F
|
|
19
|
+
import tqdm
|
|
18
20
|
|
|
19
21
|
from typing import Union, Literal
|
|
20
22
|
from torch.utils.data import DataLoader, TensorDataset
|
|
21
23
|
|
|
22
24
|
from nextrec.basic.callback import EarlyStopper
|
|
23
|
-
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
25
|
+
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureConfig
|
|
24
26
|
from nextrec.basic.metrics import configure_metrics, evaluate_metrics
|
|
25
27
|
|
|
28
|
+
from nextrec.loss import get_loss_fn
|
|
26
29
|
from nextrec.data import get_column_data
|
|
30
|
+
from nextrec.data.dataloader import build_tensors_from_data
|
|
27
31
|
from nextrec.basic.loggers import setup_logger, colorize
|
|
28
32
|
from nextrec.utils import get_optimizer_fn, get_scheduler_fn
|
|
29
|
-
from nextrec.
|
|
33
|
+
from nextrec.basic.session import resolve_save_path, create_session
|
|
30
34
|
|
|
31
35
|
|
|
32
|
-
class BaseModel(nn.Module):
|
|
36
|
+
class BaseModel(FeatureConfig, nn.Module):
|
|
33
37
|
@property
|
|
34
38
|
def model_name(self) -> str:
|
|
35
39
|
raise NotImplementedError
|
|
@@ -43,6 +47,7 @@ class BaseModel(nn.Module):
|
|
|
43
47
|
sparse_features: list[SparseFeature] | None = None,
|
|
44
48
|
sequence_features: list[SequenceFeature] | None = None,
|
|
45
49
|
target: list[str] | str | None = None,
|
|
50
|
+
id_columns: list[str] | str | None = None,
|
|
46
51
|
task: str|list[str] = 'binary',
|
|
47
52
|
device: str = 'cpu',
|
|
48
53
|
embedding_l1_reg: float = 0.0,
|
|
@@ -50,25 +55,40 @@ class BaseModel(nn.Module):
|
|
|
50
55
|
embedding_l2_reg: float = 0.0,
|
|
51
56
|
dense_l2_reg: float = 0.0,
|
|
52
57
|
early_stop_patience: int = 20,
|
|
53
|
-
|
|
58
|
+
session_id: str | None = None,):
|
|
54
59
|
|
|
55
60
|
super(BaseModel, self).__init__()
|
|
56
61
|
|
|
57
62
|
try:
|
|
58
63
|
self.device = torch.device(device)
|
|
59
64
|
except Exception as e:
|
|
60
|
-
logging.warning(
|
|
65
|
+
logging.warning("Invalid device , defaulting to CPU.")
|
|
61
66
|
self.device = torch.device('cpu')
|
|
62
67
|
|
|
63
|
-
self.
|
|
64
|
-
self.
|
|
65
|
-
self.
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
68
|
+
self.session_id = session_id
|
|
69
|
+
self.session = create_session(session_id)
|
|
70
|
+
self.session_path = Path(self.session.logs_dir)
|
|
71
|
+
checkpoint_dir = self.session.checkpoints_dir / self.model_name
|
|
72
|
+
|
|
73
|
+
self.checkpoint = resolve_save_path(
|
|
74
|
+
path=None,
|
|
75
|
+
default_dir=checkpoint_dir,
|
|
76
|
+
default_name=self.model_name,
|
|
77
|
+
suffix=".model",
|
|
78
|
+
add_timestamp=True,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
self.best = resolve_save_path(
|
|
82
|
+
path="best.model",
|
|
83
|
+
default_dir=checkpoint_dir,
|
|
84
|
+
default_name="best",
|
|
85
|
+
suffix=".model",
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
self._set_feature_config(dense_features, sparse_features, sequence_features)
|
|
89
|
+
self._set_target_config(target, id_columns)
|
|
90
|
+
|
|
91
|
+
self.target = self.target_columns
|
|
72
92
|
self.target_index = {target_name: idx for idx, target_name in enumerate(self.target)}
|
|
73
93
|
|
|
74
94
|
self.task = task
|
|
@@ -85,14 +105,6 @@ class BaseModel(nn.Module):
|
|
|
85
105
|
self.early_stop_patience = early_stop_patience
|
|
86
106
|
self._max_gradient_norm = 1.0 # Maximum gradient norm for gradient clipping
|
|
87
107
|
|
|
88
|
-
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
89
|
-
self.model_id = model_id
|
|
90
|
-
|
|
91
|
-
checkpoint_dir = os.path.abspath(os.path.join(project_root, "..", "checkpoints"))
|
|
92
|
-
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
93
|
-
self.checkpoint = os.path.join(checkpoint_dir, f"{self.model_name}_{self.model_id}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.model")
|
|
94
|
-
self.best = os.path.join(checkpoint_dir, f"{self.model_name}_{self.model_id}_best.model")
|
|
95
|
-
|
|
96
108
|
self._logger_initialized = False
|
|
97
109
|
self._verbose = 1
|
|
98
110
|
|
|
@@ -455,54 +467,15 @@ class BaseModel(nn.Module):
|
|
|
455
467
|
def _prepare_data_loader(self, data: dict|pd.DataFrame|DataLoader, batch_size: int = 32, shuffle: bool = True):
|
|
456
468
|
if isinstance(data, DataLoader):
|
|
457
469
|
return data
|
|
458
|
-
tensors =
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
if isinstance(column, pd.Series):
|
|
468
|
-
column = column.values
|
|
469
|
-
if isinstance(column, np.ndarray) and column.dtype == object:
|
|
470
|
-
column = np.array([np.array(seq, dtype=np.int64) if not isinstance(seq, np.ndarray) else seq for seq in column])
|
|
471
|
-
if isinstance(column, np.ndarray) and column.ndim == 1 and column.dtype == object:
|
|
472
|
-
column = np.vstack([c if isinstance(c, np.ndarray) else np.array(c) for c in column]) # type: ignore
|
|
473
|
-
tensor = torch.from_numpy(np.asarray(column, dtype=np.int64)).to('cpu')
|
|
474
|
-
else:
|
|
475
|
-
dtype = torch.float32 if isinstance(feature, DenseFeature) else torch.long
|
|
476
|
-
tensor = self._to_tensor(column, dtype=dtype, device='cpu')
|
|
477
|
-
|
|
478
|
-
tensors.append(tensor)
|
|
479
|
-
|
|
480
|
-
label_tensors = []
|
|
481
|
-
for target_name in self.target:
|
|
482
|
-
column = get_column_data(data, target_name)
|
|
483
|
-
if column is None:
|
|
484
|
-
continue
|
|
485
|
-
label_tensor = self._to_tensor(column, dtype=torch.float32, device='cpu')
|
|
486
|
-
|
|
487
|
-
if label_tensor.dim() == 1:
|
|
488
|
-
# 1D tensor: (N,) -> (N, 1)
|
|
489
|
-
label_tensor = label_tensor.view(-1, 1)
|
|
490
|
-
elif label_tensor.dim() == 2:
|
|
491
|
-
if label_tensor.shape[0] == 1 and label_tensor.shape[1] > 1:
|
|
492
|
-
label_tensor = label_tensor.t()
|
|
493
|
-
|
|
494
|
-
label_tensors.append(label_tensor)
|
|
495
|
-
|
|
496
|
-
if label_tensors:
|
|
497
|
-
if len(label_tensors) == 1 and label_tensors[0].shape[1] > 1:
|
|
498
|
-
y_tensor = label_tensors[0]
|
|
499
|
-
else:
|
|
500
|
-
y_tensor = torch.cat(label_tensors, dim=1)
|
|
501
|
-
|
|
502
|
-
if y_tensor.shape[1] == 1:
|
|
503
|
-
y_tensor = y_tensor.squeeze(1)
|
|
504
|
-
tensors.append(y_tensor)
|
|
505
|
-
|
|
470
|
+
tensors = build_tensors_from_data(
|
|
471
|
+
data=data,
|
|
472
|
+
raw_data=data,
|
|
473
|
+
features=self.all_features,
|
|
474
|
+
target_columns=self.target,
|
|
475
|
+
id_columns=getattr(self, "id_columns", []),
|
|
476
|
+
on_missing_feature="raise",
|
|
477
|
+
)
|
|
478
|
+
assert tensors is not None, "No tensors were created from provided data."
|
|
506
479
|
dataset = TensorDataset(*tensors)
|
|
507
480
|
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
|
|
508
481
|
|
|
@@ -548,7 +521,7 @@ class BaseModel(nn.Module):
|
|
|
548
521
|
|
|
549
522
|
self.to(self.device)
|
|
550
523
|
if not self._logger_initialized:
|
|
551
|
-
setup_logger()
|
|
524
|
+
setup_logger(session_id=self.session_id)
|
|
552
525
|
self._logger_initialized = True
|
|
553
526
|
self._verbose = verbose
|
|
554
527
|
self._set_metrics(metrics) # add self.metrics, self.task_specific_metrics, self.best_metrics_mode, self.early_stopper
|
|
@@ -975,7 +948,11 @@ class BaseModel(nn.Module):
|
|
|
975
948
|
)
|
|
976
949
|
|
|
977
950
|
|
|
978
|
-
def predict(self,
|
|
951
|
+
def predict(self,
|
|
952
|
+
data: str|dict|pd.DataFrame|DataLoader,
|
|
953
|
+
batch_size: int = 32,
|
|
954
|
+
save_path: str | os.PathLike | None = None,
|
|
955
|
+
save_format: Literal["npy", "csv"] = "npy") -> np.ndarray:
|
|
979
956
|
self.eval()
|
|
980
957
|
# todo: handle file path input later
|
|
981
958
|
if isinstance(data, (str, os.PathLike)):
|
|
@@ -998,12 +975,38 @@ class BaseModel(nn.Module):
|
|
|
998
975
|
|
|
999
976
|
if len(y_pred_list) > 0:
|
|
1000
977
|
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
1001
|
-
return y_pred_all
|
|
1002
978
|
else:
|
|
1003
|
-
|
|
979
|
+
y_pred_all = np.array([])
|
|
980
|
+
|
|
981
|
+
if save_path is not None:
|
|
982
|
+
suffix = ".npy" if save_format == "npy" else ".csv"
|
|
983
|
+
target_path = resolve_save_path(
|
|
984
|
+
path=save_path,
|
|
985
|
+
default_dir=self.session.predictions_dir,
|
|
986
|
+
default_name="predictions",
|
|
987
|
+
suffix=suffix,
|
|
988
|
+
add_timestamp=True if save_path is None else False,
|
|
989
|
+
)
|
|
990
|
+
|
|
991
|
+
if save_format == "npy":
|
|
992
|
+
np.save(target_path, y_pred_all)
|
|
993
|
+
else:
|
|
994
|
+
pd.DataFrame(y_pred_all).to_csv(target_path, index=False)
|
|
995
|
+
|
|
996
|
+
if self._verbose:
|
|
997
|
+
logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
|
|
998
|
+
|
|
999
|
+
return y_pred_all
|
|
1004
1000
|
|
|
1005
|
-
def save_weights(self, model_path: str):
|
|
1006
|
-
|
|
1001
|
+
def save_weights(self, model_path: str | os.PathLike | None):
|
|
1002
|
+
target_path = resolve_save_path(
|
|
1003
|
+
path=model_path,
|
|
1004
|
+
default_dir=self.session.checkpoints_dir / self.model_name,
|
|
1005
|
+
default_name=self.model_name,
|
|
1006
|
+
suffix=".model",
|
|
1007
|
+
add_timestamp=model_path is None,
|
|
1008
|
+
)
|
|
1009
|
+
torch.save(self.state_dict(), target_path)
|
|
1007
1010
|
|
|
1008
1011
|
def load_weights(self, checkpoint):
|
|
1009
1012
|
self.to(self.device)
|
|
@@ -1115,7 +1118,7 @@ class BaseModel(nn.Module):
|
|
|
1115
1118
|
logger.info("Other Settings:")
|
|
1116
1119
|
logger.info(f" Early Stop Patience: {self.early_stop_patience}")
|
|
1117
1120
|
logger.info(f" Max Gradient Norm: {self._max_gradient_norm}")
|
|
1118
|
-
logger.info(f"
|
|
1121
|
+
logger.info(f" Session ID: {self.session_id}")
|
|
1119
1122
|
logger.info(f" Checkpoint Path: {self.checkpoint}")
|
|
1120
1123
|
|
|
1121
1124
|
logger.info("")
|
|
@@ -1160,7 +1163,7 @@ class BaseMatchModel(BaseModel):
|
|
|
1160
1163
|
embedding_l2_reg: float = 0.0,
|
|
1161
1164
|
dense_l2_reg: float = 0.0,
|
|
1162
1165
|
early_stop_patience: int = 20,
|
|
1163
|
-
|
|
1166
|
+
**kwargs):
|
|
1164
1167
|
|
|
1165
1168
|
all_dense_features = []
|
|
1166
1169
|
all_sparse_features = []
|
|
@@ -1191,7 +1194,7 @@ class BaseMatchModel(BaseModel):
|
|
|
1191
1194
|
embedding_l2_reg=embedding_l2_reg,
|
|
1192
1195
|
dense_l2_reg=dense_l2_reg,
|
|
1193
1196
|
early_stop_patience=early_stop_patience,
|
|
1194
|
-
|
|
1197
|
+
**kwargs
|
|
1195
1198
|
)
|
|
1196
1199
|
|
|
1197
1200
|
self.user_dense_features = list(user_dense_features) if user_dense_features else []
|
nextrec/basic/session.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
"""Session and experiment utilities.
|
|
2
|
+
|
|
3
|
+
This module centralizes session/experiment management so the rest of the
|
|
4
|
+
framework writes all artifacts to a consistent location:: <pwd>/log/<experiment_id>/
|
|
5
|
+
|
|
6
|
+
Within that folder we keep model parameters, checkpoints, training metrics,
|
|
7
|
+
evaluation metrics, and consolidated log output. When users do not provide an
|
|
8
|
+
``experiment_id`` a timestamp-based identifier is generated once per process to
|
|
9
|
+
avoid scattering files across multiple directories. Test runs are redirected to
|
|
10
|
+
temporary folders so local trees are not polluted.
|
|
11
|
+
|
|
12
|
+
Date: create on 23/11/2025
|
|
13
|
+
Author: Yang Zhou,zyaztec@gmail.com
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import os
|
|
19
|
+
import tempfile
|
|
20
|
+
from dataclasses import dataclass
|
|
21
|
+
from datetime import datetime
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"Session",
|
|
26
|
+
"resolve_save_path",
|
|
27
|
+
"create_session",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
@dataclass(frozen=True)
|
|
31
|
+
class Session:
|
|
32
|
+
"""Encapsulate standard folders for a NextRec experiment."""
|
|
33
|
+
|
|
34
|
+
experiment_id: str
|
|
35
|
+
root: Path
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def logs_dir(self) -> Path:
|
|
39
|
+
return self._ensure_dir(self.root)
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def checkpoints_dir(self) -> Path:
|
|
43
|
+
return self._ensure_dir(self.root)
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def predictions_dir(self) -> Path:
|
|
47
|
+
return self._ensure_dir(self.root / "predictions")
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def processor_dir(self) -> Path:
|
|
51
|
+
return self._ensure_dir(self.root / "processor")
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def params_dir(self) -> Path:
|
|
55
|
+
return self._ensure_dir(self.root)
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def metrics_dir(self) -> Path:
|
|
59
|
+
return self._ensure_dir(self.root)
|
|
60
|
+
|
|
61
|
+
def save_text(self, name: str, content: str) -> Path:
|
|
62
|
+
"""Convenience helper: write a text file under logs_dir."""
|
|
63
|
+
path = self.logs_dir / name
|
|
64
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
65
|
+
path.write_text(content, encoding="utf-8")
|
|
66
|
+
return path
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def _ensure_dir(path: Path) -> Path:
|
|
70
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
71
|
+
return path
|
|
72
|
+
|
|
73
|
+
def create_session(experiment_id: str | Path | None = None) -> Session:
|
|
74
|
+
"""Create a :class:`Session` instance with prepared directories."""
|
|
75
|
+
|
|
76
|
+
if experiment_id is not None and str(experiment_id).strip():
|
|
77
|
+
exp_id = str(experiment_id).strip()
|
|
78
|
+
else:
|
|
79
|
+
exp_id = "nextrec_session_" + datetime.now().strftime("%Y%m%d")
|
|
80
|
+
|
|
81
|
+
if (
|
|
82
|
+
os.getenv("PYTEST_CURRENT_TEST")
|
|
83
|
+
or os.getenv("PYTEST_RUNNING")
|
|
84
|
+
or os.getenv("NEXTREC_TEST_MODE") == "1"
|
|
85
|
+
):
|
|
86
|
+
session_path = Path(tempfile.gettempdir()) / "nextrec_logs" / exp_id
|
|
87
|
+
else:
|
|
88
|
+
# export NEXTREC_LOG_DIR=/data/nextrec/logs
|
|
89
|
+
base_dir = Path(os.getenv("NEXTREC_LOG_DIR", Path.cwd() / "nextrec_logs"))
|
|
90
|
+
session_path = base_dir / exp_id
|
|
91
|
+
|
|
92
|
+
session_path.mkdir(parents=True, exist_ok=True)
|
|
93
|
+
root = session_path.resolve()
|
|
94
|
+
|
|
95
|
+
return Session(experiment_id=exp_id, root=root)
|
|
96
|
+
|
|
97
|
+
def resolve_save_path(
|
|
98
|
+
path: str | Path | None,
|
|
99
|
+
default_dir: str | Path,
|
|
100
|
+
default_name: str,
|
|
101
|
+
suffix: str,
|
|
102
|
+
add_timestamp: bool = False,
|
|
103
|
+
) -> Path:
|
|
104
|
+
"""
|
|
105
|
+
Normalize and create a save path.
|
|
106
|
+
|
|
107
|
+
- If ``path`` is ``None`` or has no suffix, place the file under
|
|
108
|
+
``default_dir``.
|
|
109
|
+
- If ``path`` has no suffix, its stem is used as the file name; otherwise
|
|
110
|
+
``default_name``.
|
|
111
|
+
- Relative paths with a suffix are also anchored under ``default_dir``.
|
|
112
|
+
- Enforces ``suffix`` (with leading dot) and optionally appends a
|
|
113
|
+
timestamp.
|
|
114
|
+
- Parent directories are created.
|
|
115
|
+
"""
|
|
116
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") if add_timestamp else None
|
|
117
|
+
|
|
118
|
+
normalized_suffix = suffix if suffix.startswith(".") else f".{suffix}"
|
|
119
|
+
|
|
120
|
+
if path is not None and Path(path).suffix:
|
|
121
|
+
target = Path(path)
|
|
122
|
+
if not target.is_absolute():
|
|
123
|
+
target = Path(default_dir) / target
|
|
124
|
+
if target.suffix != normalized_suffix:
|
|
125
|
+
target = target.with_suffix(normalized_suffix)
|
|
126
|
+
if timestamp:
|
|
127
|
+
target = target.with_name(f"{target.stem}_{timestamp}{normalized_suffix}")
|
|
128
|
+
target.parent.mkdir(parents=True, exist_ok=True)
|
|
129
|
+
return target.resolve()
|
|
130
|
+
|
|
131
|
+
base_dir = Path(default_dir)
|
|
132
|
+
candidate = Path(path) if path is not None else None
|
|
133
|
+
|
|
134
|
+
if candidate is not None:
|
|
135
|
+
if candidate.exists() and candidate.is_dir():
|
|
136
|
+
base_dir = candidate
|
|
137
|
+
file_stem = default_name
|
|
138
|
+
else:
|
|
139
|
+
base_dir = candidate.parent if candidate.parent not in (Path("."), Path("")) else base_dir
|
|
140
|
+
file_stem = candidate.name or default_name
|
|
141
|
+
else:
|
|
142
|
+
file_stem = default_name
|
|
143
|
+
|
|
144
|
+
base_dir.mkdir(parents=True, exist_ok=True)
|
|
145
|
+
if timestamp:
|
|
146
|
+
file_stem = f"{file_stem}_{timestamp}"
|
|
147
|
+
|
|
148
|
+
return (base_dir / f"{file_stem}{normalized_suffix}").resolve()
|
|
149
|
+
|
|
150
|
+
|
nextrec/data/__init__.py
CHANGED
|
@@ -4,16 +4,21 @@ Data utilities package for NextRec
|
|
|
4
4
|
This package provides data processing and manipulation utilities.
|
|
5
5
|
|
|
6
6
|
Date: create on 13/11/2025
|
|
7
|
-
Author:
|
|
8
|
-
Yang Zhou, zyaztec@gmail.com
|
|
7
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
9
8
|
"""
|
|
10
9
|
|
|
11
10
|
from nextrec.data.data_utils import (
|
|
12
11
|
collate_fn,
|
|
13
12
|
get_column_data,
|
|
13
|
+
default_output_dir,
|
|
14
14
|
split_dict_random,
|
|
15
15
|
build_eval_candidates,
|
|
16
|
+
resolve_file_paths,
|
|
17
|
+
iter_file_chunks,
|
|
18
|
+
read_table,
|
|
19
|
+
load_dataframes,
|
|
16
20
|
)
|
|
21
|
+
from nextrec.basic.features import FeatureConfig
|
|
17
22
|
|
|
18
23
|
# For backward compatibility, keep utils accessible
|
|
19
24
|
from nextrec.data import data_utils
|
|
@@ -21,7 +26,13 @@ from nextrec.data import data_utils
|
|
|
21
26
|
__all__ = [
|
|
22
27
|
'collate_fn',
|
|
23
28
|
'get_column_data',
|
|
29
|
+
'default_output_dir',
|
|
24
30
|
'split_dict_random',
|
|
25
31
|
'build_eval_candidates',
|
|
32
|
+
'resolve_file_paths',
|
|
33
|
+
'iter_file_chunks',
|
|
34
|
+
'read_table',
|
|
35
|
+
'load_dataframes',
|
|
36
|
+
'FeatureConfig',
|
|
26
37
|
'data_utils',
|
|
27
38
|
]
|