ins-pricing 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. ins_pricing/docs/LOSS_FUNCTIONS.md +78 -0
  2. ins_pricing/docs/modelling/BayesOpt_USAGE.md +3 -3
  3. ins_pricing/frontend/QUICKSTART.md +152 -0
  4. ins_pricing/frontend/README.md +388 -0
  5. ins_pricing/frontend/__init__.py +10 -0
  6. ins_pricing/frontend/app.py +903 -0
  7. ins_pricing/frontend/config_builder.py +352 -0
  8. ins_pricing/frontend/example_config.json +36 -0
  9. ins_pricing/frontend/example_workflows.py +979 -0
  10. ins_pricing/frontend/ft_workflow.py +316 -0
  11. ins_pricing/frontend/runner.py +388 -0
  12. ins_pricing/modelling/core/bayesopt/config_preprocess.py +12 -0
  13. ins_pricing/modelling/core/bayesopt/core.py +21 -8
  14. ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +16 -6
  15. ins_pricing/modelling/core/bayesopt/models/model_gnn.py +16 -6
  16. ins_pricing/modelling/core/bayesopt/models/model_resn.py +16 -7
  17. ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +2 -0
  18. ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py +25 -8
  19. ins_pricing/modelling/core/bayesopt/trainers/trainer_glm.py +14 -11
  20. ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py +29 -10
  21. ins_pricing/modelling/core/bayesopt/trainers/trainer_resn.py +28 -12
  22. ins_pricing/modelling/core/bayesopt/trainers/trainer_xgb.py +13 -14
  23. ins_pricing/modelling/core/bayesopt/utils/losses.py +129 -0
  24. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +18 -3
  25. ins_pricing/modelling/core/bayesopt/utils/torch_trainer_mixin.py +24 -3
  26. ins_pricing/production/predict.py +693 -635
  27. ins_pricing/setup.py +1 -1
  28. ins_pricing/utils/metrics.py +27 -3
  29. {ins_pricing-0.3.3.dist-info → ins_pricing-0.4.0.dist-info}/METADATA +162 -162
  30. {ins_pricing-0.3.3.dist-info → ins_pricing-0.4.0.dist-info}/RECORD +32 -21
  31. {ins_pricing-0.3.3.dist-info → ins_pricing-0.4.0.dist-info}/WHEEL +1 -1
  32. {ins_pricing-0.3.3.dist-info → ins_pricing-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,316 @@
1
+ """
2
+ FT-Transformer Two-Step Workflow Helper
3
+ Automates the FT → XGB/ResN two-step training process.
4
+ """
5
+
6
+ import json
7
+ import copy
8
+ from pathlib import Path
9
+ from typing import Dict, Any, List, Tuple, Optional
10
+ import pandas as pd
11
+
12
+
13
+ class FTWorkflowHelper:
14
+ """
15
+ Helper for FT-Transformer two-step workflow.
16
+
17
+ Step 1: Train FT as unsupervised embedding generator
18
+ Step 2: Merge embeddings with raw data and train XGB/ResN
19
+ """
20
+
21
+ def __init__(self):
22
+ self.step1_config = None
23
+ self.step2_configs = {}
24
+
25
+ def prepare_step1_config(
26
+ self,
27
+ base_config: Dict[str, Any],
28
+ output_dir: str = "./ResultsFTUnsupervisedDDP",
29
+ ft_feature_prefix: str = "ft_emb",
30
+ use_ddp: bool = True,
31
+ nproc_per_node: int = 2,
32
+ ) -> Dict[str, Any]:
33
+ """
34
+ Prepare configuration for Step 1: FT unsupervised embedding.
35
+
36
+ Args:
37
+ base_config: Base configuration dictionary
38
+ output_dir: Output directory for FT embeddings
39
+ ft_feature_prefix: Prefix for embedding column names
40
+ use_ddp: Whether to use DDP for FT training
41
+ nproc_per_node: Number of processes for DDP
42
+
43
+ Returns:
44
+ Step 1 configuration
45
+ """
46
+ config = copy.deepcopy(base_config)
47
+
48
+ # Set FT role to unsupervised embedding
49
+ config['ft_role'] = 'unsupervised_embedding'
50
+ config['ft_feature_prefix'] = ft_feature_prefix
51
+ config['output_dir'] = output_dir
52
+ config['cache_predictions'] = True
53
+ config['prediction_cache_format'] = 'csv'
54
+
55
+ # Disable other models in step 1
56
+ config['stack_model_keys'] = []
57
+
58
+ # DDP settings
59
+ config['use_ft_ddp'] = use_ddp
60
+ config['use_resn_ddp'] = False
61
+ config['use_gnn_ddp'] = False
62
+ config['use_ft_data_parallel'] = False
63
+ config['use_resn_data_parallel'] = False
64
+ config['use_gnn_data_parallel'] = False
65
+
66
+ # Optuna storage
67
+ config['optuna_storage'] = f"{output_dir}/optuna/bayesopt.sqlite3"
68
+ config['optuna_study_prefix'] = 'pricing_ft_unsup'
69
+
70
+ # Runner config
71
+ runner = config.get('runner', {})
72
+ runner['mode'] = 'entry'
73
+ runner['model_keys'] = ['ft']
74
+ runner['nproc_per_node'] = nproc_per_node if use_ddp else 1
75
+ runner['plot_curves'] = False
76
+ config['runner'] = runner
77
+
78
+ # Disable plotting
79
+ config['plot_curves'] = False
80
+ plot_cfg = config.get('plot', {})
81
+ plot_cfg['enable'] = False
82
+ config['plot'] = plot_cfg
83
+
84
+ self.step1_config = config
85
+ return config
86
+
87
+ def generate_step2_configs(
88
+ self,
89
+ step1_config_path: str,
90
+ target_models: List[str] = None,
91
+ augmented_data_dir: str = "./DataFTUnsupervised"
92
+ ) -> Tuple[Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
93
+ """
94
+ Generate Step 2 configurations for XGB and/or ResN.
95
+
96
+ This requires that Step 1 has completed and embeddings are cached.
97
+
98
+ Args:
99
+ step1_config_path: Path to the Step 1 config file
100
+ target_models: Models to train in step 2 (e.g., ['xgb', 'resn'])
101
+ augmented_data_dir: Directory to save augmented data with embeddings
102
+
103
+ Returns:
104
+ Tuple of (xgb_config, resn_config) - None if not in target_models
105
+ """
106
+ if target_models is None:
107
+ target_models = ['xgb', 'resn']
108
+
109
+ # Load step 1 config
110
+ cfg_path = Path(step1_config_path)
111
+ with open(cfg_path, 'r', encoding='utf-8') as f:
112
+ cfg = json.load(f)
113
+
114
+ # Read raw data and split
115
+ model_name = f"{cfg['model_list'][0]}_{cfg['model_categories'][0]}"
116
+ data_dir = (cfg_path.parent / cfg["data_dir"]).resolve()
117
+ raw_path = data_dir / f"{model_name}.csv"
118
+
119
+ if not raw_path.exists():
120
+ raise FileNotFoundError(f"Data file not found: {raw_path}")
121
+
122
+ raw = pd.read_csv(raw_path)
123
+
124
+ # Import split function
125
+ try:
126
+ from ins_pricing.cli.utils.cli_common import split_train_test
127
+ except ImportError:
128
+ raise ImportError("Cannot import split_train_test. Ensure ins_pricing is installed.")
129
+
130
+ # Split data using same settings as step 1
131
+ holdout_ratio = cfg.get("holdout_ratio", cfg.get("prop_test", 0.25))
132
+ split_strategy = cfg.get("split_strategy", "random")
133
+ split_group_col = cfg.get("split_group_col")
134
+ split_time_col = cfg.get("split_time_col")
135
+ split_time_ascending = cfg.get("split_time_ascending", True)
136
+ rand_seed = cfg.get("rand_seed", 13)
137
+
138
+ train_df, test_df = split_train_test(
139
+ raw,
140
+ holdout_ratio=holdout_ratio,
141
+ strategy=split_strategy,
142
+ group_col=split_group_col,
143
+ time_col=split_time_col,
144
+ time_ascending=split_time_ascending,
145
+ rand_seed=rand_seed,
146
+ reset_index_mode="time_group",
147
+ ratio_label="holdout_ratio",
148
+ )
149
+
150
+ # Load cached embeddings
151
+ out_root = (cfg_path.parent / cfg["output_dir"]).resolve()
152
+ pred_prefix = cfg.get("ft_feature_prefix", "ft_emb")
153
+ pred_dir = out_root / "Results" / "predictions"
154
+
155
+ train_pred_path = pred_dir / f"{model_name}_{pred_prefix}_train.csv"
156
+ test_pred_path = pred_dir / f"{model_name}_{pred_prefix}_test.csv"
157
+
158
+ if not train_pred_path.exists() or not test_pred_path.exists():
159
+ raise FileNotFoundError(
160
+ f"Embedding files not found. Run Step 1 first.\n"
161
+ f"Expected: {train_pred_path} and {test_pred_path}"
162
+ )
163
+
164
+ pred_train = pd.read_csv(train_pred_path)
165
+ pred_test = pd.read_csv(test_pred_path)
166
+
167
+ if len(pred_train) != len(train_df) or len(pred_test) != len(test_df):
168
+ raise ValueError(
169
+ "Prediction rows do not match split sizes; check split settings.")
170
+
171
+ # Merge embeddings with raw data
172
+ aug = raw.copy()
173
+ aug.loc[train_df.index, pred_train.columns] = pred_train.values
174
+ aug.loc[test_df.index, pred_test.columns] = pred_test.values
175
+
176
+ # Save augmented data
177
+ data_out_dir = cfg_path.parent / augmented_data_dir
178
+ data_out_dir.mkdir(parents=True, exist_ok=True)
179
+ aug_path = data_out_dir / f"{model_name}.csv"
180
+ aug.to_csv(aug_path, index=False)
181
+
182
+ # Get embedding column names
183
+ embed_cols = list(pred_train.columns)
184
+
185
+ # Generate configs
186
+ xgb_config = None
187
+ resn_config = None
188
+
189
+ if 'xgb' in target_models:
190
+ xgb_config = self._build_xgb_config(cfg, cfg_path, embed_cols, augmented_data_dir)
191
+ self.step2_configs['xgb'] = xgb_config
192
+
193
+ if 'resn' in target_models:
194
+ resn_config = self._build_resn_config(cfg, cfg_path, embed_cols, augmented_data_dir)
195
+ self.step2_configs['resn'] = resn_config
196
+
197
+ return xgb_config, resn_config
198
+
199
+ def _build_xgb_config(
200
+ self,
201
+ base_cfg: Dict[str, Any],
202
+ cfg_path: Path,
203
+ embed_cols: List[str],
204
+ data_dir: str
205
+ ) -> Dict[str, Any]:
206
+ """Build XGB config for Step 2."""
207
+ xgb_cfg = copy.deepcopy(base_cfg)
208
+
209
+ xgb_cfg["data_dir"] = str(data_dir)
210
+ xgb_cfg["feature_list"] = base_cfg["feature_list"] + embed_cols
211
+ xgb_cfg["ft_role"] = "model"
212
+ xgb_cfg["stack_model_keys"] = ["xgb"]
213
+ xgb_cfg["cache_predictions"] = False
214
+
215
+ # Disable DDP for XGB
216
+ xgb_cfg["use_resn_ddp"] = False
217
+ xgb_cfg["use_ft_ddp"] = False
218
+ xgb_cfg["use_gnn_ddp"] = False
219
+ xgb_cfg["use_resn_data_parallel"] = False
220
+ xgb_cfg["use_ft_data_parallel"] = False
221
+ xgb_cfg["use_gnn_data_parallel"] = False
222
+
223
+ xgb_cfg["output_dir"] = "./ResultsXGBFromFTUnsupervised"
224
+ xgb_cfg["optuna_storage"] = "./ResultsXGBFromFTUnsupervised/optuna/bayesopt.sqlite3"
225
+ xgb_cfg["optuna_study_prefix"] = "pricing_ft_unsup_xgb"
226
+ xgb_cfg["loss_name"] = "mse"
227
+
228
+ runner_cfg = xgb_cfg.get("runner", {})
229
+ runner_cfg["model_keys"] = ["xgb"]
230
+ runner_cfg["nproc_per_node"] = 1
231
+ runner_cfg["plot_curves"] = False
232
+ xgb_cfg["runner"] = runner_cfg
233
+
234
+ xgb_cfg["plot_curves"] = False
235
+ plot_cfg = xgb_cfg.get("plot", {})
236
+ plot_cfg["enable"] = False
237
+ xgb_cfg["plot"] = plot_cfg
238
+
239
+ return xgb_cfg
240
+
241
+ def _build_resn_config(
242
+ self,
243
+ base_cfg: Dict[str, Any],
244
+ cfg_path: Path,
245
+ embed_cols: List[str],
246
+ data_dir: str
247
+ ) -> Dict[str, Any]:
248
+ """Build ResNet config for Step 2."""
249
+ resn_cfg = copy.deepcopy(base_cfg)
250
+
251
+ resn_cfg["data_dir"] = str(data_dir)
252
+ resn_cfg["feature_list"] = base_cfg["feature_list"] + embed_cols
253
+ resn_cfg["ft_role"] = "model"
254
+ resn_cfg["stack_model_keys"] = ["resn"]
255
+ resn_cfg["cache_predictions"] = False
256
+
257
+ # Enable DDP for ResNet
258
+ resn_cfg["use_resn_ddp"] = True
259
+ resn_cfg["use_ft_ddp"] = False
260
+ resn_cfg["use_gnn_ddp"] = False
261
+ resn_cfg["use_resn_data_parallel"] = False
262
+ resn_cfg["use_ft_data_parallel"] = False
263
+ resn_cfg["use_gnn_data_parallel"] = False
264
+
265
+ resn_cfg["output_dir"] = "./ResultsResNFromFTUnsupervised"
266
+ resn_cfg["optuna_storage"] = "./ResultsResNFromFTUnsupervised/optuna/bayesopt.sqlite3"
267
+ resn_cfg["optuna_study_prefix"] = "pricing_ft_unsup_resn_ddp"
268
+ resn_cfg["loss_name"] = "mse"
269
+
270
+ runner_cfg = resn_cfg.get("runner", {})
271
+ runner_cfg["model_keys"] = ["resn"]
272
+ runner_cfg["nproc_per_node"] = 2
273
+ runner_cfg["plot_curves"] = False
274
+ resn_cfg["runner"] = runner_cfg
275
+
276
+ resn_cfg["plot_curves"] = False
277
+ plot_cfg = resn_cfg.get("plot", {})
278
+ plot_cfg["enable"] = False
279
+ resn_cfg["plot"] = plot_cfg
280
+
281
+ return resn_cfg
282
+
283
+ def save_configs(self, output_dir: str = ".") -> Dict[str, str]:
284
+ """
285
+ Save generated configs to files.
286
+
287
+ Args:
288
+ output_dir: Directory to save config files
289
+
290
+ Returns:
291
+ Dictionary mapping model names to saved file paths
292
+ """
293
+ output_path = Path(output_dir)
294
+ output_path.mkdir(parents=True, exist_ok=True)
295
+
296
+ saved_files = {}
297
+
298
+ if self.step1_config:
299
+ step1_path = output_path / "config_ft_step1_unsupervised.json"
300
+ with open(step1_path, 'w', encoding='utf-8') as f:
301
+ json.dump(self.step1_config, f, indent=2)
302
+ saved_files['ft_step1'] = str(step1_path)
303
+
304
+ if 'xgb' in self.step2_configs:
305
+ xgb_path = output_path / "config_xgb_from_ft_step2.json"
306
+ with open(xgb_path, 'w', encoding='utf-8') as f:
307
+ json.dump(self.step2_configs['xgb'], f, indent=2)
308
+ saved_files['xgb_step2'] = str(xgb_path)
309
+
310
+ if 'resn' in self.step2_configs:
311
+ resn_path = output_path / "config_resn_from_ft_step2.json"
312
+ with open(resn_path, 'w', encoding='utf-8') as f:
313
+ json.dump(self.step2_configs['resn'], f, indent=2)
314
+ saved_files['resn_step2'] = str(resn_path)
315
+
316
+ return saved_files
@@ -0,0 +1,388 @@
1
+ """
2
+ Unified Task Runner with Real-time Logging
3
+ Executes model training, explanation, plotting, and other tasks based on config.
4
+ """
5
+
6
+ import sys
7
+ import threading
8
+ import queue
9
+ import time
10
+ import json
11
+ import subprocess
12
+ from pathlib import Path
13
+ from typing import Generator, Optional, Dict, Any, List, Sequence, Tuple
14
+ import logging
15
+
16
+
17
+ class LogCapture:
18
+ """Capture stdout and stderr for real-time display."""
19
+
20
+ def __init__(self):
21
+ self.log_queue = queue.Queue()
22
+ self.stop_flag = threading.Event()
23
+
24
+ def write(self, text: str):
25
+ """Write method for capturing output."""
26
+ if text and text.strip():
27
+ self.log_queue.put(text)
28
+
29
+ def flush(self):
30
+ """Flush method (required for file-like objects)."""
31
+ pass
32
+
33
+
34
+ class TaskRunner:
35
+ """
36
+ Run model tasks (training, explain, plotting, etc.) and capture logs.
37
+
38
+ Supports all task modes defined in config.runner.mode:
39
+ - entry: Standard model training
40
+ - explain: Model explanation (permutation, SHAP, etc.)
41
+ - incremental: Incremental training
42
+ - watchdog: Watchdog mode for monitoring
43
+ """
44
+
45
+ def __init__(self):
46
+ self.task_thread = None
47
+ self.log_capture = None
48
+
49
+ def _detect_task_mode(self, config_path: str) -> str:
50
+ """
51
+ Detect the task mode from config file.
52
+
53
+ Args:
54
+ config_path: Path to configuration JSON file
55
+
56
+ Returns:
57
+ Task mode string (e.g., 'entry', 'explain', 'incremental', 'watchdog')
58
+ """
59
+ try:
60
+ with open(config_path, 'r', encoding='utf-8') as f:
61
+ config = json.load(f)
62
+
63
+ runner_config = config.get('runner', {})
64
+ mode = runner_config.get('mode', 'entry')
65
+ return str(mode).lower()
66
+
67
+ except Exception as e:
68
+ print(f"Warning: Could not detect task mode, defaulting to 'entry': {e}")
69
+ return 'entry'
70
+
71
+ def _build_cmd_from_config(self, config_path: str) -> Tuple[List[str], str]:
72
+ """
73
+ Build the command to execute based on config.runner.mode, mirroring
74
+ ins_pricing.cli.utils.notebook_utils.run_from_config behavior.
75
+ """
76
+ from ins_pricing.cli.utils.cli_config import set_env
77
+ from ins_pricing.cli.utils.notebook_utils import (
78
+ build_bayesopt_entry_cmd,
79
+ build_incremental_cmd,
80
+ build_explain_cmd,
81
+ wrap_with_watchdog,
82
+ )
83
+
84
+ cfg_path = Path(config_path).resolve()
85
+ raw = json.loads(cfg_path.read_text(encoding="utf-8", errors="replace"))
86
+ set_env(raw.get("env", {}))
87
+ runner = dict(raw.get("runner") or {})
88
+
89
+ mode = str(runner.get("mode") or "entry").strip().lower()
90
+ use_watchdog = bool(runner.get("use_watchdog", False))
91
+ if mode == "watchdog":
92
+ use_watchdog = True
93
+ mode = "entry"
94
+
95
+ idle_seconds = int(runner.get("idle_seconds", 7200))
96
+ max_restarts = int(runner.get("max_restarts", 50))
97
+ restart_delay_seconds = int(runner.get("restart_delay_seconds", 10))
98
+
99
+ if mode == "incremental":
100
+ inc_args = runner.get("incremental_args") or []
101
+ if not isinstance(inc_args, list):
102
+ raise ValueError("config.runner.incremental_args must be a list of strings.")
103
+ cmd = build_incremental_cmd(cfg_path, extra_args=[str(x) for x in inc_args])
104
+ if use_watchdog:
105
+ cmd = wrap_with_watchdog(
106
+ cmd,
107
+ idle_seconds=idle_seconds,
108
+ max_restarts=max_restarts,
109
+ restart_delay_seconds=restart_delay_seconds,
110
+ )
111
+ return cmd, "incremental"
112
+
113
+ if mode == "explain":
114
+ exp_args = runner.get("explain_args") or []
115
+ if not isinstance(exp_args, list):
116
+ raise ValueError("config.runner.explain_args must be a list of strings.")
117
+ cmd = build_explain_cmd(cfg_path, extra_args=[str(x) for x in exp_args])
118
+ if use_watchdog:
119
+ cmd = wrap_with_watchdog(
120
+ cmd,
121
+ idle_seconds=idle_seconds,
122
+ max_restarts=max_restarts,
123
+ restart_delay_seconds=restart_delay_seconds,
124
+ )
125
+ return cmd, "explain"
126
+
127
+ if mode != "entry":
128
+ raise ValueError(
129
+ f"Unsupported runner.mode={mode!r}, expected 'entry', 'incremental', or 'explain'."
130
+ )
131
+
132
+ model_keys = runner.get("model_keys") or raw.get("model_keys") or ["ft"]
133
+ if not isinstance(model_keys, list):
134
+ raise ValueError("runner.model_keys must be a list of strings.")
135
+ nproc_per_node = int(runner.get("nproc_per_node", 1))
136
+ max_evals = int(runner.get("max_evals", raw.get("max_evals", 50)))
137
+ plot_curves = bool(runner.get("plot_curves", raw.get("plot_curves", True)))
138
+ ft_role = runner.get("ft_role", raw.get("ft_role"))
139
+
140
+ extra_args: List[str] = ["--max-evals", str(max_evals)]
141
+ if plot_curves:
142
+ extra_args.append("--plot-curves")
143
+ if ft_role:
144
+ extra_args += ["--ft-role", str(ft_role)]
145
+
146
+ cmd = build_bayesopt_entry_cmd(
147
+ cfg_path,
148
+ model_keys=[str(x) for x in model_keys],
149
+ nproc_per_node=nproc_per_node,
150
+ extra_args=extra_args,
151
+ )
152
+ if use_watchdog:
153
+ cmd = wrap_with_watchdog(
154
+ cmd,
155
+ idle_seconds=idle_seconds,
156
+ max_restarts=max_restarts,
157
+ restart_delay_seconds=restart_delay_seconds,
158
+ )
159
+ return cmd, "entry"
160
+
161
+ def run_task(self, config_path: str) -> Generator[str, None, None]:
162
+ """
163
+ Run task based on config file with real-time log capture.
164
+
165
+ This method automatically detects the task mode from the config file
166
+ (training, explain, plotting, etc.) and runs the appropriate task.
167
+
168
+ Args:
169
+ config_path: Path to configuration JSON file
170
+
171
+ Yields:
172
+ Log lines as they are generated
173
+ """
174
+ self.log_capture = LogCapture()
175
+
176
+ # Configure logging to capture both file and stream output
177
+ log_handler = logging.StreamHandler(self.log_capture)
178
+ log_handler.setLevel(logging.INFO)
179
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
180
+ log_handler.setFormatter(formatter)
181
+
182
+ # Add handler to root logger
183
+ root_logger = logging.getLogger()
184
+ original_handlers = root_logger.handlers.copy()
185
+ root_logger.addHandler(log_handler)
186
+
187
+ # Store original stdout/stderr
188
+ original_stdout = sys.stdout
189
+ original_stderr = sys.stderr
190
+
191
+ try:
192
+ # Detect task mode
193
+ task_mode = self._detect_task_mode(config_path)
194
+
195
+ # Start task in separate thread
196
+ exception_holder = []
197
+
198
+ def task_worker():
199
+ try:
200
+ sys.stdout = self.log_capture
201
+ sys.stderr = self.log_capture
202
+
203
+ # Log start
204
+ cmd, task_mode = self._build_cmd_from_config(config_path)
205
+ print(f"Starting task [{task_mode}] with config: {config_path}")
206
+ print("=" * 80)
207
+
208
+ # Run subprocess with streamed output
209
+ proc = subprocess.Popen(
210
+ cmd,
211
+ stdout=subprocess.PIPE,
212
+ stderr=subprocess.STDOUT,
213
+ text=True,
214
+ bufsize=1,
215
+ cwd=str(Path(config_path).resolve().parent),
216
+ )
217
+ if proc.stdout is not None:
218
+ for line in proc.stdout:
219
+ print(line.rstrip())
220
+ return_code = proc.wait()
221
+ if return_code != 0:
222
+ raise RuntimeError(f"Task exited with code {return_code}")
223
+
224
+ print("=" * 80)
225
+ print(f"Task [{task_mode}] completed successfully!")
226
+
227
+ except Exception as e:
228
+ exception_holder.append(e)
229
+ print(f"Error during task execution: {str(e)}")
230
+ import traceback
231
+ print(traceback.format_exc())
232
+
233
+ finally:
234
+ sys.stdout = original_stdout
235
+ sys.stderr = original_stderr
236
+
237
+ self.task_thread = threading.Thread(target=task_worker, daemon=True)
238
+ self.task_thread.start()
239
+
240
+ # Yield log lines as they come in
241
+ last_update = time.time()
242
+ while self.task_thread.is_alive() or not self.log_capture.log_queue.empty():
243
+ try:
244
+ # Try to get log with timeout
245
+ log_line = self.log_capture.log_queue.get(timeout=0.1)
246
+ yield log_line
247
+ last_update = time.time()
248
+
249
+ except queue.Empty:
250
+ # Send heartbeat every 5 seconds
251
+ if time.time() - last_update > 5:
252
+ yield "."
253
+ last_update = time.time()
254
+ continue
255
+
256
+ # Wait for thread to complete
257
+ self.task_thread.join(timeout=1)
258
+
259
+ # Check for exceptions
260
+ if exception_holder:
261
+ raise exception_holder[0]
262
+
263
+ finally:
264
+ # Restore original stdout/stderr
265
+ sys.stdout = original_stdout
266
+ sys.stderr = original_stderr
267
+
268
+ # Restore original logging handlers
269
+ root_logger.handlers = original_handlers
270
+
271
+ def run_callable(self, func, *args, **kwargs) -> Generator[str, None, None]:
272
+ """Run an in-process callable and stream stdout/stderr."""
273
+ self.log_capture = LogCapture()
274
+
275
+ log_handler = logging.StreamHandler(self.log_capture)
276
+ log_handler.setLevel(logging.INFO)
277
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
278
+ log_handler.setFormatter(formatter)
279
+
280
+ root_logger = logging.getLogger()
281
+ original_handlers = root_logger.handlers.copy()
282
+ root_logger.addHandler(log_handler)
283
+
284
+ original_stdout = sys.stdout
285
+ original_stderr = sys.stderr
286
+
287
+ try:
288
+ exception_holder = []
289
+
290
+ def task_worker():
291
+ try:
292
+ sys.stdout = self.log_capture
293
+ sys.stderr = self.log_capture
294
+ func(*args, **kwargs)
295
+ except Exception as e:
296
+ exception_holder.append(e)
297
+ print(f"Error during task execution: {str(e)}")
298
+ import traceback
299
+ print(traceback.format_exc())
300
+ finally:
301
+ sys.stdout = original_stdout
302
+ sys.stderr = original_stderr
303
+
304
+ self.task_thread = threading.Thread(target=task_worker, daemon=True)
305
+ self.task_thread.start()
306
+
307
+ last_update = time.time()
308
+ while self.task_thread.is_alive() or not self.log_capture.log_queue.empty():
309
+ try:
310
+ log_line = self.log_capture.log_queue.get(timeout=0.1)
311
+ yield log_line
312
+ last_update = time.time()
313
+ except queue.Empty:
314
+ if time.time() - last_update > 5:
315
+ yield "."
316
+ last_update = time.time()
317
+ continue
318
+
319
+ self.task_thread.join(timeout=1)
320
+ if exception_holder:
321
+ raise exception_holder[0]
322
+ finally:
323
+ sys.stdout = original_stdout
324
+ sys.stderr = original_stderr
325
+ root_logger.handlers = original_handlers
326
+
327
+ def stop_task(self):
328
+ """Stop the current task process."""
329
+ if self.log_capture:
330
+ self.log_capture.stop_flag.set()
331
+
332
+ if self.task_thread and self.task_thread.is_alive():
333
+ # Note: Thread.join() will wait for completion
334
+ # For forceful termination, you may need to use process-based approach
335
+ self.task_thread.join(timeout=5)
336
+
337
+
338
+ # Backward compatibility aliases
339
+ TrainingRunner = TaskRunner
340
+
341
+
342
+ class StreamToLogger:
343
+ """
344
+ Fake file-like stream object that redirects writes to a logger instance.
345
+ """
346
+
347
+ def __init__(self, logger, log_level=logging.INFO):
348
+ self.logger = logger
349
+ self.log_level = log_level
350
+ self.linebuf = ''
351
+
352
+ def write(self, buf):
353
+ for line in buf.rstrip().splitlines():
354
+ self.logger.log(self.log_level, line.rstrip())
355
+
356
+ def flush(self):
357
+ pass
358
+
359
+
360
+ def setup_logger(name: str = "task") -> logging.Logger:
361
+ """
362
+ Set up a logger for task execution.
363
+
364
+ Args:
365
+ name: Logger name
366
+
367
+ Returns:
368
+ Configured logger instance
369
+ """
370
+ logger = logging.getLogger(name)
371
+ logger.setLevel(logging.INFO)
372
+
373
+ # Create console handler
374
+ console_handler = logging.StreamHandler()
375
+ console_handler.setLevel(logging.INFO)
376
+
377
+ # Create formatter
378
+ formatter = logging.Formatter(
379
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
380
+ datefmt='%Y-%m-%d %H:%M:%S'
381
+ )
382
+ console_handler.setFormatter(formatter)
383
+
384
+ # Add handler to logger
385
+ if not logger.handlers:
386
+ logger.addHandler(console_handler)
387
+
388
+ return logger