ins-pricing 0.3.0__tar.gz → 0.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/PKG-INFO +162 -162
  2. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/cli/Explain_entry.py +50 -48
  3. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/cli/bayesopt_entry_runner.py +73 -70
  4. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +4 -3
  5. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +4 -3
  6. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/utils/torch_trainer_mixin.py +7 -3
  7. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/setup.py +1 -1
  8. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing.egg-info/PKG-INFO +162 -162
  9. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/pyproject.toml +1 -1
  10. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/setup.cfg +4 -4
  11. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/MANIFEST.in +0 -0
  12. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/README.md +0 -0
  13. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/CHANGELOG.md +0 -0
  14. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/README.md +0 -0
  15. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/RELEASE_NOTES_0.2.8.md +0 -0
  16. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/__init__.py +0 -0
  17. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/cli/BayesOpt_entry.py +0 -0
  18. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/cli/BayesOpt_incremental.py +0 -0
  19. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/cli/Explain_Run.py +0 -0
  20. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/cli/Pricing_Run.py +0 -0
  21. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/cli/__init__.py +0 -0
  22. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/cli/utils/__init__.py +0 -0
  23. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/cli/utils/cli_common.py +0 -0
  24. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/cli/utils/cli_config.py +0 -0
  25. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/cli/utils/evaluation_context.py +0 -0
  26. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/cli/utils/import_resolver.py +0 -0
  27. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/cli/utils/notebook_utils.py +0 -0
  28. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/cli/utils/run_logging.py +0 -0
  29. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/cli/watchdog_run.py +0 -0
  30. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/docs/modelling/BayesOpt_USAGE.md +0 -0
  31. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/docs/modelling/README.md +0 -0
  32. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/exceptions.py +0 -0
  33. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/governance/README.md +0 -0
  34. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/governance/__init__.py +0 -0
  35. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/governance/approval.py +0 -0
  36. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/governance/audit.py +0 -0
  37. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/governance/registry.py +0 -0
  38. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/governance/release.py +0 -0
  39. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/__init__.py +0 -0
  40. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/BayesOpt.py +0 -0
  41. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/__init__.py +0 -0
  42. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +0 -0
  43. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +0 -0
  44. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +0 -0
  45. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/__init__.py +0 -0
  46. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/config_components.py +0 -0
  47. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/config_preprocess.py +0 -0
  48. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/core.py +0 -0
  49. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/model_explain_mixin.py +0 -0
  50. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/model_plotting_mixin.py +0 -0
  51. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/models/__init__.py +0 -0
  52. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/models/model_ft_components.py +0 -0
  53. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/models/model_gnn.py +0 -0
  54. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/models/model_resn.py +0 -0
  55. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -0
  56. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py +0 -0
  57. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/trainers/trainer_glm.py +0 -0
  58. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py +0 -0
  59. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/trainers/trainer_resn.py +0 -0
  60. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/trainers/trainer_xgb.py +0 -0
  61. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -0
  62. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -0
  63. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/utils/distributed_utils.py +0 -0
  64. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -0
  65. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -0
  66. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/utils.py +0 -0
  67. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -0
  68. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/core/evaluation.py +0 -0
  69. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/explain/__init__.py +0 -0
  70. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/explain/gradients.py +0 -0
  71. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/explain/metrics.py +0 -0
  72. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/explain/permutation.py +0 -0
  73. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/explain/shap_utils.py +0 -0
  74. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/plotting/__init__.py +0 -0
  75. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/plotting/common.py +0 -0
  76. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/plotting/curves.py +0 -0
  77. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/plotting/diagnostics.py +0 -0
  78. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/plotting/geo.py +0 -0
  79. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/modelling/plotting/importance.py +0 -0
  80. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/pricing/README.md +0 -0
  81. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/pricing/__init__.py +0 -0
  82. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/pricing/calibration.py +0 -0
  83. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/pricing/data_quality.py +0 -0
  84. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/pricing/exposure.py +0 -0
  85. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/pricing/factors.py +0 -0
  86. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/pricing/monitoring.py +0 -0
  87. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/pricing/rate_table.py +0 -0
  88. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/production/__init__.py +0 -0
  89. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/production/drift.py +0 -0
  90. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/production/monitoring.py +0 -0
  91. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/production/predict.py +0 -0
  92. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/production/preprocess.py +0 -0
  93. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/production/scoring.py +0 -0
  94. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/reporting/README.md +0 -0
  95. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/reporting/__init__.py +0 -0
  96. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/reporting/report_builder.py +0 -0
  97. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/reporting/scheduler.py +0 -0
  98. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/tests/governance/__init__.py +0 -0
  99. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/tests/governance/test_audit.py +0 -0
  100. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/tests/governance/test_registry.py +0 -0
  101. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/tests/governance/test_release.py +0 -0
  102. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/tests/modelling/conftest.py +0 -0
  103. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/tests/modelling/test_cross_val_generic.py +0 -0
  104. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/tests/modelling/test_distributed_utils.py +0 -0
  105. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/tests/modelling/test_explain.py +0 -0
  106. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/tests/modelling/test_geo_tokens_split.py +0 -0
  107. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/tests/modelling/test_graph_cache.py +0 -0
  108. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/tests/modelling/test_plotting.py +0 -0
  109. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/tests/modelling/test_plotting_library.py +0 -0
  110. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/tests/modelling/test_preprocessor.py +0 -0
  111. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/tests/pricing/__init__.py +0 -0
  112. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/tests/pricing/test_calibration.py +0 -0
  113. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/tests/pricing/test_exposure.py +0 -0
  114. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/tests/pricing/test_factors.py +0 -0
  115. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/tests/pricing/test_rate_table.py +0 -0
  116. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/tests/production/__init__.py +0 -0
  117. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/tests/production/test_monitoring.py +0 -0
  118. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/tests/production/test_predict.py +0 -0
  119. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/tests/production/test_preprocess.py +0 -0
  120. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/tests/production/test_scoring.py +0 -0
  121. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/utils/__init__.py +0 -0
  122. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/utils/device.py +0 -0
  123. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/utils/logging.py +0 -0
  124. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/utils/metrics.py +0 -0
  125. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/utils/paths.py +0 -0
  126. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/utils/profiling.py +0 -0
  127. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing/utils/validation.py +0 -0
  128. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing.egg-info/SOURCES.txt +0 -0
  129. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing.egg-info/dependency_links.txt +0 -0
  130. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing.egg-info/requires.txt +0 -0
  131. {ins_pricing-0.3.0 → ins_pricing-0.3.1}/ins_pricing.egg-info/top_level.txt +0 -0
