nextrec 0.4.8__py3-none-any.whl → 0.4.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/callback.py +30 -15
  3. nextrec/basic/features.py +1 -0
  4. nextrec/basic/layers.py +6 -8
  5. nextrec/basic/loggers.py +14 -7
  6. nextrec/basic/metrics.py +6 -76
  7. nextrec/basic/model.py +316 -321
  8. nextrec/cli.py +185 -43
  9. nextrec/data/__init__.py +13 -16
  10. nextrec/data/batch_utils.py +3 -2
  11. nextrec/data/data_processing.py +10 -2
  12. nextrec/data/data_utils.py +9 -14
  13. nextrec/data/dataloader.py +31 -33
  14. nextrec/data/preprocessor.py +328 -255
  15. nextrec/loss/__init__.py +1 -5
  16. nextrec/loss/loss_utils.py +2 -8
  17. nextrec/models/generative/__init__.py +1 -8
  18. nextrec/models/generative/hstu.py +6 -4
  19. nextrec/models/multi_task/esmm.py +2 -2
  20. nextrec/models/multi_task/mmoe.py +2 -2
  21. nextrec/models/multi_task/ple.py +2 -2
  22. nextrec/models/multi_task/poso.py +2 -3
  23. nextrec/models/multi_task/share_bottom.py +2 -2
  24. nextrec/models/ranking/afm.py +2 -2
  25. nextrec/models/ranking/autoint.py +2 -2
  26. nextrec/models/ranking/dcn.py +2 -2
  27. nextrec/models/ranking/dcn_v2.py +2 -2
  28. nextrec/models/ranking/deepfm.py +6 -7
  29. nextrec/models/ranking/dien.py +3 -3
  30. nextrec/models/ranking/din.py +3 -3
  31. nextrec/models/ranking/eulernet.py +365 -0
  32. nextrec/models/ranking/fibinet.py +5 -5
  33. nextrec/models/ranking/fm.py +3 -7
  34. nextrec/models/ranking/lr.py +120 -0
  35. nextrec/models/ranking/masknet.py +2 -2
  36. nextrec/models/ranking/pnn.py +2 -2
  37. nextrec/models/ranking/widedeep.py +2 -2
  38. nextrec/models/ranking/xdeepfm.py +2 -2
  39. nextrec/models/representation/__init__.py +9 -0
  40. nextrec/models/{generative → representation}/rqvae.py +9 -9
  41. nextrec/models/retrieval/__init__.py +0 -0
  42. nextrec/models/{match → retrieval}/dssm.py +8 -3
  43. nextrec/models/{match → retrieval}/dssm_v2.py +8 -3
  44. nextrec/models/{match → retrieval}/mind.py +4 -3
  45. nextrec/models/{match → retrieval}/sdm.py +4 -3
  46. nextrec/models/{match → retrieval}/youtube_dnn.py +8 -3
  47. nextrec/utils/__init__.py +60 -46
  48. nextrec/utils/config.py +8 -7
  49. nextrec/utils/console.py +371 -0
  50. nextrec/utils/{synthetic_data.py → data.py} +102 -15
  51. nextrec/utils/feature.py +15 -0
  52. nextrec/utils/torch_utils.py +411 -0
  53. {nextrec-0.4.8.dist-info → nextrec-0.4.10.dist-info}/METADATA +6 -7
  54. nextrec-0.4.10.dist-info/RECORD +70 -0
  55. nextrec/utils/cli_utils.py +0 -58
  56. nextrec/utils/device.py +0 -78
  57. nextrec/utils/distributed.py +0 -141
  58. nextrec/utils/file.py +0 -92
  59. nextrec/utils/initializer.py +0 -79
  60. nextrec/utils/optimizer.py +0 -75
  61. nextrec/utils/tensor.py +0 -72
  62. nextrec-0.4.8.dist-info/RECORD +0 -71
  63. /nextrec/models/{match/__init__.py → ranking/ffm.py} +0 -0
  64. {nextrec-0.4.8.dist-info → nextrec-0.4.10.dist-info}/WHEEL +0 -0
  65. {nextrec-0.4.8.dist-info → nextrec-0.4.10.dist-info}/entry_points.txt +0 -0
  66. {nextrec-0.4.8.dist-info → nextrec-0.4.10.dist-info}/licenses/LICENSE +0 -0
