nextrec 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/basic/model.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 25/01/2026
5
+ Checkpoint: edit on 01/02/2026
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -13,20 +13,18 @@ import os
13
13
  import sys
14
14
  import pickle
15
15
  import socket
16
+ import multiprocessing as mp
16
17
  from pathlib import Path
17
18
  from typing import Any, Literal, cast, overload
18
19
 
19
20
  import numpy as np
20
21
  import pandas as pd
22
+ import pyarrow as pa
23
+ import pyarrow.parquet as pq
24
+ import polars as pl
25
+ import swanlab
26
+ import wandb
21
27
 
22
- try:
23
- import swanlab # type: ignore
24
- except ModuleNotFoundError:
25
- swanlab = None
26
- try:
27
- import wandb # type: ignore
28
- except ModuleNotFoundError:
29
- wandb = None
30
28
 
31
29
  import torch
32
30
  import torch.distributed as dist
@@ -65,15 +63,9 @@ from nextrec.data.dataloader import (
65
63
  TensorDictDataset,
66
64
  build_tensors_from_data,
67
65
  )
68
- from nextrec.utils.data import check_streaming_support
69
- from nextrec.loss import (
70
- BPRLoss,
71
- GradNormLossWeighting,
72
- HingeLoss,
73
- InfoNCELoss,
74
- SampledSoftmaxLoss,
75
- TripletLoss,
76
- )
66
+ from nextrec.loss.grad_norm import GradNormLossWeighting
67
+ from nextrec.loss.listwise import InfoNCELoss, SampledSoftmaxLoss
68
+ from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
77
69
  from nextrec.utils.loss import get_loss_fn
78
70
  from nextrec.loss.grad_norm import get_grad_norm_shared_params
79
71
  from nextrec.utils.console import display_metrics_table, progress
@@ -111,8 +103,6 @@ from nextrec.utils.types import (
111
103
  MetricsName,
112
104
  )
113
105
 
114
- from nextrec.utils.data import FILE_FORMAT_CONFIG
115
-
116
106
 
117
107
  class BaseModel(SummarySet, FeatureSet, nn.Module):
118
108
  @property
@@ -1619,14 +1609,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1619
1609
  )
1620
1610
  )
1621
1611
  return {}
1622
- # if self.is_main_process:
1623
- # logging.info(
1624
- # colorize(
1625
- # format_kv(
1626
- # "Evaluation samples", y_true_all.shape[0]
1627
- # ),
1628
- # )
1629
- # )
1612
+
1630
1613
  logging.info("")
