nextrec 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/loggers.py +71 -8
- nextrec/basic/model.py +41 -9
- nextrec/data/dataloader.py +2 -2
- nextrec/data/preprocessor.py +33 -69
- {nextrec-0.3.3.dist-info → nextrec-0.3.4.dist-info}/METADATA +3 -3
- {nextrec-0.3.3.dist-info → nextrec-0.3.4.dist-info}/RECORD +9 -9
- {nextrec-0.3.3.dist-info → nextrec-0.3.4.dist-info}/WHEEL +0 -0
- {nextrec-0.3.3.dist-info → nextrec-0.3.4.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.3.
|
|
1
|
+
__version__ = "0.3.4"
|
nextrec/basic/loggers.py
CHANGED
|
@@ -2,17 +2,19 @@
|
|
|
2
2
|
NextRec Basic Loggers
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 03/12/2025
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
-
|
|
10
9
|
import os
|
|
11
10
|
import re
|
|
12
11
|
import sys
|
|
12
|
+
import json
|
|
13
13
|
import copy
|
|
14
14
|
import logging
|
|
15
|
-
|
|
15
|
+
import numbers
|
|
16
|
+
from typing import Mapping, Any
|
|
17
|
+
from nextrec.basic.session import create_session, Session
|
|
16
18
|
|
|
17
19
|
ANSI_CODES = {
|
|
18
20
|
'black': '\033[30m',
|
|
@@ -77,17 +79,12 @@ def colorize(text: str, color: str | None = None, bold: bool = False) -> str:
|
|
|
77
79
|
"""Apply ANSI color and bold formatting to the given text."""
|
|
78
80
|
if not color and not bold:
|
|
79
81
|
return text
|
|
80
|
-
|
|
81
82
|
result = ""
|
|
82
|
-
|
|
83
83
|
if bold:
|
|
84
84
|
result += ANSI_BOLD
|
|
85
|
-
|
|
86
85
|
if color and color in ANSI_CODES:
|
|
87
86
|
result += ANSI_CODES[color]
|
|
88
|
-
|
|
89
87
|
result += text + ANSI_RESET
|
|
90
|
-
|
|
91
88
|
return result
|
|
92
89
|
|
|
93
90
|
def setup_logger(session_id: str | os.PathLike | None = None):
|
|
@@ -126,3 +123,69 @@ def setup_logger(session_id: str | os.PathLike | None = None):
|
|
|
126
123
|
logger.addHandler(console_handler)
|
|
127
124
|
|
|
128
125
|
return logger
|
|
126
|
+
|
|
127
|
+
class TrainingLogger:
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
session: Session,
|
|
131
|
+
enable_tensorboard: bool,
|
|
132
|
+
log_name: str = "training_metrics.jsonl",
|
|
133
|
+
) -> None:
|
|
134
|
+
self.session = session
|
|
135
|
+
self.enable_tensorboard = enable_tensorboard
|
|
136
|
+
self.log_path = session.metrics_dir / log_name
|
|
137
|
+
self.log_path.parent.mkdir(parents=True, exist_ok=True)
|
|
138
|
+
|
|
139
|
+
self.tb_writer = None
|
|
140
|
+
self.tb_dir = None
|
|
141
|
+
|
|
142
|
+
if self.enable_tensorboard:
|
|
143
|
+
self._init_tensorboard()
|
|
144
|
+
|
|
145
|
+
def _init_tensorboard(self) -> None:
|
|
146
|
+
try:
|
|
147
|
+
from torch.utils.tensorboard import SummaryWriter # type: ignore
|
|
148
|
+
except ImportError:
|
|
149
|
+
logging.warning("[TrainingLogger] tensorboard not installed, disable tensorboard logging.")
|
|
150
|
+
self.enable_tensorboard = False
|
|
151
|
+
return
|
|
152
|
+
tb_dir = self.session.logs_dir / "tensorboard"
|
|
153
|
+
tb_dir.mkdir(parents=True, exist_ok=True)
|
|
154
|
+
self.tb_dir = tb_dir
|
|
155
|
+
self.tb_writer = SummaryWriter(log_dir=str(tb_dir))
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def tensorboard_logdir(self):
|
|
159
|
+
return self.tb_dir
|
|
160
|
+
|
|
161
|
+
def format_metrics(self, metrics: Mapping[str, Any], split: str) -> dict[str, float]:
|
|
162
|
+
formatted: dict[str, float] = {}
|
|
163
|
+
for key, value in metrics.items():
|
|
164
|
+
if isinstance(value, numbers.Number):
|
|
165
|
+
formatted[f"{split}/{key}"] = float(value)
|
|
166
|
+
elif hasattr(value, "item"):
|
|
167
|
+
try:
|
|
168
|
+
formatted[f"{split}/{key}"] = float(value.item())
|
|
169
|
+
except Exception:
|
|
170
|
+
continue
|
|
171
|
+
return formatted
|
|
172
|
+
|
|
173
|
+
def log_metrics(self, metrics: Mapping[str, Any], step: int, split: str = "train") -> None:
|
|
174
|
+
payload = self.format_metrics(metrics, split)
|
|
175
|
+
payload["step"] = int(step)
|
|
176
|
+
with self.log_path.open("a", encoding="utf-8") as f:
|
|
177
|
+
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
|
178
|
+
|
|
179
|
+
if not self.tb_writer:
|
|
180
|
+
return
|
|
181
|
+
step = int(payload.get("step", 0))
|
|
182
|
+
for key, value in payload.items():
|
|
183
|
+
if key == "step":
|
|
184
|
+
continue
|
|
185
|
+
self.tb_writer.add_scalar(key, value, global_step=step)
|
|
186
|
+
|
|
187
|
+
def close(self) -> None:
|
|
188
|
+
if self.tb_writer:
|
|
189
|
+
self.tb_writer.flush()
|
|
190
|
+
self.tb_writer.close()
|
|
191
|
+
self.tb_writer = None
|
nextrec/basic/model.py
CHANGED
|
@@ -10,6 +10,8 @@ import os
|
|
|
10
10
|
import tqdm
|
|
11
11
|
import pickle
|
|
12
12
|
import logging
|
|
13
|
+
import getpass
|
|
14
|
+
import socket
|
|
13
15
|
import numpy as np
|
|
14
16
|
import pandas as pd
|
|
15
17
|
import torch
|
|
@@ -24,7 +26,7 @@ from nextrec.basic.callback import EarlyStopper
|
|
|
24
26
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSet
|
|
25
27
|
from nextrec.data.dataloader import TensorDictDataset, RecDataLoader
|
|
26
28
|
|
|
27
|
-
from nextrec.basic.loggers import setup_logger, colorize
|
|
29
|
+
from nextrec.basic.loggers import setup_logger, colorize, TrainingLogger
|
|
28
30
|
from nextrec.basic.session import resolve_save_path, create_session
|
|
29
31
|
from nextrec.basic.metrics import configure_metrics, evaluate_metrics, check_user_id
|
|
30
32
|
|
|
@@ -88,6 +90,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
88
90
|
self.early_stop_patience = early_stop_patience
|
|
89
91
|
self.max_gradient_norm = 1.0
|
|
90
92
|
self.logger_initialized = False
|
|
93
|
+
self.training_logger: TrainingLogger | None = None
|
|
91
94
|
|
|
92
95
|
def register_regularization_weights(self, embedding_attr: str = "embedding", exclude_modules: list[str] | None = None, include_modules: list[str] | None = None) -> None:
|
|
93
96
|
exclude_modules = exclude_modules or []
|
|
@@ -275,11 +278,13 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
275
278
|
metrics: list[str] | dict[str, list[str]] | None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
|
|
276
279
|
epochs:int=1, shuffle:bool=True, batch_size:int=32,
|
|
277
280
|
user_id_column: str | None = None,
|
|
278
|
-
validation_split: float | None = None
|
|
281
|
+
validation_split: float | None = None,
|
|
282
|
+
tensorboard: bool = True,):
|
|
279
283
|
self.to(self.device)
|
|
280
284
|
if not self.logger_initialized:
|
|
281
285
|
setup_logger(session_id=self.session_id)
|
|
282
286
|
self.logger_initialized = True
|
|
287
|
+
self.training_logger = TrainingLogger(session=self.session, enable_tensorboard=tensorboard)
|
|
283
288
|
|
|
284
289
|
self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(task=self.task, metrics=metrics, target_names=self.target_columns) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
|
|
285
290
|
self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
|
|
@@ -303,6 +308,20 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
303
308
|
is_streaming = True
|
|
304
309
|
|
|
305
310
|
self.summary()
|
|
311
|
+
logging.info("")
|
|
312
|
+
if self.training_logger and self.training_logger.enable_tensorboard:
|
|
313
|
+
tb_dir = self.training_logger.tensorboard_logdir
|
|
314
|
+
if tb_dir:
|
|
315
|
+
user = getpass.getuser()
|
|
316
|
+
host = socket.gethostname()
|
|
317
|
+
tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
|
|
318
|
+
ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
|
|
319
|
+
logging.info(colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan"))
|
|
320
|
+
logging.info(colorize("To view logs, run:", color="cyan"))
|
|
321
|
+
logging.info(colorize(f" {tb_cmd}", color="cyan"))
|
|
322
|
+
logging.info(colorize("Then SSH port forward:", color="cyan"))
|
|
323
|
+
logging.info(colorize(f" {ssh_hint}", color="cyan"))
|
|
324
|
+
|
|
306
325
|
logging.info("")
|
|
307
326
|
logging.info(colorize("=" * 80, bold=True))
|
|
308
327
|
if is_streaming:
|
|
@@ -312,7 +331,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
312
331
|
logging.info(colorize("=" * 80, bold=True))
|
|
313
332
|
logging.info("")
|
|
314
333
|
logging.info(colorize(f"Model device: {self.device}", bold=True))
|
|
315
|
-
|
|
334
|
+
|
|
316
335
|
for epoch in range(epochs):
|
|
317
336
|
self.epoch_index = epoch
|
|
318
337
|
if is_streaming:
|
|
@@ -326,7 +345,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
326
345
|
else:
|
|
327
346
|
train_loss = train_result
|
|
328
347
|
train_metrics = None
|
|
329
|
-
|
|
348
|
+
|
|
349
|
+
train_log_payload: dict[str, float] = {}
|
|
330
350
|
# handle logging for single-task and multi-task
|
|
331
351
|
if self.nums_task == 1:
|
|
332
352
|
log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
|
|
@@ -334,6 +354,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
334
354
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
|
|
335
355
|
log_str += f", {metrics_str}"
|
|
336
356
|
logging.info(colorize(log_str))
|
|
357
|
+
train_log_payload["loss"] = float(train_loss)
|
|
358
|
+
if train_metrics:
|
|
359
|
+
train_log_payload.update(train_metrics)
|
|
337
360
|
else:
|
|
338
361
|
total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
|
|
339
362
|
log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
|
|
@@ -356,12 +379,17 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
356
379
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
357
380
|
log_str += ", " + ", ".join(task_metric_strs)
|
|
358
381
|
logging.info(colorize(log_str))
|
|
382
|
+
train_log_payload["loss"] = float(total_loss_val)
|
|
383
|
+
if train_metrics:
|
|
384
|
+
train_log_payload.update(train_metrics)
|
|
385
|
+
if self.training_logger:
|
|
386
|
+
self.training_logger.log_metrics(train_log_payload, step=epoch + 1, split="train")
|
|
359
387
|
if valid_loader is not None:
|
|
360
388
|
# pass user_ids only if needed for GAUC metric
|
|
361
389
|
val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if self.needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
362
390
|
if self.nums_task == 1:
|
|
363
391
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
|
|
364
|
-
logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
|
|
392
|
+
logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
|
|
365
393
|
else:
|
|
366
394
|
# multi task metrics
|
|
367
395
|
task_metrics = {}
|
|
@@ -378,7 +406,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
378
406
|
if target_name in task_metrics:
|
|
379
407
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
|
|
380
408
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
381
|
-
logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
|
|
409
|
+
logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
|
|
410
|
+
if val_metrics and self.training_logger:
|
|
411
|
+
self.training_logger.log_metrics(val_metrics, step=epoch + 1, split="valid")
|
|
382
412
|
# Handle empty validation metrics
|
|
383
413
|
if not val_metrics:
|
|
384
414
|
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
@@ -401,6 +431,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
401
431
|
self.best_metric = primary_metric
|
|
402
432
|
improved = True
|
|
403
433
|
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
434
|
+
logging.info(" ")
|
|
404
435
|
if improved:
|
|
405
436
|
logging.info(colorize(f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"))
|
|
406
437
|
self.save_model(self.best_path, add_timestamp=False, verbose=False)
|
|
@@ -431,6 +462,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
431
462
|
if valid_loader is not None:
|
|
432
463
|
logging.info(colorize(f"Load best model from: {self.best_checkpoint_path}"))
|
|
433
464
|
self.load_model(self.best_checkpoint_path, map_location=self.device, verbose=False)
|
|
465
|
+
if self.training_logger:
|
|
466
|
+
self.training_logger.close()
|
|
434
467
|
return self
|
|
435
468
|
|
|
436
469
|
def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
|
|
@@ -527,6 +560,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
527
560
|
batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
|
|
528
561
|
if batch_user_id is not None:
|
|
529
562
|
collected_user_ids.append(batch_user_id)
|
|
563
|
+
logging.info(" ")
|
|
530
564
|
logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
|
|
531
565
|
if len(y_true_list) > 0:
|
|
532
566
|
y_true_all = np.concatenate(y_true_list, axis=0)
|
|
@@ -956,9 +990,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
956
990
|
logger.info(f" Session ID: {self.session_id}")
|
|
957
991
|
logger.info(f" Features Config Path: {self.features_config_path}")
|
|
958
992
|
logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
|
|
959
|
-
|
|
960
|
-
logger.info("")
|
|
961
|
-
logger.info("")
|
|
993
|
+
|
|
962
994
|
|
|
963
995
|
|
|
964
996
|
class BaseMatchModel(BaseModel):
|
nextrec/data/dataloader.py
CHANGED
|
@@ -185,9 +185,9 @@ class RecDataLoader(FeatureSet):
|
|
|
185
185
|
chunk_size: int,
|
|
186
186
|
shuffle: bool) -> DataLoader:
|
|
187
187
|
if shuffle:
|
|
188
|
-
logging.
|
|
188
|
+
logging.info("[RecDataLoader Info] Shuffle is ignored in streaming mode (IterableDataset).")
|
|
189
189
|
if batch_size != 1:
|
|
190
|
-
logging.
|
|
190
|
+
logging.info("[RecDataLoader Info] Streaming mode enforces batch_size=1; tune chunk_size to control memory/throughput.")
|
|
191
191
|
dataset = FileDataset(file_paths=file_paths, dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target_columns=self.target_columns, id_columns=self.id_columns, chunk_size=chunk_size, file_type=file_type, processor=self.processor)
|
|
192
192
|
return DataLoader(dataset, batch_size=1, collate_fn=collate_fn)
|
|
193
193
|
|
nextrec/data/preprocessor.py
CHANGED
|
@@ -38,26 +38,6 @@ from nextrec.__version__ import __version__
|
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
class DataProcessor(FeatureSet):
|
|
41
|
-
"""DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
|
|
42
|
-
|
|
43
|
-
Examples:
|
|
44
|
-
>>> processor = DataProcessor()
|
|
45
|
-
>>> processor.add_numeric_feature('age', scaler='standard')
|
|
46
|
-
>>> processor.add_sparse_feature('user_id', encode_method='hash', hash_size=10000)
|
|
47
|
-
>>> processor.add_sequence_feature('item_history', encode_method='label', max_len=50, pad_value=0)
|
|
48
|
-
>>> processor.add_target('label', target_type='binary')
|
|
49
|
-
>>>
|
|
50
|
-
>>> # Fit and transform data
|
|
51
|
-
>>> processor.fit(train_df)
|
|
52
|
-
>>> processed_data = processor.transform(test_df) # Returns dict of numpy arrays
|
|
53
|
-
>>>
|
|
54
|
-
>>> # Save and load processor
|
|
55
|
-
>>> processor.save('processor.pkl')
|
|
56
|
-
>>> loaded_processor = DataProcessor.load('processor.pkl')
|
|
57
|
-
>>>
|
|
58
|
-
>>> # Get vocabulary sizes for embedding layers
|
|
59
|
-
>>> vocab_sizes = processor.get_vocab_sizes()
|
|
60
|
-
"""
|
|
61
41
|
def __init__(self):
|
|
62
42
|
self.numeric_features: Dict[str, Dict[str, Any]] = {}
|
|
63
43
|
self.sparse_features: Dict[str, Dict[str, Any]] = {}
|
|
@@ -132,10 +112,10 @@ class DataProcessor(FeatureSet):
|
|
|
132
112
|
}
|
|
133
113
|
self.set_target_id(list(self.target_features.keys()), [])
|
|
134
114
|
|
|
135
|
-
def
|
|
115
|
+
def hash_string(self, s: str, hash_size: int) -> int:
|
|
136
116
|
return int(hashlib.md5(str(s).encode()).hexdigest(), 16) % hash_size
|
|
137
117
|
|
|
138
|
-
def
|
|
118
|
+
def process_numeric_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
|
|
139
119
|
name = str(data.name)
|
|
140
120
|
scaler_type = config['scaler']
|
|
141
121
|
fill_na = config['fill_na']
|
|
@@ -164,7 +144,7 @@ class DataProcessor(FeatureSet):
|
|
|
164
144
|
scaler.fit(values)
|
|
165
145
|
self.scalers[name] = scaler
|
|
166
146
|
|
|
167
|
-
def
|
|
147
|
+
def process_numeric_feature_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
|
|
168
148
|
logger = logging.getLogger()
|
|
169
149
|
name = str(data.name)
|
|
170
150
|
scaler_type = config['scaler']
|
|
@@ -184,7 +164,7 @@ class DataProcessor(FeatureSet):
|
|
|
184
164
|
result = scaler.transform(values.reshape(-1, 1)).ravel()
|
|
185
165
|
return result
|
|
186
166
|
|
|
187
|
-
def
|
|
167
|
+
def process_sparse_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
|
|
188
168
|
name = str(data.name)
|
|
189
169
|
encode_method = config['encode_method']
|
|
190
170
|
fill_na = config['fill_na'] # <UNK>
|
|
@@ -197,7 +177,7 @@ class DataProcessor(FeatureSet):
|
|
|
197
177
|
elif encode_method == 'hash':
|
|
198
178
|
config['vocab_size'] = config['hash_size']
|
|
199
179
|
|
|
200
|
-
def
|
|
180
|
+
def process_sparse_feature_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
|
|
201
181
|
name = str(data.name)
|
|
202
182
|
encode_method = config['encode_method']
|
|
203
183
|
fill_na = config['fill_na']
|
|
@@ -215,11 +195,11 @@ class DataProcessor(FeatureSet):
|
|
|
215
195
|
return encoded.to_numpy()
|
|
216
196
|
if encode_method == 'hash':
|
|
217
197
|
hash_size = config['hash_size']
|
|
218
|
-
hash_fn = self.
|
|
198
|
+
hash_fn = self.hash_string
|
|
219
199
|
return np.fromiter((hash_fn(v, hash_size) for v in sparse_series.to_numpy()), dtype=np.int64, count=sparse_series.size,)
|
|
220
200
|
return np.array([], dtype=np.int64)
|
|
221
201
|
|
|
222
|
-
def
|
|
202
|
+
def process_sequence_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
|
|
223
203
|
name = str(data.name)
|
|
224
204
|
encode_method = config['encode_method']
|
|
225
205
|
separator = config['separator']
|
|
@@ -252,7 +232,7 @@ class DataProcessor(FeatureSet):
|
|
|
252
232
|
elif encode_method == 'hash':
|
|
253
233
|
config['vocab_size'] = config['hash_size']
|
|
254
234
|
|
|
255
|
-
def
|
|
235
|
+
def process_sequence_feature_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
|
|
256
236
|
"""Optimized sequence transform with preallocation and cached vocab map."""
|
|
257
237
|
name = str(data.name)
|
|
258
238
|
encode_method = config['encode_method']
|
|
@@ -276,7 +256,7 @@ class DataProcessor(FeatureSet):
|
|
|
276
256
|
config['_class_to_idx'] = class_to_idx
|
|
277
257
|
else:
|
|
278
258
|
class_to_idx = None # type: ignore
|
|
279
|
-
hash_fn = self.
|
|
259
|
+
hash_fn = self.hash_string
|
|
280
260
|
hash_size = config.get('hash_size')
|
|
281
261
|
for i, seq in enumerate(arr):
|
|
282
262
|
# normalize sequence to a list of strings
|
|
@@ -301,11 +281,7 @@ class DataProcessor(FeatureSet):
|
|
|
301
281
|
elif encode_method == 'hash':
|
|
302
282
|
if hash_size is None:
|
|
303
283
|
raise ValueError("hash_size must be set for hash encoding")
|
|
304
|
-
encoded = [
|
|
305
|
-
hash_fn(str(token), hash_size)
|
|
306
|
-
for token in tokens
|
|
307
|
-
if str(token).strip()
|
|
308
|
-
]
|
|
284
|
+
encoded = [hash_fn(str(token), hash_size) for token in tokens if str(token).strip()]
|
|
309
285
|
else:
|
|
310
286
|
encoded = []
|
|
311
287
|
if not encoded:
|
|
@@ -315,7 +291,7 @@ class DataProcessor(FeatureSet):
|
|
|
315
291
|
output[i, : len(encoded)] = encoded
|
|
316
292
|
return output
|
|
317
293
|
|
|
318
|
-
def
|
|
294
|
+
def process_target_fit(self, data: pd.Series, config: Dict[str, Any]):
|
|
319
295
|
name = str(data.name)
|
|
320
296
|
target_type = config['target_type']
|
|
321
297
|
label_map = config.get('label_map')
|
|
@@ -334,7 +310,7 @@ class DataProcessor(FeatureSet):
|
|
|
334
310
|
config['label_map'] = label_map
|
|
335
311
|
self.target_encoders[name] = label_map
|
|
336
312
|
|
|
337
|
-
def
|
|
313
|
+
def process_target_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
|
|
338
314
|
logger = logging.getLogger()
|
|
339
315
|
name = str(data.name)
|
|
340
316
|
target_type = config.get('target_type')
|
|
@@ -355,13 +331,13 @@ class DataProcessor(FeatureSet):
|
|
|
355
331
|
result.append(0)
|
|
356
332
|
return np.array(result, dtype=np.int64 if target_type == 'multiclass' else np.float32)
|
|
357
333
|
|
|
358
|
-
def
|
|
334
|
+
def load_dataframe_from_path(self, path: str) -> pd.DataFrame:
|
|
359
335
|
"""Load all data from a file or directory path into a single DataFrame."""
|
|
360
336
|
file_paths, file_type = resolve_file_paths(path)
|
|
361
337
|
frames = load_dataframes(file_paths, file_type)
|
|
362
338
|
return pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
|
|
363
339
|
|
|
364
|
-
def
|
|
340
|
+
def extract_sequence_tokens(self, value: Any, separator: str) -> list[str]:
|
|
365
341
|
"""Extract sequence tokens from a single value."""
|
|
366
342
|
if value is None:
|
|
367
343
|
return []
|
|
@@ -374,7 +350,7 @@ class DataProcessor(FeatureSet):
|
|
|
374
350
|
return [str(v) for v in value]
|
|
375
351
|
return [str(value)]
|
|
376
352
|
|
|
377
|
-
def
|
|
353
|
+
def fit_from_path(self, path: str, chunk_size: int) -> 'DataProcessor':
|
|
378
354
|
"""Fit processor statistics by streaming files to reduce memory usage."""
|
|
379
355
|
logger = logging.getLogger()
|
|
380
356
|
logger.info(colorize("Fitting DataProcessor (streaming path mode)...", color="cyan", bold=True))
|
|
@@ -433,7 +409,7 @@ class DataProcessor(FeatureSet):
|
|
|
433
409
|
series = chunk[name]
|
|
434
410
|
tokens = []
|
|
435
411
|
for val in series:
|
|
436
|
-
tokens.extend(self.
|
|
412
|
+
tokens.extend(self.extract_sequence_tokens(val, separator))
|
|
437
413
|
seq_vocab[name].update(tokens)
|
|
438
414
|
|
|
439
415
|
# target features
|
|
@@ -548,7 +524,7 @@ class DataProcessor(FeatureSet):
|
|
|
548
524
|
logger.info(colorize("DataProcessor fitted successfully (streaming path mode)", color="green", bold=True))
|
|
549
525
|
return self
|
|
550
526
|
|
|
551
|
-
def
|
|
527
|
+
def transform_in_memory(
|
|
552
528
|
self,
|
|
553
529
|
data: Union[pd.DataFrame, Dict[str, Any]],
|
|
554
530
|
return_dict: bool,
|
|
@@ -581,7 +557,7 @@ class DataProcessor(FeatureSet):
|
|
|
581
557
|
continue
|
|
582
558
|
# Convert to Series for processing
|
|
583
559
|
series_data = pd.Series(data_dict[name], name=name)
|
|
584
|
-
processed = self.
|
|
560
|
+
processed = self.process_numeric_feature_transform(series_data, config)
|
|
585
561
|
result_dict[name] = processed
|
|
586
562
|
|
|
587
563
|
# process sparse features
|
|
@@ -590,7 +566,7 @@ class DataProcessor(FeatureSet):
|
|
|
590
566
|
logger.warning(f"Sparse feature {name} not found in data")
|
|
591
567
|
continue
|
|
592
568
|
series_data = pd.Series(data_dict[name], name=name)
|
|
593
|
-
processed = self.
|
|
569
|
+
processed = self.process_sparse_feature_transform(series_data, config)
|
|
594
570
|
result_dict[name] = processed
|
|
595
571
|
|
|
596
572
|
# process sequence features
|
|
@@ -599,7 +575,7 @@ class DataProcessor(FeatureSet):
|
|
|
599
575
|
logger.warning(f"Sequence feature {name} not found in data")
|
|
600
576
|
continue
|
|
601
577
|
series_data = pd.Series(data_dict[name], name=name)
|
|
602
|
-
processed = self.
|
|
578
|
+
processed = self.process_sequence_feature_transform(series_data, config)
|
|
603
579
|
result_dict[name] = processed
|
|
604
580
|
|
|
605
581
|
# process target features
|
|
@@ -608,10 +584,10 @@ class DataProcessor(FeatureSet):
|
|
|
608
584
|
logger.warning(f"Target {name} not found in data")
|
|
609
585
|
continue
|
|
610
586
|
series_data = pd.Series(data_dict[name], name=name)
|
|
611
|
-
processed = self.
|
|
587
|
+
processed = self.process_target_transform(series_data, config)
|
|
612
588
|
result_dict[name] = processed
|
|
613
589
|
|
|
614
|
-
def
|
|
590
|
+
def dict_to_dataframe(result: Dict[str, np.ndarray]) -> pd.DataFrame:
|
|
615
591
|
# Convert all arrays to Series/lists at once to avoid fragmentation
|
|
616
592
|
columns_dict = {}
|
|
617
593
|
for key, value in result.items():
|
|
@@ -629,7 +605,7 @@ class DataProcessor(FeatureSet):
|
|
|
629
605
|
effective_format = save_format or "parquet"
|
|
630
606
|
result_df = None
|
|
631
607
|
if (not return_dict) or persist:
|
|
632
|
-
result_df =
|
|
608
|
+
result_df = dict_to_dataframe(result_dict)
|
|
633
609
|
if persist:
|
|
634
610
|
if output_path is None:
|
|
635
611
|
raise ValueError("output_path must be provided when persisting transformed data.")
|
|
@@ -649,7 +625,7 @@ class DataProcessor(FeatureSet):
|
|
|
649
625
|
assert result_df is not None, "DataFrame is None after transform"
|
|
650
626
|
return result_df
|
|
651
627
|
|
|
652
|
-
def
|
|
628
|
+
def transform_path(
|
|
653
629
|
self,
|
|
654
630
|
input_path: str,
|
|
655
631
|
output_path: Optional[str],
|
|
@@ -669,13 +645,7 @@ class DataProcessor(FeatureSet):
|
|
|
669
645
|
saved_paths = []
|
|
670
646
|
for file_path in tqdm.tqdm(file_paths, desc="Transforming files", unit="file"):
|
|
671
647
|
df = read_table(file_path, file_type)
|
|
672
|
-
transformed_df = self.
|
|
673
|
-
df,
|
|
674
|
-
return_dict=False,
|
|
675
|
-
persist=False,
|
|
676
|
-
save_format=None,
|
|
677
|
-
output_path=None,
|
|
678
|
-
)
|
|
648
|
+
transformed_df = self.transform_in_memory(df, return_dict=False, persist=False, save_format=None, output_path=None)
|
|
679
649
|
assert isinstance(transformed_df, pd.DataFrame), "Expected DataFrame when return_dict=False"
|
|
680
650
|
source_path = Path(file_path)
|
|
681
651
|
target_file = output_root / f"{source_path.stem}.{target_format}"
|
|
@@ -695,9 +665,9 @@ class DataProcessor(FeatureSet):
|
|
|
695
665
|
uses_robust = any(cfg.get("scaler") == "robust" for cfg in self.numeric_features.values())
|
|
696
666
|
if uses_robust:
|
|
697
667
|
logger.warning("Robust scaler requires full data; loading all files into memory. Consider smaller chunk_size or different scaler if memory is limited.")
|
|
698
|
-
data = self.
|
|
668
|
+
data = self.load_dataframe_from_path(path_str)
|
|
699
669
|
else:
|
|
700
|
-
return self.
|
|
670
|
+
return self.fit_from_path(path_str, chunk_size)
|
|
701
671
|
if isinstance(data, dict):
|
|
702
672
|
data = pd.DataFrame(data)
|
|
703
673
|
logger.info(colorize("Fitting DataProcessor...", color="cyan", bold=True))
|
|
@@ -705,22 +675,22 @@ class DataProcessor(FeatureSet):
|
|
|
705
675
|
if name not in data.columns:
|
|
706
676
|
logger.warning(f"Numeric feature {name} not found in data")
|
|
707
677
|
continue
|
|
708
|
-
self.
|
|
678
|
+
self.process_numeric_feature_fit(data[name], config)
|
|
709
679
|
for name, config in self.sparse_features.items():
|
|
710
680
|
if name not in data.columns:
|
|
711
681
|
logger.warning(f"Sparse feature {name} not found in data")
|
|
712
682
|
continue
|
|
713
|
-
self.
|
|
683
|
+
self.process_sparse_feature_fit(data[name], config)
|
|
714
684
|
for name, config in self.sequence_features.items():
|
|
715
685
|
if name not in data.columns:
|
|
716
686
|
logger.warning(f"Sequence feature {name} not found in data")
|
|
717
687
|
continue
|
|
718
|
-
self.
|
|
688
|
+
self.process_sequence_feature_fit(data[name], config)
|
|
719
689
|
for name, config in self.target_features.items():
|
|
720
690
|
if name not in data.columns:
|
|
721
691
|
logger.warning(f"Target {name} not found in data")
|
|
722
692
|
continue
|
|
723
|
-
self.
|
|
693
|
+
self.process_target_fit(data[name], config)
|
|
724
694
|
self.is_fitted = True
|
|
725
695
|
return self
|
|
726
696
|
|
|
@@ -736,14 +706,8 @@ class DataProcessor(FeatureSet):
|
|
|
736
706
|
if isinstance(data, (str, os.PathLike)):
|
|
737
707
|
if return_dict:
|
|
738
708
|
raise ValueError("Path transform writes files only; set return_dict=False when passing a path.")
|
|
739
|
-
return self.
|
|
740
|
-
return self.
|
|
741
|
-
data=data,
|
|
742
|
-
return_dict=return_dict,
|
|
743
|
-
persist=output_path is not None,
|
|
744
|
-
save_format=save_format,
|
|
745
|
-
output_path=output_path,
|
|
746
|
-
)
|
|
709
|
+
return self.transform_path(str(data), output_path, save_format)
|
|
710
|
+
return self.transform_in_memory(data=data, return_dict=return_dict, persist=output_path is not None, save_format=save_format, output_path=output_path)
|
|
747
711
|
|
|
748
712
|
def fit_transform(
|
|
749
713
|
self,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nextrec
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.4
|
|
4
4
|
Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
|
|
5
5
|
Project-URL: Homepage, https://github.com/zerolovesea/NextRec
|
|
6
6
|
Project-URL: Repository, https://github.com/zerolovesea/NextRec
|
|
@@ -63,7 +63,7 @@ Description-Content-Type: text/markdown
|
|
|
63
63
|

|
|
64
64
|

|
|
65
65
|

|
|
66
|
-

|
|
67
67
|
|
|
68
68
|
English | [中文文档](README_zh.md)
|
|
69
69
|
|
|
@@ -110,7 +110,7 @@ To dive deeper, Jupyter notebooks are available:
|
|
|
110
110
|
- [Hands on the NextRec framework](/tutorials/notebooks/en/Hands%20on%20nextrec.ipynb)
|
|
111
111
|
- [Using the data processor for preprocessing](/tutorials/notebooks/en/Hands%20on%20dataprocessor.ipynb)
|
|
112
112
|
|
|
113
|
-
> Current version [0.3.
|
|
113
|
+
> Current version [0.3.4]: the matching module is not fully polished yet and may have compatibility issues or unexpected errors. Please raise an issue if you run into problems.
|
|
114
114
|
|
|
115
115
|
## 5-Minute Quick Start
|
|
116
116
|
|
|
@@ -1,18 +1,18 @@
|
|
|
1
1
|
nextrec/__init__.py,sha256=CvocnY2uBp0cjNkhrT6ogw0q2bN9s1GNp754FLO-7lo,1117
|
|
2
|
-
nextrec/__version__.py,sha256=
|
|
2
|
+
nextrec/__version__.py,sha256=oYLGMpySamd16KLiaBTfRyrAS7_oyp-TOEHmzmeumwg,22
|
|
3
3
|
nextrec/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
4
|
nextrec/basic/activation.py,sha256=1qs9pq4hT3BUxIiYdYs57axMCm4-JyOBFQ6x7xkHTwM,2849
|
|
5
5
|
nextrec/basic/callback.py,sha256=wwh0I2kKYyywCB-sG9eQXShlpXFJIo75qApJmnI5p6c,1036
|
|
6
6
|
nextrec/basic/features.py,sha256=-RRRbEPU-SFI-GtppflW6O0bKShUsV-Hg_lTGpo3AIE,4262
|
|
7
7
|
nextrec/basic/layers.py,sha256=zzEseKYVnMVs1Tg5EGrFimugId15jI6HumgzjFyRqgw,23127
|
|
8
|
-
nextrec/basic/loggers.py,sha256=
|
|
8
|
+
nextrec/basic/loggers.py,sha256=hh9tRMmaCTaJ_sfRHIlbcqd6BcpK63vpZ_21TFCiKLI,6148
|
|
9
9
|
nextrec/basic/metrics.py,sha256=8-hMZJXU5L4F8GnToxMZey5dlBrtFyRtTuI_zoQCtIo,21579
|
|
10
|
-
nextrec/basic/model.py,sha256=
|
|
10
|
+
nextrec/basic/model.py,sha256=afnvicyxXMgWdvhrIUaoNnZ7S-QYRYr7fTY5bdM1u_s,68829
|
|
11
11
|
nextrec/basic/session.py,sha256=oaATn-nzbJ9A6SGbMut9xLV_NSh9_1KmVDeNauS06Ps,4767
|
|
12
12
|
nextrec/data/__init__.py,sha256=6WgXZafzzXcv5kuxKNi67O8BJZVl_P_HM2IZCDIIhPA,1052
|
|
13
13
|
nextrec/data/data_utils.py,sha256=aOyja3Yu7O2c8eIeL3P8MyUlUR5EerOUT9UeF4ATq8o,10574
|
|
14
|
-
nextrec/data/dataloader.py,sha256=
|
|
15
|
-
nextrec/data/preprocessor.py,sha256=
|
|
14
|
+
nextrec/data/dataloader.py,sha256=2MLe69y0E1cTZyzMNgyLUCxa6lllGd1ntvwpXzxdX10,14199
|
|
15
|
+
nextrec/data/preprocessor.py,sha256=lhigpjvkEqsjTRfbBBOjgGOxoPyOifwq2LoswgyIVqc,40488
|
|
16
16
|
nextrec/loss/__init__.py,sha256=mO5t417BneZ8Ysa51GyjDaffjWyjzFgPXIQrrggasaQ,827
|
|
17
17
|
nextrec/loss/listwise.py,sha256=gxDbO1td5IeS28jKzdE35o1KAYBRdCYoMzyZzfNLhc0,5689
|
|
18
18
|
nextrec/loss/loss_utils.py,sha256=uZ4m9ChLr-UgIc5Yxm1LjwXDDepApQ-Fas8njweZ9qg,2641
|
|
@@ -51,7 +51,7 @@ nextrec/utils/common.py,sha256=NYXnBVtUCtm8epT2ZxJHn_m1SIBBI_PEjZ5VpL465ls,2009
|
|
|
51
51
|
nextrec/utils/embedding.py,sha256=yxYSdFx0cJITh3Gf-K4SdhwRtKGcI0jOsyBgZ0NLa_c,465
|
|
52
52
|
nextrec/utils/initializer.py,sha256=ffYOs5QuIns_d_-5e40iNtg6s1ftgREJN-ueq_NbDQE,1647
|
|
53
53
|
nextrec/utils/optimizer.py,sha256=EUjAGFPeyou_Cv-_2HRvjzut8y_qpAQudc8L2T0k8zw,2706
|
|
54
|
-
nextrec-0.3.
|
|
55
|
-
nextrec-0.3.
|
|
56
|
-
nextrec-0.3.
|
|
57
|
-
nextrec-0.3.
|
|
54
|
+
nextrec-0.3.4.dist-info/METADATA,sha256=X5fo5gymQdPXLgM1N03E58uFSQyuQOmdbUp8vXvKl0g,16319
|
|
55
|
+
nextrec-0.3.4.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
56
|
+
nextrec-0.3.4.dist-info/licenses/LICENSE,sha256=2fQfVKeafywkni7MYHyClC6RGGC3laLTXCNBx-ubtp0,1064
|
|
57
|
+
nextrec-0.3.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|