nextrec 0.4.8__py3-none-any.whl → 0.4.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/callback.py +30 -15
  3. nextrec/basic/features.py +1 -0
  4. nextrec/basic/layers.py +6 -8
  5. nextrec/basic/loggers.py +14 -7
  6. nextrec/basic/metrics.py +6 -76
  7. nextrec/basic/model.py +316 -321
  8. nextrec/cli.py +185 -43
  9. nextrec/data/__init__.py +13 -16
  10. nextrec/data/batch_utils.py +3 -2
  11. nextrec/data/data_processing.py +10 -2
  12. nextrec/data/data_utils.py +9 -14
  13. nextrec/data/dataloader.py +31 -33
  14. nextrec/data/preprocessor.py +328 -255
  15. nextrec/loss/__init__.py +1 -5
  16. nextrec/loss/loss_utils.py +2 -8
  17. nextrec/models/generative/__init__.py +1 -8
  18. nextrec/models/generative/hstu.py +6 -4
  19. nextrec/models/multi_task/esmm.py +2 -2
  20. nextrec/models/multi_task/mmoe.py +2 -2
  21. nextrec/models/multi_task/ple.py +2 -2
  22. nextrec/models/multi_task/poso.py +2 -3
  23. nextrec/models/multi_task/share_bottom.py +2 -2
  24. nextrec/models/ranking/afm.py +2 -2
  25. nextrec/models/ranking/autoint.py +2 -2
  26. nextrec/models/ranking/dcn.py +2 -2
  27. nextrec/models/ranking/dcn_v2.py +2 -2
  28. nextrec/models/ranking/deepfm.py +6 -7
  29. nextrec/models/ranking/dien.py +3 -3
  30. nextrec/models/ranking/din.py +3 -3
  31. nextrec/models/ranking/eulernet.py +365 -0
  32. nextrec/models/ranking/fibinet.py +5 -5
  33. nextrec/models/ranking/fm.py +3 -7
  34. nextrec/models/ranking/lr.py +120 -0
  35. nextrec/models/ranking/masknet.py +2 -2
  36. nextrec/models/ranking/pnn.py +2 -2
  37. nextrec/models/ranking/widedeep.py +2 -2
  38. nextrec/models/ranking/xdeepfm.py +2 -2
  39. nextrec/models/representation/__init__.py +9 -0
  40. nextrec/models/{generative → representation}/rqvae.py +9 -9
  41. nextrec/models/retrieval/__init__.py +0 -0
  42. nextrec/models/{match → retrieval}/dssm.py +8 -3
  43. nextrec/models/{match → retrieval}/dssm_v2.py +8 -3
  44. nextrec/models/{match → retrieval}/mind.py +4 -3
  45. nextrec/models/{match → retrieval}/sdm.py +4 -3
  46. nextrec/models/{match → retrieval}/youtube_dnn.py +8 -3
  47. nextrec/utils/__init__.py +60 -46
  48. nextrec/utils/config.py +8 -7
  49. nextrec/utils/console.py +371 -0
  50. nextrec/utils/{synthetic_data.py → data.py} +102 -15
  51. nextrec/utils/feature.py +15 -0
  52. nextrec/utils/torch_utils.py +411 -0
  53. {nextrec-0.4.8.dist-info → nextrec-0.4.10.dist-info}/METADATA +6 -7
  54. nextrec-0.4.10.dist-info/RECORD +70 -0
  55. nextrec/utils/cli_utils.py +0 -58
  56. nextrec/utils/device.py +0 -78
  57. nextrec/utils/distributed.py +0 -141
  58. nextrec/utils/file.py +0 -92
  59. nextrec/utils/initializer.py +0 -79
  60. nextrec/utils/optimizer.py +0 -75
  61. nextrec/utils/tensor.py +0 -72
  62. nextrec-0.4.8.dist-info/RECORD +0 -71
  63. /nextrec/models/{match/__init__.py → ranking/ffm.py} +0 -0
  64. {nextrec-0.4.8.dist-info → nextrec-0.4.10.dist-info}/WHEEL +0 -0
  65. {nextrec-0.4.8.dist-info → nextrec-0.4.10.dist-info}/entry_points.txt +0 -0
  66. {nextrec-0.4.8.dist-info → nextrec-0.4.10.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,371 @@
1
+ """
2
+ Console and CLI utilities for NextRec.
3
+
4
+ This module centralizes CLI logging helpers, progress display, and metric tables.
5
+
6
+ Date: create on 19/12/2025
7
+ Checkpoint: edit on 19/12/2025
8
+ Author: Yang Zhou, zyaztec@gmail.com
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import io
14
+ import logging
15
+ import numbers
16
+ import os
17
+ import platform
18
+ import sys
19
+ from datetime import datetime, timedelta
20
+ from typing import Any, Callable, Iterable, Mapping, TypeVar
21
+
22
+ import numpy as np
23
+ from rich import box
24
+ from rich.console import Console
25
+ from rich.progress import (
26
+ BarColumn,
27
+ MofNCompleteColumn,
28
+ Progress,
29
+ SpinnerColumn,
30
+ TaskProgressColumn,
31
+ TextColumn,
32
+ TimeElapsedColumn,
33
+ TimeRemainingColumn,
34
+ )
35
+ from rich.table import Table
36
+ from rich.text import Text
37
+
38
+ from nextrec.utils.feature import as_float, normalize_to_list
39
+
40
+ T = TypeVar("T")
41
+
42
+
43
+ def get_nextrec_version() -> str:
44
+ """
45
+ Best-effort version resolver for NextRec.
46
+
47
+ Prefer in-repo `nextrec.__version__`, fall back to installed package metadata.
48
+ """
49
+ try:
50
+ from nextrec import __version__ # type: ignore
51
+
52
+ if __version__:
53
+ return str(__version__)
54
+ except Exception:
55
+ pass
56
+
57
+ try:
58
+ from importlib.metadata import version
59
+
60
+ return version("nextrec")
61
+ except Exception:
62
+ return "unknown"
63
+
64
+
65
+ def log_startup_info(
66
+ logger: logging.Logger, *, mode: str, config_path: str | None
67
+ ) -> None:
68
+ """Log a short, user-friendly startup banner."""
69
+ version = get_nextrec_version()
70
+ now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
71
+
72
+ lines = [
73
+ "NextRec CLI",
74
+ f"- Version: {version}",
75
+ f"- Time: {now}",
76
+ f"- Mode: {mode}",
77
+ f"- Config: {config_path or '(not set)'}",
78
+ f"- Python: {platform.python_version()} ({sys.executable})",
79
+ f"- Platform: {platform.system()} {platform.release()} ({platform.machine()})",
80
+ f"- Workdir: {os.getcwd()}",
81
+ f"- Command: {' '.join(sys.argv)}",
82
+ ]
83
+ for line in lines:
84
+ logger.info(line)
85
+
86
+
87
+ class BlackTimeElapsedColumn(TimeElapsedColumn):
88
+ def render(self, task) -> Text:
89
+ elapsed = task.finished_time if task.finished else task.elapsed
90
+ if elapsed is None:
91
+ return Text("-:--:--", style="black")
92
+ delta = timedelta(seconds=max(0, int(elapsed)))
93
+ return Text(str(delta), style="black")
94
+
95
+
96
+ class BlackTimeRemainingColumn(TimeRemainingColumn):
97
+ def render(self, task) -> Text:
98
+ if self.elapsed_when_finished and task.finished:
99
+ task_time = task.finished_time
100
+ else:
101
+ task_time = task.time_remaining
102
+
103
+ if task.total is None:
104
+ return Text("", style="black")
105
+
106
+ if task_time is None:
107
+ return Text("--:--" if self.compact else "-:--:--", style="black")
108
+
109
+ minutes, seconds = divmod(int(task_time), 60)
110
+ hours, minutes = divmod(minutes, 60)
111
+
112
+ if self.compact and not hours:
113
+ formatted = f"{minutes:02d}:{seconds:02d}"
114
+ else:
115
+ formatted = f"{hours:d}:{minutes:02d}:{seconds:02d}"
116
+
117
+ return Text(formatted, style="black")
118
+
119
+
120
+ class BlackMofNCompleteColumn(MofNCompleteColumn):
121
+ def render(self, task) -> Text:
122
+ completed = int(task.completed)
123
+ total = int(task.total) if task.total is not None else "?"
124
+ total_width = len(str(total))
125
+ return Text(
126
+ f"{completed:{total_width}d}{self.separator}{total}",
127
+ style="black",
128
+ )
129
+
130
+
131
+ def progress(
132
+ iterable: Iterable[T],
133
+ *,
134
+ description: str | None = None,
135
+ total: int | None = None,
136
+ disable: bool = False,
137
+ ) -> Iterable[T]:
138
+ if disable:
139
+ for item in iterable:
140
+ yield item
141
+ return
142
+ resolved_total = total
143
+ if resolved_total is None:
144
+ try:
145
+ resolved_total = len(iterable) # type: ignore[arg-type]
146
+ except TypeError:
147
+ resolved_total = None
148
+
149
+ progress_bar = Progress(
150
+ SpinnerColumn(style="black"),
151
+ TextColumn("{task.description}", style="black"),
152
+ BarColumn(
153
+ bar_width=36, style="black", complete_style="black", finished_style="black"
154
+ ),
155
+ TaskProgressColumn(style="black"),
156
+ BlackMofNCompleteColumn(),
157
+ BlackTimeElapsedColumn(),
158
+ BlackTimeRemainingColumn(),
159
+ refresh_per_second=12,
160
+ )
161
+
162
+ task_id = progress_bar.add_task(description or "Working", total=resolved_total)
163
+ progress_bar.start()
164
+ try:
165
+ for item in iterable:
166
+ yield item
167
+ progress_bar.advance(task_id, 1)
168
+ finally:
169
+ progress_bar.stop()
170
+
171
+
172
+ def group_metrics_by_task(
173
+ metrics: Mapping[str, Any] | None,
174
+ target_names: list[str] | str | None,
175
+ default_task_name: str = "overall",
176
+ ) -> tuple[list[str], dict[str, dict[str, float]]]:
177
+ if not metrics:
178
+ return [], {}
179
+
180
+ if isinstance(target_names, str):
181
+ target_names = [target_names]
182
+ if not isinstance(target_names, list) or not target_names:
183
+ target_names = [default_task_name]
184
+
185
+ targets_by_len = sorted(target_names, key=len, reverse=True)
186
+ grouped: dict[str, dict[str, float]] = {}
187
+ for key, raw_value in metrics.items():
188
+ value = as_float(raw_value)
189
+ if value is None:
190
+ continue
191
+
192
+ matched_target: str | None = None
193
+ metric_name = key
194
+ for target in targets_by_len:
195
+ suffix = f"_{target}"
196
+ if key.endswith(suffix):
197
+ metric_name = key[: -len(suffix)]
198
+ matched_target = target
199
+ break
200
+
201
+ if matched_target is None:
202
+ matched_target = (
203
+ target_names[0] if len(target_names) == 1 else default_task_name
204
+ )
205
+ grouped.setdefault(matched_target, {})[metric_name] = value
206
+
207
+ task_order: list[str] = []
208
+ for target in target_names:
209
+ if target in grouped:
210
+ task_order.append(target)
211
+ for task_name in grouped:
212
+ if task_name not in task_order:
213
+ task_order.append(task_name)
214
+ return task_order, grouped
215
+
216
+
217
+ def display_metrics_table(
218
+ epoch: int,
219
+ epochs: int,
220
+ split: str,
221
+ loss: float | np.ndarray | None,
222
+ metrics: Mapping[str, Any] | None,
223
+ target_names: list[str] | str | None,
224
+ base_metrics: list[str] | None = None,
225
+ is_main_process: bool = True,
226
+ colorize: Callable[[str], str] | None = None,
227
+ ) -> None:
228
+ if not is_main_process:
229
+ return
230
+
231
+ target_list = normalize_to_list(target_names)
232
+ task_order, grouped = group_metrics_by_task(metrics, target_names=target_names)
233
+
234
+ if isinstance(loss, np.ndarray) and target_list:
235
+ # Ensure tasks with losses are shown even when metrics are missing for some targets.
236
+ normalized_order: list[str] = []
237
+ for name in target_list:
238
+ if name not in normalized_order:
239
+ normalized_order.append(name)
240
+ for name in task_order:
241
+ if name not in normalized_order:
242
+ normalized_order.append(name)
243
+ task_order = normalized_order
244
+
245
+ if Console is None or Table is None or box is None:
246
+ prefix = f"Epoch {epoch}/{epochs} - {split}:"
247
+ segments: list[str] = []
248
+ if isinstance(loss, numbers.Number):
249
+ segments.append(f"loss={float(loss):.4f}")
250
+ if task_order and grouped:
251
+ task_strs: list[str] = []
252
+ for task_name in task_order:
253
+ metric_items = grouped.get(task_name, {})
254
+ if not metric_items:
255
+ continue
256
+ metric_str = ", ".join(
257
+ f"{k}={float(v):.4f}" for k, v in metric_items.items()
258
+ )
259
+ task_strs.append(f"{task_name}[{metric_str}]")
260
+ if task_strs:
261
+ segments.append(", ".join(task_strs))
262
+ elif metrics:
263
+ metric_str = ", ".join(
264
+ f"{k}={float(v):.4f}"
265
+ for k, v in metrics.items()
266
+ if as_float(v) is not None
267
+ )
268
+ if metric_str:
269
+ segments.append(metric_str)
270
+ if not segments:
271
+ return
272
+ msg = f"{prefix} " + ", ".join(segments)
273
+ if colorize is not None:
274
+ msg = colorize(msg)
275
+ logging.info(msg)
276
+ return
277
+
278
+ title = f"Epoch {epoch}/{epochs} - {split}"
279
+ if isinstance(loss, numbers.Number):
280
+ title += f" (loss={float(loss):.4f})"
281
+
282
+ table = Table(
283
+ title=title,
284
+ box=box.ROUNDED,
285
+ header_style="bold",
286
+ title_style="bold",
287
+ )
288
+ table.add_column("Task", style="bold")
289
+
290
+ include_loss = isinstance(loss, np.ndarray)
291
+ if include_loss:
292
+ table.add_column("loss", justify="right")
293
+
294
+ metric_names: list[str] = []
295
+ for task_name in task_order:
296
+ for metric_name in grouped.get(task_name, {}):
297
+ if metric_name not in metric_names:
298
+ metric_names.append(metric_name)
299
+
300
+ preferred_order: list[str] = []
301
+ if isinstance(base_metrics, list):
302
+ preferred_order = [m for m in base_metrics if m in metric_names]
303
+ remaining = [m for m in metric_names if m not in preferred_order]
304
+ metric_names = preferred_order + sorted(remaining)
305
+
306
+ for metric_name in metric_names:
307
+ table.add_column(metric_name, justify="right")
308
+
309
+ def fmt(value: float | None) -> str:
310
+ if value is None:
311
+ return "-"
312
+ if np.isnan(value):
313
+ return "nan"
314
+ if np.isinf(value):
315
+ return "inf" if value > 0 else "-inf"
316
+ return f"{value:.4f}"
317
+
318
+ loss_by_task: dict[str, float] = {}
319
+ if isinstance(loss, np.ndarray):
320
+ if target_list:
321
+ for i, task_name in enumerate(target_list):
322
+ if i < loss.shape[0]:
323
+ loss_by_task[task_name] = float(loss[i])
324
+ if "overall" in task_order and "overall" not in loss_by_task:
325
+ loss_by_task["overall"] = float(np.sum(loss))
326
+ elif task_order:
327
+ for i, task_name in enumerate(task_order):
328
+ if i < loss.shape[0]:
329
+ loss_by_task[task_name] = float(loss[i])
330
+ else:
331
+ task_order = ["overall"]
332
+ loss_by_task["overall"] = float(np.sum(loss))
333
+
334
+ if not task_order:
335
+ task_order = ["__overall__"]
336
+
337
+ for task_name in task_order:
338
+ row: list[str] = [str(task_name)]
339
+ if include_loss:
340
+ row.append(fmt(loss_by_task.get(task_name)))
341
+ for metric_name in metric_names:
342
+ row.append(fmt(grouped.get(task_name, {}).get(metric_name)))
343
+ table.add_row(*row)
344
+
345
+ Console().print(table)
346
+
347
+ record_console = Console(file=io.StringIO(), record=True, width=120)
348
+ record_console.print(table)
349
+ table_text = record_console.export_text(styles=False).rstrip()
350
+
351
+ root_logger = logging.getLogger()
352
+ record = root_logger.makeRecord(
353
+ root_logger.name,
354
+ logging.INFO,
355
+ __file__,
356
+ 0,
357
+ "[MetricsTable]\n" + table_text,
358
+ args=(),
359
+ exc_info=None,
360
+ extra=None,
361
+ )
362
+
363
+ emitted = False
364
+ for handler in root_logger.handlers:
365
+ if isinstance(handler, logging.FileHandler):
366
+ handler.emit(record)
367
+ emitted = True
368
+
369
+ if not emitted:
370
+ # Fallback: no file handlers configured, use standard logging.
371
+ root_logger.log(logging.INFO, "[MetricsTable]\n" + table_text)
@@ -1,17 +1,101 @@
1
1
  """
2
- Synthetic Data Generation Utilities
2
+ Data utilities for NextRec.
3
3
 
4
- This module provides utilities for generating synthetic datasets for testing
5
- and tutorial purposes in the NextRec framework.
4
+ This module provides file I/O helpers and synthetic data generation.
6
5
 
7
- Date: create on 06/12/2025
6
+ Date: create on 19/12/2025
7
+ Checkpoint: edit on 19/12/2025
8
8
  Author: Yang Zhou, zyaztec@gmail.com
9
9
  """
10
10
 
11
- import torch
11
+ from __future__ import annotations
12
+
13
+ from pathlib import Path
14
+ from typing import Dict, Generator, List, Optional, Tuple
15
+
12
16
  import numpy as np
13
17
  import pandas as pd
14
- from typing import Optional, Dict, List, Tuple
18
+ import pyarrow.parquet as pq
19
+ import torch
20
+ import yaml
21
+
22
+
23
+ def resolve_file_paths(path: str) -> tuple[list[str], str]:
24
+ """
25
+ Resolve file or directory path into a sorted list of files and file type.
26
+
27
+ Args: path: Path to a file or directory
28
+ Returns: tuple: (list of file paths, file type)
29
+ """
30
+ path_obj = Path(path)
31
+
32
+ if path_obj.is_file():
33
+ file_type = path_obj.suffix.lower().lstrip(".")
34
+ assert file_type in [
35
+ "csv",
36
+ "parquet",
37
+ ], f"Unsupported file extension: {file_type}"
38
+ return [str(path_obj)], file_type
39
+
40
+ if path_obj.is_dir():
41
+ collected_files = [p for p in path_obj.iterdir() if p.is_file()]
42
+ csv_files = [str(p) for p in collected_files if p.suffix.lower() == ".csv"]
43
+ parquet_files = [
44
+ str(p) for p in collected_files if p.suffix.lower() == ".parquet"
45
+ ]
46
+
47
+ if csv_files and parquet_files:
48
+ raise ValueError(
49
+ "Directory contains both CSV and Parquet files. Please keep a single format."
50
+ )
51
+ file_paths = csv_files if csv_files else parquet_files
52
+ if not file_paths:
53
+ raise ValueError(f"No CSV or Parquet files found in directory: {path}")
54
+ file_paths.sort()
55
+ file_type = "csv" if csv_files else "parquet"
56
+ return file_paths, file_type
57
+
58
+ raise ValueError(f"Invalid path: {path}")
59
+
60
+
61
+ def read_table(path: str | Path, data_format: str | None = None) -> pd.DataFrame:
62
+ data_path = Path(path)
63
+ fmt = data_format.lower() if data_format else data_path.suffix.lower().lstrip(".")
64
+ if data_path.is_dir() and not fmt:
65
+ fmt = "parquet"
66
+ if fmt in {"parquet", ""}:
67
+ return pd.read_parquet(data_path)
68
+ if fmt in {"csv", "txt"}:
69
+ # Use low_memory=False to avoid mixed-type DtypeWarning on wide CSVs
70
+ return pd.read_csv(data_path, low_memory=False)
71
+ raise ValueError(f"Unsupported data format: {data_path}")
72
+
73
+
74
+ def load_dataframes(file_paths: list[str], file_type: str) -> list[pd.DataFrame]:
75
+ return [read_table(fp, file_type) for fp in file_paths]
76
+
77
+
78
+ def iter_file_chunks(
79
+ file_path: str, file_type: str, chunk_size: int
80
+ ) -> Generator[pd.DataFrame, None, None]:
81
+ if file_type == "csv":
82
+ yield from pd.read_csv(file_path, chunksize=chunk_size)
83
+ return
84
+ parquet_file = pq.ParquetFile(file_path)
85
+ for batch in parquet_file.iter_batches(batch_size=chunk_size):
86
+ yield batch.to_pandas()
87
+
88
+
89
+ def default_output_dir(path: str) -> Path:
90
+ path_obj = Path(path)
91
+ if path_obj.is_file():
92
+ return path_obj.parent / f"{path_obj.stem}_preprocessed"
93
+ return path_obj.with_name(f"{path_obj.name}_preprocessed")
94
+
95
+
96
+ def read_yaml(path: str | Path):
97
+ with open(path, "r", encoding="utf-8") as file:
98
+ return yaml.safe_load(file) or {}
15
99
 
16
100
 
17
101
  def generate_ranking_data(
@@ -90,13 +174,14 @@ def generate_ranking_data(
90
174
  sequence_vocabs.append(seq_vocab)
91
175
 
92
176
  if "gender" in data and "dense_0" in data:
177
+ dense_1 = data.get("dense_1", 0)
93
178
  # Complex label generation with feature correlation
94
179
  label_probs = 1 / (
95
180
  1
96
181
  + np.exp(
97
182
  -(
98
183
  data["dense_0"] * 0.3
99
- + data["dense_1"] * 0.2
184
+ + dense_1 * 0.2
100
185
  + (data["gender"] - 0.5) * 0.5
101
186
  + np.random.randn(n_samples) * 0.1
102
187
  )
@@ -112,7 +197,7 @@ def generate_ranking_data(
112
197
  print(f"Positive rate: {data['label'].mean():.4f}")
113
198
 
114
199
  # Import here to avoid circular import
115
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
200
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
116
201
 
117
202
  # Create feature definitions
118
203
  # Use input_dim for dense features to be compatible with both simple and complex scenarios
@@ -273,7 +358,7 @@ def generate_match_data(
273
358
  print(f"Positive rate: {data['label'].mean():.4f}")
274
359
 
275
360
  # Import here to avoid circular import
276
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
361
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
277
362
 
278
363
  # User dense features
279
364
  user_dense_features = [DenseFeature(name="user_age", input_dim=1)]
@@ -413,15 +498,17 @@ def generate_multitask_data(
413
498
 
414
499
  # Generate multi-task labels with correlation
415
500
  # CTR (click) is relatively easier to predict
416
- ctr_logits = (
417
- data["dense_0"] * 0.3 + data["dense_1"] * 0.2 + np.random.randn(n_samples) * 0.5
418
- )
501
+ dense_0 = data.get("dense_0", 0)
502
+ dense_1 = data.get("dense_1", 0)
503
+ dense_2 = data.get("dense_2", 0)
504
+ dense_3 = data.get("dense_3", 0)
505
+ ctr_logits = dense_0 * 0.3 + dense_1 * 0.2 + np.random.randn(n_samples) * 0.5
419
506
  data["click"] = (1 / (1 + np.exp(-ctr_logits)) > 0.5).astype(np.float32)
420
507
 
421
508
  # CVR (conversion) depends on click and is harder
422
509
  cvr_logits = (
423
- data["dense_2"] * 0.2
424
- + data["dense_3"] * 0.15
510
+ dense_2 * 0.2
511
+ + dense_3 * 0.15
425
512
  + data["click"] * 1.5 # Strong dependency on click
426
513
  + np.random.randn(n_samples) * 0.8
427
514
  )
@@ -441,7 +528,7 @@ def generate_multitask_data(
441
528
  print(f"CTCVR rate: {data['ctcvr'].mean():.4f}")
442
529
 
443
530
  # Import here to avoid circular import
444
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
531
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
445
532
 
446
533
  # Create feature definitions
447
534
  dense_features = [
nextrec/utils/feature.py CHANGED
@@ -2,9 +2,13 @@
2
2
  Feature processing utilities for NextRec
3
3
 
4
4
  Date: create on 03/12/2025
5
+ Checkpoint: edit on 19/12/2025
5
6
  Author: Yang Zhou, zyaztec@gmail.com
6
7
  """
7
8
 
9
+ import numbers
10
+ from typing import Any
11
+
8
12
 
9
13
  def normalize_to_list(value: str | list[str] | None) -> list[str]:
10
14
  if value is None:
@@ -12,3 +16,14 @@ def normalize_to_list(value: str | list[str] | None) -> list[str]:
12
16
  if isinstance(value, str):
13
17
  return [value]
14
18
  return list(value)
19
+
20
+
21
+ def as_float(value: Any) -> float | None:
22
+ if isinstance(value, numbers.Number):
23
+ return float(value)
24
+ if hasattr(value, "item"):
25
+ try:
26
+ return float(value.item())
27
+ except Exception:
28
+ return None
29
+ return None