ins-pricing 0.2.9__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/CHANGELOG.md +93 -0
- ins_pricing/README.md +11 -0
- ins_pricing/cli/Explain_entry.py +50 -48
- ins_pricing/cli/bayesopt_entry_runner.py +699 -569
- ins_pricing/cli/utils/evaluation_context.py +320 -0
- ins_pricing/cli/utils/import_resolver.py +350 -0
- ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +449 -0
- ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +406 -0
- ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +247 -0
- ins_pricing/modelling/core/bayesopt/config_components.py +351 -0
- ins_pricing/modelling/core/bayesopt/config_preprocess.py +3 -4
- ins_pricing/modelling/core/bayesopt/core.py +153 -94
- ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +122 -34
- ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +298 -142
- ins_pricing/modelling/core/bayesopt/utils/__init__.py +86 -0
- ins_pricing/modelling/core/bayesopt/utils/constants.py +183 -0
- ins_pricing/modelling/core/bayesopt/utils/distributed_utils.py +186 -0
- ins_pricing/modelling/core/bayesopt/utils/io_utils.py +126 -0
- ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +540 -0
- ins_pricing/modelling/core/bayesopt/utils/torch_trainer_mixin.py +591 -0
- ins_pricing/modelling/core/bayesopt/utils.py +98 -1496
- ins_pricing/modelling/core/bayesopt/utils_backup.py +1503 -0
- ins_pricing/setup.py +1 -1
- {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/METADATA +14 -1
- {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/RECORD +27 -14
- {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/WHEEL +0 -0
- {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
"""Data classes for evaluation and reporting context.
|
|
2
|
+
|
|
3
|
+
These data classes group related parameters together to reduce function signatures
|
|
4
|
+
and improve code readability.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any, Dict, List, Optional
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pandas as pd
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class ModelIdentity:
|
|
19
|
+
"""Identifies a model within the evaluation pipeline."""
|
|
20
|
+
|
|
21
|
+
model_name: str
|
|
22
|
+
model_key: str
|
|
23
|
+
version: str
|
|
24
|
+
task_type: str = "regression"
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def full_name(self) -> str:
|
|
28
|
+
"""Return the full model name with key."""
|
|
29
|
+
return f"{self.model_name}/{self.model_key}"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class DataFingerprint:
|
|
34
|
+
"""Fingerprint information for data provenance tracking."""
|
|
35
|
+
|
|
36
|
+
path: str
|
|
37
|
+
sha256_prefix: str = ""
|
|
38
|
+
size: str = ""
|
|
39
|
+
mtime: str = ""
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
def from_dict(cls, d: Dict[str, Any]) -> "DataFingerprint":
|
|
43
|
+
"""Create from a dictionary."""
|
|
44
|
+
return cls(
|
|
45
|
+
path=str(d.get("path", "")),
|
|
46
|
+
sha256_prefix=str(d.get("sha256_prefix", "")),
|
|
47
|
+
size=str(d.get("size", "")),
|
|
48
|
+
mtime=str(d.get("mtime", "")),
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
def to_dict(self) -> Dict[str, str]:
|
|
52
|
+
"""Convert to dictionary."""
|
|
53
|
+
return {
|
|
54
|
+
"path": self.path,
|
|
55
|
+
"sha256_prefix": self.sha256_prefix,
|
|
56
|
+
"size": self.size,
|
|
57
|
+
"mtime": self.mtime,
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class CalibrationConfig:
|
|
63
|
+
"""Configuration for prediction calibration."""
|
|
64
|
+
|
|
65
|
+
enable: bool = False
|
|
66
|
+
method: str = "sigmoid"
|
|
67
|
+
max_rows: Optional[int] = None
|
|
68
|
+
seed: Optional[int] = None
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def from_dict(cls, d: Optional[Dict[str, Any]]) -> "CalibrationConfig":
|
|
72
|
+
"""Create from a dictionary."""
|
|
73
|
+
if not d:
|
|
74
|
+
return cls()
|
|
75
|
+
return cls(
|
|
76
|
+
enable=bool(d.get("enable", False) or d.get("method")),
|
|
77
|
+
method=str(d.get("method", "sigmoid")),
|
|
78
|
+
max_rows=d.get("max_rows"),
|
|
79
|
+
seed=d.get("seed"),
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass
|
|
84
|
+
class ThresholdConfig:
|
|
85
|
+
"""Configuration for classification threshold selection."""
|
|
86
|
+
|
|
87
|
+
enable: bool = False
|
|
88
|
+
metric: str = "f1"
|
|
89
|
+
value: Optional[float] = None
|
|
90
|
+
min_positive_rate: Optional[float] = None
|
|
91
|
+
grid: int = 99
|
|
92
|
+
max_rows: Optional[int] = None
|
|
93
|
+
seed: Optional[int] = None
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def from_dict(cls, d: Optional[Dict[str, Any]]) -> "ThresholdConfig":
|
|
97
|
+
"""Create from a dictionary."""
|
|
98
|
+
if not d:
|
|
99
|
+
return cls()
|
|
100
|
+
return cls(
|
|
101
|
+
enable=bool(
|
|
102
|
+
d.get("enable", False)
|
|
103
|
+
or d.get("metric")
|
|
104
|
+
or d.get("value") is not None
|
|
105
|
+
),
|
|
106
|
+
metric=str(d.get("metric", "f1")),
|
|
107
|
+
value=float(d["value"]) if d.get("value") is not None else None,
|
|
108
|
+
min_positive_rate=d.get("min_positive_rate"),
|
|
109
|
+
grid=int(d.get("grid", 99)),
|
|
110
|
+
max_rows=d.get("max_rows"),
|
|
111
|
+
seed=d.get("seed"),
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@dataclass
|
|
116
|
+
class BootstrapConfig:
|
|
117
|
+
"""Configuration for bootstrap confidence intervals."""
|
|
118
|
+
|
|
119
|
+
enable: bool = False
|
|
120
|
+
metrics: Optional[List[str]] = None
|
|
121
|
+
n_samples: int = 200
|
|
122
|
+
ci: float = 0.95
|
|
123
|
+
seed: Optional[int] = None
|
|
124
|
+
|
|
125
|
+
@classmethod
|
|
126
|
+
def from_dict(cls, d: Optional[Dict[str, Any]]) -> "BootstrapConfig":
|
|
127
|
+
"""Create from a dictionary."""
|
|
128
|
+
if not d:
|
|
129
|
+
return cls()
|
|
130
|
+
return cls(
|
|
131
|
+
enable=bool(d.get("enable", False)),
|
|
132
|
+
metrics=d.get("metrics"),
|
|
133
|
+
n_samples=int(d.get("n_samples", 200)),
|
|
134
|
+
ci=float(d.get("ci", 0.95)),
|
|
135
|
+
seed=d.get("seed"),
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@dataclass
|
|
140
|
+
class ReportConfig:
|
|
141
|
+
"""Configuration for report generation."""
|
|
142
|
+
|
|
143
|
+
output_dir: Optional[str] = None
|
|
144
|
+
group_cols: Optional[List[str]] = None
|
|
145
|
+
time_col: Optional[str] = None
|
|
146
|
+
time_freq: str = "M"
|
|
147
|
+
time_ascending: bool = True
|
|
148
|
+
|
|
149
|
+
@classmethod
|
|
150
|
+
def from_dict(cls, d: Dict[str, Any]) -> "ReportConfig":
|
|
151
|
+
"""Create from a dictionary."""
|
|
152
|
+
return cls(
|
|
153
|
+
output_dir=d.get("report_output_dir"),
|
|
154
|
+
group_cols=d.get("report_group_cols"),
|
|
155
|
+
time_col=d.get("report_time_col"),
|
|
156
|
+
time_freq=str(d.get("report_time_freq", "M")),
|
|
157
|
+
time_ascending=bool(d.get("report_time_ascending", True)),
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
@dataclass
|
|
162
|
+
class RegistryConfig:
|
|
163
|
+
"""Configuration for model registry."""
|
|
164
|
+
|
|
165
|
+
register: bool = False
|
|
166
|
+
path: Optional[str] = None
|
|
167
|
+
tags: Dict[str, Any] = field(default_factory=dict)
|
|
168
|
+
status: str = "candidate"
|
|
169
|
+
|
|
170
|
+
@classmethod
|
|
171
|
+
def from_dict(cls, d: Dict[str, Any]) -> "RegistryConfig":
|
|
172
|
+
"""Create from a dictionary."""
|
|
173
|
+
return cls(
|
|
174
|
+
register=bool(d.get("register_model", False)),
|
|
175
|
+
path=d.get("registry_path"),
|
|
176
|
+
tags=dict(d.get("registry_tags") or {}),
|
|
177
|
+
status=str(d.get("registry_status", "candidate")),
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
@dataclass
|
|
182
|
+
class MetricsResult:
|
|
183
|
+
"""Results from metrics computation."""
|
|
184
|
+
|
|
185
|
+
metrics: Dict[str, float] = field(default_factory=dict)
|
|
186
|
+
threshold_info: Optional[Dict[str, Any]] = None
|
|
187
|
+
calibration_info: Optional[Dict[str, Any]] = None
|
|
188
|
+
bootstrap_results: Dict[str, Dict[str, float]] = field(default_factory=dict)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
@dataclass
|
|
192
|
+
class EvaluationContext:
|
|
193
|
+
"""Complete context for model evaluation and reporting.
|
|
194
|
+
|
|
195
|
+
This groups all the parameters needed for _evaluate_and_report into a single
|
|
196
|
+
object, reducing the function signature from 19+ parameters to 1.
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
# Model identification
|
|
200
|
+
identity: ModelIdentity
|
|
201
|
+
|
|
202
|
+
# Data info
|
|
203
|
+
data_path: Path
|
|
204
|
+
data_fingerprint: DataFingerprint
|
|
205
|
+
config_sha: str
|
|
206
|
+
run_id: str
|
|
207
|
+
|
|
208
|
+
# Prediction column
|
|
209
|
+
pred_col: str
|
|
210
|
+
|
|
211
|
+
# Configuration
|
|
212
|
+
calibration: CalibrationConfig = field(default_factory=CalibrationConfig)
|
|
213
|
+
threshold: ThresholdConfig = field(default_factory=ThresholdConfig)
|
|
214
|
+
bootstrap: BootstrapConfig = field(default_factory=BootstrapConfig)
|
|
215
|
+
report: ReportConfig = field(default_factory=ReportConfig)
|
|
216
|
+
registry: RegistryConfig = field(default_factory=RegistryConfig)
|
|
217
|
+
|
|
218
|
+
# Pre-computed reports
|
|
219
|
+
psi_report_df: Optional[pd.DataFrame] = None
|
|
220
|
+
|
|
221
|
+
# Full config dict (for artifact collection)
|
|
222
|
+
cfg: Dict[str, Any] = field(default_factory=dict)
|
|
223
|
+
|
|
224
|
+
@classmethod
|
|
225
|
+
def from_params(
|
|
226
|
+
cls,
|
|
227
|
+
model_name: str,
|
|
228
|
+
model_key: str,
|
|
229
|
+
cfg: Dict[str, Any],
|
|
230
|
+
data_path: Path,
|
|
231
|
+
data_fingerprint: Dict[str, Any],
|
|
232
|
+
run_id: str,
|
|
233
|
+
config_sha: str,
|
|
234
|
+
pred_col: str,
|
|
235
|
+
calibration_cfg: Dict[str, Any],
|
|
236
|
+
threshold_cfg: Dict[str, Any],
|
|
237
|
+
bootstrap_cfg: Dict[str, Any],
|
|
238
|
+
report_output_dir: Optional[str],
|
|
239
|
+
report_group_cols: Optional[List[str]],
|
|
240
|
+
report_time_col: Optional[str],
|
|
241
|
+
report_time_freq: str,
|
|
242
|
+
report_time_ascending: bool,
|
|
243
|
+
register_model: bool,
|
|
244
|
+
registry_path: Optional[str],
|
|
245
|
+
registry_tags: Dict[str, Any],
|
|
246
|
+
registry_status: str,
|
|
247
|
+
psi_report_df: Optional[pd.DataFrame] = None,
|
|
248
|
+
) -> "EvaluationContext":
|
|
249
|
+
"""Create from individual parameters (for backward compatibility)."""
|
|
250
|
+
task_type = str(cfg.get("task_type", "regression"))
|
|
251
|
+
version = f"{model_key}_{run_id}"
|
|
252
|
+
|
|
253
|
+
return cls(
|
|
254
|
+
identity=ModelIdentity(
|
|
255
|
+
model_name=model_name,
|
|
256
|
+
model_key=model_key,
|
|
257
|
+
version=version,
|
|
258
|
+
task_type=task_type,
|
|
259
|
+
),
|
|
260
|
+
data_path=data_path,
|
|
261
|
+
data_fingerprint=DataFingerprint.from_dict(data_fingerprint),
|
|
262
|
+
config_sha=config_sha,
|
|
263
|
+
run_id=run_id,
|
|
264
|
+
pred_col=pred_col,
|
|
265
|
+
calibration=CalibrationConfig.from_dict(calibration_cfg),
|
|
266
|
+
threshold=ThresholdConfig.from_dict(threshold_cfg),
|
|
267
|
+
bootstrap=BootstrapConfig.from_dict(bootstrap_cfg),
|
|
268
|
+
report=ReportConfig(
|
|
269
|
+
output_dir=report_output_dir,
|
|
270
|
+
group_cols=report_group_cols,
|
|
271
|
+
time_col=report_time_col,
|
|
272
|
+
time_freq=report_time_freq,
|
|
273
|
+
time_ascending=report_time_ascending,
|
|
274
|
+
),
|
|
275
|
+
registry=RegistryConfig(
|
|
276
|
+
register=register_model,
|
|
277
|
+
path=registry_path,
|
|
278
|
+
tags=registry_tags,
|
|
279
|
+
status=registry_status,
|
|
280
|
+
),
|
|
281
|
+
psi_report_df=psi_report_df,
|
|
282
|
+
cfg=cfg,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
@dataclass
|
|
287
|
+
class TrainingContext:
|
|
288
|
+
"""Context for distributed training orchestration."""
|
|
289
|
+
|
|
290
|
+
world_size: int = 1
|
|
291
|
+
rank: int = 0
|
|
292
|
+
local_rank: int = 0
|
|
293
|
+
is_distributed: bool = False
|
|
294
|
+
|
|
295
|
+
@property
|
|
296
|
+
def is_main_process(self) -> bool:
|
|
297
|
+
"""Check if this is the main process."""
|
|
298
|
+
return not self.is_distributed or self.rank == 0
|
|
299
|
+
|
|
300
|
+
@classmethod
|
|
301
|
+
def from_env(cls) -> "TrainingContext":
|
|
302
|
+
"""Create from environment variables."""
|
|
303
|
+
import os
|
|
304
|
+
|
|
305
|
+
def _safe_int_env(key: str, default: int) -> int:
|
|
306
|
+
try:
|
|
307
|
+
return int(os.environ.get(key, default))
|
|
308
|
+
except (TypeError, ValueError):
|
|
309
|
+
return default
|
|
310
|
+
|
|
311
|
+
world_size = _safe_int_env("WORLD_SIZE", 1)
|
|
312
|
+
rank = _safe_int_env("RANK", 0)
|
|
313
|
+
local_rank = _safe_int_env("LOCAL_RANK", 0)
|
|
314
|
+
|
|
315
|
+
return cls(
|
|
316
|
+
world_size=world_size,
|
|
317
|
+
rank=rank,
|
|
318
|
+
local_rank=local_rank,
|
|
319
|
+
is_distributed=world_size > 1,
|
|
320
|
+
)
|
|
@@ -0,0 +1,350 @@
|
|
|
1
|
+
"""Unified import resolver for CLI modules.
|
|
2
|
+
|
|
3
|
+
This module provides a single source of truth for all import fallback chains,
|
|
4
|
+
eliminating the need for nested try/except blocks in multiple CLI files.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
from ins_pricing.cli.utils.import_resolver import resolve_imports
|
|
8
|
+
imports = resolve_imports()
|
|
9
|
+
ropt = imports.bayesopt
|
|
10
|
+
PLOT_MODEL_LABELS = imports.PLOT_MODEL_LABELS
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import importlib
|
|
16
|
+
import sys
|
|
17
|
+
from dataclasses import dataclass, field
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class ResolvedImports:
|
|
24
|
+
"""Container for resolved imports from the bayesopt ecosystem."""
|
|
25
|
+
|
|
26
|
+
# Core bayesopt module
|
|
27
|
+
bayesopt: Any = None
|
|
28
|
+
|
|
29
|
+
# CLI common utilities
|
|
30
|
+
PLOT_MODEL_LABELS: Dict[str, Tuple[str, str]] = field(default_factory=dict)
|
|
31
|
+
PYTORCH_TRAINERS: List[str] = field(default_factory=list)
|
|
32
|
+
build_model_names: Optional[Callable] = None
|
|
33
|
+
dedupe_preserve_order: Optional[Callable] = None
|
|
34
|
+
load_dataset: Optional[Callable] = None
|
|
35
|
+
parse_model_pairs: Optional[Callable] = None
|
|
36
|
+
resolve_data_path: Optional[Callable] = None
|
|
37
|
+
resolve_path: Optional[Callable] = None
|
|
38
|
+
fingerprint_file: Optional[Callable] = None
|
|
39
|
+
coerce_dataset_types: Optional[Callable] = None
|
|
40
|
+
split_train_test: Optional[Callable] = None
|
|
41
|
+
|
|
42
|
+
# CLI config utilities
|
|
43
|
+
add_config_json_arg: Optional[Callable] = None
|
|
44
|
+
add_output_dir_arg: Optional[Callable] = None
|
|
45
|
+
resolve_and_load_config: Optional[Callable] = None
|
|
46
|
+
resolve_data_config: Optional[Callable] = None
|
|
47
|
+
resolve_report_config: Optional[Callable] = None
|
|
48
|
+
resolve_split_config: Optional[Callable] = None
|
|
49
|
+
resolve_runtime_config: Optional[Callable] = None
|
|
50
|
+
resolve_output_dirs: Optional[Callable] = None
|
|
51
|
+
|
|
52
|
+
# Evaluation utilities
|
|
53
|
+
bootstrap_ci: Optional[Callable] = None
|
|
54
|
+
calibrate_predictions: Optional[Callable] = None
|
|
55
|
+
metrics_report: Optional[Callable] = None
|
|
56
|
+
select_threshold: Optional[Callable] = None
|
|
57
|
+
|
|
58
|
+
# Governance and reporting
|
|
59
|
+
ModelArtifact: Optional[Type] = None
|
|
60
|
+
ModelRegistry: Optional[Type] = None
|
|
61
|
+
drift_psi_report: Optional[Callable] = None
|
|
62
|
+
group_metrics: Optional[Callable] = None
|
|
63
|
+
ReportPayload: Optional[Type] = None
|
|
64
|
+
write_report: Optional[Callable] = None
|
|
65
|
+
|
|
66
|
+
# Logging
|
|
67
|
+
configure_run_logging: Optional[Callable] = None
|
|
68
|
+
|
|
69
|
+
# Plotting
|
|
70
|
+
plot_loss_curve: Optional[Callable] = None
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _try_import(module_path: str, attr_name: Optional[str] = None) -> Optional[Any]:
|
|
74
|
+
"""Attempt to import a module or attribute, returning None on failure."""
|
|
75
|
+
try:
|
|
76
|
+
module = importlib.import_module(module_path)
|
|
77
|
+
if attr_name:
|
|
78
|
+
return getattr(module, attr_name, None)
|
|
79
|
+
return module
|
|
80
|
+
except Exception:
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _try_import_from_paths(
|
|
85
|
+
paths: List[str],
|
|
86
|
+
attr_name: Optional[str] = None
|
|
87
|
+
) -> Optional[Any]:
|
|
88
|
+
"""Try importing from multiple module paths, return first success."""
|
|
89
|
+
for path in paths:
|
|
90
|
+
result = _try_import(path, attr_name)
|
|
91
|
+
if result is not None:
|
|
92
|
+
return result
|
|
93
|
+
return None
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _resolve_bayesopt() -> Optional[Any]:
|
|
97
|
+
"""Resolve the bayesopt module from multiple possible locations."""
|
|
98
|
+
paths = [
|
|
99
|
+
"ins_pricing.modelling.core.bayesopt",
|
|
100
|
+
"bayesopt",
|
|
101
|
+
"BayesOpt",
|
|
102
|
+
]
|
|
103
|
+
return _try_import_from_paths(paths)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _resolve_cli_common() -> Dict[str, Any]:
|
|
107
|
+
"""Resolve CLI common utilities."""
|
|
108
|
+
paths = [
|
|
109
|
+
"ins_pricing.cli.utils.cli_common",
|
|
110
|
+
"cli.utils.cli_common",
|
|
111
|
+
"utils.cli_common",
|
|
112
|
+
]
|
|
113
|
+
|
|
114
|
+
attrs = [
|
|
115
|
+
"PLOT_MODEL_LABELS",
|
|
116
|
+
"PYTORCH_TRAINERS",
|
|
117
|
+
"build_model_names",
|
|
118
|
+
"dedupe_preserve_order",
|
|
119
|
+
"load_dataset",
|
|
120
|
+
"parse_model_pairs",
|
|
121
|
+
"resolve_data_path",
|
|
122
|
+
"resolve_path",
|
|
123
|
+
"fingerprint_file",
|
|
124
|
+
"coerce_dataset_types",
|
|
125
|
+
"split_train_test",
|
|
126
|
+
]
|
|
127
|
+
|
|
128
|
+
results = {}
|
|
129
|
+
for path in paths:
|
|
130
|
+
module = _try_import(path)
|
|
131
|
+
if module is not None:
|
|
132
|
+
for attr in attrs:
|
|
133
|
+
if attr not in results or results[attr] is None:
|
|
134
|
+
results[attr] = getattr(module, attr, None)
|
|
135
|
+
# If we got most attributes, break
|
|
136
|
+
if sum(1 for v in results.values() if v is not None) >= len(attrs) // 2:
|
|
137
|
+
break
|
|
138
|
+
|
|
139
|
+
return results
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _resolve_cli_config() -> Dict[str, Any]:
|
|
143
|
+
"""Resolve CLI config utilities."""
|
|
144
|
+
paths = [
|
|
145
|
+
"ins_pricing.cli.utils.cli_config",
|
|
146
|
+
"cli.utils.cli_config",
|
|
147
|
+
"utils.cli_config",
|
|
148
|
+
]
|
|
149
|
+
|
|
150
|
+
attrs = [
|
|
151
|
+
"add_config_json_arg",
|
|
152
|
+
"add_output_dir_arg",
|
|
153
|
+
"resolve_and_load_config",
|
|
154
|
+
"resolve_data_config",
|
|
155
|
+
"resolve_report_config",
|
|
156
|
+
"resolve_split_config",
|
|
157
|
+
"resolve_runtime_config",
|
|
158
|
+
"resolve_output_dirs",
|
|
159
|
+
]
|
|
160
|
+
|
|
161
|
+
results = {}
|
|
162
|
+
for path in paths:
|
|
163
|
+
module = _try_import(path)
|
|
164
|
+
if module is not None:
|
|
165
|
+
for attr in attrs:
|
|
166
|
+
if attr not in results or results[attr] is None:
|
|
167
|
+
results[attr] = getattr(module, attr, None)
|
|
168
|
+
if sum(1 for v in results.values() if v is not None) >= len(attrs) // 2:
|
|
169
|
+
break
|
|
170
|
+
|
|
171
|
+
return results
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def _resolve_evaluation() -> Dict[str, Any]:
|
|
175
|
+
"""Resolve evaluation utilities."""
|
|
176
|
+
paths = [
|
|
177
|
+
"ins_pricing.modelling.core.evaluation",
|
|
178
|
+
"evaluation",
|
|
179
|
+
]
|
|
180
|
+
|
|
181
|
+
results = {}
|
|
182
|
+
for path in paths:
|
|
183
|
+
module = _try_import(path)
|
|
184
|
+
if module is not None:
|
|
185
|
+
results["bootstrap_ci"] = getattr(module, "bootstrap_ci", None)
|
|
186
|
+
results["calibrate_predictions"] = getattr(module, "calibrate_predictions", None)
|
|
187
|
+
results["metrics_report"] = getattr(module, "metrics_report", None)
|
|
188
|
+
results["select_threshold"] = getattr(module, "select_threshold", None)
|
|
189
|
+
if any(v is not None for v in results.values()):
|
|
190
|
+
break
|
|
191
|
+
|
|
192
|
+
return results
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _resolve_governance() -> Dict[str, Any]:
|
|
196
|
+
"""Resolve governance and reporting utilities."""
|
|
197
|
+
results = {}
|
|
198
|
+
|
|
199
|
+
# ModelRegistry and ModelArtifact
|
|
200
|
+
registry_paths = [
|
|
201
|
+
"ins_pricing.governance.registry",
|
|
202
|
+
]
|
|
203
|
+
for path in registry_paths:
|
|
204
|
+
module = _try_import(path)
|
|
205
|
+
if module is not None:
|
|
206
|
+
results["ModelArtifact"] = getattr(module, "ModelArtifact", None)
|
|
207
|
+
results["ModelRegistry"] = getattr(module, "ModelRegistry", None)
|
|
208
|
+
break
|
|
209
|
+
|
|
210
|
+
# PSI report
|
|
211
|
+
psi_paths = [
|
|
212
|
+
"ins_pricing.production",
|
|
213
|
+
]
|
|
214
|
+
for path in psi_paths:
|
|
215
|
+
module = _try_import(path)
|
|
216
|
+
if module is not None:
|
|
217
|
+
results["drift_psi_report"] = getattr(module, "psi_report", None)
|
|
218
|
+
break
|
|
219
|
+
|
|
220
|
+
# Group metrics
|
|
221
|
+
monitoring_paths = [
|
|
222
|
+
"ins_pricing.production.monitoring",
|
|
223
|
+
]
|
|
224
|
+
for path in monitoring_paths:
|
|
225
|
+
module = _try_import(path)
|
|
226
|
+
if module is not None:
|
|
227
|
+
results["group_metrics"] = getattr(module, "group_metrics", None)
|
|
228
|
+
break
|
|
229
|
+
|
|
230
|
+
# Report builder
|
|
231
|
+
report_paths = [
|
|
232
|
+
"ins_pricing.reporting.report_builder",
|
|
233
|
+
]
|
|
234
|
+
for path in report_paths:
|
|
235
|
+
module = _try_import(path)
|
|
236
|
+
if module is not None:
|
|
237
|
+
results["ReportPayload"] = getattr(module, "ReportPayload", None)
|
|
238
|
+
results["write_report"] = getattr(module, "write_report", None)
|
|
239
|
+
break
|
|
240
|
+
|
|
241
|
+
return results
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _resolve_logging() -> Dict[str, Any]:
|
|
245
|
+
"""Resolve logging utilities."""
|
|
246
|
+
paths = [
|
|
247
|
+
"ins_pricing.cli.utils.run_logging",
|
|
248
|
+
"cli.utils.run_logging",
|
|
249
|
+
"utils.run_logging",
|
|
250
|
+
]
|
|
251
|
+
|
|
252
|
+
results = {}
|
|
253
|
+
for path in paths:
|
|
254
|
+
module = _try_import(path)
|
|
255
|
+
if module is not None:
|
|
256
|
+
results["configure_run_logging"] = getattr(module, "configure_run_logging", None)
|
|
257
|
+
break
|
|
258
|
+
|
|
259
|
+
return results
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def _resolve_plotting() -> Dict[str, Any]:
|
|
263
|
+
"""Resolve plotting utilities."""
|
|
264
|
+
paths = [
|
|
265
|
+
"ins_pricing.modelling.plotting.diagnostics",
|
|
266
|
+
"ins_pricing.plotting.diagnostics",
|
|
267
|
+
]
|
|
268
|
+
|
|
269
|
+
results = {}
|
|
270
|
+
for path in paths:
|
|
271
|
+
module = _try_import(path)
|
|
272
|
+
if module is not None:
|
|
273
|
+
results["plot_loss_curve"] = getattr(module, "plot_loss_curve", None)
|
|
274
|
+
break
|
|
275
|
+
|
|
276
|
+
return results
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def resolve_imports() -> ResolvedImports:
|
|
280
|
+
"""Resolve all imports from the bayesopt ecosystem.
|
|
281
|
+
|
|
282
|
+
This function attempts to import modules from multiple possible locations,
|
|
283
|
+
handling the various ways the package might be installed or run.
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
ResolvedImports object containing all resolved imports.
|
|
287
|
+
"""
|
|
288
|
+
imports = ResolvedImports()
|
|
289
|
+
|
|
290
|
+
# Resolve bayesopt core
|
|
291
|
+
imports.bayesopt = _resolve_bayesopt()
|
|
292
|
+
|
|
293
|
+
# Resolve CLI common utilities
|
|
294
|
+
cli_common = _resolve_cli_common()
|
|
295
|
+
imports.PLOT_MODEL_LABELS = cli_common.get("PLOT_MODEL_LABELS", {})
|
|
296
|
+
imports.PYTORCH_TRAINERS = cli_common.get("PYTORCH_TRAINERS", [])
|
|
297
|
+
imports.build_model_names = cli_common.get("build_model_names")
|
|
298
|
+
imports.dedupe_preserve_order = cli_common.get("dedupe_preserve_order")
|
|
299
|
+
imports.load_dataset = cli_common.get("load_dataset")
|
|
300
|
+
imports.parse_model_pairs = cli_common.get("parse_model_pairs")
|
|
301
|
+
imports.resolve_data_path = cli_common.get("resolve_data_path")
|
|
302
|
+
imports.resolve_path = cli_common.get("resolve_path")
|
|
303
|
+
imports.fingerprint_file = cli_common.get("fingerprint_file")
|
|
304
|
+
imports.coerce_dataset_types = cli_common.get("coerce_dataset_types")
|
|
305
|
+
imports.split_train_test = cli_common.get("split_train_test")
|
|
306
|
+
|
|
307
|
+
# Resolve CLI config utilities
|
|
308
|
+
cli_config = _resolve_cli_config()
|
|
309
|
+
imports.add_config_json_arg = cli_config.get("add_config_json_arg")
|
|
310
|
+
imports.add_output_dir_arg = cli_config.get("add_output_dir_arg")
|
|
311
|
+
imports.resolve_and_load_config = cli_config.get("resolve_and_load_config")
|
|
312
|
+
imports.resolve_data_config = cli_config.get("resolve_data_config")
|
|
313
|
+
imports.resolve_report_config = cli_config.get("resolve_report_config")
|
|
314
|
+
imports.resolve_split_config = cli_config.get("resolve_split_config")
|
|
315
|
+
imports.resolve_runtime_config = cli_config.get("resolve_runtime_config")
|
|
316
|
+
imports.resolve_output_dirs = cli_config.get("resolve_output_dirs")
|
|
317
|
+
|
|
318
|
+
# Resolve evaluation utilities
|
|
319
|
+
evaluation = _resolve_evaluation()
|
|
320
|
+
imports.bootstrap_ci = evaluation.get("bootstrap_ci")
|
|
321
|
+
imports.calibrate_predictions = evaluation.get("calibrate_predictions")
|
|
322
|
+
imports.metrics_report = evaluation.get("metrics_report")
|
|
323
|
+
imports.select_threshold = evaluation.get("select_threshold")
|
|
324
|
+
|
|
325
|
+
# Resolve governance and reporting
|
|
326
|
+
governance = _resolve_governance()
|
|
327
|
+
imports.ModelArtifact = governance.get("ModelArtifact")
|
|
328
|
+
imports.ModelRegistry = governance.get("ModelRegistry")
|
|
329
|
+
imports.drift_psi_report = governance.get("drift_psi_report")
|
|
330
|
+
imports.group_metrics = governance.get("group_metrics")
|
|
331
|
+
imports.ReportPayload = governance.get("ReportPayload")
|
|
332
|
+
imports.write_report = governance.get("write_report")
|
|
333
|
+
|
|
334
|
+
# Resolve logging
|
|
335
|
+
logging_utils = _resolve_logging()
|
|
336
|
+
imports.configure_run_logging = logging_utils.get("configure_run_logging")
|
|
337
|
+
|
|
338
|
+
# Resolve plotting
|
|
339
|
+
plotting = _resolve_plotting()
|
|
340
|
+
imports.plot_loss_curve = plotting.get("plot_loss_curve")
|
|
341
|
+
|
|
342
|
+
return imports
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
# Convenience function for backward compatibility
|
|
346
|
+
def setup_sys_path() -> None:
|
|
347
|
+
"""Ensure the repository root is in sys.path for imports."""
|
|
348
|
+
repo_root = Path(__file__).resolve().parents[3]
|
|
349
|
+
if str(repo_root) not in sys.path:
|
|
350
|
+
sys.path.insert(0, str(repo_root))
|