nextrec/cli.py CHANGED
@@ -18,10 +18,10 @@ Checkpoint: edit on 18/12/2025
18
18
  Author: Yang Zhou, zyaztec@gmail.com
19
19
  """
20
20
 
21
- import sys
22
21
  import argparse
23
22
  import logging
24
23
  import pickle
24
+ import sys
25
25
  import time
26
26
  from pathlib import Path
27
27
  from typing import Any, Dict, List
@@ -29,6 +29,7 @@ from typing import Any, Dict, List
29
29
  import pandas as pd
30
30
 
31
31
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
32
+ from nextrec.basic.loggers import colorize, format_kv, setup_logger
32
33
  from nextrec.data.data_utils import split_dict_random
33
34
  from nextrec.data.dataloader import RecDataLoader
34
35
  from nextrec.data.preprocessor import DataProcessor
@@ -39,22 +40,29 @@ from nextrec.utils.config import (
39
40
  resolve_path,
40
41
  select_features,
41
42
  )
42
- from nextrec.utils.feature import normalize_to_list
43
- from nextrec.utils.file import (
43
+ from nextrec.utils.console import get_nextrec_version
44
+ from nextrec.utils.data import (
44
45
  iter_file_chunks,
45
46
  read_table,
46
47
  read_yaml,
47
48
  resolve_file_paths,
48
49
  )
49
- from nextrec.utils.cli_utils import (
50
- get_nextrec_version,
51
- log_startup_info,
52
- )
53
- from nextrec.basic.loggers import setup_logger
50
+ from nextrec.utils.feature import normalize_to_list
54
51
 
55
52
  logger = logging.getLogger(__name__)
56
53
 
57
54
 
55
+ def log_cli_section(title: str) -> None:
56
+ logger.info("")
57
+ logger.info(colorize(f"[{title}]", color="bright_blue", bold=True))
58
+ logger.info(colorize("-" * 80, color="bright_blue"))
59
+
60
+
61
+ def log_kv_lines(items: list[tuple[str, Any]]) -> None:
62
+ for label, value in items:
63
+ logger.info(format_kv(label, value))
64
+
65
+
58
66
  def train_model(train_config_path: str) -> None:
59
67
  """
60
68
  Train a NextRec model using the provided configuration file.
@@ -77,8 +85,17 @@ def train_model(train_config_path: str) -> None:
77
85
  artifact_root = Path(session_cfg.get("artifact_root", "nextrec_logs"))
78
86
  session_dir = artifact_root / session_id
79
87
  setup_logger(session_id=session_id)
80
- logger.info(
81
- f"[NextRec CLI] Training start | version={get_nextrec_version()} | session_id={session_id} | artifacts={session_dir.resolve()}"
88
+
89
+ log_cli_section("CLI")
90
+ log_kv_lines(
91
+ [
92
+ ("Mode", "train"),
93
+ ("Version", get_nextrec_version()),
94
+ ("Session ID", session_id),
95
+ ("Artifacts", session_dir.resolve()),
96
+ ("Config", config_file.resolve()),
97
+ ("Command", " ".join(sys.argv)),
98
+ ]
82
99
  )
83
100
 
84
101
  processor_path = session_dir / "processor.pkl"
@@ -105,11 +122,53 @@ def train_model(train_config_path: str) -> None:
105
122
  cfg.get("model_config", "model_config.yaml"), config_dir
106
123
  )
107
124
 
125
+ log_cli_section("Config")
126
+ log_kv_lines(
127
+ [
128
+ ("Train config", config_file.resolve()),
129
+ ("Feature config", feature_cfg_path),
130
+ ("Model config", model_cfg_path),
131
+ ]
132
+ )
133
+
108
134
  feature_cfg = read_yaml(feature_cfg_path)
109
135
  model_cfg = read_yaml(model_cfg_path)
110
136
 
