rustystats 0.1.5__cp313-cp313-manylinux_2_34_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2471 @@
1
+ """
2
+ Model Diagnostics for RustyStats GLM
3
+ =====================================
4
+
5
+ This module provides comprehensive model diagnostics for assessing GLM quality.
6
+
7
+ Features:
8
+ - Overall model fit statistics
9
+ - Calibration metrics (A/E ratios, calibration curves)
10
+ - Discrimination metrics (Gini, lift, Lorenz curve)
11
+ - Per-factor diagnostics (for both fitted and unfitted factors)
12
+ - Interaction detection
13
+ - JSON export for LLM consumption
14
+
15
+ Usage:
16
+ ------
17
+ >>> result = rs.glm("y ~ x1 + C(region)", data, family="poisson").fit()
18
+ >>> diagnostics = result.diagnostics(
19
+ ... data=data,
20
+ ... categorical_factors=["region", "brand"],
21
+ ... continuous_factors=["age", "income"]
22
+ ... )
23
+ >>> print(diagnostics.to_json())
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import json
29
+ import math
30
+ from dataclasses import dataclass, field, asdict
31
+ from functools import cached_property
32
+
33
+ # Import Rust diagnostics functions
34
+ from rustystats._rustystats import (
35
+ compute_calibration_curve_py as _rust_calibration_curve,
36
+ compute_discrimination_stats_py as _rust_discrimination_stats,
37
+ compute_ae_continuous_py as _rust_ae_continuous,
38
+ compute_ae_categorical_py as _rust_ae_categorical,
39
+ compute_loss_metrics_py as _rust_loss_metrics,
40
+ compute_lorenz_curve_py as _rust_lorenz_curve,
41
+ hosmer_lemeshow_test_py as _rust_hosmer_lemeshow,
42
+ compute_fit_statistics_py as _rust_fit_statistics,
43
+ compute_residual_summary_py as _rust_residual_summary,
44
+ compute_residual_pattern_py as _rust_residual_pattern,
45
+ compute_pearson_residuals_py as _rust_pearson_residuals,
46
+ compute_deviance_residuals_py as _rust_deviance_residuals,
47
+ compute_null_deviance_py as _rust_null_deviance,
48
+ compute_unit_deviance_py as _rust_unit_deviance,
49
+ )
50
+ from typing import Optional, List, Dict, Any, Union, TYPE_CHECKING
51
+
52
+ import numpy as np
53
+
54
+ if TYPE_CHECKING:
55
+ import polars as pl
56
+
57
+
58
+ # =============================================================================
59
+ # Data Classes for Diagnostics Structure
60
+ # =============================================================================
61
+
62
+ @dataclass
63
+ class Percentiles:
64
+ """Percentile values for a continuous variable."""
65
+ p1: float
66
+ p5: float
67
+ p10: float
68
+ p25: float
69
+ p50: float
70
+ p75: float
71
+ p90: float
72
+ p95: float
73
+ p99: float
74
+
75
+
76
+ @dataclass
77
+ class ResidualSummary:
78
+ """Summary statistics for residuals (compressed: mean, std, skewness only)."""
79
+ mean: float
80
+ std: float
81
+ skewness: float
82
+
83
+
84
+ @dataclass
85
+ class CalibrationBin:
86
+ """A single bin in the calibration curve."""
87
+ bin_index: int
88
+ predicted_lower: float
89
+ predicted_upper: float
90
+ predicted_mean: float
91
+ actual_mean: float
92
+ actual_expected_ratio: float
93
+ count: int
94
+ exposure: float
95
+ actual_sum: float
96
+ predicted_sum: float
97
+ ae_confidence_interval_lower: float
98
+ ae_confidence_interval_upper: float
99
+
100
+
101
+ @dataclass
102
+ class LorenzPoint:
103
+ """A point on the Lorenz curve."""
104
+ cumulative_exposure_pct: float
105
+ cumulative_actual_pct: float
106
+ cumulative_predicted_pct: float
107
+
108
+
109
+ @dataclass
110
+ class ActualExpectedBin:
111
+ """A/E statistics for a single bin (compressed format)."""
112
+ bin: str # bin label or range
113
+ n: int # count
114
+ actual: int # actual_sum (rounded)
115
+ predicted: int # predicted_sum (rounded)
116
+ ae: float # actual/expected ratio
117
+ ae_ci: List[float] # [lower, upper] confidence interval
118
+
119
+
120
+ @dataclass
121
+ class ResidualPattern:
122
+ """Residual pattern analysis for a factor (compressed)."""
123
+ resid_corr: float # correlation_with_residuals
124
+ var_explained: float # residual_variance_explained
125
+
126
+
127
+ @dataclass
128
+ class ContinuousFactorStats:
129
+ """Univariate statistics for a continuous factor."""
130
+ mean: float
131
+ std: float
132
+ min: float
133
+ max: float
134
+ missing_count: int
135
+ percentiles: Percentiles
136
+
137
+
138
+ @dataclass
139
+ class CategoricalLevelStats:
140
+ """Statistics for a categorical level."""
141
+ level: str
142
+ count: int
143
+ percentage: float
144
+
145
+
146
+ @dataclass
147
+ class CategoricalFactorStats:
148
+ """Distribution statistics for a categorical factor (compressed: no levels array)."""
149
+ n_levels: int
150
+ n_rare_levels: int
151
+ rare_level_total_pct: float
152
+
153
+
154
+ @dataclass
155
+ class FactorSignificance:
156
+ """Statistical significance tests for a factor (compressed field names)."""
157
+ chi2: Optional[float] # Wald chi-square test statistic
158
+ p: Optional[float] # p-value for Wald test
159
+ dev_contrib: Optional[float] # Drop-in-deviance if term removed
160
+
161
+
162
+ @dataclass
163
+ class FactorDiagnostics:
164
+ """Complete diagnostics for a single factor."""
165
+ name: str
166
+ factor_type: str # "continuous" or "categorical"
167
+ in_model: bool
168
+ transformation: Optional[str]
169
+ univariate_stats: Union[ContinuousFactorStats, CategoricalFactorStats]
170
+ actual_vs_expected: List[ActualExpectedBin]
171
+ residual_pattern: ResidualPattern
172
+ significance: Optional[FactorSignificance] = None # Significance tests (only for in_model factors)
173
+
174
+
175
+ @dataclass
176
+ class InteractionCandidate:
177
+ """A potential interaction between two factors."""
178
+ factor1: str
179
+ factor2: str
180
+ interaction_strength: float
181
+ pvalue: float
182
+ n_cells: int
183
+ current_terms: Optional[List[str]] = None # How factors currently appear in model
184
+ recommendation: Optional[str] = None # Suggested action
185
+
186
+
187
+ @dataclass
188
+ class ConvergenceDetails:
189
+ """Details about model convergence."""
190
+ max_iterations_allowed: int
191
+ iterations_used: int
192
+ converged: bool
193
+ reason: str # "converged", "max_iterations_reached", "gradient_tolerance", etc.
194
+
195
+
196
+ @dataclass
197
+ class DataExploration:
198
+ """Pre-fit data exploration results."""
199
+
200
+ # Data summary
201
+ data_summary: Dict[str, Any]
202
+
203
+ # Factor statistics
204
+ factor_stats: List[Dict[str, Any]]
205
+
206
+ # Missing value analysis
207
+ missing_values: Dict[str, Any]
208
+
209
+ # Univariate significance tests (each factor vs response)
210
+ univariate_tests: List[Dict[str, Any]]
211
+
212
+ # Correlation matrix for continuous factors
213
+ correlations: Dict[str, Any]
214
+
215
+ # Cramér's V matrix for categorical factors
216
+ cramers_v: Dict[str, Any]
217
+
218
+ # Variance inflation factors (multicollinearity)
219
+ vif: List[Dict[str, Any]]
220
+
221
+ # Zero inflation check (for count data)
222
+ zero_inflation: Dict[str, Any]
223
+
224
+ # Overdispersion check
225
+ overdispersion: Dict[str, Any]
226
+
227
+ # Interaction candidates
228
+ interaction_candidates: List[InteractionCandidate]
229
+
230
+ # Response distribution
231
+ response_stats: Dict[str, Any]
232
+
233
+ def to_dict(self) -> Dict[str, Any]:
234
+ """Convert to dictionary."""
235
+ return _to_dict_recursive(self)
236
+
237
+ def to_json(self, indent: Optional[int] = None) -> str:
238
+ """Convert to JSON string."""
239
+ return json.dumps(self.to_dict(), indent=indent, default=_json_default)
240
+
241
+
242
+ @dataclass
243
+ class ModelDiagnostics:
244
+ """Complete model diagnostics output."""
245
+
246
+ # Model summary
247
+ model_summary: Dict[str, Any]
248
+
249
+ # Convergence details (especially important when converged=False)
250
+ convergence_details: Optional[ConvergenceDetails]
251
+
252
+ # Fit statistics
253
+ fit_statistics: Dict[str, float]
254
+
255
+ # Loss metrics
256
+ loss_metrics: Dict[str, float]
257
+
258
+ # Calibration
259
+ calibration: Dict[str, Any]
260
+
261
+ # Discrimination (only for applicable models)
262
+ discrimination: Optional[Dict[str, Any]]
263
+
264
+ # Residual summary
265
+ residual_summary: Dict[str, ResidualSummary]
266
+
267
+ # Per-factor diagnostics
268
+ factors: List[FactorDiagnostics]
269
+
270
+ # Interaction candidates
271
+ interaction_candidates: List[InteractionCandidate]
272
+
273
+ # Model comparison vs null
274
+ model_comparison: Dict[str, float]
275
+
276
+ # Warnings
277
+ warnings: List[Dict[str, str]]
278
+
279
+ def to_dict(self) -> Dict[str, Any]:
280
+ """Convert to dictionary, handling nested dataclasses."""
281
+ return _to_dict_recursive(self)
282
+
283
+ def to_json(self, indent: Optional[int] = None) -> str:
284
+ """Convert to JSON string."""
285
+ return json.dumps(self.to_dict(), indent=indent, default=_json_default)
286
+
287
+
288
+ def _json_default(obj):
289
+ """Handle special types for JSON serialization."""
290
+ if isinstance(obj, float):
291
+ if math.isnan(obj):
292
+ return None
293
+ if math.isinf(obj):
294
+ return None
295
+ if hasattr(obj, '__dict__'):
296
+ return obj.__dict__
297
+ return str(obj)
298
+
299
+
300
+ def _round_float(x: float, decimals: int = 4) -> float:
301
+ """Round float for token-efficient JSON output."""
302
+ if x == 0:
303
+ return 0.0
304
+ # Use fewer decimals for large numbers, more for small
305
+ if abs(x) >= 100:
306
+ return round(x, 2)
307
+ elif abs(x) >= 1:
308
+ return round(x, 4)
309
+ else:
310
+ return round(x, 6)
311
+
312
+
313
+ def _to_dict_recursive(obj) -> Any:
314
+ """Recursively convert dataclasses and handle special values."""
315
+ if isinstance(obj, dict):
316
+ return {k: _to_dict_recursive(v) for k, v in obj.items()}
317
+ elif isinstance(obj, list):
318
+ return [_to_dict_recursive(v) for v in obj]
319
+ elif hasattr(obj, '__dataclass_fields__'):
320
+ return {k: _to_dict_recursive(v) for k, v in asdict(obj).items()}
321
+ elif isinstance(obj, float):
322
+ if math.isnan(obj) or math.isinf(obj):
323
+ return None
324
+ return _round_float(obj)
325
+ elif isinstance(obj, np.ndarray):
326
+ return [_to_dict_recursive(v) for v in obj.tolist()]
327
+ elif isinstance(obj, np.floating):
328
+ return _round_float(float(obj))
329
+ elif isinstance(obj, np.integer):
330
+ return int(obj)
331
+ else:
332
+ return obj
333
+
334
+
335
+ # =============================================================================
336
+ # Focused Diagnostic Components
337
+ # =============================================================================
338
+ #
339
+ # Each component handles a specific type of diagnostic computation.
340
+ # DiagnosticsComputer coordinates these components to produce unified output.
341
+ # =============================================================================
342
+
343
+ class _ResidualComputer:
344
+ """Computes and caches residuals."""
345
+
346
+ def __init__(self, y: np.ndarray, mu: np.ndarray, family: str, exposure: np.ndarray):
347
+ self.y = y
348
+ self.mu = mu
349
+ self.family = family
350
+ self.exposure = exposure
351
+ self._pearson = None
352
+ self._deviance = None
353
+ self._null_dev = None
354
+
355
+ @property
356
+ def pearson(self) -> np.ndarray:
357
+ if self._pearson is None:
358
+ self._pearson = np.asarray(_rust_pearson_residuals(self.y, self.mu, self.family))
359
+ return self._pearson
360
+
361
+ @property
362
+ def deviance(self) -> np.ndarray:
363
+ if self._deviance is None:
364
+ self._deviance = np.asarray(_rust_deviance_residuals(self.y, self.mu, self.family))
365
+ return self._deviance
366
+
367
+ @property
368
+ def null_deviance(self) -> float:
369
+ if self._null_dev is None:
370
+ self._null_dev = _rust_null_deviance(self.y, self.family, self.exposure)
371
+ return self._null_dev
372
+
373
+ def unit_deviance(self, y: np.ndarray, mu: np.ndarray) -> np.ndarray:
374
+ return np.asarray(_rust_unit_deviance(y, mu, self.family))
375
+
376
+
377
+ class _CalibrationComputer:
378
+ """Computes calibration metrics."""
379
+
380
+ def __init__(self, y: np.ndarray, mu: np.ndarray, exposure: np.ndarray):
381
+ self.y = y
382
+ self.mu = mu
383
+ self.exposure = exposure
384
+
385
+ def compute(self, n_bins: int = 10) -> Dict[str, Any]:
386
+ actual_total = float(np.sum(self.y))
387
+ predicted_total = float(np.sum(self.mu))
388
+ exposure_total = float(np.sum(self.exposure))
389
+ ae_ratio = actual_total / predicted_total if predicted_total > 0 else float('nan')
390
+
391
+ bins = self._compute_bins(n_bins)
392
+ hl_stat, hl_pvalue = self._hosmer_lemeshow(n_bins)
393
+
394
+ # Compressed format: only include problem deciles (A/E outside [0.9, 1.1])
395
+ problem_deciles = [
396
+ {
397
+ "decile": b.bin_index,
398
+ "ae": round(b.actual_expected_ratio, 2),
399
+ "n": b.count,
400
+ "ae_ci": [round(b.ae_confidence_interval_lower, 2), round(b.ae_confidence_interval_upper, 2)],
401
+ }
402
+ for b in bins
403
+ if b.actual_expected_ratio < 0.9 or b.actual_expected_ratio > 1.1
404
+ ]
405
+
406
+ return {
407
+ "ae_ratio": round(ae_ratio, 3),
408
+ "hl_pvalue": round(hl_pvalue, 4) if not np.isnan(hl_pvalue) else None,
409
+ "problem_deciles": problem_deciles,
410
+ }
411
+
412
+ def _compute_bins(self, n_bins: int) -> List[CalibrationBin]:
413
+ rust_bins = _rust_calibration_curve(self.y, self.mu, self.exposure, n_bins)
414
+ return [
415
+ CalibrationBin(
416
+ bin_index=b["bin_index"], predicted_lower=b["predicted_lower"],
417
+ predicted_upper=b["predicted_upper"], predicted_mean=b["predicted_mean"],
418
+ actual_mean=b["actual_mean"], actual_expected_ratio=b["actual_expected_ratio"],
419
+ count=b["count"], exposure=b["exposure"], actual_sum=b["actual_sum"],
420
+ predicted_sum=b["predicted_sum"], ae_confidence_interval_lower=b["ae_ci_lower"],
421
+ ae_confidence_interval_upper=b["ae_ci_upper"],
422
+ )
423
+ for b in rust_bins
424
+ ]
425
+
426
+ def _hosmer_lemeshow(self, n_bins: int) -> tuple:
427
+ result = _rust_hosmer_lemeshow(self.y, self.mu, n_bins)
428
+ return result["chi2_statistic"], result["pvalue"]
429
+
430
+
431
+ class _DiscriminationComputer:
432
+ """Computes discrimination metrics."""
433
+
434
+ def __init__(self, y: np.ndarray, mu: np.ndarray, exposure: np.ndarray):
435
+ self.y = y
436
+ self.mu = mu
437
+ self.exposure = exposure
438
+
439
+ def compute(self) -> Dict[str, Any]:
440
+ stats = _rust_discrimination_stats(self.y, self.mu, self.exposure)
441
+ # Removed lorenz_curve - Gini coefficient provides sufficient discrimination info
442
+ return {
443
+ "gini": round(stats["gini"], 3),
444
+ "auc": round(stats["auc"], 3),
445
+ "ks": round(stats["ks_statistic"], 3),
446
+ "lift_10pct": round(stats["lift_at_10pct"], 3),
447
+ "lift_20pct": round(stats["lift_at_20pct"], 3),
448
+ }
449
+
450
+
451
+ # =============================================================================
452
+ # Main Diagnostics Computation
453
+ # =============================================================================
454
+
455
+ class DiagnosticsComputer:
456
+ """
457
+ Computes comprehensive model diagnostics.
458
+
459
+ Coordinates focused component classes to produce unified diagnostics output.
460
+ All results are cached for efficiency.
461
+ """
462
+
463
+ def __init__(
464
+ self,
465
+ y: np.ndarray,
466
+ mu: np.ndarray,
467
+ linear_predictor: np.ndarray,
468
+ family: str,
469
+ n_params: int,
470
+ deviance: float,
471
+ exposure: Optional[np.ndarray] = None,
472
+ feature_names: Optional[List[str]] = None,
473
+ var_power: float = 1.5,
474
+ theta: float = 1.0,
475
+ ):
476
+ self.y = np.asarray(y, dtype=np.float64)
477
+ self.mu = np.asarray(mu, dtype=np.float64)
478
+ self.linear_predictor = np.asarray(linear_predictor, dtype=np.float64)
479
+ self.family = family.lower()
480
+ self.n_params = n_params
481
+ self.deviance = deviance
482
+ self.exposure = np.asarray(exposure, dtype=np.float64) if exposure is not None else np.ones_like(y)
483
+ self.feature_names = feature_names or []
484
+ self.var_power = var_power
485
+ self.theta = theta
486
+
487
+ self.n_obs = len(y)
488
+ self.df_resid = self.n_obs - n_params
489
+
490
+ # Initialize focused components
491
+ self._residuals = _ResidualComputer(self.y, self.mu, self.family, self.exposure)
492
+ self._calibration = _CalibrationComputer(self.y, self.mu, self.exposure)
493
+ self._discrimination = _DiscriminationComputer(self.y, self.mu, self.exposure)
494
+
495
+ @property
496
+ def pearson_residuals(self) -> np.ndarray:
497
+ return self._residuals.pearson
498
+
499
+ @property
500
+ def deviance_residuals(self) -> np.ndarray:
501
+ return self._residuals.deviance
502
+
503
+ @property
504
+ def null_deviance(self) -> float:
505
+ return self._residuals.null_deviance
506
+
507
+ def _compute_unit_deviance(self, y: np.ndarray, mu: np.ndarray) -> np.ndarray:
508
+ return self._residuals.unit_deviance(y, mu)
509
+
510
+ def _compute_loss(self, y: np.ndarray, mu: np.ndarray, weights: Optional[np.ndarray] = None) -> float:
511
+ unit_dev = self._compute_unit_deviance(y, mu)
512
+ if weights is not None:
513
+ return np.average(unit_dev, weights=weights)
514
+ return np.mean(unit_dev)
515
+
516
+ def compute_fit_statistics(self) -> Dict[str, float]:
517
+ """Compute overall fit statistics using Rust backend."""
518
+ return _rust_fit_statistics(
519
+ self.y, self.mu, self.deviance, self.null_deviance, self.n_params, self.family
520
+ )
521
+
522
+ def compute_loss_metrics(self) -> Dict[str, float]:
523
+ """Compute various loss metrics using Rust backend."""
524
+ rust_loss = _rust_loss_metrics(self.y, self.mu, self.family)
525
+ return {
526
+ "family_deviance_loss": rust_loss["family_loss"],
527
+ "mse": rust_loss["mse"],
528
+ "mae": rust_loss["mae"],
529
+ "rmse": rust_loss["rmse"],
530
+ }
531
+
532
+ def compute_calibration(self, n_bins: int = 10) -> Dict[str, Any]:
533
+ """Compute calibration metrics using focused component."""
534
+ return self._calibration.compute(n_bins)
535
+
536
+ def compute_discrimination(self) -> Optional[Dict[str, Any]]:
537
+ """Compute discrimination metrics using focused component."""
538
+ return self._discrimination.compute()
539
+
540
+ def compute_residual_summary(self) -> Dict[str, ResidualSummary]:
541
+ """Compute residual summary statistics using Rust backend (compressed)."""
542
+ def summarize(resid: np.ndarray) -> ResidualSummary:
543
+ stats = _rust_residual_summary(resid)
544
+ return ResidualSummary(
545
+ mean=round(stats["mean"], 2),
546
+ std=round(stats["std"], 2),
547
+ skewness=round(stats["skewness"], 1),
548
+ )
549
+
550
+ return {
551
+ "pearson": summarize(self.pearson_residuals),
552
+ "deviance": summarize(self.deviance_residuals),
553
+ }
554
+
555
+ def compute_factor_diagnostics(
556
+ self,
557
+ data: "pl.DataFrame",
558
+ categorical_factors: List[str],
559
+ continuous_factors: List[str],
560
+ result=None, # GLMResults for significance tests
561
+ n_bins: int = 10,
562
+ rare_threshold_pct: float = 1.0,
563
+ max_categorical_levels: int = 20,
564
+ ) -> List[FactorDiagnostics]:
565
+ """Compute diagnostics for each specified factor."""
566
+ factors = []
567
+
568
+ # Process categorical factors
569
+ for name in categorical_factors:
570
+ if name not in data.columns:
571
+ continue
572
+
573
+ values = data[name].to_numpy().astype(str)
574
+ in_model = any(name in fn for fn in self.feature_names)
575
+
576
+ # Univariate stats (compressed: no levels array, info is in actual_vs_expected)
577
+ unique, counts = np.unique(values, return_counts=True)
578
+ total = len(values)
579
+ percentages = [100.0 * c / total for c in counts]
580
+
581
+ n_rare = sum(1 for pct in percentages if pct < rare_threshold_pct)
582
+ rare_pct = sum(pct for pct in percentages if pct < rare_threshold_pct)
583
+
584
+ univariate = CategoricalFactorStats(
585
+ n_levels=len(unique),
586
+ n_rare_levels=n_rare,
587
+ rare_level_total_pct=round(rare_pct, 2),
588
+ )
589
+
590
+ # A/E by level
591
+ ae_bins = self._compute_ae_categorical(
592
+ values, rare_threshold_pct, max_categorical_levels
593
+ )
594
+
595
+ # Residual pattern
596
+ resid_pattern = self._compute_residual_pattern_categorical(values)
597
+
598
+ # Factor significance (only for factors in model)
599
+ significance = self.compute_factor_significance(name, result) if in_model and result else None
600
+
601
+ factors.append(FactorDiagnostics(
602
+ name=name,
603
+ factor_type="categorical",
604
+ in_model=in_model,
605
+ transformation=self._get_transformation(name),
606
+ univariate_stats=univariate,
607
+ actual_vs_expected=ae_bins,
608
+ residual_pattern=resid_pattern,
609
+ significance=significance,
610
+ ))
611
+
612
+ # Process continuous factors
613
+ for name in continuous_factors:
614
+ if name not in data.columns:
615
+ continue
616
+
617
+ values = data[name].to_numpy().astype(np.float64)
618
+ in_model = any(name in fn for fn in self.feature_names)
619
+
620
+ # Univariate stats - batch percentile calculation
621
+ valid_mask = ~np.isnan(values) & ~np.isinf(values)
622
+ valid = values[valid_mask]
623
+
624
+ if len(valid) > 0:
625
+ # Single batched percentile call (much faster)
626
+ pcts = np.percentile(valid, [1, 5, 10, 25, 50, 75, 90, 95, 99])
627
+ percentiles = Percentiles(
628
+ p1=float(pcts[0]), p5=float(pcts[1]), p10=float(pcts[2]),
629
+ p25=float(pcts[3]), p50=float(pcts[4]), p75=float(pcts[5]),
630
+ p90=float(pcts[6]), p95=float(pcts[7]), p99=float(pcts[8]),
631
+ )
632
+ univariate = ContinuousFactorStats(
633
+ mean=float(np.mean(valid)),
634
+ std=float(np.std(valid)),
635
+ min=float(pcts[0]), # Reuse min from percentiles
636
+ max=float(np.max(valid)),
637
+ missing_count=int(np.sum(~valid_mask)),
638
+ percentiles=percentiles,
639
+ )
640
+ else:
641
+ nan = float('nan')
642
+ percentiles = Percentiles(p1=nan, p5=nan, p10=nan, p25=nan, p50=nan, p75=nan, p90=nan, p95=nan, p99=nan)
643
+ univariate = ContinuousFactorStats(mean=nan, std=nan, min=nan, max=nan, missing_count=len(values), percentiles=percentiles)
644
+
645
+ # A/E by quantile bins
646
+ ae_bins = self._compute_ae_continuous(values, n_bins)
647
+
648
+ # Residual pattern
649
+ resid_pattern = self._compute_residual_pattern_continuous(values, n_bins)
650
+
651
+ # Factor significance (only for factors in model)
652
+ significance = self.compute_factor_significance(name, result) if in_model and result else None
653
+
654
+ factors.append(FactorDiagnostics(
655
+ name=name,
656
+ factor_type="continuous",
657
+ in_model=in_model,
658
+ transformation=self._get_transformation(name),
659
+ univariate_stats=univariate,
660
+ actual_vs_expected=ae_bins,
661
+ residual_pattern=resid_pattern,
662
+ significance=significance,
663
+ ))
664
+
665
+ return factors
666
+
667
+ def _get_transformation(self, name: str) -> Optional[str]:
668
+ """Find transformation for a factor in the model."""
669
+ for fn in self.feature_names:
670
+ if name in fn and fn != name:
671
+ return fn
672
+ return None
673
+
674
+ def _get_factor_terms(self, name: str) -> List[str]:
675
+ """Get all model terms that include this factor."""
676
+ return [fn for fn in self.feature_names if name in fn]
677
+
678
+ def compute_factor_significance(
679
+ self,
680
+ name: str,
681
+ result, # GLMResults or FormulaGLMResults
682
+ ) -> Optional[FactorSignificance]:
683
+ """
684
+ Compute significance tests for a factor in the model.
685
+
686
+ Returns Wald chi-square test and deviance contribution.
687
+ """
688
+ if not hasattr(result, 'params') or not hasattr(result, 'bse'):
689
+ return None
690
+
691
+ # Find indices of parameters related to this factor
692
+ param_indices = []
693
+ for i, fn in enumerate(self.feature_names):
694
+ if name in fn and fn != 'Intercept':
695
+ param_indices.append(i)
696
+
697
+ if not param_indices:
698
+ return None
699
+
700
+ try:
701
+ params = np.asarray(result.params)
702
+ bse = np.asarray(result.bse())
703
+
704
+ # Wald chi-square: sum of (coef/se)^2 for all related parameters
705
+ wald_chi2 = 0.0
706
+ for idx in param_indices:
707
+ if bse[idx] > 0:
708
+ wald_chi2 += (params[idx] / bse[idx]) ** 2
709
+
710
+ # Degrees of freedom = number of parameters for this term
711
+ df = len(param_indices)
712
+
713
+ # P-value from chi-square distribution
714
+ try:
715
+ from scipy.stats import chi2
716
+ wald_pvalue = 1 - chi2.cdf(wald_chi2, df) if df > 0 else 1.0
717
+ except ImportError:
718
+ wald_pvalue = float('nan')
719
+
720
+ # Deviance contribution: approximate using sum of z^2 (scaled)
721
+ # This is an approximation; true drop-in-deviance requires refitting
722
+ deviance_contribution = float(wald_chi2) # Approximate
723
+
724
+ return FactorSignificance(
725
+ chi2=round(float(wald_chi2), 2),
726
+ p=round(float(wald_pvalue), 4),
727
+ dev_contrib=round(deviance_contribution, 2),
728
+ )
729
+ except Exception:
730
+ return None
731
+
732
+ def _compute_ae_continuous(self, values: np.ndarray, n_bins: int) -> List[ActualExpectedBin]:
733
+ """Compute A/E for continuous factor using Rust backend (compressed format)."""
734
+ rust_bins = _rust_ae_continuous(values, self.y, self.mu, self.exposure, n_bins, self.family)
735
+ # Filter out empty bins (count=0)
736
+ non_empty_bins = [b for b in rust_bins if b["count"] > 0]
737
+ return [
738
+ ActualExpectedBin(
739
+ bin=b["bin_label"], # includes range for continuous
740
+ n=b["count"],
741
+ actual=int(round(b["actual_sum"])),
742
+ predicted=int(round(b["predicted_sum"])),
743
+ ae=round(b["actual_expected_ratio"], 3),
744
+ ae_ci=[round(b["ae_ci_lower"], 3), round(b["ae_ci_upper"], 3)],
745
+ )
746
+ for b in non_empty_bins
747
+ ]
748
+
749
+ def _compute_ae_categorical(
750
+ self,
751
+ values: np.ndarray,
752
+ rare_threshold_pct: float,
753
+ max_levels: int,
754
+ ) -> List[ActualExpectedBin]:
755
+ """Compute A/E for categorical factor using Rust backend (compressed format)."""
756
+ levels = [str(v) for v in values]
757
+ rust_bins = _rust_ae_categorical(levels, self.y, self.mu, self.exposure,
758
+ rare_threshold_pct, max_levels, self.family)
759
+ return [
760
+ ActualExpectedBin(
761
+ bin=b["bin_label"],
762
+ n=b["count"],
763
+ actual=int(round(b["actual_sum"])),
764
+ predicted=int(round(b["predicted_sum"])),
765
+ ae=round(b["actual_expected_ratio"], 3),
766
+ ae_ci=[round(b["ae_ci_lower"], 3), round(b["ae_ci_upper"], 3)],
767
+ )
768
+ for b in rust_bins
769
+ ]
770
+
771
+ def _compute_residual_pattern_continuous(
772
+ self,
773
+ values: np.ndarray,
774
+ n_bins: int,
775
+ ) -> ResidualPattern:
776
+ """Compute residual pattern using Rust backend (compressed: no mean_by_bin)."""
777
+ valid_mask = ~np.isnan(values) & ~np.isinf(values)
778
+
779
+ if not np.any(valid_mask):
780
+ return ResidualPattern(resid_corr=0.0, var_explained=0.0)
781
+
782
+ result = _rust_residual_pattern(values, self.pearson_residuals, n_bins)
783
+ corr = result["correlation_with_residuals"]
784
+ corr_val = float(corr) if not np.isnan(corr) else 0.0
785
+
786
+ return ResidualPattern(
787
+ resid_corr=round(corr_val, 4),
788
+ var_explained=round(corr_val ** 2, 6),
789
+ )
790
+
791
+ def _compute_residual_pattern_categorical(self, values: np.ndarray) -> ResidualPattern:
792
+ """Compute residual pattern for categorical factor (compressed)."""
793
+ unique_levels = np.unique(values)
794
+
795
+ # Compute eta-squared (variance explained)
796
+ overall_mean = np.mean(self.pearson_residuals)
797
+ ss_total = np.sum((self.pearson_residuals - overall_mean) ** 2)
798
+ ss_between = 0.0
799
+ level_means = []
800
+
801
+ for level in unique_levels:
802
+ mask = values == level
803
+ level_resid = self.pearson_residuals[mask]
804
+ level_mean = np.mean(level_resid)
805
+ level_means.append(level_mean)
806
+ ss_between += len(level_resid) * (level_mean - overall_mean) ** 2
807
+
808
+ eta_squared = ss_between / ss_total if ss_total > 0 else 0.0
809
+ mean_abs_resid = np.mean(np.abs(level_means))
810
+
811
+ return ResidualPattern(
812
+ resid_corr=round(float(mean_abs_resid), 4),
813
+ var_explained=round(float(eta_squared), 6),
814
+ )
815
+
816
+ def _linear_trend_test(self, x: np.ndarray, y: np.ndarray) -> tuple:
817
+ """Simple linear regression trend test."""
818
+ n = len(x)
819
+ if n < 3:
820
+ return float('nan'), float('nan')
821
+
822
+ x_mean = np.mean(x)
823
+ y_mean = np.mean(y)
824
+
825
+ ss_xx = np.sum((x - x_mean) ** 2)
826
+ ss_xy = np.sum((x - x_mean) * (y - y_mean))
827
+
828
+ if ss_xx == 0:
829
+ return 0.0, 1.0
830
+
831
+ slope = ss_xy / ss_xx
832
+
833
+ # Residuals from regression
834
+ y_pred = y_mean + slope * (x - x_mean)
835
+ ss_res = np.sum((y - y_pred) ** 2)
836
+
837
+ df = n - 2
838
+ mse = ss_res / df if df > 0 else 0
839
+ se_slope = np.sqrt(mse / ss_xx) if mse > 0 and ss_xx > 0 else float('nan')
840
+
841
+ if np.isnan(se_slope) or se_slope == 0:
842
+ return slope, float('nan')
843
+
844
+ t_stat = slope / se_slope
845
+
846
+ try:
847
+ from scipy.stats import t
848
+ pvalue = 2 * (1 - t.cdf(abs(t_stat), df))
849
+ except ImportError:
850
+ pvalue = float('nan')
851
+
852
+ return slope, pvalue
853
+
854
+ def detect_interactions(
855
+ self,
856
+ data: "pl.DataFrame",
857
+ factor_names: List[str],
858
+ max_factors: int = 10,
859
+ min_correlation: float = 0.01,
860
+ max_candidates: int = 5,
861
+ min_cell_count: int = 30,
862
+ ) -> List[InteractionCandidate]:
863
+ """Detect potential interactions using greedy residual-based approach."""
864
+ # First, rank factors by residual association
865
+ factor_scores = []
866
+
867
+ for name in factor_names:
868
+ if name not in data.columns:
869
+ continue
870
+
871
+ values = data[name].to_numpy()
872
+
873
+ # Check if categorical or continuous
874
+ if values.dtype == object or str(values.dtype).startswith('str'):
875
+ score = self._compute_eta_squared(values.astype(str))
876
+ else:
877
+ values = values.astype(np.float64)
878
+ valid_mask = ~np.isnan(values) & ~np.isinf(values)
879
+ if np.sum(valid_mask) < 10:
880
+ continue
881
+ score = abs(np.corrcoef(values[valid_mask], self.pearson_residuals[valid_mask])[0, 1])
882
+
883
+ if score >= min_correlation:
884
+ factor_scores.append((name, score))
885
+
886
+ # Sort and take top factors
887
+ factor_scores.sort(key=lambda x: -x[1])
888
+ top_factors = [name for name, _ in factor_scores[:max_factors]]
889
+
890
+ if len(top_factors) < 2:
891
+ return []
892
+
893
+ # Check pairwise interactions
894
+ candidates = []
895
+
896
+ for i in range(len(top_factors)):
897
+ for j in range(i + 1, len(top_factors)):
898
+ name1, name2 = top_factors[i], top_factors[j]
899
+
900
+ values1 = data[name1].to_numpy()
901
+ values2 = data[name2].to_numpy()
902
+
903
+ # Discretize both factors
904
+ bins1 = self._discretize(values1, 5)
905
+ bins2 = self._discretize(values2, 5)
906
+
907
+ # Compute interaction strength
908
+ candidate = self._compute_interaction_strength(
909
+ name1, bins1, name2, bins2, min_cell_count
910
+ )
911
+
912
+ if candidate is not None:
913
+ # Add current_terms and recommendation
914
+ terms1 = self._get_factor_terms(name1)
915
+ terms2 = self._get_factor_terms(name2)
916
+ candidate.current_terms = terms1 + terms2 if (terms1 or terms2) else None
917
+
918
+ # Generate recommendation based on current terms and factor types
919
+ candidate.recommendation = self._generate_interaction_recommendation(
920
+ name1, name2, terms1, terms2, values1, values2
921
+ )
922
+ candidates.append(candidate)
923
+
924
+ # Sort by strength and return top candidates
925
+ candidates.sort(key=lambda x: -x.interaction_strength)
926
+ return candidates[:max_candidates]
927
+
928
+ def _generate_interaction_recommendation(
929
+ self,
930
+ name1: str,
931
+ name2: str,
932
+ terms1: List[str],
933
+ terms2: List[str],
934
+ values1: np.ndarray,
935
+ values2: np.ndarray,
936
+ ) -> str:
937
+ """Generate a recommendation for how to model an interaction."""
938
+ is_cat1 = values1.dtype == object or str(values1.dtype).startswith('str')
939
+ is_cat2 = values2.dtype == object or str(values2.dtype).startswith('str')
940
+
941
+ # Check if factors have spline/polynomial terms
942
+ has_spline1 = any('bs(' in t or 'ns(' in t or 's(' in t for t in terms1)
943
+ has_spline2 = any('bs(' in t or 'ns(' in t or 's(' in t for t in terms2)
944
+ has_poly1 = any('I(' in t and '**' in t for t in terms1)
945
+ has_poly2 = any('I(' in t and '**' in t for t in terms2)
946
+
947
+ if is_cat1 and is_cat2:
948
+ return f"Consider C({name1}):C({name2}) interaction term"
949
+ elif is_cat1 and not is_cat2:
950
+ if has_spline2:
951
+ return f"Consider C({name1}):{name2} or separate splines by {name1} level"
952
+ else:
953
+ return f"Consider C({name1}):{name2} interaction term"
954
+ elif not is_cat1 and is_cat2:
955
+ if has_spline1:
956
+ return f"Consider {name1}:C({name2}) or separate splines by {name2} level"
957
+ else:
958
+ return f"Consider {name1}:C({name2}) interaction term"
959
+ else:
960
+ # Both continuous
961
+ if has_spline1 or has_spline2 or has_poly1 or has_poly2:
962
+ return f"Consider {name1}:{name2} or tensor product spline"
963
+ else:
964
+ return f"Consider {name1}:{name2} interaction or joint spline"
965
+
966
+ def _compute_eta_squared(self, categories: np.ndarray) -> float:
967
+ """Compute eta-squared for categorical association with residuals."""
968
+ unique_levels = np.unique(categories)
969
+ overall_mean = np.mean(self.pearson_residuals)
970
+ ss_total = np.sum((self.pearson_residuals - overall_mean) ** 2)
971
+
972
+ if ss_total == 0:
973
+ return 0.0
974
+
975
+ ss_between = 0.0
976
+ for level in unique_levels:
977
+ mask = categories == level
978
+ level_resid = self.pearson_residuals[mask]
979
+ level_mean = np.mean(level_resid)
980
+ ss_between += len(level_resid) * (level_mean - overall_mean) ** 2
981
+
982
+ return ss_between / ss_total
983
+
984
+ def _discretize(self, values: np.ndarray, n_bins: int) -> np.ndarray:
985
+ """Discretize values into bins."""
986
+ if values.dtype == object or str(values.dtype).startswith('str'):
987
+ # Categorical - map to integers
988
+ unique_vals = np.unique(values)
989
+ mapping = {v: i for i, v in enumerate(unique_vals)}
990
+ return np.array([mapping[v] for v in values])
991
+ else:
992
+ # Continuous - quantile bins
993
+ values = values.astype(np.float64)
994
+ valid_mask = ~np.isnan(values) & ~np.isinf(values)
995
+
996
+ if not np.any(valid_mask):
997
+ return np.zeros(len(values), dtype=int)
998
+
999
+ quantiles = np.percentile(values[valid_mask], np.linspace(0, 100, n_bins + 1))
1000
+ bins = np.digitize(values, quantiles[1:-1])
1001
+ bins[~valid_mask] = n_bins # Invalid values in separate bin
1002
+ return bins
1003
+
1004
+ def _compute_interaction_strength(
1005
+ self,
1006
+ name1: str,
1007
+ bins1: np.ndarray,
1008
+ name2: str,
1009
+ bins2: np.ndarray,
1010
+ min_cell_count: int,
1011
+ ) -> Optional[InteractionCandidate]:
1012
+ """Compute interaction strength between two discretized factors."""
1013
+ # Create interaction cells
1014
+ cell_ids = bins1 * 1000 + bins2 # Unique cell ID
1015
+ unique_cells = np.unique(cell_ids)
1016
+
1017
+ # Filter cells with sufficient data
1018
+ valid_cells = []
1019
+ cell_residuals = []
1020
+
1021
+ for cell_id in unique_cells:
1022
+ mask = cell_ids == cell_id
1023
+ if np.sum(mask) >= min_cell_count:
1024
+ valid_cells.append(cell_id)
1025
+ cell_residuals.append(self.pearson_residuals[mask])
1026
+
1027
+ if len(valid_cells) < 4:
1028
+ return None
1029
+
1030
+ # Compute variance explained by cells
1031
+ all_resid = np.concatenate(cell_residuals)
1032
+ overall_mean = np.mean(all_resid)
1033
+ ss_total = np.sum((all_resid - overall_mean) ** 2)
1034
+
1035
+ if ss_total == 0:
1036
+ return None
1037
+
1038
+ ss_model = sum(
1039
+ len(r) * (np.mean(r) - overall_mean) ** 2
1040
+ for r in cell_residuals
1041
+ )
1042
+
1043
+ r_squared = ss_model / ss_total
1044
+
1045
+ # F-test p-value
1046
+ df_model = len(valid_cells) - 1
1047
+ df_resid = len(all_resid) - len(valid_cells)
1048
+
1049
+ if df_model > 0 and df_resid > 0:
1050
+ f_stat = (ss_model / df_model) / ((ss_total - ss_model) / df_resid)
1051
+
1052
+ try:
1053
+ from scipy.stats import f
1054
+ pvalue = 1 - f.cdf(f_stat, df_model, df_resid)
1055
+ except ImportError:
1056
+ pvalue = float('nan')
1057
+ else:
1058
+ pvalue = float('nan')
1059
+
1060
+ return InteractionCandidate(
1061
+ factor1=name1,
1062
+ factor2=name2,
1063
+ interaction_strength=float(r_squared),
1064
+ pvalue=float(pvalue),
1065
+ n_cells=len(valid_cells),
1066
+ )
1067
+
1068
+ def compute_model_comparison(self) -> Dict[str, float]:
1069
+ """Compute model comparison statistics vs null model."""
1070
+ null_dev = self.null_deviance
1071
+
1072
+ # Likelihood ratio test
1073
+ lr_chi2 = null_dev - self.deviance
1074
+ lr_df = self.n_params - 1
1075
+
1076
+ try:
1077
+ from scipy.stats import chi2
1078
+ lr_pvalue = 1 - chi2.cdf(lr_chi2, lr_df) if lr_df > 0 else float('nan')
1079
+ except ImportError:
1080
+ lr_pvalue = float('nan')
1081
+
1082
+ deviance_reduction_pct = 100 * (1 - self.deviance / null_dev) if null_dev > 0 else 0
1083
+
1084
+ # AIC improvement
1085
+ null_aic = null_dev + 2 # Null model has 1 parameter
1086
+ model_aic = self.deviance + 2 * self.n_params
1087
+ aic_improvement = null_aic - model_aic
1088
+
1089
+ return {
1090
+ "likelihood_ratio_chi2": float(lr_chi2),
1091
+ "likelihood_ratio_df": lr_df,
1092
+ "likelihood_ratio_pvalue": float(lr_pvalue),
1093
+ "deviance_reduction_pct": float(deviance_reduction_pct),
1094
+ "aic_improvement": float(aic_improvement),
1095
+ }
1096
+
1097
+ def generate_warnings(
1098
+ self,
1099
+ fit_stats: Dict[str, float],
1100
+ calibration: Dict[str, Any],
1101
+ factors: List[FactorDiagnostics],
1102
+ ) -> List[Dict[str, str]]:
1103
+ """Generate warnings based on diagnostics."""
1104
+ warnings = []
1105
+
1106
+ # High dispersion warning
1107
+ dispersion = fit_stats.get("dispersion_pearson", 1.0)
1108
+ if dispersion > 1.5:
1109
+ warnings.append({
1110
+ "type": "high_dispersion",
1111
+ "message": f"Dispersion {dispersion:.2f} suggests overdispersion. Consider quasipoisson or negbinomial."
1112
+ })
1113
+
1114
+ # Poor overall calibration
1115
+ ae_ratio = calibration.get("actual_expected_ratio", 1.0)
1116
+ if abs(ae_ratio - 1.0) > 0.05:
1117
+ direction = "over" if ae_ratio < 1 else "under"
1118
+ warnings.append({
1119
+ "type": "poor_overall_calibration",
1120
+ "message": f"Model {direction}-predicts overall (A/E = {ae_ratio:.3f})."
1121
+ })
1122
+
1123
+ # Extreme calibration bins
1124
+ for bin in calibration.get("by_decile", []):
1125
+ if isinstance(bin, dict):
1126
+ ae = bin.get("actual_expected_ratio", 1.0)
1127
+ if ae is not None and abs(ae - 1.0) > 0.3:
1128
+ warnings.append({
1129
+ "type": "poor_bin_calibration",
1130
+ "message": f"Decile {bin.get('bin_index', '?')} has A/E = {ae:.2f}."
1131
+ })
1132
+
1133
+ # Factors with high residual correlation (not in model)
1134
+ for factor in factors:
1135
+ if not factor.in_model:
1136
+ corr = factor.residual_pattern.resid_corr
1137
+ r2 = factor.residual_pattern.var_explained
1138
+ if r2 > 0.02:
1139
+ warnings.append({
1140
+ "type": "missing_factor",
1141
+ "message": f"Factor '{factor.name}' not in model but explains {100*r2:.1f}% of residual variance."
1142
+ })
1143
+
1144
+ return warnings
1145
+
1146
+
1147
+ # =============================================================================
1148
+ # Pre-Fit Data Exploration
1149
+ # =============================================================================
1150
+
1151
+ class DataExplorer:
1152
+ """
1153
+ Explores data before model fitting.
1154
+
1155
+ This class provides pre-fit analysis including:
1156
+ - Factor statistics (univariate distributions)
1157
+ - Interaction detection based on response variable
1158
+ - Response distribution analysis
1159
+
1160
+ Unlike DiagnosticsComputer, this does NOT require a fitted model.
1161
+ """
1162
+
1163
+ def __init__(
1164
+ self,
1165
+ y: np.ndarray,
1166
+ exposure: Optional[np.ndarray] = None,
1167
+ family: str = "poisson",
1168
+ ):
1169
+ """
1170
+ Initialize the data explorer.
1171
+
1172
+ Parameters
1173
+ ----------
1174
+ y : np.ndarray
1175
+ Response variable.
1176
+ exposure : np.ndarray, optional
1177
+ Exposure or weights.
1178
+ family : str, default="poisson"
1179
+ Family hint for appropriate statistics.
1180
+ """
1181
+ self.y = np.asarray(y, dtype=np.float64)
1182
+ self.exposure = np.asarray(exposure, dtype=np.float64) if exposure is not None else np.ones_like(self.y)
1183
+ self.family = family.lower()
1184
+ self.n_obs = len(y)
1185
+
1186
+ def compute_response_stats(self) -> Dict[str, Any]:
1187
+ """Compute response variable statistics."""
1188
+ y_rate = self.y / self.exposure
1189
+
1190
+ stats = {
1191
+ "n_observations": self.n_obs,
1192
+ "total_exposure": float(np.sum(self.exposure)),
1193
+ "total_response": float(np.sum(self.y)),
1194
+ "mean_response": float(np.mean(self.y)),
1195
+ "mean_rate": float(np.mean(y_rate)),
1196
+ "std_rate": float(np.std(y_rate)),
1197
+ "min": float(np.min(self.y)),
1198
+ "max": float(np.max(self.y)),
1199
+ "zeros_count": int(np.sum(self.y == 0)),
1200
+ "zeros_pct": float(100 * np.sum(self.y == 0) / self.n_obs),
1201
+ }
1202
+
1203
+ # Add percentiles
1204
+ percentiles = [1, 5, 10, 25, 50, 75, 90, 95, 99]
1205
+ for p in percentiles:
1206
+ stats[f"p{p}"] = float(np.percentile(y_rate, p))
1207
+
1208
+ return stats
1209
+
1210
+ def compute_factor_stats(
1211
+ self,
1212
+ data: "pl.DataFrame",
1213
+ categorical_factors: List[str],
1214
+ continuous_factors: List[str],
1215
+ n_bins: int = 10,
1216
+ rare_threshold_pct: float = 1.0,
1217
+ max_categorical_levels: int = 20,
1218
+ ) -> List[Dict[str, Any]]:
1219
+ """
1220
+ Compute univariate statistics for each factor.
1221
+
1222
+ Returns statistics and actual/expected rates by level/bin.
1223
+ """
1224
+ factors = []
1225
+
1226
+ # Continuous factors
1227
+ for name in continuous_factors:
1228
+ if name not in data.columns:
1229
+ continue
1230
+
1231
+ values = data[name].to_numpy().astype(np.float64)
1232
+ valid_mask = ~np.isnan(values) & ~np.isinf(values)
1233
+ valid_values = values[valid_mask]
1234
+
1235
+ if len(valid_values) == 0:
1236
+ continue
1237
+
1238
+ # Univariate stats
1239
+ stats = {
1240
+ "name": name,
1241
+ "type": "continuous",
1242
+ "mean": float(np.mean(valid_values)),
1243
+ "std": float(np.std(valid_values)),
1244
+ "min": float(np.min(valid_values)),
1245
+ "max": float(np.max(valid_values)),
1246
+ "missing_count": int(np.sum(~valid_mask)),
1247
+ "missing_pct": float(100 * np.sum(~valid_mask) / len(values)),
1248
+ }
1249
+
1250
+ # Response by quantile bins
1251
+ quantiles = np.percentile(valid_values, np.linspace(0, 100, n_bins + 1))
1252
+ bins_data = []
1253
+ bin_rates = []
1254
+ thin_cells = []
1255
+ total_exposure = np.sum(self.exposure)
1256
+
1257
+ for i in range(n_bins):
1258
+ if i == n_bins - 1:
1259
+ bin_mask = (values >= quantiles[i]) & (values <= quantiles[i + 1])
1260
+ else:
1261
+ bin_mask = (values >= quantiles[i]) & (values < quantiles[i + 1])
1262
+
1263
+ if not np.any(bin_mask):
1264
+ continue
1265
+
1266
+ y_bin = self.y[bin_mask]
1267
+ exp_bin = self.exposure[bin_mask]
1268
+ bin_exposure = float(np.sum(exp_bin))
1269
+ rate = float(np.sum(y_bin) / bin_exposure) if bin_exposure > 0 else 0
1270
+
1271
+ bins_data.append({
1272
+ "bin_index": i,
1273
+ "bin_lower": float(quantiles[i]),
1274
+ "bin_upper": float(quantiles[i + 1]),
1275
+ "count": int(np.sum(bin_mask)),
1276
+ "exposure": bin_exposure,
1277
+ "response_sum": float(np.sum(y_bin)),
1278
+ "response_rate": rate,
1279
+ })
1280
+ bin_rates.append(rate)
1281
+
1282
+ # Check for thin cells (< 1% exposure)
1283
+ if bin_exposure / total_exposure < 0.01:
1284
+ thin_cells.append(i)
1285
+
1286
+ stats["response_by_bin"] = bins_data
1287
+
1288
+ # Compute shape recommendation
1289
+ if len(bin_rates) >= 3:
1290
+ shape_hint = self._compute_shape_hint(bin_rates)
1291
+ else:
1292
+ shape_hint = {"shape": "insufficient_data", "recommendation": "linear"}
1293
+
1294
+ stats["modeling_hints"] = {
1295
+ "shape": shape_hint["shape"],
1296
+ "recommendation": shape_hint["recommendation"],
1297
+ "thin_cells": thin_cells if thin_cells else None,
1298
+ "thin_cell_warning": f"Bins {thin_cells} have <1% exposure" if thin_cells else None,
1299
+ }
1300
+
1301
+ factors.append(stats)
1302
+
1303
+ # Categorical factors
1304
+ for name in categorical_factors:
1305
+ if name not in data.columns:
1306
+ continue
1307
+
1308
+ values = data[name].to_numpy().astype(str)
1309
+ unique_levels = np.unique(values)
1310
+
1311
+ # Sort levels by exposure
1312
+ level_exposures = []
1313
+ for level in unique_levels:
1314
+ mask = values == level
1315
+ exp = np.sum(self.exposure[mask])
1316
+ level_exposures.append((level, exp))
1317
+ level_exposures.sort(key=lambda x: -x[1])
1318
+
1319
+ total_exposure = np.sum(self.exposure)
1320
+
1321
+ # Build level stats
1322
+ levels_data = []
1323
+ other_mask = np.zeros(len(values), dtype=bool)
1324
+
1325
+ for i, (level, exp) in enumerate(level_exposures):
1326
+ pct = 100 * exp / total_exposure
1327
+
1328
+ if pct < rare_threshold_pct or i >= max_categorical_levels - 1:
1329
+ other_mask |= (values == level)
1330
+ else:
1331
+ mask = values == level
1332
+ y_level = self.y[mask]
1333
+ exp_level = self.exposure[mask]
1334
+
1335
+ levels_data.append({
1336
+ "level": level,
1337
+ "count": int(np.sum(mask)),
1338
+ "exposure": float(np.sum(exp_level)),
1339
+ "exposure_pct": float(pct),
1340
+ "response_sum": float(np.sum(y_level)),
1341
+ "response_rate": float(np.sum(y_level) / np.sum(exp_level)) if np.sum(exp_level) > 0 else 0,
1342
+ })
1343
+
1344
+ # Add "Other" if needed
1345
+ if np.any(other_mask):
1346
+ y_other = self.y[other_mask]
1347
+ exp_other = self.exposure[other_mask]
1348
+ levels_data.append({
1349
+ "level": "_Other",
1350
+ "count": int(np.sum(other_mask)),
1351
+ "exposure": float(np.sum(exp_other)),
1352
+ "exposure_pct": float(100 * np.sum(exp_other) / total_exposure),
1353
+ "response_sum": float(np.sum(y_other)),
1354
+ "response_rate": float(np.sum(y_other) / np.sum(exp_other)) if np.sum(exp_other) > 0 else 0,
1355
+ })
1356
+
1357
+ # Compute modeling hints for categorical
1358
+ main_levels = [l for l in levels_data if l["level"] != "_Other"]
1359
+
1360
+ # Suggested base level: highest exposure among non-Other levels
1361
+ suggested_base = main_levels[0]["level"] if main_levels else None
1362
+
1363
+ # Check for thin cells
1364
+ thin_levels = [l["level"] for l in main_levels if l["exposure_pct"] < 1.0]
1365
+
1366
+ # Check if ordinal (levels are numeric or follow A-Z pattern)
1367
+ ordinal_hint = self._detect_ordinal_pattern(unique_levels)
1368
+
1369
+ stats = {
1370
+ "name": name,
1371
+ "type": "categorical",
1372
+ "n_levels": len(unique_levels),
1373
+ "n_levels_shown": len(levels_data),
1374
+ "levels": levels_data,
1375
+ "modeling_hints": {
1376
+ "suggested_base_level": suggested_base,
1377
+ "ordinal": ordinal_hint["is_ordinal"],
1378
+ "ordinal_pattern": ordinal_hint["pattern"],
1379
+ "thin_levels": thin_levels if thin_levels else None,
1380
+ "thin_level_warning": f"Levels {thin_levels} have <1% exposure" if thin_levels else None,
1381
+ },
1382
+ }
1383
+ factors.append(stats)
1384
+
1385
+ return factors
1386
+
1387
+ def _compute_shape_hint(self, bin_rates: List[float]) -> Dict[str, str]:
1388
+ """Analyze binned response rates to suggest transformation."""
1389
+ n = len(bin_rates)
1390
+ if n < 3:
1391
+ return {"shape": "insufficient_data", "recommendation": "linear"}
1392
+
1393
+ # Check monotonicity
1394
+ diffs = [bin_rates[i+1] - bin_rates[i] for i in range(n-1)]
1395
+ increasing = sum(1 for d in diffs if d > 0)
1396
+ decreasing = sum(1 for d in diffs if d < 0)
1397
+
1398
+ # Strong monotonic pattern
1399
+ if increasing >= n - 2:
1400
+ return {"shape": "monotonic_increasing", "recommendation": "linear or log"}
1401
+ if decreasing >= n - 2:
1402
+ return {"shape": "monotonic_decreasing", "recommendation": "linear or log"}
1403
+
1404
+ # Check for U-shape or inverted U
1405
+ mid = n // 2
1406
+ left_trend = sum(diffs[:mid])
1407
+ right_trend = sum(diffs[mid:])
1408
+
1409
+ if left_trend < 0 and right_trend > 0:
1410
+ return {"shape": "u_shaped", "recommendation": "spline or polynomial"}
1411
+ if left_trend > 0 and right_trend < 0:
1412
+ return {"shape": "inverted_u", "recommendation": "spline or polynomial"}
1413
+
1414
+ # Check for step function (large jump)
1415
+ max_diff = max(abs(d) for d in diffs)
1416
+ avg_rate = sum(bin_rates) / n
1417
+ if max_diff > avg_rate * 0.5:
1418
+ return {"shape": "step_function", "recommendation": "banding or categorical"}
1419
+
1420
+ # Non-linear but no clear pattern
1421
+ variance = sum((r - avg_rate)**2 for r in bin_rates) / n
1422
+ if variance > avg_rate * 0.1:
1423
+ return {"shape": "non_linear", "recommendation": "spline"}
1424
+
1425
+ return {"shape": "flat", "recommendation": "may not need in model"}
1426
+
1427
+ def _detect_ordinal_pattern(self, levels: np.ndarray) -> Dict[str, Any]:
1428
+ """Detect if categorical levels follow an ordinal pattern."""
1429
+ levels_str = [str(l) for l in levels]
1430
+
1431
+ # Check for numeric levels
1432
+ try:
1433
+ numeric = [float(l) for l in levels_str]
1434
+ return {"is_ordinal": True, "pattern": "numeric"}
1435
+ except ValueError:
1436
+ pass
1437
+
1438
+ # Check for single letter A-Z pattern
1439
+ if all(len(l) == 1 and l.isalpha() for l in levels_str):
1440
+ return {"is_ordinal": True, "pattern": "alphabetic"}
1441
+
1442
+ # Check for common ordinal patterns
1443
+ ordinal_patterns = [
1444
+ (["low", "medium", "high"], "low_medium_high"),
1445
+ (["small", "medium", "large"], "size"),
1446
+ (["young", "middle", "old"], "age"),
1447
+ (["1", "2", "3", "4", "5"], "numeric_string"),
1448
+ ]
1449
+
1450
+ levels_lower = [l.lower() for l in levels_str]
1451
+ for pattern, name in ordinal_patterns:
1452
+ if all(p in levels_lower for p in pattern):
1453
+ return {"is_ordinal": True, "pattern": name}
1454
+
1455
+ # Check for prefix + number pattern (e.g., "Region1", "Region2")
1456
+ import re
1457
+ if all(re.match(r'^[A-Za-z]+\d+$', l) for l in levels_str):
1458
+ return {"is_ordinal": True, "pattern": "prefix_numeric"}
1459
+
1460
+ return {"is_ordinal": False, "pattern": None}
1461
+
1462
+ def compute_univariate_tests(
1463
+ self,
1464
+ data: "pl.DataFrame",
1465
+ categorical_factors: List[str],
1466
+ continuous_factors: List[str],
1467
+ ) -> List[Dict[str, Any]]:
1468
+ """
1469
+ Compute univariate significance tests for each factor vs response.
1470
+
1471
+ For continuous factors: Pearson correlation + F-test from simple regression
1472
+ For categorical factors: ANOVA F-test (eta-squared based)
1473
+ """
1474
+ results = []
1475
+ y_rate = self.y / self.exposure
1476
+
1477
+ for name in continuous_factors:
1478
+ if name not in data.columns:
1479
+ continue
1480
+
1481
+ values = data[name].to_numpy().astype(np.float64)
1482
+ valid_mask = ~np.isnan(values) & ~np.isinf(values)
1483
+
1484
+ if np.sum(valid_mask) < 10:
1485
+ continue
1486
+
1487
+ x_valid = values[valid_mask]
1488
+ y_valid = y_rate[valid_mask]
1489
+ w_valid = self.exposure[valid_mask]
1490
+
1491
+ # Weighted correlation
1492
+ x_mean = np.average(x_valid, weights=w_valid)
1493
+ y_mean = np.average(y_valid, weights=w_valid)
1494
+
1495
+ cov_xy = np.sum(w_valid * (x_valid - x_mean) * (y_valid - y_mean)) / np.sum(w_valid)
1496
+ std_x = np.sqrt(np.sum(w_valid * (x_valid - x_mean) ** 2) / np.sum(w_valid))
1497
+ std_y = np.sqrt(np.sum(w_valid * (y_valid - y_mean) ** 2) / np.sum(w_valid))
1498
+
1499
+ corr = cov_xy / (std_x * std_y) if std_x > 0 and std_y > 0 else 0.0
1500
+
1501
+ # F-test from regression
1502
+ n = len(x_valid)
1503
+ r2 = corr ** 2
1504
+ f_stat = (r2 / 1) / ((1 - r2) / (n - 2)) if r2 < 1 and n > 2 else 0
1505
+
1506
+ try:
1507
+ from scipy.stats import f
1508
+ pvalue = 1 - f.cdf(f_stat, 1, n - 2) if n > 2 else 1.0
1509
+ except ImportError:
1510
+ pvalue = float('nan')
1511
+
1512
+ results.append({
1513
+ "factor": name,
1514
+ "type": "continuous",
1515
+ "test": "correlation_f_test",
1516
+ "correlation": float(corr),
1517
+ "r_squared": float(r2),
1518
+ "f_statistic": float(f_stat),
1519
+ "pvalue": float(pvalue),
1520
+ "significant_01": pvalue < 0.01 if not np.isnan(pvalue) else False,
1521
+ "significant_05": pvalue < 0.05 if not np.isnan(pvalue) else False,
1522
+ })
1523
+
1524
+ for name in categorical_factors:
1525
+ if name not in data.columns:
1526
+ continue
1527
+
1528
+ values = data[name].to_numpy().astype(str)
1529
+
1530
+ # ANOVA: eta-squared and F-test
1531
+ eta_sq = self._compute_eta_squared_response(values)
1532
+
1533
+ unique_levels = np.unique(values)
1534
+ k = len(unique_levels)
1535
+ n = len(values)
1536
+
1537
+ if k > 1 and n > k:
1538
+ f_stat = (eta_sq / (k - 1)) / ((1 - eta_sq) / (n - k)) if eta_sq < 1 else 0
1539
+
1540
+ try:
1541
+ from scipy.stats import f
1542
+ pvalue = 1 - f.cdf(f_stat, k - 1, n - k)
1543
+ except ImportError:
1544
+ pvalue = float('nan')
1545
+ else:
1546
+ f_stat = 0.0
1547
+ pvalue = 1.0
1548
+
1549
+ results.append({
1550
+ "factor": name,
1551
+ "type": "categorical",
1552
+ "test": "anova_f_test",
1553
+ "n_levels": k,
1554
+ "eta_squared": float(eta_sq),
1555
+ "f_statistic": float(f_stat),
1556
+ "pvalue": float(pvalue),
1557
+ "significant_01": pvalue < 0.01 if not np.isnan(pvalue) else False,
1558
+ "significant_05": pvalue < 0.05 if not np.isnan(pvalue) else False,
1559
+ })
1560
+
1561
+ # Sort by p-value (most significant first)
1562
+ results.sort(key=lambda x: x["pvalue"] if not np.isnan(x["pvalue"]) else 1.0)
1563
+ return results
1564
+
1565
+ def compute_correlations(
1566
+ self,
1567
+ data: "pl.DataFrame",
1568
+ continuous_factors: List[str],
1569
+ ) -> Dict[str, Any]:
1570
+ """
1571
+ Compute pairwise correlations between continuous factors.
1572
+
1573
+ Returns correlation matrix and flags for high correlations.
1574
+ """
1575
+ valid_factors = [f for f in continuous_factors if f in data.columns]
1576
+
1577
+ if len(valid_factors) < 2:
1578
+ return {"factors": valid_factors, "matrix": [], "high_correlations": []}
1579
+
1580
+ # Build matrix of valid values
1581
+ arrays = []
1582
+ for name in valid_factors:
1583
+ arr = data[name].to_numpy().astype(np.float64)
1584
+ arrays.append(arr)
1585
+
1586
+ X = np.column_stack(arrays)
1587
+
1588
+ # Handle missing values - use pairwise complete observations
1589
+ n_factors = len(valid_factors)
1590
+ corr_matrix = np.eye(n_factors)
1591
+
1592
+ for i in range(n_factors):
1593
+ for j in range(i + 1, n_factors):
1594
+ xi, xj = X[:, i], X[:, j]
1595
+ valid = ~np.isnan(xi) & ~np.isnan(xj) & ~np.isinf(xi) & ~np.isinf(xj)
1596
+
1597
+ if np.sum(valid) > 2:
1598
+ corr = np.corrcoef(xi[valid], xj[valid])[0, 1]
1599
+ corr_matrix[i, j] = corr
1600
+ corr_matrix[j, i] = corr
1601
+ else:
1602
+ corr_matrix[i, j] = float('nan')
1603
+ corr_matrix[j, i] = float('nan')
1604
+
1605
+ # Find high correlations (|r| > 0.7)
1606
+ high_corrs = []
1607
+ for i in range(n_factors):
1608
+ for j in range(i + 1, n_factors):
1609
+ r = corr_matrix[i, j]
1610
+ if not np.isnan(r) and abs(r) > 0.7:
1611
+ high_corrs.append({
1612
+ "factor1": valid_factors[i],
1613
+ "factor2": valid_factors[j],
1614
+ "correlation": float(r),
1615
+ "severity": "high" if abs(r) > 0.9 else "moderate",
1616
+ })
1617
+
1618
+ high_corrs.sort(key=lambda x: -abs(x["correlation"]))
1619
+
1620
+ return {
1621
+ "factors": valid_factors,
1622
+ "matrix": corr_matrix.tolist(),
1623
+ "high_correlations": high_corrs,
1624
+ }
1625
+
1626
+ def compute_vif(
1627
+ self,
1628
+ data: "pl.DataFrame",
1629
+ continuous_factors: List[str],
1630
+ ) -> List[Dict[str, Any]]:
1631
+ """
1632
+ Compute Variance Inflation Factors for multicollinearity detection.
1633
+
1634
+ VIF > 5 indicates moderate multicollinearity
1635
+ VIF > 10 indicates severe multicollinearity
1636
+ """
1637
+ valid_factors = [f for f in continuous_factors if f in data.columns]
1638
+
1639
+ if len(valid_factors) < 2:
1640
+ return [{"factor": f, "vif": 1.0, "severity": "none"} for f in valid_factors]
1641
+
1642
+ # Build design matrix
1643
+ arrays = []
1644
+ for name in valid_factors:
1645
+ arr = data[name].to_numpy().astype(np.float64)
1646
+ arrays.append(arr)
1647
+
1648
+ X = np.column_stack(arrays)
1649
+
1650
+ # Remove rows with any NaN/Inf
1651
+ valid_rows = np.all(~np.isnan(X) & ~np.isinf(X), axis=1)
1652
+ X = X[valid_rows]
1653
+
1654
+ if len(X) < len(valid_factors) + 1:
1655
+ return [{"factor": f, "vif": float('nan'), "severity": "unknown"} for f in valid_factors]
1656
+
1657
+ # Standardize
1658
+ X = (X - np.mean(X, axis=0)) / (np.std(X, axis=0) + 1e-10)
1659
+
1660
+ results = []
1661
+ for i, name in enumerate(valid_factors):
1662
+ # Regress factor i on all others
1663
+ y = X[:, i]
1664
+ others = np.delete(X, i, axis=1)
1665
+
1666
+ # Add intercept
1667
+ others_with_int = np.column_stack([np.ones(len(others)), others])
1668
+
1669
+ try:
1670
+ # OLS: beta = (X'X)^-1 X'y
1671
+ beta = np.linalg.lstsq(others_with_int, y, rcond=None)[0]
1672
+ y_pred = others_with_int @ beta
1673
+
1674
+ ss_res = np.sum((y - y_pred) ** 2)
1675
+ ss_tot = np.sum((y - np.mean(y)) ** 2)
1676
+
1677
+ r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
1678
+ vif = 1 / (1 - r2) if r2 < 1 else float('inf')
1679
+ except:
1680
+ vif = float('nan')
1681
+
1682
+ if np.isnan(vif) or np.isinf(vif):
1683
+ severity = "unknown"
1684
+ elif vif > 10:
1685
+ severity = "severe"
1686
+ elif vif > 5:
1687
+ severity = "moderate"
1688
+ else:
1689
+ severity = "none"
1690
+
1691
+ results.append({
1692
+ "factor": name,
1693
+ "vif": float(vif) if not np.isinf(vif) else 999.0,
1694
+ "severity": severity,
1695
+ })
1696
+
1697
+ results.sort(key=lambda x: -x["vif"] if not np.isnan(x["vif"]) else 0)
1698
+ return results
1699
+
1700
+ def compute_missing_values(
1701
+ self,
1702
+ data: "pl.DataFrame",
1703
+ categorical_factors: List[str],
1704
+ continuous_factors: List[str],
1705
+ ) -> Dict[str, Any]:
1706
+ """
1707
+ Analyze missing values across all factors.
1708
+ """
1709
+ all_factors = categorical_factors + continuous_factors
1710
+ factor_missing = []
1711
+ total_rows = len(data)
1712
+
1713
+ for name in all_factors:
1714
+ if name not in data.columns:
1715
+ continue
1716
+
1717
+ col = data[name]
1718
+ n_missing = col.null_count()
1719
+ pct_missing = 100.0 * n_missing / total_rows if total_rows > 0 else 0
1720
+
1721
+ factor_missing.append({
1722
+ "factor": name,
1723
+ "n_missing": int(n_missing),
1724
+ "pct_missing": float(pct_missing),
1725
+ "severity": "high" if pct_missing > 10 else ("moderate" if pct_missing > 1 else "none"),
1726
+ })
1727
+
1728
+ factor_missing.sort(key=lambda x: -x["pct_missing"])
1729
+
1730
+ # Count rows with any missing
1731
+ any_missing = 0
1732
+ for name in all_factors:
1733
+ if name in data.columns:
1734
+ any_missing += data[name].null_count()
1735
+
1736
+ return {
1737
+ "total_rows": total_rows,
1738
+ "factors_with_missing": [f for f in factor_missing if f["n_missing"] > 0],
1739
+ "n_complete_rows": total_rows - sum(f["n_missing"] for f in factor_missing),
1740
+ "summary": "No missing values" if all(f["n_missing"] == 0 for f in factor_missing) else "Missing values present",
1741
+ }
1742
+
1743
+ def compute_zero_inflation(self) -> Dict[str, Any]:
1744
+ """
1745
+ Check for zero inflation in count data.
1746
+
1747
+ Compares observed zeros to expected zeros under Poisson assumption.
1748
+ """
1749
+ y = self.y
1750
+ n = len(y)
1751
+
1752
+ observed_zeros = int(np.sum(y == 0))
1753
+ observed_zero_pct = 100.0 * observed_zeros / n if n > 0 else 0
1754
+
1755
+ # Expected zeros under Poisson: P(Y=0) = exp(-lambda) where lambda = mean
1756
+ mean_y = np.mean(y)
1757
+ if mean_y > 0:
1758
+ expected_zero_pct = 100.0 * np.exp(-mean_y)
1759
+ excess_zeros = observed_zero_pct - expected_zero_pct
1760
+ else:
1761
+ expected_zero_pct = 100.0
1762
+ excess_zeros = 0.0
1763
+
1764
+ # Severity assessment
1765
+ if excess_zeros > 20:
1766
+ severity = "severe"
1767
+ recommendation = "Consider zero-inflated model (ZIP, ZINB)"
1768
+ elif excess_zeros > 10:
1769
+ severity = "moderate"
1770
+ recommendation = "Consider zero-inflated or hurdle model"
1771
+ elif excess_zeros > 5:
1772
+ severity = "mild"
1773
+ recommendation = "Monitor; may need zero-inflated model"
1774
+ else:
1775
+ severity = "none"
1776
+ recommendation = "Standard Poisson/NegBin likely adequate"
1777
+
1778
+ return {
1779
+ "observed_zeros": observed_zeros,
1780
+ "observed_zero_pct": float(observed_zero_pct),
1781
+ "expected_zero_pct_poisson": float(expected_zero_pct),
1782
+ "excess_zero_pct": float(excess_zeros),
1783
+ "severity": severity,
1784
+ "recommendation": recommendation,
1785
+ }
1786
+
1787
+ def compute_overdispersion(self) -> Dict[str, Any]:
1788
+ """
1789
+ Check for overdispersion in count data.
1790
+
1791
+ Compares variance to mean (Poisson assumes Var = Mean).
1792
+ """
1793
+ y = self.y
1794
+ exposure = self.exposure
1795
+
1796
+ # Compute rate
1797
+ rate = y / exposure
1798
+
1799
+ # Weighted mean and variance
1800
+ total_exp = np.sum(exposure)
1801
+ mean_rate = np.sum(y) / total_exp
1802
+
1803
+ # Variance of rates (exposure-weighted)
1804
+ var_rate = np.sum(exposure * (rate - mean_rate) ** 2) / total_exp
1805
+
1806
+ # For Poisson with exposure, expected variance of rate is mean_rate / exposure
1807
+ # Aggregate expected variance
1808
+ expected_var = mean_rate * np.sum(1.0 / exposure * exposure) / total_exp # = mean_rate
1809
+
1810
+ # Dispersion ratio
1811
+ if expected_var > 0:
1812
+ dispersion_ratio = var_rate / expected_var
1813
+ else:
1814
+ dispersion_ratio = 1.0
1815
+
1816
+ # Also compute using counts directly
1817
+ mean_count = np.mean(y)
1818
+ var_count = np.var(y, ddof=1)
1819
+ count_dispersion = var_count / mean_count if mean_count > 0 else 1.0
1820
+
1821
+ # Severity assessment
1822
+ if count_dispersion > 5:
1823
+ severity = "severe"
1824
+ recommendation = "Use Negative Binomial or QuasiPoisson"
1825
+ elif count_dispersion > 2:
1826
+ severity = "moderate"
1827
+ recommendation = "Consider Negative Binomial or QuasiPoisson"
1828
+ elif count_dispersion > 1.5:
1829
+ severity = "mild"
1830
+ recommendation = "Monitor; Poisson may underestimate standard errors"
1831
+ else:
1832
+ severity = "none"
1833
+ recommendation = "Poisson assumption reasonable"
1834
+
1835
+ return {
1836
+ "mean_count": float(mean_count),
1837
+ "var_count": float(var_count),
1838
+ "dispersion_ratio": float(count_dispersion),
1839
+ "severity": severity,
1840
+ "recommendation": recommendation,
1841
+ }
1842
+
1843
+ def compute_cramers_v(
1844
+ self,
1845
+ data: "pl.DataFrame",
1846
+ categorical_factors: List[str],
1847
+ ) -> Dict[str, Any]:
1848
+ """
1849
+ Compute Cramér's V matrix for categorical factor pairs.
1850
+
1851
+ Cramér's V measures association between categorical variables (0 to 1).
1852
+ """
1853
+ valid_factors = [f for f in categorical_factors if f in data.columns]
1854
+
1855
+ if len(valid_factors) < 2:
1856
+ return {"factors": valid_factors, "matrix": [], "high_associations": []}
1857
+
1858
+ n_factors = len(valid_factors)
1859
+ v_matrix = np.eye(n_factors)
1860
+
1861
+ for i in range(n_factors):
1862
+ for j in range(i + 1, n_factors):
1863
+ v = self._compute_cramers_v_pair(
1864
+ data[valid_factors[i]].to_numpy(),
1865
+ data[valid_factors[j]].to_numpy(),
1866
+ )
1867
+ v_matrix[i, j] = v
1868
+ v_matrix[j, i] = v
1869
+
1870
+ # Find high associations (V > 0.3)
1871
+ high_assoc = []
1872
+ for i in range(n_factors):
1873
+ for j in range(i + 1, n_factors):
1874
+ v = v_matrix[i, j]
1875
+ if not np.isnan(v) and v > 0.3:
1876
+ high_assoc.append({
1877
+ "factor1": valid_factors[i],
1878
+ "factor2": valid_factors[j],
1879
+ "cramers_v": float(v),
1880
+ "severity": "high" if v > 0.5 else "moderate",
1881
+ })
1882
+
1883
+ high_assoc.sort(key=lambda x: -x["cramers_v"])
1884
+
1885
+ return {
1886
+ "factors": valid_factors,
1887
+ "matrix": v_matrix.tolist(),
1888
+ "high_associations": high_assoc,
1889
+ }
1890
+
1891
+ def _compute_cramers_v_pair(self, x: np.ndarray, y: np.ndarray) -> float:
1892
+ """Compute Cramér's V for a pair of categorical variables."""
1893
+ # Build contingency table
1894
+ x_str = x.astype(str)
1895
+ y_str = y.astype(str)
1896
+
1897
+ x_cats = np.unique(x_str)
1898
+ y_cats = np.unique(y_str)
1899
+
1900
+ r, k = len(x_cats), len(y_cats)
1901
+ if r < 2 or k < 2:
1902
+ return 0.0
1903
+
1904
+ # Count frequencies
1905
+ contingency = np.zeros((r, k))
1906
+ for i, xc in enumerate(x_cats):
1907
+ for j, yc in enumerate(y_cats):
1908
+ contingency[i, j] = np.sum((x_str == xc) & (y_str == yc))
1909
+
1910
+ n = contingency.sum()
1911
+ if n == 0:
1912
+ return 0.0
1913
+
1914
+ # Chi-squared statistic
1915
+ row_sums = contingency.sum(axis=1, keepdims=True)
1916
+ col_sums = contingency.sum(axis=0, keepdims=True)
1917
+ expected = row_sums * col_sums / n
1918
+
1919
+ # Avoid division by zero
1920
+ with np.errstate(divide='ignore', invalid='ignore'):
1921
+ chi2 = np.sum((contingency - expected) ** 2 / expected)
1922
+ chi2 = np.nan_to_num(chi2, nan=0.0, posinf=0.0, neginf=0.0)
1923
+
1924
+ # Cramér's V
1925
+ min_dim = min(r - 1, k - 1)
1926
+ if min_dim == 0 or n == 0:
1927
+ return 0.0
1928
+
1929
+ v = np.sqrt(chi2 / (n * min_dim))
1930
+ return float(v)
1931
+
1932
+ def detect_interactions(
1933
+ self,
1934
+ data: "pl.DataFrame",
1935
+ factor_names: List[str],
1936
+ max_factors: int = 10,
1937
+ min_effect_size: float = 0.01,
1938
+ max_candidates: int = 5,
1939
+ min_cell_count: int = 30,
1940
+ ) -> List[InteractionCandidate]:
1941
+ """
1942
+ Detect potential interactions using response-based analysis.
1943
+
1944
+ This identifies factors whose combined effect on the response
1945
+ differs from their individual effects, suggesting an interaction.
1946
+ """
1947
+ # First, rank factors by their effect on response variance
1948
+ factor_scores = []
1949
+
1950
+ for name in factor_names:
1951
+ if name not in data.columns:
1952
+ continue
1953
+
1954
+ values = data[name].to_numpy()
1955
+
1956
+ # Compute eta-squared (variance explained)
1957
+ if values.dtype == object or str(values.dtype).startswith('str'):
1958
+ score = self._compute_eta_squared_response(values.astype(str))
1959
+ else:
1960
+ values = values.astype(np.float64)
1961
+ valid_mask = ~np.isnan(values) & ~np.isinf(values)
1962
+ if np.sum(valid_mask) < 10:
1963
+ continue
1964
+ # Bin continuous variables
1965
+ bins = self._discretize(values, 5)
1966
+ score = self._compute_eta_squared_response(bins.astype(str))
1967
+
1968
+ if score >= min_effect_size:
1969
+ factor_scores.append((name, score))
1970
+
1971
+ # Sort and take top factors
1972
+ factor_scores.sort(key=lambda x: -x[1])
1973
+ top_factors = [name for name, _ in factor_scores[:max_factors]]
1974
+
1975
+ if len(top_factors) < 2:
1976
+ return []
1977
+
1978
+ # Check pairwise interactions
1979
+ candidates = []
1980
+
1981
+ for i in range(len(top_factors)):
1982
+ for j in range(i + 1, len(top_factors)):
1983
+ name1, name2 = top_factors[i], top_factors[j]
1984
+
1985
+ values1 = data[name1].to_numpy()
1986
+ values2 = data[name2].to_numpy()
1987
+
1988
+ # Discretize both factors
1989
+ bins1 = self._discretize(values1, 5)
1990
+ bins2 = self._discretize(values2, 5)
1991
+
1992
+ # Compute interaction strength
1993
+ candidate = self._compute_interaction_strength_response(
1994
+ name1, bins1, name2, bins2, min_cell_count
1995
+ )
1996
+
1997
+ if candidate is not None:
1998
+ candidates.append(candidate)
1999
+
2000
+ # Sort by strength and return top candidates
2001
+ candidates.sort(key=lambda x: -x.interaction_strength)
2002
+ return candidates[:max_candidates]
2003
+
2004
+ def _compute_eta_squared_response(self, categories: np.ndarray) -> float:
2005
+ """Compute eta-squared for categorical association with response."""
2006
+ y_rate = self.y / self.exposure
2007
+ unique_levels = np.unique(categories)
2008
+ overall_mean = np.average(y_rate, weights=self.exposure)
2009
+
2010
+ ss_total = np.sum(self.exposure * (y_rate - overall_mean) ** 2)
2011
+
2012
+ if ss_total == 0:
2013
+ return 0.0
2014
+
2015
+ ss_between = 0.0
2016
+ for level in unique_levels:
2017
+ mask = categories == level
2018
+ level_rate = y_rate[mask]
2019
+ level_exp = self.exposure[mask]
2020
+ level_mean = np.average(level_rate, weights=level_exp)
2021
+ ss_between += np.sum(level_exp) * (level_mean - overall_mean) ** 2
2022
+
2023
+ return ss_between / ss_total
2024
+
2025
+ def _discretize(self, values: np.ndarray, n_bins: int) -> np.ndarray:
2026
+ """Discretize values into bins."""
2027
+ if values.dtype == object or str(values.dtype).startswith('str'):
2028
+ unique_vals = np.unique(values)
2029
+ mapping = {v: i for i, v in enumerate(unique_vals)}
2030
+ return np.array([mapping[v] for v in values])
2031
+ else:
2032
+ values = values.astype(np.float64)
2033
+ valid_mask = ~np.isnan(values) & ~np.isinf(values)
2034
+
2035
+ if not np.any(valid_mask):
2036
+ return np.zeros(len(values), dtype=int)
2037
+
2038
+ quantiles = np.percentile(values[valid_mask], np.linspace(0, 100, n_bins + 1))
2039
+ bins = np.digitize(values, quantiles[1:-1])
2040
+ bins[~valid_mask] = n_bins
2041
+ return bins
2042
+
2043
+ def _compute_interaction_strength_response(
2044
+ self,
2045
+ name1: str,
2046
+ bins1: np.ndarray,
2047
+ name2: str,
2048
+ bins2: np.ndarray,
2049
+ min_cell_count: int,
2050
+ ) -> Optional[InteractionCandidate]:
2051
+ """Compute interaction strength based on response variance."""
2052
+ y_rate = self.y / self.exposure
2053
+
2054
+ # Create interaction cells
2055
+ cell_ids = bins1 * 1000 + bins2
2056
+ unique_cells = np.unique(cell_ids)
2057
+
2058
+ # Filter cells with sufficient data
2059
+ valid_cells = []
2060
+ cell_rates = []
2061
+ cell_weights = []
2062
+
2063
+ for cell_id in unique_cells:
2064
+ mask = cell_ids == cell_id
2065
+ if np.sum(mask) >= min_cell_count:
2066
+ valid_cells.append(cell_id)
2067
+ cell_rates.append(y_rate[mask])
2068
+ cell_weights.append(self.exposure[mask])
2069
+
2070
+ if len(valid_cells) < 4:
2071
+ return None
2072
+
2073
+ # Compute variance explained by cells
2074
+ all_rates = np.concatenate(cell_rates)
2075
+ all_weights = np.concatenate(cell_weights)
2076
+ overall_mean = np.average(all_rates, weights=all_weights)
2077
+
2078
+ ss_total = np.sum(all_weights * (all_rates - overall_mean) ** 2)
2079
+
2080
+ if ss_total == 0:
2081
+ return None
2082
+
2083
+ ss_model = sum(
2084
+ np.sum(w) * (np.average(r, weights=w) - overall_mean) ** 2
2085
+ for r, w in zip(cell_rates, cell_weights)
2086
+ )
2087
+
2088
+ r_squared = ss_model / ss_total
2089
+
2090
+ # F-test p-value
2091
+ df_model = len(valid_cells) - 1
2092
+ df_resid = len(all_rates) - len(valid_cells)
2093
+
2094
+ if df_model > 0 and df_resid > 0:
2095
+ f_stat = (ss_model / df_model) / ((ss_total - ss_model) / df_resid)
2096
+
2097
+ try:
2098
+ from scipy.stats import f
2099
+ pvalue = 1 - f.cdf(f_stat, df_model, df_resid)
2100
+ except ImportError:
2101
+ pvalue = float('nan')
2102
+ else:
2103
+ pvalue = float('nan')
2104
+
2105
+ return InteractionCandidate(
2106
+ factor1=name1,
2107
+ factor2=name2,
2108
+ interaction_strength=float(r_squared),
2109
+ pvalue=float(pvalue),
2110
+ n_cells=len(valid_cells),
2111
+ )
2112
+
2113
+
2114
+ def explore_data(
2115
+ data: "pl.DataFrame",
2116
+ response: str,
2117
+ categorical_factors: Optional[List[str]] = None,
2118
+ continuous_factors: Optional[List[str]] = None,
2119
+ exposure: Optional[str] = None,
2120
+ family: str = "poisson",
2121
+ n_bins: int = 10,
2122
+ rare_threshold_pct: float = 1.0,
2123
+ max_categorical_levels: int = 20,
2124
+ detect_interactions: bool = True,
2125
+ max_interaction_factors: int = 10,
2126
+ ) -> DataExploration:
2127
+ """
2128
+ Explore data before model fitting.
2129
+
2130
+ This function provides pre-fit analysis including factor statistics
2131
+ and interaction detection without requiring a fitted model.
2132
+
2133
+ Results are automatically saved to 'analysis/exploration.json'.
2134
+
2135
+ Parameters
2136
+ ----------
2137
+ data : pl.DataFrame
2138
+ Data to explore.
2139
+ response : str
2140
+ Name of the response variable column.
2141
+ categorical_factors : list of str, optional
2142
+ Names of categorical factors to analyze.
2143
+ continuous_factors : list of str, optional
2144
+ Names of continuous factors to analyze.
2145
+ exposure : str, optional
2146
+ Name of the exposure/weights column.
2147
+ family : str, default="poisson"
2148
+ Expected family (for appropriate statistics).
2149
+ n_bins : int, default=10
2150
+ Number of bins for continuous factors.
2151
+ rare_threshold_pct : float, default=1.0
2152
+ Threshold (%) below which categorical levels are grouped.
2153
+ max_categorical_levels : int, default=20
2154
+ Maximum categorical levels to show.
2155
+ detect_interactions : bool, default=True
2156
+ Whether to detect potential interactions.
2157
+ max_interaction_factors : int, default=10
2158
+ Maximum factors for interaction detection.
2159
+
2160
+ Returns
2161
+ -------
2162
+ DataExploration
2163
+ Pre-fit exploration results with to_json() method.
2164
+
2165
+ Examples
2166
+ --------
2167
+ >>> import rustystats as rs
2168
+ >>>
2169
+ >>> # Explore data before fitting
2170
+ >>> exploration = rs.explore_data(
2171
+ ... data=data,
2172
+ ... response="ClaimNb",
2173
+ ... categorical_factors=["Region", "VehBrand"],
2174
+ ... continuous_factors=["Age", "VehPower"],
2175
+ ... exposure="Exposure",
2176
+ ... family="poisson",
2177
+ ... )
2178
+ >>>
2179
+ >>> # View interaction candidates
2180
+ >>> for ic in exploration.interaction_candidates:
2181
+ ... print(f"{ic.factor1} x {ic.factor2}: {ic.interaction_strength:.3f}")
2182
+ >>>
2183
+ >>> # Export as JSON
2184
+ >>> print(exploration.to_json())
2185
+ """
2186
+ categorical_factors = list(dict.fromkeys(categorical_factors or [])) # Dedupe preserving order
2187
+ continuous_factors = list(dict.fromkeys(continuous_factors or [])) # Dedupe preserving order
2188
+
2189
+ # Extract response and exposure
2190
+ y = data[response].to_numpy().astype(np.float64)
2191
+ exp = data[exposure].to_numpy().astype(np.float64) if exposure else None
2192
+
2193
+ # Create explorer
2194
+ explorer = DataExplorer(y=y, exposure=exp, family=family)
2195
+
2196
+ # Compute statistics
2197
+ response_stats = explorer.compute_response_stats()
2198
+
2199
+ factor_stats = explorer.compute_factor_stats(
2200
+ data=data,
2201
+ categorical_factors=categorical_factors,
2202
+ continuous_factors=continuous_factors,
2203
+ n_bins=n_bins,
2204
+ rare_threshold_pct=rare_threshold_pct,
2205
+ max_categorical_levels=max_categorical_levels,
2206
+ )
2207
+
2208
+ # Univariate significance tests
2209
+ univariate_tests = explorer.compute_univariate_tests(
2210
+ data=data,
2211
+ categorical_factors=categorical_factors,
2212
+ continuous_factors=continuous_factors,
2213
+ )
2214
+
2215
+ # Correlations between continuous factors
2216
+ correlations = explorer.compute_correlations(
2217
+ data=data,
2218
+ continuous_factors=continuous_factors,
2219
+ )
2220
+
2221
+ # VIF for multicollinearity
2222
+ vif = explorer.compute_vif(
2223
+ data=data,
2224
+ continuous_factors=continuous_factors,
2225
+ )
2226
+
2227
+ # Missing value analysis
2228
+ missing_values = explorer.compute_missing_values(
2229
+ data=data,
2230
+ categorical_factors=categorical_factors,
2231
+ continuous_factors=continuous_factors,
2232
+ )
2233
+
2234
+ # Cramér's V for categorical pairs
2235
+ cramers_v = explorer.compute_cramers_v(
2236
+ data=data,
2237
+ categorical_factors=categorical_factors,
2238
+ )
2239
+
2240
+ # Zero inflation check (for count data)
2241
+ zero_inflation = explorer.compute_zero_inflation()
2242
+
2243
+ # Overdispersion check
2244
+ overdispersion = explorer.compute_overdispersion()
2245
+
2246
+ # Interaction detection
2247
+ interaction_candidates = []
2248
+ if detect_interactions and len(categorical_factors) + len(continuous_factors) >= 2:
2249
+ all_factors = categorical_factors + continuous_factors
2250
+ interaction_candidates = explorer.detect_interactions(
2251
+ data=data,
2252
+ factor_names=all_factors,
2253
+ max_factors=max_interaction_factors,
2254
+ min_effect_size=0.001, # Lower threshold to catch more interactions
2255
+ )
2256
+
2257
+ # Data summary
2258
+ data_summary = {
2259
+ "n_rows": len(data),
2260
+ "n_columns": len(data.columns),
2261
+ "response_column": response,
2262
+ "exposure_column": exposure,
2263
+ "n_categorical_factors": len(categorical_factors),
2264
+ "n_continuous_factors": len(continuous_factors),
2265
+ }
2266
+
2267
+ result = DataExploration(
2268
+ data_summary=data_summary,
2269
+ factor_stats=factor_stats,
2270
+ missing_values=missing_values,
2271
+ univariate_tests=univariate_tests,
2272
+ correlations=correlations,
2273
+ cramers_v=cramers_v,
2274
+ vif=vif,
2275
+ zero_inflation=zero_inflation,
2276
+ overdispersion=overdispersion,
2277
+ interaction_candidates=interaction_candidates,
2278
+ response_stats=response_stats,
2279
+ )
2280
+
2281
+ # Auto-save JSON to analysis folder
2282
+ import os
2283
+ os.makedirs("analysis", exist_ok=True)
2284
+ with open("analysis/exploration.json", "w") as f:
2285
+ f.write(result.to_json(indent=2))
2286
+
2287
+ return result
2288
+
2289
+
2290
+ # =============================================================================
2291
+ # Post-Fit Model Diagnostics
2292
+ # =============================================================================
2293
+
2294
+ def compute_diagnostics(
2295
+ result, # GLMResults or FormulaGLMResults
2296
+ data: "pl.DataFrame",
2297
+ categorical_factors: Optional[List[str]] = None,
2298
+ continuous_factors: Optional[List[str]] = None,
2299
+ n_calibration_bins: int = 10,
2300
+ n_factor_bins: int = 10,
2301
+ rare_threshold_pct: float = 1.0,
2302
+ max_categorical_levels: int = 20,
2303
+ detect_interactions: bool = True,
2304
+ max_interaction_factors: int = 10,
2305
+ ) -> ModelDiagnostics:
2306
+ """
2307
+ Compute comprehensive model diagnostics.
2308
+
2309
+ Results are automatically saved to 'analysis/diagnostics.json'.
2310
+
2311
+ Parameters
2312
+ ----------
2313
+ result : GLMResults or FormulaGLMResults
2314
+ Fitted model results.
2315
+ data : pl.DataFrame
2316
+ Original data used for fitting.
2317
+ categorical_factors : list of str, optional
2318
+ Names of categorical factors to analyze.
2319
+ continuous_factors : list of str, optional
2320
+ Names of continuous factors to analyze.
2321
+ n_calibration_bins : int, default=10
2322
+ Number of bins for calibration curve.
2323
+ n_factor_bins : int, default=10
2324
+ Number of quantile bins for continuous factors.
2325
+ rare_threshold_pct : float, default=1.0
2326
+ Threshold (%) below which categorical levels are grouped into "Other".
2327
+ max_categorical_levels : int, default=20
2328
+ Maximum number of categorical levels to show (rest grouped to "Other").
2329
+ detect_interactions : bool, default=True
2330
+ Whether to detect potential interactions.
2331
+ max_interaction_factors : int, default=10
2332
+ Maximum number of factors to consider for interaction detection.
2333
+
2334
+ Returns
2335
+ -------
2336
+ ModelDiagnostics
2337
+ Complete diagnostics object with to_json() method.
2338
+ """
2339
+ # Deduplicate factors while preserving order
2340
+ categorical_factors = list(dict.fromkeys(categorical_factors or []))
2341
+ continuous_factors = list(dict.fromkeys(continuous_factors or []))
2342
+ # Remove any overlap (a factor can't be both categorical and continuous)
2343
+ continuous_factors = [f for f in continuous_factors if f not in categorical_factors]
2344
+
2345
+ # Extract what we need from result
2346
+ # Get y from the residuals (y = mu + response_residuals)
2347
+ mu = np.asarray(result.fittedvalues, dtype=np.float64)
2348
+ response_resid = np.asarray(result.resid_response(), dtype=np.float64)
2349
+ y = mu + response_resid
2350
+
2351
+ lp = np.asarray(result.linear_predictor, dtype=np.float64)
2352
+ family = result.family if hasattr(result, 'family') else "unknown"
2353
+ n_params = len(result.params)
2354
+ deviance = result.deviance
2355
+ feature_names = result.feature_names if hasattr(result, 'feature_names') else []
2356
+
2357
+ # Try to get exposure from data if weights column exists
2358
+ exposure = None
2359
+
2360
+ # Extract family parameters
2361
+ var_power = 1.5
2362
+ theta = 1.0
2363
+ if "tweedie" in family.lower():
2364
+ # Try to parse var_power from family string
2365
+ import re
2366
+ match = re.search(r'p=(\d+\.?\d*)', family)
2367
+ if match:
2368
+ var_power = float(match.group(1))
2369
+ if "negbinomial" in family.lower() or "negativebinomial" in family.lower():
2370
+ import re
2371
+ match = re.search(r'theta=(\d+\.?\d*)', family)
2372
+ if match:
2373
+ theta = float(match.group(1))
2374
+
2375
+ # Create computer
2376
+ computer = DiagnosticsComputer(
2377
+ y=y,
2378
+ mu=mu,
2379
+ linear_predictor=lp,
2380
+ family=family,
2381
+ n_params=n_params,
2382
+ deviance=deviance,
2383
+ exposure=exposure,
2384
+ feature_names=feature_names,
2385
+ var_power=var_power,
2386
+ theta=theta,
2387
+ )
2388
+
2389
+ # Compute all diagnostics
2390
+ fit_stats = computer.compute_fit_statistics()
2391
+ loss_metrics = computer.compute_loss_metrics()
2392
+ calibration = computer.compute_calibration(n_calibration_bins)
2393
+ discrimination = computer.compute_discrimination()
2394
+ residual_summary = computer.compute_residual_summary()
2395
+
2396
+ factors = computer.compute_factor_diagnostics(
2397
+ data=data,
2398
+ categorical_factors=categorical_factors,
2399
+ continuous_factors=continuous_factors,
2400
+ result=result, # Pass result for significance tests
2401
+ n_bins=n_factor_bins,
2402
+ rare_threshold_pct=rare_threshold_pct,
2403
+ max_categorical_levels=max_categorical_levels,
2404
+ )
2405
+
2406
+ # Interaction detection
2407
+ interaction_candidates = []
2408
+ if detect_interactions and len(categorical_factors) + len(continuous_factors) >= 2:
2409
+ all_factors = categorical_factors + continuous_factors
2410
+ interaction_candidates = computer.detect_interactions(
2411
+ data=data,
2412
+ factor_names=all_factors,
2413
+ max_factors=max_interaction_factors,
2414
+ )
2415
+
2416
+ model_comparison = computer.compute_model_comparison()
2417
+ warnings = computer.generate_warnings(fit_stats, calibration, factors)
2418
+
2419
+ # Extract convergence info
2420
+ converged = result.converged if hasattr(result, 'converged') else True
2421
+ iterations = result.iterations if hasattr(result, 'iterations') else 0
2422
+ max_iter = 25 # Default max iterations
2423
+
2424
+ # Determine convergence reason
2425
+ if converged:
2426
+ reason = "converged"
2427
+ elif iterations >= max_iter:
2428
+ reason = "max_iterations_reached"
2429
+ else:
2430
+ reason = "unknown"
2431
+
2432
+ convergence_details = ConvergenceDetails(
2433
+ max_iterations_allowed=max_iter,
2434
+ iterations_used=iterations,
2435
+ converged=converged,
2436
+ reason=reason,
2437
+ )
2438
+
2439
+ # Model summary
2440
+ model_summary = {
2441
+ "formula": result.formula if hasattr(result, 'formula') else None,
2442
+ "family": family,
2443
+ "link": result.link if hasattr(result, 'link') else "unknown",
2444
+ "n_observations": computer.n_obs,
2445
+ "n_parameters": n_params,
2446
+ "degrees_of_freedom_residual": computer.df_resid,
2447
+ "converged": converged,
2448
+ "iterations": iterations,
2449
+ }
2450
+
2451
+ diagnostics = ModelDiagnostics(
2452
+ model_summary=model_summary,
2453
+ convergence_details=convergence_details,
2454
+ fit_statistics=fit_stats,
2455
+ loss_metrics=loss_metrics,
2456
+ calibration=calibration,
2457
+ discrimination=discrimination,
2458
+ residual_summary=residual_summary,
2459
+ factors=factors,
2460
+ interaction_candidates=interaction_candidates,
2461
+ model_comparison=model_comparison,
2462
+ warnings=warnings,
2463
+ )
2464
+
2465
+ # Auto-save JSON to analysis folder
2466
+ import os
2467
+ os.makedirs("analysis", exist_ok=True)
2468
+ with open("analysis/diagnostics.json", "w") as f:
2469
+ f.write(diagnostics.to_json(indent=2))
2470
+
2471
+ return diagnostics