1631
1614
  metrics_dict = evaluate_metrics(
1632
1615
  y_true=y_true_all,
@@ -1643,106 +1626,141 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1643
1626
  @overload
1644
1627
  def predict(
1645
1628
  self,
1646
- data: str | dict | pd.DataFrame | DataLoader,
1629
+ data: str | os.PathLike | DataLoader,
1647
1630
  batch_size: int = 32,
1648
1631
  save_path: str | os.PathLike | None = None,
1649
1632
  save_format: str = "csv",
1650
- include_ids: bool | None = None,
1651
- id_columns: str | list[str] | None = None,
1652
1633
  return_dataframe: Literal[True] = True,
1653
1634
  stream_chunk_size: int = 10000,
1654
1635
  num_workers: int = 0,
1636
+ prefetch_factor: int | None = None,
1637
+ num_processes: int = 1,
1638
+ processor: Any | None = None,
1655
1639
  ) -> pd.DataFrame: ...
1656
1640
 
1657
1641
  @overload
1658
1642
  def predict(
1659
1643
  self,
1660
- data: str | dict | pd.DataFrame | DataLoader,
1644
+ data: str | os.PathLike | DataLoader,
1661
1645
  batch_size: int = 32,
1662
1646
  save_path: None = None,
1663
1647
  save_format: str = "csv",
1664
- include_ids: bool | None = None,
1665
- id_columns: str | list[str] | None = None,
1666
1648
  return_dataframe: Literal[False] = False,
1667
1649
  stream_chunk_size: int = 10000,
1668
1650
  num_workers: int = 0,
1651
+ prefetch_factor: int | None = None,
1652
+ num_processes: int = 1,
1653
+ processor: Any | None = None,
1669
1654
  ) -> np.ndarray: ...
1670
1655
 
1671
1656
  @overload
1672
1657
  def predict(
1673
1658
  self,
1674
- data: str | dict | pd.DataFrame | DataLoader,
1659
+ data: str | os.PathLike | dict | pd.DataFrame | DataLoader,
1675
1660
  batch_size: int = 32,
1676
1661
  *,
1677
1662
  save_path: str | os.PathLike,
1678
1663
  save_format: str = "csv",
1679
- include_ids: bool | None = None,
1680
- id_columns: str | list[str] | None = None,
1681
1664
  return_dataframe: Literal[False] = False,
1682
1665
  stream_chunk_size: int = 10000,
1683
1666
  num_workers: int = 0,
1667
+ prefetch_factor: int | None = None,
1668
+ num_processes: int = 1,
1669
+ processor: Any | None = None,
1684
1670
  ) -> Path: ...
1685
1671
 
1686
1672
  def predict(
1687
1673
  self,
1688
- data: str | dict | pd.DataFrame | DataLoader,
1674
+ data: str | os.PathLike | dict | pd.DataFrame | DataLoader,
1689
1675
  batch_size: int = 32,
1690
1676
  save_path: str | os.PathLike | None = None,
1691
1677
  save_format: str = "csv",
1692
- include_ids: bool | None = None,
1693
- id_columns: str | list[str] | None = None,
1694
1678
  return_dataframe: bool = True,
1695
1679
  stream_chunk_size: int = 10000,
1696
1680
  num_workers: int = 0,
1681
+ prefetch_factor: int | None = None,
1682
+ num_processes: int = 1,
1683
+ processor: Any | None = None,
1697
1684
  ) -> pd.DataFrame | np.ndarray | Path | None:
1698
1685
  """
1699
1686
  Make predictions on the given data.
1700
1687
 
1701
1688
  Args:
1702
- data: Input data for prediction (file path, dict, DataFrame, or DataLoader).
1689
+ data: Input data for prediction (file path or DataLoader).
1703
1690
  batch_size: Batch size for prediction (per process when distributed).
1704
1691
  save_path: Optional path to save predictions; if None, predictions are not saved to disk.
1705
1692
  save_format: Format to save predictions ('csv' or 'parquet').
1706
- include_ids: Whether to include ID columns in the output; if None, includes if id_columns are set.
1707
- id_columns: Column name(s) to use as IDs; if None, uses model's id_columns.
1708
1693
  return_dataframe: Whether to return predictions as a pandas DataFrame; if False, returns a NumPy array.
1709
1694
  stream_chunk_size: Number of rows per chunk when using streaming mode for large datasets.
1710
1695
  num_workers: DataLoader worker count.
1696
+ prefetch_factor: Number of batches prefetched per worker (only when num_workers > 0).
1697
+ num_processes: Number of inference processes for streaming file inference.
1698
+ processor: Optional DataProcessor for transforming input data.
1711
1699
 
1712
1700
  Note:
1713
- predict does not support distributed mode currently, consider it as a single-process operation.
1701
+ predict does not support distributed mode currently; streaming file inference can use
1702
+ multiple processes via num_processes > 1, which may change output order.
1714
1703
  """
1715
1704
  self.eval()
1716
- # Use prediction-time id_columns if provided, otherwise fall back to model's id_columns
1717
- predict_id_columns = id_columns if id_columns is not None else self.id_columns
1718
- if isinstance(predict_id_columns, str):
1719
- predict_id_columns = [predict_id_columns]
1720
-
1721
- if include_ids is None:
1722
- include_ids = bool(predict_id_columns)
1723
- include_ids = include_ids and bool(predict_id_columns)
1724
1705
 
1725
- # Use streaming mode for large file saves without loading all data into memory
1726
- if save_path is not None and not return_dataframe:
1706
+ # streaming mode prediction
1707
+ if (
1708
+ save_path is not None
1709
+ and not return_dataframe
1710
+ and isinstance(data, (str, os.PathLike))
1711
+ ):
1712
+ if num_processes > 1 and not isinstance(data, (str, os.PathLike)):
1713
+ raise ValueError(
1714
+ "[BaseModel-predict Error] Multi-process streaming requires data to be a file path."
1715
+ )
1716
+ if num_workers != 0:
1717
+ logging.info(
1718
+ "[BaseModel-predict-streaming Info] Streaming mode enforces num_workers=0."
1719
+ )
1720
+ logging.info("")
1727
1721
  return self.predict_streaming(
1728
1722
  data=data,
1729
1723
  batch_size=batch_size,
1730
1724
  save_path=save_path,
1731
1725
  save_format=save_format,
1732
- include_ids=include_ids,
1733
1726
  stream_chunk_size=stream_chunk_size,
1734
1727
  return_dataframe=return_dataframe,
1735
- id_columns=predict_id_columns,
1728
+ num_workers=0,
1729
+ num_processes=num_processes,
1730
+ processor=processor,
1736
1731
  )
1737
1732
 
1738
- # Create DataLoader based on data type
1733
+ return self.predict_in_memory(
1734
+ data=data,
1735
+ batch_size=batch_size,
1736
+ save_path=save_path,
1737
+ save_format=save_format,
1738
+ return_dataframe=return_dataframe,
1739
+ stream_chunk_size=stream_chunk_size,
1740
+ num_workers=num_workers,
1741
+ prefetch_factor=prefetch_factor,
1742
+ processor=processor,
1743
+ )
1744
+
1745
+ def predict_in_memory(
1746
+ self,
1747
+ data: str | os.PathLike | dict | pd.DataFrame | DataLoader,
1748
+ batch_size: int = 32,
1749
+ save_path: str | os.PathLike | None = None,
1750
+ save_format: str = "csv",
1751
+ return_dataframe: bool = True,
1752
+ stream_chunk_size: int = 10000,
1753
+ num_workers: int = 0,
1754
+ prefetch_factor: int | None = None,
1755
+ processor: Any | None = None,
1756
+ ) -> pd.DataFrame | np.ndarray | Path | None:
1757
+
1758
+ predict_id_columns = self.id_columns
1759
+ if isinstance(predict_id_columns, str):
1760
+ predict_id_columns = [predict_id_columns]
1761
+ include_ids = bool(predict_id_columns)
1739
1762
  if isinstance(data, DataLoader):
1740
1763
  data_loader = data
1741
- if num_workers != 0:
1742
- logging.warning(
1743
- "[Predict Warning] num_workers parameter is ignored when data is already a DataLoader. "
1744
- "The DataLoader's existing num_workers configuration will be used."
1745
- )
1746
1764
  elif isinstance(data, (str, os.PathLike)):
1747
1765
  rec_loader = RecDataLoader(
1748
1766
  dense_features=self.dense_features,
@@ -1750,6 +1768,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1750
1768
  sequence_features=self.sequence_features,
1751
1769
  target=self.target_columns,
1752
1770
  id_columns=predict_id_columns,
1771
+ processor=processor,
1753
1772
  )
1754
1773
  data_loader = rec_loader.create_dataloader(
1755
1774
  data=data,
@@ -1757,6 +1776,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1757
1776
  shuffle=False,
1758
1777
  streaming=True,
1759
1778
  chunk_size=stream_chunk_size,
1779
+ num_workers=0,
1780
+ prefetch_factor=prefetch_factor,
1760
1781
  )
1761
1782
  else:
1762
1783
  data_loader = self.prepare_data_loader(
@@ -1834,23 +1855,16 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1834
1855
  else y_pred_all
1835
1856
  )
1836
1857
  if save_path is not None:
1837
- # Check streaming write support
1838
- if not check_streaming_support(save_format):
1839
- logging.warning(
1840
- f"[BaseModel-predict Warning] Format '{save_format}' does not support streaming writes. "
1841
- "The entire result will be saved at once. Use csv or parquet for large datasets."
1858
+ if save_format not in {"csv", "parquet"}:
1859
+ raise ValueError(
1860
+ f"Unsupported save format: {save_format}. "
1861
+ "Supported: csv, parquet"
1842
1862
  )
1843
-
1844
- # Get file extension from format
1845
- from nextrec.utils.data import FILE_FORMAT_CONFIG
1846
-
1847
- suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
1848
-
1849
1863
  target_path = get_save_path(
1850
1864
  path=save_path,
1851
1865
  default_dir=self.session.predictions_dir,
1852
1866
  default_name="predictions",
1853
- suffix=suffix,
1867
+ suffix=f".{save_format}",
1854
1868
  add_timestamp=True if save_path is None else False,
1855
1869
  )
1856
1870
  if isinstance(output, pd.DataFrame):
@@ -1870,12 +1884,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1870
1884
  df_to_save.to_csv(target_path, index=False)
1871
1885
  elif save_format == "parquet":
1872
1886
  df_to_save.to_parquet(target_path, index=False)
1873
- elif save_format == "feather":
1874
- df_to_save.to_feather(target_path)
1875
- elif save_format == "excel":
1876
- df_to_save.to_excel(target_path, index=False)
1877
- elif save_format == "hdf5":
1878
- df_to_save.to_hdf(target_path, key="predictions", mode="w")
1879
1887
  else:
1880
1888
  raise ValueError(f"Unsupported save format: {save_format}")
1881
1889
 
@@ -1886,37 +1894,64 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1886
1894
 
1887
1895
  def predict_streaming(
1888
1896
  self,
1889
- data: str | dict | pd.DataFrame | DataLoader,
1897
+ data: str | os.PathLike | DataLoader,
1890
1898
  batch_size: int,
1891
1899
  save_path: str | os.PathLike,
1892
1900
  save_format: str,
1893
- include_ids: bool,
1894
1901
  stream_chunk_size: int,
1895
1902
  return_dataframe: bool,
1896
- id_columns: list[str] | None = None,
1903
+ num_workers: int = 0,
1904
+ prefetch_factor: int | None = None,
1905
+ num_processes: int = 1,
1906
+ processor: Any | None = None,
1907
+ shard_rank: int = 0,
1908
+ shard_count: int = 1,
1897
1909
  ):
1898
1910
  """
1899
1911
  Make predictions on the given data using streaming mode for large datasets.
1900
1912
 
1901
1913
  Args:
1902
- data: Input data for prediction (file path, dict, DataFrame, or DataLoader).
1914
+ data: Input data for prediction (file path or DataLoader).
1903
1915
  batch_size: Batch size for prediction.
1904
1916
  save_path: Path to save predictions.
1905
1917
  save_format: Format to save predictions ('csv' or 'parquet').
1906
- include_ids: Whether to include ID columns in the output.
1907
1918
  stream_chunk_size: Number of rows per chunk when using streaming mode.
1908
1919
  return_dataframe: Whether to return predictions as a pandas DataFrame.
1909
- id_columns: Column name(s) to use as IDs; if None, uses model's id_columns.
1910
- Note:
1911
- This method uses streaming writes to handle large datasets without loading all data into memory.
1920
+ num_workers: DataLoader worker count.
1921
+ prefetch_factor: Number of batches prefetched per worker (only when num_workers > 0).
1922
+ num_processes: Number of inference processes for streaming file inference.
1923
+ processor: Optional DataProcessor for transforming input data.
1924
+ shard_rank: Process shard rank for multi-process inference.
1925
+ shard_count: Total number of shards for multi-process inference.
1912
1926
  """
1927
+ predict_id_columns = self.id_columns
1928
+ if isinstance(predict_id_columns, str):
1929
+ predict_id_columns = [predict_id_columns]
1930
+ include_ids = bool(predict_id_columns)
1931
+
1932
+ # Multi-process streaming
1933
+ if num_processes > 1:
1934
+ return self.predict_streaming_multiprocess(
1935
+ data=data,
1936
+ batch_size=batch_size,
1937
+ save_path=save_path,
1938
+ save_format=save_format,
1939
+ stream_chunk_size=stream_chunk_size,
1940
+ return_dataframe=return_dataframe,
1941
+ num_workers=num_workers,
1942
+ prefetch_factor=None, # disable prefetching in multi-process mode
1943
+ num_processes=num_processes,
1944
+ processor=processor,
1945
+ )
1946
+ # Single-process streaming
1913
1947
  if isinstance(data, (str, os.PathLike)):
1914
1948
  rec_loader = RecDataLoader(
1915
1949
  dense_features=self.dense_features,
1916
1950
  sparse_features=self.sparse_features,
1917
1951
  sequence_features=self.sequence_features,
1918
1952
  target=self.target_columns,
1919
- id_columns=id_columns,
1953
+ id_columns=predict_id_columns,
1954
+ processor=processor,
1920
1955
  )
1921
1956
  data_loader = rec_loader.create_dataloader(
1922
1957
  data=data,
@@ -1924,53 +1959,41 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1924
1959
  shuffle=False,
1925
1960
  streaming=True,
1926
1961
  chunk_size=stream_chunk_size,
1962
+ num_workers=num_workers,
1963
+ prefetch_factor=None if num_workers == 0 else prefetch_factor,
1964
+ shard_rank=shard_rank,
1965
+ shard_count=shard_count,
1927
1966
  )
1928
1967
  elif not isinstance(data, DataLoader):
1929
- data_loader = self.prepare_data_loader(
1930
- data,
1931
- batch_size=batch_size,
1932
- shuffle=False,
1968
+ raise TypeError(
1969
+ "[BaseModel-predict-streaming Error] data must be a file path or a DataLoader."
1933
1970
  )
1934
- else:
1971
+ else: # data is a DataLoader
1935
1972
  data_loader = data
1936
1973
 
1937
- if hasattr(data_loader, "num_workers") and data_loader.num_workers > 0:
1938
- if (
1939
- hasattr(data_loader.dataset, "__class__")
1940
- and "Streaming" in data_loader.dataset.__class__.__name__
1941
- ):
1942
- logging.warning(
1943
- f"[Predict Streaming Warning] Detected DataLoader with num_workers={data_loader.num_workers} "
1944
- "and streaming dataset. This may cause data duplication! "
1945
- "When using streaming mode, set num_workers=0 to avoid reading data multiple times."
1946
- )
1947
-
1948
- # Check streaming support and prepare file path
1949
- if not check_streaming_support(save_format):
1950
- logging.warning(
1951
- f"[Predict Streaming Warning] Format '{save_format}' does not support streaming writes. "
1952
- "Results will be collected in memory and saved at the end. Use csv or parquet for true streaming."
1974
+ if save_format not in {"csv", "parquet"}:
1975
+ raise ValueError(
1976
+ f"Unsupported save format: {save_format}. Supported: csv, parquet"
1953
1977
  )
1954
-
1955
- suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
1956
-
1957
1978
  target_path = get_save_path(
1958
1979
  path=save_path,
1959
1980
  default_dir=self.session.predictions_dir,
1960
1981
  default_name="predictions",
1961
- suffix=suffix,
1982
+ suffix=f".{save_format}",
1962
1983
  add_timestamp=True if save_path is None else False,
1963
1984
  )
1964
1985
  target_path.parent.mkdir(parents=True, exist_ok=True)
1965
- header_written = target_path.exists() and target_path.stat().st_size > 0
1986
+ header_written = target_path.exists()
1966
1987
  parquet_writer = None
1988
+
1967
1989
  pred_columns = None
1968
- collected_frames = (
1969
- []
1970
- ) # used when return_dataframe=True or for non-streaming formats
1990
+ cached_frames = [] # used when return_dataframe=True
1971
1991
 
1992
+ disable_progress = shard_count > 1
1972
1993
  with torch.no_grad():
1973
- for batch_data in progress(data_loader, description="Predicting"):
1994
+ for batch_data in progress(
1995
+ data_loader, description="Predicting", disable=disable_progress
1996
+ ):
1974
1997
  batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
1975
1998
  X_input, _ = self.get_input(batch_dict, require_labels=False)
1976
1999
  y_pred = self.forward(X_input)
@@ -1989,14 +2012,18 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1989
2012
  while len(pred_columns) < num_outputs:
1990
2013
  pred_columns.append(f"pred_{len(pred_columns)}")
1991
2014
 
1992
- ids = batch_dict.get("ids") if include_ids and id_columns else None
2015
+ ids = (
2016
+ batch_dict.get("ids")
2017
+ if include_ids and predict_id_columns
2018
+ else None
2019
+ )
1993
2020
  id_arrays_batch = {
1994
2021
  id_name: (
1995
2022
  ids[id_name].detach().cpu().numpy()
1996
2023
  if isinstance(ids[id_name], torch.Tensor)
1997
2024
  else np.asarray(ids[id_name])
1998
2025
  ).reshape(-1)
1999
- for id_name in (id_columns or [])
2026
+ for id_name in (predict_id_columns or [])
2000
2027
  if ids and id_name in ids
2001
2028
  }
2002
2029
 
@@ -2015,48 +2042,123 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
2015
2042
  target_path, mode="a", header=not header_written, index=False
2016
2043
  )
2017
2044
  header_written = True
2045
+ if return_dataframe:
2046
+ cached_frames.append(df_batch)
2018
2047
  elif save_format == "parquet":
2019
- try:
2020
- import pyarrow as pa
2021
- import pyarrow.parquet as pq
2022
- except ImportError as exc: # pragma: no cover
2023
- raise ImportError(
2024
- "[BaseModel-predict-streaming Error] Parquet streaming save requires pyarrow."
2025
- ) from exc
2026
2048
  table = pa.Table.from_pandas(df_batch, preserve_index=False)
2027
2049
  if parquet_writer is None:
2028
2050
  parquet_writer = pq.ParquetWriter(target_path, table.schema)
2029
2051
  parquet_writer.write_table(table)
2052
+ if return_dataframe:
2053
+ cached_frames.append(df_batch)
2030
2054
  else:
2031
2055
  # Non-streaming formats: collect all data
2032
- collected_frames.append(df_batch)
2033
-
2034
- if return_dataframe and save_format in ["csv", "parquet"]:
2035
- collected_frames.append(df_batch)
2056
+ cached_frames.append(df_batch)
2036
2057
 
2037
2058
  # Close writers
2038
2059
  if parquet_writer is not None:
2039
2060
  parquet_writer.close()
2040
- # For non-streaming formats, save collected data
2041
- if save_format in ["feather", "excel", "hdf5"] and collected_frames:
2042
- combined_df = pd.concat(collected_frames, ignore_index=True)
2043
- if save_format == "feather":
2044
- combined_df.to_feather(target_path)
2045
- elif save_format == "excel":
2046
- combined_df.to_excel(target_path, index=False)
2047
- elif save_format == "hdf5":
2048
- combined_df.to_hdf(target_path, key="predictions", mode="w")
2049
2061
 
2050
2062
  logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
2051
2063
  if return_dataframe:
2052
2064
  return (
2053
- pd.concat(collected_frames, ignore_index=True)
2054
- if collected_frames
2065
+ pd.concat(cached_frames, ignore_index=True)
2066
+ if cached_frames
2055
2067
  else pd.DataFrame(columns=pred_columns or [])
2056
2068
  )
2057
2069
  # Return the actual save path when not returning dataframe
2058
2070
  return target_path
2059
2071
 
2072
+ def predict_streaming_multiprocess(
2073
+ self,
2074
+ data: str | os.PathLike | DataLoader,
2075
+ batch_size: int,
2076
+ save_path: str | os.PathLike,
2077
+ save_format: str,
2078
+ stream_chunk_size: int,
2079
+ return_dataframe: bool,
2080
+ num_workers: int,
2081
+ prefetch_factor: int | None,
2082
+ num_processes: int,
2083
+ processor: Any | None,
2084
+ ):
2085
+ target_path = Path(
2086
+ get_save_path(
2087
+ path=save_path,
2088
+ default_dir=self.session.predictions_dir,
2089
+ default_name="predictions",
2090
+ suffix=f".{save_format}",
2091
+ add_timestamp=True if save_path is None else False,
2092
+ )
2093
+ )
2094
+ parts_dir = target_path.parent / f".{target_path.stem}_parts"
2095
+ parts_dir.mkdir(parents=True, exist_ok=True)
2096
+ part_paths = [
2097
+ parts_dir / f"{target_path.stem}.part{rank}{target_path.suffix}"
2098
+ for rank in range(num_processes)
2099
+ ]
2100
+
2101
+ ctx = mp.get_context("spawn")
2102
+ processes = []
2103
+ for rank in range(num_processes):
2104
+ process = ctx.Process(
2105
+ target=predict_streaming_worker,
2106
+ args=(
2107
+ self,
2108
+ data,
2109
+ batch_size,
2110
+ part_paths[rank],
2111
+ save_format,
2112
+ stream_chunk_size,
2113
+ num_workers,
2114
+ prefetch_factor,
2115
+ processor,
2116
+ rank,
2117
+ num_processes,
2118
+ ),
2119
+ )
2120
+ process.start()
2121
+ processes.append(process)
2122
+
2123
+ for process in progress(
2124
+ iter(processes), description="Predicting...", total=None
2125
+ ):
2126
+ process.join()
2127
+
2128
+ for process in processes:
2129
+ if process.exitcode not in (0, None):
2130
+ raise RuntimeError(
2131
+ "[BaseModel-predict-streaming Error] One or more inference processes failed."
2132
+ )
2133
+ # Merge part files
2134
+ existing_parts = [p for p in part_paths if p.exists()]
2135
+ if existing_parts:
2136
+ target_path.parent.mkdir(parents=True, exist_ok=True)
2137
+ if save_format == "csv":
2138
+ lazy_frames = [pl.scan_csv(p) for p in existing_parts]
2139
+ pl.concat(lazy_frames).sink_csv(target_path)
2140
+ elif save_format == "parquet":
2141
+ lazy_frames = [pl.scan_parquet(p) for p in existing_parts]
2142
+ pl.concat(lazy_frames).sink_parquet(target_path)
2143
+ else:
2144
+ raise ValueError(
2145
+ f"Unsupported save format: {save_format}. Supported: csv, parquet"
2146
+ )
2147
+
2148
+ for part_path in part_paths:
2149
+ if part_path.exists():
2150
+ part_path.unlink()
2151
+ if parts_dir.exists() and not any(parts_dir.iterdir()):
2152
+ parts_dir.rmdir()
2153
+
2154
+ logging.info(
2155
+ colorize(
2156
+ f"Predictions saved to: {target_path} (merged from {num_processes} parts)",
2157
+ color="green",
2158
+ )
2159
+ )
2160
+ return target_path
2161
+
2060
2162
  def prepare_onnx_dataloader(
2061
2163
  self,
2062
2164
  data: str | dict | pd.DataFrame | DataLoader,
@@ -2074,11 +2176,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
2074
2176
 
2075
2177
  """
2076
2178
  if isinstance(data, DataLoader):
2077
- if num_workers != 0:
2078
- logging.warning(
2079
- "[Predict ONNX Warning] num_workers parameter is ignored when data is already a DataLoader. "
2080
- "The DataLoader's existing num_workers configuration will be used."
2081
- )
2082
2179
  return data
2083
2180
  # if data is a file path, use streaming DataLoader
2084
2181
  # will set batch_size=1 cause each batch is a file chunk
@@ -2366,18 +2463,16 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
2366
2463
  output = df_to_return
2367
2464
 
2368
2465
  if save_path is not None:
2369
- if not check_streaming_support(save_format):
2370
- logging.warning(
2371
- f"[BaseModel-predict-onnx Warning] Format '{save_format}' does not support streaming writes. "
2372
- "The entire result will be saved at once. Use csv or parquet for large datasets."
2466
+ if save_format not in {"csv", "parquet"}:
2467
+ raise ValueError(
2468
+ f"Unsupported save format: {save_format}. "
2469
+ "Supported: csv, parquet"
2373
2470
  )
2374
-
2375
- suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
2376
2471
  target_path = get_save_path(
2377
2472
  path=save_path,
2378
2473
  default_dir=self.session.predictions_dir,
2379
2474
  default_name="predictions",
2380
- suffix=suffix,
2475
+ suffix=f".{save_format}",
2381
2476
  )
2382
2477
  if return_dataframe and isinstance(output, pd.DataFrame):
2383
2478
  df_to_save = output
@@ -2390,12 +2485,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
2390
2485
  df_to_save.to_csv(target_path, index=False)
2391
2486
  elif save_format == "parquet":
2392
2487
  df_to_save.to_parquet(target_path, index=False)
2393
- elif save_format == "feather":
2394
- df_to_save.to_feather(target_path)
2395
- elif save_format == "excel":
2396
- df_to_save.to_excel(target_path, index=False)
2397
- elif save_format == "hdf5":
2398
- df_to_save.to_hdf(target_path, key="predictions", mode="w")
2399
2488
  else:
2400
2489
  raise ValueError(f"Unsupported save format: {save_format}")
2401
2490
  logging.info(
@@ -2432,24 +2521,21 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
2432
2521
  num_workers=num_workers,
2433
2522
  )
2434
2523
 
2435
- if not check_streaming_support(save_format):
2436
- logging.warning(
2437
- f"[Predict ONNX Streaming Warning] Format '{save_format}' does not support streaming writes. "
2438
- "Results will be collected in memory and saved at the end. Use csv or parquet for true streaming."
2524
+ if save_format not in {"csv", "parquet"}:
2525
+ raise ValueError(
2526
+ f"Unsupported save format: {save_format}. " "Supported: csv, parquet"
2439
2527
  )
2440
-
2441
- suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
2442
2528
  target_path = get_save_path(
2443
2529
  path=save_path,
2444
2530
  default_dir=self.session.predictions_dir,
2445
2531
  default_name="predictions",
2446
- suffix=suffix,
2532
+ suffix=f".{save_format}",
2447
2533
  add_timestamp=False,
2448
2534
  )
2449
2535
  header_written = target_path.exists() and target_path.stat().st_size > 0
2450
2536
  parquet_writer = None
2451
2537
  pred_columns = None
2452
- collected_frames = []
2538
+ cached_frames = []
2453
2539
 
2454
2540
  for batch_data in progress(data_loader, description="Predicting (ONNX)"):
2455
2541
  batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
@@ -2514,7 +2600,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
2514
2600
 
2515
2601
  should_collect = return_dataframe or save_format not in {"csv", "parquet"}
2516
2602
  if should_collect:
2517
- collected_frames.append(df_batch)
2603
+ cached_frames.append(df_batch)
2518
2604
 
2519
2605
  if save_format == "csv":
2520
2606
  df_batch.to_csv(
@@ -2538,20 +2624,11 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
2538
2624
  if parquet_writer is not None:
2539
2625
  parquet_writer.close()
2540
2626
 
2541
- if save_format in ["feather", "excel", "hdf5"] and collected_frames:
2542
- combined_df = pd.concat(collected_frames, ignore_index=True)
2543
- if save_format == "feather":
2544
- combined_df.to_feather(target_path)
2545
- elif save_format == "excel":
2546
- combined_df.to_excel(target_path, index=False)
2547
- elif save_format == "hdf5":
2548
- combined_df.to_hdf(target_path, key="predictions", mode="w")
2549
-
2550
2627
  logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
2551
2628
  if return_dataframe:
2552
2629
  return (
2553
- pd.concat(collected_frames, ignore_index=True)
2554
- if collected_frames
2630
+ pd.concat(cached_frames, ignore_index=True)
2631
+ if cached_frames
2555
2632
  else pd.DataFrame(columns=pred_columns or [])
2556
2633
  )
2557
2634
  return target_path
@@ -2738,6 +2815,36 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
2738
2815
  return model
2739
2816
 
2740
2817
 
2818
+ def predict_streaming_worker(
2819
+ model: "BaseModel",
2820
+ data_path: str | os.PathLike,
2821
+ batch_size: int,
2822
+ save_path: str | os.PathLike,
2823
+ save_format: str,
2824
+ stream_chunk_size: int,
2825
+ num_workers: int,
2826
+ prefetch_factor: int | None,
2827
+ processor: Any | None,
2828
+ shard_rank: int,
2829
+ shard_count: int,
2830
+ ) -> None:
2831
+ model.eval()
2832
+ model.predict_streaming(
2833
+ data=data_path,
2834
+ batch_size=batch_size,
2835
+ save_path=save_path,
2836
+ save_format=save_format,
2837
+ stream_chunk_size=stream_chunk_size,
2838
+ return_dataframe=False,
2839
+ num_workers=num_workers,
2840
+ prefetch_factor=prefetch_factor,
2841
+ processor=processor,
2842
+ num_processes=1,
2843
+ shard_rank=shard_rank,
2844
+ shard_count=shard_count,
2845
+ )
2846
+
2847
+
2741
2848
  class BaseMatchModel(BaseModel):
2742
2849
  """
2743
2850
  Base class for match (retrieval/recall) models