nextrec 0.1.11__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +1 -2
  3. nextrec/basic/callback.py +1 -2
  4. nextrec/basic/features.py +39 -8
  5. nextrec/basic/layers.py +1 -2
  6. nextrec/basic/loggers.py +15 -10
  7. nextrec/basic/metrics.py +1 -2
  8. nextrec/basic/model.py +87 -85
  9. nextrec/basic/session.py +150 -0
  10. nextrec/data/__init__.py +13 -2
  11. nextrec/data/data_utils.py +74 -22
  12. nextrec/data/dataloader.py +513 -0
  13. nextrec/data/preprocessor.py +494 -134
  14. nextrec/loss/listwise.py +6 -0
  15. nextrec/loss/loss_utils.py +1 -2
  16. nextrec/loss/match_losses.py +4 -5
  17. nextrec/loss/pairwise.py +6 -0
  18. nextrec/loss/pointwise.py +6 -0
  19. nextrec/models/match/dssm.py +2 -2
  20. nextrec/models/match/dssm_v2.py +2 -2
  21. nextrec/models/match/mind.py +2 -2
  22. nextrec/models/match/sdm.py +2 -2
  23. nextrec/models/match/youtube_dnn.py +2 -2
  24. nextrec/models/multi_task/esmm.py +3 -3
  25. nextrec/models/multi_task/mmoe.py +3 -3
  26. nextrec/models/multi_task/ple.py +3 -3
  27. nextrec/models/multi_task/share_bottom.py +3 -3
  28. nextrec/models/ranking/afm.py +2 -3
  29. nextrec/models/ranking/autoint.py +3 -3
  30. nextrec/models/ranking/dcn.py +3 -3
  31. nextrec/models/ranking/deepfm.py +2 -3
  32. nextrec/models/ranking/dien.py +3 -3
  33. nextrec/models/ranking/din.py +3 -3
  34. nextrec/models/ranking/fibinet.py +3 -3
  35. nextrec/models/ranking/fm.py +3 -3
  36. nextrec/models/ranking/masknet.py +3 -3
  37. nextrec/models/ranking/pnn.py +3 -3
  38. nextrec/models/ranking/widedeep.py +3 -3
  39. nextrec/models/ranking/xdeepfm.py +3 -3
  40. nextrec/utils/__init__.py +4 -8
  41. nextrec/utils/embedding.py +2 -4
  42. nextrec/utils/initializer.py +1 -2
  43. nextrec/utils/optimizer.py +1 -2
  44. {nextrec-0.1.11.dist-info → nextrec-0.2.1.dist-info}/METADATA +3 -3
  45. nextrec-0.2.1.dist-info/RECORD +54 -0
  46. nextrec/basic/dataloader.py +0 -447
  47. nextrec/utils/common.py +0 -14
  48. nextrec-0.1.11.dist-info/RECORD +0 -51
  49. {nextrec-0.1.11.dist-info → nextrec-0.2.1.dist-info}/WHEEL +0 -0
  50. {nextrec-0.1.11.dist-info → nextrec-0.2.1.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.11"
1
+ __version__ = "0.2.1"
@@ -2,8 +2,7 @@
2
2
  Activation function definitions
3
3
 
4
4
  Date: create on 27/10/2025
5
- Author:
6
- Yang Zhou,zyaztec@gmail.com
5
+ Author: Yang Zhou,zyaztec@gmail.com
7
6
  """
8
7
 
9
8
  import torch
nextrec/basic/callback.py CHANGED
@@ -2,8 +2,7 @@
2
2
  EarlyStopper definitions
3
3
 
4
4
  Date: create on 27/10/2025
5
- Author:
6
- Yang Zhou,zyaztec@gmail.com
5
+ Author: Yang Zhou,zyaztec@gmail.com
7
6
  """
8
7
 
9
8
  import copy
nextrec/basic/features.py CHANGED
@@ -2,12 +2,11 @@
2
2
  Feature definitions
3
3
 
4
4
  Date: create on 27/10/2025
5
- Author:
6
- Yang Zhou,zyaztec@gmail.com
5
+ Author: Yang Zhou,zyaztec@gmail.com
7
6
  """
8
-
9
- from typing import Optional
10
- from nextrec.utils import get_auto_embedding_dim
7
+ from __future__ import annotations
8
+ from typing import List, Sequence, Optional
9
+ from nextrec.utils.embedding import get_auto_embedding_dim
11
10
 
12
11
  class BaseFeature(object):
13
12
  def __repr__(self):
@@ -26,9 +25,9 @@ class SequenceFeature(BaseFeature):
26
25
  vocab_size: int,
27
26
  max_len: int = 20,
28
27
  embedding_name: str = '',
29
- embedding_dim: Optional[int] = 4,
28
+ embedding_dim: int | None = 4,
30
29
  combiner: str = "mean",
31
- padding_idx: Optional[int] = None,
30
+ padding_idx: int | None = None,
32
31
  init_type: str='normal',
33
32
  init_params: dict|None = None,
34
33
  l1_reg: float = 0.0,
@@ -55,7 +54,7 @@ class SparseFeature(BaseFeature):
55
54
  name: str,
56
55
  vocab_size: int,
57
56
  embedding_name: str = '',
58
- embedding_dim: int = 4,
57
+ embedding_dim: int | None = 4,
59
58
  padding_idx: int | None = None,
60
59
  init_type: str='normal',
61
60
  init_params: dict|None = None,
@@ -84,4 +83,36 @@ class DenseFeature(BaseFeature):
84
83
  self.embedding_dim = embedding_dim
85
84
 
86
85
 
86
+ class FeatureConfig:
87
+ """
88
+ Mixin that normalizes dense/sparse/sequence feature lists and target/id columns.
89
+ """
90
+
91
+ def _set_feature_config(
92
+ self,
93
+ dense_features: Sequence[DenseFeature] | None = None,
94
+ sparse_features: Sequence[SparseFeature] | None = None,
95
+ sequence_features: Sequence[SequenceFeature] | None = None,
96
+ ) -> None:
97
+ self.dense_features: List[DenseFeature] = list(dense_features) if dense_features else []
98
+ self.sparse_features: List[SparseFeature] = list(sparse_features) if sparse_features else []
99
+ self.sequence_features: List[SequenceFeature] = list(sequence_features) if sequence_features else []
100
+
101
+ self.all_features = self.dense_features + self.sparse_features + self.sequence_features
102
+ self.feature_names = [feat.name for feat in self.all_features]
103
+
104
+ def _set_target_config(
105
+ self,
106
+ target: str | Sequence[str] | None = None,
107
+ id_columns: str | Sequence[str] | None = None,
108
+ ) -> None:
109
+ self.target_columns = self._normalize_to_list(target)
110
+ self.id_columns = self._normalize_to_list(id_columns)
87
111
 
112
+ @staticmethod
113
+ def _normalize_to_list(value: str | Sequence[str] | None) -> list[str]:
114
+ if value is None:
115
+ return []
116
+ if isinstance(value, str):
117
+ return [value]
118
+ return list(value)
nextrec/basic/layers.py CHANGED
@@ -2,8 +2,7 @@
2
2
  Layer implementations used across NextRec models.
3
3
 
4
4
  Date: create on 27/10/2025, update on 19/11/2025
5
- Author:
6
- Yang Zhou,zyaztec@gmail.com
5
+ Author: Yang Zhou,zyaztec@gmail.com
7
6
  """
8
7
 
9
8
  from __future__ import annotations
nextrec/basic/loggers.py CHANGED
@@ -2,16 +2,18 @@
2
2
  NextRec Basic Loggers
3
3
 
4
4
  Date: create on 27/10/2025
5
- Author:
6
- Yang Zhou,zyaztec@gmail.com
5
+ Author: Yang Zhou,zyaztec@gmail.com
7
6
  """
8
7
 
8
+
9
9
  import os
10
10
  import re
11
11
  import sys
12
12
  import copy
13
13
  import datetime
14
14
  import logging
15
+ from pathlib import Path
16
+ from nextrec.basic.session import resolve_save_path, create_session
15
17
 
16
18
  ANSI_CODES = {
17
19
  'black': '\033[30m',
@@ -89,16 +91,19 @@ def colorize(text: str, color: str | None = None, bold: bool = False) -> str:
89
91
 
90
92
  return result
91
93
 
92
- def setup_logger(log_dir: str | None = None):
94
+ def setup_logger(session_id: str | os.PathLike | None = None):
93
95
  """Set up a logger that logs to both console and a file with ANSI formatting.
94
- Only console output has colors; file output is stripped of ANSI codes.
96
+ Only console output has colors; file output is stripped of ANSI codes.
97
+ Logs are stored under ``log/<experiment_id>/logs`` by default. A stable
98
+ log file is used per experiment so multiple components (e.g. data
99
+ processor and model training) append to the same file instead of creating
100
+ separate timestamped files.
95
101
  """
96
- if log_dir is None:
97
- project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
98
- log_dir = os.path.join(project_root, "..", "logs")
99
-
100
- os.makedirs(log_dir, exist_ok=True)
101
- log_file = os.path.join(log_dir, f"nextrec_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
102
+
103
+ session = create_session(str(session_id) if session_id is not None else None)
104
+ log_dir = session.logs_dir
105
+ log_dir.mkdir(parents=True, exist_ok=True)
106
+ log_file = log_dir / f"{session.experiment_id}.log"
102
107
 
103
108
  console_format = '%(message)s'
104
109
  file_format = '%(asctime)s - %(levelname)s - %(message)s'
nextrec/basic/metrics.py CHANGED
@@ -2,8 +2,7 @@
2
2
  Metrics computation and configuration for model evaluation.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Author:
6
- Yang Zhou,zyaztec@gmail.com
5
+ Author: Yang Zhou,zyaztec@gmail.com
7
6
  """
8
7
  import logging
9
8
  import numpy as np
nextrec/basic/model.py CHANGED
@@ -2,34 +2,38 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Author:
6
- Yang Zhou,zyaztec@gmail.com
5
+ Author: Yang Zhou,zyaztec@gmail.com
7
6
  """
8
7
 
9
8
  import os
10
- import tqdm
11
- import torch
12
- import logging
13
9
  import datetime
10
+ import logging
11
+ import os
12
+ from pathlib import Path
13
+
14
14
  import numpy as np
15
15
  import pandas as pd
16
+ import torch
16
17
  import torch.nn as nn
17
18
  import torch.nn.functional as F
19
+ import tqdm
18
20
 
19
21
  from typing import Union, Literal
20
22
  from torch.utils.data import DataLoader, TensorDataset
21
23
 
22
24
  from nextrec.basic.callback import EarlyStopper
23
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
25
+ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureConfig
24
26
  from nextrec.basic.metrics import configure_metrics, evaluate_metrics
25
27
 
28
+ from nextrec.loss import get_loss_fn
26
29
  from nextrec.data import get_column_data
30
+ from nextrec.data.dataloader import build_tensors_from_data
27
31
  from nextrec.basic.loggers import setup_logger, colorize
28
32
  from nextrec.utils import get_optimizer_fn, get_scheduler_fn
29
- from nextrec.loss import get_loss_fn
33
+ from nextrec.basic.session import resolve_save_path, create_session
30
34
 
31
35
 
32
- class BaseModel(nn.Module):
36
+ class BaseModel(FeatureConfig, nn.Module):
33
37
  @property
34
38
  def model_name(self) -> str:
35
39
  raise NotImplementedError
@@ -43,6 +47,7 @@ class BaseModel(nn.Module):
43
47
  sparse_features: list[SparseFeature] | None = None,
44
48
  sequence_features: list[SequenceFeature] | None = None,
45
49
  target: list[str] | str | None = None,
50
+ id_columns: list[str] | str | None = None,
46
51
  task: str|list[str] = 'binary',
47
52
  device: str = 'cpu',
48
53
  embedding_l1_reg: float = 0.0,
@@ -50,26 +55,40 @@ class BaseModel(nn.Module):
50
55
  embedding_l2_reg: float = 0.0,
51
56
  dense_l2_reg: float = 0.0,
52
57
  early_stop_patience: int = 20,
53
- model_path: str = './',
54
- model_id: str = 'baseline'):
58
+ session_id: str | None = None,):
55
59
 
56
60
  super(BaseModel, self).__init__()
57
61
 
58
62
  try:
59
63
  self.device = torch.device(device)
60
64
  except Exception as e:
61
- logging.warning(colorize("Invalid device , defaulting to CPU.", color='yellow'))
65
+ logging.warning("Invalid device , defaulting to CPU.")
62
66
  self.device = torch.device('cpu')
63
67
 
64
- self.dense_features = list(dense_features) if dense_features is not None else []
65
- self.sparse_features = list(sparse_features) if sparse_features is not None else []
66
- self.sequence_features = list(sequence_features) if sequence_features is not None else []
67
-
68
- if isinstance(target, str):
69
- self.target = [target]
70
- else:
71
- self.target = list(target) if target is not None else []
72
-
68
+ self.session_id = session_id
69
+ self.session = create_session(session_id)
70
+ self.session_path = Path(self.session.logs_dir)
71
+ checkpoint_dir = self.session.checkpoints_dir / self.model_name
72
+
73
+ self.checkpoint = resolve_save_path(
74
+ path=None,
75
+ default_dir=checkpoint_dir,
76
+ default_name=self.model_name,
77
+ suffix=".model",
78
+ add_timestamp=True,
79
+ )
80
+
81
+ self.best = resolve_save_path(
82
+ path="best.model",
83
+ default_dir=checkpoint_dir,
84
+ default_name="best",
85
+ suffix=".model",
86
+ )
87
+
88
+ self._set_feature_config(dense_features, sparse_features, sequence_features)
89
+ self._set_target_config(target, id_columns)
90
+
91
+ self.target = self.target_columns
73
92
  self.target_index = {target_name: idx for idx, target_name in enumerate(self.target)}
74
93
 
75
94
  self.task = task
@@ -86,14 +105,6 @@ class BaseModel(nn.Module):
86
105
  self.early_stop_patience = early_stop_patience
87
106
  self._max_gradient_norm = 1.0 # Maximum gradient norm for gradient clipping
88
107
 
89
- self.model_id = model_id
90
-
91
- model_path = os.path.abspath(os.getcwd() if model_path in [None, './'] else model_path)
92
- checkpoint_dir = os.path.join(model_path, "checkpoints", self.model_id)
93
- os.makedirs(checkpoint_dir, exist_ok=True)
94
- self.checkpoint = os.path.join(checkpoint_dir, f"{self.model_name}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.model")
95
- self.best = os.path.join(checkpoint_dir, f"{self.model_name}_{self.model_id}_best.model")
96
-
97
108
  self._logger_initialized = False
98
109
  self._verbose = 1
99
110
 
@@ -456,54 +467,15 @@ class BaseModel(nn.Module):
456
467
  def _prepare_data_loader(self, data: dict|pd.DataFrame|DataLoader, batch_size: int = 32, shuffle: bool = True):
457
468
  if isinstance(data, DataLoader):
458
469
  return data
459
- tensors = []
460
- all_features = self.dense_features + self.sparse_features + self.sequence_features
461
-
462
- for feature in all_features:
463
- column = get_column_data(data, feature.name)
464
- if column is None:
465
- raise KeyError(f"Feature {feature.name} not found in provided data.")
466
-
467
- if isinstance(feature, SequenceFeature):
468
- if isinstance(column, pd.Series):
469
- column = column.values
470
- if isinstance(column, np.ndarray) and column.dtype == object:
471
- column = np.array([np.array(seq, dtype=np.int64) if not isinstance(seq, np.ndarray) else seq for seq in column])
472
- if isinstance(column, np.ndarray) and column.ndim == 1 and column.dtype == object:
473
- column = np.vstack([c if isinstance(c, np.ndarray) else np.array(c) for c in column]) # type: ignore
474
- tensor = torch.from_numpy(np.asarray(column, dtype=np.int64)).to('cpu')
475
- else:
476
- dtype = torch.float32 if isinstance(feature, DenseFeature) else torch.long
477
- tensor = self._to_tensor(column, dtype=dtype, device='cpu')
478
-
479
- tensors.append(tensor)
480
-
481
- label_tensors = []
482
- for target_name in self.target:
483
- column = get_column_data(data, target_name)
484
- if column is None:
485
- continue
486
- label_tensor = self._to_tensor(column, dtype=torch.float32, device='cpu')
487
-
488
- if label_tensor.dim() == 1:
489
- # 1D tensor: (N,) -> (N, 1)
490
- label_tensor = label_tensor.view(-1, 1)
491
- elif label_tensor.dim() == 2:
492
- if label_tensor.shape[0] == 1 and label_tensor.shape[1] > 1:
493
- label_tensor = label_tensor.t()
494
-
495
- label_tensors.append(label_tensor)
496
-
497
- if label_tensors:
498
- if len(label_tensors) == 1 and label_tensors[0].shape[1] > 1:
499
- y_tensor = label_tensors[0]
500
- else:
501
- y_tensor = torch.cat(label_tensors, dim=1)
502
-
503
- if y_tensor.shape[1] == 1:
504
- y_tensor = y_tensor.squeeze(1)
505
- tensors.append(y_tensor)
506
-
470
+ tensors = build_tensors_from_data(
471
+ data=data,
472
+ raw_data=data,
473
+ features=self.all_features,
474
+ target_columns=self.target,
475
+ id_columns=getattr(self, "id_columns", []),
476
+ on_missing_feature="raise",
477
+ )
478
+ assert tensors is not None, "No tensors were created from provided data."
507
479
  dataset = TensorDataset(*tensors)
508
480
  return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
509
481
 
@@ -549,7 +521,7 @@ class BaseModel(nn.Module):
549
521
 
550
522
  self.to(self.device)
551
523
  if not self._logger_initialized:
552
- setup_logger()
524
+ setup_logger(session_id=self.session_id)
553
525
  self._logger_initialized = True
554
526
  self._verbose = verbose
555
527
  self._set_metrics(metrics) # add self.metrics, self.task_specific_metrics, self.best_metrics_mode, self.early_stopper
@@ -976,7 +948,11 @@ class BaseModel(nn.Module):
976
948
  )
977
949
 
978
950
 
979
- def predict(self, data: str|dict|pd.DataFrame|DataLoader, batch_size: int = 32) -> np.ndarray:
951
+ def predict(self,
952
+ data: str|dict|pd.DataFrame|DataLoader,
953
+ batch_size: int = 32,
954
+ save_path: str | os.PathLike | None = None,
955
+ save_format: Literal["npy", "csv"] = "npy") -> np.ndarray:
980
956
  self.eval()
981
957
  # todo: handle file path input later
982
958
  if isinstance(data, (str, os.PathLike)):
@@ -999,12 +975,38 @@ class BaseModel(nn.Module):
999
975
 
1000
976
  if len(y_pred_list) > 0:
1001
977
  y_pred_all = np.concatenate(y_pred_list, axis=0)
1002
- return y_pred_all
1003
978
  else:
1004
- return np.array([])
979
+ y_pred_all = np.array([])
980
+
981
+ if save_path is not None:
982
+ suffix = ".npy" if save_format == "npy" else ".csv"
983
+ target_path = resolve_save_path(
984
+ path=save_path,
985
+ default_dir=self.session.predictions_dir,
986
+ default_name="predictions",
987
+ suffix=suffix,
988
+ add_timestamp=True if save_path is None else False,
989
+ )
990
+
991
+ if save_format == "npy":
992
+ np.save(target_path, y_pred_all)
993
+ else:
994
+ pd.DataFrame(y_pred_all).to_csv(target_path, index=False)
995
+
996
+ if self._verbose:
997
+ logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
998
+
999
+ return y_pred_all
1005
1000
 
1006
- def save_weights(self, model_path: str):
1007
- torch.save(self.state_dict(), model_path)
1001
+ def save_weights(self, model_path: str | os.PathLike | None):
1002
+ target_path = resolve_save_path(
1003
+ path=model_path,
1004
+ default_dir=self.session.checkpoints_dir / self.model_name,
1005
+ default_name=self.model_name,
1006
+ suffix=".model",
1007
+ add_timestamp=model_path is None,
1008
+ )
1009
+ torch.save(self.state_dict(), target_path)
1008
1010
 
1009
1011
  def load_weights(self, checkpoint):
1010
1012
  self.to(self.device)
@@ -1116,7 +1118,7 @@ class BaseModel(nn.Module):
1116
1118
  logger.info("Other Settings:")
1117
1119
  logger.info(f" Early Stop Patience: {self.early_stop_patience}")
1118
1120
  logger.info(f" Max Gradient Norm: {self._max_gradient_norm}")
1119
- logger.info(f" Model ID: {self.model_id}")
1121
+ logger.info(f" Session ID: {self.session_id}")
1120
1122
  logger.info(f" Checkpoint Path: {self.checkpoint}")
1121
1123
 
1122
1124
  logger.info("")
@@ -1161,7 +1163,7 @@ class BaseMatchModel(BaseModel):
1161
1163
  embedding_l2_reg: float = 0.0,
1162
1164
  dense_l2_reg: float = 0.0,
1163
1165
  early_stop_patience: int = 20,
1164
- model_id: str = 'baseline'):
1166
+ **kwargs):
1165
1167
 
1166
1168
  all_dense_features = []
1167
1169
  all_sparse_features = []
@@ -1192,7 +1194,7 @@ class BaseMatchModel(BaseModel):
1192
1194
  embedding_l2_reg=embedding_l2_reg,
1193
1195
  dense_l2_reg=dense_l2_reg,
1194
1196
  early_stop_patience=early_stop_patience,
1195
- model_id=model_id
1197
+ **kwargs
1196
1198
  )
1197
1199
 
1198
1200
  self.user_dense_features = list(user_dense_features) if user_dense_features else []
@@ -0,0 +1,150 @@
1
+ """Session and experiment utilities.
2
+
3
+ This module centralizes session/experiment management so the rest of the
4
+ framework writes all artifacts to a consistent location:: <pwd>/log/<experiment_id>/
5
+
6
+ Within that folder we keep model parameters, checkpoints, training metrics,
7
+ evaluation metrics, and consolidated log output. When users do not provide an
8
+ ``experiment_id`` a timestamp-based identifier is generated once per process to
9
+ avoid scattering files across multiple directories. Test runs are redirected to
10
+ temporary folders so local trees are not polluted.
11
+
12
+ Date: create on 23/11/2025
13
+ Author: Yang Zhou,zyaztec@gmail.com
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import os
19
+ import tempfile
20
+ from dataclasses import dataclass
21
+ from datetime import datetime
22
+ from pathlib import Path
23
+
24
+ __all__ = [
25
+ "Session",
26
+ "resolve_save_path",
27
+ "create_session",
28
+ ]
29
+
30
+ @dataclass(frozen=True)
31
+ class Session:
32
+ """Encapsulate standard folders for a NextRec experiment."""
33
+
34
+ experiment_id: str
35
+ root: Path
36
+
37
+ @property
38
+ def logs_dir(self) -> Path:
39
+ return self._ensure_dir(self.root)
40
+
41
+ @property
42
+ def checkpoints_dir(self) -> Path:
43
+ return self._ensure_dir(self.root)
44
+
45
+ @property
46
+ def predictions_dir(self) -> Path:
47
+ return self._ensure_dir(self.root / "predictions")
48
+
49
+ @property
50
+ def processor_dir(self) -> Path:
51
+ return self._ensure_dir(self.root / "processor")
52
+
53
+ @property
54
+ def params_dir(self) -> Path:
55
+ return self._ensure_dir(self.root)
56
+
57
+ @property
58
+ def metrics_dir(self) -> Path:
59
+ return self._ensure_dir(self.root)
60
+
61
+ def save_text(self, name: str, content: str) -> Path:
62
+ """Convenience helper: write a text file under logs_dir."""
63
+ path = self.logs_dir / name
64
+ path.parent.mkdir(parents=True, exist_ok=True)
65
+ path.write_text(content, encoding="utf-8")
66
+ return path
67
+
68
+ @staticmethod
69
+ def _ensure_dir(path: Path) -> Path:
70
+ path.mkdir(parents=True, exist_ok=True)
71
+ return path
72
+
73
+ def create_session(experiment_id: str | Path | None = None) -> Session:
74
+ """Create a :class:`Session` instance with prepared directories."""
75
+
76
+ if experiment_id is not None and str(experiment_id).strip():
77
+ exp_id = str(experiment_id).strip()
78
+ else:
79
+ exp_id = "nextrec_session_" + datetime.now().strftime("%Y%m%d")
80
+
81
+ if (
82
+ os.getenv("PYTEST_CURRENT_TEST")
83
+ or os.getenv("PYTEST_RUNNING")
84
+ or os.getenv("NEXTREC_TEST_MODE") == "1"
85
+ ):
86
+ session_path = Path(tempfile.gettempdir()) / "nextrec_logs" / exp_id
87
+ else:
88
+ # export NEXTREC_LOG_DIR=/data/nextrec/logs
89
+ base_dir = Path(os.getenv("NEXTREC_LOG_DIR", Path.cwd() / "nextrec_logs"))
90
+ session_path = base_dir / exp_id
91
+
92
+ session_path.mkdir(parents=True, exist_ok=True)
93
+ root = session_path.resolve()
94
+
95
+ return Session(experiment_id=exp_id, root=root)
96
+
97
+ def resolve_save_path(
98
+ path: str | Path | None,
99
+ default_dir: str | Path,
100
+ default_name: str,
101
+ suffix: str,
102
+ add_timestamp: bool = False,
103
+ ) -> Path:
104
+ """
105
+ Normalize and create a save path.
106
+
107
+ - If ``path`` is ``None`` or has no suffix, place the file under
108
+ ``default_dir``.
109
+ - If ``path`` has no suffix, its stem is used as the file name; otherwise
110
+ ``default_name``.
111
+ - Relative paths with a suffix are also anchored under ``default_dir``.
112
+ - Enforces ``suffix`` (with leading dot) and optionally appends a
113
+ timestamp.
114
+ - Parent directories are created.
115
+ """
116
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") if add_timestamp else None
117
+
118
+ normalized_suffix = suffix if suffix.startswith(".") else f".{suffix}"
119
+
120
+ if path is not None and Path(path).suffix:
121
+ target = Path(path)
122
+ if not target.is_absolute():
123
+ target = Path(default_dir) / target
124
+ if target.suffix != normalized_suffix:
125
+ target = target.with_suffix(normalized_suffix)
126
+ if timestamp:
127
+ target = target.with_name(f"{target.stem}_{timestamp}{normalized_suffix}")
128
+ target.parent.mkdir(parents=True, exist_ok=True)
129
+ return target.resolve()
130
+
131
+ base_dir = Path(default_dir)
132
+ candidate = Path(path) if path is not None else None
133
+
134
+ if candidate is not None:
135
+ if candidate.exists() and candidate.is_dir():
136
+ base_dir = candidate
137
+ file_stem = default_name
138
+ else:
139
+ base_dir = candidate.parent if candidate.parent not in (Path("."), Path("")) else base_dir
140
+ file_stem = candidate.name or default_name
141
+ else:
142
+ file_stem = default_name
143
+
144
+ base_dir.mkdir(parents=True, exist_ok=True)
145
+ if timestamp:
146
+ file_stem = f"{file_stem}_{timestamp}"
147
+
148
+ return (base_dir / f"{file_stem}{normalized_suffix}").resolve()
149
+
150
+
nextrec/data/__init__.py CHANGED
@@ -4,16 +4,21 @@ Data utilities package for NextRec
4
4
  This package provides data processing and manipulation utilities.
5
5
 
6
6
  Date: create on 13/11/2025
7
- Author:
8
- Yang Zhou, zyaztec@gmail.com
7
+ Author: Yang Zhou, zyaztec@gmail.com
9
8
  """
10
9
 
11
10
  from nextrec.data.data_utils import (
12
11
  collate_fn,
13
12
  get_column_data,
13
+ default_output_dir,
14
14
  split_dict_random,
15
15
  build_eval_candidates,
16
+ resolve_file_paths,
17
+ iter_file_chunks,
18
+ read_table,
19
+ load_dataframes,
16
20
  )
21
+ from nextrec.basic.features import FeatureConfig
17
22
 
18
23
  # For backward compatibility, keep utils accessible
19
24
  from nextrec.data import data_utils
@@ -21,7 +26,13 @@ from nextrec.data import data_utils
21
26
  __all__ = [
22
27
  'collate_fn',
23
28
  'get_column_data',
29
+ 'default_output_dir',
24
30
  'split_dict_random',
25
31
  'build_eval_candidates',
32
+ 'resolve_file_paths',
33
+ 'iter_file_chunks',
34
+ 'read_table',
35
+ 'load_dataframes',
36
+ 'FeatureConfig',
26
37
  'data_utils',
27
38
  ]