nextrec 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/model.py +288 -181
- nextrec/basic/summary.py +21 -4
- nextrec/cli.py +35 -15
- nextrec/data/__init__.py +0 -52
- nextrec/data/batch_utils.py +1 -1
- nextrec/data/data_processing.py +1 -35
- nextrec/data/data_utils.py +0 -4
- nextrec/data/dataloader.py +125 -103
- nextrec/data/preprocessor.py +141 -92
- nextrec/loss/__init__.py +0 -36
- nextrec/models/generative/__init__.py +0 -9
- nextrec/models/tree_base/__init__.py +0 -15
- nextrec/models/tree_base/base.py +14 -23
- nextrec/utils/__init__.py +0 -119
- nextrec/utils/data.py +39 -119
- nextrec/utils/model.py +5 -14
- {nextrec-0.5.1.dist-info → nextrec-0.5.3.dist-info}/METADATA +4 -5
- {nextrec-0.5.1.dist-info → nextrec-0.5.3.dist-info}/RECORD +22 -22
- {nextrec-0.5.1.dist-info → nextrec-0.5.3.dist-info}/WHEEL +0 -0
- {nextrec-0.5.1.dist-info → nextrec-0.5.3.dist-info}/entry_points.txt +0 -0
- {nextrec-0.5.1.dist-info → nextrec-0.5.3.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 01/02/2026
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -13,20 +13,18 @@ import os
|
|
|
13
13
|
import sys
|
|
14
14
|
import pickle
|
|
15
15
|
import socket
|
|
16
|
+
import multiprocessing as mp
|
|
16
17
|
from pathlib import Path
|
|
17
18
|
from typing import Any, Literal, cast, overload
|
|
18
19
|
|
|
19
20
|
import numpy as np
|
|
20
21
|
import pandas as pd
|
|
22
|
+
import pyarrow as pa
|
|
23
|
+
import pyarrow.parquet as pq
|
|
24
|
+
import polars as pl
|
|
25
|
+
import swanlab
|
|
26
|
+
import wandb
|
|
21
27
|
|
|
22
|
-
try:
|
|
23
|
-
import swanlab # type: ignore
|
|
24
|
-
except ModuleNotFoundError:
|
|
25
|
-
swanlab = None
|
|
26
|
-
try:
|
|
27
|
-
import wandb # type: ignore
|
|
28
|
-
except ModuleNotFoundError:
|
|
29
|
-
wandb = None
|
|
30
28
|
|
|
31
29
|
import torch
|
|
32
30
|
import torch.distributed as dist
|
|
@@ -65,15 +63,9 @@ from nextrec.data.dataloader import (
|
|
|
65
63
|
TensorDictDataset,
|
|
66
64
|
build_tensors_from_data,
|
|
67
65
|
)
|
|
68
|
-
from nextrec.
|
|
69
|
-
from nextrec.loss import
|
|
70
|
-
|
|
71
|
-
GradNormLossWeighting,
|
|
72
|
-
HingeLoss,
|
|
73
|
-
InfoNCELoss,
|
|
74
|
-
SampledSoftmaxLoss,
|
|
75
|
-
TripletLoss,
|
|
76
|
-
)
|
|
66
|
+
from nextrec.loss.grad_norm import GradNormLossWeighting
|
|
67
|
+
from nextrec.loss.listwise import InfoNCELoss, SampledSoftmaxLoss
|
|
68
|
+
from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
|
|
77
69
|
from nextrec.utils.loss import get_loss_fn
|
|
78
70
|
from nextrec.loss.grad_norm import get_grad_norm_shared_params
|
|
79
71
|
from nextrec.utils.console import display_metrics_table, progress
|
|
@@ -111,8 +103,6 @@ from nextrec.utils.types import (
|
|
|
111
103
|
MetricsName,
|
|
112
104
|
)
|
|
113
105
|
|
|
114
|
-
from nextrec.utils.data import FILE_FORMAT_CONFIG
|
|
115
|
-
|
|
116
106
|
|
|
117
107
|
class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
118
108
|
@property
|
|
@@ -1619,14 +1609,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1619
1609
|
)
|
|
1620
1610
|
)
|
|
1621
1611
|
return {}
|
|
1622
|
-
|
|
1623
|
-
# logging.info(
|
|
1624
|
-
# colorize(
|
|
1625
|
-
# format_kv(
|
|
1626
|
-
# "Evaluation samples", y_true_all.shape[0]
|
|
1627
|
-
# ),
|
|
1628
|
-
# )
|
|
1629
|
-
# )
|
|
1612
|
+
|
|
1630
1613
|
logging.info("")
|
|
1631
1614
|
metrics_dict = evaluate_metrics(
|
|
1632
1615
|
y_true=y_true_all,
|
|
@@ -1643,106 +1626,141 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1643
1626
|
@overload
|
|
1644
1627
|
def predict(
|
|
1645
1628
|
self,
|
|
1646
|
-
data: str |
|
|
1629
|
+
data: str | os.PathLike | DataLoader,
|
|
1647
1630
|
batch_size: int = 32,
|
|
1648
1631
|
save_path: str | os.PathLike | None = None,
|
|
1649
1632
|
save_format: str = "csv",
|
|
1650
|
-
include_ids: bool | None = None,
|
|
1651
|
-
id_columns: str | list[str] | None = None,
|
|
1652
1633
|
return_dataframe: Literal[True] = True,
|
|
1653
1634
|
stream_chunk_size: int = 10000,
|
|
1654
1635
|
num_workers: int = 0,
|
|
1636
|
+
prefetch_factor: int | None = None,
|
|
1637
|
+
num_processes: int = 1,
|
|
1638
|
+
processor: Any | None = None,
|
|
1655
1639
|
) -> pd.DataFrame: ...
|
|
1656
1640
|
|
|
1657
1641
|
@overload
|
|
1658
1642
|
def predict(
|
|
1659
1643
|
self,
|
|
1660
|
-
data: str |
|
|
1644
|
+
data: str | os.PathLike | DataLoader,
|
|
1661
1645
|
batch_size: int = 32,
|
|
1662
1646
|
save_path: None = None,
|
|
1663
1647
|
save_format: str = "csv",
|
|
1664
|
-
include_ids: bool | None = None,
|
|
1665
|
-
id_columns: str | list[str] | None = None,
|
|
1666
1648
|
return_dataframe: Literal[False] = False,
|
|
1667
1649
|
stream_chunk_size: int = 10000,
|
|
1668
1650
|
num_workers: int = 0,
|
|
1651
|
+
prefetch_factor: int | None = None,
|
|
1652
|
+
num_processes: int = 1,
|
|
1653
|
+
processor: Any | None = None,
|
|
1669
1654
|
) -> np.ndarray: ...
|
|
1670
1655
|
|
|
1671
1656
|
@overload
|
|
1672
1657
|
def predict(
|
|
1673
1658
|
self,
|
|
1674
|
-
data: str | dict | pd.DataFrame | DataLoader,
|
|
1659
|
+
data: str | os.PathLike | dict | pd.DataFrame | DataLoader,
|
|
1675
1660
|
batch_size: int = 32,
|
|
1676
1661
|
*,
|
|
1677
1662
|
save_path: str | os.PathLike,
|
|
1678
1663
|
save_format: str = "csv",
|
|
1679
|
-
include_ids: bool | None = None,
|
|
1680
|
-
id_columns: str | list[str] | None = None,
|
|
1681
1664
|
return_dataframe: Literal[False] = False,
|
|
1682
1665
|
stream_chunk_size: int = 10000,
|
|
1683
1666
|
num_workers: int = 0,
|
|
1667
|
+
prefetch_factor: int | None = None,
|
|
1668
|
+
num_processes: int = 1,
|
|
1669
|
+
processor: Any | None = None,
|
|
1684
1670
|
) -> Path: ...
|
|
1685
1671
|
|
|
1686
1672
|
def predict(
|
|
1687
1673
|
self,
|
|
1688
|
-
data: str | dict | pd.DataFrame | DataLoader,
|
|
1674
|
+
data: str | os.PathLike | dict | pd.DataFrame | DataLoader,
|
|
1689
1675
|
batch_size: int = 32,
|
|
1690
1676
|
save_path: str | os.PathLike | None = None,
|
|
1691
1677
|
save_format: str = "csv",
|
|
1692
|
-
include_ids: bool | None = None,
|
|
1693
|
-
id_columns: str | list[str] | None = None,
|
|
1694
1678
|
return_dataframe: bool = True,
|
|
1695
1679
|
stream_chunk_size: int = 10000,
|
|
1696
1680
|
num_workers: int = 0,
|
|
1681
|
+
prefetch_factor: int | None = None,
|
|
1682
|
+
num_processes: int = 1,
|
|
1683
|
+
processor: Any | None = None,
|
|
1697
1684
|
) -> pd.DataFrame | np.ndarray | Path | None:
|
|
1698
1685
|
"""
|
|
1699
1686
|
Make predictions on the given data.
|
|
1700
1687
|
|
|
1701
1688
|
Args:
|
|
1702
|
-
data: Input data for prediction (file path
|
|
1689
|
+
data: Input data for prediction (file path or DataLoader).
|
|
1703
1690
|
batch_size: Batch size for prediction (per process when distributed).
|
|
1704
1691
|
save_path: Optional path to save predictions; if None, predictions are not saved to disk.
|
|
1705
1692
|
save_format: Format to save predictions ('csv' or 'parquet').
|
|
1706
|
-
include_ids: Whether to include ID columns in the output; if None, includes if id_columns are set.
|
|
1707
|
-
id_columns: Column name(s) to use as IDs; if None, uses model's id_columns.
|
|
1708
1693
|
return_dataframe: Whether to return predictions as a pandas DataFrame; if False, returns a NumPy array.
|
|
1709
1694
|
stream_chunk_size: Number of rows per chunk when using streaming mode for large datasets.
|
|
1710
1695
|
num_workers: DataLoader worker count.
|
|
1696
|
+
prefetch_factor: Number of batches prefetched per worker (only when num_workers > 0).
|
|
1697
|
+
num_processes: Number of inference processes for streaming file inference.
|
|
1698
|
+
processor: Optional DataProcessor for transforming input data.
|
|
1711
1699
|
|
|
1712
1700
|
Note:
|
|
1713
|
-
predict does not support distributed mode currently
|
|
1701
|
+
predict does not support distributed mode currently; streaming file inference can use
|
|
1702
|
+
multiple processes via num_processes > 1, which may change output order.
|
|
1714
1703
|
"""
|
|
1715
1704
|
self.eval()
|
|
1716
|
-
# Use prediction-time id_columns if provided, otherwise fall back to model's id_columns
|
|
1717
|
-
predict_id_columns = id_columns if id_columns is not None else self.id_columns
|
|
1718
|
-
if isinstance(predict_id_columns, str):
|
|
1719
|
-
predict_id_columns = [predict_id_columns]
|
|
1720
|
-
|
|
1721
|
-
if include_ids is None:
|
|
1722
|
-
include_ids = bool(predict_id_columns)
|
|
1723
|
-
include_ids = include_ids and bool(predict_id_columns)
|
|
1724
1705
|
|
|
1725
|
-
#
|
|
1726
|
-
if
|
|
1706
|
+
# streaming mode prediction
|
|
1707
|
+
if (
|
|
1708
|
+
save_path is not None
|
|
1709
|
+
and not return_dataframe
|
|
1710
|
+
and isinstance(data, (str, os.PathLike, DataLoader))
|
|
1711
|
+
):
|
|
1712
|
+
if num_processes > 1 and not isinstance(data, (str, os.PathLike)):
|
|
1713
|
+
raise ValueError(
|
|
1714
|
+
"[BaseModel-predict Error] Multi-process streaming requires data to be a file path."
|
|
1715
|
+
)
|
|
1716
|
+
if num_workers != 0:
|
|
1717
|
+
logging.info(
|
|
1718
|
+
"[BaseModel-predict-streaming Info] Streaming mode enforces num_workers=0."
|
|
1719
|
+
)
|
|
1720
|
+
logging.info("")
|
|
1727
1721
|
return self.predict_streaming(
|
|
1728
1722
|
data=data,
|
|
1729
1723
|
batch_size=batch_size,
|
|
1730
1724
|
save_path=save_path,
|
|
1731
1725
|
save_format=save_format,
|
|
1732
|
-
include_ids=include_ids,
|
|
1733
1726
|
stream_chunk_size=stream_chunk_size,
|
|
1734
1727
|
return_dataframe=return_dataframe,
|
|
1735
|
-
|
|
1728
|
+
num_workers=0,
|
|
1729
|
+
num_processes=num_processes,
|
|
1730
|
+
processor=processor,
|
|
1736
1731
|
)
|
|
1737
1732
|
|
|
1738
|
-
|
|
1733
|
+
return self.predict_in_memory(
|
|
1734
|
+
data=data,
|
|
1735
|
+
batch_size=batch_size,
|
|
1736
|
+
save_path=save_path,
|
|
1737
|
+
save_format=save_format,
|
|
1738
|
+
return_dataframe=return_dataframe,
|
|
1739
|
+
stream_chunk_size=stream_chunk_size,
|
|
1740
|
+
num_workers=num_workers,
|
|
1741
|
+
prefetch_factor=prefetch_factor,
|
|
1742
|
+
processor=processor,
|
|
1743
|
+
)
|
|
1744
|
+
|
|
1745
|
+
def predict_in_memory(
|
|
1746
|
+
self,
|
|
1747
|
+
data: str | os.PathLike | dict | pd.DataFrame | DataLoader,
|
|
1748
|
+
batch_size: int = 32,
|
|
1749
|
+
save_path: str | os.PathLike | None = None,
|
|
1750
|
+
save_format: str = "csv",
|
|
1751
|
+
return_dataframe: bool = True,
|
|
1752
|
+
stream_chunk_size: int = 10000,
|
|
1753
|
+
num_workers: int = 0,
|
|
1754
|
+
prefetch_factor: int | None = None,
|
|
1755
|
+
processor: Any | None = None,
|
|
1756
|
+
) -> pd.DataFrame | np.ndarray | Path | None:
|
|
1757
|
+
|
|
1758
|
+
predict_id_columns = self.id_columns
|
|
1759
|
+
if isinstance(predict_id_columns, str):
|
|
1760
|
+
predict_id_columns = [predict_id_columns]
|
|
1761
|
+
include_ids = bool(predict_id_columns)
|
|
1739
1762
|
if isinstance(data, DataLoader):
|
|
1740
1763
|
data_loader = data
|
|
1741
|
-
if num_workers != 0:
|
|
1742
|
-
logging.warning(
|
|
1743
|
-
"[Predict Warning] num_workers parameter is ignored when data is already a DataLoader. "
|
|
1744
|
-
"The DataLoader's existing num_workers configuration will be used."
|
|
1745
|
-
)
|
|
1746
1764
|
elif isinstance(data, (str, os.PathLike)):
|
|
1747
1765
|
rec_loader = RecDataLoader(
|
|
1748
1766
|
dense_features=self.dense_features,
|
|
@@ -1750,6 +1768,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1750
1768
|
sequence_features=self.sequence_features,
|
|
1751
1769
|
target=self.target_columns,
|
|
1752
1770
|
id_columns=predict_id_columns,
|
|
1771
|
+
processor=processor,
|
|
1753
1772
|
)
|
|
1754
1773
|
data_loader = rec_loader.create_dataloader(
|
|
1755
1774
|
data=data,
|
|
@@ -1757,6 +1776,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1757
1776
|
shuffle=False,
|
|
1758
1777
|
streaming=True,
|
|
1759
1778
|
chunk_size=stream_chunk_size,
|
|
1779
|
+
num_workers=0,
|
|
1780
|
+
prefetch_factor=prefetch_factor,
|
|
1760
1781
|
)
|
|
1761
1782
|
else:
|
|
1762
1783
|
data_loader = self.prepare_data_loader(
|
|
@@ -1834,23 +1855,16 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1834
1855
|
else y_pred_all
|
|
1835
1856
|
)
|
|
1836
1857
|
if save_path is not None:
|
|
1837
|
-
|
|
1838
|
-
|
|
1839
|
-
|
|
1840
|
-
|
|
1841
|
-
"The entire result will be saved at once. Use csv or parquet for large datasets."
|
|
1858
|
+
if save_format not in {"csv", "parquet"}:
|
|
1859
|
+
raise ValueError(
|
|
1860
|
+
f"Unsupported save format: {save_format}. "
|
|
1861
|
+
"Supported: csv, parquet"
|
|
1842
1862
|
)
|
|
1843
|
-
|
|
1844
|
-
# Get file extension from format
|
|
1845
|
-
from nextrec.utils.data import FILE_FORMAT_CONFIG
|
|
1846
|
-
|
|
1847
|
-
suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
|
|
1848
|
-
|
|
1849
1863
|
target_path = get_save_path(
|
|
1850
1864
|
path=save_path,
|
|
1851
1865
|
default_dir=self.session.predictions_dir,
|
|
1852
1866
|
default_name="predictions",
|
|
1853
|
-
suffix=
|
|
1867
|
+
suffix=f".{save_format}",
|
|
1854
1868
|
add_timestamp=True if save_path is None else False,
|
|
1855
1869
|
)
|
|
1856
1870
|
if isinstance(output, pd.DataFrame):
|
|
@@ -1870,12 +1884,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1870
1884
|
df_to_save.to_csv(target_path, index=False)
|
|
1871
1885
|
elif save_format == "parquet":
|
|
1872
1886
|
df_to_save.to_parquet(target_path, index=False)
|
|
1873
|
-
elif save_format == "feather":
|
|
1874
|
-
df_to_save.to_feather(target_path)
|
|
1875
|
-
elif save_format == "excel":
|
|
1876
|
-
df_to_save.to_excel(target_path, index=False)
|
|
1877
|
-
elif save_format == "hdf5":
|
|
1878
|
-
df_to_save.to_hdf(target_path, key="predictions", mode="w")
|
|
1879
1887
|
else:
|
|
1880
1888
|
raise ValueError(f"Unsupported save format: {save_format}")
|
|
1881
1889
|
|
|
@@ -1886,37 +1894,64 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1886
1894
|
|
|
1887
1895
|
def predict_streaming(
|
|
1888
1896
|
self,
|
|
1889
|
-
data: str |
|
|
1897
|
+
data: str | os.PathLike | DataLoader,
|
|
1890
1898
|
batch_size: int,
|
|
1891
1899
|
save_path: str | os.PathLike,
|
|
1892
1900
|
save_format: str,
|
|
1893
|
-
include_ids: bool,
|
|
1894
1901
|
stream_chunk_size: int,
|
|
1895
1902
|
return_dataframe: bool,
|
|
1896
|
-
|
|
1903
|
+
num_workers: int = 0,
|
|
1904
|
+
prefetch_factor: int | None = None,
|
|
1905
|
+
num_processes: int = 1,
|
|
1906
|
+
processor: Any | None = None,
|
|
1907
|
+
shard_rank: int = 0,
|
|
1908
|
+
shard_count: int = 1,
|
|
1897
1909
|
):
|
|
1898
1910
|
"""
|
|
1899
1911
|
Make predictions on the given data using streaming mode for large datasets.
|
|
1900
1912
|
|
|
1901
1913
|
Args:
|
|
1902
|
-
data: Input data for prediction (file path
|
|
1914
|
+
data: Input data for prediction (file path or DataLoader).
|
|
1903
1915
|
batch_size: Batch size for prediction.
|
|
1904
1916
|
save_path: Path to save predictions.
|
|
1905
1917
|
save_format: Format to save predictions ('csv' or 'parquet').
|
|
1906
|
-
include_ids: Whether to include ID columns in the output.
|
|
1907
1918
|
stream_chunk_size: Number of rows per chunk when using streaming mode.
|
|
1908
1919
|
return_dataframe: Whether to return predictions as a pandas DataFrame.
|
|
1909
|
-
|
|
1910
|
-
|
|
1911
|
-
|
|
1920
|
+
num_workers: DataLoader worker count.
|
|
1921
|
+
prefetch_factor: Number of batches prefetched per worker (only when num_workers > 0).
|
|
1922
|
+
num_processes: Number of inference processes for streaming file inference.
|
|
1923
|
+
processor: Optional DataProcessor for transforming input data.
|
|
1924
|
+
shard_rank: Process shard rank for multi-process inference.
|
|
1925
|
+
shard_count: Total number of shards for multi-process inference.
|
|
1912
1926
|
"""
|
|
1927
|
+
predict_id_columns = self.id_columns
|
|
1928
|
+
if isinstance(predict_id_columns, str):
|
|
1929
|
+
predict_id_columns = [predict_id_columns]
|
|
1930
|
+
include_ids = bool(predict_id_columns)
|
|
1931
|
+
|
|
1932
|
+
# Multi-process streaming
|
|
1933
|
+
if num_processes > 1:
|
|
1934
|
+
return self.predict_streaming_multiprocess(
|
|
1935
|
+
data=data,
|
|
1936
|
+
batch_size=batch_size,
|
|
1937
|
+
save_path=save_path,
|
|
1938
|
+
save_format=save_format,
|
|
1939
|
+
stream_chunk_size=stream_chunk_size,
|
|
1940
|
+
return_dataframe=return_dataframe,
|
|
1941
|
+
num_workers=num_workers,
|
|
1942
|
+
prefetch_factor=None, # disable prefetching in multi-process mode
|
|
1943
|
+
num_processes=num_processes,
|
|
1944
|
+
processor=processor,
|
|
1945
|
+
)
|
|
1946
|
+
# Single-process streaming
|
|
1913
1947
|
if isinstance(data, (str, os.PathLike)):
|
|
1914
1948
|
rec_loader = RecDataLoader(
|
|
1915
1949
|
dense_features=self.dense_features,
|
|
1916
1950
|
sparse_features=self.sparse_features,
|
|
1917
1951
|
sequence_features=self.sequence_features,
|
|
1918
1952
|
target=self.target_columns,
|
|
1919
|
-
id_columns=
|
|
1953
|
+
id_columns=predict_id_columns,
|
|
1954
|
+
processor=processor,
|
|
1920
1955
|
)
|
|
1921
1956
|
data_loader = rec_loader.create_dataloader(
|
|
1922
1957
|
data=data,
|
|
@@ -1924,53 +1959,41 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1924
1959
|
shuffle=False,
|
|
1925
1960
|
streaming=True,
|
|
1926
1961
|
chunk_size=stream_chunk_size,
|
|
1962
|
+
num_workers=num_workers,
|
|
1963
|
+
prefetch_factor=None if num_workers == 0 else prefetch_factor,
|
|
1964
|
+
shard_rank=shard_rank,
|
|
1965
|
+
shard_count=shard_count,
|
|
1927
1966
|
)
|
|
1928
1967
|
elif not isinstance(data, DataLoader):
|
|
1929
|
-
|
|
1930
|
-
data
|
|
1931
|
-
batch_size=batch_size,
|
|
1932
|
-
shuffle=False,
|
|
1968
|
+
raise TypeError(
|
|
1969
|
+
"[BaseModel-predict-streaming Error] data must be a file path or a DataLoader."
|
|
1933
1970
|
)
|
|
1934
|
-
else:
|
|
1971
|
+
else: # data is a DataLoader
|
|
1935
1972
|
data_loader = data
|
|
1936
1973
|
|
|
1937
|
-
|
|
1938
|
-
|
|
1939
|
-
|
|
1940
|
-
and "Streaming" in data_loader.dataset.__class__.__name__
|
|
1941
|
-
):
|
|
1942
|
-
logging.warning(
|
|
1943
|
-
f"[Predict Streaming Warning] Detected DataLoader with num_workers={data_loader.num_workers} "
|
|
1944
|
-
"and streaming dataset. This may cause data duplication! "
|
|
1945
|
-
"When using streaming mode, set num_workers=0 to avoid reading data multiple times."
|
|
1946
|
-
)
|
|
1947
|
-
|
|
1948
|
-
# Check streaming support and prepare file path
|
|
1949
|
-
if not check_streaming_support(save_format):
|
|
1950
|
-
logging.warning(
|
|
1951
|
-
f"[Predict Streaming Warning] Format '{save_format}' does not support streaming writes. "
|
|
1952
|
-
"Results will be collected in memory and saved at the end. Use csv or parquet for true streaming."
|
|
1974
|
+
if save_format not in {"csv", "parquet"}:
|
|
1975
|
+
raise ValueError(
|
|
1976
|
+
f"Unsupported save format: {save_format}. Supported: csv, parquet"
|
|
1953
1977
|
)
|
|
1954
|
-
|
|
1955
|
-
suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
|
|
1956
|
-
|
|
1957
1978
|
target_path = get_save_path(
|
|
1958
1979
|
path=save_path,
|
|
1959
1980
|
default_dir=self.session.predictions_dir,
|
|
1960
1981
|
default_name="predictions",
|
|
1961
|
-
suffix=
|
|
1982
|
+
suffix=f".{save_format}",
|
|
1962
1983
|
add_timestamp=True if save_path is None else False,
|
|
1963
1984
|
)
|
|
1964
1985
|
target_path.parent.mkdir(parents=True, exist_ok=True)
|
|
1965
|
-
header_written = target_path.exists()
|
|
1986
|
+
header_written = target_path.exists()
|
|
1966
1987
|
parquet_writer = None
|
|
1988
|
+
|
|
1967
1989
|
pred_columns = None
|
|
1968
|
-
|
|
1969
|
-
[]
|
|
1970
|
-
) # used when return_dataframe=True or for non-streaming formats
|
|
1990
|
+
cached_frames = [] # used when return_dataframe=True
|
|
1971
1991
|
|
|
1992
|
+
disable_progress = shard_count > 1
|
|
1972
1993
|
with torch.no_grad():
|
|
1973
|
-
for batch_data in progress(
|
|
1994
|
+
for batch_data in progress(
|
|
1995
|
+
data_loader, description="Predicting", disable=disable_progress
|
|
1996
|
+
):
|
|
1974
1997
|
batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
|
|
1975
1998
|
X_input, _ = self.get_input(batch_dict, require_labels=False)
|
|
1976
1999
|
y_pred = self.forward(X_input)
|
|
@@ -1989,14 +2012,18 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1989
2012
|
while len(pred_columns) < num_outputs:
|
|
1990
2013
|
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
1991
2014
|
|
|
1992
|
-
ids =
|
|
2015
|
+
ids = (
|
|
2016
|
+
batch_dict.get("ids")
|
|
2017
|
+
if include_ids and predict_id_columns
|
|
2018
|
+
else None
|
|
2019
|
+
)
|
|
1993
2020
|
id_arrays_batch = {
|
|
1994
2021
|
id_name: (
|
|
1995
2022
|
ids[id_name].detach().cpu().numpy()
|
|
1996
2023
|
if isinstance(ids[id_name], torch.Tensor)
|
|
1997
2024
|
else np.asarray(ids[id_name])
|
|
1998
2025
|
).reshape(-1)
|
|
1999
|
-
for id_name in (
|
|
2026
|
+
for id_name in (predict_id_columns or [])
|
|
2000
2027
|
if ids and id_name in ids
|
|
2001
2028
|
}
|
|
2002
2029
|
|
|
@@ -2015,48 +2042,123 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
2015
2042
|
target_path, mode="a", header=not header_written, index=False
|
|
2016
2043
|
)
|
|
2017
2044
|
header_written = True
|
|
2045
|
+
if return_dataframe:
|
|
2046
|
+
cached_frames.append(df_batch)
|
|
2018
2047
|
elif save_format == "parquet":
|
|
2019
|
-
try:
|
|
2020
|
-
import pyarrow as pa
|
|
2021
|
-
import pyarrow.parquet as pq
|
|
2022
|
-
except ImportError as exc: # pragma: no cover
|
|
2023
|
-
raise ImportError(
|
|
2024
|
-
"[BaseModel-predict-streaming Error] Parquet streaming save requires pyarrow."
|
|
2025
|
-
) from exc
|
|
2026
2048
|
table = pa.Table.from_pandas(df_batch, preserve_index=False)
|
|
2027
2049
|
if parquet_writer is None:
|
|
2028
2050
|
parquet_writer = pq.ParquetWriter(target_path, table.schema)
|
|
2029
2051
|
parquet_writer.write_table(table)
|
|
2052
|
+
if return_dataframe:
|
|
2053
|
+
cached_frames.append(df_batch)
|
|
2030
2054
|
else:
|
|
2031
2055
|
# Non-streaming formats: collect all data
|
|
2032
|
-
|
|
2033
|
-
|
|
2034
|
-
if return_dataframe and save_format in ["csv", "parquet"]:
|
|
2035
|
-
collected_frames.append(df_batch)
|
|
2056
|
+
cached_frames.append(df_batch)
|
|
2036
2057
|
|
|
2037
2058
|
# Close writers
|
|
2038
2059
|
if parquet_writer is not None:
|
|
2039
2060
|
parquet_writer.close()
|
|
2040
|
-
# For non-streaming formats, save collected data
|
|
2041
|
-
if save_format in ["feather", "excel", "hdf5"] and collected_frames:
|
|
2042
|
-
combined_df = pd.concat(collected_frames, ignore_index=True)
|
|
2043
|
-
if save_format == "feather":
|
|
2044
|
-
combined_df.to_feather(target_path)
|
|
2045
|
-
elif save_format == "excel":
|
|
2046
|
-
combined_df.to_excel(target_path, index=False)
|
|
2047
|
-
elif save_format == "hdf5":
|
|
2048
|
-
combined_df.to_hdf(target_path, key="predictions", mode="w")
|
|
2049
2061
|
|
|
2050
2062
|
logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
|
|
2051
2063
|
if return_dataframe:
|
|
2052
2064
|
return (
|
|
2053
|
-
pd.concat(
|
|
2054
|
-
if
|
|
2065
|
+
pd.concat(cached_frames, ignore_index=True)
|
|
2066
|
+
if cached_frames
|
|
2055
2067
|
else pd.DataFrame(columns=pred_columns or [])
|
|
2056
2068
|
)
|
|
2057
2069
|
# Return the actual save path when not returning dataframe
|
|
2058
2070
|
return target_path
|
|
2059
2071
|
|
|
2072
|
+
def predict_streaming_multiprocess(
|
|
2073
|
+
self,
|
|
2074
|
+
data: str | os.PathLike | DataLoader,
|
|
2075
|
+
batch_size: int,
|
|
2076
|
+
save_path: str | os.PathLike,
|
|
2077
|
+
save_format: str,
|
|
2078
|
+
stream_chunk_size: int,
|
|
2079
|
+
return_dataframe: bool,
|
|
2080
|
+
num_workers: int,
|
|
2081
|
+
prefetch_factor: int | None,
|
|
2082
|
+
num_processes: int,
|
|
2083
|
+
processor: Any | None,
|
|
2084
|
+
):
|
|
2085
|
+
target_path = Path(
|
|
2086
|
+
get_save_path(
|
|
2087
|
+
path=save_path,
|
|
2088
|
+
default_dir=self.session.predictions_dir,
|
|
2089
|
+
default_name="predictions",
|
|
2090
|
+
suffix=f".{save_format}",
|
|
2091
|
+
add_timestamp=True if save_path is None else False,
|
|
2092
|
+
)
|
|
2093
|
+
)
|
|
2094
|
+
parts_dir = target_path.parent / f".{target_path.stem}_parts"
|
|
2095
|
+
parts_dir.mkdir(parents=True, exist_ok=True)
|
|
2096
|
+
part_paths = [
|
|
2097
|
+
parts_dir / f"{target_path.stem}.part{rank}{target_path.suffix}"
|
|
2098
|
+
for rank in range(num_processes)
|
|
2099
|
+
]
|
|
2100
|
+
|
|
2101
|
+
ctx = mp.get_context("spawn")
|
|
2102
|
+
processes = []
|
|
2103
|
+
for rank in range(num_processes):
|
|
2104
|
+
process = ctx.Process(
|
|
2105
|
+
target=predict_streaming_worker,
|
|
2106
|
+
args=(
|
|
2107
|
+
self,
|
|
2108
|
+
data,
|
|
2109
|
+
batch_size,
|
|
2110
|
+
part_paths[rank],
|
|
2111
|
+
save_format,
|
|
2112
|
+
stream_chunk_size,
|
|
2113
|
+
num_workers,
|
|
2114
|
+
prefetch_factor,
|
|
2115
|
+
processor,
|
|
2116
|
+
rank,
|
|
2117
|
+
num_processes,
|
|
2118
|
+
),
|
|
2119
|
+
)
|
|
2120
|
+
process.start()
|
|
2121
|
+
processes.append(process)
|
|
2122
|
+
|
|
2123
|
+
for process in progress(
|
|
2124
|
+
iter(processes), description="Predicting...", total=None
|
|
2125
|
+
):
|
|
2126
|
+
process.join()
|
|
2127
|
+
|
|
2128
|
+
for process in processes:
|
|
2129
|
+
if process.exitcode not in (0, None):
|
|
2130
|
+
raise RuntimeError(
|
|
2131
|
+
"[BaseModel-predict-streaming Error] One or more inference processes failed."
|
|
2132
|
+
)
|
|
2133
|
+
# Merge part files
|
|
2134
|
+
existing_parts = [p for p in part_paths if p.exists()]
|
|
2135
|
+
if existing_parts:
|
|
2136
|
+
target_path.parent.mkdir(parents=True, exist_ok=True)
|
|
2137
|
+
if save_format == "csv":
|
|
2138
|
+
lazy_frames = [pl.scan_csv(p) for p in existing_parts]
|
|
2139
|
+
pl.concat(lazy_frames).sink_csv(target_path)
|
|
2140
|
+
elif save_format == "parquet":
|
|
2141
|
+
lazy_frames = [pl.scan_parquet(p) for p in existing_parts]
|
|
2142
|
+
pl.concat(lazy_frames).sink_parquet(target_path)
|
|
2143
|
+
else:
|
|
2144
|
+
raise ValueError(
|
|
2145
|
+
f"Unsupported save format: {save_format}. Supported: csv, parquet"
|
|
2146
|
+
)
|
|
2147
|
+
|
|
2148
|
+
for part_path in part_paths:
|
|
2149
|
+
if part_path.exists():
|
|
2150
|
+
part_path.unlink()
|
|
2151
|
+
if parts_dir.exists() and not any(parts_dir.iterdir()):
|
|
2152
|
+
parts_dir.rmdir()
|
|
2153
|
+
|
|
2154
|
+
logging.info(
|
|
2155
|
+
colorize(
|
|
2156
|
+
f"Predictions saved to: {target_path} (merged from {num_processes} parts)",
|
|
2157
|
+
color="green",
|
|
2158
|
+
)
|
|
2159
|
+
)
|
|
2160
|
+
return target_path
|
|
2161
|
+
|
|
2060
2162
|
def prepare_onnx_dataloader(
|
|
2061
2163
|
self,
|
|
2062
2164
|
data: str | dict | pd.DataFrame | DataLoader,
|
|
@@ -2074,11 +2176,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
2074
2176
|
|
|
2075
2177
|
"""
|
|
2076
2178
|
if isinstance(data, DataLoader):
|
|
2077
|
-
if num_workers != 0:
|
|
2078
|
-
logging.warning(
|
|
2079
|
-
"[Predict ONNX Warning] num_workers parameter is ignored when data is already a DataLoader. "
|
|
2080
|
-
"The DataLoader's existing num_workers configuration will be used."
|
|
2081
|
-
)
|
|
2082
2179
|
return data
|
|
2083
2180
|
# if data is a file path, use streaming DataLoader
|
|
2084
2181
|
# will set batch_size=1 cause each batch is a file chunk
|
|
@@ -2366,18 +2463,16 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
2366
2463
|
output = df_to_return
|
|
2367
2464
|
|
|
2368
2465
|
if save_path is not None:
|
|
2369
|
-
if not
|
|
2370
|
-
|
|
2371
|
-
f"
|
|
2372
|
-
"
|
|
2466
|
+
if save_format not in {"csv", "parquet"}:
|
|
2467
|
+
raise ValueError(
|
|
2468
|
+
f"Unsupported save format: {save_format}. "
|
|
2469
|
+
"Supported: csv, parquet"
|
|
2373
2470
|
)
|
|
2374
|
-
|
|
2375
|
-
suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
|
|
2376
2471
|
target_path = get_save_path(
|
|
2377
2472
|
path=save_path,
|
|
2378
2473
|
default_dir=self.session.predictions_dir,
|
|
2379
2474
|
default_name="predictions",
|
|
2380
|
-
suffix=
|
|
2475
|
+
suffix=f".{save_format}",
|
|
2381
2476
|
)
|
|
2382
2477
|
if return_dataframe and isinstance(output, pd.DataFrame):
|
|
2383
2478
|
df_to_save = output
|
|
@@ -2390,12 +2485,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
2390
2485
|
df_to_save.to_csv(target_path, index=False)
|
|
2391
2486
|
elif save_format == "parquet":
|
|
2392
2487
|
df_to_save.to_parquet(target_path, index=False)
|
|
2393
|
-
elif save_format == "feather":
|
|
2394
|
-
df_to_save.to_feather(target_path)
|
|
2395
|
-
elif save_format == "excel":
|
|
2396
|
-
df_to_save.to_excel(target_path, index=False)
|
|
2397
|
-
elif save_format == "hdf5":
|
|
2398
|
-
df_to_save.to_hdf(target_path, key="predictions", mode="w")
|
|
2399
2488
|
else:
|
|
2400
2489
|
raise ValueError(f"Unsupported save format: {save_format}")
|
|
2401
2490
|
logging.info(
|
|
@@ -2432,24 +2521,21 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
2432
2521
|
num_workers=num_workers,
|
|
2433
2522
|
)
|
|
2434
2523
|
|
|
2435
|
-
if not
|
|
2436
|
-
|
|
2437
|
-
f"
|
|
2438
|
-
"Results will be collected in memory and saved at the end. Use csv or parquet for true streaming."
|
|
2524
|
+
if save_format not in {"csv", "parquet"}:
|
|
2525
|
+
raise ValueError(
|
|
2526
|
+
f"Unsupported save format: {save_format}. " "Supported: csv, parquet"
|
|
2439
2527
|
)
|
|
2440
|
-
|
|
2441
|
-
suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
|
|
2442
2528
|
target_path = get_save_path(
|
|
2443
2529
|
path=save_path,
|
|
2444
2530
|
default_dir=self.session.predictions_dir,
|
|
2445
2531
|
default_name="predictions",
|
|
2446
|
-
suffix=
|
|
2532
|
+
suffix=f".{save_format}",
|
|
2447
2533
|
add_timestamp=False,
|
|
2448
2534
|
)
|
|
2449
2535
|
header_written = target_path.exists() and target_path.stat().st_size > 0
|
|
2450
2536
|
parquet_writer = None
|
|
2451
2537
|
pred_columns = None
|
|
2452
|
-
|
|
2538
|
+
cached_frames = []
|
|
2453
2539
|
|
|
2454
2540
|
for batch_data in progress(data_loader, description="Predicting (ONNX)"):
|
|
2455
2541
|
batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
|
|
@@ -2514,7 +2600,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
2514
2600
|
|
|
2515
2601
|
should_collect = return_dataframe or save_format not in {"csv", "parquet"}
|
|
2516
2602
|
if should_collect:
|
|
2517
|
-
|
|
2603
|
+
cached_frames.append(df_batch)
|
|
2518
2604
|
|
|
2519
2605
|
if save_format == "csv":
|
|
2520
2606
|
df_batch.to_csv(
|
|
@@ -2538,20 +2624,11 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
2538
2624
|
if parquet_writer is not None:
|
|
2539
2625
|
parquet_writer.close()
|
|
2540
2626
|
|
|
2541
|
-
if save_format in ["feather", "excel", "hdf5"] and collected_frames:
|
|
2542
|
-
combined_df = pd.concat(collected_frames, ignore_index=True)
|
|
2543
|
-
if save_format == "feather":
|
|
2544
|
-
combined_df.to_feather(target_path)
|
|
2545
|
-
elif save_format == "excel":
|
|
2546
|
-
combined_df.to_excel(target_path, index=False)
|
|
2547
|
-
elif save_format == "hdf5":
|
|
2548
|
-
combined_df.to_hdf(target_path, key="predictions", mode="w")
|
|
2549
|
-
|
|
2550
2627
|
logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
|
|
2551
2628
|
if return_dataframe:
|
|
2552
2629
|
return (
|
|
2553
|
-
pd.concat(
|
|
2554
|
-
if
|
|
2630
|
+
pd.concat(cached_frames, ignore_index=True)
|
|
2631
|
+
if cached_frames
|
|
2555
2632
|
else pd.DataFrame(columns=pred_columns or [])
|
|
2556
2633
|
)
|
|
2557
2634
|
return target_path
|
|
@@ -2738,6 +2815,36 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
2738
2815
|
return model
|
|
2739
2816
|
|
|
2740
2817
|
|
|
2818
|
+
def predict_streaming_worker(
|
|
2819
|
+
model: "BaseModel",
|
|
2820
|
+
data_path: str | os.PathLike,
|
|
2821
|
+
batch_size: int,
|
|
2822
|
+
save_path: str | os.PathLike,
|
|
2823
|
+
save_format: str,
|
|
2824
|
+
stream_chunk_size: int,
|
|
2825
|
+
num_workers: int,
|
|
2826
|
+
prefetch_factor: int | None,
|
|
2827
|
+
processor: Any | None,
|
|
2828
|
+
shard_rank: int,
|
|
2829
|
+
shard_count: int,
|
|
2830
|
+
) -> None:
|
|
2831
|
+
model.eval()
|
|
2832
|
+
model.predict_streaming(
|
|
2833
|
+
data=data_path,
|
|
2834
|
+
batch_size=batch_size,
|
|
2835
|
+
save_path=save_path,
|
|
2836
|
+
save_format=save_format,
|
|
2837
|
+
stream_chunk_size=stream_chunk_size,
|
|
2838
|
+
return_dataframe=False,
|
|
2839
|
+
num_workers=num_workers,
|
|
2840
|
+
prefetch_factor=prefetch_factor,
|
|
2841
|
+
processor=processor,
|
|
2842
|
+
num_processes=1,
|
|
2843
|
+
shard_rank=shard_rank,
|
|
2844
|
+
shard_count=shard_count,
|
|
2845
|
+
)
|
|
2846
|
+
|
|
2847
|
+
|
|
2741
2848
|
class BaseMatchModel(BaseModel):
|
|
2742
2849
|
"""
|
|
2743
2850
|
Base class for match (retrieval/recall) models
|