ins-pricing 0.5.0__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. ins_pricing/cli/BayesOpt_entry.py +15 -5
  2. ins_pricing/cli/BayesOpt_incremental.py +43 -10
  3. ins_pricing/cli/Explain_Run.py +16 -5
  4. ins_pricing/cli/Explain_entry.py +29 -8
  5. ins_pricing/cli/Pricing_Run.py +16 -5
  6. ins_pricing/cli/bayesopt_entry_runner.py +45 -12
  7. ins_pricing/cli/utils/bootstrap.py +23 -0
  8. ins_pricing/cli/utils/cli_config.py +34 -15
  9. ins_pricing/cli/utils/import_resolver.py +14 -14
  10. ins_pricing/cli/utils/notebook_utils.py +120 -106
  11. ins_pricing/cli/watchdog_run.py +15 -5
  12. ins_pricing/frontend/app.py +132 -61
  13. ins_pricing/frontend/config_builder.py +33 -0
  14. ins_pricing/frontend/example_config.json +11 -0
  15. ins_pricing/frontend/runner.py +340 -388
  16. ins_pricing/modelling/README.md +1 -1
  17. ins_pricing/modelling/__init__.py +10 -10
  18. ins_pricing/modelling/bayesopt/README.md +29 -11
  19. ins_pricing/modelling/bayesopt/config_components.py +12 -0
  20. ins_pricing/modelling/bayesopt/config_preprocess.py +50 -13
  21. ins_pricing/modelling/bayesopt/core.py +47 -19
  22. ins_pricing/modelling/bayesopt/model_plotting_mixin.py +20 -14
  23. ins_pricing/modelling/bayesopt/models/model_ft_components.py +349 -342
  24. ins_pricing/modelling/bayesopt/models/model_ft_trainer.py +11 -5
  25. ins_pricing/modelling/bayesopt/models/model_gnn.py +20 -14
  26. ins_pricing/modelling/bayesopt/models/model_resn.py +9 -3
  27. ins_pricing/modelling/bayesopt/trainers/trainer_base.py +62 -50
  28. ins_pricing/modelling/bayesopt/trainers/trainer_ft.py +61 -53
  29. ins_pricing/modelling/bayesopt/trainers/trainer_glm.py +9 -3
  30. ins_pricing/modelling/bayesopt/trainers/trainer_gnn.py +40 -32
  31. ins_pricing/modelling/bayesopt/trainers/trainer_resn.py +36 -24
  32. ins_pricing/modelling/bayesopt/trainers/trainer_xgb.py +240 -37
  33. ins_pricing/modelling/bayesopt/utils/distributed_utils.py +193 -186
  34. ins_pricing/modelling/bayesopt/utils/torch_trainer_mixin.py +23 -10
  35. ins_pricing/pricing/factors.py +67 -56
  36. ins_pricing/setup.py +1 -1
  37. ins_pricing/utils/__init__.py +7 -6
  38. ins_pricing/utils/device.py +45 -24
  39. ins_pricing/utils/logging.py +34 -1
  40. ins_pricing/utils/profiling.py +8 -4
  41. {ins_pricing-0.5.0.dist-info → ins_pricing-0.5.3.dist-info}/METADATA +182 -182
  42. {ins_pricing-0.5.0.dist-info → ins_pricing-0.5.3.dist-info}/RECORD +44 -43
  43. {ins_pricing-0.5.0.dist-info → ins_pricing-0.5.3.dist-info}/WHEEL +0 -0
  44. {ins_pricing-0.5.0.dist-info → ins_pricing-0.5.3.dist-info}/top_level.txt +0 -0