137
+ # Extract id_column from data config for GAUC metrics
138
+ id_column = data_cfg.get("id_column") or data_cfg.get("user_id_column")
139
+ id_columns = [id_column] if id_column else []
140
+
141
+ log_cli_section("Data")
142
+ log_kv_lines(
143
+ [
144
+ ("Data path", data_path),
145
+ ("Format", data_cfg.get("format", "auto")),
146
+ ("Streaming", streaming),
147
+ ("Target", target),
148
+ ("ID column", id_column or "(not set)"),
149
+ ]
150
+ )
151
+ if data_cfg.get("valid_ratio") is not None:
152
+ logger.info(format_kv("Valid ratio", data_cfg.get("valid_ratio")))
153
+ if data_cfg.get("val_path") or data_cfg.get("valid_path"):
154
+ logger.info(
155
+ format_kv(
156
+ "Validation path",
157
+ resolve_path(
158
+ data_cfg.get("val_path") or data_cfg.get("valid_path"), config_dir
159
+ ),
160
+ )
161
+ )
162
+
111
163
  if streaming:
112
164
  file_paths, file_type = resolve_file_paths(str(data_path))
165
+ log_kv_lines(
166
+ [
167
+ ("File type", file_type),
168
+ ("Files", len(file_paths)),
169
+ ("Chunk size", dataloader_chunk_size),
170
+ ]
171
+ )
113
172
  first_file = file_paths[0]
114
173
  first_chunk_size = max(1, min(dataloader_chunk_size, 1000))
115
174
  chunk_iter = iter_file_chunks(first_file, file_type, first_chunk_size)
@@ -121,14 +180,12 @@ def train_model(train_config_path: str) -> None:
121
180
 
122
181
  else:
123
182
  df = read_table(data_path, data_cfg.get("format"))
183
+ logger.info(format_kv("Rows", len(df)))
184
+ logger.info(format_kv("Columns", len(df.columns)))
124
185
  df_columns = list(df.columns)
125
186
 
126
187
  dense_names, sparse_names, sequence_names = select_features(feature_cfg, df_columns)
127
188
 
128
- # Extract id_column from data config for GAUC metrics
129
- id_column = data_cfg.get("id_column") or data_cfg.get("user_id_column")
130
- id_columns = [id_column] if id_column else []
131
-
132
189
  used_columns = dense_names + sparse_names + sequence_names + target + id_columns
133
190
 
134
191
  # keep order but drop duplicates
@@ -144,6 +201,17 @@ def train_model(train_config_path: str) -> None:
144
201
  processor, feature_cfg, dense_names, sparse_names, sequence_names
145
202
  )
146
203
 
204
+ log_cli_section("Features")
205
+ log_kv_lines(
206
+ [
207
+ ("Dense features", len(dense_names)),
208
+ ("Sparse features", len(sparse_names)),
209
+ ("Sequence features", len(sequence_names)),
210
+ ("Targets", len(target)),
211
+ ("Used columns", len(unique_used_columns)),
212
+ ]
213
+ )
214
+
147
215
  if streaming:
148
216
  processor.fit(str(data_path), chunk_size=dataloader_chunk_size)
149
217
  processed = None
@@ -247,7 +315,7 @@ def train_model(train_config_path: str) -> None:
247
315
  data=train_stream_source,
248
316
  batch_size=dataloader_cfg.get("train_batch_size", 512),
249
317
  shuffle=dataloader_cfg.get("train_shuffle", True),
250
- load_full=False,
318
+ streaming=True,
251
319
  chunk_size=dataloader_chunk_size,
252
320
  num_workers=dataloader_cfg.get("num_workers", 0),
253
321
  )
@@ -258,7 +326,7 @@ def train_model(train_config_path: str) -> None:
258
326
  data=str(val_data_resolved),
259
327
  batch_size=dataloader_cfg.get("valid_batch_size", 512),
260
328
  shuffle=dataloader_cfg.get("valid_shuffle", False),
261
- load_full=False,
329
+ streaming=True,
262
330
  chunk_size=dataloader_chunk_size,
263
331
  num_workers=dataloader_cfg.get("num_workers", 0),
264
332
  )
@@ -267,7 +335,7 @@ def train_model(train_config_path: str) -> None:
267
335
  data=streaming_valid_files,
268
336
  batch_size=dataloader_cfg.get("valid_batch_size", 512),
269
337
  shuffle=dataloader_cfg.get("valid_shuffle", False),
270
- load_full=False,
338
+ streaming=True,
271
339
  chunk_size=dataloader_chunk_size,