@@ -1,162 +1,162 @@
1
- Metadata-Version: 2.4
2
- Name: ins_pricing
3
- Version: 0.3.0
4
- Summary: Reusable modelling, pricing, governance, and reporting utilities.
5
- Author: meishi125478
6
- License: Proprietary
7
- Keywords: pricing,insurance,bayesopt,ml
8
- Classifier: Programming Language :: Python :: 3
9
- Classifier: Programming Language :: Python :: 3 :: Only
10
- Classifier: Programming Language :: Python :: 3.9
11
- Classifier: License :: Other/Proprietary License
12
- Classifier: Operating System :: OS Independent
13
- Classifier: Intended Audience :: Developers
14
- Requires-Python: >=3.9
15
- Description-Content-Type: text/markdown
16
- Requires-Dist: numpy>=1.20
17
- Requires-Dist: pandas>=1.4
18
- Provides-Extra: bayesopt
19
- Requires-Dist: torch>=1.13; extra == "bayesopt"
20
- Requires-Dist: optuna>=3.0; extra == "bayesopt"
21
- Requires-Dist: xgboost>=1.6; extra == "bayesopt"
22
- Requires-Dist: scikit-learn>=1.1; extra == "bayesopt"
23
- Requires-Dist: statsmodels>=0.13; extra == "bayesopt"
24
- Requires-Dist: joblib>=1.2; extra == "bayesopt"
25
- Requires-Dist: matplotlib>=3.5; extra == "bayesopt"
26
- Provides-Extra: plotting
27
- Requires-Dist: matplotlib>=3.5; extra == "plotting"
28
- Requires-Dist: scikit-learn>=1.1; extra == "plotting"
29
- Provides-Extra: explain
30
- Requires-Dist: torch>=1.13; extra == "explain"
31
- Requires-Dist: shap>=0.41; extra == "explain"
32
- Requires-Dist: scikit-learn>=1.1; extra == "explain"
33
- Provides-Extra: geo
34
- Requires-Dist: contextily>=1.3; extra == "geo"
35
- Requires-Dist: matplotlib>=3.5; extra == "geo"
36
- Provides-Extra: gnn
37
- Requires-Dist: torch>=1.13; extra == "gnn"
38
- Requires-Dist: pynndescent>=0.5; extra == "gnn"
39
- Requires-Dist: torch-geometric>=2.3; extra == "gnn"
40
- Provides-Extra: all
41
- Requires-Dist: torch>=1.13; extra == "all"
42
- Requires-Dist: optuna>=3.0; extra == "all"
43
- Requires-Dist: xgboost>=1.6; extra == "all"
44
- Requires-Dist: scikit-learn>=1.1; extra == "all"
45
- Requires-Dist: statsmodels>=0.13; extra == "all"
46
- Requires-Dist: joblib>=1.2; extra == "all"
47
- Requires-Dist: matplotlib>=3.5; extra == "all"
48
- Requires-Dist: shap>=0.41; extra == "all"
49
- Requires-Dist: contextily>=1.3; extra == "all"
50
- Requires-Dist: pynndescent>=0.5; extra == "all"
51
- Requires-Dist: torch-geometric>=2.3; extra == "all"
52
-
53
- # Insurance-Pricing
54
-
55
- A reusable toolkit for insurance modeling, pricing, governance, and reporting.
56
-
57
- ## Overview
58
-
59
- Insurance-Pricing (ins_pricing) is an enterprise-grade Python library designed for machine learning model training, pricing calculations, and model governance workflows in the insurance industry.
60
-
61
- ### Core Modules
62
-
63
- | Module | Description |
64
- |--------|-------------|
65
- | **modelling** | ML model training (GLM, XGBoost, ResNet, FT-Transformer, GNN) and model interpretability (SHAP, permutation importance) |
66
- | **pricing** | Factor table construction, numeric binning, premium calibration, exposure calculation, PSI monitoring |
67
- | **production** | Model prediction, batch scoring, data drift detection, production metrics monitoring |
68
- | **governance** | Model registry, version management, approval workflows, audit logging |
69
- | **reporting** | Report generation (Markdown format), report scheduling |
70
- | **utils** | Data validation, performance profiling, device management, logging configuration |
71
-
72
- ### Quick Start
73
-
74
- ```python
75
- # Model training with Bayesian optimization
76
- from ins_pricing import bayesopt as ropt
77
-
78
- model = ropt.BayesOptModel(
79
- train_data, test_data,
80
- model_name='my_model',
81
- resp_nme='target',
82
- weight_nme='weight',
83
- factor_nmes=feature_list,
84
- cate_list=categorical_features,
85
- )
86
- model.bayesopt_xgb(max_evals=100) # Train XGBoost
87
- model.bayesopt_resnet(max_evals=50) # Train ResNet
88
- model.bayesopt_ft(max_evals=50) # Train FT-Transformer
89
-
90
- # Pricing: build factor table
91
- from ins_pricing.pricing import build_factor_table
92
- factors = build_factor_table(
93
- df,
94
- factor_col='age_band',
95
- loss_col='claim_amount',
96
- exposure_col='exposure',
97
- )
98
-
99
- # Production: batch scoring
100
- from ins_pricing.production import batch_score
101
- scores = batch_score(model.trainers['xgb'].predict, df)
102
-
103
- # Model governance
104
- from ins_pricing.governance import ModelRegistry
105
- registry = ModelRegistry('models.json')
106
- registry.register(model_name, version, metrics=metrics)
107
- ```
108
-
109
- ### Project Structure
110
-
111
- ```
112
- ins_pricing/
113
- ├── cli/ # Command-line entry points
114
- ├── modelling/
115
- │ ├── core/bayesopt/ # ML model training core
116
- │ ├── explain/ # Model interpretability
117
- │ └── plotting/ # Model visualization
118
- ├── pricing/ # Insurance pricing module
119
- ├── production/ # Production deployment module
120
- ├── governance/ # Model governance
121
- ├── reporting/ # Report generation
122
- ├── utils/ # Utilities
123
- └── tests/ # Test suite
124
- ```
125
-
126
- ### Installation
127
-
128
- ```bash
129
- # Basic installation
130
- pip install ins_pricing
131
-
132
- # Full installation (all optional dependencies)
133
- pip install ins_pricing[all]
134
-
135
- # Install specific extras
136
- pip install ins_pricing[bayesopt] # Model training
137
- pip install ins_pricing[explain] # Model explanation
138
- pip install ins_pricing[plotting] # Visualization
139
- pip install ins_pricing[gnn] # Graph neural networks
140
- ```
141
-
142
- #### Multi-platform & GPU installation notes
143
-
144
- - **PyTorch (CPU/GPU/MPS)**: Install the correct PyTorch build for your platform/GPU first (CUDA on
145
- Linux/Windows, ROCm on supported AMD platforms, or MPS on Apple Silicon). Then install the
146
- optional extras you need (e.g., `bayesopt`, `explain`, or `gnn`). This avoids pip pulling a
147
- mismatched wheel.
148
- - **Torch Geometric (GNN)**: `torch-geometric` often requires platform-specific wheels (e.g.,
149
- `torch-scatter`, `torch-sparse`). Follow the official PyG installation instructions for your
150
- CUDA/ROCm/CPU environment, then install `ins_pricing[gnn]`.
151
- - **Multi-GPU**: Training code will use CUDA when available and can enable multi-GPU via
152
- `torch.distributed`/`DataParallel` where supported. On Windows, CUDA DDP is not supported and will
153
- fall back to single-GPU or DataParallel where possible.
154
-
155
- ### Requirements
156
-
157
- - Python >= 3.9
158
- - Core dependencies: numpy >= 1.20, pandas >= 1.4
159
-
160
- ### License
161
-
162
- Proprietary
1
+ Metadata-Version: 2.4
2
+ Name: ins_pricing
3
+ Version: 0.3.1
4
+ Summary: Reusable modelling, pricing, governance, and reporting utilities.
5
+ Author: meishi125478
6
+ License: Proprietary
7
+ Keywords: pricing,insurance,bayesopt,ml
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3 :: Only
10
+ Classifier: Programming Language :: Python :: 3.9
11
+ Classifier: License :: Other/Proprietary License
12
+ Classifier: Operating System :: OS Independent
13
+ Classifier: Intended Audience :: Developers
14
+ Requires-Python: >=3.9
15
+ Description-Content-Type: text/markdown
16
+ Requires-Dist: numpy>=1.20
17
+ Requires-Dist: pandas>=1.4
18
+ Provides-Extra: bayesopt
19
+ Requires-Dist: torch>=1.13; extra == "bayesopt"
20
+ Requires-Dist: optuna>=3.0; extra == "bayesopt"
21
+ Requires-Dist: xgboost>=1.6; extra == "bayesopt"
22
+ Requires-Dist: scikit-learn>=1.1; extra == "bayesopt"
23
+ Requires-Dist: statsmodels>=0.13; extra == "bayesopt"
24
+ Requires-Dist: joblib>=1.2; extra == "bayesopt"
25
+ Requires-Dist: matplotlib>=3.5; extra == "bayesopt"
26
+ Provides-Extra: plotting
27
+ Requires-Dist: matplotlib>=3.5; extra == "plotting"
28
+ Requires-Dist: scikit-learn>=1.1; extra == "plotting"
29
+ Provides-Extra: explain
30
+ Requires-Dist: torch>=1.13; extra == "explain"
31
+ Requires-Dist: shap>=0.41; extra == "explain"
32
+ Requires-Dist: scikit-learn>=1.1; extra == "explain"
33
+ Provides-Extra: geo
34
+ Requires-Dist: contextily>=1.3; extra == "geo"
35
+ Requires-Dist: matplotlib>=3.5; extra == "geo"
36
+ Provides-Extra: gnn
37
+ Requires-Dist: torch>=1.13; extra == "gnn"
38
+ Requires-Dist: pynndescent>=0.5; extra == "gnn"
39
+ Requires-Dist: torch-geometric>=2.3; extra == "gnn"
40
+ Provides-Extra: all
41
+ Requires-Dist: torch>=1.13; extra == "all"
42
+ Requires-Dist: optuna>=3.0; extra == "all"
43
+ Requires-Dist: xgboost>=1.6; extra == "all"
44
+ Requires-Dist: scikit-learn>=1.1; extra == "all"
45
+ Requires-Dist: statsmodels>=0.13; extra == "all"
46
+ Requires-Dist: joblib>=1.2; extra == "all"
47
+ Requires-Dist: matplotlib>=3.5; extra == "all"
48
+ Requires-Dist: shap>=0.41; extra == "all"
49
+ Requires-Dist: contextily>=1.3; extra == "all"
50
+ Requires-Dist: pynndescent>=0.5; extra == "all"
51
+ Requires-Dist: torch-geometric>=2.3; extra == "all"
52
+
53
+ # Insurance-Pricing
54
+
55
+ A reusable toolkit for insurance modeling, pricing, governance, and reporting.
56
+
57
+ ## Overview
58
+
59
+ Insurance-Pricing (ins_pricing) is an enterprise-grade Python library designed for machine learning model training, pricing calculations, and model governance workflows in the insurance industry.
60
+
61
+ ### Core Modules
62
+
63
+ | Module | Description |
64
+ |--------|-------------|
65
+ | **modelling** | ML model training (GLM, XGBoost, ResNet, FT-Transformer, GNN) and model interpretability (SHAP, permutation importance) |
66
+ | **pricing** | Factor table construction, numeric binning, premium calibration, exposure calculation, PSI monitoring |
67
+ | **production** | Model prediction, batch scoring, data drift detection, production metrics monitoring |
68
+ | **governance** | Model registry, version management, approval workflows, audit logging |
69
+ | **reporting** | Report generation (Markdown format), report scheduling |
70
+ | **utils** | Data validation, performance profiling, device management, logging configuration |
71
+
72
+ ### Quick Start
73
+
74
+ ```python
75
+ # Model training with Bayesian optimization
76
+ from ins_pricing import bayesopt as ropt
77
+
78
+ model = ropt.BayesOptModel(
79
+ train_data, test_data,
80
+ model_name='my_model',
81
+ resp_nme='target',
82
+ weight_nme='weight',
83
+ factor_nmes=feature_list,
84
+ cate_list=categorical_features,
85
+ )
86
+ model.bayesopt_xgb(max_evals=100) # Train XGBoost
87
+ model.bayesopt_resnet(max_evals=50) # Train ResNet
88
+ model.bayesopt_ft(max_evals=50) # Train FT-Transformer
89
+
90
+ # Pricing: build factor table
91
+ from ins_pricing.pricing import build_factor_table
92
+ factors = build_factor_table(
93
+ df,
94
+ factor_col='age_band',
95
+ loss_col='claim_amount',
96
+ exposure_col='exposure',
97
+ )
98
+
99
+ # Production: batch scoring
100
+ from ins_pricing.production import batch_score
101
+ scores = batch_score(model.trainers['xgb'].predict, df)
102
+
103
+ # Model governance
104
+ from ins_pricing.governance import ModelRegistry
105
+ registry = ModelRegistry('models.json')
106
+ registry.register(model_name, version, metrics=metrics)
107
+ ```
108
+
109
+ ### Project Structure
110
+
111
+ ```
112
+ ins_pricing/
113
+ ├── cli/ # Command-line entry points
114
+ ├── modelling/
115
+ │ ├── core/bayesopt/ # ML model training core
116
+ │ ├── explain/ # Model interpretability
117
+ │ └── plotting/ # Model visualization
118
+ ├── pricing/ # Insurance pricing module
119
+ ├── production/ # Production deployment module
120
+ ├── governance/ # Model governance
121
+ ├── reporting/ # Report generation
122
+ ├── utils/ # Utilities
123
+ └── tests/ # Test suite
124
+ ```
125
+
126
+ ### Installation
127
+
128
+ ```bash
129
+ # Basic installation
130
+ pip install ins_pricing
131
+
132
+ # Full installation (all optional dependencies)
133
+ pip install ins_pricing[all]
134
+
135
+ # Install specific extras
136
+ pip install ins_pricing[bayesopt] # Model training
137
+ pip install ins_pricing[explain] # Model explanation
138
+ pip install ins_pricing[plotting] # Visualization
139
+ pip install ins_pricing[gnn] # Graph neural networks
140
+ ```
141
+
142
+ #### Multi-platform & GPU installation notes
143
+
144
+ - **PyTorch (CPU/GPU/MPS)**: Install the correct PyTorch build for your platform/GPU first (CUDA on
145
+ Linux/Windows, ROCm on supported AMD platforms, or MPS on Apple Silicon). Then install the
146
+ optional extras you need (e.g., `bayesopt`, `explain`, or `gnn`). This avoids pip pulling a
147
+ mismatched wheel.
148
+ - **Torch Geometric (GNN)**: `torch-geometric` often requires platform-specific wheels (e.g.,
149
+ `torch-scatter`, `torch-sparse`). Follow the official PyG installation instructions for your
150
+ CUDA/ROCm/CPU environment, then install `ins_pricing[gnn]`.
151
+ - **Multi-GPU**: Training code will use CUDA when available and can enable multi-GPU via
152
+ `torch.distributed`/`DataParallel` where supported. On Windows, CUDA DDP is not supported and will
153
+ fall back to single-GPU or DataParallel where possible.
154
+
155
+ ### Requirements
156
+
157
+ - Python >= 3.9
158
+ - Core dependencies: numpy >= 1.20, pandas >= 1.4
159
+
160
+ ### License
161
+
162
+ Proprietary
@@ -491,54 +491,56 @@ def explain_from_config(args: argparse.Namespace) -> None:
491
491
  categorical_features = cfg.get("categorical_features")
