nextrec 0.4.34__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +7 -13
- nextrec/basic/layers.py +28 -94
- nextrec/basic/model.py +512 -4
- nextrec/cli.py +101 -18
- nextrec/data/data_processing.py +8 -13
- nextrec/data/preprocessor.py +449 -846
- nextrec/models/ranking/afm.py +4 -9
- nextrec/models/ranking/dien.py +7 -8
- nextrec/models/ranking/ffm.py +2 -2
- nextrec/models/retrieval/sdm.py +1 -2
- nextrec/models/sequential/hstu.py +0 -2
- nextrec/utils/onnx_utils.py +252 -0
- {nextrec-0.4.34.dist-info → nextrec-0.5.0.dist-info}/METADATA +10 -4
- {nextrec-0.4.34.dist-info → nextrec-0.5.0.dist-info}/RECORD +18 -18
- nextrec/models/multi_task/[pre]star.py +0 -192
- {nextrec-0.4.34.dist-info → nextrec-0.5.0.dist-info}/WHEEL +0 -0
- {nextrec-0.4.34.dist-info → nextrec-0.5.0.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.34.dist-info → nextrec-0.5.0.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -2,11 +2,12 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 25/01/2026
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
import getpass
|
|
10
|
+
import inspect
|
|
10
11
|
import logging
|
|
11
12
|
import os
|
|
12
13
|
import sys
|
|
@@ -90,6 +91,15 @@ from nextrec.utils.model import (
|
|
|
90
91
|
compute_ranking_loss,
|
|
91
92
|
get_loss_list,
|
|
92
93
|
)
|
|
94
|
+
from nextrec.utils.onnx_utils import (
|
|
95
|
+
OnnxModelWrapper,
|
|
96
|
+
build_onnx_input_feed,
|
|
97
|
+
create_dummy_inputs,
|
|
98
|
+
load_onnx_session,
|
|
99
|
+
merge_onnx_outputs,
|
|
100
|
+
pad_onnx_inputs,
|
|
101
|
+
pad_id_batch,
|
|
102
|
+
)
|
|
93
103
|
|
|
94
104
|
from nextrec.utils.types import (
|
|
95
105
|
LossName,
|
|
@@ -312,7 +322,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
312
322
|
"""
|
|
313
323
|
Prepare unified input features and labels from the given input data.
|
|
314
324
|
|
|
315
|
-
|
|
316
325
|
Args:
|
|
317
326
|
input_data: Input data dictionary containing 'features' and optionally 'labels', e.g., {'features': {'feat1': [...], 'feat2': [...]}, 'labels': {'label': [...]}}.
|
|
318
327
|
require_labels: Whether labels are required in the input data. Default is True: for training and evaluation with labels.
|
|
@@ -367,14 +376,14 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
367
376
|
target_tensor = to_tensor(
|
|
368
377
|
target_data, dtype=torch.float32, device=self.device
|
|
369
378
|
)
|
|
370
|
-
target_tensor = target_tensor.
|
|
379
|
+
target_tensor = target_tensor.reshape(
|
|
371
380
|
target_tensor.size(0), -1
|
|
372
381
|
) # always reshape to (batch_size, num_targets)
|
|
373
382
|
target_tensors.append(target_tensor)
|
|
374
383
|
if target_tensors:
|
|
375
384
|
y = torch.cat(target_tensors, dim=1)
|
|
376
385
|
if y.shape[1] == 1: # no need to do that again
|
|
377
|
-
y = y.
|
|
386
|
+
y = y.reshape(-1)
|
|
378
387
|
elif require_labels:
|
|
379
388
|
raise ValueError(
|
|
380
389
|
"[BaseModel-input Error] Labels are required but none were found in the input batch."
|
|
@@ -2048,6 +2057,505 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
2048
2057
|
# Return the actual save path when not returning dataframe
|
|
2049
2058
|
return target_path
|
|
2050
2059
|
|
|
2060
|
+
def prepare_onnx_dataloader(
|
|
2061
|
+
self,
|
|
2062
|
+
data: str | dict | pd.DataFrame | DataLoader,
|
|
2063
|
+
batch_size: int,
|
|
2064
|
+
num_workers: int,
|
|
2065
|
+
) -> DataLoader:
|
|
2066
|
+
"""
|
|
2067
|
+
Prepare a DataLoader for ONNX prediction.
|
|
2068
|
+
|
|
2069
|
+
Args:
|
|
2070
|
+
data: Input data (file path, dict, DataFrame, or DataLoader).
|
|
2071
|
+
batch_size: Effective batch size. When `data` is a file path, this is
|
|
2072
|
+
treated as the streaming chunk size.
|
|
2073
|
+
num_workers: Number of DataLoader workers.
|
|
2074
|
+
|
|
2075
|
+
"""
|
|
2076
|
+
if isinstance(data, DataLoader):
|
|
2077
|
+
if num_workers != 0:
|
|
2078
|
+
logging.warning(
|
|
2079
|
+
"[Predict ONNX Warning] num_workers parameter is ignored when data is already a DataLoader. "
|
|
2080
|
+
"The DataLoader's existing num_workers configuration will be used."
|
|
2081
|
+
)
|
|
2082
|
+
return data
|
|
2083
|
+
# if data is a file path, use streaming DataLoader
|
|
2084
|
+
# will set batch_size=1 cause each batch is a file chunk
|
|
2085
|
+
if isinstance(data, (str, os.PathLike)):
|
|
2086
|
+
rec_loader = RecDataLoader(
|
|
2087
|
+
dense_features=self.dense_features,
|
|
2088
|
+
sparse_features=self.sparse_features,
|
|
2089
|
+
sequence_features=self.sequence_features,
|
|
2090
|
+
target=self.target_columns,
|
|
2091
|
+
id_columns=self.id_columns,
|
|
2092
|
+
)
|
|
2093
|
+
return rec_loader.create_dataloader(
|
|
2094
|
+
data=data,
|
|
2095
|
+
batch_size=1,
|
|
2096
|
+
shuffle=False,
|
|
2097
|
+
streaming=True,
|
|
2098
|
+
chunk_size=batch_size,
|
|
2099
|
+
)
|
|
2100
|
+
return self.prepare_data_loader(
|
|
2101
|
+
data, batch_size=batch_size, shuffle=False, num_workers=num_workers
|
|
2102
|
+
)
|
|
2103
|
+
|
|
2104
|
+
def export_onnx(
|
|
2105
|
+
self,
|
|
2106
|
+
save_path: str | Path | None = None,
|
|
2107
|
+
batch_size: int = 1,
|
|
2108
|
+
opset_version: int = 18,
|
|
2109
|
+
) -> Path:
|
|
2110
|
+
"""
|
|
2111
|
+
Export the model to ONNX.
|
|
2112
|
+
|
|
2113
|
+
Usage:
|
|
2114
|
+
onnx_path = model.export_onnx(
|
|
2115
|
+
save_path="model.onnx",
|
|
2116
|
+
batch_size=1,
|
|
2117
|
+
opset_version=18,
|
|
2118
|
+
)
|
|
2119
|
+
|
|
2120
|
+
Args:
|
|
2121
|
+
save_path: Path to save the ONNX model; if None, uses session root.
|
|
2122
|
+
batch_size: Dummy batch size for tracing.
|
|
2123
|
+
opset_version: ONNX opset version for export.
|
|
2124
|
+
"""
|
|
2125
|
+
model_to_export = (
|
|
2126
|
+
self.ddp_model.module
|
|
2127
|
+
if hasattr(self, "ddp_model") and self.ddp_model is not None
|
|
2128
|
+
else self
|
|
2129
|
+
)
|
|
2130
|
+
model_to_export = model_to_export.to(self.device)
|
|
2131
|
+
model_to_export.eval()
|
|
2132
|
+
|
|
2133
|
+
input_names = [feat.name for feat in self.all_features]
|
|
2134
|
+
dummy_inputs = create_dummy_inputs(
|
|
2135
|
+
self.all_features,
|
|
2136
|
+
batch_size=batch_size,
|
|
2137
|
+
device=self.device,
|
|
2138
|
+
)
|
|
2139
|
+
wrapper = OnnxModelWrapper(model_to_export, input_names)
|
|
2140
|
+
with torch.no_grad():
|
|
2141
|
+
output_sample = wrapper(*dummy_inputs)
|
|
2142
|
+
if isinstance(output_sample, (tuple, list)): # multiple outputs
|
|
2143
|
+
output_names = [f"output_{idx}" for idx in range(len(output_sample))]
|
|
2144
|
+
else:
|
|
2145
|
+
output_names = ["output"]
|
|
2146
|
+
target_path = get_save_path(
|
|
2147
|
+
path=save_path,
|
|
2148
|
+
default_dir=self.session.root,
|
|
2149
|
+
default_name=f"{self.model_name}_onnx",
|
|
2150
|
+
suffix="onnx",
|
|
2151
|
+
)
|
|
2152
|
+
export_kwargs: dict[str, Any] = {}
|
|
2153
|
+
export_sig = inspect.signature(torch.onnx.export)
|
|
2154
|
+
if "dynamo" in export_sig.parameters:
|
|
2155
|
+
export_kwargs["dynamo"] = True
|
|
2156
|
+
if opset_version < 18:
|
|
2157
|
+
logging.warning(
|
|
2158
|
+
"[BaseModel-export-onnx Warning] TorchDynamo exporter requires opset >= 18. "
|
|
2159
|
+
"Overriding opset_version to 18."
|
|
2160
|
+
)
|
|
2161
|
+
opset_version = 18
|
|
2162
|
+
|
|
2163
|
+
torch.onnx.export(
|
|
2164
|
+
wrapper,
|
|
2165
|
+
tuple(dummy_inputs),
|
|
2166
|
+
target_path,
|
|
2167
|
+
input_names=list(input_names),
|
|
2168
|
+
output_names=list(output_names),
|
|
2169
|
+
opset_version=opset_version,
|
|
2170
|
+
do_constant_folding=True,
|
|
2171
|
+
**export_kwargs,
|
|
2172
|
+
)
|
|
2173
|
+
|
|
2174
|
+
logging.info(colorize(f"ONNX model exported to: {target_path}", color="green"))
|
|
2175
|
+
return target_path
|
|
2176
|
+
|
|
2177
|
+
@overload
|
|
2178
|
+
def predict_onnx(
|
|
2179
|
+
self,
|
|
2180
|
+
onnx_path: str | os.PathLike | Path,
|
|
2181
|
+
data: str | dict | pd.DataFrame | DataLoader,
|
|
2182
|
+
batch_size: int = 32,
|
|
2183
|
+
save_path: str | os.PathLike | None = None,
|
|
2184
|
+
save_format: str = "csv",
|
|
2185
|
+
include_ids: bool | None = None,
|
|
2186
|
+
id_columns: str | list[str] | None = None,
|
|
2187
|
+
return_dataframe: Literal[True] = True,
|
|
2188
|
+
num_workers: int = 0,
|
|
2189
|
+
onnx_session: Any | None = None,
|
|
2190
|
+
) -> pd.DataFrame: ...
|
|
2191
|
+
|
|
2192
|
+
@overload
|
|
2193
|
+
def predict_onnx(
|
|
2194
|
+
self,
|
|
2195
|
+
onnx_path: str | os.PathLike | Path,
|
|
2196
|
+
data: str | dict | pd.DataFrame | DataLoader,
|
|
2197
|
+
batch_size: int = 32,
|
|
2198
|
+
save_path: None = None,
|
|
2199
|
+
save_format: str = "csv",
|
|
2200
|
+
include_ids: bool | None = None,
|
|
2201
|
+
id_columns: str | list[str] | None = None,
|
|
2202
|
+
return_dataframe: Literal[False] = False,
|
|
2203
|
+
num_workers: int = 0,
|
|
2204
|
+
onnx_session: Any | None = None,
|
|
2205
|
+
) -> np.ndarray: ...
|
|
2206
|
+
|
|
2207
|
+
@overload
|
|
2208
|
+
def predict_onnx(
|
|
2209
|
+
self,
|
|
2210
|
+
onnx_path: str | os.PathLike | Path,
|
|
2211
|
+
data: str | dict | pd.DataFrame | DataLoader,
|
|
2212
|
+
batch_size: int = 32,
|
|
2213
|
+
*,
|
|
2214
|
+
save_path: str | os.PathLike,
|
|
2215
|
+
save_format: str = "csv",
|
|
2216
|
+
include_ids: bool | None = None,
|
|
2217
|
+
id_columns: str | list[str] | None = None,
|
|
2218
|
+
return_dataframe: Literal[False] = False,
|
|
2219
|
+
num_workers: int = 0,
|
|
2220
|
+
onnx_session: Any | None = None,
|
|
2221
|
+
) -> Path: ...
|
|
2222
|
+
|
|
2223
|
+
def predict_onnx(
|
|
2224
|
+
self,
|
|
2225
|
+
onnx_path: str | os.PathLike | Path,
|
|
2226
|
+
data: str | dict | pd.DataFrame | DataLoader,
|
|
2227
|
+
batch_size: int = 32,
|
|
2228
|
+
save_path: str | os.PathLike | None = None,
|
|
2229
|
+
save_format: str = "csv",
|
|
2230
|
+
include_ids: bool | None = None,
|
|
2231
|
+
id_columns: str | list[str] | None = None,
|
|
2232
|
+
return_dataframe: bool = True,
|
|
2233
|
+
num_workers: int = 0,
|
|
2234
|
+
onnx_session: Any | None = None,
|
|
2235
|
+
) -> pd.DataFrame | np.ndarray | Path | None:
|
|
2236
|
+
"""
|
|
2237
|
+
Run ONNX inference on the given data.
|
|
2238
|
+
|
|
2239
|
+
Args:
|
|
2240
|
+
onnx_path: Path to the ONNX model file.
|
|
2241
|
+
data: Input data for prediction (file path, dict, DataFrame, or DataLoader).
|
|
2242
|
+
batch_size: Batch size for prediction.
|
|
2243
|
+
save_path: Optional path to save predictions; if None, predictions are not saved to disk.
|
|
2244
|
+
save_format: Format to save predictions ('csv' or 'parquet').
|
|
2245
|
+
include_ids: Whether to include ID columns in the output; if None, includes if id_columns are set.
|
|
2246
|
+
id_columns: Column name(s) to use as IDs; if None, uses model's id_columns.
|
|
2247
|
+
return_dataframe: Whether to return predictions as a pandas DataFrame; if False, returns a NumPy array.
|
|
2248
|
+
num_workers: DataLoader worker count.
|
|
2249
|
+
onnx_session: Optional pre-created ONNX Runtime session.
|
|
2250
|
+
"""
|
|
2251
|
+
predict_id_columns = id_columns if id_columns is not None else self.id_columns
|
|
2252
|
+
if isinstance(predict_id_columns, str):
|
|
2253
|
+
predict_id_columns = [predict_id_columns]
|
|
2254
|
+
|
|
2255
|
+
if include_ids is None:
|
|
2256
|
+
include_ids = bool(predict_id_columns)
|
|
2257
|
+
include_ids = include_ids and bool(predict_id_columns)
|
|
2258
|
+
|
|
2259
|
+
if save_path is not None and not return_dataframe:
|
|
2260
|
+
return self.predict_onnx_streaming(
|
|
2261
|
+
onnx_path=onnx_path,
|
|
2262
|
+
data=data,
|
|
2263
|
+
batch_size=batch_size,
|
|
2264
|
+
save_path=save_path,
|
|
2265
|
+
save_format=save_format,
|
|
2266
|
+
include_ids=include_ids,
|
|
2267
|
+
return_dataframe=return_dataframe,
|
|
2268
|
+
id_columns=predict_id_columns,
|
|
2269
|
+
onnx_session=onnx_session,
|
|
2270
|
+
num_workers=num_workers,
|
|
2271
|
+
)
|
|
2272
|
+
|
|
2273
|
+
session = onnx_session or load_onnx_session(onnx_path)
|
|
2274
|
+
session_inputs = session.get_inputs()
|
|
2275
|
+
session_input_names = [inp.name for inp in session_inputs]
|
|
2276
|
+
batch_dim = session_inputs[0].shape[0] if session_inputs else None
|
|
2277
|
+
fixed_batch = (
|
|
2278
|
+
batch_dim if isinstance(batch_dim, int) and batch_dim > 0 else None
|
|
2279
|
+
)
|
|
2280
|
+
data_loader = self.prepare_onnx_dataloader(
|
|
2281
|
+
data=data,
|
|
2282
|
+
batch_size=batch_size,
|
|
2283
|
+
num_workers=num_workers,
|
|
2284
|
+
)
|
|
2285
|
+
|
|
2286
|
+
y_pred_list = []
|
|
2287
|
+
id_buffers = (
|
|
2288
|
+
{name: [] for name in (predict_id_columns or [])} if include_ids else {}
|
|
2289
|
+
)
|
|
2290
|
+
|
|
2291
|
+
for batch_data in progress(data_loader, description="Predicting (ONNX)"):
|
|
2292
|
+
batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
|
|
2293
|
+
X_input, _ = self.get_input(batch_dict, require_labels=False)
|
|
2294
|
+
if X_input is None:
|
|
2295
|
+
raise ValueError(
|
|
2296
|
+
"[BaseModel-predict-onnx Error] No input features found in the prediction data."
|
|
2297
|
+
)
|
|
2298
|
+
orig_batch = None
|
|
2299
|
+
if fixed_batch is not None:
|
|
2300
|
+
X_input, orig_batch = pad_onnx_inputs(
|
|
2301
|
+
self.all_features, X_input, target_batch=fixed_batch
|
|
2302
|
+
)
|
|
2303
|
+
feed = build_onnx_input_feed(
|
|
2304
|
+
self.all_features, X_input, input_names=session_input_names
|
|
2305
|
+
)
|
|
2306
|
+
outputs = session.run(None, feed)
|
|
2307
|
+
y_pred_np = merge_onnx_outputs(outputs)
|
|
2308
|
+
if y_pred_np.ndim == 1:
|
|
2309
|
+
y_pred_np = y_pred_np.reshape(-1, 1)
|
|
2310
|
+
if orig_batch is not None and orig_batch > 0:
|
|
2311
|
+
y_pred_np = y_pred_np[:orig_batch]
|
|
2312
|
+
y_pred_list.append(y_pred_np)
|
|
2313
|
+
|
|
2314
|
+
if include_ids and predict_id_columns and batch_dict.get("ids"):
|
|
2315
|
+
ids_dict = batch_dict["ids"]
|
|
2316
|
+
orig_id_batch = None
|
|
2317
|
+
if fixed_batch is not None:
|
|
2318
|
+
ids_dict, orig_id_batch = pad_id_batch(ids_dict, fixed_batch)
|
|
2319
|
+
for id_name in predict_id_columns:
|
|
2320
|
+
if id_name not in ids_dict:
|
|
2321
|
+
continue
|
|
2322
|
+
id_tensor = ids_dict[id_name]
|
|
2323
|
+
id_np = (
|
|
2324
|
+
id_tensor.detach().cpu().numpy()
|
|
2325
|
+
if isinstance(id_tensor, torch.Tensor)
|
|
2326
|
+
else np.asarray(id_tensor)
|
|
2327
|
+
)
|
|
2328
|
+
if orig_batch is not None and orig_batch > 0:
|
|
2329
|
+
id_np = id_np[:orig_batch]
|
|
2330
|
+
elif orig_id_batch is not None and orig_id_batch > 0:
|
|
2331
|
+
id_np = id_np[:orig_id_batch]
|
|
2332
|
+
id_buffers[id_name].append(
|
|
2333
|
+
id_np.reshape(id_np.shape[0], -1) if id_np.ndim == 1 else id_np
|
|
2334
|
+
)
|
|
2335
|
+
|
|
2336
|
+
y_pred_all = np.concatenate(y_pred_list, axis=0) if y_pred_list else None
|
|
2337
|
+
if y_pred_all is None:
|
|
2338
|
+
return pd.DataFrame() if return_dataframe else np.array([])
|
|
2339
|
+
|
|
2340
|
+
num_outputs = y_pred_all.shape[1] if y_pred_all.ndim > 1 else 1
|
|
2341
|
+
pred_columns = (
|
|
2342
|
+
list(self.target_columns[:num_outputs]) if self.target_columns else []
|
|
2343
|
+
)
|
|
2344
|
+
while len(pred_columns) < num_outputs:
|
|
2345
|
+
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
2346
|
+
|
|
2347
|
+
id_df = None
|
|
2348
|
+
if include_ids and predict_id_columns:
|
|
2349
|
+
id_arrays = {
|
|
2350
|
+
id_name: np.concatenate(id_buffers[id_name], axis=0)
|
|
2351
|
+
for id_name in predict_id_columns
|
|
2352
|
+
if id_buffers.get(id_name)
|
|
2353
|
+
}
|
|
2354
|
+
if id_arrays:
|
|
2355
|
+
id_df = pd.DataFrame(id_arrays)
|
|
2356
|
+
if len(id_df) and len(id_df) != len(y_pred_all):
|
|
2357
|
+
raise ValueError(
|
|
2358
|
+
f"[BaseModel-predict-onnx Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(y_pred_all)})."
|
|
2359
|
+
)
|
|
2360
|
+
|
|
2361
|
+
output = y_pred_all
|
|
2362
|
+
if return_dataframe:
|
|
2363
|
+
df_to_return = pd.DataFrame(y_pred_all, columns=pred_columns)
|
|
2364
|
+
if id_df is not None:
|
|
2365
|
+
df_to_return = pd.concat([id_df, df_to_return], axis=1)
|
|
2366
|
+
output = df_to_return
|
|
2367
|
+
|
|
2368
|
+
if save_path is not None:
|
|
2369
|
+
if not check_streaming_support(save_format):
|
|
2370
|
+
logging.warning(
|
|
2371
|
+
f"[BaseModel-predict-onnx Warning] Format '{save_format}' does not support streaming writes. "
|
|
2372
|
+
"The entire result will be saved at once. Use csv or parquet for large datasets."
|
|
2373
|
+
)
|
|
2374
|
+
|
|
2375
|
+
suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
|
|
2376
|
+
target_path = get_save_path(
|
|
2377
|
+
path=save_path,
|
|
2378
|
+
default_dir=self.session.predictions_dir,
|
|
2379
|
+
default_name="predictions",
|
|
2380
|
+
suffix=suffix,
|
|
2381
|
+
)
|
|
2382
|
+
if return_dataframe and isinstance(output, pd.DataFrame):
|
|
2383
|
+
df_to_save = output
|
|
2384
|
+
else:
|
|
2385
|
+
df_to_save = pd.DataFrame(y_pred_all, columns=pred_columns)
|
|
2386
|
+
if id_df is not None:
|
|
2387
|
+
df_to_save = pd.concat([id_df, df_to_save], axis=1)
|
|
2388
|
+
|
|
2389
|
+
if save_format == "csv":
|
|
2390
|
+
df_to_save.to_csv(target_path, index=False)
|
|
2391
|
+
elif save_format == "parquet":
|
|
2392
|
+
df_to_save.to_parquet(target_path, index=False)
|
|
2393
|
+
elif save_format == "feather":
|
|
2394
|
+
df_to_save.to_feather(target_path)
|
|
2395
|
+
elif save_format == "excel":
|
|
2396
|
+
df_to_save.to_excel(target_path, index=False)
|
|
2397
|
+
elif save_format == "hdf5":
|
|
2398
|
+
df_to_save.to_hdf(target_path, key="predictions", mode="w")
|
|
2399
|
+
else:
|
|
2400
|
+
raise ValueError(f"Unsupported save format: {save_format}")
|
|
2401
|
+
logging.info(
|
|
2402
|
+
colorize(f"Predictions saved to: {target_path}", color="green")
|
|
2403
|
+
)
|
|
2404
|
+
return output
|
|
2405
|
+
|
|
2406
|
+
def predict_onnx_streaming(
|
|
2407
|
+
self,
|
|
2408
|
+
onnx_path: str | os.PathLike | Path,
|
|
2409
|
+
data: str | dict | pd.DataFrame | DataLoader,
|
|
2410
|
+
batch_size: int,
|
|
2411
|
+
save_path: str | os.PathLike,
|
|
2412
|
+
save_format: str,
|
|
2413
|
+
include_ids: bool,
|
|
2414
|
+
return_dataframe: bool,
|
|
2415
|
+
id_columns: list[str] | None = None,
|
|
2416
|
+
onnx_session: Any | None = None,
|
|
2417
|
+
num_workers: int = 0,
|
|
2418
|
+
):
|
|
2419
|
+
"""
|
|
2420
|
+
Run ONNX inference using streaming mode for large datasets.
|
|
2421
|
+
"""
|
|
2422
|
+
session = onnx_session or load_onnx_session(onnx_path)
|
|
2423
|
+
session_inputs = session.get_inputs()
|
|
2424
|
+
session_input_names = [inp.name for inp in session_inputs]
|
|
2425
|
+
batch_dim = session_inputs[0].shape[0] if session_inputs else None
|
|
2426
|
+
fixed_batch = (
|
|
2427
|
+
batch_dim if isinstance(batch_dim, int) and batch_dim > 0 else None
|
|
2428
|
+
)
|
|
2429
|
+
data_loader = self.prepare_onnx_dataloader(
|
|
2430
|
+
data=data,
|
|
2431
|
+
batch_size=batch_size,
|
|
2432
|
+
num_workers=num_workers,
|
|
2433
|
+
)
|
|
2434
|
+
|
|
2435
|
+
if not check_streaming_support(save_format):
|
|
2436
|
+
logging.warning(
|
|
2437
|
+
f"[Predict ONNX Streaming Warning] Format '{save_format}' does not support streaming writes. "
|
|
2438
|
+
"Results will be collected in memory and saved at the end. Use csv or parquet for true streaming."
|
|
2439
|
+
)
|
|
2440
|
+
|
|
2441
|
+
suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
|
|
2442
|
+
target_path = get_save_path(
|
|
2443
|
+
path=save_path,
|
|
2444
|
+
default_dir=self.session.predictions_dir,
|
|
2445
|
+
default_name="predictions",
|
|
2446
|
+
suffix=suffix,
|
|
2447
|
+
add_timestamp=False,
|
|
2448
|
+
)
|
|
2449
|
+
header_written = target_path.exists() and target_path.stat().st_size > 0
|
|
2450
|
+
parquet_writer = None
|
|
2451
|
+
pred_columns = None
|
|
2452
|
+
collected_frames = []
|
|
2453
|
+
|
|
2454
|
+
for batch_data in progress(data_loader, description="Predicting (ONNX)"):
|
|
2455
|
+
batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
|
|
2456
|
+
X_input, _ = self.get_input(batch_dict, require_labels=False)
|
|
2457
|
+
if X_input is None:
|
|
2458
|
+
continue
|
|
2459
|
+
orig_batch = None
|
|
2460
|
+
if fixed_batch is not None:
|
|
2461
|
+
X_input, orig_batch = pad_onnx_inputs(
|
|
2462
|
+
self.all_features, X_input, target_batch=fixed_batch
|
|
2463
|
+
)
|
|
2464
|
+
feed = build_onnx_input_feed(
|
|
2465
|
+
self.all_features, X_input, input_names=session_input_names
|
|
2466
|
+
)
|
|
2467
|
+
outputs = session.run(None, feed)
|
|
2468
|
+
y_pred_np = merge_onnx_outputs(outputs)
|
|
2469
|
+
if y_pred_np.ndim == 1:
|
|
2470
|
+
y_pred_np = y_pred_np.reshape(-1, 1)
|
|
2471
|
+
if orig_batch is not None and orig_batch > 0:
|
|
2472
|
+
y_pred_np = y_pred_np[:orig_batch]
|
|
2473
|
+
if pred_columns is None:
|
|
2474
|
+
num_outputs = y_pred_np.shape[1]
|
|
2475
|
+
pred_columns = (
|
|
2476
|
+
list(self.target_columns[:num_outputs])
|
|
2477
|
+
if self.target_columns
|
|
2478
|
+
else []
|
|
2479
|
+
)
|
|
2480
|
+
while len(pred_columns) < num_outputs:
|
|
2481
|
+
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
2482
|
+
|
|
2483
|
+
ids = batch_dict.get("ids") if include_ids and id_columns else None
|
|
2484
|
+
if ids and fixed_batch is not None:
|
|
2485
|
+
ids, orig_id_batch = pad_id_batch(ids, fixed_batch)
|
|
2486
|
+
else:
|
|
2487
|
+
orig_id_batch = None
|
|
2488
|
+
id_arrays_batch = {
|
|
2489
|
+
id_name: (
|
|
2490
|
+
ids[id_name].detach().cpu().numpy()
|
|
2491
|
+
if isinstance(ids[id_name], torch.Tensor)
|
|
2492
|
+
else np.asarray(ids[id_name])
|
|
2493
|
+
).reshape(-1)
|
|
2494
|
+
for id_name in (id_columns or [])
|
|
2495
|
+
if ids and id_name in ids
|
|
2496
|
+
}
|
|
2497
|
+
if orig_batch is not None and orig_batch > 0:
|
|
2498
|
+
id_arrays_batch = {
|
|
2499
|
+
k: v[:orig_batch] for k, v in id_arrays_batch.items()
|
|
2500
|
+
}
|
|
2501
|
+
elif orig_id_batch is not None and orig_id_batch > 0:
|
|
2502
|
+
id_arrays_batch = {
|
|
2503
|
+
k: v[:orig_id_batch] for k, v in id_arrays_batch.items()
|
|
2504
|
+
}
|
|
2505
|
+
|
|
2506
|
+
df_batch = pd.DataFrame(y_pred_np, columns=pred_columns)
|
|
2507
|
+
if id_arrays_batch:
|
|
2508
|
+
id_df = pd.DataFrame(id_arrays_batch)
|
|
2509
|
+
if len(id_df) and len(df_batch) and len(id_df) != len(df_batch):
|
|
2510
|
+
raise ValueError(
|
|
2511
|
+
f"Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_batch)})."
|
|
2512
|
+
)
|
|
2513
|
+
df_batch = pd.concat([id_df, df_batch], axis=1)
|
|
2514
|
+
|
|
2515
|
+
should_collect = return_dataframe or save_format not in {"csv", "parquet"}
|
|
2516
|
+
if should_collect:
|
|
2517
|
+
collected_frames.append(df_batch)
|
|
2518
|
+
|
|
2519
|
+
if save_format == "csv":
|
|
2520
|
+
df_batch.to_csv(
|
|
2521
|
+
target_path, mode="a", header=not header_written, index=False
|
|
2522
|
+
)
|
|
2523
|
+
header_written = True
|
|
2524
|
+
elif save_format == "parquet":
|
|
2525
|
+
try:
|
|
2526
|
+
import pyarrow as pa
|
|
2527
|
+
import pyarrow.parquet as pq
|
|
2528
|
+
except ImportError as exc: # pragma: no cover
|
|
2529
|
+
raise ImportError(
|
|
2530
|
+
"[BaseModel-predict-onnx-streaming Error] Parquet streaming save requires pyarrow."
|
|
2531
|
+
) from exc
|
|
2532
|
+
table = pa.Table.from_pandas(df_batch, preserve_index=False)
|
|
2533
|
+
if parquet_writer is None:
|
|
2534
|
+
parquet_writer = pq.ParquetWriter(target_path, table.schema)
|
|
2535
|
+
parquet_writer.write_table(table)
|
|
2536
|
+
# Non-streaming formats are saved after collecting all batches.
|
|
2537
|
+
|
|
2538
|
+
if parquet_writer is not None:
|
|
2539
|
+
parquet_writer.close()
|
|
2540
|
+
|
|
2541
|
+
if save_format in ["feather", "excel", "hdf5"] and collected_frames:
|
|
2542
|
+
combined_df = pd.concat(collected_frames, ignore_index=True)
|
|
2543
|
+
if save_format == "feather":
|
|
2544
|
+
combined_df.to_feather(target_path)
|
|
2545
|
+
elif save_format == "excel":
|
|
2546
|
+
combined_df.to_excel(target_path, index=False)
|
|
2547
|
+
elif save_format == "hdf5":
|
|
2548
|
+
combined_df.to_hdf(target_path, key="predictions", mode="w")
|
|
2549
|
+
|
|
2550
|
+
logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
|
|
2551
|
+
if return_dataframe:
|
|
2552
|
+
return (
|
|
2553
|
+
pd.concat(collected_frames, ignore_index=True)
|
|
2554
|
+
if collected_frames
|
|
2555
|
+
else pd.DataFrame(columns=pred_columns or [])
|
|
2556
|
+
)
|
|
2557
|
+
return target_path
|
|
2558
|
+
|
|
2051
2559
|
def save_model(
|
|
2052
2560
|
self,
|
|
2053
2561
|
save_path: str | Path | None = None,
|