@@ -1,388 +1,340 @@
1
- """
2
- Unified Task Runner with Real-time Logging
3
- Executes model training, explanation, plotting, and other tasks based on config.
4
- """
5
-
6
- import sys
7
- import threading
8
- import queue
9
- import time
10
- import json
11
- import subprocess
12
- from pathlib import Path
13
- from typing import Generator, Optional, Dict, Any, List, Sequence, Tuple
14
- import logging
15
-
16
-
17
- class LogCapture:
18
- """Capture stdout and stderr for real-time display."""
19
-
20
- def __init__(self):
21
- self.log_queue = queue.Queue()
22
- self.stop_flag = threading.Event()
23
-
24
- def write(self, text: str):
25
- """Write method for capturing output."""
26
- if text and text.strip():
27
- self.log_queue.put(text)
28
-
29
- def flush(self):
30
- """Flush method (required for file-like objects)."""
31
- pass
32
-
33
-
34
- class TaskRunner:
35
- """
36
- Run model tasks (training, explain, plotting, etc.) and capture logs.
37
-
38
- Supports all task modes defined in config.runner.mode:
39
- - entry: Standard model training
40
- - explain: Model explanation (permutation, SHAP, etc.)
41
- - incremental: Incremental training
42
- - watchdog: Watchdog mode for monitoring
43
- """
44
-
45
- def __init__(self):
46
- self.task_thread = None
47
- self.log_capture = None
48
-
49
- def _detect_task_mode(self, config_path: str) -> str:
50
- """
51
- Detect the task mode from config file.
52
-
53
- Args:
54
- config_path: Path to configuration JSON file
55
-
56
- Returns:
57
- Task mode string (e.g., 'entry', 'explain', 'incremental', 'watchdog')
58
- """
59
- try:
60
- with open(config_path, 'r', encoding='utf-8') as f:
61
- config = json.load(f)
62
-
63
- runner_config = config.get('runner', {})
64
- mode = runner_config.get('mode', 'entry')
65
- return str(mode).lower()
66
-
67
- except Exception as e:
68
- print(f"Warning: Could not detect task mode, defaulting to 'entry': {e}")
69
- return 'entry'
70
-
71
- def _build_cmd_from_config(self, config_path: str) -> Tuple[List[str], str]:
72
- """
73
- Build the command to execute based on config.runner.mode, mirroring
74
- ins_pricing.cli.utils.notebook_utils.run_from_config behavior.
75
- """
76
- from ins_pricing.cli.utils.cli_config import set_env
77
- from ins_pricing.cli.utils.notebook_utils import (
78
- build_bayesopt_entry_cmd,
79
- build_incremental_cmd,
80
- build_explain_cmd,
81
- wrap_with_watchdog,
82
- )
83
-
84
- cfg_path = Path(config_path).resolve()
85
- raw = json.loads(cfg_path.read_text(encoding="utf-8", errors="replace"))
86
- set_env(raw.get("env", {}))
87
- runner = dict(raw.get("runner") or {})
88
-
89
- mode = str(runner.get("mode") or "entry").strip().lower()
90
- use_watchdog = bool(runner.get("use_watchdog", False))
91
- if mode == "watchdog":
92
- use_watchdog = True
93
- mode = "entry"
94
-
95
- idle_seconds = int(runner.get("idle_seconds", 7200))
96
- max_restarts = int(runner.get("max_restarts", 50))
97
- restart_delay_seconds = int(runner.get("restart_delay_seconds", 10))
98
-
99
- if mode == "incremental":
100
- inc_args = runner.get("incremental_args") or []
101
- if not isinstance(inc_args, list):
102
- raise ValueError("config.runner.incremental_args must be a list of strings.")
103
- cmd = build_incremental_cmd(cfg_path, extra_args=[str(x) for x in inc_args])
104
- if use_watchdog:
105
- cmd = wrap_with_watchdog(
106
- cmd,
107
- idle_seconds=idle_seconds,
108
- max_restarts=max_restarts,
109
- restart_delay_seconds=restart_delay_seconds,
110
- )
111
- return cmd, "incremental"
112
-
113
- if mode == "explain":
114
- exp_args = runner.get("explain_args") or []
115
- if not isinstance(exp_args, list):
116
- raise ValueError("config.runner.explain_args must be a list of strings.")
117
- cmd = build_explain_cmd(cfg_path, extra_args=[str(x) for x in exp_args])
118
- if use_watchdog:
119
- cmd = wrap_with_watchdog(
120
- cmd,
121
- idle_seconds=idle_seconds,
122
- max_restarts=max_restarts,
123
- restart_delay_seconds=restart_delay_seconds,
124
- )
125
- return cmd, "explain"
126
-
127
- if mode != "entry":
128
- raise ValueError(
129
- f"Unsupported runner.mode={mode!r}, expected 'entry', 'incremental', or 'explain'."
130
- )
131
-
132
- model_keys = runner.get("model_keys") or raw.get("model_keys") or ["ft"]
133
- if not isinstance(model_keys, list):
134
- raise ValueError("runner.model_keys must be a list of strings.")
135
- nproc_per_node = int(runner.get("nproc_per_node", 1))
136
- max_evals = int(runner.get("max_evals", raw.get("max_evals", 50)))
137
- plot_curves = bool(runner.get("plot_curves", raw.get("plot_curves", True)))
138
- ft_role = runner.get("ft_role", raw.get("ft_role"))
139
-
140
- extra_args: List[str] = ["--max-evals", str(max_evals)]
141
- if plot_curves:
142
- extra_args.append("--plot-curves")
143
- if ft_role:
144
- extra_args += ["--ft-role", str(ft_role)]
145
-
146
- cmd = build_bayesopt_entry_cmd(
147
- cfg_path,
148
- model_keys=[str(x) for x in model_keys],
149
- nproc_per_node=nproc_per_node,
150
- extra_args=extra_args,
151
- )
152
- if use_watchdog:
153
- cmd = wrap_with_watchdog(
154
- cmd,
155
- idle_seconds=idle_seconds,
156
- max_restarts=max_restarts,
157
- restart_delay_seconds=restart_delay_seconds,
158
- )
159
- return cmd, "entry"
160
-
161
- def run_task(self, config_path: str) -> Generator[str, None, None]:
162
- """
163
- Run task based on config file with real-time log capture.
164
-
165
- This method automatically detects the task mode from the config file
166
- (training, explain, plotting, etc.) and runs the appropriate task.
167
-
168
- Args:
169
- config_path: Path to configuration JSON file
170
-
171
- Yields:
172
- Log lines as they are generated
173
- """
174
- self.log_capture = LogCapture()
175
-
176
- # Configure logging to capture both file and stream output
177
- log_handler = logging.StreamHandler(self.log_capture)
178
- log_handler.setLevel(logging.INFO)
179
- formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
180
- log_handler.setFormatter(formatter)
181
-
182
- # Add handler to root logger
183
- root_logger = logging.getLogger()
184
- original_handlers = root_logger.handlers.copy()
185
- root_logger.addHandler(log_handler)
186
-
187
- # Store original stdout/stderr
188
- original_stdout = sys.stdout
189
- original_stderr = sys.stderr
190
-
191
- try:
192
- # Detect task mode
193
- task_mode = self._detect_task_mode(config_path)
194
-
195
- # Start task in separate thread
196
- exception_holder = []
197
-
198
- def task_worker():
199
- try:
200
- sys.stdout = self.log_capture
201
- sys.stderr = self.log_capture
202
-
203
- # Log start
204
- cmd, task_mode = self._build_cmd_from_config(config_path)
205
- print(f"Starting task [{task_mode}] with config: {config_path}")
206
- print("=" * 80)
207
-
208
- # Run subprocess with streamed output
209
- proc = subprocess.Popen(
210
- cmd,
211
- stdout=subprocess.PIPE,
212
- stderr=subprocess.STDOUT,
213
- text=True,
214
- bufsize=1,
215
- cwd=str(Path(config_path).resolve().parent),
216
- )
217
- if proc.stdout is not None:
218
- for line in proc.stdout:
219
- print(line.rstrip())
220
- return_code = proc.wait()
221
- if return_code != 0:
222
- raise RuntimeError(f"Task exited with code {return_code}")
223
-
224
- print("=" * 80)
225
- print(f"Task [{task_mode}] completed successfully!")
226
-
227
- except Exception as e:
228
- exception_holder.append(e)
229
- print(f"Error during task execution: {str(e)}")
230
- import traceback
231
- print(traceback.format_exc())
232
-
233
- finally:
234
- sys.stdout = original_stdout
235
- sys.stderr = original_stderr
236
-
237
- self.task_thread = threading.Thread(target=task_worker, daemon=True)
238
- self.task_thread.start()
239
-
240
- # Yield log lines as they come in
241
- last_update = time.time()
242
- while self.task_thread.is_alive() or not self.log_capture.log_queue.empty():
243
- try:
244
- # Try to get log with timeout
245
- log_line = self.log_capture.log_queue.get(timeout=0.1)
246
- yield log_line
247
- last_update = time.time()
248
-
249
- except queue.Empty:
250
- # Send heartbeat every 5 seconds
251
- if time.time() - last_update > 5:
252
- yield "."
253
- last_update = time.time()
254
- continue
255
-
256
- # Wait for thread to complete
257
- self.task_thread.join(timeout=1)
258
-
259
- # Check for exceptions
260
- if exception_holder:
261
- raise exception_holder[0]
262
-
263
- finally:
264
- # Restore original stdout/stderr
265
- sys.stdout = original_stdout
266
- sys.stderr = original_stderr
267
-
268
- # Restore original logging handlers
269
- root_logger.handlers = original_handlers
270
-
271
- def run_callable(self, func, *args, **kwargs) -> Generator[str, None, None]:
272
- """Run an in-process callable and stream stdout/stderr."""
273
- self.log_capture = LogCapture()
274
-
275
- log_handler = logging.StreamHandler(self.log_capture)
276
- log_handler.setLevel(logging.INFO)
277
- formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
278
- log_handler.setFormatter(formatter)
279
-
280
- root_logger = logging.getLogger()
281
- original_handlers = root_logger.handlers.copy()
282
- root_logger.addHandler(log_handler)
283
-
284
- original_stdout = sys.stdout
285
- original_stderr = sys.stderr
286
-
287
- try:
288
- exception_holder = []
289
-
290
- def task_worker():
291
- try:
292
- sys.stdout = self.log_capture
293
- sys.stderr = self.log_capture
294
- func(*args, **kwargs)
295
- except Exception as e:
296
- exception_holder.append(e)
297
- print(f"Error during task execution: {str(e)}")
298
- import traceback
299
- print(traceback.format_exc())
300
- finally:
301
- sys.stdout = original_stdout
302
- sys.stderr = original_stderr
303
-
304
- self.task_thread = threading.Thread(target=task_worker, daemon=True)
305
- self.task_thread.start()
306
-
307
- last_update = time.time()
308
- while self.task_thread.is_alive() or not self.log_capture.log_queue.empty():
309
- try:
310
- log_line = self.log_capture.log_queue.get(timeout=0.1)
311
- yield log_line
312
- last_update = time.time()
313
- except queue.Empty:
314
- if time.time() - last_update > 5:
315
- yield "."
316
- last_update = time.time()
317
- continue
318
-
319
- self.task_thread.join(timeout=1)
320
- if exception_holder:
321
- raise exception_holder[0]
322
- finally:
323
- sys.stdout = original_stdout
324
- sys.stderr = original_stderr
325
- root_logger.handlers = original_handlers
326
-
327
- def stop_task(self):
328
- """Stop the current task process."""
329
- if self.log_capture:
330
- self.log_capture.stop_flag.set()
331
-
332
- if self.task_thread and self.task_thread.is_alive():
333
- # Note: Thread.join() will wait for completion
334
- # For forceful termination, you may need to use process-based approach
335
- self.task_thread.join(timeout=5)
336
-
337
-
338
- # Backward compatibility aliases
339
- TrainingRunner = TaskRunner
340
-
341
-
342
- class StreamToLogger:
343
- """
344
- Fake file-like stream object that redirects writes to a logger instance.
345
- """
346
-
347
- def __init__(self, logger, log_level=logging.INFO):
348
- self.logger = logger
349
- self.log_level = log_level
350
- self.linebuf = ''
351
-
352
- def write(self, buf):
353
- for line in buf.rstrip().splitlines():
354
- self.logger.log(self.log_level, line.rstrip())
355
-
356
- def flush(self):
357
- pass
358
-
359
-
360
- def setup_logger(name: str = "task") -> logging.Logger:
361
- """
362
- Set up a logger for task execution.
363
-
364
- Args:
365
- name: Logger name
366
-
367
- Returns:
368
- Configured logger instance
369
- """
370
- logger = logging.getLogger(name)
371
- logger.setLevel(logging.INFO)
372
-
373
- # Create console handler
374
- console_handler = logging.StreamHandler()
375
- console_handler.setLevel(logging.INFO)
376
-
377
- # Create formatter
378
- formatter = logging.Formatter(
379
- '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
380
- datefmt='%Y-%m-%d %H:%M:%S'
381
- )
382
- console_handler.setFormatter(formatter)
383
-
384
- # Add handler to logger
385
- if not logger.handlers:
386
- logger.addHandler(console_handler)
387
-
388
- return logger
1
+ """
2
+ Unified Task Runner with Real-time Logging
3
+ Executes model training, explanation, plotting, and other tasks based on config.
4
+ """
5
+
6
+ import sys
7
+ import os
8
+ import threading
9
+ import queue
10
+ import time
11
+ import json
12
+ import subprocess
13
+ from pathlib import Path
14
+ from typing import Generator, Optional, Dict, Any, List, Sequence, Tuple
15
+ import logging
16
+
17
+ from ins_pricing.utils import get_logger, log_print
18
+
19
+ _logger = get_logger("ins_pricing.frontend.runner")
20
+ _logger.propagate = False
21
+ if not _logger.handlers:
22
+ _handler = logging.StreamHandler()
23
+ _handler.setFormatter(logging.Formatter("%(message)s"))
24
+ _logger.addHandler(_handler)
25
+
26
+
27
+ def _log(*args, **kwargs) -> None:
28
+ log_print(_logger, *args, **kwargs)
29
+
30
+ class LogCapture:
31
+ """Capture stdout and stderr for real-time display."""
32
+
33
+ def __init__(self):
34
+ self.log_queue = queue.Queue()
35
+ self.stop_flag = threading.Event()
36
+
37
+ def write(self, text: str):
38
+ """Write method for capturing output."""
39
+ if text and text.strip():
40
+ self.log_queue.put(text)
41
+
42
+ def flush(self):
43
+ """Flush method (required for file-like objects)."""
44
+ pass
45
+
46
+
47
+ class TaskRunner:
48
+ """
49
+ Run model tasks (training, explain, plotting, etc.) and capture logs.
50
+
51
+ Supports all task modes defined in config.runner.mode:
52
+ - entry: Standard model training
53
+ - explain: Model explanation (permutation, SHAP, etc.)
54
+ - incremental: Incremental training
55
+ - watchdog: Watchdog mode for monitoring
56
+ """
57
+
58
+ def __init__(self):
59
+ self.task_thread = None
60
+ self.log_capture = None
61
+ self._proc: Optional[subprocess.Popen] = None
62
+
63
+ def _detect_task_mode(self, config_path: str) -> str:
64
+ """
65
+ Detect the task mode from config file.
66
+
67
+ Args:
68
+ config_path: Path to configuration JSON file
69
+
70
+ Returns:
71
+ Task mode string (e.g., 'entry', 'explain', 'incremental', 'watchdog')
72
+ """
73
+ try:
74
+ with open(config_path, 'r', encoding='utf-8') as f:
75
+ config = json.load(f)
76
+
77
+ runner_config = config.get('runner', {})
78
+ mode = runner_config.get('mode', 'entry')
79
+ return str(mode).lower()
80
+
81
+ except Exception as e:
82
+ _log(f"Warning: Could not detect task mode, defaulting to 'entry': {e}")
83
+ return 'entry'
84
+
85
+ def _build_cmd_from_config(self, config_path: str) -> Tuple[List[str], str]:
86
+ """Build the command to execute based on config.runner.mode."""
87
+ from ins_pricing.cli.utils.notebook_utils import build_cmd_from_config
88
+
89
+ return build_cmd_from_config(config_path)
90
+
91
+ def run_task(self, config_path: str) -> Generator[str, None, None]:
92
+ """
93
+ Run task based on config file with real-time log capture.
94
+
95
+ This method automatically detects the task mode from the config file
96
+ (training, explain, plotting, etc.) and runs the appropriate task.
97
+
98
+ Args:
99
+ config_path: Path to configuration JSON file
100
+
101
+ Yields:
102
+ Log lines as they are generated
103
+ """
104
+ self.log_capture = LogCapture()
105
+
106
+ # Configure logging to capture both file and stream output
107
+ log_handler = logging.StreamHandler(self.log_capture)
108
+ log_handler.setLevel(logging.INFO)
109
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
110
+ log_handler.setFormatter(formatter)
111
+
112
+ # Add handler to root logger
113
+ root_logger = logging.getLogger()
114
+ original_handlers = root_logger.handlers.copy()
115
+ root_logger.addHandler(log_handler)
116
+
117
+ # Store original stdout/stderr
118
+ original_stdout = sys.stdout
119
+ original_stderr = sys.stderr
120
+
121
+ try:
122
+ # Detect task mode
123
+ task_mode = self._detect_task_mode(config_path)
124
+
125
+ # Start task in separate thread
126
+ exception_holder = []
127
+
128
+ def task_worker():
129
+ try:
130
+ sys.stdout = self.log_capture
131
+ sys.stderr = self.log_capture
132
+
133
+ # Log start
134
+ cmd, task_mode = self._build_cmd_from_config(config_path)
135
+ _log(f"Starting task [{task_mode}] with config: {config_path}")
136
+ _log("=" * 80)
137
+
138
+ # Run subprocess with streamed output
139
+ proc = subprocess.Popen(
140
+ cmd,
141
+ stdout=subprocess.PIPE,
142
+ stderr=subprocess.STDOUT,
143
+ text=True,
144
+ bufsize=1,
145
+ cwd=str(Path(config_path).resolve().parent),
146
+ )
147
+ self._proc = proc
148
+ if proc.stdout is not None:
149
+ for line in proc.stdout:
150
+ _log(line.rstrip())
151
+ return_code = proc.wait()
152
+ if return_code != 0:
153
+ raise RuntimeError(f"Task exited with code {return_code}")
154
+
155
+ _log("=" * 80)
156
+ _log(f"Task [{task_mode}] completed successfully!")
157
+
158
+ except Exception as e:
159
+ exception_holder.append(e)
160
+ _log(f"Error during task execution: {str(e)}")
161
+ import traceback
162
+ _log(traceback.format_exc())
163
+
164
+ finally:
165
+ self._proc = None
166
+ sys.stdout = original_stdout
167
+ sys.stderr = original_stderr
168
+
169
+ self.task_thread = threading.Thread(target=task_worker, daemon=True)
170
+ self.task_thread.start()
171
+
172
+ # Yield log lines as they come in
173
+ last_update = time.time()
174
+ while self.task_thread.is_alive() or not self.log_capture.log_queue.empty():
175
+ try:
176
+ # Try to get log with timeout
177
+ log_line = self.log_capture.log_queue.get(timeout=0.1)
178
+ yield log_line
179
+ last_update = time.time()
180
+
181
+ except queue.Empty:
182
+ # Send heartbeat every 5 seconds
183
+ if time.time() - last_update > 5:
184
+ yield "."
185
+ last_update = time.time()
186
+ continue
187
+
188
+ # Wait for thread to complete
189
+ self.task_thread.join(timeout=1)
190
+
191
+ # Check for exceptions
192
+ if exception_holder:
193
+ raise exception_holder[0]
194
+
195
+ finally:
196
+ # Restore original stdout/stderr
197
+ sys.stdout = original_stdout
198
+ sys.stderr = original_stderr
199
+
200
+ # Restore original logging handlers
201
+ root_logger.handlers = original_handlers
202
+
203
+ def run_callable(self, func, *args, **kwargs) -> Generator[str, None, None]:
204
+ """Run an in-process callable and stream stdout/stderr."""
205
+ self.log_capture = LogCapture()
206
+
207
+ log_handler = logging.StreamHandler(self.log_capture)
208
+ log_handler.setLevel(logging.INFO)
209
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
210
+ log_handler.setFormatter(formatter)
211
+
212
+ root_logger = logging.getLogger()
213
+ original_handlers = root_logger.handlers.copy()
214
+ root_logger.addHandler(log_handler)
215
+
216
+ original_stdout = sys.stdout
217
+ original_stderr = sys.stderr
218
+
219
+ try:
220
+ exception_holder = []
221
+
222
+ def task_worker():
223
+ try:
224
+ sys.stdout = self.log_capture
225
+ sys.stderr = self.log_capture
226
+ func(*args, **kwargs)
227
+ except Exception as e:
228
+ exception_holder.append(e)
229
+ _log(f"Error during task execution: {str(e)}")
230
+ import traceback
231
+ _log(traceback.format_exc())
232
+ finally:
233
+ sys.stdout = original_stdout
234
+ sys.stderr = original_stderr
235
+
236
+ self.task_thread = threading.Thread(target=task_worker, daemon=True)
237
+ self.task_thread.start()
238
+
239
+ last_update = time.time()
240
+ while self.task_thread.is_alive() or not self.log_capture.log_queue.empty():
241
+ try:
242
+ log_line = self.log_capture.log_queue.get(timeout=0.1)
243
+ yield log_line
244
+ last_update = time.time()
245
+ except queue.Empty:
246
+ if time.time() - last_update > 5:
247
+ yield "."
248
+ last_update = time.time()
249
+ continue
250
+
251
+ self.task_thread.join(timeout=1)
252
+ if exception_holder:
253
+ raise exception_holder[0]
254
+ finally:
255
+ sys.stdout = original_stdout
256
+ sys.stderr = original_stderr
257
+ root_logger.handlers = original_handlers
258
+
259
+ def stop_task(self):
260
+ """Stop the current task process."""
261
+ if self.log_capture:
262
+ self.log_capture.stop_flag.set()
263
+
264
+ proc = self._proc
265
+ if proc is not None and proc.poll() is None:
266
+ try:
267
+ if os.name == "nt":
268
+ subprocess.run(
269
+ ["taskkill", "/PID", str(proc.pid), "/T", "/F"],
270
+ stdout=subprocess.DEVNULL,
271
+ stderr=subprocess.DEVNULL,
272
+ check=False,
273
+ )
274
+ else:
275
+ proc.terminate()
276
+ try:
277
+ proc.wait(timeout=5)
278
+ except Exception:
279
+ proc.kill()
280
+ except Exception:
281
+ try:
282
+ proc.kill()
283
+ except Exception:
284
+ pass
285
+
286
+ if self.task_thread and self.task_thread.is_alive():
287
+ self.task_thread.join(timeout=5)
288
+
289
+
290
+ # Backward compatibility aliases
291
+ TrainingRunner = TaskRunner
292
+
293
+
294
+ class StreamToLogger:
295
+ """
296
+ Fake file-like stream object that redirects writes to a logger instance.
297
+ """
298
+
299
+ def __init__(self, logger, log_level=logging.INFO):
300
+ self.logger = logger
301
+ self.log_level = log_level
302
+ self.linebuf = ''
303
+
304
+ def write(self, buf):
305
+ for line in buf.rstrip().splitlines():
306
+ self.logger.log(self.log_level, line.rstrip())
307
+
308
+ def flush(self):
309
+ pass
310
+
311
+
312
+ def setup_logger(name: str = "task") -> logging.Logger:
313
+ """
314
+ Set up a logger for task execution.
315
+
316
+ Args:
317
+ name: Logger name
318
+
319
+ Returns:
320
+ Configured logger instance
321
+ """
322
+ logger = logging.getLogger(name)
323
+ logger.setLevel(logging.INFO)
324
+
325
+ # Create console handler
326
+ console_handler = logging.StreamHandler()
327
+ console_handler.setLevel(logging.INFO)
328
+
329
+ # Create formatter
330
+ formatter = logging.Formatter(
331
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
332
+ datefmt='%Y-%m-%d %H:%M:%S'
333
+ )
334
+ console_handler.setFormatter(formatter)
335
+
336
+ # Add handler to logger
337
+ if not logger.handlers:
338
+ logger.addHandler(console_handler)
339
+
340
+ return logger