nextrec 0.4.34__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/basic/model.py CHANGED
@@ -2,11 +2,12 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 22/01/2026
5
+ Checkpoint: edit on 25/01/2026
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
9
9
  import getpass
10
+ import inspect
10
11
  import logging
11
12
  import os
12
13
  import sys
@@ -90,6 +91,15 @@ from nextrec.utils.model import (
90
91
  compute_ranking_loss,
91
92
  get_loss_list,
92
93
  )
94
+ from nextrec.utils.onnx_utils import (
95
+ OnnxModelWrapper,
96
+ build_onnx_input_feed,
97
+ create_dummy_inputs,
98
+ load_onnx_session,
99
+ merge_onnx_outputs,
100
+ pad_onnx_inputs,
101
+ pad_id_batch,
102
+ )
93
103
 
94
104
  from nextrec.utils.types import (
95
105
  LossName,
@@ -312,7 +322,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
312
322
  """
313
323
  Prepare unified input features and labels from the given input data.
314
324
 
315
-
316
325
  Args:
317
326
  input_data: Input data dictionary containing 'features' and optionally 'labels', e.g., {'features': {'feat1': [...], 'feat2': [...]}, 'labels': {'label': [...]}}.
318
327
  require_labels: Whether labels are required in the input data. Default is True: for training and evaluation with labels.
@@ -367,14 +376,14 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
367
376
  target_tensor = to_tensor(
368
377
  target_data, dtype=torch.float32, device=self.device
369
378
  )
370
- target_tensor = target_tensor.view(
379
+ target_tensor = target_tensor.reshape(
371
380
  target_tensor.size(0), -1
372
381
  ) # always reshape to (batch_size, num_targets)
373
382
  target_tensors.append(target_tensor)
374
383
  if target_tensors:
375
384
  y = torch.cat(target_tensors, dim=1)
376
385
  if y.shape[1] == 1: # no need to do that again
377
- y = y.view(-1)
386
+ y = y.reshape(-1)
378
387
  elif require_labels:
379
388
  raise ValueError(
380
389
  "[BaseModel-input Error] Labels are required but none were found in the input batch."
@@ -2048,6 +2057,505 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
2048
2057
  # Return the actual save path when not returning dataframe
2049
2058
  return target_path
2050
2059
 
2060
+ def prepare_onnx_dataloader(
2061
+ self,
2062
+ data: str | dict | pd.DataFrame | DataLoader,
2063
+ batch_size: int,
2064
+ num_workers: int,
2065
+ ) -> DataLoader:
2066
+ """
2067
+ Prepare a DataLoader for ONNX prediction.
2068
+
2069
+ Args:
2070
+ data: Input data (file path, dict, DataFrame, or DataLoader).
2071
+ batch_size: Effective batch size. When `data` is a file path, this is
2072
+ treated as the streaming chunk size.
2073
+ num_workers: Number of DataLoader workers.
2074
+
2075
+ """
2076
+ if isinstance(data, DataLoader):
2077
+ if num_workers != 0:
2078
+ logging.warning(
2079
+ "[Predict ONNX Warning] num_workers parameter is ignored when data is already a DataLoader. "
2080
+ "The DataLoader's existing num_workers configuration will be used."
2081
+ )
2082
+ return data
2083
+ # if data is a file path, use streaming DataLoader
2084
+ # will set batch_size=1 cause each batch is a file chunk
2085
+ if isinstance(data, (str, os.PathLike)):
2086
+ rec_loader = RecDataLoader(
2087
+ dense_features=self.dense_features,
2088
+ sparse_features=self.sparse_features,
2089
+ sequence_features=self.sequence_features,
2090
+ target=self.target_columns,
2091
+ id_columns=self.id_columns,
2092
+ )
2093
+ return rec_loader.create_dataloader(
2094
+ data=data,
2095
+ batch_size=1,
2096
+ shuffle=False,
2097
+ streaming=True,
2098
+ chunk_size=batch_size,
2099
+ )
2100
+ return self.prepare_data_loader(
2101
+ data, batch_size=batch_size, shuffle=False, num_workers=num_workers
2102
+ )
2103
+
2104
+ def export_onnx(
2105
+ self,
2106
+ save_path: str | Path | None = None,
2107
+ batch_size: int = 1,
2108
+ opset_version: int = 18,
2109
+ ) -> Path:
2110
+ """
2111
+ Export the model to ONNX.
2112
+
2113
+ Usage:
2114
+ onnx_path = model.export_onnx(
2115
+ save_path="model.onnx",
2116
+ batch_size=1,
2117
+ opset_version=18,
2118
+ )
2119
+
2120
+ Args:
2121
+ save_path: Path to save the ONNX model; if None, uses session root.
2122
+ batch_size: Dummy batch size for tracing.
2123
+ opset_version: ONNX opset version for export.
2124
+ """
2125
+ model_to_export = (
2126
+ self.ddp_model.module
2127
+ if hasattr(self, "ddp_model") and self.ddp_model is not None
2128
+ else self
2129
+ )
2130
+ model_to_export = model_to_export.to(self.device)
2131
+ model_to_export.eval()
2132
+
2133
+ input_names = [feat.name for feat in self.all_features]
2134
+ dummy_inputs = create_dummy_inputs(
2135
+ self.all_features,
2136
+ batch_size=batch_size,
2137
+ device=self.device,
2138
+ )
2139
+ wrapper = OnnxModelWrapper(model_to_export, input_names)
2140
+ with torch.no_grad():
2141
+ output_sample = wrapper(*dummy_inputs)
2142
+ if isinstance(output_sample, (tuple, list)): # multiple outputs
2143
+ output_names = [f"output_{idx}" for idx in range(len(output_sample))]
2144
+ else:
2145
+ output_names = ["output"]
2146
+ target_path = get_save_path(
2147
+ path=save_path,
2148
+ default_dir=self.session.root,
2149
+ default_name=f"{self.model_name}_onnx",
2150
+ suffix="onnx",
2151
+ )
2152
+ export_kwargs: dict[str, Any] = {}
2153
+ export_sig = inspect.signature(torch.onnx.export)
2154
+ if "dynamo" in export_sig.parameters:
2155
+ export_kwargs["dynamo"] = True
2156
+ if opset_version < 18:
2157
+ logging.warning(
2158
+ "[BaseModel-export-onnx Warning] TorchDynamo exporter requires opset >= 18. "
2159
+ "Overriding opset_version to 18."
2160
+ )
2161
+ opset_version = 18
2162
+
2163
+ torch.onnx.export(
2164
+ wrapper,
2165
+ tuple(dummy_inputs),
2166
+ target_path,
2167
+ input_names=list(input_names),
2168
+ output_names=list(output_names),
2169
+ opset_version=opset_version,
2170
+ do_constant_folding=True,
2171
+ **export_kwargs,
2172
+ )
2173
+
2174
+ logging.info(colorize(f"ONNX model exported to: {target_path}", color="green"))
2175
+ return target_path
2176
+
2177
+ @overload
2178
+ def predict_onnx(
2179
+ self,
2180
+ onnx_path: str | os.PathLike | Path,
2181
+ data: str | dict | pd.DataFrame | DataLoader,
2182
+ batch_size: int = 32,
2183
+ save_path: str | os.PathLike | None = None,
2184
+ save_format: str = "csv",
2185
+ include_ids: bool | None = None,
2186
+ id_columns: str | list[str] | None = None,
2187
+ return_dataframe: Literal[True] = True,
2188
+ num_workers: int = 0,
2189
+ onnx_session: Any | None = None,
2190
+ ) -> pd.DataFrame: ...
2191
+
2192
+ @overload
2193
+ def predict_onnx(
2194
+ self,
2195
+ onnx_path: str | os.PathLike | Path,
2196
+ data: str | dict | pd.DataFrame | DataLoader,
2197
+ batch_size: int = 32,
2198
+ save_path: None = None,
2199
+ save_format: str = "csv",
2200
+ include_ids: bool | None = None,
2201
+ id_columns: str | list[str] | None = None,
2202
+ return_dataframe: Literal[False] = False,
2203
+ num_workers: int = 0,
2204
+ onnx_session: Any | None = None,
2205
+ ) -> np.ndarray: ...
2206
+
2207
+ @overload
2208
+ def predict_onnx(
2209
+ self,
2210
+ onnx_path: str | os.PathLike | Path,
2211
+ data: str | dict | pd.DataFrame | DataLoader,
2212
+ batch_size: int = 32,
2213
+ *,
2214
+ save_path: str | os.PathLike,
2215
+ save_format: str = "csv",
2216
+ include_ids: bool | None = None,
2217
+ id_columns: str | list[str] | None = None,
2218
+ return_dataframe: Literal[False] = False,
2219
+ num_workers: int = 0,
2220
+ onnx_session: Any | None = None,
2221
+ ) -> Path: ...
2222
+
2223
+ def predict_onnx(
2224
+ self,
2225
+ onnx_path: str | os.PathLike | Path,
2226
+ data: str | dict | pd.DataFrame | DataLoader,
2227
+ batch_size: int = 32,
2228
+ save_path: str | os.PathLike | None = None,
2229
+ save_format: str = "csv",
2230
+ include_ids: bool | None = None,
2231
+ id_columns: str | list[str] | None = None,
2232
+ return_dataframe: bool = True,
2233
+ num_workers: int = 0,
2234
+ onnx_session: Any | None = None,
2235
+ ) -> pd.DataFrame | np.ndarray | Path | None:
2236
+ """
2237
+ Run ONNX inference on the given data.
2238
+
2239
+ Args:
2240
+ onnx_path: Path to the ONNX model file.
2241
+ data: Input data for prediction (file path, dict, DataFrame, or DataLoader).
2242
+ batch_size: Batch size for prediction.
2243
+ save_path: Optional path to save predictions; if None, predictions are not saved to disk.
2244
+ save_format: Format to save predictions ('csv' or 'parquet').
2245
+ include_ids: Whether to include ID columns in the output; if None, includes if id_columns are set.
2246
+ id_columns: Column name(s) to use as IDs; if None, uses model's id_columns.
2247
+ return_dataframe: Whether to return predictions as a pandas DataFrame; if False, returns a NumPy array.
2248
+ num_workers: DataLoader worker count.
2249
+ onnx_session: Optional pre-created ONNX Runtime session.
2250
+ """
2251
+ predict_id_columns = id_columns if id_columns is not None else self.id_columns
2252
+ if isinstance(predict_id_columns, str):
2253
+ predict_id_columns = [predict_id_columns]
2254
+
2255
+ if include_ids is None:
2256
+ include_ids = bool(predict_id_columns)
2257
+ include_ids = include_ids and bool(predict_id_columns)
2258
+
2259
+ if save_path is not None and not return_dataframe:
2260
+ return self.predict_onnx_streaming(
2261
+ onnx_path=onnx_path,
2262
+ data=data,
2263
+ batch_size=batch_size,
2264
+ save_path=save_path,
2265
+ save_format=save_format,
2266
+ include_ids=include_ids,
2267
+ return_dataframe=return_dataframe,
2268
+ id_columns=predict_id_columns,
2269
+ onnx_session=onnx_session,
2270
+ num_workers=num_workers,
2271
+ )
2272
+
2273
+ session = onnx_session or load_onnx_session(onnx_path)
2274
+ session_inputs = session.get_inputs()
2275
+ session_input_names = [inp.name for inp in session_inputs]
2276
+ batch_dim = session_inputs[0].shape[0] if session_inputs else None
2277
+ fixed_batch = (
2278
+ batch_dim if isinstance(batch_dim, int) and batch_dim > 0 else None
2279
+ )
2280
+ data_loader = self.prepare_onnx_dataloader(
2281
+ data=data,
2282
+ batch_size=batch_size,
2283
+ num_workers=num_workers,
2284
+ )
2285
+
2286
+ y_pred_list = []
2287
+ id_buffers = (
2288
+ {name: [] for name in (predict_id_columns or [])} if include_ids else {}
2289
+ )
2290
+
2291
+ for batch_data in progress(data_loader, description="Predicting (ONNX)"):
2292
+ batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
2293
+ X_input, _ = self.get_input(batch_dict, require_labels=False)
2294
+ if X_input is None:
2295
+ raise ValueError(
2296
+ "[BaseModel-predict-onnx Error] No input features found in the prediction data."
2297
+ )
2298
+ orig_batch = None
2299
+ if fixed_batch is not None:
2300
+ X_input, orig_batch = pad_onnx_inputs(
2301
+ self.all_features, X_input, target_batch=fixed_batch
2302
+ )
2303
+ feed = build_onnx_input_feed(
2304
+ self.all_features, X_input, input_names=session_input_names
2305
+ )
2306
+ outputs = session.run(None, feed)
2307
+ y_pred_np = merge_onnx_outputs(outputs)
2308
+ if y_pred_np.ndim == 1:
2309
+ y_pred_np = y_pred_np.reshape(-1, 1)
2310
+ if orig_batch is not None and orig_batch > 0:
2311
+ y_pred_np = y_pred_np[:orig_batch]
2312
+ y_pred_list.append(y_pred_np)
2313
+
2314
+ if include_ids and predict_id_columns and batch_dict.get("ids"):
2315
+ ids_dict = batch_dict["ids"]
2316
+ orig_id_batch = None
2317
+ if fixed_batch is not None:
2318
+ ids_dict, orig_id_batch = pad_id_batch(ids_dict, fixed_batch)
2319
+ for id_name in predict_id_columns:
2320
+ if id_name not in ids_dict:
2321
+ continue
2322
+ id_tensor = ids_dict[id_name]
2323
+ id_np = (
2324
+ id_tensor.detach().cpu().numpy()
2325
+ if isinstance(id_tensor, torch.Tensor)
2326
+ else np.asarray(id_tensor)
2327
+ )
2328
+ if orig_batch is not None and orig_batch > 0:
2329
+ id_np = id_np[:orig_batch]
2330
+ elif orig_id_batch is not None and orig_id_batch > 0:
2331
+ id_np = id_np[:orig_id_batch]
2332
+ id_buffers[id_name].append(
2333
+ id_np.reshape(id_np.shape[0], -1) if id_np.ndim == 1 else id_np
2334
+ )
2335
+
2336
+ y_pred_all = np.concatenate(y_pred_list, axis=0) if y_pred_list else None
2337
+ if y_pred_all is None:
2338
+ return pd.DataFrame() if return_dataframe else np.array([])
2339
+
2340
+ num_outputs = y_pred_all.shape[1] if y_pred_all.ndim > 1 else 1
2341
+ pred_columns = (
2342
+ list(self.target_columns[:num_outputs]) if self.target_columns else []
2343
+ )
2344
+ while len(pred_columns) < num_outputs:
2345
+ pred_columns.append(f"pred_{len(pred_columns)}")
2346
+
2347
+ id_df = None
2348
+ if include_ids and predict_id_columns:
2349
+ id_arrays = {
2350
+ id_name: np.concatenate(id_buffers[id_name], axis=0)
2351
+ for id_name in predict_id_columns
2352
+ if id_buffers.get(id_name)
2353
+ }
2354
+ if id_arrays:
2355
+ id_df = pd.DataFrame(id_arrays)
2356
+ if len(id_df) and len(id_df) != len(y_pred_all):
2357
+ raise ValueError(
2358
+ f"[BaseModel-predict-onnx Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(y_pred_all)})."
2359
+ )
2360
+
2361
+ output = y_pred_all
2362
+ if return_dataframe:
2363
+ df_to_return = pd.DataFrame(y_pred_all, columns=pred_columns)
2364
+ if id_df is not None:
2365
+ df_to_return = pd.concat([id_df, df_to_return], axis=1)
2366
+ output = df_to_return
2367
+
2368
+ if save_path is not None:
2369
+ if not check_streaming_support(save_format):
2370
+ logging.warning(
2371
+ f"[BaseModel-predict-onnx Warning] Format '{save_format}' does not support streaming writes. "
2372
+ "The entire result will be saved at once. Use csv or parquet for large datasets."
2373
+ )
2374
+
2375
+ suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
2376
+ target_path = get_save_path(
2377
+ path=save_path,
2378
+ default_dir=self.session.predictions_dir,
2379
+ default_name="predictions",
2380
+ suffix=suffix,
2381
+ )
2382
+ if return_dataframe and isinstance(output, pd.DataFrame):
2383
+ df_to_save = output
2384
+ else:
2385
+ df_to_save = pd.DataFrame(y_pred_all, columns=pred_columns)
2386
+ if id_df is not None:
2387
+ df_to_save = pd.concat([id_df, df_to_save], axis=1)
2388
+
2389
+ if save_format == "csv":
2390
+ df_to_save.to_csv(target_path, index=False)
2391
+ elif save_format == "parquet":
2392
+ df_to_save.to_parquet(target_path, index=False)
2393
+ elif save_format == "feather":
2394
+ df_to_save.to_feather(target_path)
2395
+ elif save_format == "excel":
2396
+ df_to_save.to_excel(target_path, index=False)
2397
+ elif save_format == "hdf5":
2398
+ df_to_save.to_hdf(target_path, key="predictions", mode="w")
2399
+ else:
2400
+ raise ValueError(f"Unsupported save format: {save_format}")
2401
+ logging.info(
2402
+ colorize(f"Predictions saved to: {target_path}", color="green")
2403
+ )
2404
+ return output
2405
+
2406
+ def predict_onnx_streaming(
2407
+ self,
2408
+ onnx_path: str | os.PathLike | Path,
2409
+ data: str | dict | pd.DataFrame | DataLoader,
2410
+ batch_size: int,
2411
+ save_path: str | os.PathLike,
2412
+ save_format: str,
2413
+ include_ids: bool,
2414
+ return_dataframe: bool,
2415
+ id_columns: list[str] | None = None,
2416
+ onnx_session: Any | None = None,
2417
+ num_workers: int = 0,
2418
+ ):
2419
+ """
2420
+ Run ONNX inference using streaming mode for large datasets.
2421
+ """
2422
+ session = onnx_session or load_onnx_session(onnx_path)
2423
+ session_inputs = session.get_inputs()
2424
+ session_input_names = [inp.name for inp in session_inputs]
2425
+ batch_dim = session_inputs[0].shape[0] if session_inputs else None
2426
+ fixed_batch = (
2427
+ batch_dim if isinstance(batch_dim, int) and batch_dim > 0 else None
2428
+ )
2429
+ data_loader = self.prepare_onnx_dataloader(
2430
+ data=data,
2431
+ batch_size=batch_size,
2432
+ num_workers=num_workers,
2433
+ )
2434
+
2435
+ if not check_streaming_support(save_format):
2436
+ logging.warning(
2437
+ f"[Predict ONNX Streaming Warning] Format '{save_format}' does not support streaming writes. "
2438
+ "Results will be collected in memory and saved at the end. Use csv or parquet for true streaming."
2439
+ )
2440
+
2441
+ suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
2442
+ target_path = get_save_path(
2443
+ path=save_path,
2444
+ default_dir=self.session.predictions_dir,
2445
+ default_name="predictions",
2446
+ suffix=suffix,
2447
+ add_timestamp=False,
2448
+ )
2449
+ header_written = target_path.exists() and target_path.stat().st_size > 0
2450
+ parquet_writer = None
2451
+ pred_columns = None
2452
+ collected_frames = []
2453
+
2454
+ for batch_data in progress(data_loader, description="Predicting (ONNX)"):
2455
+ batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
2456
+ X_input, _ = self.get_input(batch_dict, require_labels=False)
2457
+ if X_input is None:
2458
+ continue
2459
+ orig_batch = None
2460
+ if fixed_batch is not None:
2461
+ X_input, orig_batch = pad_onnx_inputs(
2462
+ self.all_features, X_input, target_batch=fixed_batch
2463
+ )
2464
+ feed = build_onnx_input_feed(
2465
+ self.all_features, X_input, input_names=session_input_names
2466
+ )
2467
+ outputs = session.run(None, feed)
2468
+ y_pred_np = merge_onnx_outputs(outputs)
2469
+ if y_pred_np.ndim == 1:
2470
+ y_pred_np = y_pred_np.reshape(-1, 1)
2471
+ if orig_batch is not None and orig_batch > 0:
2472
+ y_pred_np = y_pred_np[:orig_batch]
2473
+ if pred_columns is None:
2474
+ num_outputs = y_pred_np.shape[1]
2475
+ pred_columns = (
2476
+ list(self.target_columns[:num_outputs])
2477
+ if self.target_columns
2478
+ else []
2479
+ )
2480
+ while len(pred_columns) < num_outputs:
2481
+ pred_columns.append(f"pred_{len(pred_columns)}")
2482
+
2483
+ ids = batch_dict.get("ids") if include_ids and id_columns else None
2484
+ if ids and fixed_batch is not None:
2485
+ ids, orig_id_batch = pad_id_batch(ids, fixed_batch)
2486
+ else:
2487
+ orig_id_batch = None
2488
+ id_arrays_batch = {
2489
+ id_name: (
2490
+ ids[id_name].detach().cpu().numpy()
2491
+ if isinstance(ids[id_name], torch.Tensor)
2492
+ else np.asarray(ids[id_name])
2493
+ ).reshape(-1)
2494
+ for id_name in (id_columns or [])
2495
+ if ids and id_name in ids
2496
+ }
2497
+ if orig_batch is not None and orig_batch > 0:
2498
+ id_arrays_batch = {
2499
+ k: v[:orig_batch] for k, v in id_arrays_batch.items()
2500
+ }
2501
+ elif orig_id_batch is not None and orig_id_batch > 0:
2502
+ id_arrays_batch = {
2503
+ k: v[:orig_id_batch] for k, v in id_arrays_batch.items()
2504
+ }
2505
+
2506
+ df_batch = pd.DataFrame(y_pred_np, columns=pred_columns)
2507
+ if id_arrays_batch:
2508
+ id_df = pd.DataFrame(id_arrays_batch)
2509
+ if len(id_df) and len(df_batch) and len(id_df) != len(df_batch):
2510
+ raise ValueError(
2511
+ f"Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_batch)})."
2512
+ )
2513
+ df_batch = pd.concat([id_df, df_batch], axis=1)
2514
+
2515
+ should_collect = return_dataframe or save_format not in {"csv", "parquet"}
2516
+ if should_collect:
2517
+ collected_frames.append(df_batch)
2518
+
2519
+ if save_format == "csv":
2520
+ df_batch.to_csv(
2521
+ target_path, mode="a", header=not header_written, index=False
2522
+ )
2523
+ header_written = True
2524
+ elif save_format == "parquet":
2525
+ try:
2526
+ import pyarrow as pa
2527
+ import pyarrow.parquet as pq
2528
+ except ImportError as exc: # pragma: no cover
2529
+ raise ImportError(
2530
+ "[BaseModel-predict-onnx-streaming Error] Parquet streaming save requires pyarrow."
2531
+ ) from exc
2532
+ table = pa.Table.from_pandas(df_batch, preserve_index=False)
2533
+ if parquet_writer is None:
2534
+ parquet_writer = pq.ParquetWriter(target_path, table.schema)
2535
+ parquet_writer.write_table(table)
2536
+ # Non-streaming formats are saved after collecting all batches.
2537
+
2538
+ if parquet_writer is not None:
2539
+ parquet_writer.close()
2540
+
2541
+ if save_format in ["feather", "excel", "hdf5"] and collected_frames:
2542
+ combined_df = pd.concat(collected_frames, ignore_index=True)
2543
+ if save_format == "feather":
2544
+ combined_df.to_feather(target_path)
2545
+ elif save_format == "excel":
2546
+ combined_df.to_excel(target_path, index=False)
2547
+ elif save_format == "hdf5":
2548
+ combined_df.to_hdf(target_path, key="predictions", mode="w")
2549
+
2550
+ logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
2551
+ if return_dataframe:
2552
+ return (
2553
+ pd.concat(collected_frames, ignore_index=True)
2554
+ if collected_frames
2555
+ else pd.DataFrame(columns=pred_columns or [])
2556
+ )
2557
+ return target_path
2558
+
2051
2559
  def save_model(
2052
2560
  self,
2053
2561
  save_path: str | Path | None = None,