nextrec 0.4.8__py3-none-any.whl → 0.4.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/callback.py +30 -15
- nextrec/basic/features.py +1 -0
- nextrec/basic/layers.py +6 -8
- nextrec/basic/loggers.py +14 -7
- nextrec/basic/metrics.py +6 -76
- nextrec/basic/model.py +316 -321
- nextrec/cli.py +185 -43
- nextrec/data/__init__.py +13 -16
- nextrec/data/batch_utils.py +3 -2
- nextrec/data/data_processing.py +10 -2
- nextrec/data/data_utils.py +9 -14
- nextrec/data/dataloader.py +31 -33
- nextrec/data/preprocessor.py +328 -255
- nextrec/loss/__init__.py +1 -5
- nextrec/loss/loss_utils.py +2 -8
- nextrec/models/generative/__init__.py +1 -8
- nextrec/models/generative/hstu.py +6 -4
- nextrec/models/multi_task/esmm.py +2 -2
- nextrec/models/multi_task/mmoe.py +2 -2
- nextrec/models/multi_task/ple.py +2 -2
- nextrec/models/multi_task/poso.py +2 -3
- nextrec/models/multi_task/share_bottom.py +2 -2
- nextrec/models/ranking/afm.py +2 -2
- nextrec/models/ranking/autoint.py +2 -2
- nextrec/models/ranking/dcn.py +2 -2
- nextrec/models/ranking/dcn_v2.py +2 -2
- nextrec/models/ranking/deepfm.py +6 -7
- nextrec/models/ranking/dien.py +3 -3
- nextrec/models/ranking/din.py +3 -3
- nextrec/models/ranking/eulernet.py +365 -0
- nextrec/models/ranking/fibinet.py +5 -5
- nextrec/models/ranking/fm.py +3 -7
- nextrec/models/ranking/lr.py +120 -0
- nextrec/models/ranking/masknet.py +2 -2
- nextrec/models/ranking/pnn.py +2 -2
- nextrec/models/ranking/widedeep.py +2 -2
- nextrec/models/ranking/xdeepfm.py +2 -2
- nextrec/models/representation/__init__.py +9 -0
- nextrec/models/{generative → representation}/rqvae.py +9 -9
- nextrec/models/retrieval/__init__.py +0 -0
- nextrec/models/{match → retrieval}/dssm.py +8 -3
- nextrec/models/{match → retrieval}/dssm_v2.py +8 -3
- nextrec/models/{match → retrieval}/mind.py +4 -3
- nextrec/models/{match → retrieval}/sdm.py +4 -3
- nextrec/models/{match → retrieval}/youtube_dnn.py +8 -3
- nextrec/utils/__init__.py +60 -46
- nextrec/utils/config.py +8 -7
- nextrec/utils/console.py +371 -0
- nextrec/utils/{synthetic_data.py → data.py} +102 -15
- nextrec/utils/feature.py +15 -0
- nextrec/utils/torch_utils.py +411 -0
- {nextrec-0.4.8.dist-info → nextrec-0.4.10.dist-info}/METADATA +6 -7
- nextrec-0.4.10.dist-info/RECORD +70 -0
- nextrec/utils/cli_utils.py +0 -58
- nextrec/utils/device.py +0 -78
- nextrec/utils/distributed.py +0 -141
- nextrec/utils/file.py +0 -92
- nextrec/utils/initializer.py +0 -79
- nextrec/utils/optimizer.py +0 -75
- nextrec/utils/tensor.py +0 -72
- nextrec-0.4.8.dist-info/RECORD +0 -71
- /nextrec/models/{match/__init__.py → ranking/ffm.py} +0 -0
- {nextrec-0.4.8.dist-info → nextrec-0.4.10.dist-info}/WHEEL +0 -0
- {nextrec-0.4.8.dist-info → nextrec-0.4.10.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.8.dist-info → nextrec-0.4.10.dist-info}/licenses/LICENSE +0 -0
nextrec/utils/console.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Console and CLI utilities for NextRec.
|
|
3
|
+
|
|
4
|
+
This module centralizes CLI logging helpers, progress display, and metric tables.
|
|
5
|
+
|
|
6
|
+
Date: create on 19/12/2025
|
|
7
|
+
Checkpoint: edit on 19/12/2025
|
|
8
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import io
|
|
14
|
+
import logging
|
|
15
|
+
import numbers
|
|
16
|
+
import os
|
|
17
|
+
import platform
|
|
18
|
+
import sys
|
|
19
|
+
from datetime import datetime, timedelta
|
|
20
|
+
from typing import Any, Callable, Iterable, Mapping, TypeVar
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
from rich import box
|
|
24
|
+
from rich.console import Console
|
|
25
|
+
from rich.progress import (
|
|
26
|
+
BarColumn,
|
|
27
|
+
MofNCompleteColumn,
|
|
28
|
+
Progress,
|
|
29
|
+
SpinnerColumn,
|
|
30
|
+
TaskProgressColumn,
|
|
31
|
+
TextColumn,
|
|
32
|
+
TimeElapsedColumn,
|
|
33
|
+
TimeRemainingColumn,
|
|
34
|
+
)
|
|
35
|
+
from rich.table import Table
|
|
36
|
+
from rich.text import Text
|
|
37
|
+
|
|
38
|
+
from nextrec.utils.feature import as_float, normalize_to_list
|
|
39
|
+
|
|
40
|
+
T = TypeVar("T")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_nextrec_version() -> str:
|
|
44
|
+
"""
|
|
45
|
+
Best-effort version resolver for NextRec.
|
|
46
|
+
|
|
47
|
+
Prefer in-repo `nextrec.__version__`, fall back to installed package metadata.
|
|
48
|
+
"""
|
|
49
|
+
try:
|
|
50
|
+
from nextrec import __version__ # type: ignore
|
|
51
|
+
|
|
52
|
+
if __version__:
|
|
53
|
+
return str(__version__)
|
|
54
|
+
except Exception:
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
from importlib.metadata import version
|
|
59
|
+
|
|
60
|
+
return version("nextrec")
|
|
61
|
+
except Exception:
|
|
62
|
+
return "unknown"
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def log_startup_info(
|
|
66
|
+
logger: logging.Logger, *, mode: str, config_path: str | None
|
|
67
|
+
) -> None:
|
|
68
|
+
"""Log a short, user-friendly startup banner."""
|
|
69
|
+
version = get_nextrec_version()
|
|
70
|
+
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
71
|
+
|
|
72
|
+
lines = [
|
|
73
|
+
"NextRec CLI",
|
|
74
|
+
f"- Version: {version}",
|
|
75
|
+
f"- Time: {now}",
|
|
76
|
+
f"- Mode: {mode}",
|
|
77
|
+
f"- Config: {config_path or '(not set)'}",
|
|
78
|
+
f"- Python: {platform.python_version()} ({sys.executable})",
|
|
79
|
+
f"- Platform: {platform.system()} {platform.release()} ({platform.machine()})",
|
|
80
|
+
f"- Workdir: {os.getcwd()}",
|
|
81
|
+
f"- Command: {' '.join(sys.argv)}",
|
|
82
|
+
]
|
|
83
|
+
for line in lines:
|
|
84
|
+
logger.info(line)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class BlackTimeElapsedColumn(TimeElapsedColumn):
|
|
88
|
+
def render(self, task) -> Text:
|
|
89
|
+
elapsed = task.finished_time if task.finished else task.elapsed
|
|
90
|
+
if elapsed is None:
|
|
91
|
+
return Text("-:--:--", style="black")
|
|
92
|
+
delta = timedelta(seconds=max(0, int(elapsed)))
|
|
93
|
+
return Text(str(delta), style="black")
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class BlackTimeRemainingColumn(TimeRemainingColumn):
|
|
97
|
+
def render(self, task) -> Text:
|
|
98
|
+
if self.elapsed_when_finished and task.finished:
|
|
99
|
+
task_time = task.finished_time
|
|
100
|
+
else:
|
|
101
|
+
task_time = task.time_remaining
|
|
102
|
+
|
|
103
|
+
if task.total is None:
|
|
104
|
+
return Text("", style="black")
|
|
105
|
+
|
|
106
|
+
if task_time is None:
|
|
107
|
+
return Text("--:--" if self.compact else "-:--:--", style="black")
|
|
108
|
+
|
|
109
|
+
minutes, seconds = divmod(int(task_time), 60)
|
|
110
|
+
hours, minutes = divmod(minutes, 60)
|
|
111
|
+
|
|
112
|
+
if self.compact and not hours:
|
|
113
|
+
formatted = f"{minutes:02d}:{seconds:02d}"
|
|
114
|
+
else:
|
|
115
|
+
formatted = f"{hours:d}:{minutes:02d}:{seconds:02d}"
|
|
116
|
+
|
|
117
|
+
return Text(formatted, style="black")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class BlackMofNCompleteColumn(MofNCompleteColumn):
|
|
121
|
+
def render(self, task) -> Text:
|
|
122
|
+
completed = int(task.completed)
|
|
123
|
+
total = int(task.total) if task.total is not None else "?"
|
|
124
|
+
total_width = len(str(total))
|
|
125
|
+
return Text(
|
|
126
|
+
f"{completed:{total_width}d}{self.separator}{total}",
|
|
127
|
+
style="black",
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def progress(
|
|
132
|
+
iterable: Iterable[T],
|
|
133
|
+
*,
|
|
134
|
+
description: str | None = None,
|
|
135
|
+
total: int | None = None,
|
|
136
|
+
disable: bool = False,
|
|
137
|
+
) -> Iterable[T]:
|
|
138
|
+
if disable:
|
|
139
|
+
for item in iterable:
|
|
140
|
+
yield item
|
|
141
|
+
return
|
|
142
|
+
resolved_total = total
|
|
143
|
+
if resolved_total is None:
|
|
144
|
+
try:
|
|
145
|
+
resolved_total = len(iterable) # type: ignore[arg-type]
|
|
146
|
+
except TypeError:
|
|
147
|
+
resolved_total = None
|
|
148
|
+
|
|
149
|
+
progress_bar = Progress(
|
|
150
|
+
SpinnerColumn(style="black"),
|
|
151
|
+
TextColumn("{task.description}", style="black"),
|
|
152
|
+
BarColumn(
|
|
153
|
+
bar_width=36, style="black", complete_style="black", finished_style="black"
|
|
154
|
+
),
|
|
155
|
+
TaskProgressColumn(style="black"),
|
|
156
|
+
BlackMofNCompleteColumn(),
|
|
157
|
+
BlackTimeElapsedColumn(),
|
|
158
|
+
BlackTimeRemainingColumn(),
|
|
159
|
+
refresh_per_second=12,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
task_id = progress_bar.add_task(description or "Working", total=resolved_total)
|
|
163
|
+
progress_bar.start()
|
|
164
|
+
try:
|
|
165
|
+
for item in iterable:
|
|
166
|
+
yield item
|
|
167
|
+
progress_bar.advance(task_id, 1)
|
|
168
|
+
finally:
|
|
169
|
+
progress_bar.stop()
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def group_metrics_by_task(
|
|
173
|
+
metrics: Mapping[str, Any] | None,
|
|
174
|
+
target_names: list[str] | str | None,
|
|
175
|
+
default_task_name: str = "overall",
|
|
176
|
+
) -> tuple[list[str], dict[str, dict[str, float]]]:
|
|
177
|
+
if not metrics:
|
|
178
|
+
return [], {}
|
|
179
|
+
|
|
180
|
+
if isinstance(target_names, str):
|
|
181
|
+
target_names = [target_names]
|
|
182
|
+
if not isinstance(target_names, list) or not target_names:
|
|
183
|
+
target_names = [default_task_name]
|
|
184
|
+
|
|
185
|
+
targets_by_len = sorted(target_names, key=len, reverse=True)
|
|
186
|
+
grouped: dict[str, dict[str, float]] = {}
|
|
187
|
+
for key, raw_value in metrics.items():
|
|
188
|
+
value = as_float(raw_value)
|
|
189
|
+
if value is None:
|
|
190
|
+
continue
|
|
191
|
+
|
|
192
|
+
matched_target: str | None = None
|
|
193
|
+
metric_name = key
|
|
194
|
+
for target in targets_by_len:
|
|
195
|
+
suffix = f"_{target}"
|
|
196
|
+
if key.endswith(suffix):
|
|
197
|
+
metric_name = key[: -len(suffix)]
|
|
198
|
+
matched_target = target
|
|
199
|
+
break
|
|
200
|
+
|
|
201
|
+
if matched_target is None:
|
|
202
|
+
matched_target = (
|
|
203
|
+
target_names[0] if len(target_names) == 1 else default_task_name
|
|
204
|
+
)
|
|
205
|
+
grouped.setdefault(matched_target, {})[metric_name] = value
|
|
206
|
+
|
|
207
|
+
task_order: list[str] = []
|
|
208
|
+
for target in target_names:
|
|
209
|
+
if target in grouped:
|
|
210
|
+
task_order.append(target)
|
|
211
|
+
for task_name in grouped:
|
|
212
|
+
if task_name not in task_order:
|
|
213
|
+
task_order.append(task_name)
|
|
214
|
+
return task_order, grouped
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def display_metrics_table(
|
|
218
|
+
epoch: int,
|
|
219
|
+
epochs: int,
|
|
220
|
+
split: str,
|
|
221
|
+
loss: float | np.ndarray | None,
|
|
222
|
+
metrics: Mapping[str, Any] | None,
|
|
223
|
+
target_names: list[str] | str | None,
|
|
224
|
+
base_metrics: list[str] | None = None,
|
|
225
|
+
is_main_process: bool = True,
|
|
226
|
+
colorize: Callable[[str], str] | None = None,
|
|
227
|
+
) -> None:
|
|
228
|
+
if not is_main_process:
|
|
229
|
+
return
|
|
230
|
+
|
|
231
|
+
target_list = normalize_to_list(target_names)
|
|
232
|
+
task_order, grouped = group_metrics_by_task(metrics, target_names=target_names)
|
|
233
|
+
|
|
234
|
+
if isinstance(loss, np.ndarray) and target_list:
|
|
235
|
+
# Ensure tasks with losses are shown even when metrics are missing for some targets.
|
|
236
|
+
normalized_order: list[str] = []
|
|
237
|
+
for name in target_list:
|
|
238
|
+
if name not in normalized_order:
|
|
239
|
+
normalized_order.append(name)
|
|
240
|
+
for name in task_order:
|
|
241
|
+
if name not in normalized_order:
|
|
242
|
+
normalized_order.append(name)
|
|
243
|
+
task_order = normalized_order
|
|
244
|
+
|
|
245
|
+
if Console is None or Table is None or box is None:
|
|
246
|
+
prefix = f"Epoch {epoch}/{epochs} - {split}:"
|
|
247
|
+
segments: list[str] = []
|
|
248
|
+
if isinstance(loss, numbers.Number):
|
|
249
|
+
segments.append(f"loss={float(loss):.4f}")
|
|
250
|
+
if task_order and grouped:
|
|
251
|
+
task_strs: list[str] = []
|
|
252
|
+
for task_name in task_order:
|
|
253
|
+
metric_items = grouped.get(task_name, {})
|
|
254
|
+
if not metric_items:
|
|
255
|
+
continue
|
|
256
|
+
metric_str = ", ".join(
|
|
257
|
+
f"{k}={float(v):.4f}" for k, v in metric_items.items()
|
|
258
|
+
)
|
|
259
|
+
task_strs.append(f"{task_name}[{metric_str}]")
|
|
260
|
+
if task_strs:
|
|
261
|
+
segments.append(", ".join(task_strs))
|
|
262
|
+
elif metrics:
|
|
263
|
+
metric_str = ", ".join(
|
|
264
|
+
f"{k}={float(v):.4f}"
|
|
265
|
+
for k, v in metrics.items()
|
|
266
|
+
if as_float(v) is not None
|
|
267
|
+
)
|
|
268
|
+
if metric_str:
|
|
269
|
+
segments.append(metric_str)
|
|
270
|
+
if not segments:
|
|
271
|
+
return
|
|
272
|
+
msg = f"{prefix} " + ", ".join(segments)
|
|
273
|
+
if colorize is not None:
|
|
274
|
+
msg = colorize(msg)
|
|
275
|
+
logging.info(msg)
|
|
276
|
+
return
|
|
277
|
+
|
|
278
|
+
title = f"Epoch {epoch}/{epochs} - {split}"
|
|
279
|
+
if isinstance(loss, numbers.Number):
|
|
280
|
+
title += f" (loss={float(loss):.4f})"
|
|
281
|
+
|
|
282
|
+
table = Table(
|
|
283
|
+
title=title,
|
|
284
|
+
box=box.ROUNDED,
|
|
285
|
+
header_style="bold",
|
|
286
|
+
title_style="bold",
|
|
287
|
+
)
|
|
288
|
+
table.add_column("Task", style="bold")
|
|
289
|
+
|
|
290
|
+
include_loss = isinstance(loss, np.ndarray)
|
|
291
|
+
if include_loss:
|
|
292
|
+
table.add_column("loss", justify="right")
|
|
293
|
+
|
|
294
|
+
metric_names: list[str] = []
|
|
295
|
+
for task_name in task_order:
|
|
296
|
+
for metric_name in grouped.get(task_name, {}):
|
|
297
|
+
if metric_name not in metric_names:
|
|
298
|
+
metric_names.append(metric_name)
|
|
299
|
+
|
|
300
|
+
preferred_order: list[str] = []
|
|
301
|
+
if isinstance(base_metrics, list):
|
|
302
|
+
preferred_order = [m for m in base_metrics if m in metric_names]
|
|
303
|
+
remaining = [m for m in metric_names if m not in preferred_order]
|
|
304
|
+
metric_names = preferred_order + sorted(remaining)
|
|
305
|
+
|
|
306
|
+
for metric_name in metric_names:
|
|
307
|
+
table.add_column(metric_name, justify="right")
|
|
308
|
+
|
|
309
|
+
def fmt(value: float | None) -> str:
|
|
310
|
+
if value is None:
|
|
311
|
+
return "-"
|
|
312
|
+
if np.isnan(value):
|
|
313
|
+
return "nan"
|
|
314
|
+
if np.isinf(value):
|
|
315
|
+
return "inf" if value > 0 else "-inf"
|
|
316
|
+
return f"{value:.4f}"
|
|
317
|
+
|
|
318
|
+
loss_by_task: dict[str, float] = {}
|
|
319
|
+
if isinstance(loss, np.ndarray):
|
|
320
|
+
if target_list:
|
|
321
|
+
for i, task_name in enumerate(target_list):
|
|
322
|
+
if i < loss.shape[0]:
|
|
323
|
+
loss_by_task[task_name] = float(loss[i])
|
|
324
|
+
if "overall" in task_order and "overall" not in loss_by_task:
|
|
325
|
+
loss_by_task["overall"] = float(np.sum(loss))
|
|
326
|
+
elif task_order:
|
|
327
|
+
for i, task_name in enumerate(task_order):
|
|
328
|
+
if i < loss.shape[0]:
|
|
329
|
+
loss_by_task[task_name] = float(loss[i])
|
|
330
|
+
else:
|
|
331
|
+
task_order = ["overall"]
|
|
332
|
+
loss_by_task["overall"] = float(np.sum(loss))
|
|
333
|
+
|
|
334
|
+
if not task_order:
|
|
335
|
+
task_order = ["__overall__"]
|
|
336
|
+
|
|
337
|
+
for task_name in task_order:
|
|
338
|
+
row: list[str] = [str(task_name)]
|
|
339
|
+
if include_loss:
|
|
340
|
+
row.append(fmt(loss_by_task.get(task_name)))
|
|
341
|
+
for metric_name in metric_names:
|
|
342
|
+
row.append(fmt(grouped.get(task_name, {}).get(metric_name)))
|
|
343
|
+
table.add_row(*row)
|
|
344
|
+
|
|
345
|
+
Console().print(table)
|
|
346
|
+
|
|
347
|
+
record_console = Console(file=io.StringIO(), record=True, width=120)
|
|
348
|
+
record_console.print(table)
|
|
349
|
+
table_text = record_console.export_text(styles=False).rstrip()
|
|
350
|
+
|
|
351
|
+
root_logger = logging.getLogger()
|
|
352
|
+
record = root_logger.makeRecord(
|
|
353
|
+
root_logger.name,
|
|
354
|
+
logging.INFO,
|
|
355
|
+
__file__,
|
|
356
|
+
0,
|
|
357
|
+
"[MetricsTable]\n" + table_text,
|
|
358
|
+
args=(),
|
|
359
|
+
exc_info=None,
|
|
360
|
+
extra=None,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
emitted = False
|
|
364
|
+
for handler in root_logger.handlers:
|
|
365
|
+
if isinstance(handler, logging.FileHandler):
|
|
366
|
+
handler.emit(record)
|
|
367
|
+
emitted = True
|
|
368
|
+
|
|
369
|
+
if not emitted:
|
|
370
|
+
# Fallback: no file handlers configured, use standard logging.
|
|
371
|
+
root_logger.log(logging.INFO, "[MetricsTable]\n" + table_text)
|
|
@@ -1,17 +1,101 @@
|
|
|
1
1
|
"""
|
|
2
|
-
|
|
2
|
+
Data utilities for NextRec.
|
|
3
3
|
|
|
4
|
-
This module provides
|
|
5
|
-
and tutorial purposes in the NextRec framework.
|
|
4
|
+
This module provides file I/O helpers and synthetic data generation.
|
|
6
5
|
|
|
7
|
-
Date: create on
|
|
6
|
+
Date: create on 19/12/2025
|
|
7
|
+
Checkpoint: edit on 19/12/2025
|
|
8
8
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
|
-
import
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Dict, Generator, List, Optional, Tuple
|
|
15
|
+
|
|
12
16
|
import numpy as np
|
|
13
17
|
import pandas as pd
|
|
14
|
-
|
|
18
|
+
import pyarrow.parquet as pq
|
|
19
|
+
import torch
|
|
20
|
+
import yaml
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def resolve_file_paths(path: str) -> tuple[list[str], str]:
|
|
24
|
+
"""
|
|
25
|
+
Resolve file or directory path into a sorted list of files and file type.
|
|
26
|
+
|
|
27
|
+
Args: path: Path to a file or directory
|
|
28
|
+
Returns: tuple: (list of file paths, file type)
|
|
29
|
+
"""
|
|
30
|
+
path_obj = Path(path)
|
|
31
|
+
|
|
32
|
+
if path_obj.is_file():
|
|
33
|
+
file_type = path_obj.suffix.lower().lstrip(".")
|
|
34
|
+
assert file_type in [
|
|
35
|
+
"csv",
|
|
36
|
+
"parquet",
|
|
37
|
+
], f"Unsupported file extension: {file_type}"
|
|
38
|
+
return [str(path_obj)], file_type
|
|
39
|
+
|
|
40
|
+
if path_obj.is_dir():
|
|
41
|
+
collected_files = [p for p in path_obj.iterdir() if p.is_file()]
|
|
42
|
+
csv_files = [str(p) for p in collected_files if p.suffix.lower() == ".csv"]
|
|
43
|
+
parquet_files = [
|
|
44
|
+
str(p) for p in collected_files if p.suffix.lower() == ".parquet"
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
if csv_files and parquet_files:
|
|
48
|
+
raise ValueError(
|
|
49
|
+
"Directory contains both CSV and Parquet files. Please keep a single format."
|
|
50
|
+
)
|
|
51
|
+
file_paths = csv_files if csv_files else parquet_files
|
|
52
|
+
if not file_paths:
|
|
53
|
+
raise ValueError(f"No CSV or Parquet files found in directory: {path}")
|
|
54
|
+
file_paths.sort()
|
|
55
|
+
file_type = "csv" if csv_files else "parquet"
|
|
56
|
+
return file_paths, file_type
|
|
57
|
+
|
|
58
|
+
raise ValueError(f"Invalid path: {path}")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def read_table(path: str | Path, data_format: str | None = None) -> pd.DataFrame:
|
|
62
|
+
data_path = Path(path)
|
|
63
|
+
fmt = data_format.lower() if data_format else data_path.suffix.lower().lstrip(".")
|
|
64
|
+
if data_path.is_dir() and not fmt:
|
|
65
|
+
fmt = "parquet"
|
|
66
|
+
if fmt in {"parquet", ""}:
|
|
67
|
+
return pd.read_parquet(data_path)
|
|
68
|
+
if fmt in {"csv", "txt"}:
|
|
69
|
+
# Use low_memory=False to avoid mixed-type DtypeWarning on wide CSVs
|
|
70
|
+
return pd.read_csv(data_path, low_memory=False)
|
|
71
|
+
raise ValueError(f"Unsupported data format: {data_path}")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def load_dataframes(file_paths: list[str], file_type: str) -> list[pd.DataFrame]:
|
|
75
|
+
return [read_table(fp, file_type) for fp in file_paths]
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def iter_file_chunks(
|
|
79
|
+
file_path: str, file_type: str, chunk_size: int
|
|
80
|
+
) -> Generator[pd.DataFrame, None, None]:
|
|
81
|
+
if file_type == "csv":
|
|
82
|
+
yield from pd.read_csv(file_path, chunksize=chunk_size)
|
|
83
|
+
return
|
|
84
|
+
parquet_file = pq.ParquetFile(file_path)
|
|
85
|
+
for batch in parquet_file.iter_batches(batch_size=chunk_size):
|
|
86
|
+
yield batch.to_pandas()
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def default_output_dir(path: str) -> Path:
|
|
90
|
+
path_obj = Path(path)
|
|
91
|
+
if path_obj.is_file():
|
|
92
|
+
return path_obj.parent / f"{path_obj.stem}_preprocessed"
|
|
93
|
+
return path_obj.with_name(f"{path_obj.name}_preprocessed")
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def read_yaml(path: str | Path):
|
|
97
|
+
with open(path, "r", encoding="utf-8") as file:
|
|
98
|
+
return yaml.safe_load(file) or {}
|
|
15
99
|
|
|
16
100
|
|
|
17
101
|
def generate_ranking_data(
|
|
@@ -90,13 +174,14 @@ def generate_ranking_data(
|
|
|
90
174
|
sequence_vocabs.append(seq_vocab)
|
|
91
175
|
|
|
92
176
|
if "gender" in data and "dense_0" in data:
|
|
177
|
+
dense_1 = data.get("dense_1", 0)
|
|
93
178
|
# Complex label generation with feature correlation
|
|
94
179
|
label_probs = 1 / (
|
|
95
180
|
1
|
|
96
181
|
+ np.exp(
|
|
97
182
|
-(
|
|
98
183
|
data["dense_0"] * 0.3
|
|
99
|
-
+
|
|
184
|
+
+ dense_1 * 0.2
|
|
100
185
|
+ (data["gender"] - 0.5) * 0.5
|
|
101
186
|
+ np.random.randn(n_samples) * 0.1
|
|
102
187
|
)
|
|
@@ -112,7 +197,7 @@ def generate_ranking_data(
|
|
|
112
197
|
print(f"Positive rate: {data['label'].mean():.4f}")
|
|
113
198
|
|
|
114
199
|
# Import here to avoid circular import
|
|
115
|
-
from nextrec.basic.features import DenseFeature,
|
|
200
|
+
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
116
201
|
|
|
117
202
|
# Create feature definitions
|
|
118
203
|
# Use input_dim for dense features to be compatible with both simple and complex scenarios
|
|
@@ -273,7 +358,7 @@ def generate_match_data(
|
|
|
273
358
|
print(f"Positive rate: {data['label'].mean():.4f}")
|
|
274
359
|
|
|
275
360
|
# Import here to avoid circular import
|
|
276
|
-
from nextrec.basic.features import DenseFeature,
|
|
361
|
+
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
277
362
|
|
|
278
363
|
# User dense features
|
|
279
364
|
user_dense_features = [DenseFeature(name="user_age", input_dim=1)]
|
|
@@ -413,15 +498,17 @@ def generate_multitask_data(
|
|
|
413
498
|
|
|
414
499
|
# Generate multi-task labels with correlation
|
|
415
500
|
# CTR (click) is relatively easier to predict
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
)
|
|
501
|
+
dense_0 = data.get("dense_0", 0)
|
|
502
|
+
dense_1 = data.get("dense_1", 0)
|
|
503
|
+
dense_2 = data.get("dense_2", 0)
|
|
504
|
+
dense_3 = data.get("dense_3", 0)
|
|
505
|
+
ctr_logits = dense_0 * 0.3 + dense_1 * 0.2 + np.random.randn(n_samples) * 0.5
|
|
419
506
|
data["click"] = (1 / (1 + np.exp(-ctr_logits)) > 0.5).astype(np.float32)
|
|
420
507
|
|
|
421
508
|
# CVR (conversion) depends on click and is harder
|
|
422
509
|
cvr_logits = (
|
|
423
|
-
|
|
424
|
-
+
|
|
510
|
+
dense_2 * 0.2
|
|
511
|
+
+ dense_3 * 0.15
|
|
425
512
|
+ data["click"] * 1.5 # Strong dependency on click
|
|
426
513
|
+ np.random.randn(n_samples) * 0.8
|
|
427
514
|
)
|
|
@@ -441,7 +528,7 @@ def generate_multitask_data(
|
|
|
441
528
|
print(f"CTCVR rate: {data['ctcvr'].mean():.4f}")
|
|
442
529
|
|
|
443
530
|
# Import here to avoid circular import
|
|
444
|
-
from nextrec.basic.features import DenseFeature,
|
|
531
|
+
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
445
532
|
|
|
446
533
|
# Create feature definitions
|
|
447
534
|
dense_features = [
|
nextrec/utils/feature.py
CHANGED
|
@@ -2,9 +2,13 @@
|
|
|
2
2
|
Feature processing utilities for NextRec
|
|
3
3
|
|
|
4
4
|
Date: create on 03/12/2025
|
|
5
|
+
Checkpoint: edit on 19/12/2025
|
|
5
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
7
|
"""
|
|
7
8
|
|
|
9
|
+
import numbers
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
8
12
|
|
|
9
13
|
def normalize_to_list(value: str | list[str] | None) -> list[str]:
|
|
10
14
|
if value is None:
|
|
@@ -12,3 +16,14 @@ def normalize_to_list(value: str | list[str] | None) -> list[str]:
|
|
|
12
16
|
if isinstance(value, str):
|
|
13
17
|
return [value]
|
|
14
18
|
return list(value)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def as_float(value: Any) -> float | None:
|
|
22
|
+
if isinstance(value, numbers.Number):
|
|
23
|
+
return float(value)
|
|
24
|
+
if hasattr(value, "item"):
|
|
25
|
+
try:
|
|
26
|
+
return float(value.item())
|
|
27
|
+
except Exception:
|
|
28
|
+
return None
|
|
29
|
+
return None
|