272
340
  num_workers=dataloader_cfg.get("num_workers", 0),
273
341
  )
@@ -298,6 +366,15 @@ def train_model(train_config_path: str) -> None:
298
366
  device,
299
367
  )
300
368
 
369
+ log_cli_section("Model")
370
+ log_kv_lines(
371
+ [
372
+ ("Model", model.__class__.__name__),
373
+ ("Device", device),
374
+ ("Session ID", session_id),
375
+ ]
376
+ )
377
+
301
378
  model.compile(
302
379
  optimizer=train_cfg.get("optimizer", "adam"),
303
380
  optimizer_params=train_cfg.get("optimizer_params", {}),
@@ -328,13 +405,30 @@ def predict_model(predict_config_path: str) -> None:
328
405
  config_dir = config_file.resolve().parent
329
406
  cfg = read_yaml(config_file)
330
407
 
331
- session_cfg = cfg.get("session", {}) or {}
332
- session_id = session_cfg.get("id", "masknet_tutorial")
333
- artifact_root = Path(session_cfg.get("artifact_root", "nextrec_logs"))
334
- session_dir = Path(cfg.get("checkpoint_path") or (artifact_root / session_id))
408
+ # Checkpoint path is the primary configuration
409
+ if "checkpoint_path" not in cfg:
410
+ session_cfg = cfg.get("session", {}) or {}
411
+ session_id = session_cfg.get("id", "nextrec_session")
412
+ artifact_root = Path(session_cfg.get("artifact_root", "nextrec_logs"))
413
+ session_dir = artifact_root / session_id
414
+ else:
415
+ session_dir = Path(cfg["checkpoint_path"])
416
+ # Auto-infer session_id from checkpoint directory name
417
+ session_cfg = cfg.get("session", {}) or {}
418
+ session_id = session_cfg.get("id") or session_dir.name
419
+
335
420
  setup_logger(session_id=session_id)
336
- logger.info(
337
- f"[NextRec CLI] Predict start | version={get_nextrec_version()} | session_id={session_id} | checkpoint={session_dir.resolve()}"
421
+
422
+ log_cli_section("CLI")
423
+ log_kv_lines(
424
+ [
425
+ ("Mode", "predict"),
426
+ ("Version", get_nextrec_version()),
427
+ ("Session ID", session_id),
428
+ ("Checkpoint", session_dir.resolve()),
429
+ ("Config", config_file.resolve()),
430
+ ("Command", " ".join(sys.argv)),
431
+ ]
338
432
  )
339
433
 
340
434
  processor_path = Path(session_dir / "processor.pkl")
@@ -342,24 +436,38 @@ def predict_model(predict_config_path: str) -> None:
342
436
  processor_path = session_dir / "processor" / "processor.pkl"
343
437
 
344
438
  predict_cfg = cfg.get("predict", {}) or {}
345
- model_cfg_path = resolve_path(
346
- cfg.get("model_config", "model_config.yaml"), config_dir
347
- )
348
- # feature_cfg_path = resolve_path(
349
- # cfg.get("feature_config", "feature_config.yaml"), config_dir
350
- # )
439
+
440
+ # Auto-find model_config in checkpoint directory if not specified
441
+ if "model_config" in cfg:
442
+ model_cfg_path = resolve_path(cfg["model_config"], config_dir)
443
+ else:
444
+ # Try to find model_config.yaml in checkpoint directory
445
+ auto_model_cfg = session_dir / "model_config.yaml"
446
+ if auto_model_cfg.exists():
447
+ model_cfg_path = auto_model_cfg
448
+ else:
449
+ # Fallback to config directory
450
+ model_cfg_path = resolve_path("model_config.yaml", config_dir)
351
451
 
352
452
  model_cfg = read_yaml(model_cfg_path)
353
- # feature_cfg = read_yaml(feature_cfg_path)
354
453
  model_cfg.setdefault("session_id", session_id)
355
454
  model_cfg.setdefault("params", {})
356
455
 
456
+ log_cli_section("Config")
457
+ log_kv_lines(
458
+ [
459
+ ("Predict config", config_file.resolve()),
460
+ ("Model config", model_cfg_path),
461
+ ("Processor", processor_path),
462
+ ]
463
+ )
464
+
357
465
  processor = DataProcessor.load(processor_path)
358
466
 
359
467
  # Load checkpoint and ensure required parameters are passed
360
468
  checkpoint_base = Path(session_dir)
361
469
  if checkpoint_base.is_dir():
362
- candidates = sorted(checkpoint_base.glob("*.model"))
470
+ candidates = sorted(checkpoint_base.glob("*.pt"))
363
471
  if not candidates:
364
472
  raise FileNotFoundError(
365
473
  f"[NextRec CLI Error]: Unable to find model checkpoint: {checkpoint_base}"
@@ -368,7 +476,7 @@ def predict_model(predict_config_path: str) -> None:
368
476
  config_dir_for_features = checkpoint_base
369
477
  else:
370
478
  model_file = (
371
- checkpoint_base.with_suffix(".model")
479
+ checkpoint_base.with_suffix(".pt")
372
480
  if checkpoint_base.suffix == ""
373
481
  else checkpoint_base
374
482
  )
@@ -418,40 +526,78 @@ def predict_model(predict_config_path: str) -> None:
418
526
  id_columns = [predict_cfg["id_column"]]
419
527
  model.id_columns = id_columns
420
528
 
529
+ effective_id_columns = id_columns or model.id_columns
530
+ log_cli_section("Features")
531
+ log_kv_lines(
532
+ [
533
+ ("Dense features", len(dense_features)),
534
+ ("Sparse features", len(sparse_features)),
535
+ ("Sequence features", len(sequence_features)),
536
+ ("Targets", len(target_cols)),
537
+ ("ID columns", len(effective_id_columns)),
538
+ ]
539
+ )
540
+
541
+ log_cli_section("Model")
542
+ log_kv_lines(
543
+ [
544
+ ("Model", model.__class__.__name__),
545
+ ("Checkpoint", model_file),
546
+ ("Device", predict_cfg.get("device", "cpu")),
547
+ ]
548
+ )
549
+
421
550
  rec_dataloader = RecDataLoader(
422
551
  dense_features=model.dense_features,
423
552
  sparse_features=model.sparse_features,
424
553
  sequence_features=model.sequence_features,
425
554
  target=None,
426
- id_columns=id_columns or model.id_columns,
555
+ id_columns=effective_id_columns,
427
556
  processor=processor,
428
557
  )
429
558
 
430
559
  data_path = resolve_path(predict_cfg["data_path"], config_dir)
431
560
  batch_size = predict_cfg.get("batch_size", 512)
432
561
 
562
+ log_cli_section("Data")
563
+ log_kv_lines(
564
+ [
565
+ ("Data path", data_path),
566
+ ("Format", predict_cfg.get("source_data_format", predict_cfg.get("data_format", "auto"))),
567
+ ("Batch size", batch_size),
568
+ ("Chunk size", predict_cfg.get("chunk_size", 20000)),
569
+ ("Streaming", predict_cfg.get("streaming", True)),
570
+ ]
571
+ )
572
+ logger.info("")
433
573
  pred_loader = rec_dataloader.create_dataloader(
434
574
  data=str(data_path),
435
575
  batch_size=batch_size,
436
576
  shuffle=False,
437
- load_full=predict_cfg.get("load_full", False),
577
+ streaming=predict_cfg.get("streaming", True),
438
578
  chunk_size=predict_cfg.get("chunk_size", 20000),
439
579
  )
440
580
 
441
- output_path = resolve_path(predict_cfg["output_path"], config_dir)
442
- output_path.parent.mkdir(parents=True, exist_ok=True)
581
+ # Build output path: {checkpoint_path}/predictions/{name}.{save_data_format}
582
+ save_format = predict_cfg.get("save_data_format", predict_cfg.get("save_format", "csv"))
583
+ pred_name = predict_cfg.get("name", "pred")
584
+ # Pass filename with extension to let model.predict handle path resolution
585
+ save_path = f"{pred_name}.{save_format}"
443
586
 
444
587
  start = time.time()
445
- model.predict(
588
+ logger.info("")
589
+ result = model.predict(
446
590
  data=pred_loader,
447
591
  batch_size=batch_size,
448
592
  include_ids=bool(id_columns),
449
593
  return_dataframe=False,
450
- save_path=output_path,
451
- save_format=predict_cfg.get("save_format", "csv"),
594
+ save_path=save_path,
595
+ save_format=save_format,
452
596
  num_workers=predict_cfg.get("num_workers", 0),
453
597
  )
454
598
  duration = time.time() - start
599
+ # When return_dataframe=False, result is the actual file path
600
+ output_path = result if isinstance(result, Path) else checkpoint_base / "predictions" / save_path
455
601
  logger.info(f"Prediction completed, results saved to: {output_path}")
456
602
  logger.info(f"Total time: {duration:.2f} seconds")
457
603
 
@@ -495,8 +641,6 @@ Examples:
495
641
  parser.add_argument("--predict_config", help="Prediction configuration file path")
496
642
  args = parser.parse_args()
497
643
 
498
- logger.info(get_nextrec_version())
499
-
500
644
  if not args.mode:
501
645
  parser.error("[NextRec CLI Error] --mode is required (train|predict)")
502
646
 
@@ -504,13 +648,11 @@ Examples:
504
648
  config_path = args.train_config
505
649
  if not config_path:
506
650
  parser.error("[NextRec CLI Error] train mode requires --train_config")
507
- log_startup_info(logger, mode="train", config_path=config_path)
508
651
  train_model(config_path)
509
652
  else:
510
653
  config_path = args.predict_config
511
654
  if not config_path:
512
655
  parser.error("[NextRec CLI Error] predict mode requires --predict_config")
513
- log_startup_info(logger, mode="predict", config_path=config_path)
514
656
  predict_model(config_path)
515
657
 
516
658
 
nextrec/data/__init__.py CHANGED
@@ -1,29 +1,26 @@
1
- from nextrec.data.batch_utils import collate_fn, batch_to_dict, stack_section
1
+ from nextrec.basic.features import FeatureSet
2
+ from nextrec.data import data_utils
3
+ from nextrec.data.batch_utils import batch_to_dict, collate_fn, stack_section
2
4
  from nextrec.data.data_processing import (
3
- get_column_data,
4
- split_dict_random,
5
5
  build_eval_candidates,
6
+ get_column_data,
6
7
  get_user_ids,
8
+ split_dict_random,
7
9
  )
8
-
9
- from nextrec.utils.file import (
10
- resolve_file_paths,
11
- iter_file_chunks,
12
- read_table,
13
- load_dataframes,
14
- default_output_dir,
15
- )
16
-
17
10
  from nextrec.data.dataloader import (
18
- TensorDictDataset,
19
11
  FileDataset,
20
12
  RecDataLoader,
13
+ TensorDictDataset,
21
14
  build_tensors_from_data,
22
15
  )
23
-
24
16
  from nextrec.data.preprocessor import DataProcessor
25
- from nextrec.basic.features import FeatureSet
26
- from nextrec.data import data_utils
17
+ from nextrec.utils.data import (
18
+ default_output_dir,
19
+ iter_file_chunks,
20
+ load_dataframes,
21
+ read_table,
22
+ resolve_file_paths,
23
+ )
27
24
 
28
25
  __all__ = [
29
26
  # Batch utilities
@@ -5,10 +5,11 @@ Date: create on 03/12/2025
5
5
  Author: Yang Zhou, zyaztec@gmail.com
6
6
  """
7
7
 
8
- import torch
9
- import numpy as np
10
8
  from typing import Any, Mapping
11
9
 
10
+ import numpy as np
11
+ import torch
12
+
12
13
 
13
14
  def stack_section(batch: list[dict], section: str):
14
15
  entries = [item.get(section) for item in batch if item.get(section) is not None]
@@ -2,13 +2,16 @@
2
2
  Data processing utilities for NextRec
3
3
 
4
4
  Date: create on 03/12/2025
5
+ Checkpoint: edit on 19/12/2025
5
6
  Author: Yang Zhou, zyaztec@gmail.com
6
7
  """
7
8
 
8
- import torch
9
+ import hashlib
10
+ from typing import Any
11
+
9
12
  import numpy as np
10
13
  import pandas as pd
11
- from typing import Any
14
+ import torch
12
15
 
13
16
 
14
17
  def get_column_data(data: dict | pd.DataFrame, name: str):
@@ -166,3 +169,8 @@ def get_user_ids(
166
169
  return arr.reshape(arr.shape[0])
167
170
 
168
171
  return None
172
+
173
+
174
+ def hash_md5_mod(value: str, hash_size: int) -> int:
175
+ digest = hashlib.md5(value.encode("utf-8")).digest()
176
+ return int.from_bytes(digest, byteorder="big", signed=False) % hash_size
@@ -1,30 +1,25 @@
1
1
  """
2
- Data processing utilities for NextRec (Refactored)
3
-
4
- This module now re-exports functions from specialized submodules:
5
- - batch_utils: collate_fn, batch_to_dict
6
- - data_processing: get_column_data, split_dict_random, build_eval_candidates, get_user_ids
7
- - nextrec.utils.file_utils: resolve_file_paths, iter_file_chunks, read_table, load_dataframes, default_output_dir
2
+ Data processing utilities for NextRec
8
3
 
9
4
  Date: create on 27/10/2025
10
- Last update: 03/12/2025 (refactored)
5
+ Last update: 19/12/2025
11
6
  Author: Yang Zhou, zyaztec@gmail.com
12
7
  """
13
8
 
14
9
  # Import from new organized modules
15
- from nextrec.data.batch_utils import collate_fn, batch_to_dict, stack_section
10
+ from nextrec.data.batch_utils import batch_to_dict, collate_fn, stack_section
16
11
  from nextrec.data.data_processing import (
17
- get_column_data,
18
- split_dict_random,
19
12
  build_eval_candidates,
13
+ get_column_data,
20
14
  get_user_ids,
15
+ split_dict_random,
21
16
  )
22
- from nextrec.utils.file import (
23
- resolve_file_paths,
17
+ from nextrec.utils.data import (
18
+ default_output_dir,
24
19
  iter_file_chunks,
25
- read_table,
26
20
  load_dataframes,
27
- default_output_dir,
21
+ read_table,
22
+ resolve_file_paths,
28
23
  )
29
24
 
30
25
  __all__ = [
@@ -2,33 +2,32 @@
2
2
  Dataloader definitions
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 02/12/2025
5
+ Checkpoint: edit on 19/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
9
- import os
10
- import torch
11
9
  import logging
10
+ import os
11
+ from pathlib import Path
12
+ from typing import cast
13
+
12
14
  import numpy as np
13
15
  import pandas as pd
14
16
  import pyarrow.parquet as pq
15
-
16
- from pathlib import Path
17
- from typing import cast
17
+ import torch
18
+ from torch.utils.data import DataLoader, Dataset, IterableDataset
18
19
 
19
20
  from nextrec.basic.features import (
20
21
  DenseFeature,
21
- SparseFeature,
22
- SequenceFeature,
23
22
  FeatureSet,
23
+ SequenceFeature,
24
+ SparseFeature,
24
25
  )
25
- from nextrec.data.preprocessor import DataProcessor
26
- from torch.utils.data import DataLoader, Dataset, IterableDataset
27
-
28
- from nextrec.utils.tensor import to_tensor
29
- from nextrec.utils.file import resolve_file_paths, read_table
30
26
  from nextrec.data.batch_utils import collate_fn
31
27
  from nextrec.data.data_processing import get_column_data
28
+ from nextrec.data.preprocessor import DataProcessor
29
+ from nextrec.utils.data import read_table, resolve_file_paths
30
+ from nextrec.utils.torch_utils import to_tensor
32
31
 
33
32
 
34
33
  class TensorDictDataset(Dataset):
@@ -103,9 +102,8 @@ class FileDataset(FeatureSet, IterableDataset):
103
102
  self.current_file_index = 0
104
103
  for file_path in self.file_paths:
105
104
  self.current_file_index += 1
106
- if self.total_files == 1:
107
- file_name = os.path.basename(file_path)
108
- logging.info(f"Processing file: {file_name}")
105
+ # Don't log file processing here to avoid interrupting progress bars
106
+ # File information is already displayed in the CLI data section
109
107
  if self.file_type == "csv":
110
108
  yield from self.read_csv_chunks(file_path)
111
109
  elif self.file_type == "parquet":
@@ -191,7 +189,7 @@ class RecDataLoader(FeatureSet):
191
189
  ),
192
190
  batch_size: int = 32,
193
191
  shuffle: bool = True,
194
- load_full: bool = True,
192
+ streaming: bool = False,
195
193
  chunk_size: int = 10000,
196
194
  num_workers: int = 0,
197
195
  sampler=None,
@@ -203,7 +201,7 @@ class RecDataLoader(FeatureSet):
203
201
  data: Data source, can be a dict, pd.DataFrame, file path (str), or existing DataLoader.
204
202
  batch_size: Batch size for DataLoader.
205
203
  shuffle: Whether to shuffle the data (ignored in streaming mode).
206
- load_full: If True, load full data into memory; if False, use streaming mode for large files.
204
+ streaming: If True, use streaming mode for large files; if False, load full data into memory.
207
205
  chunk_size: Chunk size for streaming mode (number of rows per chunk).
208
206
  num_workers: Number of worker processes for data loading.
209
207
  sampler: Optional sampler for DataLoader, only used for distributed training.
@@ -218,7 +216,7 @@ class RecDataLoader(FeatureSet):
218
216
  path=data,
219
217
  batch_size=batch_size,
220
218
  shuffle=shuffle,
221
- load_full=load_full,
219
+ streaming=streaming,
222
220
  chunk_size=chunk_size,
223
221
  num_workers=num_workers,
224
222
  )
@@ -231,7 +229,7 @@ class RecDataLoader(FeatureSet):
231
229
  path=data,
232
230
  batch_size=batch_size,
233
231
  shuffle=shuffle,
234
- load_full=load_full,
232
+ streaming=streaming,
235
233
  chunk_size=chunk_size,
236
234
  num_workers=num_workers,
237
235
  )
@@ -291,7 +289,7 @@ class RecDataLoader(FeatureSet):
291
289
  path: str | os.PathLike | list[str] | list[os.PathLike],
292
290
  batch_size: int,
293
291
  shuffle: bool,
294
- load_full: bool,
292
+ streaming: bool,
295
293
  chunk_size: int = 10000,
296
294
  num_workers: int = 0,
297
295
  ) -> DataLoader:
@@ -312,8 +310,17 @@ class RecDataLoader(FeatureSet):
312
310
  f"[RecDataLoader Error] Unsupported file extension in list: {suffix}"
313
311
  )
314
312
  file_type = "csv" if suffix == ".csv" else "parquet"
313
+ if streaming:
314
+ return self.load_files_streaming(
315
+ file_paths,
316
+ file_type,
317
+ batch_size,
318
+ chunk_size,
319
+ shuffle,
320
+ num_workers=num_workers,
321
+ )
315
322
  # Load full data into memory
316
- if load_full:
323
+ else:
317
324
  dfs = []
318
325
  total_bytes = 0
319
326
  for file_path in file_paths:
@@ -326,26 +333,17 @@ class RecDataLoader(FeatureSet):
326
333
  dfs.append(df)
327
334
  except MemoryError as exc:
328
335
  raise MemoryError(
329
- f"[RecDataLoader Error] Out of memory while reading {file_path}. Consider using load_full=False with streaming."
336
+ f"[RecDataLoader Error] Out of memory while reading {file_path}. Consider using streaming=True."
330
337
  ) from exc
331
338
  try:
332
339
  combined_df = pd.concat(dfs, ignore_index=True)
333
340
  except MemoryError as exc:
334
341
  raise MemoryError(
335
- f"[RecDataLoader Error] Out of memory while concatenating loaded data (approx {total_bytes / (1024**3):.2f} GB). Use load_full=False to stream or reduce chunk_size."
342
+ f"[RecDataLoader Error] Out of memory while concatenating loaded data (approx {total_bytes / (1024**3):.2f} GB). Use streaming=True or reduce chunk_size."
336
343
  ) from exc
337
344
  return self.create_from_memory(
338
345
  combined_df, batch_size, shuffle, num_workers=num_workers
339
346
  )
340
- else:
341
- return self.load_files_streaming(
342
- file_paths,
343
- file_type,
344
- batch_size,
345
- chunk_size,
346
- shuffle,
347
- num_workers=num_workers,
348
- )
349
347
 
350
348
  def load_files_streaming(
351
349
  self,