nextrec 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.3.3"
1
+ __version__ = "0.3.4"
nextrec/basic/loggers.py CHANGED
@@ -2,17 +2,19 @@
2
2
  NextRec Basic Loggers
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 29/11/2025
5
+ Checkpoint: edit on 03/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
9
-
10
9
  import os
11
10
  import re
12
11
  import sys
12
+ import json
13
13
  import copy
14
14
  import logging
15
- from nextrec.basic.session import create_session
15
+ import numbers
16
+ from typing import Mapping, Any
17
+ from nextrec.basic.session import create_session, Session
16
18
 
17
19
  ANSI_CODES = {
18
20
  'black': '\033[30m',
@@ -77,17 +79,12 @@ def colorize(text: str, color: str | None = None, bold: bool = False) -> str:
77
79
  """Apply ANSI color and bold formatting to the given text."""
78
80
  if not color and not bold:
79
81
  return text
80
-
81
82
  result = ""
82
-
83
83
  if bold:
84
84
  result += ANSI_BOLD
85
-
86
85
  if color and color in ANSI_CODES:
87
86
  result += ANSI_CODES[color]
88
-
89
87
  result += text + ANSI_RESET
90
-
91
88
  return result
92
89
 
93
90
  def setup_logger(session_id: str | os.PathLike | None = None):
@@ -126,3 +123,69 @@ def setup_logger(session_id: str | os.PathLike | None = None):
126
123
  logger.addHandler(console_handler)
127
124
 
128
125
  return logger
126
+
127
+ class TrainingLogger:
128
+ def __init__(
129
+ self,
130
+ session: Session,
131
+ enable_tensorboard: bool,
132
+ log_name: str = "training_metrics.jsonl",
133
+ ) -> None:
134
+ self.session = session
135
+ self.enable_tensorboard = enable_tensorboard
136
+ self.log_path = session.metrics_dir / log_name
137
+ self.log_path.parent.mkdir(parents=True, exist_ok=True)
138
+
139
+ self.tb_writer = None
140
+ self.tb_dir = None
141
+
142
+ if self.enable_tensorboard:
143
+ self._init_tensorboard()
144
+
145
+ def _init_tensorboard(self) -> None:
146
+ try:
147
+ from torch.utils.tensorboard import SummaryWriter # type: ignore
148
+ except ImportError:
149
+ logging.warning("[TrainingLogger] tensorboard not installed, disable tensorboard logging.")
150
+ self.enable_tensorboard = False
151
+ return
152
+ tb_dir = self.session.logs_dir / "tensorboard"
153
+ tb_dir.mkdir(parents=True, exist_ok=True)
154
+ self.tb_dir = tb_dir
155
+ self.tb_writer = SummaryWriter(log_dir=str(tb_dir))
156
+
157
+ @property
158
+ def tensorboard_logdir(self):
159
+ return self.tb_dir
160
+
161
+ def format_metrics(self, metrics: Mapping[str, Any], split: str) -> dict[str, float]:
162
+ formatted: dict[str, float] = {}
163
+ for key, value in metrics.items():
164
+ if isinstance(value, numbers.Number):
165
+ formatted[f"{split}/{key}"] = float(value)
166
+ elif hasattr(value, "item"):
167
+ try:
168
+ formatted[f"{split}/{key}"] = float(value.item())
169
+ except Exception:
170
+ continue
171
+ return formatted
172
+
173
+ def log_metrics(self, metrics: Mapping[str, Any], step: int, split: str = "train") -> None:
174
+ payload = self.format_metrics(metrics, split)
175
+ payload["step"] = int(step)
176
+ with self.log_path.open("a", encoding="utf-8") as f:
177
+ f.write(json.dumps(payload, ensure_ascii=False) + "\n")
178
+
179
+ if not self.tb_writer:
180
+ return
181
+ step = int(payload.get("step", 0))
182
+ for key, value in payload.items():
183
+ if key == "step":
184
+ continue
185
+ self.tb_writer.add_scalar(key, value, global_step=step)
186
+
187
+ def close(self) -> None:
188
+ if self.tb_writer:
189
+ self.tb_writer.flush()
190
+ self.tb_writer.close()
191
+ self.tb_writer = None
nextrec/basic/model.py CHANGED
@@ -10,6 +10,8 @@ import os
10
10
  import tqdm
11
11
  import pickle
12
12
  import logging
13
+ import getpass
14
+ import socket
13
15
  import numpy as np
14
16
  import pandas as pd
15
17
  import torch
@@ -24,7 +26,7 @@ from nextrec.basic.callback import EarlyStopper
24
26
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSet
25
27
  from nextrec.data.dataloader import TensorDictDataset, RecDataLoader
26
28
 
27
- from nextrec.basic.loggers import setup_logger, colorize
29
+ from nextrec.basic.loggers import setup_logger, colorize, TrainingLogger
28
30
  from nextrec.basic.session import resolve_save_path, create_session
29
31
  from nextrec.basic.metrics import configure_metrics, evaluate_metrics, check_user_id
30
32
 
@@ -88,6 +90,7 @@ class BaseModel(FeatureSet, nn.Module):
88
90
  self.early_stop_patience = early_stop_patience
89
91
  self.max_gradient_norm = 1.0
90
92
  self.logger_initialized = False
93
+ self.training_logger: TrainingLogger | None = None
91
94
 
92
95
  def register_regularization_weights(self, embedding_attr: str = "embedding", exclude_modules: list[str] | None = None, include_modules: list[str] | None = None) -> None:
93
96
  exclude_modules = exclude_modules or []
@@ -275,11 +278,13 @@ class BaseModel(FeatureSet, nn.Module):
275
278
  metrics: list[str] | dict[str, list[str]] | None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
276
279
  epochs:int=1, shuffle:bool=True, batch_size:int=32,
277
280
  user_id_column: str | None = None,
278
- validation_split: float | None = None):
281
+ validation_split: float | None = None,
282
+ tensorboard: bool = True,):
279
283
  self.to(self.device)
