ins-pricing 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/docs/LOSS_FUNCTIONS.md +78 -0
- ins_pricing/docs/modelling/BayesOpt_USAGE.md +3 -3
- ins_pricing/frontend/QUICKSTART.md +152 -0
- ins_pricing/frontend/README.md +388 -0
- ins_pricing/frontend/__init__.py +10 -0
- ins_pricing/frontend/app.py +903 -0
- ins_pricing/frontend/config_builder.py +352 -0
- ins_pricing/frontend/example_config.json +36 -0
- ins_pricing/frontend/example_workflows.py +979 -0
- ins_pricing/frontend/ft_workflow.py +316 -0
- ins_pricing/frontend/runner.py +388 -0
- ins_pricing/modelling/core/bayesopt/config_preprocess.py +12 -0
- ins_pricing/modelling/core/bayesopt/core.py +21 -8
- ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +16 -6
- ins_pricing/modelling/core/bayesopt/models/model_gnn.py +16 -6
- ins_pricing/modelling/core/bayesopt/models/model_resn.py +16 -7
- ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +2 -0
- ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py +25 -8
- ins_pricing/modelling/core/bayesopt/trainers/trainer_glm.py +14 -11
- ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py +29 -10
- ins_pricing/modelling/core/bayesopt/trainers/trainer_resn.py +28 -12
- ins_pricing/modelling/core/bayesopt/trainers/trainer_xgb.py +13 -14
- ins_pricing/modelling/core/bayesopt/utils/losses.py +129 -0
- ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +18 -3
- ins_pricing/modelling/core/bayesopt/utils/torch_trainer_mixin.py +24 -3
- ins_pricing/production/predict.py +693 -635
- ins_pricing/setup.py +1 -1
- ins_pricing/utils/metrics.py +27 -3
- {ins_pricing-0.3.3.dist-info → ins_pricing-0.4.0.dist-info}/METADATA +162 -162
- {ins_pricing-0.3.3.dist-info → ins_pricing-0.4.0.dist-info}/RECORD +32 -21
- {ins_pricing-0.3.3.dist-info → ins_pricing-0.4.0.dist-info}/WHEEL +1 -1
- {ins_pricing-0.3.3.dist-info → ins_pricing-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
"""
|
|
2
|
+
FT-Transformer Two-Step Workflow Helper
|
|
3
|
+
Automates the FT → XGB/ResN two-step training process.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import copy
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Dict, Any, List, Tuple, Optional
|
|
10
|
+
import pandas as pd
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class FTWorkflowHelper:
|
|
14
|
+
"""
|
|
15
|
+
Helper for FT-Transformer two-step workflow.
|
|
16
|
+
|
|
17
|
+
Step 1: Train FT as unsupervised embedding generator
|
|
18
|
+
Step 2: Merge embeddings with raw data and train XGB/ResN
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self):
|
|
22
|
+
self.step1_config = None
|
|
23
|
+
self.step2_configs = {}
|
|
24
|
+
|
|
25
|
+
def prepare_step1_config(
|
|
26
|
+
self,
|
|
27
|
+
base_config: Dict[str, Any],
|
|
28
|
+
output_dir: str = "./ResultsFTUnsupervisedDDP",
|
|
29
|
+
ft_feature_prefix: str = "ft_emb",
|
|
30
|
+
use_ddp: bool = True,
|
|
31
|
+
nproc_per_node: int = 2,
|
|
32
|
+
) -> Dict[str, Any]:
|
|
33
|
+
"""
|
|
34
|
+
Prepare configuration for Step 1: FT unsupervised embedding.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
base_config: Base configuration dictionary
|
|
38
|
+
output_dir: Output directory for FT embeddings
|
|
39
|
+
ft_feature_prefix: Prefix for embedding column names
|
|
40
|
+
use_ddp: Whether to use DDP for FT training
|
|
41
|
+
nproc_per_node: Number of processes for DDP
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Step 1 configuration
|
|
45
|
+
"""
|
|
46
|
+
config = copy.deepcopy(base_config)
|
|
47
|
+
|
|
48
|
+
# Set FT role to unsupervised embedding
|
|
49
|
+
config['ft_role'] = 'unsupervised_embedding'
|
|
50
|
+
config['ft_feature_prefix'] = ft_feature_prefix
|
|
51
|
+
config['output_dir'] = output_dir
|
|
52
|
+
config['cache_predictions'] = True
|
|
53
|
+
config['prediction_cache_format'] = 'csv'
|
|
54
|
+
|
|
55
|
+
# Disable other models in step 1
|
|
56
|
+
config['stack_model_keys'] = []
|
|
57
|
+
|
|
58
|
+
# DDP settings
|
|
59
|
+
config['use_ft_ddp'] = use_ddp
|
|
60
|
+
config['use_resn_ddp'] = False
|
|
61
|
+
config['use_gnn_ddp'] = False
|
|
62
|
+
config['use_ft_data_parallel'] = False
|
|
63
|
+
config['use_resn_data_parallel'] = False
|
|
64
|
+
config['use_gnn_data_parallel'] = False
|
|
65
|
+
|
|
66
|
+
# Optuna storage
|
|
67
|
+
config['optuna_storage'] = f"{output_dir}/optuna/bayesopt.sqlite3"
|
|
68
|
+
config['optuna_study_prefix'] = 'pricing_ft_unsup'
|
|
69
|
+
|
|
70
|
+
# Runner config
|
|
71
|
+
runner = config.get('runner', {})
|
|
72
|
+
runner['mode'] = 'entry'
|
|
73
|
+
runner['model_keys'] = ['ft']
|
|
74
|
+
runner['nproc_per_node'] = nproc_per_node if use_ddp else 1
|
|
75
|
+
runner['plot_curves'] = False
|
|
76
|
+
config['runner'] = runner
|
|
77
|
+
|
|
78
|
+
# Disable plotting
|
|
79
|
+
config['plot_curves'] = False
|
|
80
|
+
plot_cfg = config.get('plot', {})
|
|
81
|
+
plot_cfg['enable'] = False
|
|
82
|
+
config['plot'] = plot_cfg
|
|
83
|
+
|
|
84
|
+
self.step1_config = config
|
|
85
|
+
return config
|
|
86
|
+
|
|
87
|
+
def generate_step2_configs(
|
|
88
|
+
self,
|
|
89
|
+
step1_config_path: str,
|
|
90
|
+
target_models: List[str] = None,
|
|
91
|
+
augmented_data_dir: str = "./DataFTUnsupervised"
|
|
92
|
+
) -> Tuple[Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
|
93
|
+
"""
|
|
94
|
+
Generate Step 2 configurations for XGB and/or ResN.
|
|
95
|
+
|
|
96
|
+
This requires that Step 1 has completed and embeddings are cached.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
step1_config_path: Path to the Step 1 config file
|
|
100
|
+
target_models: Models to train in step 2 (e.g., ['xgb', 'resn'])
|
|
101
|
+
augmented_data_dir: Directory to save augmented data with embeddings
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Tuple of (xgb_config, resn_config) - None if not in target_models
|
|
105
|
+
"""
|
|
106
|
+
if target_models is None:
|
|
107
|
+
target_models = ['xgb', 'resn']
|
|
108
|
+
|
|
109
|
+
# Load step 1 config
|
|
110
|
+
cfg_path = Path(step1_config_path)
|
|
111
|
+
with open(cfg_path, 'r', encoding='utf-8') as f:
|
|
112
|
+
cfg = json.load(f)
|
|
113
|
+
|
|
114
|
+
# Read raw data and split
|
|
115
|
+
model_name = f"{cfg['model_list'][0]}_{cfg['model_categories'][0]}"
|
|
116
|
+
data_dir = (cfg_path.parent / cfg["data_dir"]).resolve()
|
|
117
|
+
raw_path = data_dir / f"{model_name}.csv"
|
|
118
|
+
|
|
119
|
+
if not raw_path.exists():
|
|
120
|
+
raise FileNotFoundError(f"Data file not found: {raw_path}")
|
|
121
|
+
|
|
122
|
+
raw = pd.read_csv(raw_path)
|
|
123
|
+
|
|
124
|
+
# Import split function
|
|
125
|
+
try:
|
|
126
|
+
from ins_pricing.cli.utils.cli_common import split_train_test
|
|
127
|
+
except ImportError:
|
|
128
|
+
raise ImportError("Cannot import split_train_test. Ensure ins_pricing is installed.")
|
|
129
|
+
|
|
130
|
+
# Split data using same settings as step 1
|
|
131
|
+
holdout_ratio = cfg.get("holdout_ratio", cfg.get("prop_test", 0.25))
|
|
132
|
+
split_strategy = cfg.get("split_strategy", "random")
|
|
133
|
+
split_group_col = cfg.get("split_group_col")
|
|
134
|
+
split_time_col = cfg.get("split_time_col")
|
|
135
|
+
split_time_ascending = cfg.get("split_time_ascending", True)
|
|
136
|
+
rand_seed = cfg.get("rand_seed", 13)
|
|
137
|
+
|
|
138
|
+
train_df, test_df = split_train_test(
|
|
139
|
+
raw,
|
|
140
|
+
holdout_ratio=holdout_ratio,
|
|
141
|
+
strategy=split_strategy,
|
|
142
|
+
group_col=split_group_col,
|
|
143
|
+
time_col=split_time_col,
|
|
144
|
+
time_ascending=split_time_ascending,
|
|
145
|
+
rand_seed=rand_seed,
|
|
146
|
+
reset_index_mode="time_group",
|
|
147
|
+
ratio_label="holdout_ratio",
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# Load cached embeddings
|
|
151
|
+
out_root = (cfg_path.parent / cfg["output_dir"]).resolve()
|
|
152
|
+
pred_prefix = cfg.get("ft_feature_prefix", "ft_emb")
|
|
153
|
+
pred_dir = out_root / "Results" / "predictions"
|
|
154
|
+
|
|
155
|
+
train_pred_path = pred_dir / f"{model_name}_{pred_prefix}_train.csv"
|
|
156
|
+
test_pred_path = pred_dir / f"{model_name}_{pred_prefix}_test.csv"
|
|
157
|
+
|
|
158
|
+
if not train_pred_path.exists() or not test_pred_path.exists():
|
|
159
|
+
raise FileNotFoundError(
|
|
160
|
+
f"Embedding files not found. Run Step 1 first.\n"
|
|
161
|
+
f"Expected: {train_pred_path} and {test_pred_path}"
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
pred_train = pd.read_csv(train_pred_path)
|
|
165
|
+
pred_test = pd.read_csv(test_pred_path)
|
|
166
|
+
|
|
167
|
+
if len(pred_train) != len(train_df) or len(pred_test) != len(test_df):
|
|
168
|
+
raise ValueError(
|
|
169
|
+
"Prediction rows do not match split sizes; check split settings.")
|
|
170
|
+
|
|
171
|
+
# Merge embeddings with raw data
|
|
172
|
+
aug = raw.copy()
|
|
173
|
+
aug.loc[train_df.index, pred_train.columns] = pred_train.values
|
|
174
|
+
aug.loc[test_df.index, pred_test.columns] = pred_test.values
|
|
175
|
+
|
|
176
|
+
# Save augmented data
|
|
177
|
+
data_out_dir = cfg_path.parent / augmented_data_dir
|
|
178
|
+
data_out_dir.mkdir(parents=True, exist_ok=True)
|
|
179
|
+
aug_path = data_out_dir / f"{model_name}.csv"
|
|
180
|
+
aug.to_csv(aug_path, index=False)
|
|
181
|
+
|
|
182
|
+
# Get embedding column names
|
|
183
|
+
embed_cols = list(pred_train.columns)
|
|
184
|
+
|
|
185
|
+
# Generate configs
|
|
186
|
+
xgb_config = None
|
|
187
|
+
resn_config = None
|
|
188
|
+
|
|
189
|
+
if 'xgb' in target_models:
|
|
190
|
+
xgb_config = self._build_xgb_config(cfg, cfg_path, embed_cols, augmented_data_dir)
|
|
191
|
+
self.step2_configs['xgb'] = xgb_config
|
|
192
|
+
|
|
193
|
+
if 'resn' in target_models:
|
|
194
|
+
resn_config = self._build_resn_config(cfg, cfg_path, embed_cols, augmented_data_dir)
|
|
195
|
+
self.step2_configs['resn'] = resn_config
|
|
196
|
+
|
|
197
|
+
return xgb_config, resn_config
|
|
198
|
+
|
|
199
|
+
def _build_xgb_config(
|
|
200
|
+
self,
|
|
201
|
+
base_cfg: Dict[str, Any],
|
|
202
|
+
cfg_path: Path,
|
|
203
|
+
embed_cols: List[str],
|
|
204
|
+
data_dir: str
|
|
205
|
+
) -> Dict[str, Any]:
|
|
206
|
+
"""Build XGB config for Step 2."""
|
|
207
|
+
xgb_cfg = copy.deepcopy(base_cfg)
|
|
208
|
+
|
|
209
|
+
xgb_cfg["data_dir"] = str(data_dir)
|
|
210
|
+
xgb_cfg["feature_list"] = base_cfg["feature_list"] + embed_cols
|
|
211
|
+
xgb_cfg["ft_role"] = "model"
|
|
212
|
+
xgb_cfg["stack_model_keys"] = ["xgb"]
|
|
213
|
+
xgb_cfg["cache_predictions"] = False
|
|
214
|
+
|
|
215
|
+
# Disable DDP for XGB
|
|
216
|
+
xgb_cfg["use_resn_ddp"] = False
|
|
217
|
+
xgb_cfg["use_ft_ddp"] = False
|
|
218
|
+
xgb_cfg["use_gnn_ddp"] = False
|
|
219
|
+
xgb_cfg["use_resn_data_parallel"] = False
|
|
220
|
+
xgb_cfg["use_ft_data_parallel"] = False
|
|
221
|
+
xgb_cfg["use_gnn_data_parallel"] = False
|
|
222
|
+
|
|
223
|
+
xgb_cfg["output_dir"] = "./ResultsXGBFromFTUnsupervised"
|
|
224
|
+
xgb_cfg["optuna_storage"] = "./ResultsXGBFromFTUnsupervised/optuna/bayesopt.sqlite3"
|
|
225
|
+
xgb_cfg["optuna_study_prefix"] = "pricing_ft_unsup_xgb"
|
|
226
|
+
xgb_cfg["loss_name"] = "mse"
|
|
227
|
+
|
|
228
|
+
runner_cfg = xgb_cfg.get("runner", {})
|
|
229
|
+
runner_cfg["model_keys"] = ["xgb"]
|
|
230
|
+
runner_cfg["nproc_per_node"] = 1
|
|
231
|
+
runner_cfg["plot_curves"] = False
|
|
232
|
+
xgb_cfg["runner"] = runner_cfg
|
|
233
|
+
|
|
234
|
+
xgb_cfg["plot_curves"] = False
|
|
235
|
+
plot_cfg = xgb_cfg.get("plot", {})
|
|
236
|
+
plot_cfg["enable"] = False
|
|
237
|
+
xgb_cfg["plot"] = plot_cfg
|
|
238
|
+
|
|
239
|
+
return xgb_cfg
|
|
240
|
+
|
|
241
|
+
def _build_resn_config(
|
|
242
|
+
self,
|
|
243
|
+
base_cfg: Dict[str, Any],
|
|
244
|
+
cfg_path: Path,
|
|
245
|
+
embed_cols: List[str],
|
|
246
|
+
data_dir: str
|
|
247
|
+
) -> Dict[str, Any]:
|
|
248
|
+
"""Build ResNet config for Step 2."""
|
|
249
|
+
resn_cfg = copy.deepcopy(base_cfg)
|
|
250
|
+
|
|
251
|
+
resn_cfg["data_dir"] = str(data_dir)
|
|
252
|
+
resn_cfg["feature_list"] = base_cfg["feature_list"] + embed_cols
|
|
253
|
+
resn_cfg["ft_role"] = "model"
|
|
254
|
+
resn_cfg["stack_model_keys"] = ["resn"]
|
|
255
|
+
resn_cfg["cache_predictions"] = False
|
|
256
|
+
|
|
257
|
+
# Enable DDP for ResNet
|
|
258
|
+
resn_cfg["use_resn_ddp"] = True
|
|
259
|
+
resn_cfg["use_ft_ddp"] = False
|
|
260
|
+
resn_cfg["use_gnn_ddp"] = False
|
|
261
|
+
resn_cfg["use_resn_data_parallel"] = False
|
|
262
|
+
resn_cfg["use_ft_data_parallel"] = False
|
|
263
|
+
resn_cfg["use_gnn_data_parallel"] = False
|
|
264
|
+
|
|
265
|
+
resn_cfg["output_dir"] = "./ResultsResNFromFTUnsupervised"
|
|
266
|
+
resn_cfg["optuna_storage"] = "./ResultsResNFromFTUnsupervised/optuna/bayesopt.sqlite3"
|
|
267
|
+
resn_cfg["optuna_study_prefix"] = "pricing_ft_unsup_resn_ddp"
|
|
268
|
+
resn_cfg["loss_name"] = "mse"
|
|
269
|
+
|
|
270
|
+
runner_cfg = resn_cfg.get("runner", {})
|
|
271
|
+
runner_cfg["model_keys"] = ["resn"]
|
|
272
|
+
runner_cfg["nproc_per_node"] = 2
|
|
273
|
+
runner_cfg["plot_curves"] = False
|
|
274
|
+
resn_cfg["runner"] = runner_cfg
|
|
275
|
+
|
|
276
|
+
resn_cfg["plot_curves"] = False
|
|
277
|
+
plot_cfg = resn_cfg.get("plot", {})
|
|
278
|
+
plot_cfg["enable"] = False
|
|
279
|
+
resn_cfg["plot"] = plot_cfg
|
|
280
|
+
|
|
281
|
+
return resn_cfg
|
|
282
|
+
|
|
283
|
+
def save_configs(self, output_dir: str = ".") -> Dict[str, str]:
|
|
284
|
+
"""
|
|
285
|
+
Save generated configs to files.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
output_dir: Directory to save config files
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
Dictionary mapping model names to saved file paths
|
|
292
|
+
"""
|
|
293
|
+
output_path = Path(output_dir)
|
|
294
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
295
|
+
|
|
296
|
+
saved_files = {}
|
|
297
|
+
|
|
298
|
+
if self.step1_config:
|
|
299
|
+
step1_path = output_path / "config_ft_step1_unsupervised.json"
|
|
300
|
+
with open(step1_path, 'w', encoding='utf-8') as f:
|
|
301
|
+
json.dump(self.step1_config, f, indent=2)
|
|
302
|
+
saved_files['ft_step1'] = str(step1_path)
|
|
303
|
+
|
|
304
|
+
if 'xgb' in self.step2_configs:
|
|
305
|
+
xgb_path = output_path / "config_xgb_from_ft_step2.json"
|
|
306
|
+
with open(xgb_path, 'w', encoding='utf-8') as f:
|
|
307
|
+
json.dump(self.step2_configs['xgb'], f, indent=2)
|
|
308
|
+
saved_files['xgb_step2'] = str(xgb_path)
|
|
309
|
+
|
|
310
|
+
if 'resn' in self.step2_configs:
|
|
311
|
+
resn_path = output_path / "config_resn_from_ft_step2.json"
|
|
312
|
+
with open(resn_path, 'w', encoding='utf-8') as f:
|
|
313
|
+
json.dump(self.step2_configs['resn'], f, indent=2)
|
|
314
|
+
saved_files['resn_step2'] = str(resn_path)
|
|
315
|
+
|
|
316
|
+
return saved_files
|
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Unified Task Runner with Real-time Logging
|
|
3
|
+
Executes model training, explanation, plotting, and other tasks based on config.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import sys
|
|
7
|
+
import threading
|
|
8
|
+
import queue
|
|
9
|
+
import time
|
|
10
|
+
import json
|
|
11
|
+
import subprocess
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Generator, Optional, Dict, Any, List, Sequence, Tuple
|
|
14
|
+
import logging
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LogCapture:
|
|
18
|
+
"""Capture stdout and stderr for real-time display."""
|
|
19
|
+
|
|
20
|
+
def __init__(self):
|
|
21
|
+
self.log_queue = queue.Queue()
|
|
22
|
+
self.stop_flag = threading.Event()
|
|
23
|
+
|
|
24
|
+
def write(self, text: str):
|
|
25
|
+
"""Write method for capturing output."""
|
|
26
|
+
if text and text.strip():
|
|
27
|
+
self.log_queue.put(text)
|
|
28
|
+
|
|
29
|
+
def flush(self):
|
|
30
|
+
"""Flush method (required for file-like objects)."""
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class TaskRunner:
|
|
35
|
+
"""
|
|
36
|
+
Run model tasks (training, explain, plotting, etc.) and capture logs.
|
|
37
|
+
|
|
38
|
+
Supports all task modes defined in config.runner.mode:
|
|
39
|
+
- entry: Standard model training
|
|
40
|
+
- explain: Model explanation (permutation, SHAP, etc.)
|
|
41
|
+
- incremental: Incremental training
|
|
42
|
+
- watchdog: Watchdog mode for monitoring
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self):
|
|
46
|
+
self.task_thread = None
|
|
47
|
+
self.log_capture = None
|
|
48
|
+
|
|
49
|
+
def _detect_task_mode(self, config_path: str) -> str:
|
|
50
|
+
"""
|
|
51
|
+
Detect the task mode from config file.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
config_path: Path to configuration JSON file
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Task mode string (e.g., 'entry', 'explain', 'incremental', 'watchdog')
|
|
58
|
+
"""
|
|
59
|
+
try:
|
|
60
|
+
with open(config_path, 'r', encoding='utf-8') as f:
|
|
61
|
+
config = json.load(f)
|
|
62
|
+
|
|
63
|
+
runner_config = config.get('runner', {})
|
|
64
|
+
mode = runner_config.get('mode', 'entry')
|
|
65
|
+
return str(mode).lower()
|
|
66
|
+
|
|
67
|
+
except Exception as e:
|
|
68
|
+
print(f"Warning: Could not detect task mode, defaulting to 'entry': {e}")
|
|
69
|
+
return 'entry'
|
|
70
|
+
|
|
71
|
+
def _build_cmd_from_config(self, config_path: str) -> Tuple[List[str], str]:
|
|
72
|
+
"""
|
|
73
|
+
Build the command to execute based on config.runner.mode, mirroring
|
|
74
|
+
ins_pricing.cli.utils.notebook_utils.run_from_config behavior.
|
|
75
|
+
"""
|
|
76
|
+
from ins_pricing.cli.utils.cli_config import set_env
|
|
77
|
+
from ins_pricing.cli.utils.notebook_utils import (
|
|
78
|
+
build_bayesopt_entry_cmd,
|
|
79
|
+
build_incremental_cmd,
|
|
80
|
+
build_explain_cmd,
|
|
81
|
+
wrap_with_watchdog,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
cfg_path = Path(config_path).resolve()
|
|
85
|
+
raw = json.loads(cfg_path.read_text(encoding="utf-8", errors="replace"))
|
|
86
|
+
set_env(raw.get("env", {}))
|
|
87
|
+
runner = dict(raw.get("runner") or {})
|
|
88
|
+
|
|
89
|
+
mode = str(runner.get("mode") or "entry").strip().lower()
|
|
90
|
+
use_watchdog = bool(runner.get("use_watchdog", False))
|
|
91
|
+
if mode == "watchdog":
|
|
92
|
+
use_watchdog = True
|
|
93
|
+
mode = "entry"
|
|
94
|
+
|
|
95
|
+
idle_seconds = int(runner.get("idle_seconds", 7200))
|
|
96
|
+
max_restarts = int(runner.get("max_restarts", 50))
|
|
97
|
+
restart_delay_seconds = int(runner.get("restart_delay_seconds", 10))
|
|
98
|
+
|
|
99
|
+
if mode == "incremental":
|
|
100
|
+
inc_args = runner.get("incremental_args") or []
|
|
101
|
+
if not isinstance(inc_args, list):
|
|
102
|
+
raise ValueError("config.runner.incremental_args must be a list of strings.")
|
|
103
|
+
cmd = build_incremental_cmd(cfg_path, extra_args=[str(x) for x in inc_args])
|
|
104
|
+
if use_watchdog:
|
|
105
|
+
cmd = wrap_with_watchdog(
|
|
106
|
+
cmd,
|
|
107
|
+
idle_seconds=idle_seconds,
|
|
108
|
+
max_restarts=max_restarts,
|
|
109
|
+
restart_delay_seconds=restart_delay_seconds,
|
|
110
|
+
)
|
|
111
|
+
return cmd, "incremental"
|
|
112
|
+
|
|
113
|
+
if mode == "explain":
|
|
114
|
+
exp_args = runner.get("explain_args") or []
|
|
115
|
+
if not isinstance(exp_args, list):
|
|
116
|
+
raise ValueError("config.runner.explain_args must be a list of strings.")
|
|
117
|
+
cmd = build_explain_cmd(cfg_path, extra_args=[str(x) for x in exp_args])
|
|
118
|
+
if use_watchdog:
|
|
119
|
+
cmd = wrap_with_watchdog(
|
|
120
|
+
cmd,
|
|
121
|
+
idle_seconds=idle_seconds,
|
|
122
|
+
max_restarts=max_restarts,
|
|
123
|
+
restart_delay_seconds=restart_delay_seconds,
|
|
124
|
+
)
|
|
125
|
+
return cmd, "explain"
|
|
126
|
+
|
|
127
|
+
if mode != "entry":
|
|
128
|
+
raise ValueError(
|
|
129
|
+
f"Unsupported runner.mode={mode!r}, expected 'entry', 'incremental', or 'explain'."
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
model_keys = runner.get("model_keys") or raw.get("model_keys") or ["ft"]
|
|
133
|
+
if not isinstance(model_keys, list):
|
|
134
|
+
raise ValueError("runner.model_keys must be a list of strings.")
|
|
135
|
+
nproc_per_node = int(runner.get("nproc_per_node", 1))
|
|
136
|
+
max_evals = int(runner.get("max_evals", raw.get("max_evals", 50)))
|
|
137
|
+
plot_curves = bool(runner.get("plot_curves", raw.get("plot_curves", True)))
|
|
138
|
+
ft_role = runner.get("ft_role", raw.get("ft_role"))
|
|
139
|
+
|
|
140
|
+
extra_args: List[str] = ["--max-evals", str(max_evals)]
|
|
141
|
+
if plot_curves:
|
|
142
|
+
extra_args.append("--plot-curves")
|
|
143
|
+
if ft_role:
|
|
144
|
+
extra_args += ["--ft-role", str(ft_role)]
|
|
145
|
+
|
|
146
|
+
cmd = build_bayesopt_entry_cmd(
|
|
147
|
+
cfg_path,
|
|
148
|
+
model_keys=[str(x) for x in model_keys],
|
|
149
|
+
nproc_per_node=nproc_per_node,
|
|
150
|
+
extra_args=extra_args,
|
|
151
|
+
)
|
|
152
|
+
if use_watchdog:
|
|
153
|
+
cmd = wrap_with_watchdog(
|
|
154
|
+
cmd,
|
|
155
|
+
idle_seconds=idle_seconds,
|
|
156
|
+
max_restarts=max_restarts,
|
|
157
|
+
restart_delay_seconds=restart_delay_seconds,
|
|
158
|
+
)
|
|
159
|
+
return cmd, "entry"
|
|
160
|
+
|
|
161
|
+
def run_task(self, config_path: str) -> Generator[str, None, None]:
|
|
162
|
+
"""
|
|
163
|
+
Run task based on config file with real-time log capture.
|
|
164
|
+
|
|
165
|
+
This method automatically detects the task mode from the config file
|
|
166
|
+
(training, explain, plotting, etc.) and runs the appropriate task.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
config_path: Path to configuration JSON file
|
|
170
|
+
|
|
171
|
+
Yields:
|
|
172
|
+
Log lines as they are generated
|
|
173
|
+
"""
|
|
174
|
+
self.log_capture = LogCapture()
|
|
175
|
+
|
|
176
|
+
# Configure logging to capture both file and stream output
|
|
177
|
+
log_handler = logging.StreamHandler(self.log_capture)
|
|
178
|
+
log_handler.setLevel(logging.INFO)
|
|
179
|
+
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
|
180
|
+
log_handler.setFormatter(formatter)
|
|
181
|
+
|
|
182
|
+
# Add handler to root logger
|
|
183
|
+
root_logger = logging.getLogger()
|
|
184
|
+
original_handlers = root_logger.handlers.copy()
|
|
185
|
+
root_logger.addHandler(log_handler)
|
|
186
|
+
|
|
187
|
+
# Store original stdout/stderr
|
|
188
|
+
original_stdout = sys.stdout
|
|
189
|
+
original_stderr = sys.stderr
|
|
190
|
+
|
|
191
|
+
try:
|
|
192
|
+
# Detect task mode
|
|
193
|
+
task_mode = self._detect_task_mode(config_path)
|
|
194
|
+
|
|
195
|
+
# Start task in separate thread
|
|
196
|
+
exception_holder = []
|
|
197
|
+
|
|
198
|
+
def task_worker():
|
|
199
|
+
try:
|
|
200
|
+
sys.stdout = self.log_capture
|
|
201
|
+
sys.stderr = self.log_capture
|
|
202
|
+
|
|
203
|
+
# Log start
|
|
204
|
+
cmd, task_mode = self._build_cmd_from_config(config_path)
|
|
205
|
+
print(f"Starting task [{task_mode}] with config: {config_path}")
|
|
206
|
+
print("=" * 80)
|
|
207
|
+
|
|
208
|
+
# Run subprocess with streamed output
|
|
209
|
+
proc = subprocess.Popen(
|
|
210
|
+
cmd,
|
|
211
|
+
stdout=subprocess.PIPE,
|
|
212
|
+
stderr=subprocess.STDOUT,
|
|
213
|
+
text=True,
|
|
214
|
+
bufsize=1,
|
|
215
|
+
cwd=str(Path(config_path).resolve().parent),
|
|
216
|
+
)
|
|
217
|
+
if proc.stdout is not None:
|
|
218
|
+
for line in proc.stdout:
|
|
219
|
+
print(line.rstrip())
|
|
220
|
+
return_code = proc.wait()
|
|
221
|
+
if return_code != 0:
|
|
222
|
+
raise RuntimeError(f"Task exited with code {return_code}")
|
|
223
|
+
|
|
224
|
+
print("=" * 80)
|
|
225
|
+
print(f"Task [{task_mode}] completed successfully!")
|
|
226
|
+
|
|
227
|
+
except Exception as e:
|
|
228
|
+
exception_holder.append(e)
|
|
229
|
+
print(f"Error during task execution: {str(e)}")
|
|
230
|
+
import traceback
|
|
231
|
+
print(traceback.format_exc())
|
|
232
|
+
|
|
233
|
+
finally:
|
|
234
|
+
sys.stdout = original_stdout
|
|
235
|
+
sys.stderr = original_stderr
|
|
236
|
+
|
|
237
|
+
self.task_thread = threading.Thread(target=task_worker, daemon=True)
|
|
238
|
+
self.task_thread.start()
|
|
239
|
+
|
|
240
|
+
# Yield log lines as they come in
|
|
241
|
+
last_update = time.time()
|
|
242
|
+
while self.task_thread.is_alive() or not self.log_capture.log_queue.empty():
|
|
243
|
+
try:
|
|
244
|
+
# Try to get log with timeout
|
|
245
|
+
log_line = self.log_capture.log_queue.get(timeout=0.1)
|
|
246
|
+
yield log_line
|
|
247
|
+
last_update = time.time()
|
|
248
|
+
|
|
249
|
+
except queue.Empty:
|
|
250
|
+
# Send heartbeat every 5 seconds
|
|
251
|
+
if time.time() - last_update > 5:
|
|
252
|
+
yield "."
|
|
253
|
+
last_update = time.time()
|
|
254
|
+
continue
|
|
255
|
+
|
|
256
|
+
# Wait for thread to complete
|
|
257
|
+
self.task_thread.join(timeout=1)
|
|
258
|
+
|
|
259
|
+
# Check for exceptions
|
|
260
|
+
if exception_holder:
|
|
261
|
+
raise exception_holder[0]
|
|
262
|
+
|
|
263
|
+
finally:
|
|
264
|
+
# Restore original stdout/stderr
|
|
265
|
+
sys.stdout = original_stdout
|
|
266
|
+
sys.stderr = original_stderr
|
|
267
|
+
|
|
268
|
+
# Restore original logging handlers
|
|
269
|
+
root_logger.handlers = original_handlers
|
|
270
|
+
|
|
271
|
+
def run_callable(self, func, *args, **kwargs) -> Generator[str, None, None]:
|
|
272
|
+
"""Run an in-process callable and stream stdout/stderr."""
|
|
273
|
+
self.log_capture = LogCapture()
|
|
274
|
+
|
|
275
|
+
log_handler = logging.StreamHandler(self.log_capture)
|
|
276
|
+
log_handler.setLevel(logging.INFO)
|
|
277
|
+
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
|
278
|
+
log_handler.setFormatter(formatter)
|
|
279
|
+
|
|
280
|
+
root_logger = logging.getLogger()
|
|
281
|
+
original_handlers = root_logger.handlers.copy()
|
|
282
|
+
root_logger.addHandler(log_handler)
|
|
283
|
+
|
|
284
|
+
original_stdout = sys.stdout
|
|
285
|
+
original_stderr = sys.stderr
|
|
286
|
+
|
|
287
|
+
try:
|
|
288
|
+
exception_holder = []
|
|
289
|
+
|
|
290
|
+
def task_worker():
|
|
291
|
+
try:
|
|
292
|
+
sys.stdout = self.log_capture
|
|
293
|
+
sys.stderr = self.log_capture
|
|
294
|
+
func(*args, **kwargs)
|
|
295
|
+
except Exception as e:
|
|
296
|
+
exception_holder.append(e)
|
|
297
|
+
print(f"Error during task execution: {str(e)}")
|
|
298
|
+
import traceback
|
|
299
|
+
print(traceback.format_exc())
|
|
300
|
+
finally:
|
|
301
|
+
sys.stdout = original_stdout
|
|
302
|
+
sys.stderr = original_stderr
|
|
303
|
+
|
|
304
|
+
self.task_thread = threading.Thread(target=task_worker, daemon=True)
|
|
305
|
+
self.task_thread.start()
|
|
306
|
+
|
|
307
|
+
last_update = time.time()
|
|
308
|
+
while self.task_thread.is_alive() or not self.log_capture.log_queue.empty():
|
|
309
|
+
try:
|
|
310
|
+
log_line = self.log_capture.log_queue.get(timeout=0.1)
|
|
311
|
+
yield log_line
|
|
312
|
+
last_update = time.time()
|
|
313
|
+
except queue.Empty:
|
|
314
|
+
if time.time() - last_update > 5:
|
|
315
|
+
yield "."
|
|
316
|
+
last_update = time.time()
|
|
317
|
+
continue
|
|
318
|
+
|
|
319
|
+
self.task_thread.join(timeout=1)
|
|
320
|
+
if exception_holder:
|
|
321
|
+
raise exception_holder[0]
|
|
322
|
+
finally:
|
|
323
|
+
sys.stdout = original_stdout
|
|
324
|
+
sys.stderr = original_stderr
|
|
325
|
+
root_logger.handlers = original_handlers
|
|
326
|
+
|
|
327
|
+
def stop_task(self):
|
|
328
|
+
"""Stop the current task process."""
|
|
329
|
+
if self.log_capture:
|
|
330
|
+
self.log_capture.stop_flag.set()
|
|
331
|
+
|
|
332
|
+
if self.task_thread and self.task_thread.is_alive():
|
|
333
|
+
# Note: Thread.join() will wait for completion
|
|
334
|
+
# For forceful termination, you may need to use process-based approach
|
|
335
|
+
self.task_thread.join(timeout=5)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
# Backward compatibility aliases
|
|
339
|
+
TrainingRunner = TaskRunner
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
class StreamToLogger:
|
|
343
|
+
"""
|
|
344
|
+
Fake file-like stream object that redirects writes to a logger instance.
|
|
345
|
+
"""
|
|
346
|
+
|
|
347
|
+
def __init__(self, logger, log_level=logging.INFO):
|
|
348
|
+
self.logger = logger
|
|
349
|
+
self.log_level = log_level
|
|
350
|
+
self.linebuf = ''
|
|
351
|
+
|
|
352
|
+
def write(self, buf):
|
|
353
|
+
for line in buf.rstrip().splitlines():
|
|
354
|
+
self.logger.log(self.log_level, line.rstrip())
|
|
355
|
+
|
|
356
|
+
def flush(self):
|
|
357
|
+
pass
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def setup_logger(name: str = "task") -> logging.Logger:
|
|
361
|
+
"""
|
|
362
|
+
Set up a logger for task execution.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
name: Logger name
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
Configured logger instance
|
|
369
|
+
"""
|
|
370
|
+
logger = logging.getLogger(name)
|
|
371
|
+
logger.setLevel(logging.INFO)
|
|
372
|
+
|
|
373
|
+
# Create console handler
|
|
374
|
+
console_handler = logging.StreamHandler()
|
|
375
|
+
console_handler.setLevel(logging.INFO)
|
|
376
|
+
|
|
377
|
+
# Create formatter
|
|
378
|
+
formatter = logging.Formatter(
|
|
379
|
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
380
|
+
datefmt='%Y-%m-%d %H:%M:%S'
|
|
381
|
+
)
|
|
382
|
+
console_handler.setFormatter(formatter)
|
|
383
|
+
|
|
384
|
+
# Add handler to logger
|
|
385
|
+
if not logger.handlers:
|
|
386
|
+
logger.addHandler(console_handler)
|
|
387
|
+
|
|
388
|
+
return logger
|