492
492
  plot_path_style = runtime_cfg["plot_path_style"]
493
493
 
494
- model = ropt.BayesOptModel(
495
- train_df,
496
- test_df,
497
- model_name,
498
- cfg["target"],
499
- cfg["weight"],
500
- feature_list,
501
- task_type=str(cfg.get("task_type", "regression")),
502
- binary_resp_nme=binary_target,
503
- cate_list=categorical_features,
504
- prop_test=prop_test,
505
- rand_seed=rand_seed,
506
- epochs=int(runtime_cfg["epochs"]),
507
- use_gpu=bool(cfg.get("use_gpu", True)),
508
- output_dir=output_dir,
509
- xgb_max_depth_max=runtime_cfg["xgb_max_depth_max"],
510
- xgb_n_estimators_max=runtime_cfg["xgb_n_estimators_max"],
511
- resn_weight_decay=cfg.get("resn_weight_decay"),
512
- final_ensemble=bool(cfg.get("final_ensemble", False)),
513
- final_ensemble_k=int(cfg.get("final_ensemble_k", 3)),
514
- final_refit=bool(cfg.get("final_refit", True)),
515
- optuna_storage=runtime_cfg["optuna_storage"],
516
- optuna_study_prefix=runtime_cfg["optuna_study_prefix"],
517
- best_params_files=runtime_cfg["best_params_files"],
518
- gnn_use_approx_knn=cfg.get("gnn_use_approx_knn", True),
519
- gnn_approx_knn_threshold=cfg.get("gnn_approx_knn_threshold", 50000),
520
- gnn_graph_cache=cfg.get("gnn_graph_cache"),
521
- gnn_max_gpu_knn_nodes=cfg.get("gnn_max_gpu_knn_nodes", 200000),
522
- gnn_knn_gpu_mem_ratio=cfg.get("gnn_knn_gpu_mem_ratio", 0.9),
523
- gnn_knn_gpu_mem_overhead=cfg.get("gnn_knn_gpu_mem_overhead", 2.0),
524
- region_province_col=cfg.get("region_province_col"),
525
- region_city_col=cfg.get("region_city_col"),
526
- region_effect_alpha=cfg.get("region_effect_alpha"),
527
- geo_feature_nmes=cfg.get("geo_feature_nmes"),
528
- geo_token_hidden_dim=cfg.get("geo_token_hidden_dim"),
529
- geo_token_layers=cfg.get("geo_token_layers"),
530
- geo_token_dropout=cfg.get("geo_token_dropout"),
531
- geo_token_k_neighbors=cfg.get("geo_token_k_neighbors"),
532
- geo_token_learning_rate=cfg.get("geo_token_learning_rate"),
533
- geo_token_epochs=cfg.get("geo_token_epochs"),
534
- ft_role=str(cfg.get("ft_role", "model")),
535
- ft_feature_prefix=str(cfg.get("ft_feature_prefix", "ft_emb")),
536
- ft_num_numeric_tokens=cfg.get("ft_num_numeric_tokens"),
537
- infer_categorical_max_unique=int(cfg.get("infer_categorical_max_unique", 50)),
538
- infer_categorical_max_ratio=float(cfg.get("infer_categorical_max_ratio", 0.05)),
539
- reuse_best_params=runtime_cfg["reuse_best_params"],
540
- plot_path_style=plot_path_style,
541
- )
494
+ config_fields = getattr(ropt.BayesOptConfig, "__dataclass_fields__", {})
495
+ allowed_config_keys = set(config_fields.keys())
496
+ config_payload = {k: v for k, v in cfg.items() if k in allowed_config_keys}
497
+ config_payload.update({
498
+ "model_nme": model_name,
499
+ "resp_nme": cfg["target"],
500
+ "weight_nme": cfg["weight"],
501
+ "factor_nmes": feature_list,
502
+ "task_type": str(cfg.get("task_type", "regression")),
503
+ "binary_resp_nme": binary_target,
504
+ "cate_list": categorical_features,
505
+ "prop_test": prop_test,
506
+ "rand_seed": rand_seed,
507
+ "epochs": int(runtime_cfg["epochs"]),
508
+ "use_gpu": bool(cfg.get("use_gpu", True)),
509
+ "output_dir": output_dir,
510
+ "xgb_max_depth_max": runtime_cfg["xgb_max_depth_max"],
511
+ "xgb_n_estimators_max": runtime_cfg["xgb_n_estimators_max"],
512
+ "resn_weight_decay": cfg.get("resn_weight_decay"),
513
+ "final_ensemble": bool(cfg.get("final_ensemble", False)),
514
+ "final_ensemble_k": int(cfg.get("final_ensemble_k", 3)),
515
+ "final_refit": bool(cfg.get("final_refit", True)),
516
+ "optuna_storage": runtime_cfg["optuna_storage"],
517
+ "optuna_study_prefix": runtime_cfg["optuna_study_prefix"],
518
+ "best_params_files": runtime_cfg["best_params_files"],
519
+ "gnn_use_approx_knn": cfg.get("gnn_use_approx_knn", True),
520
+ "gnn_approx_knn_threshold": cfg.get("gnn_approx_knn_threshold", 50000),
521
+ "gnn_graph_cache": cfg.get("gnn_graph_cache"),
522
+ "gnn_max_gpu_knn_nodes": cfg.get("gnn_max_gpu_knn_nodes", 200000),
523
+ "gnn_knn_gpu_mem_ratio": cfg.get("gnn_knn_gpu_mem_ratio", 0.9),
524
+ "gnn_knn_gpu_mem_overhead": cfg.get("gnn_knn_gpu_mem_overhead", 2.0),
525
+ "region_province_col": cfg.get("region_province_col"),
526
+ "region_city_col": cfg.get("region_city_col"),
527
+ "region_effect_alpha": cfg.get("region_effect_alpha"),
528
+ "geo_feature_nmes": cfg.get("geo_feature_nmes"),
529
+ "geo_token_hidden_dim": cfg.get("geo_token_hidden_dim"),
530
+ "geo_token_layers": cfg.get("geo_token_layers"),
531
+ "geo_token_dropout": cfg.get("geo_token_dropout"),
532
+ "geo_token_k_neighbors": cfg.get("geo_token_k_neighbors"),
533
+ "geo_token_learning_rate": cfg.get("geo_token_learning_rate"),
534
+ "geo_token_epochs": cfg.get("geo_token_epochs"),
535
+ "ft_role": str(cfg.get("ft_role", "model")),
536
+ "ft_feature_prefix": str(cfg.get("ft_feature_prefix", "ft_emb")),
537
+ "ft_num_numeric_tokens": cfg.get("ft_num_numeric_tokens"),
538
+ "reuse_best_params": runtime_cfg["reuse_best_params"],
539
+ "plot_path_style": plot_path_style or "nested",
540
+ })
541
+ config_payload = {k: v for k, v in config_payload.items() if v is not None}
542
+ config = ropt.BayesOptConfig(**config_payload)
543
+ model = ropt.BayesOptModel(train_df, test_df, config=config)
542
544
 
