p2predict 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,706 @@
1
+ import os
2
+ import sys
3
+ from typing import Optional
4
+
5
+ import click
6
+ import numpy as np
7
+ import pandas as pd
8
+ import questionary
9
+ from rich.console import Console
10
+ from rich.panel import Panel
11
+ from rich.pretty import Pretty
12
+ from rich.table import Table
13
+
14
+ from p2predict.cmdline_io import print_logo
15
+ from p2predict.explain import Explanation, explain_batch, explain_row, top_drivers
16
+ from p2predict.intervals import coverage_health, predict_interval
17
+ from p2predict.json_output import JSON_SCHEMA_VERSION, emit, emit_error
18
+ # NOTE: the JSON serializers (_interval_per_row, _explanation_to_dict,
19
+ # _whatif_to_dict) are defined LOCALLY below, on purpose. They are NOT imported
20
+ # from model_utils. The CLI ``--json`` schema is versioned (JSON_SCHEMA_VERSION)
21
+ # and locked by tests/test_json_output.py, so it stays deliberately lean. The
22
+ # model_utils copies serve the MCP surface and carry extra business-facing
23
+ # fields (price_drivers, starting_point, interval reliability/say_to_user,
24
+ # what_if summary) that the locked CLI schema intentionally omits. Keep these
25
+ # two surfaces independent. Only the pipeline helpers below are shared.
26
+ from p2predict.model_utils import (
27
+ coerce_features as _coerce_features,
28
+ extract_feature_info as _extract_feature_info,
29
+ inner_pipeline as _inner_pipeline,
30
+ )
31
+ from p2predict.trained_model_io import LoadModel
32
+ from p2predict.whatif import (
33
+ WhatIfResult,
34
+ compute_whatif,
35
+ interaction_is_material,
36
+ parse_changes,
37
+ )
38
+
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # Error path that respects --json. Use this instead of
42
+ # ``console.print(...); raise SystemExit(1)`` so an agent piping stdout
43
+ # still gets a parseable JSON error document on failure.
44
+ # ---------------------------------------------------------------------------
45
+
46
+
47
+ def _abort(json_mode: bool, console, code: str, message: str) -> None:
48
+ if json_mode:
49
+ emit_error("predict", code, message)
50
+ console.print(f"Aborted: {message}", style="bold red")
51
+ raise SystemExit(1)
52
+
53
+
54
+ # ---------------------------------------------------------------------------
55
+ # Building blocks for the JSON result document. Each helper takes an
56
+ # in-memory object from the model stack and turns it into the schema
57
+ # shape defined in p2predict.json_output.
58
+ # ---------------------------------------------------------------------------
59
+
60
+
61
+ def _model_block(model_path: str, loaded: dict, target_name: str) -> dict:
62
+ return {
63
+ "path": model_path,
64
+ "algorithm": loaded.get("model_name"),
65
+ "target": target_name,
66
+ "version": loaded.get("p2predict_version"),
67
+ "log_target": bool(loaded.get("log_target", False)),
68
+ "features": list(loaded.get("features", [])),
69
+ }
70
+
71
+
72
+ def _interval_per_row(intervals) -> list[dict]:
73
+ return [
74
+ {
75
+ "low": float(ir.low),
76
+ "prediction": float(ir.prediction),
77
+ "high": float(ir.high),
78
+ # Calibration band the width came from (None = global quantile:
79
+ # old model file or calibration set too small to band).
80
+ "band": ir.band,
81
+ }
82
+ for ir in intervals
83
+ ]
84
+
85
+
86
+ def _explanation_to_dict(explanation: Explanation) -> dict:
87
+ out = {
88
+ "baseline": float(explanation.baseline),
89
+ "prediction": float(explanation.prediction),
90
+ "log_target": bool(explanation.log_target),
91
+ "contributions": [
92
+ {"feature": k, "value": float(v)}
93
+ for k, v in sorted(
94
+ explanation.contributions.items(), key=lambda kv: abs(kv[1]), reverse=True
95
+ )
96
+ ],
97
+ "residual": float(explanation.residual),
98
+ }
99
+ if explanation.log_target and explanation.multiplicative_factors is not None:
100
+ out["multiplicative_factors"] = [
101
+ {"feature": k, "factor": float(v)}
102
+ for k, v in sorted(
103
+ explanation.multiplicative_factors.items(),
104
+ key=lambda kv: abs(np.log(kv[1])) if kv[1] > 0 else 0.0,
105
+ reverse=True,
106
+ )
107
+ ]
108
+ out["dollar_attribution"] = (
109
+ [
110
+ {"feature": k, "value": float(v)}
111
+ for k, v in sorted(
112
+ explanation.dollar_attribution.items(),
113
+ key=lambda kv: abs(kv[1]),
114
+ reverse=True,
115
+ )
116
+ ]
117
+ if explanation.dollar_attribution is not None
118
+ else None
119
+ )
120
+ else:
121
+ out["multiplicative_factors"] = None
122
+ out["dollar_attribution"] = None
123
+ return out
124
+
125
+
126
+ def _whatif_to_dict(result: WhatIfResult) -> dict:
127
+ out = {
128
+ "changes": {
129
+ col: {"from": base_val, "to": cf_val}
130
+ for col, (base_val, cf_val) in result.changes.items()
131
+ },
132
+ "base_prediction": float(result.base_prediction),
133
+ "counterfactual_prediction": float(result.counterfactual_prediction),
134
+ "delta": float(result.delta),
135
+ "delta_pct": float(result.delta_pct),
136
+ "log_target": bool(result.log_target),
137
+ "multiplicative_factor": (
138
+ float(result.multiplicative_factor)
139
+ if result.multiplicative_factor is not None
140
+ else None
141
+ ),
142
+ "changed_contributions": [
143
+ {"feature": k, "value": float(v)}
144
+ for k, v in sorted(
145
+ result.changed_contributions.items(), key=lambda kv: abs(kv[1]), reverse=True
146
+ )
147
+ ],
148
+ "interaction_contribution": float(result.interaction_contribution),
149
+ "interaction_is_material": bool(interaction_is_material(result)),
150
+ "base_interval": (
151
+ {"low": float(result.base_interval.low), "high": float(result.base_interval.high)}
152
+ if result.base_interval is not None
153
+ else None
154
+ ),
155
+ "cf_interval": (
156
+ {"low": float(result.cf_interval.low), "high": float(result.cf_interval.high)}
157
+ if result.cf_interval is not None
158
+ else None
159
+ ),
160
+ }
161
+ return out
162
+
163
+
164
+ # ---------------------------------------------------------------------------
165
+ # Rich rendering helpers. Unchanged from prior versions — they only run
166
+ # when --json is absent.
167
+ # ---------------------------------------------------------------------------
168
+
169
+
170
+ def _print_explanation(console, explanation: Explanation, target_name: str) -> None:
171
+ table = Table(
172
+ title="Prediction Explanation (SHAP)",
173
+ show_header=True,
174
+ header_style="bold magenta",
175
+ expand=False,
176
+ )
177
+
178
+ if not explanation.log_target:
179
+ table.add_column("Feature")
180
+ table.add_column("Contribution", justify="right")
181
+ table.add_row(
182
+ "[dim]Baseline (model expected value)[/dim]",
183
+ f"{explanation.baseline:+.2f}",
184
+ )
185
+ ordered = sorted(
186
+ explanation.contributions.items(), key=lambda kv: abs(kv[1]), reverse=True
187
+ )
188
+ for col, value in ordered:
189
+ sign_style = "green" if value >= 0 else "red"
190
+ table.add_row(col, f"[{sign_style}]{value:+.2f}[/{sign_style}]")
191
+ table.add_row(
192
+ "[bold]Predicted " + target_name + "[/bold]",
193
+ f"[bold yellow]{explanation.prediction:+.2f}[/bold yellow]",
194
+ )
195
+ console.print(table)
196
+ else:
197
+ table.add_column("Feature")
198
+ table.add_column("× Factor", justify="right")
199
+ table.add_column("Effect", justify="right")
200
+ baseline = explanation.baseline_price
201
+ prediction = explanation.predicted_price
202
+ table.add_row(
203
+ "[dim]Baseline (geometric mean)[/dim]",
204
+ "—",
205
+ f"{baseline:,.2f}",
206
+ )
207
+ ordered = sorted(
208
+ explanation.multiplicative_factors.items(),
209
+ key=lambda kv: abs(np.log(kv[1])) if kv[1] > 0 else 0.0,
210
+ reverse=True,
211
+ )
212
+ for col, factor in ordered:
213
+ pct = (factor - 1.0) * 100.0
214
+ sign_style = "green" if pct >= 0 else "red"
215
+ table.add_row(
216
+ col,
217
+ f"×{factor:.3f}",
218
+ f"[{sign_style}]{pct:+.1f}%[/{sign_style}]",
219
+ )
220
+ table.add_row(
221
+ "[bold]Predicted " + target_name + "[/bold]",
222
+ "—",
223
+ f"[bold yellow]{prediction:,.2f}[/bold yellow]",
224
+ )
225
+ console.print(table)
226
+ console.print(
227
+ "Multiplicative factors are strict SHAP in price space "
228
+ "(their product equals predicted / baseline).",
229
+ style="italic dim",
230
+ )
231
+ if explanation.dollar_attribution is not None:
232
+ d_table = Table(
233
+ title="Approximate Dollar Attribution (rescaled, not strict SHAP)",
234
+ show_header=True,
235
+ header_style="bold magenta",
236
+ expand=False,
237
+ )
238
+ d_table.add_column("Feature")
239
+ d_table.add_column("Approx. contribution", justify="right")
240
+ ordered_d = sorted(
241
+ explanation.dollar_attribution.items(),
242
+ key=lambda kv: abs(kv[1]),
243
+ reverse=True,
244
+ )
245
+ for col, value in ordered_d:
246
+ sign_style = "green" if value >= 0 else "red"
247
+ d_table.add_row(col, f"[{sign_style}]{value:+,.2f}[/{sign_style}]")
248
+ console.print(d_table)
249
+
250
+ if abs(explanation.residual) > 1e-3 * max(1.0, abs(explanation.prediction)):
251
+ console.print(
252
+ f"Note: local-accuracy residual is {explanation.residual:+.3g} "
253
+ "(non-trivial; the SHAP/model wiring may need a look).",
254
+ style="italic yellow",
255
+ )
256
+
257
+
258
+ def _print_interval(console, interval_result, target_name: str, coverage_pct: int) -> None:
259
+ out_of_10 = round(coverage_pct / 10)
260
+ table = Table(
261
+ title=f"Likely range ({coverage_pct}%)",
262
+ show_header=True,
263
+ header_style="bold magenta",
264
+ expand=False,
265
+ )
266
+ table.add_column(f"Low {target_name}", justify="right")
267
+ table.add_column(f"Predicted {target_name}", justify="right")
268
+ table.add_column(f"High {target_name}", justify="right")
269
+ table.add_row(
270
+ f"[cyan]{interval_result.low:,.2f}[/cyan]",
271
+ f"[bold yellow]{interval_result.prediction:,.2f}[/bold yellow]",
272
+ f"[cyan]{interval_result.high:,.2f}[/cyan]",
273
+ )
274
+ console.print(table)
275
+ console.print(
276
+ f"The {target_name.lower()} for about {out_of_10} in 10 similar "
277
+ f"parts falls in this range. Quotes outside it are unusual "
278
+ "and worth questioning.",
279
+ style="italic dim",
280
+ )
281
+ if interval_result.band:
282
+ console.print(
283
+ f"Range width calibrated on similar parts ({interval_result.band}).",
284
+ style="italic dim",
285
+ )
286
+
287
+
288
+ def _print_whatif(console, result: WhatIfResult, target_name: str) -> None:
289
+ headline_table = Table(
290
+ title="What-if Analysis", show_header=True, header_style="bold magenta",
291
+ expand=False,
292
+ )
293
+ headline_table.add_column("Scenario")
294
+ headline_table.add_column(f"Predicted {target_name}", justify="right")
295
+ if result.base_interval is not None:
296
+ headline_table.add_column("Likely range", justify="right")
297
+ headline_table.add_row(
298
+ "Base",
299
+ f"{result.base_prediction:,.2f}",
300
+ *(
301
+ [f"{result.base_interval.low:,.2f} – {result.base_interval.high:,.2f}"]
302
+ if result.base_interval is not None
303
+ else []
304
+ ),
305
+ )
306
+ headline_table.add_row(
307
+ "Counterfactual",
308
+ f"{result.counterfactual_prediction:,.2f}",
309
+ *(
310
+ [f"{result.cf_interval.low:,.2f} – {result.cf_interval.high:,.2f}"]
311
+ if result.cf_interval is not None
312
+ else []
313
+ ),
314
+ )
315
+ sign_style = "green" if result.delta >= 0 else "red"
316
+ delta_label = "Change"
317
+ delta_value = f"[{sign_style}]{result.delta:+,.2f}[/{sign_style}]"
318
+ if result.log_target and result.multiplicative_factor is not None:
319
+ delta_pct = (result.multiplicative_factor - 1.0) * 100.0
320
+ delta_value += (
321
+ f" ([{sign_style}]{delta_pct:+.1f}%[/{sign_style}], "
322
+ f"×{result.multiplicative_factor:.3f})"
323
+ )
324
+ else:
325
+ delta_value += f" ([{sign_style}]{result.delta_pct:+.1f}%[/{sign_style}])"
326
+ headline_table.add_row(delta_label, delta_value, *([""] if result.base_interval is not None else []))
327
+ console.print(headline_table)
328
+
329
+ changes_table = Table(
330
+ title="Changes applied",
331
+ show_header=True, header_style="bold magenta", expand=False,
332
+ )
333
+ changes_table.add_column("Feature")
334
+ changes_table.add_column("Base", justify="right")
335
+ changes_table.add_column("Counterfactual", justify="right")
336
+ for col, (base_val, cf_val) in result.changes.items():
337
+ changes_table.add_row(col, str(base_val), str(cf_val))
338
+ console.print(changes_table)
339
+
340
+ attribution_table = Table(
341
+ title=("Drivers of the change (SHAP × factor)" if result.log_target
342
+ else "Drivers of the change (SHAP)"),
343
+ show_header=True, header_style="bold magenta", expand=False,
344
+ )
345
+ attribution_table.add_column("Feature")
346
+ if result.log_target:
347
+ attribution_table.add_column("× Factor", justify="right")
348
+ attribution_table.add_column("Effect", justify="right")
349
+ else:
350
+ attribution_table.add_column("Contribution", justify="right")
351
+ attribution_table.add_column("Share", justify="right")
352
+
353
+ abs_total = sum(abs(v) for v in result.changed_contributions.values()) + abs(
354
+ result.interaction_contribution
355
+ )
356
+ abs_total = abs_total if abs_total > 1e-12 else 1.0
357
+ ordered = sorted(
358
+ result.changed_contributions.items(),
359
+ key=lambda kv: abs(kv[1]),
360
+ reverse=True,
361
+ )
362
+ for col, value in ordered:
363
+ if result.log_target and result.changed_multiplicative_factors is not None:
364
+ factor = result.changed_multiplicative_factors[col]
365
+ pct = (factor - 1.0) * 100.0
366
+ sign_style = "green" if pct >= 0 else "red"
367
+ attribution_table.add_row(
368
+ col,
369
+ f"×{factor:.3f}",
370
+ f"[{sign_style}]{pct:+.1f}%[/{sign_style}]",
371
+ )
372
+ else:
373
+ share = abs(value) / abs_total * 100.0
374
+ sign_style = "green" if value >= 0 else "red"
375
+ attribution_table.add_row(
376
+ col,
377
+ f"[{sign_style}]{value:+,.2f}[/{sign_style}]",
378
+ f"{share:.0f}%",
379
+ )
380
+
381
+ if interaction_is_material(result):
382
+ if result.log_target and result.interaction_multiplicative_factor is not None:
383
+ factor = result.interaction_multiplicative_factor
384
+ pct = (factor - 1.0) * 100.0
385
+ sign_style = "green" if pct >= 0 else "red"
386
+ attribution_table.add_row(
387
+ "[dim]Other interaction effects[/dim]",
388
+ f"×{factor:.3f}",
389
+ f"[{sign_style}]{pct:+.1f}%[/{sign_style}]",
390
+ )
391
+ else:
392
+ share = abs(result.interaction_contribution) / abs_total * 100.0
393
+ sign_style = "green" if result.interaction_contribution >= 0 else "red"
394
+ attribution_table.add_row(
395
+ "[dim]Other interaction effects[/dim]",
396
+ f"[{sign_style}]{result.interaction_contribution:+,.2f}[/{sign_style}]",
397
+ f"{share:.0f}%",
398
+ )
399
+ console.print(attribution_table)
400
+
401
+ if result.log_target:
402
+ console.print(
403
+ "Factors multiply: × Region × Supplier × ... = total change factor.",
404
+ style="italic dim",
405
+ )
406
+ else:
407
+ console.print(
408
+ "Contributions add up to the total change. "
409
+ "Features you didn't change can still show up here when there are interactions in the model.",
410
+ style="italic dim",
411
+ )
412
+
413
+
414
+ @click.command()
415
+ @click.option("-m", "--model", type=click.Path(exists=True),
416
+ help="Path to the trained model file (.model)")
417
+ @click.option("-p", "--predict_using",
418
+ help='Feature values, e.g. "weight:100,color:red"')
419
+ @click.option("-i", "--predict_file", type=click.Path(exists=True),
420
+ help="CSV file containing feature values for batch prediction")
421
+ @click.option("--explain", "explain_flag", is_flag=True, default=False,
422
+ help="Show per-feature SHAP attribution alongside the prediction.")
423
+ @click.option("--interval", "interval_coverage", type=click.IntRange(1, 99),
424
+ default=None,
425
+ help="Show the model's likely range for the prediction "
426
+ "(e.g. --interval 90 for a range that contains the value "
427
+ "for about 9 in 10 similar parts). Range is calibrated on "
428
+ "the training holdout; see the README for the math.")
429
+ @click.option("--whatif", "whatif_spec", default=None,
430
+ help='What-if comparison. Takes a base scenario from -p (or '
431
+ 'CSV row 0 from -i) and compares it to a counterfactual '
432
+ 'where one or more features change. Same format as -p, '
433
+ 'e.g. --whatif "Region:EU,Supplier:B".')
434
+ @click.option("--json", "json_mode", is_flag=True, default=False,
435
+ help="Emit machine-readable JSON to stdout instead of "
436
+ "Rich-formatted tables. Useful for agents and scripts. "
437
+ "See p2predict.json_output for the schema.")
438
+ def main(model, predict_using, predict_file, explain_flag, interval_coverage,
439
+ whatif_spec, json_mode):
440
+ # Under --json, redirect Rich's console to /dev/null so any
441
+ # console.print() that escapes a guard does not corrupt the JSON
442
+ # document on stdout. The schema is the contract; this is the belt.
443
+ if json_mode:
444
+ console = Console(file=open(os.devnull, "w"))
445
+ else:
446
+ console = Console()
447
+
448
+ if not json_mode:
449
+ print("")
450
+ print_logo()
451
+ print("")
452
+
453
+ if not model:
454
+ if json_mode:
455
+ _abort(json_mode, console, "missing_model",
456
+ "--model (or -m) is required when using --json.")
457
+ model = questionary.path("Enter model file path (.model)").ask()
458
+ if not model:
459
+ _abort(json_mode, console, "missing_model",
460
+ "please enter the path to the trained model.")
461
+
462
+ loaded = LoadModel(model)
463
+ trained = loaded["model"]
464
+ if not trained:
465
+ _abort(json_mode, console, "corrupt_model",
466
+ "the selected model is corrupt.")
467
+
468
+ if not json_mode:
469
+ console.print(f"'{model}' successfully loaded.", style="bold white")
470
+ if loaded.get("log_target"):
471
+ console.print("(log-target transform active)", style="italic")
472
+ print("")
473
+
474
+ inner = _inner_pipeline(trained)
475
+ feature_types, all_categories = _extract_feature_info(inner)
476
+
477
+ if not json_mode:
478
+ table = Table(title="Model Features", show_header=True, header_style="bold magenta")
479
+ table.add_column("Feature", style="dim", width=20)
480
+ table.add_column("Type", justify="right")
481
+ for feature, feature_type in feature_types.items():
482
+ table.add_row(feature, feature_type)
483
+ console.print(table)
484
+
485
+ console.print(f"\nTarget feature: [bold blue]'{loaded['target_feature']}'[/bold blue]")
486
+
487
+ if all_categories:
488
+ console.print("\nCategorical Features:")
489
+ for feature, categories in all_categories.items():
490
+ console.print(f"[bold]{feature}[/bold]")
491
+ for category in categories:
492
+ console.print(f" • {category}")
493
+ console.print("")
494
+ else:
495
+ console.print("No categorical features to display.", style="italic")
496
+
497
+ console.print("\n" + "=" * 50 + "\n")
498
+
499
+ background = loaded.get("background_sample")
500
+ target_name = loaded["target_feature"]
501
+ calibration = loaded.get("calibration")
502
+
503
+ # Decide whether the model can support a likely-range interval at all,
504
+ # and whether to soft-warn about a small calibration set.
505
+ interval_soft_warning: Optional[str] = None
506
+ if interval_coverage is not None:
507
+ warning = coverage_health(calibration)
508
+ if warning and (calibration is None or calibration.get("n_calibration", 0) == 0):
509
+ if not json_mode:
510
+ console.print(
511
+ f"Likely range disabled: {warning}.", style="italic yellow"
512
+ )
513
+ interval_coverage = None
514
+ elif warning:
515
+ interval_soft_warning = warning
516
+
517
+ # --whatif is inline-only (needs a single base scenario).
518
+ if whatif_spec is not None and predict_file:
519
+ _abort(json_mode, console, "whatif_in_batch",
520
+ "--whatif is not supported in batch mode (-i). "
521
+ "Use -p to specify a base scenario.")
522
+
523
+ # Build the JSON response throughout. Mode-dependent blocks get
524
+ # added as we go; we emit the whole thing at the end under --json.
525
+ response: dict = {
526
+ "schema_version": JSON_SCHEMA_VERSION,
527
+ "command": "predict",
528
+ "model": _model_block(model, loaded, target_name),
529
+ }
530
+
531
+ features_dict = {}
532
+ if predict_using:
533
+ response["mode"] = "inline"
534
+ features_dict = dict(item.split(":") for item in predict_using.split(","))
535
+ features_df = _coerce_features(pd.DataFrame([features_dict]), feature_types)
536
+ y = trained.predict(features_df)
537
+ features_df["prediction"] = y
538
+ if not json_mode:
539
+ console.print(Panel(Pretty(features_df), title="Prediction"))
540
+
541
+ response["predictions"] = [
542
+ {"input": features_dict, "prediction": float(y[0])}
543
+ ]
544
+
545
+ if interval_coverage is not None:
546
+ [interval_result] = predict_interval(
547
+ trained, features_df[loaded["features"]],
548
+ calibration, coverage=interval_coverage / 100.0,
549
+ )
550
+ if not json_mode:
551
+ print("")
552
+ _print_interval(console, interval_result, target_name, interval_coverage)
553
+ if interval_soft_warning:
554
+ console.print(f"Note: {interval_soft_warning}.", style="italic yellow")
555
+ response["interval"] = {
556
+ "coverage": interval_coverage / 100.0,
557
+ "per_row": _interval_per_row([interval_result]),
558
+ "soft_warning": interval_soft_warning,
559
+ }
560
+
561
+ if explain_flag:
562
+ explanation = explain_row(trained, features_df[loaded["features"]], background)
563
+ if not json_mode:
564
+ print("")
565
+ _print_explanation(console, explanation, target_name)
566
+ response["explanation"] = [_explanation_to_dict(explanation)]
567
+
568
+ if whatif_spec is not None:
569
+ try:
570
+ changes = parse_changes(whatif_spec)
571
+ except ValueError as exc:
572
+ _abort(json_mode, console, "bad_whatif", str(exc))
573
+ try:
574
+ whatif_result = compute_whatif(
575
+ trained,
576
+ features_df[loaded["features"]],
577
+ changes,
578
+ feature_types,
579
+ background_X=background,
580
+ calibration=calibration if interval_coverage is not None else None,
581
+ coverage=(interval_coverage or 90) / 100.0,
582
+ )
583
+ except ValueError as exc:
584
+ _abort(json_mode, console, "bad_whatif", str(exc))
585
+ if not json_mode:
586
+ print("")
587
+ _print_whatif(console, whatif_result, target_name)
588
+ response["whatif"] = _whatif_to_dict(whatif_result)
589
+
590
+ elif predict_file:
591
+ response["mode"] = "batch"
592
+ features_df = pd.read_csv(predict_file)
593
+ features_df = _coerce_features(features_df, feature_types)
594
+ y = trained.predict(features_df)
595
+ features_df[target_name] = y
596
+
597
+ per_row = [
598
+ {"input": features_df[loaded["features"]].iloc[i].to_dict(),
599
+ "prediction": float(y[i])}
600
+ for i in range(len(features_df))
601
+ ]
602
+ response["predictions"] = per_row
603
+
604
+ if interval_coverage is not None:
605
+ intervals = predict_interval(
606
+ trained, features_df[loaded["features"]],
607
+ calibration, coverage=interval_coverage / 100.0,
608
+ )
609
+ features_df[f"{target_name}_low"] = [ir.low for ir in intervals]
610
+ features_df[f"{target_name}_high"] = [ir.high for ir in intervals]
611
+ response["interval"] = {
612
+ "coverage": interval_coverage / 100.0,
613
+ "per_row": _interval_per_row(intervals),
614
+ "soft_warning": interval_soft_warning,
615
+ }
616
+ if explain_flag:
617
+ top1, top2, top3 = [], [], []
618
+ per_row_explanations = []
619
+ explanations = explain_batch(
620
+ trained, features_df[loaded["features"]], background
621
+ )
622
+ for ex in explanations:
623
+ per_row_explanations.append(_explanation_to_dict(ex))
624
+ drivers = top_drivers(ex, n=3)
625
+ formatted = []
626
+ for col, value in drivers:
627
+ if ex.log_target:
628
+ pct = (value - 1.0) * 100.0
629
+ formatted.append(f"{col} ({pct:+.1f}%)")
630
+ else:
631
+ formatted.append(f"{col} ({value:+.2f})")
632
+ while len(formatted) < 3:
633
+ formatted.append("")
634
+ top1.append(formatted[0])
635
+ top2.append(formatted[1])
636
+ top3.append(formatted[2])
637
+ features_df["top1_driver"] = top1
638
+ features_df["top2_driver"] = top2
639
+ features_df["top3_driver"] = top3
640
+ response["explanation"] = per_row_explanations
641
+
642
+ features_df.to_csv(predict_file, index=False)
643
+ response["batch"] = {
644
+ "csv_path": str(predict_file), "n_rows": int(len(features_df)),
645
+ }
646
+ if not json_mode:
647
+ console.print(Panel(Pretty(features_df), title="Prediction"))
648
+
649
+ else:
650
+ # Interactive mode is incompatible with --json.
651
+ if json_mode:
652
+ _abort(json_mode, console, "missing_input",
653
+ "interactive mode is not supported with --json. "
654
+ "Use -p (inline) or -i (batch).")
655
+ response["mode"] = "interactive"
656
+ for feature in loaded["features"]:
657
+ if feature in all_categories:
658
+ value = questionary.select(
659
+ f"Select a value for {feature}:",
660
+ choices=[str(c) for c in all_categories[feature]],
661
+ ).ask()
662
+ else:
663
+ value = questionary.text(f"Enter a numeric value for {feature}:").ask()
664
+ if not value:
665
+ _abort(json_mode, console, "missing_input",
666
+ f"please enter a value for {feature}.")
667
+ features_dict[feature] = value
668
+
669
+ features_df = _coerce_features(pd.DataFrame([features_dict]), feature_types)
670
+ y = trained.predict(features_df)
671
+ features_df["prediction"] = y
672
+
673
+ table = Table(title="Prediction Results", show_header=True, header_style="bold magenta")
674
+ for column in features_df.columns:
675
+ table.add_column(column, style="cyan", justify="right")
676
+ table.add_row(*[str(val) for val in features_df.iloc[0]])
677
+ console.print(Panel(table, expand=False, border_style="green", padding=(1, 1)))
678
+
679
+ prediction_value = features_df["prediction"].iloc[0]
680
+ console.print(
681
+ f"\n[bold]Predicted {loaded['target_feature']}:[/bold] "
682
+ f"[yellow]{prediction_value:.2f}[/yellow]"
683
+ )
684
+ if interval_coverage is not None:
685
+ print("")
686
+ [interval_result] = predict_interval(
687
+ trained, features_df[loaded["features"]],
688
+ calibration, coverage=interval_coverage / 100.0,
689
+ )
690
+ _print_interval(console, interval_result, target_name, interval_coverage)
691
+ if interval_soft_warning:
692
+ console.print(f"Note: {interval_soft_warning}.", style="italic yellow")
693
+ if explain_flag:
694
+ print("")
695
+ explanation = explain_row(
696
+ trained, features_df[loaded["features"]], background
697
+ )
698
+ _print_explanation(console, explanation, target_name)
699
+
700
+ if json_mode:
701
+ emit(response)
702
+ return y
703
+
704
+
705
+ if __name__ == "__main__":
706
+ main()