280
284
  if not self.logger_initialized:
281
285
  setup_logger(session_id=self.session_id)
282
286
  self.logger_initialized = True
287
+ self.training_logger = TrainingLogger(session=self.session, enable_tensorboard=tensorboard)
283
288
 
284
289
  self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(task=self.task, metrics=metrics, target_names=self.target_columns) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
285
290
  self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
@@ -303,6 +308,20 @@ class BaseModel(FeatureSet, nn.Module):
303
308
  is_streaming = True
304
309
 
305
310
  self.summary()
311
+ logging.info("")
312
+ if self.training_logger and self.training_logger.enable_tensorboard:
313
+ tb_dir = self.training_logger.tensorboard_logdir
314
+ if tb_dir:
315
+ user = getpass.getuser()
316
+ host = socket.gethostname()
317
+ tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
318
+ ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
319
+ logging.info(colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan"))
320
+ logging.info(colorize("To view logs, run:", color="cyan"))
321
+ logging.info(colorize(f" {tb_cmd}", color="cyan"))
322
+ logging.info(colorize("Then SSH port forward:", color="cyan"))
323
+ logging.info(colorize(f" {ssh_hint}", color="cyan"))
324
+
306
325
  logging.info("")
307
326
  logging.info(colorize("=" * 80, bold=True))
308
327
  if is_streaming:
@@ -312,7 +331,7 @@ class BaseModel(FeatureSet, nn.Module):
312
331
  logging.info(colorize("=" * 80, bold=True))
313
332
  logging.info("")
314
333
  logging.info(colorize(f"Model device: {self.device}", bold=True))
315
-
334
+
316
335
  for epoch in range(epochs):
317
336
  self.epoch_index = epoch
318
337
  if is_streaming:
@@ -326,7 +345,8 @@ class BaseModel(FeatureSet, nn.Module):
326
345
  else:
327
346
  train_loss = train_result
328
347
  train_metrics = None
329
-
348
+
349
+ train_log_payload: dict[str, float] = {}
330
350
  # handle logging for single-task and multi-task
331
351
  if self.nums_task == 1:
332
352
  log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
@@ -334,6 +354,9 @@ class BaseModel(FeatureSet, nn.Module):
334
354
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
335
355
  log_str += f", {metrics_str}"
336
356
  logging.info(colorize(log_str))
357
+ train_log_payload["loss"] = float(train_loss)
358
+ if train_metrics:
359
+ train_log_payload.update(train_metrics)
337
360
  else:
338
361
  total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
339
362
  log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
@@ -356,12 +379,17 @@ class BaseModel(FeatureSet, nn.Module):
356
379
  task_metric_strs.append(f"{target_name}[{metrics_str}]")
357
380
  log_str += ", " + ", ".join(task_metric_strs)
358
381
  logging.info(colorize(log_str))
382
+ train_log_payload["loss"] = float(total_loss_val)
383
+ if train_metrics:
384
+ train_log_payload.update(train_metrics)
385
+ if self.training_logger:
386
+ self.training_logger.log_metrics(train_log_payload, step=epoch + 1, split="train")
359
387
  if valid_loader is not None:
360
388
  # pass user_ids only if needed for GAUC metric
361
389
  val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if self.needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
362
390
  if self.nums_task == 1:
363
391
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
364
- logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
392
+ logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
365
393
  else:
366
394
  # multi task metrics
367
395
  task_metrics = {}
@@ -378,7 +406,9 @@ class BaseModel(FeatureSet, nn.Module):
378
406
  if target_name in task_metrics:
379
407
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
380
408
  task_metric_strs.append(f"{target_name}[{metrics_str}]")
381
- logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
409
+ logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
410
+ if val_metrics and self.training_logger:
411
+ self.training_logger.log_metrics(val_metrics, step=epoch + 1, split="valid")
382
412
  # Handle empty validation metrics
383
413
  if not val_metrics:
384
414
  self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
@@ -401,6 +431,7 @@ class BaseModel(FeatureSet, nn.Module):
401
431
  self.best_metric = primary_metric
402
432
  improved = True
403
433
  self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
434
+ logging.info(" ")
404
435
  if improved:
405
436
  logging.info(colorize(f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"))
406
437
  self.save_model(self.best_path, add_timestamp=False, verbose=False)
@@ -431,6 +462,8 @@ class BaseModel(FeatureSet, nn.Module):
431
462
  if valid_loader is not None:
432
463
  logging.info(colorize(f"Load best model from: {self.best_checkpoint_path}"))
433
464
  self.load_model(self.best_checkpoint_path, map_location=self.device, verbose=False)
465
+ if self.training_logger:
466
+ self.training_logger.close()
434
467
  return self
435
468
 
436
469
  def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
@@ -527,6 +560,7 @@ class BaseModel(FeatureSet, nn.Module):
527
560
  batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
528
561
  if batch_user_id is not None:
529
562
  collected_user_ids.append(batch_user_id)
563
+ logging.info(" ")
530
564
  logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
531
565
  if len(y_true_list) > 0:
532
566
  y_true_all = np.concatenate(y_true_list, axis=0)
@@ -956,9 +990,7 @@ class BaseModel(FeatureSet, nn.Module):
956
990
  logger.info(f" Session ID: {self.session_id}")
957
991
  logger.info(f" Features Config Path: {self.features_config_path}")
958
992
  logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
959
-
960
- logger.info("")
961
- logger.info("")
993
+
962
994
 
963
995
 
964
996
  class BaseMatchModel(BaseModel):
@@ -185,9 +185,9 @@ class RecDataLoader(FeatureSet):
185
185
  chunk_size: int,
186
186
  shuffle: bool) -> DataLoader:
187
187
  if shuffle:
188
- logging.warning("[RecDataLoader Warning] Shuffle is ignored in streaming mode (IterableDataset).")
188
+ logging.info("[RecDataLoader Info] Shuffle is ignored in streaming mode (IterableDataset).")
189
189
  if batch_size != 1:
190
- logging.warning("[RecDataLoader Warning] Streaming mode enforces batch_size=1; tune chunk_size to control memory/throughput.")
190
+ logging.info("[RecDataLoader Info] Streaming mode enforces batch_size=1; tune chunk_size to control memory/throughput.")
191
191
  dataset = FileDataset(file_paths=file_paths, dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target_columns=self.target_columns, id_columns=self.id_columns, chunk_size=chunk_size, file_type=file_type, processor=self.processor)
192
192
  return DataLoader(dataset, batch_size=1, collate_fn=collate_fn)
193
193
 
@@ -38,26 +38,6 @@ from nextrec.__version__ import __version__
38
38
 
39
39
 
40
40
  class DataProcessor(FeatureSet):
41
- """DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
42
-
43
- Examples:
44
- >>> processor = DataProcessor()
45
- >>> processor.add_numeric_feature('age', scaler='standard')
46
- >>> processor.add_sparse_feature('user_id', encode_method='hash', hash_size=10000)
47
- >>> processor.add_sequence_feature('item_history', encode_method='label', max_len=50, pad_value=0)
48
- >>> processor.add_target('label', target_type='binary')
49
- >>>
50
- >>> # Fit and transform data
51
- >>> processor.fit(train_df)
52
- >>> processed_data = processor.transform(test_df) # Returns dict of numpy arrays
53
- >>>
54
- >>> # Save and load processor
55
- >>> processor.save('processor.pkl')
56
- >>> loaded_processor = DataProcessor.load('processor.pkl')
57
- >>>
58
- >>> # Get vocabulary sizes for embedding layers
59
- >>> vocab_sizes = processor.get_vocab_sizes()
60
- """
61
41
  def __init__(self):
62
42
  self.numeric_features: Dict[str, Dict[str, Any]] = {}
63
43
  self.sparse_features: Dict[str, Dict[str, Any]] = {}
@@ -132,10 +112,10 @@ class DataProcessor(FeatureSet):
132
112
  }
133
113
  self.set_target_id(list(self.target_features.keys()), [])
134
114
 
135
- def _hash_string(self, s: str, hash_size: int) -> int:
115
+ def hash_string(self, s: str, hash_size: int) -> int:
136
116
  return int(hashlib.md5(str(s).encode()).hexdigest(), 16) % hash_size
137
117
 
138
- def _process_numeric_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
118
+ def process_numeric_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
139
119
  name = str(data.name)
140
120
  scaler_type = config['scaler']
141
121
  fill_na = config['fill_na']
@@ -164,7 +144,7 @@ class DataProcessor(FeatureSet):
164
144
  scaler.fit(values)
165
145
  self.scalers[name] = scaler
166
146
 
167
- def _process_numeric_feature_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
147
+ def process_numeric_feature_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
168
148
  logger = logging.getLogger()
169
149
  name = str(data.name)
170
150
  scaler_type = config['scaler']
@@ -184,7 +164,7 @@ class DataProcessor(FeatureSet):
184
164
  result = scaler.transform(values.reshape(-1, 1)).ravel()
185
165
  return result
186
166
 
187
- def _process_sparse_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
167
+ def process_sparse_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
188
168
  name = str(data.name)
189
169
  encode_method = config['encode_method']
190
170
  fill_na = config['fill_na'] # <UNK>
@@ -197,7 +177,7 @@ class DataProcessor(FeatureSet):
197
177
  elif encode_method == 'hash':
198
178
  config['vocab_size'] = config['hash_size']
199
179
 
200
- def _process_sparse_feature_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
180
+ def process_sparse_feature_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
201
181
  name = str(data.name)
202
182
  encode_method = config['encode_method']
203
183
  fill_na = config['fill_na']
@@ -215,11 +195,11 @@ class DataProcessor(FeatureSet):
215
195
  return encoded.to_numpy()
216
196
  if encode_method == 'hash':
217
197
  hash_size = config['hash_size']
218
- hash_fn = self._hash_string
198
+ hash_fn = self.hash_string
219
199
  return np.fromiter((hash_fn(v, hash_size) for v in sparse_series.to_numpy()), dtype=np.int64, count=sparse_series.size,)
220
200
  return np.array([], dtype=np.int64)
221
201
 
222
- def _process_sequence_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
202
+ def process_sequence_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
223
203
  name = str(data.name)
224
204
  encode_method = config['encode_method']
225
205
  separator = config['separator']
@@ -252,7 +232,7 @@ class DataProcessor(FeatureSet):
252
232
  elif encode_method == 'hash':
253
233
  config['vocab_size'] = config['hash_size']
254
234
 
255
- def _process_sequence_feature_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
235
+ def process_sequence_feature_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
256
236
  """Optimized sequence transform with preallocation and cached vocab map."""
257
237
  name = str(data.name)
258
238
  encode_method = config['encode_method']
@@ -276,7 +256,7 @@ class DataProcessor(FeatureSet):
276
256
  config['_class_to_idx'] = class_to_idx
277
257
  else:
278
258
  class_to_idx = None # type: ignore
279
- hash_fn = self._hash_string
259
+ hash_fn = self.hash_string
280
260
  hash_size = config.get('hash_size')
281
261
  for i, seq in enumerate(arr):
282
262
  # normalize sequence to a list of strings
@@ -301,11 +281,7 @@ class DataProcessor(FeatureSet):
301
281
  elif encode_method == 'hash':
302
282
  if hash_size is None:
303
283
  raise ValueError("hash_size must be set for hash encoding")
304
- encoded = [
305
- hash_fn(str(token), hash_size)
306
- for token in tokens
307
- if str(token).strip()
308
- ]
284
+ encoded = [hash_fn(str(token), hash_size) for token in tokens if str(token).strip()]
309
285
  else:
310
286
  encoded = []
311
287
  if not encoded:
@@ -315,7 +291,7 @@ class DataProcessor(FeatureSet):
315
291
  output[i, : len(encoded)] = encoded
316
292
  return output
317
293
 
318
- def _process_target_fit(self, data: pd.Series, config: Dict[str, Any]):
294
+ def process_target_fit(self, data: pd.Series, config: Dict[str, Any]):
319
295
  name = str(data.name)
320
296
  target_type = config['target_type']
321
297
  label_map = config.get('label_map')
@@ -334,7 +310,7 @@ class DataProcessor(FeatureSet):
334
310
  config['label_map'] = label_map
335
311
  self.target_encoders[name] = label_map
336
312
 
337
- def _process_target_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
313
+ def process_target_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
338
314
  logger = logging.getLogger()
339
315
  name = str(data.name)
340
316
  target_type = config.get('target_type')
@@ -355,13 +331,13 @@ class DataProcessor(FeatureSet):
355
331
  result.append(0)
356
332
  return np.array(result, dtype=np.int64 if target_type == 'multiclass' else np.float32)
357
333
 
358
- def _load_dataframe_from_path(self, path: str) -> pd.DataFrame:
334
+ def load_dataframe_from_path(self, path: str) -> pd.DataFrame:
359
335
  """Load all data from a file or directory path into a single DataFrame."""
360
336
  file_paths, file_type = resolve_file_paths(path)
361
337
  frames = load_dataframes(file_paths, file_type)
362
338
  return pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
363
339
 
364
- def _extract_sequence_tokens(self, value: Any, separator: str) -> list[str]:
340
+ def extract_sequence_tokens(self, value: Any, separator: str) -> list[str]:
365
341
  """Extract sequence tokens from a single value."""
366
342
  if value is None:
367
343
  return []
@@ -374,7 +350,7 @@ class DataProcessor(FeatureSet):
374
350
  return [str(v) for v in value]
375
351
  return [str(value)]
376
352
 
377
- def _fit_from_path(self, path: str, chunk_size: int) -> 'DataProcessor':
353
+ def fit_from_path(self, path: str, chunk_size: int) -> 'DataProcessor':
378
354
  """Fit processor statistics by streaming files to reduce memory usage."""
379
355
  logger = logging.getLogger()
380
356
  logger.info(colorize("Fitting DataProcessor (streaming path mode)...", color="cyan", bold=True))
@@ -433,7 +409,7 @@ class DataProcessor(FeatureSet):
433
409
  series = chunk[name]
434
410
  tokens = []
435
411
  for val in series:
436
- tokens.extend(self._extract_sequence_tokens(val, separator))
412
+ tokens.extend(self.extract_sequence_tokens(val, separator))
437
413
  seq_vocab[name].update(tokens)
438
414
 
439
415
  # target features
@@ -548,7 +524,7 @@ class DataProcessor(FeatureSet):
548
524
  logger.info(colorize("DataProcessor fitted successfully (streaming path mode)", color="green", bold=True))
549
525
  return self
550
526
 
551
- def _transform_in_memory(
527
+ def transform_in_memory(
552
528
  self,
553
529
  data: Union[pd.DataFrame, Dict[str, Any]],
554
530
  return_dict: bool,
@@ -581,7 +557,7 @@ class DataProcessor(FeatureSet):
581
557
  continue
582
558
  # Convert to Series for processing
583
559
  series_data = pd.Series(data_dict[name], name=name)
584
- processed = self._process_numeric_feature_transform(series_data, config)
560
+ processed = self.process_numeric_feature_transform(series_data, config)
585
561
  result_dict[name] = processed
586
562
 
587
563
  # process sparse features
@@ -590,7 +566,7 @@ class DataProcessor(FeatureSet):
590
566
  logger.warning(f"Sparse feature {name} not found in data")
591
567
  continue
592
568
  series_data = pd.Series(data_dict[name], name=name)
593
- processed = self._process_sparse_feature_transform(series_data, config)
569
+ processed = self.process_sparse_feature_transform(series_data, config)
594
570
  result_dict[name] = processed
595
571
 
596
572
  # process sequence features
@@ -599,7 +575,7 @@ class DataProcessor(FeatureSet):
599
575
  logger.warning(f"Sequence feature {name} not found in data")
600
576
  continue
601
577
  series_data = pd.Series(data_dict[name], name=name)
602
- processed = self._process_sequence_feature_transform(series_data, config)
578
+ processed = self.process_sequence_feature_transform(series_data, config)
603
579
  result_dict[name] = processed
604
580
 
605
581
  # process target features
@@ -608,10 +584,10 @@ class DataProcessor(FeatureSet):
608
584
  logger.warning(f"Target {name} not found in data")
609
585
  continue
610
586
  series_data = pd.Series(data_dict[name], name=name)
611
- processed = self._process_target_transform(series_data, config)
587
+ processed = self.process_target_transform(series_data, config)
612
588
  result_dict[name] = processed
613
589
 
614
- def _dict_to_dataframe(result: Dict[str, np.ndarray]) -> pd.DataFrame:
590
+ def dict_to_dataframe(result: Dict[str, np.ndarray]) -> pd.DataFrame:
615
591
  # Convert all arrays to Series/lists at once to avoid fragmentation
616
592
  columns_dict = {}
617
593
  for key, value in result.items():
@@ -629,7 +605,7 @@ class DataProcessor(FeatureSet):
629
605
  effective_format = save_format or "parquet"
630
606
  result_df = None
631
607
  if (not return_dict) or persist:
632
- result_df = _dict_to_dataframe(result_dict)
608
+ result_df = dict_to_dataframe(result_dict)
633
609
  if persist:
634
610
  if output_path is None:
635
611
  raise ValueError("output_path must be provided when persisting transformed data.")
@@ -649,7 +625,7 @@ class DataProcessor(FeatureSet):
649
625
  assert result_df is not None, "DataFrame is None after transform"
650
626
  return result_df
651
627
 
652
- def _transform_path(
628
+ def transform_path(
653
629
  self,
654
630
  input_path: str,
655
631
  output_path: Optional[str],
@@ -669,13 +645,7 @@ class DataProcessor(FeatureSet):
669
645
  saved_paths = []
670
646
  for file_path in tqdm.tqdm(file_paths, desc="Transforming files", unit="file"):
671
647
  df = read_table(file_path, file_type)
672
- transformed_df = self._transform_in_memory(
673
- df,
674
- return_dict=False,
675
- persist=False,
676
- save_format=None,
677
- output_path=None,
678
- )
648
+ transformed_df = self.transform_in_memory(df, return_dict=False, persist=False, save_format=None, output_path=None)
679
649
  assert isinstance(transformed_df, pd.DataFrame), "Expected DataFrame when return_dict=False"
680
650
  source_path = Path(file_path)
681
651
  target_file = output_root / f"{source_path.stem}.{target_format}"
@@ -695,9 +665,9 @@ class DataProcessor(FeatureSet):
695
665
  uses_robust = any(cfg.get("scaler") == "robust" for cfg in self.numeric_features.values())
696
666
  if uses_robust:
697
667
  logger.warning("Robust scaler requires full data; loading all files into memory. Consider smaller chunk_size or different scaler if memory is limited.")
698
- data = self._load_dataframe_from_path(path_str)
668
+ data = self.load_dataframe_from_path(path_str)
699
669
  else:
700
- return self._fit_from_path(path_str, chunk_size)
670
+ return self.fit_from_path(path_str, chunk_size)
701
671
  if isinstance(data, dict):
702
672
  data = pd.DataFrame(data)
703
673
  logger.info(colorize("Fitting DataProcessor...", color="cyan", bold=True))
@@ -705,22 +675,22 @@ class DataProcessor(FeatureSet):
705
675
  if name not in data.columns:
706
676
  logger.warning(f"Numeric feature {name} not found in data")
707
677
  continue
708
- self._process_numeric_feature_fit(data[name], config)
678
+ self.process_numeric_feature_fit(data[name], config)
709
679
  for name, config in self.sparse_features.items():
710
680
  if name not in data.columns:
711
681
  logger.warning(f"Sparse feature {name} not found in data")
712
682
  continue
713
- self._process_sparse_feature_fit(data[name], config)
683
+ self.process_sparse_feature_fit(data[name], config)
714
684
  for name, config in self.sequence_features.items():
715
685
  if name not in data.columns:
716
686
  logger.warning(f"Sequence feature {name} not found in data")
717
687
  continue
718
- self._process_sequence_feature_fit(data[name], config)
688
+ self.process_sequence_feature_fit(data[name], config)
719
689
  for name, config in self.target_features.items():
720
690
  if name not in data.columns:
721
691
  logger.warning(f"Target {name} not found in data")
722
692
  continue
723
- self._process_target_fit(data[name], config)
693
+ self.process_target_fit(data[name], config)
724
694
  self.is_fitted = True
725
695
  return self
726
696
 
@@ -736,14 +706,8 @@ class DataProcessor(FeatureSet):
736
706
  if isinstance(data, (str, os.PathLike)):
737
707
  if return_dict:
738
708
  raise ValueError("Path transform writes files only; set return_dict=False when passing a path.")
739
- return self._transform_path(str(data), output_path, save_format)
740
- return self._transform_in_memory(
741
- data=data,
742
- return_dict=return_dict,
743
- persist=output_path is not None,
744
- save_format=save_format,
745
- output_path=output_path,
746
- )
709
+ return self.transform_path(str(data), output_path, save_format)
710
+ return self.transform_in_memory(data=data, return_dict=return_dict, persist=output_path is not None, save_format=save_format, output_path=output_path)
747
711
 
748
712
  def fit_transform(
749
713
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.3.3
3
+ Version: 0.3.4
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -63,7 +63,7 @@ Description-Content-Type: text/markdown
63
63
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
64
64
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
65
65
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
66
- ![Version](https://img.shields.io/badge/Version-0.3.3-orange.svg)
66
+ ![Version](https://img.shields.io/badge/Version-0.3.4-orange.svg)
67
67
 
68
68
  English | [中文文档](README_zh.md)
69
69
 
@@ -110,7 +110,7 @@ To dive deeper, Jupyter notebooks are available:
110
110
  - [Hands on the NextRec framework](/tutorials/notebooks/en/Hands%20on%20nextrec.ipynb)
111
111
  - [Using the data processor for preprocessing](/tutorials/notebooks/en/Hands%20on%20dataprocessor.ipynb)
112
112
 
113
- > Current version [0.3.3]: the matching module is not fully polished yet and may have compatibility issues or unexpected errors. Please raise an issue if you run into problems.
113
+ > Current version [0.3.4]: the matching module is not fully polished yet and may have compatibility issues or unexpected errors. Please raise an issue if you run into problems.
114
114
 
115
115
  ## 5-Minute Quick Start
116
116
 
@@ -1,18 +1,18 @@
1
1
  nextrec/__init__.py,sha256=CvocnY2uBp0cjNkhrT6ogw0q2bN9s1GNp754FLO-7lo,1117
2
- nextrec/__version__.py,sha256=8KcCYTXH99C2-gCLuPILJvtT9YftRWJsartIx6TQ2ZY,22
2
+ nextrec/__version__.py,sha256=oYLGMpySamd16KLiaBTfRyrAS7_oyp-TOEHmzmeumwg,22
3
3
  nextrec/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  nextrec/basic/activation.py,sha256=1qs9pq4hT3BUxIiYdYs57axMCm4-JyOBFQ6x7xkHTwM,2849
5
5
  nextrec/basic/callback.py,sha256=wwh0I2kKYyywCB-sG9eQXShlpXFJIo75qApJmnI5p6c,1036
6
6
  nextrec/basic/features.py,sha256=-RRRbEPU-SFI-GtppflW6O0bKShUsV-Hg_lTGpo3AIE,4262
7
7
  nextrec/basic/layers.py,sha256=zzEseKYVnMVs1Tg5EGrFimugId15jI6HumgzjFyRqgw,23127
8
- nextrec/basic/loggers.py,sha256=VNed0LagpoPSUl2itW8hHT-BSqJHTlQY5pVxIVmm6AE,3733
8
+ nextrec/basic/loggers.py,sha256=hh9tRMmaCTaJ_sfRHIlbcqd6BcpK63vpZ_21TFCiKLI,6148
9
9
  nextrec/basic/metrics.py,sha256=8-hMZJXU5L4F8GnToxMZey5dlBrtFyRtTuI_zoQCtIo,21579
10
- nextrec/basic/model.py,sha256=vtxPuGePgf7lFXItremzKIJmKe4pcSGEZ16TBLw7wcc,67059
10
+ nextrec/basic/model.py,sha256=afnvicyxXMgWdvhrIUaoNnZ7S-QYRYr7fTY5bdM1u_s,68829
11
11
  nextrec/basic/session.py,sha256=oaATn-nzbJ9A6SGbMut9xLV_NSh9_1KmVDeNauS06Ps,4767
12
12
  nextrec/data/__init__.py,sha256=6WgXZafzzXcv5kuxKNi67O8BJZVl_P_HM2IZCDIIhPA,1052
13
13
  nextrec/data/data_utils.py,sha256=aOyja3Yu7O2c8eIeL3P8MyUlUR5EerOUT9UeF4ATq8o,10574
14
- nextrec/data/dataloader.py,sha256=JsEVInyZ1nQXLAbRDPPN3Y47wOvWxHHOy-ikLa6sOrg,14211
15
- nextrec/data/preprocessor.py,sha256=Mg0unoalwNsa_OIPq8myxj3rNCrHqfTwB1IpBCdXbnI,41734
14
+ nextrec/data/dataloader.py,sha256=2MLe69y0E1cTZyzMNgyLUCxa6lllGd1ntvwpXzxdX10,14199
15
+ nextrec/data/preprocessor.py,sha256=lhigpjvkEqsjTRfbBBOjgGOxoPyOifwq2LoswgyIVqc,40488
16
16
  nextrec/loss/__init__.py,sha256=mO5t417BneZ8Ysa51GyjDaffjWyjzFgPXIQrrggasaQ,827
17
17
  nextrec/loss/listwise.py,sha256=gxDbO1td5IeS28jKzdE35o1KAYBRdCYoMzyZzfNLhc0,5689
18
18
  nextrec/loss/loss_utils.py,sha256=uZ4m9ChLr-UgIc5Yxm1LjwXDDepApQ-Fas8njweZ9qg,2641
@@ -51,7 +51,7 @@ nextrec/utils/common.py,sha256=NYXnBVtUCtm8epT2ZxJHn_m1SIBBI_PEjZ5VpL465ls,2009
51
51
  nextrec/utils/embedding.py,sha256=yxYSdFx0cJITh3Gf-K4SdhwRtKGcI0jOsyBgZ0NLa_c,465
52
52
  nextrec/utils/initializer.py,sha256=ffYOs5QuIns_d_-5e40iNtg6s1ftgREJN-ueq_NbDQE,1647
53
53
  nextrec/utils/optimizer.py,sha256=EUjAGFPeyou_Cv-_2HRvjzut8y_qpAQudc8L2T0k8zw,2706
54
- nextrec-0.3.3.dist-info/METADATA,sha256=MR4cHVPwRpyI0RBfooMTu2jZIUcPU-Ztp0AhGAMz37w,16319
55
- nextrec-0.3.3.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
56
- nextrec-0.3.3.dist-info/licenses/LICENSE,sha256=2fQfVKeafywkni7MYHyClC6RGGC3laLTXCNBx-ubtp0,1064
57
- nextrec-0.3.3.dist-info/RECORD,,
54
+ nextrec-0.3.4.dist-info/METADATA,sha256=X5fo5gymQdPXLgM1N03E58uFSQyuQOmdbUp8vXvKl0g,16319
55
+ nextrec-0.3.4.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
56
+ nextrec-0.3.4.dist-info/licenses/LICENSE,sha256=2fQfVKeafywkni7MYHyClC6RGGC3laLTXCNBx-ubtp0,1064
57
+ nextrec-0.3.4.dist-info/RECORD,,