543
545
  output_overrides = resolve_explain_output_overrides(
544
546
  explain_cfg,
@@ -1223,76 +1223,79 @@ def train_from_config(args: argparse.Namespace) -> None:
1223
1223
  cfg.get("ft_feature_prefix", args.ft_feature_prefix))
1224
1224
  ft_num_numeric_tokens = cfg.get("ft_num_numeric_tokens")
1225
1225
 
1226
- model = ropt.BayesOptModel(
1227
- train_df,
1228
- test_df,
1229
- model_name,
1230
- cfg["target"],
1231
- cfg["weight"],
1232
- feature_list,
1233
- task_type=task_type,
1234
- binary_resp_nme=binary_target,
1235
- cate_list=categorical_features,
1236
- prop_test=val_ratio,
1237
- rand_seed=rand_seed,
1238
- epochs=epochs,
1239
- use_gpu=use_gpu,
1240
- use_resn_data_parallel=use_resn_dp,
1241
- use_ft_data_parallel=use_ft_dp,
1242
- use_resn_ddp=use_resn_ddp,
1243
- use_ft_ddp=use_ft_ddp,
1244
- use_gnn_data_parallel=use_gnn_dp,
1245
- use_gnn_ddp=use_gnn_ddp,
1246
- output_dir=output_dir,
1247
- xgb_max_depth_max=xgb_max_depth_max,
1248
- xgb_n_estimators_max=xgb_n_estimators_max,
1249
- resn_weight_decay=cfg.get("resn_weight_decay"),
1250
- final_ensemble=bool(cfg.get("final_ensemble", False)),
1251
- final_ensemble_k=int(cfg.get("final_ensemble_k", 3)),
1252
- final_refit=bool(cfg.get("final_refit", True)),
1253
- optuna_storage=optuna_storage,
1254
- optuna_study_prefix=optuna_study_prefix,
1255
- best_params_files=best_params_files,
1256
- gnn_use_approx_knn=gnn_use_ann,
1257
- gnn_approx_knn_threshold=gnn_threshold,
1258
- gnn_graph_cache=gnn_graph_cache,
1259
- gnn_max_gpu_knn_nodes=gnn_max_gpu_nodes,
1260
- gnn_knn_gpu_mem_ratio=gnn_gpu_mem_ratio,
1261
- gnn_knn_gpu_mem_overhead=gnn_gpu_mem_overhead,
1262
- region_province_col=region_province_col,
1263
- region_city_col=region_city_col,
1264
- region_effect_alpha=region_effect_alpha,
1265
- geo_feature_nmes=geo_feature_nmes,
1266
- geo_token_hidden_dim=geo_token_hidden_dim,
1267
- geo_token_layers=geo_token_layers,
1268
- geo_token_dropout=geo_token_dropout,
1269
- geo_token_k_neighbors=geo_token_k_neighbors,
1270
- geo_token_learning_rate=geo_token_learning_rate,
1271
- geo_token_epochs=geo_token_epochs,
1272
- ft_role=ft_role,
1273
- ft_feature_prefix=ft_feature_prefix,
1274
- ft_num_numeric_tokens=ft_num_numeric_tokens,
1275
- infer_categorical_max_unique=int(
1276
- cfg.get("infer_categorical_max_unique", 50)),
1277
- infer_categorical_max_ratio=float(
1278
- cfg.get("infer_categorical_max_ratio", 0.05)),
1279
- reuse_best_params=reuse_best_params,
1280
- bo_sample_limit=bo_sample_limit,
1281
- cache_predictions=cache_predictions,
1282
- prediction_cache_dir=prediction_cache_dir,
1283
- prediction_cache_format=prediction_cache_format,
1284
- cv_strategy=cv_strategy or split_strategy,
1285
- cv_group_col=cv_group_col or split_group_col,
1286
- cv_time_col=cv_time_col or split_time_col,
1287
- cv_time_ascending=cv_time_ascending,
1288
- cv_splits=cv_splits,
1289
- ft_oof_folds=ft_oof_folds,
1290
- ft_oof_strategy=ft_oof_strategy,
1291
- ft_oof_shuffle=ft_oof_shuffle,
1292
- save_preprocess=save_preprocess,
1293
- preprocess_artifact_path=preprocess_artifact_path,
1294
- plot_path_style=plot_path_style,
1295
- )
1226
+ config_fields = getattr(ropt.BayesOptConfig,
1227
+ "__dataclass_fields__", {})
1228
+ allowed_config_keys = set(config_fields.keys())
1229
+ config_payload = {k: v for k,
1230
+ v in cfg.items() if k in allowed_config_keys}
1231
+ config_payload.update({
1232
+ "model_nme": model_name,
1233
+ "resp_nme": cfg["target"],
1234
+ "weight_nme": cfg["weight"],
1235
+ "factor_nmes": feature_list,
1236
+ "task_type": task_type,
1237
+ "binary_resp_nme": binary_target,
1238
+ "cate_list": categorical_features,
1239
+ "prop_test": val_ratio,
1240
+ "rand_seed": rand_seed,
1241
+ "epochs": epochs,
1242
+ "use_gpu": use_gpu,
1243
+ "use_resn_data_parallel": use_resn_dp,
1244
+ "use_ft_data_parallel": use_ft_dp,
1245
+ "use_gnn_data_parallel": use_gnn_dp,
1246
+ "use_resn_ddp": use_resn_ddp,
1247
+ "use_ft_ddp": use_ft_ddp,
1248
+ "use_gnn_ddp": use_gnn_ddp,
1249
+ "output_dir": output_dir,
1250
+ "xgb_max_depth_max": xgb_max_depth_max,
1251
+ "xgb_n_estimators_max": xgb_n_estimators_max,
1252
+ "resn_weight_decay": cfg.get("resn_weight_decay"),
1253
+ "final_ensemble": bool(cfg.get("final_ensemble", False)),
1254
+ "final_ensemble_k": int(cfg.get("final_ensemble_k", 3)),
1255
+ "final_refit": bool(cfg.get("final_refit", True)),
1256
+ "optuna_storage": optuna_storage,
1257
+ "optuna_study_prefix": optuna_study_prefix,
1258
+ "best_params_files": best_params_files,
1259
+ "gnn_use_approx_knn": gnn_use_ann,
1260
+ "gnn_approx_knn_threshold": gnn_threshold,
1261
+ "gnn_graph_cache": gnn_graph_cache,
1262
+ "gnn_max_gpu_knn_nodes": gnn_max_gpu_nodes,
1263
+ "gnn_knn_gpu_mem_ratio": gnn_gpu_mem_ratio,
1264
+ "gnn_knn_gpu_mem_overhead": gnn_gpu_mem_overhead,
1265
+ "region_province_col": region_province_col,
1266
+ "region_city_col": region_city_col,
1267
+ "region_effect_alpha": region_effect_alpha,
1268
+ "geo_feature_nmes": geo_feature_nmes,
1269
+ "geo_token_hidden_dim": geo_token_hidden_dim,
1270
+ "geo_token_layers": geo_token_layers,
1271
+ "geo_token_dropout": geo_token_dropout,
1272
+ "geo_token_k_neighbors": geo_token_k_neighbors,
1273
+ "geo_token_learning_rate": geo_token_learning_rate,
1274
+ "geo_token_epochs": geo_token_epochs,
1275
+ "ft_role": ft_role,
1276
+ "ft_feature_prefix": ft_feature_prefix,
1277
+ "ft_num_numeric_tokens": ft_num_numeric_tokens,
1278
+ "reuse_best_params": reuse_best_params,
1279
+ "bo_sample_limit": bo_sample_limit,
1280
+ "cache_predictions": cache_predictions,
1281
+ "prediction_cache_dir": prediction_cache_dir,
1282
+ "prediction_cache_format": prediction_cache_format,
1283
+ "cv_strategy": cv_strategy or split_strategy,
1284
+ "cv_group_col": cv_group_col or split_group_col,
1285
+ "cv_time_col": cv_time_col or split_time_col,
1286
+ "cv_time_ascending": cv_time_ascending,
1287
+ "cv_splits": cv_splits,
1288
+ "ft_oof_folds": ft_oof_folds,
1289
+ "ft_oof_strategy": ft_oof_strategy,
1290
+ "ft_oof_shuffle": ft_oof_shuffle,
1291
+ "save_preprocess": save_preprocess,
1292
+ "preprocess_artifact_path": preprocess_artifact_path,
1293
+ "plot_path_style": plot_path_style or "nested",
1294
+ })
1295
+ config_payload = {
1296
+ k: v for k, v in config_payload.items() if v is not None}
1297
+ config = ropt.BayesOptConfig(**config_payload)
1298
+ model = ropt.BayesOptModel(train_df, test_df, config=config)
1296
1299
 
