nextrec 0.4.33__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -18
- nextrec/basic/asserts.py +1 -22
- nextrec/basic/callback.py +2 -2
- nextrec/basic/features.py +6 -37
- nextrec/basic/heads.py +13 -1
- nextrec/basic/layers.py +33 -123
- nextrec/basic/loggers.py +3 -2
- nextrec/basic/metrics.py +85 -4
- nextrec/basic/model.py +518 -7
- nextrec/basic/summary.py +88 -42
- nextrec/cli.py +117 -30
- nextrec/data/data_processing.py +8 -13
- nextrec/data/preprocessor.py +449 -844
- nextrec/loss/grad_norm.py +78 -76
- nextrec/models/multi_task/ple.py +1 -0
- nextrec/models/multi_task/share_bottom.py +1 -0
- nextrec/models/ranking/afm.py +4 -9
- nextrec/models/ranking/dien.py +7 -8
- nextrec/models/ranking/ffm.py +2 -2
- nextrec/models/retrieval/sdm.py +1 -2
- nextrec/models/sequential/hstu.py +0 -2
- nextrec/models/tree_base/base.py +1 -1
- nextrec/utils/__init__.py +2 -1
- nextrec/utils/config.py +1 -1
- nextrec/utils/console.py +1 -1
- nextrec/utils/onnx_utils.py +252 -0
- nextrec/utils/torch_utils.py +63 -56
- nextrec/utils/types.py +43 -0
- {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/METADATA +10 -4
- {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/RECORD +34 -42
- nextrec/models/multi_task/[pre]star.py +0 -192
- nextrec/models/representation/autorec.py +0 -0
- nextrec/models/representation/bpr.py +0 -0
- nextrec/models/representation/cl4srec.py +0 -0
- nextrec/models/representation/lightgcn.py +0 -0
- nextrec/models/representation/mf.py +0 -0
- nextrec/models/representation/s3rec.py +0 -0
- nextrec/models/sequential/sasrec.py +0 -0
- nextrec/utils/feature.py +0 -29
- {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/WHEEL +0 -0
- {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/summary.py
CHANGED
|
@@ -8,6 +8,7 @@ Author: Yang Zhou,zyaztec@gmail.com
|
|
|
8
8
|
|
|
9
9
|
from __future__ import annotations
|
|
10
10
|
|
|
11
|
+
import inspect
|
|
11
12
|
import logging
|
|
12
13
|
from typing import Any, Literal
|
|
13
14
|
|
|
@@ -34,6 +35,7 @@ class SummarySet:
|
|
|
34
35
|
scheduler_name: str | None
|
|
35
36
|
scheduler_params: dict[str, Any]
|
|
36
37
|
loss_config: Any
|
|
38
|
+
loss_params: Any
|
|
37
39
|
loss_weights: Any
|
|
38
40
|
grad_norm: Any
|
|
39
41
|
embedding_l1_reg: float
|
|
@@ -73,7 +75,7 @@ class SummarySet:
|
|
|
73
75
|
def build_data_summary(
|
|
74
76
|
self, data: Any, data_loader: DataLoader | None, sample_key: str
|
|
75
77
|
):
|
|
76
|
-
|
|
78
|
+
|
|
77
79
|
dataset = data_loader.dataset if data_loader is not None else None
|
|
78
80
|
|
|
79
81
|
train_size = get_data_length(dataset)
|
|
@@ -325,6 +327,73 @@ class SummarySet:
|
|
|
325
327
|
|
|
326
328
|
if hasattr(self, "loss_config"):
|
|
327
329
|
logger.info(f"Loss Function: {self.loss_config}")
|
|
330
|
+
|
|
331
|
+
loss_params_summary: list[str] = []
|
|
332
|
+
loss_fn = getattr(self, "loss_fn", None)
|
|
333
|
+
if loss_fn is not None:
|
|
334
|
+
loss_modules = (
|
|
335
|
+
list(loss_fn) if isinstance(loss_fn, (list, tuple)) else [loss_fn]
|
|
336
|
+
)
|
|
337
|
+
loss_config = getattr(self, "loss_config", None)
|
|
338
|
+
if isinstance(loss_config, list):
|
|
339
|
+
loss_names = loss_config
|
|
340
|
+
elif loss_config is not None:
|
|
341
|
+
loss_names = [loss_config] * len(loss_modules)
|
|
342
|
+
else:
|
|
343
|
+
loss_names = [None] * len(loss_modules)
|
|
344
|
+
|
|
345
|
+
loss_params = getattr(self, "loss_params", None)
|
|
346
|
+
if isinstance(loss_params, list):
|
|
347
|
+
explicit_params = loss_params
|
|
348
|
+
elif isinstance(loss_params, dict):
|
|
349
|
+
explicit_params = [loss_params] * len(loss_modules)
|
|
350
|
+
else:
|
|
351
|
+
explicit_params = [None] * len(loss_modules)
|
|
352
|
+
|
|
353
|
+
for idx, loss_module in enumerate(loss_modules):
|
|
354
|
+
params: dict[str, Any] = {}
|
|
355
|
+
explicit = (
|
|
356
|
+
explicit_params[idx] if idx < len(explicit_params) else None
|
|
357
|
+
)
|
|
358
|
+
if explicit:
|
|
359
|
+
params.update(explicit)
|
|
360
|
+
try:
|
|
361
|
+
signature = inspect.signature(loss_module.__class__.__init__)
|
|
362
|
+
except (TypeError, ValueError):
|
|
363
|
+
signature = None
|
|
364
|
+
if signature is not None:
|
|
365
|
+
for name, param in signature.parameters.items():
|
|
366
|
+
if name == "self" or name.startswith("_"):
|
|
367
|
+
continue
|
|
368
|
+
if hasattr(loss_module, name):
|
|
369
|
+
value = getattr(loss_module, name)
|
|
370
|
+
if callable(value):
|
|
371
|
+
continue
|
|
372
|
+
params.setdefault(name, value)
|
|
373
|
+
elif (
|
|
374
|
+
param.default is not inspect._empty
|
|
375
|
+
and param.default is not None
|
|
376
|
+
):
|
|
377
|
+
params.setdefault(name, param.default)
|
|
378
|
+
if not params:
|
|
379
|
+
continue
|
|
380
|
+
|
|
381
|
+
loss_name = loss_names[idx] if idx < len(loss_names) else None
|
|
382
|
+
if len(loss_modules) > 1:
|
|
383
|
+
header = f" [{idx}]"
|
|
384
|
+
if loss_name is not None:
|
|
385
|
+
header = f"{header} {loss_name}"
|
|
386
|
+
loss_params_summary.append(header)
|
|
387
|
+
indent = " "
|
|
388
|
+
else:
|
|
389
|
+
indent = " "
|
|
390
|
+
for key, value in params.items():
|
|
391
|
+
loss_params_summary.append(f"{indent}{key:25s}: {value}")
|
|
392
|
+
|
|
393
|
+
if loss_params_summary:
|
|
394
|
+
logger.info("Loss Params:")
|
|
395
|
+
for line in loss_params_summary:
|
|
396
|
+
logger.info(line)
|
|
328
397
|
if hasattr(self, "loss_weights"):
|
|
329
398
|
logger.info(f"Loss Weights: {self.loss_weights}")
|
|
330
399
|
if hasattr(self, "grad_norm"):
|
|
@@ -355,53 +424,30 @@ class SummarySet:
|
|
|
355
424
|
logger.info("")
|
|
356
425
|
logger.info(colorize("Data Summary", color="cyan", bold=True))
|
|
357
426
|
logger.info(colorize("-" * 80, color="cyan"))
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
if
|
|
365
|
-
for target_name, details in label_distributions.items():
|
|
366
|
-
lines = details.get("lines", [])
|
|
367
|
-
logger.info(f"{target_name}:")
|
|
368
|
-
for label, value in lines:
|
|
369
|
-
logger.info(f" {format_kv(label, value)}")
|
|
370
|
-
|
|
371
|
-
dataloader_info = self.train_data_summary.get("dataloader")
|
|
372
|
-
if isinstance(dataloader_info, dict):
|
|
373
|
-
logger.info("Train DataLoader:")
|
|
374
|
-
for key in (
|
|
375
|
-
"batch_size",
|
|
376
|
-
"num_workers",
|
|
377
|
-
"pin_memory",
|
|
378
|
-
"persistent_workers",
|
|
379
|
-
"sampler",
|
|
380
|
-
):
|
|
381
|
-
if key in dataloader_info:
|
|
382
|
-
label = key.replace("_", " ").title()
|
|
383
|
-
logger.info(
|
|
384
|
-
format_kv(label, dataloader_info[key], indent=2)
|
|
385
|
-
)
|
|
386
|
-
|
|
387
|
-
if self.valid_data_summary:
|
|
388
|
-
if self.train_data_summary:
|
|
427
|
+
for label, data_summary in (
|
|
428
|
+
("Train", self.train_data_summary),
|
|
429
|
+
("Valid", self.valid_data_summary),
|
|
430
|
+
):
|
|
431
|
+
if not data_summary:
|
|
432
|
+
continue
|
|
433
|
+
if label == "Valid" and self.train_data_summary:
|
|
389
434
|
logger.info("")
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
435
|
+
sample_key = "train_samples" if label == "Train" else "valid_samples"
|
|
436
|
+
samples = data_summary.get(sample_key)
|
|
437
|
+
if samples is not None:
|
|
438
|
+
logger.info(format_kv(f"{label} Samples", f"{samples:,}"))
|
|
393
439
|
|
|
394
|
-
label_distributions =
|
|
440
|
+
label_distributions = data_summary.get("label_distributions")
|
|
395
441
|
if isinstance(label_distributions, dict):
|
|
396
442
|
for target_name, details in label_distributions.items():
|
|
397
443
|
lines = details.get("lines", [])
|
|
398
444
|
logger.info(f"{target_name}:")
|
|
399
|
-
for
|
|
400
|
-
logger.info(f" {format_kv(
|
|
445
|
+
for line_label, value in lines:
|
|
446
|
+
logger.info(f" {format_kv(line_label, value)}")
|
|
401
447
|
|
|
402
|
-
dataloader_info =
|
|
448
|
+
dataloader_info = data_summary.get("dataloader")
|
|
403
449
|
if isinstance(dataloader_info, dict):
|
|
404
|
-
logger.info("
|
|
450
|
+
logger.info(f"{label} DataLoader:")
|
|
405
451
|
for key in (
|
|
406
452
|
"batch_size",
|
|
407
453
|
"num_workers",
|
|
@@ -410,7 +456,7 @@ class SummarySet:
|
|
|
410
456
|
"sampler",
|
|
411
457
|
):
|
|
412
458
|
if key in dataloader_info:
|
|
413
|
-
|
|
459
|
+
field_label = key.replace("_", " ").title()
|
|
414
460
|
logger.info(
|
|
415
|
-
format_kv(
|
|
461
|
+
format_kv(field_label, dataloader_info[key], indent=2)
|
|
416
462
|
)
|
nextrec/cli.py
CHANGED
|
@@ -48,7 +48,7 @@ from nextrec.utils.data import (
|
|
|
48
48
|
read_yaml,
|
|
49
49
|
resolve_file_paths,
|
|
50
50
|
)
|
|
51
|
-
from nextrec.utils.
|
|
51
|
+
from nextrec.utils.torch_utils import to_list
|
|
52
52
|
|
|
53
53
|
logger = logging.getLogger(__name__)
|
|
54
54
|
|
|
@@ -156,9 +156,7 @@ def train_model(train_config_path: str) -> None:
|
|
|
156
156
|
logger.info(
|
|
157
157
|
format_kv(
|
|
158
158
|
"Validation path",
|
|
159
|
-
resolve_path(
|
|
160
|
-
data_cfg.get("valid_path"), config_dir
|
|
161
|
-
),
|
|
159
|
+
resolve_path(data_cfg.get("valid_path"), config_dir),
|
|
162
160
|
)
|
|
163
161
|
)
|
|
164
162
|
|
|
@@ -247,7 +245,9 @@ def train_model(train_config_path: str) -> None:
|
|
|
247
245
|
|
|
248
246
|
if streaming:
|
|
249
247
|
if file_type is None:
|
|
250
|
-
raise ValueError(
|
|
248
|
+
raise ValueError(
|
|
249
|
+
"[NextRec CLI Error] Streaming mode requires a valid file_type"
|
|
250
|
+
)
|
|
251
251
|
processor.fit_from_files(
|
|
252
252
|
file_paths=streaming_train_files or file_paths,
|
|
253
253
|
file_type=file_type,
|
|
@@ -422,6 +422,49 @@ def train_model(train_config_path: str) -> None:
|
|
|
422
422
|
note=train_cfg.get("note"),
|
|
423
423
|
)
|
|
424
424
|
|
|
425
|
+
export_cfg = train_cfg.get("export_onnx")
|
|
426
|
+
if export_cfg is None:
|
|
427
|
+
export_cfg = cfg.get("export_onnx")
|
|
428
|
+
export_enabled = False
|
|
429
|
+
export_options: dict[str, Any] = {}
|
|
430
|
+
if isinstance(export_cfg, bool):
|
|
431
|
+
export_enabled = export_cfg
|
|
432
|
+
elif isinstance(export_cfg, dict):
|
|
433
|
+
export_options = export_cfg
|
|
434
|
+
export_enabled = bool(export_cfg.get("enable", False))
|
|
435
|
+
|
|
436
|
+
if export_enabled:
|
|
437
|
+
log_cli_section("ONNX Export")
|
|
438
|
+
onnx_path = None
|
|
439
|
+
if export_options.get("path") or export_options.get("save_path"):
|
|
440
|
+
logger.warning(
|
|
441
|
+
"[NextRec CLI Warning] export_onnx.path/save_path is deprecated; "
|
|
442
|
+
"ONNX will be saved to best/checkpoint paths."
|
|
443
|
+
)
|
|
444
|
+
onnx_best_path = Path(model.best_path).with_suffix(".onnx")
|
|
445
|
+
onnx_ckpt_path = Path(model.checkpoint_path).with_suffix(".onnx")
|
|
446
|
+
onnx_batch_size = export_options.get("batch_size", 1)
|
|
447
|
+
onnx_opset = export_options.get("opset_version", 18)
|
|
448
|
+
log_kv_lines(
|
|
449
|
+
[
|
|
450
|
+
("ONNX best path", onnx_best_path),
|
|
451
|
+
("ONNX checkpoint path", onnx_ckpt_path),
|
|
452
|
+
("Batch size", onnx_batch_size),
|
|
453
|
+
("Opset", onnx_opset),
|
|
454
|
+
("Dynamic batch", False),
|
|
455
|
+
]
|
|
456
|
+
)
|
|
457
|
+
model.export_onnx(
|
|
458
|
+
save_path=onnx_best_path,
|
|
459
|
+
batch_size=onnx_batch_size,
|
|
460
|
+
opset_version=onnx_opset,
|
|
461
|
+
)
|
|
462
|
+
model.export_onnx(
|
|
463
|
+
save_path=onnx_ckpt_path,
|
|
464
|
+
batch_size=onnx_batch_size,
|
|
465
|
+
opset_version=onnx_opset,
|
|
466
|
+
)
|
|
467
|
+
|
|
425
468
|
|
|
426
469
|
def predict_model(predict_config_path: str) -> None:
|
|
427
470
|
"""
|
|
@@ -492,12 +535,16 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
492
535
|
# Load checkpoint and ensure required parameters are passed
|
|
493
536
|
checkpoint_base = Path(session_dir)
|
|
494
537
|
if checkpoint_base.is_dir():
|
|
538
|
+
best_candidates = sorted(checkpoint_base.glob("*_best.pt"))
|
|
495
539
|
candidates = sorted(checkpoint_base.glob("*.pt"))
|
|
496
|
-
if
|
|
540
|
+
if best_candidates:
|
|
541
|
+
model_file = best_candidates[-1]
|
|
542
|
+
elif candidates:
|
|
543
|
+
model_file = candidates[-1]
|
|
544
|
+
else:
|
|
497
545
|
raise FileNotFoundError(
|
|
498
546
|
f"[NextRec CLI Error]: Unable to find model checkpoint: {checkpoint_base}"
|
|
499
547
|
)
|
|
500
|
-
model_file = candidates[-1]
|
|
501
548
|
config_dir_for_features = checkpoint_base
|
|
502
549
|
else:
|
|
503
550
|
model_file = (
|
|
@@ -564,11 +611,32 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
564
611
|
)
|
|
565
612
|
|
|
566
613
|
log_cli_section("Model")
|
|
614
|
+
use_onnx = bool(predict_cfg.get("use_onnx")) or bool(predict_cfg.get("onnx_path"))
|
|
615
|
+
onnx_path = predict_cfg.get("onnx_path") or cfg.get("onnx_path")
|
|
616
|
+
if onnx_path:
|
|
617
|
+
onnx_path = resolve_path(onnx_path, config_dir)
|
|
618
|
+
if use_onnx and onnx_path is None:
|
|
619
|
+
search_dir = (
|
|
620
|
+
checkpoint_base if checkpoint_base.is_dir() else checkpoint_base.parent
|
|
621
|
+
)
|
|
622
|
+
best_candidates = sorted(search_dir.glob("*_best.onnx"))
|
|
623
|
+
if best_candidates:
|
|
624
|
+
onnx_path = best_candidates[-1]
|
|
625
|
+
else:
|
|
626
|
+
candidates = sorted(search_dir.glob("*.onnx"))
|
|
627
|
+
if not candidates:
|
|
628
|
+
raise FileNotFoundError(
|
|
629
|
+
f"[NextRec CLI Error]: Unable to find ONNX model in {search_dir}"
|
|
630
|
+
)
|
|
631
|
+
onnx_path = candidates[-1]
|
|
632
|
+
|
|
567
633
|
log_kv_lines(
|
|
568
634
|
[
|
|
569
635
|
("Model", model.__class__.__name__),
|
|
570
636
|
("Checkpoint", model_file),
|
|
571
637
|
("Device", predict_cfg.get("device", "cpu")),
|
|
638
|
+
("Use ONNX", use_onnx),
|
|
639
|
+
("ONNX path", onnx_path if use_onnx else "(disabled)"),
|
|
572
640
|
]
|
|
573
641
|
)
|
|
574
642
|
|
|
@@ -582,7 +650,10 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
582
650
|
)
|
|
583
651
|
|
|
584
652
|
data_path = resolve_path(predict_cfg["data_path"], config_dir)
|
|
585
|
-
|
|
653
|
+
streaming = bool(predict_cfg.get("streaming", True))
|
|
654
|
+
chunk_size = int(predict_cfg.get("chunk_size", 20000))
|
|
655
|
+
batch_size = int(predict_cfg.get("batch_size", 512))
|
|
656
|
+
effective_batch_size = chunk_size if streaming else batch_size
|
|
586
657
|
|
|
587
658
|
log_cli_section("Data")
|
|
588
659
|
log_kv_lines(
|
|
@@ -594,18 +665,18 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
594
665
|
"source_data_format", predict_cfg.get("data_format", "auto")
|
|
595
666
|
),
|
|
596
667
|
),
|
|
597
|
-
("Batch size",
|
|
598
|
-
("Chunk size",
|
|
599
|
-
("Streaming",
|
|
668
|
+
("Batch size", effective_batch_size),
|
|
669
|
+
("Chunk size", chunk_size),
|
|
670
|
+
("Streaming", streaming),
|
|
600
671
|
]
|
|
601
672
|
)
|
|
602
673
|
logger.info("")
|
|
603
674
|
pred_loader = rec_dataloader.create_dataloader(
|
|
604
675
|
data=str(data_path),
|
|
605
|
-
batch_size=batch_size,
|
|
676
|
+
batch_size=1 if streaming else batch_size,
|
|
606
677
|
shuffle=False,
|
|
607
|
-
streaming=
|
|
608
|
-
chunk_size=
|
|
678
|
+
streaming=streaming,
|
|
679
|
+
chunk_size=chunk_size,
|
|
609
680
|
prefetch_factor=predict_cfg.get("prefetch_factor"),
|
|
610
681
|
)
|
|
611
682
|
|
|
@@ -613,27 +684,43 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
613
684
|
"save_data_format", predict_cfg.get("save_format", "csv")
|
|
614
685
|
)
|
|
615
686
|
pred_name = predict_cfg.get("name", "pred")
|
|
616
|
-
|
|
617
|
-
|
|
687
|
+
pred_name_path = Path(pred_name)
|
|
688
|
+
if pred_name_path.is_absolute():
|
|
689
|
+
save_path = pred_name_path
|
|
690
|
+
if save_path.suffix == "":
|
|
691
|
+
save_path = save_path.with_suffix(f".{save_format}")
|
|
692
|
+
else:
|
|
693
|
+
save_path = checkpoint_base / "predictions" / f"{pred_name}.{save_format}"
|
|
618
694
|
|
|
619
695
|
start = time.time()
|
|
620
696
|
logger.info("")
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
697
|
+
if use_onnx:
|
|
698
|
+
result = model.predict_onnx(
|
|
699
|
+
onnx_path=onnx_path,
|
|
700
|
+
data=pred_loader,
|
|
701
|
+
batch_size=effective_batch_size,
|
|
702
|
+
include_ids=bool(id_columns),
|
|
703
|
+
return_dataframe=False,
|
|
704
|
+
save_path=str(save_path),
|
|
705
|
+
save_format=save_format,
|
|
706
|
+
num_workers=predict_cfg.get("num_workers", 0),
|
|
707
|
+
)
|
|
708
|
+
else:
|
|
709
|
+
result = model.predict(
|
|
710
|
+
data=pred_loader,
|
|
711
|
+
batch_size=effective_batch_size,
|
|
712
|
+
include_ids=bool(id_columns),
|
|
713
|
+
return_dataframe=False,
|
|
714
|
+
save_path=str(save_path),
|
|
715
|
+
save_format=save_format,
|
|
716
|
+
num_workers=predict_cfg.get("num_workers", 0),
|
|
717
|
+
)
|
|
630
718
|
duration = time.time() - start
|
|
631
719
|
# When return_dataframe=False, result is the actual file path
|
|
632
|
-
|
|
633
|
-
result
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
)
|
|
720
|
+
if isinstance(result, (str, Path)):
|
|
721
|
+
output_path = Path(result)
|
|
722
|
+
else:
|
|
723
|
+
output_path = save_path
|
|
637
724
|
logger.info(f"Prediction completed, results saved to: {output_path}")
|
|
638
725
|
logger.info(f"Total time: {duration:.2f} seconds")
|
|
639
726
|
|
nextrec/data/data_processing.py
CHANGED
|
@@ -12,18 +12,21 @@ from typing import Any
|
|
|
12
12
|
import numpy as np
|
|
13
13
|
import pandas as pd
|
|
14
14
|
import torch
|
|
15
|
+
import polars as pl
|
|
16
|
+
|
|
15
17
|
|
|
16
18
|
from nextrec.utils.torch_utils import to_numpy
|
|
17
19
|
|
|
18
20
|
|
|
19
|
-
def get_column_data(data: dict | pd.DataFrame, name: str):
|
|
21
|
+
def get_column_data(data: dict | pd.DataFrame | pl.DataFrame, name: str):
|
|
20
22
|
|
|
21
23
|
if isinstance(data, dict):
|
|
22
24
|
return data[name] if name in data else None
|
|
23
25
|
elif isinstance(data, pd.DataFrame):
|
|
24
|
-
if name not in data.columns:
|
|
25
|
-
return None
|
|
26
26
|
return data[name].values
|
|
27
|
+
elif isinstance(data, pl.DataFrame):
|
|
28
|
+
series = data.get_column(name)
|
|
29
|
+
return series.to_numpy()
|
|
27
30
|
else:
|
|
28
31
|
raise KeyError(f"Only dict or DataFrame supported, got {type(data)}")
|
|
29
32
|
|
|
@@ -33,6 +36,8 @@ def get_data_length(data: Any) -> int | None:
|
|
|
33
36
|
return None
|
|
34
37
|
if isinstance(data, pd.DataFrame):
|
|
35
38
|
return len(data)
|
|
39
|
+
if isinstance(data, pl.DataFrame):
|
|
40
|
+
return data.height
|
|
36
41
|
if isinstance(data, dict):
|
|
37
42
|
if not data:
|
|
38
43
|
return None
|
|
@@ -92,16 +97,6 @@ def split_dict_random(data_dict, test_size=0.2, random_state=None):
|
|
|
92
97
|
return train_dict, test_dict
|
|
93
98
|
|
|
94
99
|
|
|
95
|
-
def split_data(
|
|
96
|
-
df: pd.DataFrame, test_size: float = 0.2
|
|
97
|
-
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
98
|
-
|
|
99
|
-
split_idx = int(len(df) * (1 - test_size))
|
|
100
|
-
train_df = df.iloc[:split_idx].reset_index(drop=True)
|
|
101
|
-
valid_df = df.iloc[split_idx:].reset_index(drop=True)
|
|
102
|
-
return train_df, valid_df
|
|
103
|
-
|
|
104
|
-
|
|
105
100
|
def build_eval_candidates(
|
|
106
101
|
df_all: pd.DataFrame,
|
|
107
102
|
user_col: str,
|