ins-pricing 0.3.4__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,941 @@
1
+ """
2
+ Insurance Pricing Model Training Frontend
3
+ A Gradio-based web interface for configuring and running insurance pricing models.
4
+ """
5
+
6
+ import os
7
+ import platform
8
+ import subprocess
9
+ from ins_pricing.frontend.example_workflows import (
10
+ run_compare_ft_embed,
11
+ run_plot_direct,
12
+ run_plot_embed,
13
+ run_predict_ft_embed,
14
+ run_pre_oneway,
15
+ )
16
+ from ins_pricing.frontend.ft_workflow import FTWorkflowHelper
17
+ from ins_pricing.frontend.runner import TaskRunner
18
+ from ins_pricing.frontend.config_builder import ConfigBuilder
19
+ import json
20
+ import sys
21
+ import inspect
22
+ from pathlib import Path
23
+ from typing import Optional, Dict, Any, Callable, Iterable, Tuple
24
+ import threading
25
+ import queue
26
+ import time
27
+
28
+ # Add parent directory to path to import ins_pricing modules
29
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
30
+
31
+ os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False")
32
+ os.environ.setdefault("GRADIO_TELEMETRY_ENABLED", "False")
33
+ os.environ.setdefault("GRADIO_CHECK_VERSION", "False")
34
+ os.environ.setdefault("GRADIO_VERSION_CHECK", "False")
35
+
36
+ import gradio as gr
37
+
38
+
39
+
40
+
41
+ class FrontendDependencyError(RuntimeError):
42
+ pass
43
+
44
+
45
+ def _check_frontend_deps() -> None:
46
+ """Fail fast with a clear message if frontend deps are incompatible."""
47
+ try:
48
+ import gradio # noqa: F401
49
+ except Exception as exc:
50
+ raise FrontendDependencyError(f"Failed to import gradio: {exc}")
51
+
52
+ try:
53
+ import huggingface_hub as hf # noqa: F401
54
+ except Exception as exc:
55
+ raise FrontendDependencyError(
56
+ f"Failed to import huggingface_hub: {exc}. "
57
+ "Pin version with `pip install 'huggingface_hub<0.24'`."
58
+ )
59
+
60
+ if not hasattr(hf, 'HfFolder'):
61
+ raise FrontendDependencyError(
62
+ 'Incompatible huggingface_hub detected: missing HfFolder. '
63
+ 'Please install `huggingface_hub<0.24`.'
64
+ )
65
+
66
+
67
+ class PricingApp:
68
+ """Main application class for the insurance pricing model tasks interface."""
69
+
70
+ def __init__(self):
71
+ self.config_builder = ConfigBuilder()
72
+ self.runner = TaskRunner()
73
+ self.ft_workflow = FTWorkflowHelper()
74
+ self.current_config = {}
75
+ self.current_step1_config = None
76
+ self.current_config_path: Optional[Path] = None
77
+ self.current_config_dir: Optional[Path] = None
78
+
79
+ def load_json_config(self, file_path) -> tuple[str, Dict[str, Any], str]:
80
+ """Load configuration from uploaded JSON file."""
81
+ if not file_path:
82
+ return "No file uploaded", {}, ""
83
+
84
+ try:
85
+ path = Path(file_path).resolve()
86
+ with open(path, 'r', encoding='utf-8') as f:
87
+ config = json.load(f)
88
+ self.current_config = config
89
+ self.current_config_path = path
90
+ self.current_config_dir = path.parent
91
+ config_json = json.dumps(config, indent=2, ensure_ascii=False)
92
+ return f"Configuration loaded successfully from {path.name}", config, config_json
93
+ except Exception as e:
94
+ return f"Error loading config: {str(e)}", {}, ""
95
+
96
+ def build_config_from_ui(
97
+ self,
98
+ data_dir: str,
99
+ model_list: str,
100
+ model_categories: str,
101
+ target: str,
102
+ weight: str,
103
+ feature_list: str,
104
+ categorical_features: str,
105
+ task_type: str,
106
+ prop_test: float,
107
+ holdout_ratio: float,
108
+ val_ratio: float,
109
+ split_strategy: str,
110
+ rand_seed: int,
111
+ epochs: int,
112
+ output_dir: str,
113
+ use_gpu: bool,
114
+ model_keys: str,
115
+ max_evals: int,
116
+ xgb_max_depth_max: int,
117
+ xgb_n_estimators_max: int,
118
+ ) -> tuple[str, str]:
119
+ """Build configuration from UI parameters."""
120
+ try:
121
+ # Parse comma-separated lists
122
+ model_list = [x.strip()
123
+ for x in model_list.split(',') if x.strip()]
124
+ model_categories = [x.strip()
125
+ for x in model_categories.split(',') if x.strip()]
126
+ feature_list = [x.strip()
127
+ for x in feature_list.split(',') if x.strip()]
128
+ categorical_features = [
129
+ x.strip() for x in categorical_features.split(',') if x.strip()]
130
+ model_keys = [x.strip()
131
+ for x in model_keys.split(',') if x.strip()]
132
+
133
+ config = self.config_builder.build_config(
134
+ data_dir=data_dir,
135
+ model_list=model_list,
136
+ model_categories=model_categories,
137
+ target=target,
138
+ weight=weight,
139
+ feature_list=feature_list,
140
+ categorical_features=categorical_features,
141
+ task_type=task_type,
142
+ prop_test=prop_test,
143
+ holdout_ratio=holdout_ratio,
144
+ val_ratio=val_ratio,
145
+ split_strategy=split_strategy,
146
+ rand_seed=rand_seed,
147
+ epochs=epochs,
148
+ output_dir=output_dir,
149
+ use_gpu=use_gpu,
150
+ model_keys=model_keys,
151
+ max_evals=max_evals,
152
+ xgb_max_depth_max=xgb_max_depth_max,
153
+ xgb_n_estimators_max=xgb_n_estimators_max,
154
+ )
155
+
156
+ is_valid, msg = self.config_builder.validate_config(config)
157
+ if not is_valid:
158
+ return f"Validation failed: {msg}", ""
159
+
160
+ self.current_config = config
161
+ self.current_config_path = None
162
+ self.current_config_dir = None
163
+ config_json = json.dumps(config, indent=2, ensure_ascii=False)
164
+ return "Configuration built successfully", config_json
165
+
166
+ except Exception as e:
167
+ return f"Error building config: {str(e)}", ""
168
+
169
+ def save_config(self, config_json: str, filename: str) -> str:
170
+ """Save current configuration to file."""
171
+ if not config_json:
172
+ return "No configuration to save"
173
+
174
+ try:
175
+ config_path = Path(filename)
176
+ with open(config_path, 'w', encoding='utf-8') as f:
177
+ json.dump(json.loads(config_json), f,
178
+ indent=2, ensure_ascii=False)
179
+ return f"Configuration saved to {config_path}"
180
+ except Exception as e:
181
+ return f"Error saving config: {str(e)}"
182
+
183
+ def run_training(self, config_json: str) -> tuple[str, str]:
184
+ """
185
+ Run task (training, explain, plotting, etc.) with the current configuration.
186
+
187
+ The task type is automatically detected from config.runner.mode.
188
+ Supported modes: entry (training), explain, incremental, watchdog, etc.
189
+ """
190
+ try:
191
+ temp_config_path = None
192
+ if config_json:
193
+ config = json.loads(config_json)
194
+ task_mode = config.get('runner', {}).get('mode', 'entry')
195
+ base_dir = self.current_config_dir or Path.cwd()
196
+ temp_config_path = (base_dir / "temp_config.json").resolve()
197
+ with open(temp_config_path, 'w', encoding='utf-8') as f:
198
+ json.dump(config, f, indent=2)
199
+ config_path = temp_config_path
200
+ elif self.current_config_path and self.current_config_path.exists():
201
+ config_path = self.current_config_path
202
+ config = json.loads(config_path.read_text(encoding="utf-8"))
203
+ task_mode = config.get('runner', {}).get('mode', 'entry')
204
+ elif self.current_config:
205
+ config = self.current_config
206
+ task_mode = config.get('runner', {}).get('mode', 'entry')
207
+ temp_config_path = (Path.cwd() / "temp_config.json").resolve()
208
+ with open(temp_config_path, 'w', encoding='utf-8') as f:
209
+ json.dump(config, f, indent=2)
210
+ config_path = temp_config_path
211
+ else:
212
+ return "No configuration provided", ""
213
+
214
+ log_generator = self.runner.run_task(str(config_path))
215
+
216
+ # Collect logs
217
+ full_log = ""
218
+ for log_line in log_generator:
219
+ full_log += log_line + "\n"
220
+ yield f"Task [{task_mode}] in progress...", full_log
221
+
222
+ # Clean up
223
+ if temp_config_path and temp_config_path.exists():
224
+ temp_config_path.unlink()
225
+
226
+ yield f"Task [{task_mode}] completed!", full_log
227
+
228
+ except Exception as e:
229
+ error_msg = f"Error during task execution: {str(e)}"
230
+ yield error_msg, error_msg
231
+
232
+ def prepare_ft_step1(self, config_json: str, use_ddp: bool, nproc: int) -> tuple[str, str]:
233
+ """Prepare FT Step 1 configuration."""
234
+ if not config_json:
235
+ return "No configuration provided", ""
236
+
237
+ try:
238
+ config = json.loads(config_json)
239
+ step1_config = self.ft_workflow.prepare_step1_config(
240
+ base_config=config,
241
+ use_ddp=use_ddp,
242
+ nproc_per_node=int(nproc)
243
+ )
244
+
245
+ # Save to temp file
246
+ temp_path = Path("temp_ft_step1_config.json")
247
+ with open(temp_path, 'w', encoding='utf-8') as f:
248
+ json.dump(step1_config, f, indent=2)
249
+
250
+ self.current_step1_config = str(temp_path)
251
+ step1_json = json.dumps(step1_config, indent=2, ensure_ascii=False)
252
+
253
+ return "Step 1 config prepared. Click 'Run Step 1' to train FT embeddings.", step1_json
254
+
255
+ except Exception as e:
256
+ return f"Error preparing Step 1 config: {str(e)}", ""
257
+
258
+ def prepare_ft_step2(self, step1_config_path: str, target_models: str) -> tuple[str, str, str]:
259
+ """Prepare FT Step 2 configurations."""
260
+ if not step1_config_path or not os.path.exists(step1_config_path):
261
+ return "Step 1 config not found. Run Step 1 first.", "", ""
262
+
263
+ try:
264
+ models = [m.strip() for m in target_models.split(',') if m.strip()]
265
+ xgb_cfg, resn_cfg = self.ft_workflow.generate_step2_configs(
266
+ step1_config_path=step1_config_path,
267
+ target_models=models
268
+ )
269
+
270
+ status_msg = f"Step 2 configs prepared for: {', '.join(models)}"
271
+ xgb_json = json.dumps(
272
+ xgb_cfg, indent=2, ensure_ascii=False) if xgb_cfg else ""
273
+ resn_json = json.dumps(
274
+ resn_cfg, indent=2, ensure_ascii=False) if resn_cfg else ""
275
+
276
+ return status_msg, xgb_json, resn_json
277
+
278
+ except FileNotFoundError as e:
279
+ return f"Error: {str(e)}\n\nMake sure Step 1 completed successfully.", "", ""
280
+ except Exception as e:
281
+ return f"Error preparing Step 2 configs: {str(e)}", "", ""
282
+
283
+ def open_results_folder(self, config_json: str) -> str:
284
+ """Open the results folder in file explorer."""
285
+ try:
286
+ if config_json:
287
+ config = json.loads(config_json)
288
+ output_dir = config.get('output_dir', './Results')
289
+ results_path = Path(output_dir).resolve()
290
+ elif self.current_config_path and self.current_config_path.exists():
291
+ config = json.loads(
292
+ self.current_config_path.read_text(encoding="utf-8"))
293
+ output_dir = config.get('output_dir', './Results')
294
+ results_path = (
295
+ self.current_config_path.parent / output_dir).resolve()
296
+ elif self.current_config:
297
+ output_dir = self.current_config.get('output_dir', './Results')
298
+ results_path = Path(output_dir).resolve()
299
+ else:
300
+ return "No configuration loaded"
301
+
302
+ if not results_path.exists():
303
+ return f"Results folder does not exist yet: {results_path}"
304
+
305
+ # Open folder based on OS
306
+ system = platform.system()
307
+ if system == "Windows":
308
+ os.startfile(results_path)
309
+ elif system == "Darwin": # macOS
310
+ subprocess.run(["open", str(results_path)])
311
+ else: # Linux
312
+ subprocess.run(["xdg-open", str(results_path)])
313
+
314
+ return f"Opened folder: {results_path}"
315
+
316
+ except Exception as e:
317
+ return f"Error opening folder: {str(e)}"
318
+
319
+ def _run_workflow(self, label: str, func: Callable, *args, **kwargs):
320
+ """Run a workflow function and stream logs."""
321
+ try:
322
+ log_generator = self.runner.run_callable(func, *args, **kwargs)
323
+ full_log = ""
324
+ for log_line in log_generator:
325
+ full_log += log_line + "\n"
326
+ yield f"{label} in progress...", full_log
327
+ yield f"{label} completed!", full_log
328
+ except Exception as e:
329
+ error_msg = f"{label} error: {str(e)}"
330
+ yield error_msg, error_msg
331
+
332
+ def run_pre_oneway_ui(
333
+ self,
334
+ data_path: str,
335
+ model_name: str,
336
+ target_col: str,
337
+ weight_col: str,
338
+ feature_list: str,
339
+ categorical_features: str,
340
+ n_bins: int,
341
+ holdout_ratio: float,
342
+ rand_seed: int,
343
+ output_dir: str,
344
+ ):
345
+ yield from self._run_workflow(
346
+ "Pre-Oneway Plot",
347
+ run_pre_oneway,
348
+ data_path=data_path,
349
+ model_name=model_name,
350
+ target_col=target_col,
351
+ weight_col=weight_col,
352
+ feature_list=feature_list,
353
+ categorical_features=categorical_features,
354
+ n_bins=n_bins,
355
+ holdout_ratio=holdout_ratio,
356
+ rand_seed=rand_seed,
357
+ output_dir=output_dir or None,
358
+ )
359
+
360
+ def run_plot_direct_ui(self, cfg_path: str, xgb_cfg_path: str, resn_cfg_path: str):
361
+ yield from self._run_workflow(
362
+ "Direct Plot",
363
+ run_plot_direct,
364
+ cfg_path=cfg_path,
365
+ xgb_cfg_path=xgb_cfg_path,
366
+ resn_cfg_path=resn_cfg_path,
367
+ )
368
+
369
+ def run_plot_embed_ui(
370
+ self,
371
+ cfg_path: str,
372
+ xgb_cfg_path: str,
373
+ resn_cfg_path: str,
374
+ ft_cfg_path: str,
375
+ use_runtime_ft_embedding: bool,
376
+ ):
377
+ yield from self._run_workflow(
378
+ "Embed Plot",
379
+ run_plot_embed,
380
+ cfg_path=cfg_path,
381
+ xgb_cfg_path=xgb_cfg_path,
382
+ resn_cfg_path=resn_cfg_path,
383
+ ft_cfg_path=ft_cfg_path,
384
+ use_runtime_ft_embedding=use_runtime_ft_embedding,
385
+ )
386
+
387
+ def run_predict_ui(
388
+ self,
389
+ ft_cfg_path: str,
390
+ xgb_cfg_path: str,
391
+ resn_cfg_path: str,
392
+ input_path: str,
393
+ output_path: str,
394
+ model_name: str,
395
+ model_keys: str,
396
+ ):
397
+ yield from self._run_workflow(
398
+ "Prediction",
399
+ run_predict_ft_embed,
400
+ ft_cfg_path=ft_cfg_path,
401
+ xgb_cfg_path=xgb_cfg_path or None,
402
+ resn_cfg_path=resn_cfg_path or None,
403
+ input_path=input_path,
404
+ output_path=output_path,
405
+ model_name=model_name or None,
406
+ model_keys=model_keys,
407
+ )
408
+
409
+ def run_compare_xgb_ui(
410
+ self,
411
+ direct_cfg_path: str,
412
+ ft_cfg_path: str,
413
+ ft_embed_cfg_path: str,
414
+ label_direct: str,
415
+ label_ft: str,
416
+ use_runtime_ft_embedding: bool,
417
+ n_bins_override: int,
418
+ ):
419
+ yield from self._run_workflow(
420
+ "Compare XGB",
421
+ run_compare_ft_embed,
422
+ direct_cfg_path=direct_cfg_path,
423
+ ft_cfg_path=ft_cfg_path,
424
+ ft_embed_cfg_path=ft_embed_cfg_path,
425
+ model_key="xgb",
426
+ label_direct=label_direct,
427
+ label_ft=label_ft,
428
+ use_runtime_ft_embedding=use_runtime_ft_embedding,
429
+ n_bins_override=n_bins_override,
430
+ )
431
+
432
+ def run_compare_resn_ui(
433
+ self,
434
+ direct_cfg_path: str,
435
+ ft_cfg_path: str,
436
+ ft_embed_cfg_path: str,
437
+ label_direct: str,
438
+ label_ft: str,
439
+ use_runtime_ft_embedding: bool,
440
+ n_bins_override: int,
441
+ ):
442
+ yield from self._run_workflow(
443
+ "Compare ResNet",
444
+ run_compare_ft_embed,
445
+ direct_cfg_path=direct_cfg_path,
446
+ ft_cfg_path=ft_cfg_path,
447
+ ft_embed_cfg_path=ft_embed_cfg_path,
448
+ model_key="resn",
449
+ label_direct=label_direct,
450
+ label_ft=label_ft,
451
+ use_runtime_ft_embedding=use_runtime_ft_embedding,
452
+ n_bins_override=n_bins_override,
453
+ )
454
+
455
+
456
+ def create_ui():
457
+ """Create the Gradio interface."""
458
+ app = PricingApp()
459
+
460
+ with gr.Blocks(title="Insurance Pricing Model Training", theme=gr.themes.Soft()) as demo:
461
+ gr.Markdown(
462
+ """
463
+ # Insurance Pricing Model Training Interface
464
+ Configure and train insurance pricing models with an easy-to-use interface.
465
+
466
+ **Two ways to configure:**
467
+ 1. **Upload JSON Config**: Upload an existing configuration file
468
+ 2. **Manual Configuration**: Fill in the parameters below
469
+ """
470
+ )
471
+
472
+ with gr.Tab("Configuration"):
473
+ with gr.Row():
474
+ with gr.Column(scale=1):
475
+ gr.Markdown("### Load Configuration")
476
+ json_file = gr.File(
477
+ label="Upload JSON Config File",
478
+ file_types=[".json"],
479
+ type="filepath"
480
+ )
481
+ load_btn = gr.Button("Load Config", variant="primary")
482
+ load_status = gr.Textbox(
483
+ label="Load Status", interactive=False)
484
+
485
+ with gr.Column(scale=2):
486
+ gr.Markdown("### Current Configuration")
487
+ config_display = gr.JSON(label="Configuration", value={})
488
+
489
+ gr.Markdown("---")
490
+ gr.Markdown("### Manual Configuration")
491
+
492
+ with gr.Row():
493
+ with gr.Column():
494
+ gr.Markdown("#### Data Settings")
495
+ data_dir = gr.Textbox(
496
+ label="Data Directory", value="./Data")
497
+ model_list = gr.Textbox(
498
+ label="Model List (comma-separated)", value="od")
499
+ model_categories = gr.Textbox(
500
+ label="Model Categories (comma-separated)", value="bc")
501
+ target = gr.Textbox(
502
+ label="Target Column", value="response")
503
+ weight = gr.Textbox(label="Weight Column", value="weights")
504
+
505
+ gr.Markdown("#### Features")
506
+ feature_list = gr.Textbox(
507
+ label="Feature List (comma-separated)",
508
+ placeholder="feature_1, feature_2, feature_3",
509
+ lines=3
510
+ )
511
+ categorical_features = gr.Textbox(
512
+ label="Categorical Features (comma-separated)",
513
+ placeholder="feature_2, feature_3",
514
+ lines=2
515
+ )
516
+
517
+ with gr.Column():
518
+ gr.Markdown("#### Model Settings")
519
+ task_type = gr.Dropdown(
520
+ label="Task Type",
521
+ choices=["regression", "binary", "multiclass"],
522
+ value="regression"
523
+ )
524
+ prop_test = gr.Slider(
525
+ label="Test Proportion", minimum=0.1, maximum=0.5, value=0.25, step=0.05)
526
+ holdout_ratio = gr.Slider(
527
+ label="Holdout Ratio", minimum=0.1, maximum=0.5, value=0.25, step=0.05)
528
+ val_ratio = gr.Slider(
529
+ label="Validation Ratio", minimum=0.1, maximum=0.5, value=0.25, step=0.05)
530
+ split_strategy = gr.Dropdown(
531
+ label="Split Strategy",
532
+ choices=["random", "stratified", "time", "group"],
533
+ value="random"
534
+ )
535
+ rand_seed = gr.Number(
536
+ label="Random Seed", value=13, precision=0)
537
+ epochs = gr.Number(label="Epochs", value=50, precision=0)
538
+
539
+ with gr.Row():
540
+ with gr.Column():
541
+ gr.Markdown("#### Training Settings")
542
+ output_dir = gr.Textbox(
543
+ label="Output Directory", value="./Results")
544
+ use_gpu = gr.Checkbox(label="Use GPU", value=True)
545
+ model_keys = gr.Textbox(
546
+ label="Model Keys (comma-separated)",
547
+ value="xgb, resn",
548
+ placeholder="xgb, resn, ft, gnn"
549
+ )
550
+ max_evals = gr.Number(
551
+ label="Max Evaluations", value=50, precision=0)
552
+
553
+ with gr.Column():
554
+ gr.Markdown("#### XGBoost Settings")
555
+ xgb_max_depth_max = gr.Number(
556
+ label="XGB Max Depth", value=25, precision=0)
557
+ xgb_n_estimators_max = gr.Number(
558
+ label="XGB Max Estimators", value=500, precision=0)
559
+
560
+ with gr.Row():
561
+ build_btn = gr.Button(
562
+ "Build Configuration", variant="primary", size="lg")
563
+ save_config_btn = gr.Button(
564
+ "Save Configuration", variant="secondary", size="lg")
565
+
566
+ with gr.Row():
567
+ build_status = gr.Textbox(label="Status", interactive=False)
568
+ config_json = gr.Textbox(
569
+ label="Generated Config (JSON)", lines=10, max_lines=20)
570
+
571
+ save_filename = gr.Textbox(
572
+ label="Save Filename", value="my_config.json")
573
+ save_status = gr.Textbox(label="Save Status", interactive=False)
574
+
575
+ with gr.Tab("Run Task"):
576
+ gr.Markdown(
577
+ """
578
+ ### Run Model Task
579
+ Click the button below to execute the task defined in your configuration.
580
+ Task type is automatically detected from `config.runner.mode`:
581
+ - **entry**: Standard model training
582
+ - **explain**: Model explanation (permutation, SHAP, integrated gradients)
583
+ - **incremental**: Incremental training
584
+ - **watchdog**: Watchdog mode
585
+
586
+ Task logs will appear in real-time below.
587
+ """
588
+ )
589
+
590
+ with gr.Row():
591
+ run_btn = gr.Button("Run Task", variant="primary", size="lg")
592
+ run_status = gr.Textbox(label="Task Status", interactive=False)
593
+
594
+ gr.Markdown("### Task Logs")
595
+ log_output = gr.Textbox(
596
+ label="Logs",
597
+ lines=25,
598
+ max_lines=50,
599
+ interactive=False,
600
+ autoscroll=True
601
+ )
602
+
603
+ gr.Markdown("---")
604
+ with gr.Row():
605
+ open_folder_btn = gr.Button("Open Results Folder", size="lg")
606
+ folder_status = gr.Textbox(
607
+ label="Status", interactive=False, scale=2)
608
+
609
+ with gr.Tab("FT Two-Step Workflow"):
610
+ gr.Markdown(
611
+ """
612
+ ### FT-Transformer Two-Step Training
613
+
614
+ Automates the FT → XGB/ResN workflow:
615
+ 1. **Step 1**: Train FT-Transformer as unsupervised embedding generator
616
+ 2. **Step 2**: Merge embeddings with raw data and train XGB/ResN
617
+
618
+ **Instructions**:
619
+ 1. Load or build a base configuration in the Configuration tab
620
+ 2. Prepare Step 1 config (FT embeddings)
621
+ 3. Run Step 1 to generate embeddings
622
+ 4. Prepare Step 2 configs (XGB/ResN using embeddings)
623
+ 5. Run Step 2 with the generated configs
624
+ """
625
+ )
626
+
627
+ with gr.Row():
628
+ with gr.Column():
629
+ gr.Markdown("### Step 1: FT Embedding Generation")
630
+ ft_use_ddp = gr.Checkbox(
631
+ label="Use DDP for FT", value=True)
632
+ ft_nproc = gr.Number(
633
+ label="Number of Processes (DDP)", value=2, precision=0)
634
+
635
+ prepare_step1_btn = gr.Button(
636
+ "Prepare Step 1 Config", variant="primary")
637
+ step1_status = gr.Textbox(
638
+ label="Status", interactive=False)
639
+ step1_config_display = gr.Textbox(
640
+ label="Step 1 Config (FT Embedding)",
641
+ lines=15,
642
+ max_lines=25
643
+ )
644
+
645
+ with gr.Column():
646
+ gr.Markdown("### Step 2: Train XGB/ResN with Embeddings")
647
+ target_models_input = gr.Textbox(
648
+ label="Target Models (comma-separated)",
649
+ value="xgb, resn",
650
+ placeholder="xgb, resn"
651
+ )
652
+
653
+ prepare_step2_btn = gr.Button(
654
+ "Prepare Step 2 Configs", variant="primary")
655
+ step2_status = gr.Textbox(
656
+ label="Status", interactive=False)
657
+
658
+ with gr.Tab("XGB Config"):
659
+ xgb_config_display = gr.Textbox(
660
+ label="XGB Step 2 Config",
661
+ lines=15,
662
+ max_lines=25
663
+ )
664
+
665
+ with gr.Tab("ResN Config"):
666
+ resn_config_display = gr.Textbox(
667
+ label="ResN Step 2 Config",
668
+ lines=15,
669
+ max_lines=25
670
+ )
671
+
672
+ gr.Markdown("---")
673
+ gr.Markdown(
674
+ """
675
+ ### Quick Actions
676
+ After preparing configs, you can:
677
+ - Copy the Step 1 config and paste it in the **Configuration** tab, then run it in **Run Task** tab
678
+ - After Step 1 completes, click **Prepare Step 2 Configs**
679
+ - Copy the Step 2 configs (XGB or ResN) and run them in **Run Task** tab
680
+ """
681
+ )
682
+
683
+ with gr.Tab("Plotting"):
684
+ gr.Markdown(
685
+ """
686
+ ### Plotting Workflows
687
+ Run the plotting steps from the example notebooks.
688
+ """
689
+ )
690
+
691
+ with gr.Tab("Pre Oneway"):
692
+ with gr.Row():
693
+ with gr.Column():
694
+ pre_data_path = gr.Textbox(
695
+ label="Data Path", value="./Data/od_bc.csv")
696
+ pre_model_name = gr.Textbox(
697
+ label="Model Name", value="od_bc")
698
+ pre_target = gr.Textbox(
699
+ label="Target Column", value="response")
700
+ pre_weight = gr.Textbox(
701
+ label="Weight Column", value="weights")
702
+ pre_output_dir = gr.Textbox(
703
+ label="Output Dir (optional)", value="")
704
+ with gr.Column():
705
+ pre_feature_list = gr.Textbox(
706
+ label="Feature List (comma-separated)",
707
+ lines=4,
708
+ placeholder="feature_1, feature_2, feature_3",
709
+ )
710
+ pre_categorical = gr.Textbox(
711
+ label="Categorical Features (comma-separated, optional)",
712
+ lines=3,
713
+ placeholder="feature_2, feature_3",
714
+ )
715
+ pre_n_bins = gr.Number(
716
+ label="Bins", value=10, precision=0)
717
+ pre_holdout = gr.Slider(
718
+ label="Holdout Ratio",
719
+ minimum=0.0,
720
+ maximum=0.5,
721
+ value=0.25,
722
+ step=0.05,
723
+ )
724
+ pre_seed = gr.Number(
725
+ label="Random Seed", value=13, precision=0)
726
+
727
+ pre_run_btn = gr.Button("Run Pre Oneway", variant="primary")
728
+ pre_status = gr.Textbox(label="Status", interactive=False)
729
+ pre_log = gr.Textbox(label="Logs", lines=15,
730
+ max_lines=40, interactive=False)
731
+
732
+ with gr.Tab("Direct Plot"):
733
+ direct_cfg_path = gr.Textbox(
734
+ label="Plot Config", value="config_plot.json")
735
+ direct_xgb_cfg = gr.Textbox(
736
+ label="XGB Config", value="config_xgb_direct.json")
737
+ direct_resn_cfg = gr.Textbox(
738
+ label="ResN Config", value="config_resn_direct.json")
739
+ direct_run_btn = gr.Button(
740
+ "Run Direct Plot", variant="primary")
741
+ direct_status = gr.Textbox(label="Status", interactive=False)
742
+ direct_log = gr.Textbox(
743
+ label="Logs", lines=15, max_lines=40, interactive=False)
744
+
745
+ with gr.Tab("Embed Plot"):
746
+ embed_cfg_path = gr.Textbox(
747
+ label="Plot Config", value="config_plot.json")
748
+ embed_xgb_cfg = gr.Textbox(
749
+ label="XGB Embed Config", value="config_xgb_from_ft_unsupervised.json")
750
+ embed_resn_cfg = gr.Textbox(
751
+ label="ResN Embed Config", value="config_resn_from_ft_unsupervised.json")
752
+ embed_ft_cfg = gr.Textbox(
753
+ label="FT Embed Config", value="config_ft_unsupervised_ddp_embed.json")
754
+ embed_runtime = gr.Checkbox(
755
+ label="Use Runtime FT Embedding", value=False)
756
+ embed_run_btn = gr.Button("Run Embed Plot", variant="primary")
757
+ embed_status = gr.Textbox(label="Status", interactive=False)
758
+ embed_log = gr.Textbox(
759
+ label="Logs", lines=15, max_lines=40, interactive=False)
760
+
761
+ with gr.Tab("Prediction"):
762
+ gr.Markdown("### FT Embed Prediction")
763
+ pred_ft_cfg = gr.Textbox(
764
+ label="FT Config", value="config_ft_unsupervised_ddp_embed.json")
765
+ pred_xgb_cfg = gr.Textbox(
766
+ label="XGB Config (optional)", value="config_xgb_from_ft_unsupervised.json")
767
+ pred_resn_cfg = gr.Textbox(
768
+ label="ResN Config (optional)", value="config_resn_from_ft_unsupervised.json")
769
+ pred_input = gr.Textbox(
770
+ label="Input Data", value="./Data/od_bc_new.csv")
771
+ pred_output = gr.Textbox(
772
+ label="Output CSV", value="./Results/predictions_ft_xgb.csv")
773
+ pred_model_name = gr.Textbox(
774
+ label="Model Name (optional)", value="")
775
+ pred_model_keys = gr.Textbox(label="Model Keys", value="xgb, resn")
776
+ pred_run_btn = gr.Button("Run Prediction", variant="primary")
777
+ pred_status = gr.Textbox(label="Status", interactive=False)
778
+ pred_log = gr.Textbox(label="Logs", lines=15,
779
+ max_lines=40, interactive=False)
780
+
781
+ with gr.Tab("Compare"):
782
+ gr.Markdown("### Compare Direct vs FT-Embed Models")
783
+
784
+ with gr.Tab("Compare XGB"):
785
+ cmp_xgb_direct_cfg = gr.Textbox(
786
+ label="Direct XGB Config", value="config_xgb_direct.json")
787
+ cmp_xgb_ft_cfg = gr.Textbox(
788
+ label="FT Config", value="config_ft_unsupervised_ddp_embed.json")
789
+ cmp_xgb_embed_cfg = gr.Textbox(
790
+ label="FT-Embed XGB Config", value="config_xgb_from_ft_unsupervised.json")
791
+ cmp_xgb_label_direct = gr.Textbox(
792
+ label="Direct Label", value="XGB_raw")
793
+ cmp_xgb_label_ft = gr.Textbox(
794
+ label="FT Label", value="XGB_ft_embed")
795
+ cmp_xgb_runtime = gr.Checkbox(
796
+ label="Use Runtime FT Embedding", value=False)
797
+ cmp_xgb_bins = gr.Number(
798
+ label="Bins Override", value=10, precision=0)
799
+ cmp_xgb_run_btn = gr.Button(
800
+ "Run XGB Compare", variant="primary")
801
+ cmp_xgb_status = gr.Textbox(label="Status", interactive=False)
802
+ cmp_xgb_log = gr.Textbox(
803
+ label="Logs", lines=15, max_lines=40, interactive=False)
804
+
805
+ with gr.Tab("Compare ResNet"):
806
+ cmp_resn_direct_cfg = gr.Textbox(
807
+ label="Direct ResN Config", value="config_resn_direct.json")
808
+ cmp_resn_ft_cfg = gr.Textbox(
809
+ label="FT Config", value="config_ft_unsupervised_ddp_embed.json")
810
+ cmp_resn_embed_cfg = gr.Textbox(
811
+ label="FT-Embed ResN Config", value="config_resn_from_ft_unsupervised.json")
812
+ cmp_resn_label_direct = gr.Textbox(
813
+ label="Direct Label", value="ResN_raw")
814
+ cmp_resn_label_ft = gr.Textbox(
815
+ label="FT Label", value="ResN_ft_embed")
816
+ cmp_resn_runtime = gr.Checkbox(
817
+ label="Use Runtime FT Embedding", value=False)
818
+ cmp_resn_bins = gr.Number(
819
+ label="Bins Override", value=10, precision=0)
820
+ cmp_resn_run_btn = gr.Button(
821
+ "Run ResNet Compare", variant="primary")
822
+ cmp_resn_status = gr.Textbox(label="Status", interactive=False)
823
+ cmp_resn_log = gr.Textbox(
824
+ label="Logs", lines=15, max_lines=40, interactive=False)
825
+
826
+ # Event handlers
827
+ load_btn.click(
828
+ fn=app.load_json_config,
829
+ inputs=[json_file],
830
+ outputs=[load_status, config_display, config_json]
831
+ )
832
+
833
+ build_btn.click(
834
+ fn=app.build_config_from_ui,
835
+ inputs=[
836
+ data_dir, model_list, model_categories, target, weight,
837
+ feature_list, categorical_features, task_type, prop_test,
838
+ holdout_ratio, val_ratio, split_strategy, rand_seed, epochs,
839
+ output_dir, use_gpu, model_keys, max_evals,
840
+ xgb_max_depth_max, xgb_n_estimators_max
841
+ ],
842
+ outputs=[build_status, config_json]
843
+ )
844
+
845
+ save_config_btn.click(
846
+ fn=app.save_config,
847
+ inputs=[config_json, save_filename],
848
+ outputs=[save_status]
849
+ )
850
+
851
+ run_btn.click(
852
+ fn=app.run_training,
853
+ inputs=[config_json],
854
+ outputs=[run_status, log_output]
855
+ )
856
+
857
+ open_folder_btn.click(
858
+ fn=app.open_results_folder,
859
+ inputs=[config_json],
860
+ outputs=[folder_status]
861
+ )
862
+
863
+ prepare_step1_btn.click(
864
+ fn=app.prepare_ft_step1,
865
+ inputs=[config_json, ft_use_ddp, ft_nproc],
866
+ outputs=[step1_status, step1_config_display]
867
+ )
868
+
869
+ prepare_step2_btn.click(
870
+ fn=app.prepare_ft_step2,
871
+ inputs=[gr.State(
872
+ lambda: app.current_step1_config or "temp_ft_step1_config.json"), target_models_input],
873
+ outputs=[step2_status, xgb_config_display, resn_config_display]
874
+ )
875
+
876
+ pre_run_btn.click(
877
+ fn=app.run_pre_oneway_ui,
878
+ inputs=[
879
+ pre_data_path, pre_model_name, pre_target, pre_weight,
880
+ pre_feature_list, pre_categorical, pre_n_bins,
881
+ pre_holdout, pre_seed, pre_output_dir
882
+ ],
883
+ outputs=[pre_status, pre_log]
884
+ )
885
+
886
+ direct_run_btn.click(
887
+ fn=app.run_plot_direct_ui,
888
+ inputs=[direct_cfg_path, direct_xgb_cfg, direct_resn_cfg],
889
+ outputs=[direct_status, direct_log]
890
+ )
891
+
892
+ embed_run_btn.click(
893
+ fn=app.run_plot_embed_ui,
894
+ inputs=[embed_cfg_path, embed_xgb_cfg,
895
+ embed_resn_cfg, embed_ft_cfg, embed_runtime],
896
+ outputs=[embed_status, embed_log]
897
+ )
898
+
899
+ pred_run_btn.click(
900
+ fn=app.run_predict_ui,
901
+ inputs=[
902
+ pred_ft_cfg, pred_xgb_cfg, pred_resn_cfg, pred_input,
903
+ pred_output, pred_model_name, pred_model_keys
904
+ ],
905
+ outputs=[pred_status, pred_log]
906
+ )
907
+
908
+ cmp_xgb_run_btn.click(
909
+ fn=app.run_compare_xgb_ui,
910
+ inputs=[
911
+ cmp_xgb_direct_cfg, cmp_xgb_ft_cfg, cmp_xgb_embed_cfg,
912
+ cmp_xgb_label_direct, cmp_xgb_label_ft,
913
+ cmp_xgb_runtime, cmp_xgb_bins
914
+ ],
915
+ outputs=[cmp_xgb_status, cmp_xgb_log]
916
+ )
917
+
918
+ cmp_resn_run_btn.click(
919
+ fn=app.run_compare_resn_ui,
920
+ inputs=[
921
+ cmp_resn_direct_cfg, cmp_resn_ft_cfg, cmp_resn_embed_cfg,
922
+ cmp_resn_label_direct, cmp_resn_label_ft,
923
+ cmp_resn_runtime, cmp_resn_bins
924
+ ],
925
+ outputs=[cmp_resn_status, cmp_resn_log]
926
+ )
927
+
928
+ return demo
929
+
930
+
931
+ if __name__ == "__main__":
932
+ demo = create_ui()
933
+ launch_kwargs = {
934
+ "server_name": "0.0.0.0",
935
+ "server_port": 7860,
936
+ "share": False,
937
+ "show_error": True,
938
+ }
939
+ if "analytics_enabled" in inspect.signature(demo.launch).parameters:
940
+ launch_kwargs["analytics_enabled"] = False
941
+ demo.launch(**launch_kwargs)