1297
1300
  if plot_requested:
1298
1301
  plot_cfg = cfg.get("plot", {})
@@ -626,6 +626,7 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
626
626
  best_state = None
627
627
  patience_counter = 0
628
628
  is_ddp_model = isinstance(self.ft, DDP)
629
+ use_collectives = dist.is_initialized() and is_ddp_model
629
630
 
630
631
  clip_fn = None
631
632
  if self.device.type == 'cuda':
@@ -669,7 +670,7 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
669
670
  device=X_num_b.device)
670
671
  local_bad = 0 if bool(torch.isfinite(batch_loss)) else 1
671
672
  global_bad = local_bad
672
- if dist.is_initialized():
673
+ if use_collectives:
673
674
  bad = torch.tensor(
674
675
  [local_bad],
675
676
  device=batch_loss.device,
@@ -774,7 +775,7 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
774
775
  total_n += float(end - start)
775
776
  val_loss_tensor[0] = total_val / max(total_n, 1.0)
776
777
 
777
- if dist.is_initialized():
778
+ if use_collectives:
778
779
  dist.broadcast(val_loss_tensor, src=0)
779
780
  val_loss_value = float(val_loss_tensor.item())
780
781
  prune_now = False
@@ -806,7 +807,7 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
806
807
  if trial.should_prune():
807
808
  prune_now = True
808
809
 
809
- if dist.is_initialized():
810
+ if use_collectives:
810
811
  flag = torch.tensor(
811
812
  [1 if prune_now else 0],
812
813
  device=loss_tensor_device,