ins-pricing 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/frontend/README.md +31 -0
- ins_pricing/frontend/app.py +915 -877
- ins_pricing/setup.py +1 -1
- {ins_pricing-0.4.0.dist-info → ins_pricing-0.4.1.dist-info}/METADATA +1 -1
- {ins_pricing-0.4.0.dist-info → ins_pricing-0.4.1.dist-info}/RECORD +7 -7
- {ins_pricing-0.4.0.dist-info → ins_pricing-0.4.1.dist-info}/WHEEL +0 -0
- {ins_pricing-0.4.0.dist-info → ins_pricing-0.4.1.dist-info}/top_level.txt +0 -0
ins_pricing/frontend/app.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Insurance Pricing Model Training Frontend
|
|
3
|
-
A Gradio-based web interface for configuring and running insurance pricing models.
|
|
4
|
-
"""
|
|
5
|
-
|
|
1
|
+
"""
|
|
2
|
+
Insurance Pricing Model Training Frontend
|
|
3
|
+
A Gradio-based web interface for configuring and running insurance pricing models.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
6
7
|
import platform
|
|
7
8
|
import subprocess
|
|
8
9
|
from ins_pricing.frontend.example_workflows import (
|
|
@@ -12,13 +13,12 @@ from ins_pricing.frontend.example_workflows import (
|
|
|
12
13
|
run_predict_ft_embed,
|
|
13
14
|
run_pre_oneway,
|
|
14
15
|
)
|
|
15
|
-
from ins_pricing.frontend.ft_workflow import FTWorkflowHelper
|
|
16
|
-
from ins_pricing.frontend.runner import TaskRunner
|
|
16
|
+
from ins_pricing.frontend.ft_workflow import FTWorkflowHelper
|
|
17
|
+
from ins_pricing.frontend.runner import TaskRunner
|
|
17
18
|
from ins_pricing.frontend.config_builder import ConfigBuilder
|
|
18
|
-
import gradio as gr
|
|
19
19
|
import json
|
|
20
20
|
import sys
|
|
21
|
-
import
|
|
21
|
+
import inspect
|
|
22
22
|
from pathlib import Path
|
|
23
23
|
from typing import Optional, Dict, Any, Callable, Iterable, Tuple
|
|
24
24
|
import threading
|
|
@@ -28,876 +28,914 @@ import time
|
|
|
28
28
|
# Add parent directory to path to import ins_pricing modules
|
|
29
29
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
30
30
|
|
|
31
|
+
os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False")
|
|
32
|
+
os.environ.setdefault("GRADIO_TELEMETRY_ENABLED", "False")
|
|
33
|
+
os.environ.setdefault("GRADIO_CHECK_VERSION", "False")
|
|
34
|
+
os.environ.setdefault("GRADIO_VERSION_CHECK", "False")
|
|
31
35
|
|
|
32
|
-
|
|
33
|
-
"""Main application class for the insurance pricing model tasks interface."""
|
|
34
|
-
|
|
35
|
-
def __init__(self):
|
|
36
|
-
self.config_builder = ConfigBuilder()
|
|
37
|
-
self.runner = TaskRunner()
|
|
38
|
-
self.ft_workflow = FTWorkflowHelper()
|
|
39
|
-
self.current_config = {}
|
|
40
|
-
self.current_step1_config = None
|
|
41
|
-
self.current_config_path: Optional[Path] = None
|
|
42
|
-
self.current_config_dir: Optional[Path] = None
|
|
43
|
-
|
|
44
|
-
def load_json_config(self, file_path) -> tuple[str, Dict[str, Any], str]:
|
|
45
|
-
"""Load configuration from uploaded JSON file."""
|
|
46
|
-
if not file_path:
|
|
47
|
-
return "No file uploaded", {}, ""
|
|
48
|
-
|
|
49
|
-
try:
|
|
50
|
-
path = Path(file_path).resolve()
|
|
51
|
-
with open(path, 'r', encoding='utf-8') as f:
|
|
52
|
-
config = json.load(f)
|
|
53
|
-
self.current_config = config
|
|
54
|
-
self.current_config_path = path
|
|
55
|
-
self.current_config_dir = path.parent
|
|
56
|
-
config_json = json.dumps(config, indent=2, ensure_ascii=False)
|
|
57
|
-
return f"Configuration loaded successfully from {path.name}", config, config_json
|
|
58
|
-
except Exception as e:
|
|
59
|
-
return f"Error loading config: {str(e)}", {}, ""
|
|
60
|
-
|
|
61
|
-
def build_config_from_ui(
|
|
62
|
-
self,
|
|
63
|
-
data_dir: str,
|
|
64
|
-
model_list: str,
|
|
65
|
-
model_categories: str,
|
|
66
|
-
target: str,
|
|
67
|
-
weight: str,
|
|
68
|
-
feature_list: str,
|
|
69
|
-
categorical_features: str,
|
|
70
|
-
task_type: str,
|
|
71
|
-
prop_test: float,
|
|
72
|
-
holdout_ratio: float,
|
|
73
|
-
val_ratio: float,
|
|
74
|
-
split_strategy: str,
|
|
75
|
-
rand_seed: int,
|
|
76
|
-
epochs: int,
|
|
77
|
-
output_dir: str,
|
|
78
|
-
use_gpu: bool,
|
|
79
|
-
model_keys: str,
|
|
80
|
-
max_evals: int,
|
|
81
|
-
xgb_max_depth_max: int,
|
|
82
|
-
xgb_n_estimators_max: int,
|
|
83
|
-
) -> tuple[str, str]:
|
|
84
|
-
"""Build configuration from UI parameters."""
|
|
85
|
-
try:
|
|
86
|
-
# Parse comma-separated lists
|
|
87
|
-
model_list = [x.strip()
|
|
88
|
-
for x in model_list.split(',') if x.strip()]
|
|
89
|
-
model_categories = [x.strip()
|
|
90
|
-
for x in model_categories.split(',') if x.strip()]
|
|
91
|
-
feature_list = [x.strip()
|
|
92
|
-
for x in feature_list.split(',') if x.strip()]
|
|
93
|
-
categorical_features = [
|
|
94
|
-
x.strip() for x in categorical_features.split(',') if x.strip()]
|
|
95
|
-
model_keys = [x.strip()
|
|
96
|
-
for x in model_keys.split(',') if x.strip()]
|
|
97
|
-
|
|
98
|
-
config = self.config_builder.build_config(
|
|
99
|
-
data_dir=data_dir,
|
|
100
|
-
model_list=model_list,
|
|
101
|
-
model_categories=model_categories,
|
|
102
|
-
target=target,
|
|
103
|
-
weight=weight,
|
|
104
|
-
feature_list=feature_list,
|
|
105
|
-
categorical_features=categorical_features,
|
|
106
|
-
task_type=task_type,
|
|
107
|
-
prop_test=prop_test,
|
|
108
|
-
holdout_ratio=holdout_ratio,
|
|
109
|
-
val_ratio=val_ratio,
|
|
110
|
-
split_strategy=split_strategy,
|
|
111
|
-
rand_seed=rand_seed,
|
|
112
|
-
epochs=epochs,
|
|
113
|
-
output_dir=output_dir,
|
|
114
|
-
use_gpu=use_gpu,
|
|
115
|
-
model_keys=model_keys,
|
|
116
|
-
max_evals=max_evals,
|
|
117
|
-
xgb_max_depth_max=xgb_max_depth_max,
|
|
118
|
-
xgb_n_estimators_max=xgb_n_estimators_max,
|
|
119
|
-
)
|
|
120
|
-
|
|
121
|
-
is_valid, msg = self.config_builder.validate_config(config)
|
|
122
|
-
if not is_valid:
|
|
123
|
-
return f"Validation failed: {msg}", ""
|
|
124
|
-
|
|
125
|
-
self.current_config = config
|
|
126
|
-
self.current_config_path = None
|
|
127
|
-
self.current_config_dir = None
|
|
128
|
-
config_json = json.dumps(config, indent=2, ensure_ascii=False)
|
|
129
|
-
return "Configuration built successfully", config_json
|
|
130
|
-
|
|
131
|
-
except Exception as e:
|
|
132
|
-
return f"Error building config: {str(e)}", ""
|
|
133
|
-
|
|
134
|
-
def save_config(self, config_json: str, filename: str) -> str:
|
|
135
|
-
"""Save current configuration to file."""
|
|
136
|
-
if not config_json:
|
|
137
|
-
return "No configuration to save"
|
|
138
|
-
|
|
139
|
-
try:
|
|
140
|
-
config_path = Path(filename)
|
|
141
|
-
with open(config_path, 'w', encoding='utf-8') as f:
|
|
142
|
-
json.dump(json.loads(config_json), f,
|
|
143
|
-
indent=2, ensure_ascii=False)
|
|
144
|
-
return f"Configuration saved to {config_path}"
|
|
145
|
-
except Exception as e:
|
|
146
|
-
return f"Error saving config: {str(e)}"
|
|
147
|
-
|
|
148
|
-
def run_training(self, config_json: str) -> tuple[str, str]:
|
|
149
|
-
"""
|
|
150
|
-
Run task (training, explain, plotting, etc.) with the current configuration.
|
|
151
|
-
|
|
152
|
-
The task type is automatically detected from config.runner.mode.
|
|
153
|
-
Supported modes: entry (training), explain, incremental, watchdog, etc.
|
|
154
|
-
"""
|
|
155
|
-
try:
|
|
156
|
-
temp_config_path = None
|
|
157
|
-
if config_json:
|
|
158
|
-
config = json.loads(config_json)
|
|
159
|
-
task_mode = config.get('runner', {}).get('mode', 'entry')
|
|
160
|
-
base_dir = self.current_config_dir or Path.cwd()
|
|
161
|
-
temp_config_path = (base_dir / "temp_config.json").resolve()
|
|
162
|
-
with open(temp_config_path, 'w', encoding='utf-8') as f:
|
|
163
|
-
json.dump(config, f, indent=2)
|
|
164
|
-
config_path = temp_config_path
|
|
165
|
-
elif self.current_config_path and self.current_config_path.exists():
|
|
166
|
-
config_path = self.current_config_path
|
|
167
|
-
config = json.loads(config_path.read_text(encoding="utf-8"))
|
|
168
|
-
task_mode = config.get('runner', {}).get('mode', 'entry')
|
|
169
|
-
elif self.current_config:
|
|
170
|
-
config = self.current_config
|
|
171
|
-
task_mode = config.get('runner', {}).get('mode', 'entry')
|
|
172
|
-
temp_config_path = (Path.cwd() / "temp_config.json").resolve()
|
|
173
|
-
with open(temp_config_path, 'w', encoding='utf-8') as f:
|
|
174
|
-
json.dump(config, f, indent=2)
|
|
175
|
-
config_path = temp_config_path
|
|
176
|
-
else:
|
|
177
|
-
return "No configuration provided", ""
|
|
178
|
-
|
|
179
|
-
log_generator = self.runner.run_task(str(config_path))
|
|
180
|
-
|
|
181
|
-
# Collect logs
|
|
182
|
-
full_log = ""
|
|
183
|
-
for log_line in log_generator:
|
|
184
|
-
full_log += log_line + "\n"
|
|
185
|
-
yield f"Task [{task_mode}] in progress...", full_log
|
|
186
|
-
|
|
187
|
-
# Clean up
|
|
188
|
-
if temp_config_path and temp_config_path.exists():
|
|
189
|
-
temp_config_path.unlink()
|
|
190
|
-
|
|
191
|
-
yield f"Task [{task_mode}] completed!", full_log
|
|
192
|
-
|
|
193
|
-
except Exception as e:
|
|
194
|
-
error_msg = f"Error during task execution: {str(e)}"
|
|
195
|
-
yield error_msg, error_msg
|
|
196
|
-
|
|
197
|
-
def prepare_ft_step1(self, config_json: str, use_ddp: bool, nproc: int) -> tuple[str, str]:
|
|
198
|
-
"""Prepare FT Step 1 configuration."""
|
|
199
|
-
if not config_json:
|
|
200
|
-
return "No configuration provided", ""
|
|
201
|
-
|
|
202
|
-
try:
|
|
203
|
-
config = json.loads(config_json)
|
|
204
|
-
step1_config = self.ft_workflow.prepare_step1_config(
|
|
205
|
-
base_config=config,
|
|
206
|
-
use_ddp=use_ddp,
|
|
207
|
-
nproc_per_node=int(nproc)
|
|
208
|
-
)
|
|
209
|
-
|
|
210
|
-
# Save to temp file
|
|
211
|
-
temp_path = Path("temp_ft_step1_config.json")
|
|
212
|
-
with open(temp_path, 'w', encoding='utf-8') as f:
|
|
213
|
-
json.dump(step1_config, f, indent=2)
|
|
214
|
-
|
|
215
|
-
self.current_step1_config = str(temp_path)
|
|
216
|
-
step1_json = json.dumps(step1_config, indent=2, ensure_ascii=False)
|
|
217
|
-
|
|
218
|
-
return "Step 1 config prepared. Click 'Run Step 1' to train FT embeddings.", step1_json
|
|
219
|
-
|
|
220
|
-
except Exception as e:
|
|
221
|
-
return f"Error preparing Step 1 config: {str(e)}", ""
|
|
222
|
-
|
|
223
|
-
def prepare_ft_step2(self, step1_config_path: str, target_models: str) -> tuple[str, str, str]:
|
|
224
|
-
"""Prepare FT Step 2 configurations."""
|
|
225
|
-
if not step1_config_path or not os.path.exists(step1_config_path):
|
|
226
|
-
return "Step 1 config not found. Run Step 1 first.", "", ""
|
|
227
|
-
|
|
228
|
-
try:
|
|
229
|
-
models = [m.strip() for m in target_models.split(',') if m.strip()]
|
|
230
|
-
xgb_cfg, resn_cfg = self.ft_workflow.generate_step2_configs(
|
|
231
|
-
step1_config_path=step1_config_path,
|
|
232
|
-
target_models=models
|
|
233
|
-
)
|
|
234
|
-
|
|
235
|
-
status_msg = f"Step 2 configs prepared for: {', '.join(models)}"
|
|
236
|
-
xgb_json = json.dumps(
|
|
237
|
-
xgb_cfg, indent=2, ensure_ascii=False) if xgb_cfg else ""
|
|
238
|
-
resn_json = json.dumps(
|
|
239
|
-
resn_cfg, indent=2, ensure_ascii=False) if resn_cfg else ""
|
|
240
|
-
|
|
241
|
-
return status_msg, xgb_json, resn_json
|
|
242
|
-
|
|
243
|
-
except FileNotFoundError as e:
|
|
244
|
-
return f"Error: {str(e)}\n\nMake sure Step 1 completed successfully.", "", ""
|
|
245
|
-
except Exception as e:
|
|
246
|
-
return f"Error preparing Step 2 configs: {str(e)}", "", ""
|
|
247
|
-
|
|
248
|
-
def open_results_folder(self, config_json: str) -> str:
|
|
249
|
-
"""Open the results folder in file explorer."""
|
|
250
|
-
try:
|
|
251
|
-
if config_json:
|
|
252
|
-
config = json.loads(config_json)
|
|
253
|
-
output_dir = config.get('output_dir', './Results')
|
|
254
|
-
results_path = Path(output_dir).resolve()
|
|
255
|
-
elif self.current_config_path and self.current_config_path.exists():
|
|
256
|
-
config = json.loads(
|
|
257
|
-
self.current_config_path.read_text(encoding="utf-8"))
|
|
258
|
-
output_dir = config.get('output_dir', './Results')
|
|
259
|
-
results_path = (
|
|
260
|
-
self.current_config_path.parent / output_dir).resolve()
|
|
261
|
-
elif self.current_config:
|
|
262
|
-
output_dir = self.current_config.get('output_dir', './Results')
|
|
263
|
-
results_path = Path(output_dir).resolve()
|
|
264
|
-
else:
|
|
265
|
-
return "No configuration loaded"
|
|
266
|
-
|
|
267
|
-
if not results_path.exists():
|
|
268
|
-
return f"Results folder does not exist yet: {results_path}"
|
|
269
|
-
|
|
270
|
-
# Open folder based on OS
|
|
271
|
-
system = platform.system()
|
|
272
|
-
if system == "Windows":
|
|
273
|
-
os.startfile(results_path)
|
|
274
|
-
elif system == "Darwin": # macOS
|
|
275
|
-
subprocess.run(["open", str(results_path)])
|
|
276
|
-
else: # Linux
|
|
277
|
-
subprocess.run(["xdg-open", str(results_path)])
|
|
278
|
-
|
|
279
|
-
return f"Opened folder: {results_path}"
|
|
280
|
-
|
|
281
|
-
except Exception as e:
|
|
282
|
-
return f"Error opening folder: {str(e)}"
|
|
283
|
-
|
|
284
|
-
def _run_workflow(self, label: str, func: Callable, *args, **kwargs):
|
|
285
|
-
"""Run a workflow function and stream logs."""
|
|
286
|
-
try:
|
|
287
|
-
log_generator = self.runner.run_callable(func, *args, **kwargs)
|
|
288
|
-
full_log = ""
|
|
289
|
-
for log_line in log_generator:
|
|
290
|
-
full_log += log_line + "\n"
|
|
291
|
-
yield f"{label} in progress...", full_log
|
|
292
|
-
yield f"{label} completed!", full_log
|
|
293
|
-
except Exception as e:
|
|
294
|
-
error_msg = f"{label} error: {str(e)}"
|
|
295
|
-
yield error_msg, error_msg
|
|
296
|
-
|
|
297
|
-
def run_pre_oneway_ui(
|
|
298
|
-
self,
|
|
299
|
-
data_path: str,
|
|
300
|
-
model_name: str,
|
|
301
|
-
target_col: str,
|
|
302
|
-
weight_col: str,
|
|
303
|
-
feature_list: str,
|
|
304
|
-
categorical_features: str,
|
|
305
|
-
n_bins: int,
|
|
306
|
-
holdout_ratio: float,
|
|
307
|
-
rand_seed: int,
|
|
308
|
-
output_dir: str,
|
|
309
|
-
):
|
|
310
|
-
yield from self._run_workflow(
|
|
311
|
-
"Pre-Oneway Plot",
|
|
312
|
-
run_pre_oneway,
|
|
313
|
-
data_path=data_path,
|
|
314
|
-
model_name=model_name,
|
|
315
|
-
target_col=target_col,
|
|
316
|
-
weight_col=weight_col,
|
|
317
|
-
feature_list=feature_list,
|
|
318
|
-
categorical_features=categorical_features,
|
|
319
|
-
n_bins=n_bins,
|
|
320
|
-
holdout_ratio=holdout_ratio,
|
|
321
|
-
rand_seed=rand_seed,
|
|
322
|
-
output_dir=output_dir or None,
|
|
323
|
-
)
|
|
324
|
-
|
|
325
|
-
def run_plot_direct_ui(self, cfg_path: str, xgb_cfg_path: str, resn_cfg_path: str):
|
|
326
|
-
yield from self._run_workflow(
|
|
327
|
-
"Direct Plot",
|
|
328
|
-
run_plot_direct,
|
|
329
|
-
cfg_path=cfg_path,
|
|
330
|
-
xgb_cfg_path=xgb_cfg_path,
|
|
331
|
-
resn_cfg_path=resn_cfg_path,
|
|
332
|
-
)
|
|
333
|
-
|
|
334
|
-
def run_plot_embed_ui(
|
|
335
|
-
self,
|
|
336
|
-
cfg_path: str,
|
|
337
|
-
xgb_cfg_path: str,
|
|
338
|
-
resn_cfg_path: str,
|
|
339
|
-
ft_cfg_path: str,
|
|
340
|
-
use_runtime_ft_embedding: bool,
|
|
341
|
-
):
|
|
342
|
-
yield from self._run_workflow(
|
|
343
|
-
"Embed Plot",
|
|
344
|
-
run_plot_embed,
|
|
345
|
-
cfg_path=cfg_path,
|
|
346
|
-
xgb_cfg_path=xgb_cfg_path,
|
|
347
|
-
resn_cfg_path=resn_cfg_path,
|
|
348
|
-
ft_cfg_path=ft_cfg_path,
|
|
349
|
-
use_runtime_ft_embedding=use_runtime_ft_embedding,
|
|
350
|
-
)
|
|
351
|
-
|
|
352
|
-
def run_predict_ui(
|
|
353
|
-
self,
|
|
354
|
-
ft_cfg_path: str,
|
|
355
|
-
xgb_cfg_path: str,
|
|
356
|
-
resn_cfg_path: str,
|
|
357
|
-
input_path: str,
|
|
358
|
-
output_path: str,
|
|
359
|
-
model_name: str,
|
|
360
|
-
model_keys: str,
|
|
361
|
-
):
|
|
362
|
-
yield from self._run_workflow(
|
|
363
|
-
"Prediction",
|
|
364
|
-
run_predict_ft_embed,
|
|
365
|
-
ft_cfg_path=ft_cfg_path,
|
|
366
|
-
xgb_cfg_path=xgb_cfg_path or None,
|
|
367
|
-
resn_cfg_path=resn_cfg_path or None,
|
|
368
|
-
input_path=input_path,
|
|
369
|
-
output_path=output_path,
|
|
370
|
-
model_name=model_name or None,
|
|
371
|
-
model_keys=model_keys,
|
|
372
|
-
)
|
|
373
|
-
|
|
374
|
-
def run_compare_xgb_ui(
|
|
375
|
-
self,
|
|
376
|
-
direct_cfg_path: str,
|
|
377
|
-
ft_cfg_path: str,
|
|
378
|
-
ft_embed_cfg_path: str,
|
|
379
|
-
label_direct: str,
|
|
380
|
-
label_ft: str,
|
|
381
|
-
use_runtime_ft_embedding: bool,
|
|
382
|
-
n_bins_override: int,
|
|
383
|
-
):
|
|
384
|
-
yield from self._run_workflow(
|
|
385
|
-
"Compare XGB",
|
|
386
|
-
run_compare_ft_embed,
|
|
387
|
-
direct_cfg_path=direct_cfg_path,
|
|
388
|
-
ft_cfg_path=ft_cfg_path,
|
|
389
|
-
ft_embed_cfg_path=ft_embed_cfg_path,
|
|
390
|
-
model_key="xgb",
|
|
391
|
-
label_direct=label_direct,
|
|
392
|
-
label_ft=label_ft,
|
|
393
|
-
use_runtime_ft_embedding=use_runtime_ft_embedding,
|
|
394
|
-
n_bins_override=n_bins_override,
|
|
395
|
-
)
|
|
396
|
-
|
|
397
|
-
def run_compare_resn_ui(
|
|
398
|
-
self,
|
|
399
|
-
direct_cfg_path: str,
|
|
400
|
-
ft_cfg_path: str,
|
|
401
|
-
ft_embed_cfg_path: str,
|
|
402
|
-
label_direct: str,
|
|
403
|
-
label_ft: str,
|
|
404
|
-
use_runtime_ft_embedding: bool,
|
|
405
|
-
n_bins_override: int,
|
|
406
|
-
):
|
|
407
|
-
yield from self._run_workflow(
|
|
408
|
-
"Compare ResNet",
|
|
409
|
-
run_compare_ft_embed,
|
|
410
|
-
direct_cfg_path=direct_cfg_path,
|
|
411
|
-
ft_cfg_path=ft_cfg_path,
|
|
412
|
-
ft_embed_cfg_path=ft_embed_cfg_path,
|
|
413
|
-
model_key="resn",
|
|
414
|
-
label_direct=label_direct,
|
|
415
|
-
label_ft=label_ft,
|
|
416
|
-
use_runtime_ft_embedding=use_runtime_ft_embedding,
|
|
417
|
-
n_bins_override=n_bins_override,
|
|
418
|
-
)
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
def create_ui():
|
|
422
|
-
"""Create the Gradio interface."""
|
|
423
|
-
app = PricingApp()
|
|
424
|
-
|
|
425
|
-
with gr.Blocks(title="Insurance Pricing Model Training", theme=gr.themes.Soft()) as demo:
|
|
426
|
-
gr.Markdown(
|
|
427
|
-
"""
|
|
428
|
-
# Insurance Pricing Model Training Interface
|
|
429
|
-
Configure and train insurance pricing models with an easy-to-use interface.
|
|
430
|
-
|
|
431
|
-
**Two ways to configure:**
|
|
432
|
-
1. **Upload JSON Config**: Upload an existing configuration file
|
|
433
|
-
2. **Manual Configuration**: Fill in the parameters below
|
|
434
|
-
"""
|
|
435
|
-
)
|
|
436
|
-
|
|
437
|
-
with gr.Tab("Configuration"):
|
|
438
|
-
with gr.Row():
|
|
439
|
-
with gr.Column(scale=1):
|
|
440
|
-
gr.Markdown("### Load Configuration")
|
|
441
|
-
json_file = gr.File(
|
|
442
|
-
label="Upload JSON Config File",
|
|
443
|
-
file_types=[".json"],
|
|
444
|
-
type="filepath"
|
|
445
|
-
)
|
|
446
|
-
load_btn = gr.Button("Load Config", variant="primary")
|
|
447
|
-
load_status = gr.Textbox(
|
|
448
|
-
label="Load Status", interactive=False)
|
|
449
|
-
|
|
450
|
-
with gr.Column(scale=2):
|
|
451
|
-
gr.Markdown("### Current Configuration")
|
|
452
|
-
config_display = gr.JSON(label="Configuration", value={})
|
|
453
|
-
|
|
454
|
-
gr.Markdown("---")
|
|
455
|
-
gr.Markdown("### Manual Configuration")
|
|
456
|
-
|
|
457
|
-
with gr.Row():
|
|
458
|
-
with gr.Column():
|
|
459
|
-
gr.Markdown("#### Data Settings")
|
|
460
|
-
data_dir = gr.Textbox(
|
|
461
|
-
label="Data Directory", value="./Data")
|
|
462
|
-
model_list = gr.Textbox(
|
|
463
|
-
label="Model List (comma-separated)", value="od")
|
|
464
|
-
model_categories = gr.Textbox(
|
|
465
|
-
label="Model Categories (comma-separated)", value="bc")
|
|
466
|
-
target = gr.Textbox(
|
|
467
|
-
label="Target Column", value="response")
|
|
468
|
-
weight = gr.Textbox(label="Weight Column", value="weights")
|
|
469
|
-
|
|
470
|
-
gr.Markdown("#### Features")
|
|
471
|
-
feature_list = gr.Textbox(
|
|
472
|
-
label="Feature List (comma-separated)",
|
|
473
|
-
placeholder="feature_1, feature_2, feature_3",
|
|
474
|
-
lines=3
|
|
475
|
-
)
|
|
476
|
-
categorical_features = gr.Textbox(
|
|
477
|
-
label="Categorical Features (comma-separated)",
|
|
478
|
-
placeholder="feature_2, feature_3",
|
|
479
|
-
lines=2
|
|
480
|
-
)
|
|
481
|
-
|
|
482
|
-
with gr.Column():
|
|
483
|
-
gr.Markdown("#### Model Settings")
|
|
484
|
-
task_type = gr.Dropdown(
|
|
485
|
-
label="Task Type",
|
|
486
|
-
choices=["regression", "binary", "multiclass"],
|
|
487
|
-
value="regression"
|
|
488
|
-
)
|
|
489
|
-
prop_test = gr.Slider(
|
|
490
|
-
label="Test Proportion", minimum=0.1, maximum=0.5, value=0.25, step=0.05)
|
|
491
|
-
holdout_ratio = gr.Slider(
|
|
492
|
-
label="Holdout Ratio", minimum=0.1, maximum=0.5, value=0.25, step=0.05)
|
|
493
|
-
val_ratio = gr.Slider(
|
|
494
|
-
label="Validation Ratio", minimum=0.1, maximum=0.5, value=0.25, step=0.05)
|
|
495
|
-
split_strategy = gr.Dropdown(
|
|
496
|
-
label="Split Strategy",
|
|
497
|
-
choices=["random", "stratified", "time", "group"],
|
|
498
|
-
value="random"
|
|
499
|
-
)
|
|
500
|
-
rand_seed = gr.Number(
|
|
501
|
-
label="Random Seed", value=13, precision=0)
|
|
502
|
-
epochs = gr.Number(label="Epochs", value=50, precision=0)
|
|
503
|
-
|
|
504
|
-
with gr.Row():
|
|
505
|
-
with gr.Column():
|
|
506
|
-
gr.Markdown("#### Training Settings")
|
|
507
|
-
output_dir = gr.Textbox(
|
|
508
|
-
label="Output Directory", value="./Results")
|
|
509
|
-
use_gpu = gr.Checkbox(label="Use GPU", value=True)
|
|
510
|
-
model_keys = gr.Textbox(
|
|
511
|
-
label="Model Keys (comma-separated)",
|
|
512
|
-
value="xgb, resn",
|
|
513
|
-
placeholder="xgb, resn, ft, gnn"
|
|
514
|
-
)
|
|
515
|
-
max_evals = gr.Number(
|
|
516
|
-
label="Max Evaluations", value=50, precision=0)
|
|
517
|
-
|
|
518
|
-
with gr.Column():
|
|
519
|
-
gr.Markdown("#### XGBoost Settings")
|
|
520
|
-
xgb_max_depth_max = gr.Number(
|
|
521
|
-
label="XGB Max Depth", value=25, precision=0)
|
|
522
|
-
xgb_n_estimators_max = gr.Number(
|
|
523
|
-
label="XGB Max Estimators", value=500, precision=0)
|
|
524
|
-
|
|
525
|
-
with gr.Row():
|
|
526
|
-
build_btn = gr.Button(
|
|
527
|
-
"Build Configuration", variant="primary", size="lg")
|
|
528
|
-
save_config_btn = gr.Button(
|
|
529
|
-
"Save Configuration", variant="secondary", size="lg")
|
|
530
|
-
|
|
531
|
-
with gr.Row():
|
|
532
|
-
build_status = gr.Textbox(label="Status", interactive=False)
|
|
533
|
-
config_json = gr.Textbox(
|
|
534
|
-
label="Generated Config (JSON)", lines=10, max_lines=20)
|
|
535
|
-
|
|
536
|
-
save_filename = gr.Textbox(
|
|
537
|
-
label="Save Filename", value="my_config.json")
|
|
538
|
-
save_status = gr.Textbox(label="Save Status", interactive=False)
|
|
539
|
-
|
|
540
|
-
with gr.Tab("Run Task"):
|
|
541
|
-
gr.Markdown(
|
|
542
|
-
"""
|
|
543
|
-
### Run Model Task
|
|
544
|
-
Click the button below to execute the task defined in your configuration.
|
|
545
|
-
Task type is automatically detected from `config.runner.mode`:
|
|
546
|
-
- **entry**: Standard model training
|
|
547
|
-
- **explain**: Model explanation (permutation, SHAP, integrated gradients)
|
|
548
|
-
- **incremental**: Incremental training
|
|
549
|
-
- **watchdog**: Watchdog mode
|
|
550
|
-
|
|
551
|
-
Task logs will appear in real-time below.
|
|
552
|
-
"""
|
|
553
|
-
)
|
|
554
|
-
|
|
555
|
-
with gr.Row():
|
|
556
|
-
run_btn = gr.Button("Run Task", variant="primary", size="lg")
|
|
557
|
-
run_status = gr.Textbox(label="Task Status", interactive=False)
|
|
558
|
-
|
|
559
|
-
gr.Markdown("### Task Logs")
|
|
560
|
-
log_output = gr.Textbox(
|
|
561
|
-
label="Logs",
|
|
562
|
-
lines=25,
|
|
563
|
-
max_lines=50,
|
|
564
|
-
interactive=False,
|
|
565
|
-
autoscroll=True
|
|
566
|
-
)
|
|
567
|
-
|
|
568
|
-
gr.Markdown("---")
|
|
569
|
-
with gr.Row():
|
|
570
|
-
open_folder_btn = gr.Button("Open Results Folder", size="lg")
|
|
571
|
-
folder_status = gr.Textbox(
|
|
572
|
-
label="Status", interactive=False, scale=2)
|
|
573
|
-
|
|
574
|
-
with gr.Tab("FT Two-Step Workflow"):
|
|
575
|
-
gr.Markdown(
|
|
576
|
-
"""
|
|
577
|
-
### FT-Transformer Two-Step Training
|
|
578
|
-
|
|
579
|
-
Automates the FT → XGB/ResN workflow:
|
|
580
|
-
1. **Step 1**: Train FT-Transformer as unsupervised embedding generator
|
|
581
|
-
2. **Step 2**: Merge embeddings with raw data and train XGB/ResN
|
|
582
|
-
|
|
583
|
-
**Instructions**:
|
|
584
|
-
1. Load or build a base configuration in the Configuration tab
|
|
585
|
-
2. Prepare Step 1 config (FT embeddings)
|
|
586
|
-
3. Run Step 1 to generate embeddings
|
|
587
|
-
4. Prepare Step 2 configs (XGB/ResN using embeddings)
|
|
588
|
-
5. Run Step 2 with the generated configs
|
|
589
|
-
"""
|
|
590
|
-
)
|
|
591
|
-
|
|
592
|
-
with gr.Row():
|
|
593
|
-
with gr.Column():
|
|
594
|
-
gr.Markdown("### Step 1: FT Embedding Generation")
|
|
595
|
-
ft_use_ddp = gr.Checkbox(
|
|
596
|
-
label="Use DDP for FT", value=True)
|
|
597
|
-
ft_nproc = gr.Number(
|
|
598
|
-
label="Number of Processes (DDP)", value=2, precision=0)
|
|
599
|
-
|
|
600
|
-
prepare_step1_btn = gr.Button(
|
|
601
|
-
"Prepare Step 1 Config", variant="primary")
|
|
602
|
-
step1_status = gr.Textbox(
|
|
603
|
-
label="Status", interactive=False)
|
|
604
|
-
step1_config_display = gr.Textbox(
|
|
605
|
-
label="Step 1 Config (FT Embedding)",
|
|
606
|
-
lines=15,
|
|
607
|
-
max_lines=25
|
|
608
|
-
)
|
|
609
|
-
|
|
610
|
-
with gr.Column():
|
|
611
|
-
gr.Markdown("### Step 2: Train XGB/ResN with Embeddings")
|
|
612
|
-
target_models_input = gr.Textbox(
|
|
613
|
-
label="Target Models (comma-separated)",
|
|
614
|
-
value="xgb, resn",
|
|
615
|
-
placeholder="xgb, resn"
|
|
616
|
-
)
|
|
617
|
-
|
|
618
|
-
prepare_step2_btn = gr.Button(
|
|
619
|
-
"Prepare Step 2 Configs", variant="primary")
|
|
620
|
-
step2_status = gr.Textbox(
|
|
621
|
-
label="Status", interactive=False)
|
|
622
|
-
|
|
623
|
-
with gr.Tab("XGB Config"):
|
|
624
|
-
xgb_config_display = gr.Textbox(
|
|
625
|
-
label="XGB Step 2 Config",
|
|
626
|
-
lines=15,
|
|
627
|
-
max_lines=25
|
|
628
|
-
)
|
|
629
|
-
|
|
630
|
-
with gr.Tab("ResN Config"):
|
|
631
|
-
resn_config_display = gr.Textbox(
|
|
632
|
-
label="ResN Step 2 Config",
|
|
633
|
-
lines=15,
|
|
634
|
-
max_lines=25
|
|
635
|
-
)
|
|
636
|
-
|
|
637
|
-
gr.Markdown("---")
|
|
638
|
-
gr.Markdown(
|
|
639
|
-
"""
|
|
640
|
-
### Quick Actions
|
|
641
|
-
After preparing configs, you can:
|
|
642
|
-
- Copy the Step 1 config and paste it in the **Configuration** tab, then run it in **Run Task** tab
|
|
643
|
-
- After Step 1 completes, click **Prepare Step 2 Configs**
|
|
644
|
-
- Copy the Step 2 configs (XGB or ResN) and run them in **Run Task** tab
|
|
645
|
-
"""
|
|
646
|
-
)
|
|
647
|
-
|
|
648
|
-
with gr.Tab("Plotting"):
|
|
649
|
-
gr.Markdown(
|
|
650
|
-
"""
|
|
651
|
-
### Plotting Workflows
|
|
652
|
-
Run the plotting steps from the example notebooks.
|
|
653
|
-
"""
|
|
654
|
-
)
|
|
655
|
-
|
|
656
|
-
with gr.Tab("Pre Oneway"):
|
|
657
|
-
with gr.Row():
|
|
658
|
-
with gr.Column():
|
|
659
|
-
pre_data_path = gr.Textbox(
|
|
660
|
-
label="Data Path", value="./Data/od_bc.csv")
|
|
661
|
-
pre_model_name = gr.Textbox(
|
|
662
|
-
label="Model Name", value="od_bc")
|
|
663
|
-
pre_target = gr.Textbox(
|
|
664
|
-
label="Target Column", value="response")
|
|
665
|
-
pre_weight = gr.Textbox(
|
|
666
|
-
label="Weight Column", value="weights")
|
|
667
|
-
pre_output_dir = gr.Textbox(
|
|
668
|
-
label="Output Dir (optional)", value="")
|
|
669
|
-
with gr.Column():
|
|
670
|
-
pre_feature_list = gr.Textbox(
|
|
671
|
-
label="Feature List (comma-separated)",
|
|
672
|
-
lines=4,
|
|
673
|
-
placeholder="feature_1, feature_2, feature_3",
|
|
674
|
-
)
|
|
675
|
-
pre_categorical = gr.Textbox(
|
|
676
|
-
label="Categorical Features (comma-separated, optional)",
|
|
677
|
-
lines=3,
|
|
678
|
-
placeholder="feature_2, feature_3",
|
|
679
|
-
)
|
|
680
|
-
pre_n_bins = gr.Number(
|
|
681
|
-
label="Bins", value=10, precision=0)
|
|
682
|
-
pre_holdout = gr.Slider(
|
|
683
|
-
label="Holdout Ratio",
|
|
684
|
-
minimum=0.0,
|
|
685
|
-
maximum=0.5,
|
|
686
|
-
value=0.25,
|
|
687
|
-
step=0.05,
|
|
688
|
-
)
|
|
689
|
-
pre_seed = gr.Number(
|
|
690
|
-
label="Random Seed", value=13, precision=0)
|
|
691
|
-
|
|
692
|
-
pre_run_btn = gr.Button("Run Pre Oneway", variant="primary")
|
|
693
|
-
pre_status = gr.Textbox(label="Status", interactive=False)
|
|
694
|
-
pre_log = gr.Textbox(label="Logs", lines=15,
|
|
695
|
-
max_lines=40, interactive=False)
|
|
696
|
-
|
|
697
|
-
with gr.Tab("Direct Plot"):
|
|
698
|
-
direct_cfg_path = gr.Textbox(
|
|
699
|
-
label="Plot Config", value="config_plot.json")
|
|
700
|
-
direct_xgb_cfg = gr.Textbox(
|
|
701
|
-
label="XGB Config", value="config_xgb_direct.json")
|
|
702
|
-
direct_resn_cfg = gr.Textbox(
|
|
703
|
-
label="ResN Config", value="config_resn_direct.json")
|
|
704
|
-
direct_run_btn = gr.Button(
|
|
705
|
-
"Run Direct Plot", variant="primary")
|
|
706
|
-
direct_status = gr.Textbox(label="Status", interactive=False)
|
|
707
|
-
direct_log = gr.Textbox(
|
|
708
|
-
label="Logs", lines=15, max_lines=40, interactive=False)
|
|
709
|
-
|
|
710
|
-
with gr.Tab("Embed Plot"):
|
|
711
|
-
embed_cfg_path = gr.Textbox(
|
|
712
|
-
label="Plot Config", value="config_plot.json")
|
|
713
|
-
embed_xgb_cfg = gr.Textbox(
|
|
714
|
-
label="XGB Embed Config", value="config_xgb_from_ft_unsupervised.json")
|
|
715
|
-
embed_resn_cfg = gr.Textbox(
|
|
716
|
-
label="ResN Embed Config", value="config_resn_from_ft_unsupervised.json")
|
|
717
|
-
embed_ft_cfg = gr.Textbox(
|
|
718
|
-
label="FT Embed Config", value="config_ft_unsupervised_ddp_embed.json")
|
|
719
|
-
embed_runtime = gr.Checkbox(
|
|
720
|
-
label="Use Runtime FT Embedding", value=False)
|
|
721
|
-
embed_run_btn = gr.Button("Run Embed Plot", variant="primary")
|
|
722
|
-
embed_status = gr.Textbox(label="Status", interactive=False)
|
|
723
|
-
embed_log = gr.Textbox(
|
|
724
|
-
label="Logs", lines=15, max_lines=40, interactive=False)
|
|
725
|
-
|
|
726
|
-
with gr.Tab("Prediction"):
|
|
727
|
-
gr.Markdown("### FT Embed Prediction")
|
|
728
|
-
pred_ft_cfg = gr.Textbox(
|
|
729
|
-
label="FT Config", value="config_ft_unsupervised_ddp_embed.json")
|
|
730
|
-
pred_xgb_cfg = gr.Textbox(
|
|
731
|
-
label="XGB Config (optional)", value="config_xgb_from_ft_unsupervised.json")
|
|
732
|
-
pred_resn_cfg = gr.Textbox(
|
|
733
|
-
label="ResN Config (optional)", value="config_resn_from_ft_unsupervised.json")
|
|
734
|
-
pred_input = gr.Textbox(
|
|
735
|
-
label="Input Data", value="./Data/od_bc_new.csv")
|
|
736
|
-
pred_output = gr.Textbox(
|
|
737
|
-
label="Output CSV", value="./Results/predictions_ft_xgb.csv")
|
|
738
|
-
pred_model_name = gr.Textbox(
|
|
739
|
-
label="Model Name (optional)", value="")
|
|
740
|
-
pred_model_keys = gr.Textbox(label="Model Keys", value="xgb, resn")
|
|
741
|
-
pred_run_btn = gr.Button("Run Prediction", variant="primary")
|
|
742
|
-
pred_status = gr.Textbox(label="Status", interactive=False)
|
|
743
|
-
pred_log = gr.Textbox(label="Logs", lines=15,
|
|
744
|
-
max_lines=40, interactive=False)
|
|
745
|
-
|
|
746
|
-
with gr.Tab("Compare"):
|
|
747
|
-
gr.Markdown("### Compare Direct vs FT-Embed Models")
|
|
748
|
-
|
|
749
|
-
with gr.Tab("Compare XGB"):
|
|
750
|
-
cmp_xgb_direct_cfg = gr.Textbox(
|
|
751
|
-
label="Direct XGB Config", value="config_xgb_direct.json")
|
|
752
|
-
cmp_xgb_ft_cfg = gr.Textbox(
|
|
753
|
-
label="FT Config", value="config_ft_unsupervised_ddp_embed.json")
|
|
754
|
-
cmp_xgb_embed_cfg = gr.Textbox(
|
|
755
|
-
label="FT-Embed XGB Config", value="config_xgb_from_ft_unsupervised.json")
|
|
756
|
-
cmp_xgb_label_direct = gr.Textbox(
|
|
757
|
-
label="Direct Label", value="XGB_raw")
|
|
758
|
-
cmp_xgb_label_ft = gr.Textbox(
|
|
759
|
-
label="FT Label", value="XGB_ft_embed")
|
|
760
|
-
cmp_xgb_runtime = gr.Checkbox(
|
|
761
|
-
label="Use Runtime FT Embedding", value=False)
|
|
762
|
-
cmp_xgb_bins = gr.Number(
|
|
763
|
-
label="Bins Override", value=10, precision=0)
|
|
764
|
-
cmp_xgb_run_btn = gr.Button(
|
|
765
|
-
"Run XGB Compare", variant="primary")
|
|
766
|
-
cmp_xgb_status = gr.Textbox(label="Status", interactive=False)
|
|
767
|
-
cmp_xgb_log = gr.Textbox(
|
|
768
|
-
label="Logs", lines=15, max_lines=40, interactive=False)
|
|
769
|
-
|
|
770
|
-
with gr.Tab("Compare ResNet"):
|
|
771
|
-
cmp_resn_direct_cfg = gr.Textbox(
|
|
772
|
-
label="Direct ResN Config", value="config_resn_direct.json")
|
|
773
|
-
cmp_resn_ft_cfg = gr.Textbox(
|
|
774
|
-
label="FT Config", value="config_ft_unsupervised_ddp_embed.json")
|
|
775
|
-
cmp_resn_embed_cfg = gr.Textbox(
|
|
776
|
-
label="FT-Embed ResN Config", value="config_resn_from_ft_unsupervised.json")
|
|
777
|
-
cmp_resn_label_direct = gr.Textbox(
|
|
778
|
-
label="Direct Label", value="ResN_raw")
|
|
779
|
-
cmp_resn_label_ft = gr.Textbox(
|
|
780
|
-
label="FT Label", value="ResN_ft_embed")
|
|
781
|
-
cmp_resn_runtime = gr.Checkbox(
|
|
782
|
-
label="Use Runtime FT Embedding", value=False)
|
|
783
|
-
cmp_resn_bins = gr.Number(
|
|
784
|
-
label="Bins Override", value=10, precision=0)
|
|
785
|
-
cmp_resn_run_btn = gr.Button(
|
|
786
|
-
"Run ResNet Compare", variant="primary")
|
|
787
|
-
cmp_resn_status = gr.Textbox(label="Status", interactive=False)
|
|
788
|
-
cmp_resn_log = gr.Textbox(
|
|
789
|
-
label="Logs", lines=15, max_lines=40, interactive=False)
|
|
790
|
-
|
|
791
|
-
# Event handlers
|
|
792
|
-
load_btn.click(
|
|
793
|
-
fn=app.load_json_config,
|
|
794
|
-
inputs=[json_file],
|
|
795
|
-
outputs=[load_status, config_display, config_json]
|
|
796
|
-
)
|
|
797
|
-
|
|
798
|
-
build_btn.click(
|
|
799
|
-
fn=app.build_config_from_ui,
|
|
800
|
-
inputs=[
|
|
801
|
-
data_dir, model_list, model_categories, target, weight,
|
|
802
|
-
feature_list, categorical_features, task_type, prop_test,
|
|
803
|
-
holdout_ratio, val_ratio, split_strategy, rand_seed, epochs,
|
|
804
|
-
output_dir, use_gpu, model_keys, max_evals,
|
|
805
|
-
xgb_max_depth_max, xgb_n_estimators_max
|
|
806
|
-
],
|
|
807
|
-
outputs=[build_status, config_json]
|
|
808
|
-
)
|
|
809
|
-
|
|
810
|
-
save_config_btn.click(
|
|
811
|
-
fn=app.save_config,
|
|
812
|
-
inputs=[config_json, save_filename],
|
|
813
|
-
outputs=[save_status]
|
|
814
|
-
)
|
|
815
|
-
|
|
816
|
-
run_btn.click(
|
|
817
|
-
fn=app.run_training,
|
|
818
|
-
inputs=[config_json],
|
|
819
|
-
outputs=[run_status, log_output]
|
|
820
|
-
)
|
|
821
|
-
|
|
822
|
-
open_folder_btn.click(
|
|
823
|
-
fn=app.open_results_folder,
|
|
824
|
-
inputs=[config_json],
|
|
825
|
-
outputs=[folder_status]
|
|
826
|
-
)
|
|
827
|
-
|
|
828
|
-
prepare_step1_btn.click(
|
|
829
|
-
fn=app.prepare_ft_step1,
|
|
830
|
-
inputs=[config_json, ft_use_ddp, ft_nproc],
|
|
831
|
-
outputs=[step1_status, step1_config_display]
|
|
832
|
-
)
|
|
833
|
-
|
|
834
|
-
prepare_step2_btn.click(
|
|
835
|
-
fn=app.prepare_ft_step2,
|
|
836
|
-
inputs=[gr.State(
|
|
837
|
-
lambda: app.current_step1_config or "temp_ft_step1_config.json"), target_models_input],
|
|
838
|
-
outputs=[step2_status, xgb_config_display, resn_config_display]
|
|
839
|
-
)
|
|
840
|
-
|
|
841
|
-
pre_run_btn.click(
|
|
842
|
-
fn=app.run_pre_oneway_ui,
|
|
843
|
-
inputs=[
|
|
844
|
-
pre_data_path, pre_model_name, pre_target, pre_weight,
|
|
845
|
-
pre_feature_list, pre_categorical, pre_n_bins,
|
|
846
|
-
pre_holdout, pre_seed, pre_output_dir
|
|
847
|
-
],
|
|
848
|
-
outputs=[pre_status, pre_log]
|
|
849
|
-
)
|
|
850
|
-
|
|
851
|
-
direct_run_btn.click(
|
|
852
|
-
fn=app.run_plot_direct_ui,
|
|
853
|
-
inputs=[direct_cfg_path, direct_xgb_cfg, direct_resn_cfg],
|
|
854
|
-
outputs=[direct_status, direct_log]
|
|
855
|
-
)
|
|
856
|
-
|
|
857
|
-
embed_run_btn.click(
|
|
858
|
-
fn=app.run_plot_embed_ui,
|
|
859
|
-
inputs=[embed_cfg_path, embed_xgb_cfg,
|
|
860
|
-
embed_resn_cfg, embed_ft_cfg, embed_runtime],
|
|
861
|
-
outputs=[embed_status, embed_log]
|
|
862
|
-
)
|
|
863
|
-
|
|
864
|
-
pred_run_btn.click(
|
|
865
|
-
fn=app.run_predict_ui,
|
|
866
|
-
inputs=[
|
|
867
|
-
pred_ft_cfg, pred_xgb_cfg, pred_resn_cfg, pred_input,
|
|
868
|
-
pred_output, pred_model_name, pred_model_keys
|
|
869
|
-
],
|
|
870
|
-
outputs=[pred_status, pred_log]
|
|
871
|
-
)
|
|
872
|
-
|
|
873
|
-
cmp_xgb_run_btn.click(
|
|
874
|
-
fn=app.run_compare_xgb_ui,
|
|
875
|
-
inputs=[
|
|
876
|
-
cmp_xgb_direct_cfg, cmp_xgb_ft_cfg, cmp_xgb_embed_cfg,
|
|
877
|
-
cmp_xgb_label_direct, cmp_xgb_label_ft,
|
|
878
|
-
cmp_xgb_runtime, cmp_xgb_bins
|
|
879
|
-
],
|
|
880
|
-
outputs=[cmp_xgb_status, cmp_xgb_log]
|
|
881
|
-
)
|
|
36
|
+
import gradio as gr
|
|
882
37
|
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class FrontendDependencyError(RuntimeError):
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _check_frontend_deps() -> None:
|
|
46
|
+
"""Fail fast with a clear message if frontend deps are incompatible."""
|
|
47
|
+
try:
|
|
48
|
+
import gradio # noqa: F401
|
|
49
|
+
except Exception as exc:
|
|
50
|
+
raise FrontendDependencyError(f"Failed to import gradio: {exc}")
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
import huggingface_hub as hf # noqa: F401
|
|
54
|
+
except Exception as exc:
|
|
55
|
+
raise FrontendDependencyError(
|
|
56
|
+
f"Failed to import huggingface_hub: {exc}. "
|
|
57
|
+
"Pin version with `pip install 'huggingface_hub<0.24'`."
|
|
891
58
|
)
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
59
|
+
|
|
60
|
+
if not hasattr(hf, 'HfFolder'):
|
|
61
|
+
raise FrontendDependencyError(
|
|
62
|
+
'Incompatible huggingface_hub detected: missing HfFolder. '
|
|
63
|
+
'Please install `huggingface_hub<0.24`.'
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class PricingApp:
|
|
68
|
+
"""Main application class for the insurance pricing model tasks interface."""
|
|
69
|
+
|
|
70
|
+
def __init__(self):
|
|
71
|
+
self.config_builder = ConfigBuilder()
|
|
72
|
+
self.runner = TaskRunner()
|
|
73
|
+
self.ft_workflow = FTWorkflowHelper()
|
|
74
|
+
self.current_config = {}
|
|
75
|
+
self.current_step1_config = None
|
|
76
|
+
self.current_config_path: Optional[Path] = None
|
|
77
|
+
self.current_config_dir: Optional[Path] = None
|
|
78
|
+
|
|
79
|
+
def load_json_config(self, file_path) -> tuple[str, Dict[str, Any], str]:
|
|
80
|
+
"""Load configuration from uploaded JSON file."""
|
|
81
|
+
if not file_path:
|
|
82
|
+
return "No file uploaded", {}, ""
|
|
83
|
+
|
|
84
|
+
try:
|
|
85
|
+
path = Path(file_path).resolve()
|
|
86
|
+
with open(path, 'r', encoding='utf-8') as f:
|
|
87
|
+
config = json.load(f)
|
|
88
|
+
self.current_config = config
|
|
89
|
+
self.current_config_path = path
|
|
90
|
+
self.current_config_dir = path.parent
|
|
91
|
+
config_json = json.dumps(config, indent=2, ensure_ascii=False)
|
|
92
|
+
return f"Configuration loaded successfully from {path.name}", config, config_json
|
|
93
|
+
except Exception as e:
|
|
94
|
+
return f"Error loading config: {str(e)}", {}, ""
|
|
95
|
+
|
|
96
|
+
def build_config_from_ui(
|
|
97
|
+
self,
|
|
98
|
+
data_dir: str,
|
|
99
|
+
model_list: str,
|
|
100
|
+
model_categories: str,
|
|
101
|
+
target: str,
|
|
102
|
+
weight: str,
|
|
103
|
+
feature_list: str,
|
|
104
|
+
categorical_features: str,
|
|
105
|
+
task_type: str,
|
|
106
|
+
prop_test: float,
|
|
107
|
+
holdout_ratio: float,
|
|
108
|
+
val_ratio: float,
|
|
109
|
+
split_strategy: str,
|
|
110
|
+
rand_seed: int,
|
|
111
|
+
epochs: int,
|
|
112
|
+
output_dir: str,
|
|
113
|
+
use_gpu: bool,
|
|
114
|
+
model_keys: str,
|
|
115
|
+
max_evals: int,
|
|
116
|
+
xgb_max_depth_max: int,
|
|
117
|
+
xgb_n_estimators_max: int,
|
|
118
|
+
) -> tuple[str, str]:
|
|
119
|
+
"""Build configuration from UI parameters."""
|
|
120
|
+
try:
|
|
121
|
+
# Parse comma-separated lists
|
|
122
|
+
model_list = [x.strip()
|
|
123
|
+
for x in model_list.split(',') if x.strip()]
|
|
124
|
+
model_categories = [x.strip()
|
|
125
|
+
for x in model_categories.split(',') if x.strip()]
|
|
126
|
+
feature_list = [x.strip()
|
|
127
|
+
for x in feature_list.split(',') if x.strip()]
|
|
128
|
+
categorical_features = [
|
|
129
|
+
x.strip() for x in categorical_features.split(',') if x.strip()]
|
|
130
|
+
model_keys = [x.strip()
|
|
131
|
+
for x in model_keys.split(',') if x.strip()]
|
|
132
|
+
|
|
133
|
+
config = self.config_builder.build_config(
|
|
134
|
+
data_dir=data_dir,
|
|
135
|
+
model_list=model_list,
|
|
136
|
+
model_categories=model_categories,
|
|
137
|
+
target=target,
|
|
138
|
+
weight=weight,
|
|
139
|
+
feature_list=feature_list,
|
|
140
|
+
categorical_features=categorical_features,
|
|
141
|
+
task_type=task_type,
|
|
142
|
+
prop_test=prop_test,
|
|
143
|
+
holdout_ratio=holdout_ratio,
|
|
144
|
+
val_ratio=val_ratio,
|
|
145
|
+
split_strategy=split_strategy,
|
|
146
|
+
rand_seed=rand_seed,
|
|
147
|
+
epochs=epochs,
|
|
148
|
+
output_dir=output_dir,
|
|
149
|
+
use_gpu=use_gpu,
|
|
150
|
+
model_keys=model_keys,
|
|
151
|
+
max_evals=max_evals,
|
|
152
|
+
xgb_max_depth_max=xgb_max_depth_max,
|
|
153
|
+
xgb_n_estimators_max=xgb_n_estimators_max,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
is_valid, msg = self.config_builder.validate_config(config)
|
|
157
|
+
if not is_valid:
|
|
158
|
+
return f"Validation failed: {msg}", ""
|
|
159
|
+
|
|
160
|
+
self.current_config = config
|
|
161
|
+
self.current_config_path = None
|
|
162
|
+
self.current_config_dir = None
|
|
163
|
+
config_json = json.dumps(config, indent=2, ensure_ascii=False)
|
|
164
|
+
return "Configuration built successfully", config_json
|
|
165
|
+
|
|
166
|
+
except Exception as e:
|
|
167
|
+
return f"Error building config: {str(e)}", ""
|
|
168
|
+
|
|
169
|
+
def save_config(self, config_json: str, filename: str) -> str:
|
|
170
|
+
"""Save current configuration to file."""
|
|
171
|
+
if not config_json:
|
|
172
|
+
return "No configuration to save"
|
|
173
|
+
|
|
174
|
+
try:
|
|
175
|
+
config_path = Path(filename)
|
|
176
|
+
with open(config_path, 'w', encoding='utf-8') as f:
|
|
177
|
+
json.dump(json.loads(config_json), f,
|
|
178
|
+
indent=2, ensure_ascii=False)
|
|
179
|
+
return f"Configuration saved to {config_path}"
|
|
180
|
+
except Exception as e:
|
|
181
|
+
return f"Error saving config: {str(e)}"
|
|
182
|
+
|
|
183
|
+
def run_training(self, config_json: str) -> tuple[str, str]:
|
|
184
|
+
"""
|
|
185
|
+
Run task (training, explain, plotting, etc.) with the current configuration.
|
|
186
|
+
|
|
187
|
+
The task type is automatically detected from config.runner.mode.
|
|
188
|
+
Supported modes: entry (training), explain, incremental, watchdog, etc.
|
|
189
|
+
"""
|
|
190
|
+
try:
|
|
191
|
+
temp_config_path = None
|
|
192
|
+
if config_json:
|
|
193
|
+
config = json.loads(config_json)
|
|
194
|
+
task_mode = config.get('runner', {}).get('mode', 'entry')
|
|
195
|
+
base_dir = self.current_config_dir or Path.cwd()
|
|
196
|
+
temp_config_path = (base_dir / "temp_config.json").resolve()
|
|
197
|
+
with open(temp_config_path, 'w', encoding='utf-8') as f:
|
|
198
|
+
json.dump(config, f, indent=2)
|
|
199
|
+
config_path = temp_config_path
|
|
200
|
+
elif self.current_config_path and self.current_config_path.exists():
|
|
201
|
+
config_path = self.current_config_path
|
|
202
|
+
config = json.loads(config_path.read_text(encoding="utf-8"))
|
|
203
|
+
task_mode = config.get('runner', {}).get('mode', 'entry')
|
|
204
|
+
elif self.current_config:
|
|
205
|
+
config = self.current_config
|
|
206
|
+
task_mode = config.get('runner', {}).get('mode', 'entry')
|
|
207
|
+
temp_config_path = (Path.cwd() / "temp_config.json").resolve()
|
|
208
|
+
with open(temp_config_path, 'w', encoding='utf-8') as f:
|
|
209
|
+
json.dump(config, f, indent=2)
|
|
210
|
+
config_path = temp_config_path
|
|
211
|
+
else:
|
|
212
|
+
return "No configuration provided", ""
|
|
213
|
+
|
|
214
|
+
log_generator = self.runner.run_task(str(config_path))
|
|
215
|
+
|
|
216
|
+
# Collect logs
|
|
217
|
+
full_log = ""
|
|
218
|
+
for log_line in log_generator:
|
|
219
|
+
full_log += log_line + "\n"
|
|
220
|
+
yield f"Task [{task_mode}] in progress...", full_log
|
|
221
|
+
|
|
222
|
+
# Clean up
|
|
223
|
+
if temp_config_path and temp_config_path.exists():
|
|
224
|
+
temp_config_path.unlink()
|
|
225
|
+
|
|
226
|
+
yield f"Task [{task_mode}] completed!", full_log
|
|
227
|
+
|
|
228
|
+
except Exception as e:
|
|
229
|
+
error_msg = f"Error during task execution: {str(e)}"
|
|
230
|
+
yield error_msg, error_msg
|
|
231
|
+
|
|
232
|
+
def prepare_ft_step1(self, config_json: str, use_ddp: bool, nproc: int) -> tuple[str, str]:
|
|
233
|
+
"""Prepare FT Step 1 configuration."""
|
|
234
|
+
if not config_json:
|
|
235
|
+
return "No configuration provided", ""
|
|
236
|
+
|
|
237
|
+
try:
|
|
238
|
+
config = json.loads(config_json)
|
|
239
|
+
step1_config = self.ft_workflow.prepare_step1_config(
|
|
240
|
+
base_config=config,
|
|
241
|
+
use_ddp=use_ddp,
|
|
242
|
+
nproc_per_node=int(nproc)
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
# Save to temp file
|
|
246
|
+
temp_path = Path("temp_ft_step1_config.json")
|
|
247
|
+
with open(temp_path, 'w', encoding='utf-8') as f:
|
|
248
|
+
json.dump(step1_config, f, indent=2)
|
|
249
|
+
|
|
250
|
+
self.current_step1_config = str(temp_path)
|
|
251
|
+
step1_json = json.dumps(step1_config, indent=2, ensure_ascii=False)
|
|
252
|
+
|
|
253
|
+
return "Step 1 config prepared. Click 'Run Step 1' to train FT embeddings.", step1_json
|
|
254
|
+
|
|
255
|
+
except Exception as e:
|
|
256
|
+
return f"Error preparing Step 1 config: {str(e)}", ""
|
|
257
|
+
|
|
258
|
+
def prepare_ft_step2(self, step1_config_path: str, target_models: str) -> tuple[str, str, str]:
|
|
259
|
+
"""Prepare FT Step 2 configurations."""
|
|
260
|
+
if not step1_config_path or not os.path.exists(step1_config_path):
|
|
261
|
+
return "Step 1 config not found. Run Step 1 first.", "", ""
|
|
262
|
+
|
|
263
|
+
try:
|
|
264
|
+
models = [m.strip() for m in target_models.split(',') if m.strip()]
|
|
265
|
+
xgb_cfg, resn_cfg = self.ft_workflow.generate_step2_configs(
|
|
266
|
+
step1_config_path=step1_config_path,
|
|
267
|
+
target_models=models
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
status_msg = f"Step 2 configs prepared for: {', '.join(models)}"
|
|
271
|
+
xgb_json = json.dumps(
|
|
272
|
+
xgb_cfg, indent=2, ensure_ascii=False) if xgb_cfg else ""
|
|
273
|
+
resn_json = json.dumps(
|
|
274
|
+
resn_cfg, indent=2, ensure_ascii=False) if resn_cfg else ""
|
|
275
|
+
|
|
276
|
+
return status_msg, xgb_json, resn_json
|
|
277
|
+
|
|
278
|
+
except FileNotFoundError as e:
|
|
279
|
+
return f"Error: {str(e)}\n\nMake sure Step 1 completed successfully.", "", ""
|
|
280
|
+
except Exception as e:
|
|
281
|
+
return f"Error preparing Step 2 configs: {str(e)}", "", ""
|
|
282
|
+
|
|
283
|
+
def open_results_folder(self, config_json: str) -> str:
|
|
284
|
+
"""Open the results folder in file explorer."""
|
|
285
|
+
try:
|
|
286
|
+
if config_json:
|
|
287
|
+
config = json.loads(config_json)
|
|
288
|
+
output_dir = config.get('output_dir', './Results')
|
|
289
|
+
results_path = Path(output_dir).resolve()
|
|
290
|
+
elif self.current_config_path and self.current_config_path.exists():
|
|
291
|
+
config = json.loads(
|
|
292
|
+
self.current_config_path.read_text(encoding="utf-8"))
|
|
293
|
+
output_dir = config.get('output_dir', './Results')
|
|
294
|
+
results_path = (
|
|
295
|
+
self.current_config_path.parent / output_dir).resolve()
|
|
296
|
+
elif self.current_config:
|
|
297
|
+
output_dir = self.current_config.get('output_dir', './Results')
|
|
298
|
+
results_path = Path(output_dir).resolve()
|
|
299
|
+
else:
|
|
300
|
+
return "No configuration loaded"
|
|
301
|
+
|
|
302
|
+
if not results_path.exists():
|
|
303
|
+
return f"Results folder does not exist yet: {results_path}"
|
|
304
|
+
|
|
305
|
+
# Open folder based on OS
|
|
306
|
+
system = platform.system()
|
|
307
|
+
if system == "Windows":
|
|
308
|
+
os.startfile(results_path)
|
|
309
|
+
elif system == "Darwin": # macOS
|
|
310
|
+
subprocess.run(["open", str(results_path)])
|
|
311
|
+
else: # Linux
|
|
312
|
+
subprocess.run(["xdg-open", str(results_path)])
|
|
313
|
+
|
|
314
|
+
return f"Opened folder: {results_path}"
|
|
315
|
+
|
|
316
|
+
except Exception as e:
|
|
317
|
+
return f"Error opening folder: {str(e)}"
|
|
318
|
+
|
|
319
|
+
def _run_workflow(self, label: str, func: Callable, *args, **kwargs):
|
|
320
|
+
"""Run a workflow function and stream logs."""
|
|
321
|
+
try:
|
|
322
|
+
log_generator = self.runner.run_callable(func, *args, **kwargs)
|
|
323
|
+
full_log = ""
|
|
324
|
+
for log_line in log_generator:
|
|
325
|
+
full_log += log_line + "\n"
|
|
326
|
+
yield f"{label} in progress...", full_log
|
|
327
|
+
yield f"{label} completed!", full_log
|
|
328
|
+
except Exception as e:
|
|
329
|
+
error_msg = f"{label} error: {str(e)}"
|
|
330
|
+
yield error_msg, error_msg
|
|
331
|
+
|
|
332
|
+
def run_pre_oneway_ui(
|
|
333
|
+
self,
|
|
334
|
+
data_path: str,
|
|
335
|
+
model_name: str,
|
|
336
|
+
target_col: str,
|
|
337
|
+
weight_col: str,
|
|
338
|
+
feature_list: str,
|
|
339
|
+
categorical_features: str,
|
|
340
|
+
n_bins: int,
|
|
341
|
+
holdout_ratio: float,
|
|
342
|
+
rand_seed: int,
|
|
343
|
+
output_dir: str,
|
|
344
|
+
):
|
|
345
|
+
yield from self._run_workflow(
|
|
346
|
+
"Pre-Oneway Plot",
|
|
347
|
+
run_pre_oneway,
|
|
348
|
+
data_path=data_path,
|
|
349
|
+
model_name=model_name,
|
|
350
|
+
target_col=target_col,
|
|
351
|
+
weight_col=weight_col,
|
|
352
|
+
feature_list=feature_list,
|
|
353
|
+
categorical_features=categorical_features,
|
|
354
|
+
n_bins=n_bins,
|
|
355
|
+
holdout_ratio=holdout_ratio,
|
|
356
|
+
rand_seed=rand_seed,
|
|
357
|
+
output_dir=output_dir or None,
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
def run_plot_direct_ui(self, cfg_path: str, xgb_cfg_path: str, resn_cfg_path: str):
|
|
361
|
+
yield from self._run_workflow(
|
|
362
|
+
"Direct Plot",
|
|
363
|
+
run_plot_direct,
|
|
364
|
+
cfg_path=cfg_path,
|
|
365
|
+
xgb_cfg_path=xgb_cfg_path,
|
|
366
|
+
resn_cfg_path=resn_cfg_path,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
def run_plot_embed_ui(
|
|
370
|
+
self,
|
|
371
|
+
cfg_path: str,
|
|
372
|
+
xgb_cfg_path: str,
|
|
373
|
+
resn_cfg_path: str,
|
|
374
|
+
ft_cfg_path: str,
|
|
375
|
+
use_runtime_ft_embedding: bool,
|
|
376
|
+
):
|
|
377
|
+
yield from self._run_workflow(
|
|
378
|
+
"Embed Plot",
|
|
379
|
+
run_plot_embed,
|
|
380
|
+
cfg_path=cfg_path,
|
|
381
|
+
xgb_cfg_path=xgb_cfg_path,
|
|
382
|
+
resn_cfg_path=resn_cfg_path,
|
|
383
|
+
ft_cfg_path=ft_cfg_path,
|
|
384
|
+
use_runtime_ft_embedding=use_runtime_ft_embedding,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
def run_predict_ui(
|
|
388
|
+
self,
|
|
389
|
+
ft_cfg_path: str,
|
|
390
|
+
xgb_cfg_path: str,
|
|
391
|
+
resn_cfg_path: str,
|
|
392
|
+
input_path: str,
|
|
393
|
+
output_path: str,
|
|
394
|
+
model_name: str,
|
|
395
|
+
model_keys: str,
|
|
396
|
+
):
|
|
397
|
+
yield from self._run_workflow(
|
|
398
|
+
"Prediction",
|
|
399
|
+
run_predict_ft_embed,
|
|
400
|
+
ft_cfg_path=ft_cfg_path,
|
|
401
|
+
xgb_cfg_path=xgb_cfg_path or None,
|
|
402
|
+
resn_cfg_path=resn_cfg_path or None,
|
|
403
|
+
input_path=input_path,
|
|
404
|
+
output_path=output_path,
|
|
405
|
+
model_name=model_name or None,
|
|
406
|
+
model_keys=model_keys,
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
def run_compare_xgb_ui(
|
|
410
|
+
self,
|
|
411
|
+
direct_cfg_path: str,
|
|
412
|
+
ft_cfg_path: str,
|
|
413
|
+
ft_embed_cfg_path: str,
|
|
414
|
+
label_direct: str,
|
|
415
|
+
label_ft: str,
|
|
416
|
+
use_runtime_ft_embedding: bool,
|
|
417
|
+
n_bins_override: int,
|
|
418
|
+
):
|
|
419
|
+
yield from self._run_workflow(
|
|
420
|
+
"Compare XGB",
|
|
421
|
+
run_compare_ft_embed,
|
|
422
|
+
direct_cfg_path=direct_cfg_path,
|
|
423
|
+
ft_cfg_path=ft_cfg_path,
|
|
424
|
+
ft_embed_cfg_path=ft_embed_cfg_path,
|
|
425
|
+
model_key="xgb",
|
|
426
|
+
label_direct=label_direct,
|
|
427
|
+
label_ft=label_ft,
|
|
428
|
+
use_runtime_ft_embedding=use_runtime_ft_embedding,
|
|
429
|
+
n_bins_override=n_bins_override,
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
def run_compare_resn_ui(
|
|
433
|
+
self,
|
|
434
|
+
direct_cfg_path: str,
|
|
435
|
+
ft_cfg_path: str,
|
|
436
|
+
ft_embed_cfg_path: str,
|
|
437
|
+
label_direct: str,
|
|
438
|
+
label_ft: str,
|
|
439
|
+
use_runtime_ft_embedding: bool,
|
|
440
|
+
n_bins_override: int,
|
|
441
|
+
):
|
|
442
|
+
yield from self._run_workflow(
|
|
443
|
+
"Compare ResNet",
|
|
444
|
+
run_compare_ft_embed,
|
|
445
|
+
direct_cfg_path=direct_cfg_path,
|
|
446
|
+
ft_cfg_path=ft_cfg_path,
|
|
447
|
+
ft_embed_cfg_path=ft_embed_cfg_path,
|
|
448
|
+
model_key="resn",
|
|
449
|
+
label_direct=label_direct,
|
|
450
|
+
label_ft=label_ft,
|
|
451
|
+
use_runtime_ft_embedding=use_runtime_ft_embedding,
|
|
452
|
+
n_bins_override=n_bins_override,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
def create_ui():
|
|
457
|
+
"""Create the Gradio interface."""
|
|
458
|
+
app = PricingApp()
|
|
459
|
+
|
|
460
|
+
with gr.Blocks(title="Insurance Pricing Model Training", theme=gr.themes.Soft()) as demo:
|
|
461
|
+
gr.Markdown(
|
|
462
|
+
"""
|
|
463
|
+
# Insurance Pricing Model Training Interface
|
|
464
|
+
Configure and train insurance pricing models with an easy-to-use interface.
|
|
465
|
+
|
|
466
|
+
**Two ways to configure:**
|
|
467
|
+
1. **Upload JSON Config**: Upload an existing configuration file
|
|
468
|
+
2. **Manual Configuration**: Fill in the parameters below
|
|
469
|
+
"""
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
with gr.Tab("Configuration"):
|
|
473
|
+
with gr.Row():
|
|
474
|
+
with gr.Column(scale=1):
|
|
475
|
+
gr.Markdown("### Load Configuration")
|
|
476
|
+
json_file = gr.File(
|
|
477
|
+
label="Upload JSON Config File",
|
|
478
|
+
file_types=[".json"],
|
|
479
|
+
type="filepath"
|
|
480
|
+
)
|
|
481
|
+
load_btn = gr.Button("Load Config", variant="primary")
|
|
482
|
+
load_status = gr.Textbox(
|
|
483
|
+
label="Load Status", interactive=False)
|
|
484
|
+
|
|
485
|
+
with gr.Column(scale=2):
|
|
486
|
+
gr.Markdown("### Current Configuration")
|
|
487
|
+
config_display = gr.JSON(label="Configuration", value={})
|
|
488
|
+
|
|
489
|
+
gr.Markdown("---")
|
|
490
|
+
gr.Markdown("### Manual Configuration")
|
|
491
|
+
|
|
492
|
+
with gr.Row():
|
|
493
|
+
with gr.Column():
|
|
494
|
+
gr.Markdown("#### Data Settings")
|
|
495
|
+
data_dir = gr.Textbox(
|
|
496
|
+
label="Data Directory", value="./Data")
|
|
497
|
+
model_list = gr.Textbox(
|
|
498
|
+
label="Model List (comma-separated)", value="od")
|
|
499
|
+
model_categories = gr.Textbox(
|
|
500
|
+
label="Model Categories (comma-separated)", value="bc")
|
|
501
|
+
target = gr.Textbox(
|
|
502
|
+
label="Target Column", value="response")
|
|
503
|
+
weight = gr.Textbox(label="Weight Column", value="weights")
|
|
504
|
+
|
|
505
|
+
gr.Markdown("#### Features")
|
|
506
|
+
feature_list = gr.Textbox(
|
|
507
|
+
label="Feature List (comma-separated)",
|
|
508
|
+
placeholder="feature_1, feature_2, feature_3",
|
|
509
|
+
lines=3
|
|
510
|
+
)
|
|
511
|
+
categorical_features = gr.Textbox(
|
|
512
|
+
label="Categorical Features (comma-separated)",
|
|
513
|
+
placeholder="feature_2, feature_3",
|
|
514
|
+
lines=2
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
with gr.Column():
|
|
518
|
+
gr.Markdown("#### Model Settings")
|
|
519
|
+
task_type = gr.Dropdown(
|
|
520
|
+
label="Task Type",
|
|
521
|
+
choices=["regression", "binary", "multiclass"],
|
|
522
|
+
value="regression"
|
|
523
|
+
)
|
|
524
|
+
prop_test = gr.Slider(
|
|
525
|
+
label="Test Proportion", minimum=0.1, maximum=0.5, value=0.25, step=0.05)
|
|
526
|
+
holdout_ratio = gr.Slider(
|
|
527
|
+
label="Holdout Ratio", minimum=0.1, maximum=0.5, value=0.25, step=0.05)
|
|
528
|
+
val_ratio = gr.Slider(
|
|
529
|
+
label="Validation Ratio", minimum=0.1, maximum=0.5, value=0.25, step=0.05)
|
|
530
|
+
split_strategy = gr.Dropdown(
|
|
531
|
+
label="Split Strategy",
|
|
532
|
+
choices=["random", "stratified", "time", "group"],
|
|
533
|
+
value="random"
|
|
534
|
+
)
|
|
535
|
+
rand_seed = gr.Number(
|
|
536
|
+
label="Random Seed", value=13, precision=0)
|
|
537
|
+
epochs = gr.Number(label="Epochs", value=50, precision=0)
|
|
538
|
+
|
|
539
|
+
with gr.Row():
|
|
540
|
+
with gr.Column():
|
|
541
|
+
gr.Markdown("#### Training Settings")
|
|
542
|
+
output_dir = gr.Textbox(
|
|
543
|
+
label="Output Directory", value="./Results")
|
|
544
|
+
use_gpu = gr.Checkbox(label="Use GPU", value=True)
|
|
545
|
+
model_keys = gr.Textbox(
|
|
546
|
+
label="Model Keys (comma-separated)",
|
|
547
|
+
value="xgb, resn",
|
|
548
|
+
placeholder="xgb, resn, ft, gnn"
|
|
549
|
+
)
|
|
550
|
+
max_evals = gr.Number(
|
|
551
|
+
label="Max Evaluations", value=50, precision=0)
|
|
552
|
+
|
|
553
|
+
with gr.Column():
|
|
554
|
+
gr.Markdown("#### XGBoost Settings")
|
|
555
|
+
xgb_max_depth_max = gr.Number(
|
|
556
|
+
label="XGB Max Depth", value=25, precision=0)
|
|
557
|
+
xgb_n_estimators_max = gr.Number(
|
|
558
|
+
label="XGB Max Estimators", value=500, precision=0)
|
|
559
|
+
|
|
560
|
+
with gr.Row():
|
|
561
|
+
build_btn = gr.Button(
|
|
562
|
+
"Build Configuration", variant="primary", size="lg")
|
|
563
|
+
save_config_btn = gr.Button(
|
|
564
|
+
"Save Configuration", variant="secondary", size="lg")
|
|
565
|
+
|
|
566
|
+
with gr.Row():
|
|
567
|
+
build_status = gr.Textbox(label="Status", interactive=False)
|
|
568
|
+
config_json = gr.Textbox(
|
|
569
|
+
label="Generated Config (JSON)", lines=10, max_lines=20)
|
|
570
|
+
|
|
571
|
+
save_filename = gr.Textbox(
|
|
572
|
+
label="Save Filename", value="my_config.json")
|
|
573
|
+
save_status = gr.Textbox(label="Save Status", interactive=False)
|
|
574
|
+
|
|
575
|
+
with gr.Tab("Run Task"):
|
|
576
|
+
gr.Markdown(
|
|
577
|
+
"""
|
|
578
|
+
### Run Model Task
|
|
579
|
+
Click the button below to execute the task defined in your configuration.
|
|
580
|
+
Task type is automatically detected from `config.runner.mode`:
|
|
581
|
+
- **entry**: Standard model training
|
|
582
|
+
- **explain**: Model explanation (permutation, SHAP, integrated gradients)
|
|
583
|
+
- **incremental**: Incremental training
|
|
584
|
+
- **watchdog**: Watchdog mode
|
|
585
|
+
|
|
586
|
+
Task logs will appear in real-time below.
|
|
587
|
+
"""
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
with gr.Row():
|
|
591
|
+
run_btn = gr.Button("Run Task", variant="primary", size="lg")
|
|
592
|
+
run_status = gr.Textbox(label="Task Status", interactive=False)
|
|
593
|
+
|
|
594
|
+
gr.Markdown("### Task Logs")
|
|
595
|
+
log_output = gr.Textbox(
|
|
596
|
+
label="Logs",
|
|
597
|
+
lines=25,
|
|
598
|
+
max_lines=50,
|
|
599
|
+
interactive=False,
|
|
600
|
+
autoscroll=True
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
gr.Markdown("---")
|
|
604
|
+
with gr.Row():
|
|
605
|
+
open_folder_btn = gr.Button("Open Results Folder", size="lg")
|
|
606
|
+
folder_status = gr.Textbox(
|
|
607
|
+
label="Status", interactive=False, scale=2)
|
|
608
|
+
|
|
609
|
+
with gr.Tab("FT Two-Step Workflow"):
|
|
610
|
+
gr.Markdown(
|
|
611
|
+
"""
|
|
612
|
+
### FT-Transformer Two-Step Training
|
|
613
|
+
|
|
614
|
+
Automates the FT → XGB/ResN workflow:
|
|
615
|
+
1. **Step 1**: Train FT-Transformer as unsupervised embedding generator
|
|
616
|
+
2. **Step 2**: Merge embeddings with raw data and train XGB/ResN
|
|
617
|
+
|
|
618
|
+
**Instructions**:
|
|
619
|
+
1. Load or build a base configuration in the Configuration tab
|
|
620
|
+
2. Prepare Step 1 config (FT embeddings)
|
|
621
|
+
3. Run Step 1 to generate embeddings
|
|
622
|
+
4. Prepare Step 2 configs (XGB/ResN using embeddings)
|
|
623
|
+
5. Run Step 2 with the generated configs
|
|
624
|
+
"""
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
with gr.Row():
|
|
628
|
+
with gr.Column():
|
|
629
|
+
gr.Markdown("### Step 1: FT Embedding Generation")
|
|
630
|
+
ft_use_ddp = gr.Checkbox(
|
|
631
|
+
label="Use DDP for FT", value=True)
|
|
632
|
+
ft_nproc = gr.Number(
|
|
633
|
+
label="Number of Processes (DDP)", value=2, precision=0)
|
|
634
|
+
|
|
635
|
+
prepare_step1_btn = gr.Button(
|
|
636
|
+
"Prepare Step 1 Config", variant="primary")
|
|
637
|
+
step1_status = gr.Textbox(
|
|
638
|
+
label="Status", interactive=False)
|
|
639
|
+
step1_config_display = gr.Textbox(
|
|
640
|
+
label="Step 1 Config (FT Embedding)",
|
|
641
|
+
lines=15,
|
|
642
|
+
max_lines=25
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
with gr.Column():
|
|
646
|
+
gr.Markdown("### Step 2: Train XGB/ResN with Embeddings")
|
|
647
|
+
target_models_input = gr.Textbox(
|
|
648
|
+
label="Target Models (comma-separated)",
|
|
649
|
+
value="xgb, resn",
|
|
650
|
+
placeholder="xgb, resn"
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
prepare_step2_btn = gr.Button(
|
|
654
|
+
"Prepare Step 2 Configs", variant="primary")
|
|
655
|
+
step2_status = gr.Textbox(
|
|
656
|
+
label="Status", interactive=False)
|
|
657
|
+
|
|
658
|
+
with gr.Tab("XGB Config"):
|
|
659
|
+
xgb_config_display = gr.Textbox(
|
|
660
|
+
label="XGB Step 2 Config",
|
|
661
|
+
lines=15,
|
|
662
|
+
max_lines=25
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
with gr.Tab("ResN Config"):
|
|
666
|
+
resn_config_display = gr.Textbox(
|
|
667
|
+
label="ResN Step 2 Config",
|
|
668
|
+
lines=15,
|
|
669
|
+
max_lines=25
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
gr.Markdown("---")
|
|
673
|
+
gr.Markdown(
|
|
674
|
+
"""
|
|
675
|
+
### Quick Actions
|
|
676
|
+
After preparing configs, you can:
|
|
677
|
+
- Copy the Step 1 config and paste it in the **Configuration** tab, then run it in **Run Task** tab
|
|
678
|
+
- After Step 1 completes, click **Prepare Step 2 Configs**
|
|
679
|
+
- Copy the Step 2 configs (XGB or ResN) and run them in **Run Task** tab
|
|
680
|
+
"""
|
|
681
|
+
)
|
|
682
|
+
|
|
683
|
+
with gr.Tab("Plotting"):
|
|
684
|
+
gr.Markdown(
|
|
685
|
+
"""
|
|
686
|
+
### Plotting Workflows
|
|
687
|
+
Run the plotting steps from the example notebooks.
|
|
688
|
+
"""
|
|
689
|
+
)
|
|
690
|
+
|
|
691
|
+
with gr.Tab("Pre Oneway"):
|
|
692
|
+
with gr.Row():
|
|
693
|
+
with gr.Column():
|
|
694
|
+
pre_data_path = gr.Textbox(
|
|
695
|
+
label="Data Path", value="./Data/od_bc.csv")
|
|
696
|
+
pre_model_name = gr.Textbox(
|
|
697
|
+
label="Model Name", value="od_bc")
|
|
698
|
+
pre_target = gr.Textbox(
|
|
699
|
+
label="Target Column", value="response")
|
|
700
|
+
pre_weight = gr.Textbox(
|
|
701
|
+
label="Weight Column", value="weights")
|
|
702
|
+
pre_output_dir = gr.Textbox(
|
|
703
|
+
label="Output Dir (optional)", value="")
|
|
704
|
+
with gr.Column():
|
|
705
|
+
pre_feature_list = gr.Textbox(
|
|
706
|
+
label="Feature List (comma-separated)",
|
|
707
|
+
lines=4,
|
|
708
|
+
placeholder="feature_1, feature_2, feature_3",
|
|
709
|
+
)
|
|
710
|
+
pre_categorical = gr.Textbox(
|
|
711
|
+
label="Categorical Features (comma-separated, optional)",
|
|
712
|
+
lines=3,
|
|
713
|
+
placeholder="feature_2, feature_3",
|
|
714
|
+
)
|
|
715
|
+
pre_n_bins = gr.Number(
|
|
716
|
+
label="Bins", value=10, precision=0)
|
|
717
|
+
pre_holdout = gr.Slider(
|
|
718
|
+
label="Holdout Ratio",
|
|
719
|
+
minimum=0.0,
|
|
720
|
+
maximum=0.5,
|
|
721
|
+
value=0.25,
|
|
722
|
+
step=0.05,
|
|
723
|
+
)
|
|
724
|
+
pre_seed = gr.Number(
|
|
725
|
+
label="Random Seed", value=13, precision=0)
|
|
726
|
+
|
|
727
|
+
pre_run_btn = gr.Button("Run Pre Oneway", variant="primary")
|
|
728
|
+
pre_status = gr.Textbox(label="Status", interactive=False)
|
|
729
|
+
pre_log = gr.Textbox(label="Logs", lines=15,
|
|
730
|
+
max_lines=40, interactive=False)
|
|
731
|
+
|
|
732
|
+
with gr.Tab("Direct Plot"):
|
|
733
|
+
direct_cfg_path = gr.Textbox(
|
|
734
|
+
label="Plot Config", value="config_plot.json")
|
|
735
|
+
direct_xgb_cfg = gr.Textbox(
|
|
736
|
+
label="XGB Config", value="config_xgb_direct.json")
|
|
737
|
+
direct_resn_cfg = gr.Textbox(
|
|
738
|
+
label="ResN Config", value="config_resn_direct.json")
|
|
739
|
+
direct_run_btn = gr.Button(
|
|
740
|
+
"Run Direct Plot", variant="primary")
|
|
741
|
+
direct_status = gr.Textbox(label="Status", interactive=False)
|
|
742
|
+
direct_log = gr.Textbox(
|
|
743
|
+
label="Logs", lines=15, max_lines=40, interactive=False)
|
|
744
|
+
|
|
745
|
+
with gr.Tab("Embed Plot"):
|
|
746
|
+
embed_cfg_path = gr.Textbox(
|
|
747
|
+
label="Plot Config", value="config_plot.json")
|
|
748
|
+
embed_xgb_cfg = gr.Textbox(
|
|
749
|
+
label="XGB Embed Config", value="config_xgb_from_ft_unsupervised.json")
|
|
750
|
+
embed_resn_cfg = gr.Textbox(
|
|
751
|
+
label="ResN Embed Config", value="config_resn_from_ft_unsupervised.json")
|
|
752
|
+
embed_ft_cfg = gr.Textbox(
|
|
753
|
+
label="FT Embed Config", value="config_ft_unsupervised_ddp_embed.json")
|
|
754
|
+
embed_runtime = gr.Checkbox(
|
|
755
|
+
label="Use Runtime FT Embedding", value=False)
|
|
756
|
+
embed_run_btn = gr.Button("Run Embed Plot", variant="primary")
|
|
757
|
+
embed_status = gr.Textbox(label="Status", interactive=False)
|
|
758
|
+
embed_log = gr.Textbox(
|
|
759
|
+
label="Logs", lines=15, max_lines=40, interactive=False)
|
|
760
|
+
|
|
761
|
+
with gr.Tab("Prediction"):
|
|
762
|
+
gr.Markdown("### FT Embed Prediction")
|
|
763
|
+
pred_ft_cfg = gr.Textbox(
|
|
764
|
+
label="FT Config", value="config_ft_unsupervised_ddp_embed.json")
|
|
765
|
+
pred_xgb_cfg = gr.Textbox(
|
|
766
|
+
label="XGB Config (optional)", value="config_xgb_from_ft_unsupervised.json")
|
|
767
|
+
pred_resn_cfg = gr.Textbox(
|
|
768
|
+
label="ResN Config (optional)", value="config_resn_from_ft_unsupervised.json")
|
|
769
|
+
pred_input = gr.Textbox(
|
|
770
|
+
label="Input Data", value="./Data/od_bc_new.csv")
|
|
771
|
+
pred_output = gr.Textbox(
|
|
772
|
+
label="Output CSV", value="./Results/predictions_ft_xgb.csv")
|
|
773
|
+
pred_model_name = gr.Textbox(
|
|
774
|
+
label="Model Name (optional)", value="")
|
|
775
|
+
pred_model_keys = gr.Textbox(label="Model Keys", value="xgb, resn")
|
|
776
|
+
pred_run_btn = gr.Button("Run Prediction", variant="primary")
|
|
777
|
+
pred_status = gr.Textbox(label="Status", interactive=False)
|
|
778
|
+
pred_log = gr.Textbox(label="Logs", lines=15,
|
|
779
|
+
max_lines=40, interactive=False)
|
|
780
|
+
|
|
781
|
+
with gr.Tab("Compare"):
|
|
782
|
+
gr.Markdown("### Compare Direct vs FT-Embed Models")
|
|
783
|
+
|
|
784
|
+
with gr.Tab("Compare XGB"):
|
|
785
|
+
cmp_xgb_direct_cfg = gr.Textbox(
|
|
786
|
+
label="Direct XGB Config", value="config_xgb_direct.json")
|
|
787
|
+
cmp_xgb_ft_cfg = gr.Textbox(
|
|
788
|
+
label="FT Config", value="config_ft_unsupervised_ddp_embed.json")
|
|
789
|
+
cmp_xgb_embed_cfg = gr.Textbox(
|
|
790
|
+
label="FT-Embed XGB Config", value="config_xgb_from_ft_unsupervised.json")
|
|
791
|
+
cmp_xgb_label_direct = gr.Textbox(
|
|
792
|
+
label="Direct Label", value="XGB_raw")
|
|
793
|
+
cmp_xgb_label_ft = gr.Textbox(
|
|
794
|
+
label="FT Label", value="XGB_ft_embed")
|
|
795
|
+
cmp_xgb_runtime = gr.Checkbox(
|
|
796
|
+
label="Use Runtime FT Embedding", value=False)
|
|
797
|
+
cmp_xgb_bins = gr.Number(
|
|
798
|
+
label="Bins Override", value=10, precision=0)
|
|
799
|
+
cmp_xgb_run_btn = gr.Button(
|
|
800
|
+
"Run XGB Compare", variant="primary")
|
|
801
|
+
cmp_xgb_status = gr.Textbox(label="Status", interactive=False)
|
|
802
|
+
cmp_xgb_log = gr.Textbox(
|
|
803
|
+
label="Logs", lines=15, max_lines=40, interactive=False)
|
|
804
|
+
|
|
805
|
+
with gr.Tab("Compare ResNet"):
|
|
806
|
+
cmp_resn_direct_cfg = gr.Textbox(
|
|
807
|
+
label="Direct ResN Config", value="config_resn_direct.json")
|
|
808
|
+
cmp_resn_ft_cfg = gr.Textbox(
|
|
809
|
+
label="FT Config", value="config_ft_unsupervised_ddp_embed.json")
|
|
810
|
+
cmp_resn_embed_cfg = gr.Textbox(
|
|
811
|
+
label="FT-Embed ResN Config", value="config_resn_from_ft_unsupervised.json")
|
|
812
|
+
cmp_resn_label_direct = gr.Textbox(
|
|
813
|
+
label="Direct Label", value="ResN_raw")
|
|
814
|
+
cmp_resn_label_ft = gr.Textbox(
|
|
815
|
+
label="FT Label", value="ResN_ft_embed")
|
|
816
|
+
cmp_resn_runtime = gr.Checkbox(
|
|
817
|
+
label="Use Runtime FT Embedding", value=False)
|
|
818
|
+
cmp_resn_bins = gr.Number(
|
|
819
|
+
label="Bins Override", value=10, precision=0)
|
|
820
|
+
cmp_resn_run_btn = gr.Button(
|
|
821
|
+
"Run ResNet Compare", variant="primary")
|
|
822
|
+
cmp_resn_status = gr.Textbox(label="Status", interactive=False)
|
|
823
|
+
cmp_resn_log = gr.Textbox(
|
|
824
|
+
label="Logs", lines=15, max_lines=40, interactive=False)
|
|
825
|
+
|
|
826
|
+
# Event handlers
|
|
827
|
+
load_btn.click(
|
|
828
|
+
fn=app.load_json_config,
|
|
829
|
+
inputs=[json_file],
|
|
830
|
+
outputs=[load_status, config_display, config_json]
|
|
831
|
+
)
|
|
832
|
+
|
|
833
|
+
build_btn.click(
|
|
834
|
+
fn=app.build_config_from_ui,
|
|
835
|
+
inputs=[
|
|
836
|
+
data_dir, model_list, model_categories, target, weight,
|
|
837
|
+
feature_list, categorical_features, task_type, prop_test,
|
|
838
|
+
holdout_ratio, val_ratio, split_strategy, rand_seed, epochs,
|
|
839
|
+
output_dir, use_gpu, model_keys, max_evals,
|
|
840
|
+
xgb_max_depth_max, xgb_n_estimators_max
|
|
841
|
+
],
|
|
842
|
+
outputs=[build_status, config_json]
|
|
843
|
+
)
|
|
844
|
+
|
|
845
|
+
save_config_btn.click(
|
|
846
|
+
fn=app.save_config,
|
|
847
|
+
inputs=[config_json, save_filename],
|
|
848
|
+
outputs=[save_status]
|
|
849
|
+
)
|
|
850
|
+
|
|
851
|
+
run_btn.click(
|
|
852
|
+
fn=app.run_training,
|
|
853
|
+
inputs=[config_json],
|
|
854
|
+
outputs=[run_status, log_output]
|
|
855
|
+
)
|
|
856
|
+
|
|
857
|
+
open_folder_btn.click(
|
|
858
|
+
fn=app.open_results_folder,
|
|
859
|
+
inputs=[config_json],
|
|
860
|
+
outputs=[folder_status]
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
prepare_step1_btn.click(
|
|
864
|
+
fn=app.prepare_ft_step1,
|
|
865
|
+
inputs=[config_json, ft_use_ddp, ft_nproc],
|
|
866
|
+
outputs=[step1_status, step1_config_display]
|
|
867
|
+
)
|
|
868
|
+
|
|
869
|
+
prepare_step2_btn.click(
|
|
870
|
+
fn=app.prepare_ft_step2,
|
|
871
|
+
inputs=[gr.State(
|
|
872
|
+
lambda: app.current_step1_config or "temp_ft_step1_config.json"), target_models_input],
|
|
873
|
+
outputs=[step2_status, xgb_config_display, resn_config_display]
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
pre_run_btn.click(
|
|
877
|
+
fn=app.run_pre_oneway_ui,
|
|
878
|
+
inputs=[
|
|
879
|
+
pre_data_path, pre_model_name, pre_target, pre_weight,
|
|
880
|
+
pre_feature_list, pre_categorical, pre_n_bins,
|
|
881
|
+
pre_holdout, pre_seed, pre_output_dir
|
|
882
|
+
],
|
|
883
|
+
outputs=[pre_status, pre_log]
|
|
884
|
+
)
|
|
885
|
+
|
|
886
|
+
direct_run_btn.click(
|
|
887
|
+
fn=app.run_plot_direct_ui,
|
|
888
|
+
inputs=[direct_cfg_path, direct_xgb_cfg, direct_resn_cfg],
|
|
889
|
+
outputs=[direct_status, direct_log]
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
embed_run_btn.click(
|
|
893
|
+
fn=app.run_plot_embed_ui,
|
|
894
|
+
inputs=[embed_cfg_path, embed_xgb_cfg,
|
|
895
|
+
embed_resn_cfg, embed_ft_cfg, embed_runtime],
|
|
896
|
+
outputs=[embed_status, embed_log]
|
|
897
|
+
)
|
|
898
|
+
|
|
899
|
+
pred_run_btn.click(
|
|
900
|
+
fn=app.run_predict_ui,
|
|
901
|
+
inputs=[
|
|
902
|
+
pred_ft_cfg, pred_xgb_cfg, pred_resn_cfg, pred_input,
|
|
903
|
+
pred_output, pred_model_name, pred_model_keys
|
|
904
|
+
],
|
|
905
|
+
outputs=[pred_status, pred_log]
|
|
906
|
+
)
|
|
907
|
+
|
|
908
|
+
cmp_xgb_run_btn.click(
|
|
909
|
+
fn=app.run_compare_xgb_ui,
|
|
910
|
+
inputs=[
|
|
911
|
+
cmp_xgb_direct_cfg, cmp_xgb_ft_cfg, cmp_xgb_embed_cfg,
|
|
912
|
+
cmp_xgb_label_direct, cmp_xgb_label_ft,
|
|
913
|
+
cmp_xgb_runtime, cmp_xgb_bins
|
|
914
|
+
],
|
|
915
|
+
outputs=[cmp_xgb_status, cmp_xgb_log]
|
|
916
|
+
)
|
|
917
|
+
|
|
918
|
+
cmp_resn_run_btn.click(
|
|
919
|
+
fn=app.run_compare_resn_ui,
|
|
920
|
+
inputs=[
|
|
921
|
+
cmp_resn_direct_cfg, cmp_resn_ft_cfg, cmp_resn_embed_cfg,
|
|
922
|
+
cmp_resn_label_direct, cmp_resn_label_ft,
|
|
923
|
+
cmp_resn_runtime, cmp_resn_bins
|
|
924
|
+
],
|
|
925
|
+
outputs=[cmp_resn_status, cmp_resn_log]
|
|
926
|
+
)
|
|
927
|
+
|
|
928
|
+
return demo
|
|
929
|
+
|
|
930
|
+
|
|
896
931
|
if __name__ == "__main__":
|
|
897
932
|
demo = create_ui()
|
|
898
|
-
|
|
899
|
-
server_name
|
|
900
|
-
server_port
|
|
901
|
-
share
|
|
902
|
-
show_error
|
|
903
|
-
|
|
933
|
+
launch_kwargs = {
|
|
934
|
+
"server_name": "0.0.0.0",
|
|
935
|
+
"server_port": 7860,
|
|
936
|
+
"share": False,
|
|
937
|
+
"show_error": True,
|
|
938
|
+
}
|
|
939
|
+
if "analytics_enabled" in inspect.signature(demo.launch).parameters:
|
|
940
|
+
launch_kwargs["analytics_enabled"] = False
|
|
941
|
+
demo.launch(**launch_kwargs)
|