ins-pricing 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,9 @@
1
- """
2
- Insurance Pricing Model Training Frontend
3
- A Gradio-based web interface for configuring and running insurance pricing models.
4
- """
5
-
1
+ """
2
+ Insurance Pricing Model Training Frontend
3
+ A Gradio-based web interface for configuring and running insurance pricing models.
4
+ """
5
+
6
+ import os
6
7
  import platform
7
8
  import subprocess
8
9
  from ins_pricing.frontend.example_workflows import (
@@ -12,13 +13,12 @@ from ins_pricing.frontend.example_workflows import (
12
13
  run_predict_ft_embed,
13
14
  run_pre_oneway,
14
15
  )
15
- from ins_pricing.frontend.ft_workflow import FTWorkflowHelper
16
- from ins_pricing.frontend.runner import TaskRunner
16
+ from ins_pricing.frontend.ft_workflow import FTWorkflowHelper
17
+ from ins_pricing.frontend.runner import TaskRunner
17
18
  from ins_pricing.frontend.config_builder import ConfigBuilder
18
- import gradio as gr
19
19
  import json
20
20
  import sys
21
- import os
21
+ import inspect
22
22
  from pathlib import Path
23
23
  from typing import Optional, Dict, Any, Callable, Iterable, Tuple
24
24
  import threading
@@ -28,876 +28,914 @@ import time
28
28
  # Add parent directory to path to import ins_pricing modules
29
29
  sys.path.insert(0, str(Path(__file__).parent.parent.parent))
30
30
 
31
+ os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False")
32
+ os.environ.setdefault("GRADIO_TELEMETRY_ENABLED", "False")
33
+ os.environ.setdefault("GRADIO_CHECK_VERSION", "False")
34
+ os.environ.setdefault("GRADIO_VERSION_CHECK", "False")
31
35
 
32
- class PricingApp:
33
- """Main application class for the insurance pricing model tasks interface."""
34
-
35
- def __init__(self):
36
- self.config_builder = ConfigBuilder()
37
- self.runner = TaskRunner()
38
- self.ft_workflow = FTWorkflowHelper()
39
- self.current_config = {}
40
- self.current_step1_config = None
41
- self.current_config_path: Optional[Path] = None
42
- self.current_config_dir: Optional[Path] = None
43
-
44
- def load_json_config(self, file_path) -> tuple[str, Dict[str, Any], str]:
45
- """Load configuration from uploaded JSON file."""
46
- if not file_path:
47
- return "No file uploaded", {}, ""
48
-
49
- try:
50
- path = Path(file_path).resolve()
51
- with open(path, 'r', encoding='utf-8') as f:
52
- config = json.load(f)
53
- self.current_config = config
54
- self.current_config_path = path
55
- self.current_config_dir = path.parent
56
- config_json = json.dumps(config, indent=2, ensure_ascii=False)
57
- return f"Configuration loaded successfully from {path.name}", config, config_json
58
- except Exception as e:
59
- return f"Error loading config: {str(e)}", {}, ""
60
-
61
- def build_config_from_ui(
62
- self,
63
- data_dir: str,
64
- model_list: str,
65
- model_categories: str,
66
- target: str,
67
- weight: str,
68
- feature_list: str,
69
- categorical_features: str,
70
- task_type: str,
71
- prop_test: float,
72
- holdout_ratio: float,
73
- val_ratio: float,
74
- split_strategy: str,
75
- rand_seed: int,
76
- epochs: int,
77
- output_dir: str,
78
- use_gpu: bool,
79
- model_keys: str,
80
- max_evals: int,
81
- xgb_max_depth_max: int,
82
- xgb_n_estimators_max: int,
83
- ) -> tuple[str, str]:
84
- """Build configuration from UI parameters."""
85
- try:
86
- # Parse comma-separated lists
87
- model_list = [x.strip()
88
- for x in model_list.split(',') if x.strip()]
89
- model_categories = [x.strip()
90
- for x in model_categories.split(',') if x.strip()]
91
- feature_list = [x.strip()
92
- for x in feature_list.split(',') if x.strip()]
93
- categorical_features = [
94
- x.strip() for x in categorical_features.split(',') if x.strip()]
95
- model_keys = [x.strip()
96
- for x in model_keys.split(',') if x.strip()]
97
-
98
- config = self.config_builder.build_config(
99
- data_dir=data_dir,
100
- model_list=model_list,
101
- model_categories=model_categories,
102
- target=target,
103
- weight=weight,
104
- feature_list=feature_list,
105
- categorical_features=categorical_features,
106
- task_type=task_type,
107
- prop_test=prop_test,
108
- holdout_ratio=holdout_ratio,
109
- val_ratio=val_ratio,
110
- split_strategy=split_strategy,
111
- rand_seed=rand_seed,
112
- epochs=epochs,
113
- output_dir=output_dir,
114
- use_gpu=use_gpu,
115
- model_keys=model_keys,
116
- max_evals=max_evals,
117
- xgb_max_depth_max=xgb_max_depth_max,
118
- xgb_n_estimators_max=xgb_n_estimators_max,
119
- )
120
-
121
- is_valid, msg = self.config_builder.validate_config(config)
122
- if not is_valid:
123
- return f"Validation failed: {msg}", ""
124
-
125
- self.current_config = config
126
- self.current_config_path = None
127
- self.current_config_dir = None
128
- config_json = json.dumps(config, indent=2, ensure_ascii=False)
129
- return "Configuration built successfully", config_json
130
-
131
- except Exception as e:
132
- return f"Error building config: {str(e)}", ""
133
-
134
- def save_config(self, config_json: str, filename: str) -> str:
135
- """Save current configuration to file."""
136
- if not config_json:
137
- return "No configuration to save"
138
-
139
- try:
140
- config_path = Path(filename)
141
- with open(config_path, 'w', encoding='utf-8') as f:
142
- json.dump(json.loads(config_json), f,
143
- indent=2, ensure_ascii=False)
144
- return f"Configuration saved to {config_path}"
145
- except Exception as e:
146
- return f"Error saving config: {str(e)}"
147
-
148
- def run_training(self, config_json: str) -> tuple[str, str]:
149
- """
150
- Run task (training, explain, plotting, etc.) with the current configuration.
151
-
152
- The task type is automatically detected from config.runner.mode.
153
- Supported modes: entry (training), explain, incremental, watchdog, etc.
154
- """
155
- try:
156
- temp_config_path = None
157
- if config_json:
158
- config = json.loads(config_json)
159
- task_mode = config.get('runner', {}).get('mode', 'entry')
160
- base_dir = self.current_config_dir or Path.cwd()
161
- temp_config_path = (base_dir / "temp_config.json").resolve()
162
- with open(temp_config_path, 'w', encoding='utf-8') as f:
163
- json.dump(config, f, indent=2)
164
- config_path = temp_config_path
165
- elif self.current_config_path and self.current_config_path.exists():
166
- config_path = self.current_config_path
167
- config = json.loads(config_path.read_text(encoding="utf-8"))
168
- task_mode = config.get('runner', {}).get('mode', 'entry')
169
- elif self.current_config:
170
- config = self.current_config
171
- task_mode = config.get('runner', {}).get('mode', 'entry')
172
- temp_config_path = (Path.cwd() / "temp_config.json").resolve()
173
- with open(temp_config_path, 'w', encoding='utf-8') as f:
174
- json.dump(config, f, indent=2)
175
- config_path = temp_config_path
176
- else:
177
- return "No configuration provided", ""
178
-
179
- log_generator = self.runner.run_task(str(config_path))
180
-
181
- # Collect logs
182
- full_log = ""
183
- for log_line in log_generator:
184
- full_log += log_line + "\n"
185
- yield f"Task [{task_mode}] in progress...", full_log
186
-
187
- # Clean up
188
- if temp_config_path and temp_config_path.exists():
189
- temp_config_path.unlink()
190
-
191
- yield f"Task [{task_mode}] completed!", full_log
192
-
193
- except Exception as e:
194
- error_msg = f"Error during task execution: {str(e)}"
195
- yield error_msg, error_msg
196
-
197
- def prepare_ft_step1(self, config_json: str, use_ddp: bool, nproc: int) -> tuple[str, str]:
198
- """Prepare FT Step 1 configuration."""
199
- if not config_json:
200
- return "No configuration provided", ""
201
-
202
- try:
203
- config = json.loads(config_json)
204
- step1_config = self.ft_workflow.prepare_step1_config(
205
- base_config=config,
206
- use_ddp=use_ddp,
207
- nproc_per_node=int(nproc)
208
- )
209
-
210
- # Save to temp file
211
- temp_path = Path("temp_ft_step1_config.json")
212
- with open(temp_path, 'w', encoding='utf-8') as f:
213
- json.dump(step1_config, f, indent=2)
214
-
215
- self.current_step1_config = str(temp_path)
216
- step1_json = json.dumps(step1_config, indent=2, ensure_ascii=False)
217
-
218
- return "Step 1 config prepared. Click 'Run Step 1' to train FT embeddings.", step1_json
219
-
220
- except Exception as e:
221
- return f"Error preparing Step 1 config: {str(e)}", ""
222
-
223
- def prepare_ft_step2(self, step1_config_path: str, target_models: str) -> tuple[str, str, str]:
224
- """Prepare FT Step 2 configurations."""
225
- if not step1_config_path or not os.path.exists(step1_config_path):
226
- return "Step 1 config not found. Run Step 1 first.", "", ""
227
-
228
- try:
229
- models = [m.strip() for m in target_models.split(',') if m.strip()]
230
- xgb_cfg, resn_cfg = self.ft_workflow.generate_step2_configs(
231
- step1_config_path=step1_config_path,
232
- target_models=models
233
- )
234
-
235
- status_msg = f"Step 2 configs prepared for: {', '.join(models)}"
236
- xgb_json = json.dumps(
237
- xgb_cfg, indent=2, ensure_ascii=False) if xgb_cfg else ""
238
- resn_json = json.dumps(
239
- resn_cfg, indent=2, ensure_ascii=False) if resn_cfg else ""
240
-
241
- return status_msg, xgb_json, resn_json
242
-
243
- except FileNotFoundError as e:
244
- return f"Error: {str(e)}\n\nMake sure Step 1 completed successfully.", "", ""
245
- except Exception as e:
246
- return f"Error preparing Step 2 configs: {str(e)}", "", ""
247
-
248
- def open_results_folder(self, config_json: str) -> str:
249
- """Open the results folder in file explorer."""
250
- try:
251
- if config_json:
252
- config = json.loads(config_json)
253
- output_dir = config.get('output_dir', './Results')
254
- results_path = Path(output_dir).resolve()
255
- elif self.current_config_path and self.current_config_path.exists():
256
- config = json.loads(
257
- self.current_config_path.read_text(encoding="utf-8"))
258
- output_dir = config.get('output_dir', './Results')
259
- results_path = (
260
- self.current_config_path.parent / output_dir).resolve()
261
- elif self.current_config:
262
- output_dir = self.current_config.get('output_dir', './Results')
263
- results_path = Path(output_dir).resolve()
264
- else:
265
- return "No configuration loaded"
266
-
267
- if not results_path.exists():
268
- return f"Results folder does not exist yet: {results_path}"
269
-
270
- # Open folder based on OS
271
- system = platform.system()
272
- if system == "Windows":
273
- os.startfile(results_path)
274
- elif system == "Darwin": # macOS
275
- subprocess.run(["open", str(results_path)])
276
- else: # Linux
277
- subprocess.run(["xdg-open", str(results_path)])
278
-
279
- return f"Opened folder: {results_path}"
280
-
281
- except Exception as e:
282
- return f"Error opening folder: {str(e)}"
283
-
284
- def _run_workflow(self, label: str, func: Callable, *args, **kwargs):
285
- """Run a workflow function and stream logs."""
286
- try:
287
- log_generator = self.runner.run_callable(func, *args, **kwargs)
288
- full_log = ""
289
- for log_line in log_generator:
290
- full_log += log_line + "\n"
291
- yield f"{label} in progress...", full_log
292
- yield f"{label} completed!", full_log
293
- except Exception as e:
294
- error_msg = f"{label} error: {str(e)}"
295
- yield error_msg, error_msg
296
-
297
- def run_pre_oneway_ui(
298
- self,
299
- data_path: str,
300
- model_name: str,
301
- target_col: str,
302
- weight_col: str,
303
- feature_list: str,
304
- categorical_features: str,
305
- n_bins: int,
306
- holdout_ratio: float,
307
- rand_seed: int,
308
- output_dir: str,
309
- ):
310
- yield from self._run_workflow(
311
- "Pre-Oneway Plot",
312
- run_pre_oneway,
313
- data_path=data_path,
314
- model_name=model_name,
315
- target_col=target_col,
316
- weight_col=weight_col,
317
- feature_list=feature_list,
318
- categorical_features=categorical_features,
319
- n_bins=n_bins,
320
- holdout_ratio=holdout_ratio,
321
- rand_seed=rand_seed,
322
- output_dir=output_dir or None,
323
- )
324
-
325
- def run_plot_direct_ui(self, cfg_path: str, xgb_cfg_path: str, resn_cfg_path: str):
326
- yield from self._run_workflow(
327
- "Direct Plot",
328
- run_plot_direct,
329
- cfg_path=cfg_path,
330
- xgb_cfg_path=xgb_cfg_path,
331
- resn_cfg_path=resn_cfg_path,
332
- )
333
-
334
- def run_plot_embed_ui(
335
- self,
336
- cfg_path: str,
337
- xgb_cfg_path: str,
338
- resn_cfg_path: str,
339
- ft_cfg_path: str,
340
- use_runtime_ft_embedding: bool,
341
- ):
342
- yield from self._run_workflow(
343
- "Embed Plot",
344
- run_plot_embed,
345
- cfg_path=cfg_path,
346
- xgb_cfg_path=xgb_cfg_path,
347
- resn_cfg_path=resn_cfg_path,
348
- ft_cfg_path=ft_cfg_path,
349
- use_runtime_ft_embedding=use_runtime_ft_embedding,
350
- )
351
-
352
- def run_predict_ui(
353
- self,
354
- ft_cfg_path: str,
355
- xgb_cfg_path: str,
356
- resn_cfg_path: str,
357
- input_path: str,
358
- output_path: str,
359
- model_name: str,
360
- model_keys: str,
361
- ):
362
- yield from self._run_workflow(
363
- "Prediction",
364
- run_predict_ft_embed,
365
- ft_cfg_path=ft_cfg_path,
366
- xgb_cfg_path=xgb_cfg_path or None,
367
- resn_cfg_path=resn_cfg_path or None,
368
- input_path=input_path,
369
- output_path=output_path,
370
- model_name=model_name or None,
371
- model_keys=model_keys,
372
- )
373
-
374
- def run_compare_xgb_ui(
375
- self,
376
- direct_cfg_path: str,
377
- ft_cfg_path: str,
378
- ft_embed_cfg_path: str,
379
- label_direct: str,
380
- label_ft: str,
381
- use_runtime_ft_embedding: bool,
382
- n_bins_override: int,
383
- ):
384
- yield from self._run_workflow(
385
- "Compare XGB",
386
- run_compare_ft_embed,
387
- direct_cfg_path=direct_cfg_path,
388
- ft_cfg_path=ft_cfg_path,
389
- ft_embed_cfg_path=ft_embed_cfg_path,
390
- model_key="xgb",
391
- label_direct=label_direct,
392
- label_ft=label_ft,
393
- use_runtime_ft_embedding=use_runtime_ft_embedding,
394
- n_bins_override=n_bins_override,
395
- )
396
-
397
- def run_compare_resn_ui(
398
- self,
399
- direct_cfg_path: str,
400
- ft_cfg_path: str,
401
- ft_embed_cfg_path: str,
402
- label_direct: str,
403
- label_ft: str,
404
- use_runtime_ft_embedding: bool,
405
- n_bins_override: int,
406
- ):
407
- yield from self._run_workflow(
408
- "Compare ResNet",
409
- run_compare_ft_embed,
410
- direct_cfg_path=direct_cfg_path,
411
- ft_cfg_path=ft_cfg_path,
412
- ft_embed_cfg_path=ft_embed_cfg_path,
413
- model_key="resn",
414
- label_direct=label_direct,
415
- label_ft=label_ft,
416
- use_runtime_ft_embedding=use_runtime_ft_embedding,
417
- n_bins_override=n_bins_override,
418
- )
419
-
420
-
421
- def create_ui():
422
- """Create the Gradio interface."""
423
- app = PricingApp()
424
-
425
- with gr.Blocks(title="Insurance Pricing Model Training", theme=gr.themes.Soft()) as demo:
426
- gr.Markdown(
427
- """
428
- # Insurance Pricing Model Training Interface
429
- Configure and train insurance pricing models with an easy-to-use interface.
430
-
431
- **Two ways to configure:**
432
- 1. **Upload JSON Config**: Upload an existing configuration file
433
- 2. **Manual Configuration**: Fill in the parameters below
434
- """
435
- )
436
-
437
- with gr.Tab("Configuration"):
438
- with gr.Row():
439
- with gr.Column(scale=1):
440
- gr.Markdown("### Load Configuration")
441
- json_file = gr.File(
442
- label="Upload JSON Config File",
443
- file_types=[".json"],
444
- type="filepath"
445
- )
446
- load_btn = gr.Button("Load Config", variant="primary")
447
- load_status = gr.Textbox(
448
- label="Load Status", interactive=False)
449
-
450
- with gr.Column(scale=2):
451
- gr.Markdown("### Current Configuration")
452
- config_display = gr.JSON(label="Configuration", value={})
453
-
454
- gr.Markdown("---")
455
- gr.Markdown("### Manual Configuration")
456
-
457
- with gr.Row():
458
- with gr.Column():
459
- gr.Markdown("#### Data Settings")
460
- data_dir = gr.Textbox(
461
- label="Data Directory", value="./Data")
462
- model_list = gr.Textbox(
463
- label="Model List (comma-separated)", value="od")
464
- model_categories = gr.Textbox(
465
- label="Model Categories (comma-separated)", value="bc")
466
- target = gr.Textbox(
467
- label="Target Column", value="response")
468
- weight = gr.Textbox(label="Weight Column", value="weights")
469
-
470
- gr.Markdown("#### Features")
471
- feature_list = gr.Textbox(
472
- label="Feature List (comma-separated)",
473
- placeholder="feature_1, feature_2, feature_3",
474
- lines=3
475
- )
476
- categorical_features = gr.Textbox(
477
- label="Categorical Features (comma-separated)",
478
- placeholder="feature_2, feature_3",
479
- lines=2
480
- )
481
-
482
- with gr.Column():
483
- gr.Markdown("#### Model Settings")
484
- task_type = gr.Dropdown(
485
- label="Task Type",
486
- choices=["regression", "binary", "multiclass"],
487
- value="regression"
488
- )
489
- prop_test = gr.Slider(
490
- label="Test Proportion", minimum=0.1, maximum=0.5, value=0.25, step=0.05)
491
- holdout_ratio = gr.Slider(
492
- label="Holdout Ratio", minimum=0.1, maximum=0.5, value=0.25, step=0.05)
493
- val_ratio = gr.Slider(
494
- label="Validation Ratio", minimum=0.1, maximum=0.5, value=0.25, step=0.05)
495
- split_strategy = gr.Dropdown(
496
- label="Split Strategy",
497
- choices=["random", "stratified", "time", "group"],
498
- value="random"
499
- )
500
- rand_seed = gr.Number(
501
- label="Random Seed", value=13, precision=0)
502
- epochs = gr.Number(label="Epochs", value=50, precision=0)
503
-
504
- with gr.Row():
505
- with gr.Column():
506
- gr.Markdown("#### Training Settings")
507
- output_dir = gr.Textbox(
508
- label="Output Directory", value="./Results")
509
- use_gpu = gr.Checkbox(label="Use GPU", value=True)
510
- model_keys = gr.Textbox(
511
- label="Model Keys (comma-separated)",
512
- value="xgb, resn",
513
- placeholder="xgb, resn, ft, gnn"
514
- )
515
- max_evals = gr.Number(
516
- label="Max Evaluations", value=50, precision=0)
517
-
518
- with gr.Column():
519
- gr.Markdown("#### XGBoost Settings")
520
- xgb_max_depth_max = gr.Number(
521
- label="XGB Max Depth", value=25, precision=0)
522
- xgb_n_estimators_max = gr.Number(
523
- label="XGB Max Estimators", value=500, precision=0)
524
-
525
- with gr.Row():
526
- build_btn = gr.Button(
527
- "Build Configuration", variant="primary", size="lg")
528
- save_config_btn = gr.Button(
529
- "Save Configuration", variant="secondary", size="lg")
530
-
531
- with gr.Row():
532
- build_status = gr.Textbox(label="Status", interactive=False)
533
- config_json = gr.Textbox(
534
- label="Generated Config (JSON)", lines=10, max_lines=20)
535
-
536
- save_filename = gr.Textbox(
537
- label="Save Filename", value="my_config.json")
538
- save_status = gr.Textbox(label="Save Status", interactive=False)
539
-
540
- with gr.Tab("Run Task"):
541
- gr.Markdown(
542
- """
543
- ### Run Model Task
544
- Click the button below to execute the task defined in your configuration.
545
- Task type is automatically detected from `config.runner.mode`:
546
- - **entry**: Standard model training
547
- - **explain**: Model explanation (permutation, SHAP, integrated gradients)
548
- - **incremental**: Incremental training
549
- - **watchdog**: Watchdog mode
550
-
551
- Task logs will appear in real-time below.
552
- """
553
- )
554
-
555
- with gr.Row():
556
- run_btn = gr.Button("Run Task", variant="primary", size="lg")
557
- run_status = gr.Textbox(label="Task Status", interactive=False)
558
-
559
- gr.Markdown("### Task Logs")
560
- log_output = gr.Textbox(
561
- label="Logs",
562
- lines=25,
563
- max_lines=50,
564
- interactive=False,
565
- autoscroll=True
566
- )
567
-
568
- gr.Markdown("---")
569
- with gr.Row():
570
- open_folder_btn = gr.Button("Open Results Folder", size="lg")
571
- folder_status = gr.Textbox(
572
- label="Status", interactive=False, scale=2)
573
-
574
- with gr.Tab("FT Two-Step Workflow"):
575
- gr.Markdown(
576
- """
577
- ### FT-Transformer Two-Step Training
578
-
579
- Automates the FT → XGB/ResN workflow:
580
- 1. **Step 1**: Train FT-Transformer as unsupervised embedding generator
581
- 2. **Step 2**: Merge embeddings with raw data and train XGB/ResN
582
-
583
- **Instructions**:
584
- 1. Load or build a base configuration in the Configuration tab
585
- 2. Prepare Step 1 config (FT embeddings)
586
- 3. Run Step 1 to generate embeddings
587
- 4. Prepare Step 2 configs (XGB/ResN using embeddings)
588
- 5. Run Step 2 with the generated configs
589
- """
590
- )
591
-
592
- with gr.Row():
593
- with gr.Column():
594
- gr.Markdown("### Step 1: FT Embedding Generation")
595
- ft_use_ddp = gr.Checkbox(
596
- label="Use DDP for FT", value=True)
597
- ft_nproc = gr.Number(
598
- label="Number of Processes (DDP)", value=2, precision=0)
599
-
600
- prepare_step1_btn = gr.Button(
601
- "Prepare Step 1 Config", variant="primary")
602
- step1_status = gr.Textbox(
603
- label="Status", interactive=False)
604
- step1_config_display = gr.Textbox(
605
- label="Step 1 Config (FT Embedding)",
606
- lines=15,
607
- max_lines=25
608
- )
609
-
610
- with gr.Column():
611
- gr.Markdown("### Step 2: Train XGB/ResN with Embeddings")
612
- target_models_input = gr.Textbox(
613
- label="Target Models (comma-separated)",
614
- value="xgb, resn",
615
- placeholder="xgb, resn"
616
- )
617
-
618
- prepare_step2_btn = gr.Button(
619
- "Prepare Step 2 Configs", variant="primary")
620
- step2_status = gr.Textbox(
621
- label="Status", interactive=False)
622
-
623
- with gr.Tab("XGB Config"):
624
- xgb_config_display = gr.Textbox(
625
- label="XGB Step 2 Config",
626
- lines=15,
627
- max_lines=25
628
- )
629
-
630
- with gr.Tab("ResN Config"):
631
- resn_config_display = gr.Textbox(
632
- label="ResN Step 2 Config",
633
- lines=15,
634
- max_lines=25
635
- )
636
-
637
- gr.Markdown("---")
638
- gr.Markdown(
639
- """
640
- ### Quick Actions
641
- After preparing configs, you can:
642
- - Copy the Step 1 config and paste it in the **Configuration** tab, then run it in **Run Task** tab
643
- - After Step 1 completes, click **Prepare Step 2 Configs**
644
- - Copy the Step 2 configs (XGB or ResN) and run them in **Run Task** tab
645
- """
646
- )
647
-
648
- with gr.Tab("Plotting"):
649
- gr.Markdown(
650
- """
651
- ### Plotting Workflows
652
- Run the plotting steps from the example notebooks.
653
- """
654
- )
655
-
656
- with gr.Tab("Pre Oneway"):
657
- with gr.Row():
658
- with gr.Column():
659
- pre_data_path = gr.Textbox(
660
- label="Data Path", value="./Data/od_bc.csv")
661
- pre_model_name = gr.Textbox(
662
- label="Model Name", value="od_bc")
663
- pre_target = gr.Textbox(
664
- label="Target Column", value="response")
665
- pre_weight = gr.Textbox(
666
- label="Weight Column", value="weights")
667
- pre_output_dir = gr.Textbox(
668
- label="Output Dir (optional)", value="")
669
- with gr.Column():
670
- pre_feature_list = gr.Textbox(
671
- label="Feature List (comma-separated)",
672
- lines=4,
673
- placeholder="feature_1, feature_2, feature_3",
674
- )
675
- pre_categorical = gr.Textbox(
676
- label="Categorical Features (comma-separated, optional)",
677
- lines=3,
678
- placeholder="feature_2, feature_3",
679
- )
680
- pre_n_bins = gr.Number(
681
- label="Bins", value=10, precision=0)
682
- pre_holdout = gr.Slider(
683
- label="Holdout Ratio",
684
- minimum=0.0,
685
- maximum=0.5,
686
- value=0.25,
687
- step=0.05,
688
- )
689
- pre_seed = gr.Number(
690
- label="Random Seed", value=13, precision=0)
691
-
692
- pre_run_btn = gr.Button("Run Pre Oneway", variant="primary")
693
- pre_status = gr.Textbox(label="Status", interactive=False)
694
- pre_log = gr.Textbox(label="Logs", lines=15,
695
- max_lines=40, interactive=False)
696
-
697
- with gr.Tab("Direct Plot"):
698
- direct_cfg_path = gr.Textbox(
699
- label="Plot Config", value="config_plot.json")
700
- direct_xgb_cfg = gr.Textbox(
701
- label="XGB Config", value="config_xgb_direct.json")
702
- direct_resn_cfg = gr.Textbox(
703
- label="ResN Config", value="config_resn_direct.json")
704
- direct_run_btn = gr.Button(
705
- "Run Direct Plot", variant="primary")
706
- direct_status = gr.Textbox(label="Status", interactive=False)
707
- direct_log = gr.Textbox(
708
- label="Logs", lines=15, max_lines=40, interactive=False)
709
-
710
- with gr.Tab("Embed Plot"):
711
- embed_cfg_path = gr.Textbox(
712
- label="Plot Config", value="config_plot.json")
713
- embed_xgb_cfg = gr.Textbox(
714
- label="XGB Embed Config", value="config_xgb_from_ft_unsupervised.json")
715
- embed_resn_cfg = gr.Textbox(
716
- label="ResN Embed Config", value="config_resn_from_ft_unsupervised.json")
717
- embed_ft_cfg = gr.Textbox(
718
- label="FT Embed Config", value="config_ft_unsupervised_ddp_embed.json")
719
- embed_runtime = gr.Checkbox(
720
- label="Use Runtime FT Embedding", value=False)
721
- embed_run_btn = gr.Button("Run Embed Plot", variant="primary")
722
- embed_status = gr.Textbox(label="Status", interactive=False)
723
- embed_log = gr.Textbox(
724
- label="Logs", lines=15, max_lines=40, interactive=False)
725
-
726
- with gr.Tab("Prediction"):
727
- gr.Markdown("### FT Embed Prediction")
728
- pred_ft_cfg = gr.Textbox(
729
- label="FT Config", value="config_ft_unsupervised_ddp_embed.json")
730
- pred_xgb_cfg = gr.Textbox(
731
- label="XGB Config (optional)", value="config_xgb_from_ft_unsupervised.json")
732
- pred_resn_cfg = gr.Textbox(
733
- label="ResN Config (optional)", value="config_resn_from_ft_unsupervised.json")
734
- pred_input = gr.Textbox(
735
- label="Input Data", value="./Data/od_bc_new.csv")
736
- pred_output = gr.Textbox(
737
- label="Output CSV", value="./Results/predictions_ft_xgb.csv")
738
- pred_model_name = gr.Textbox(
739
- label="Model Name (optional)", value="")
740
- pred_model_keys = gr.Textbox(label="Model Keys", value="xgb, resn")
741
- pred_run_btn = gr.Button("Run Prediction", variant="primary")
742
- pred_status = gr.Textbox(label="Status", interactive=False)
743
- pred_log = gr.Textbox(label="Logs", lines=15,
744
- max_lines=40, interactive=False)
745
-
746
- with gr.Tab("Compare"):
747
- gr.Markdown("### Compare Direct vs FT-Embed Models")
748
-
749
- with gr.Tab("Compare XGB"):
750
- cmp_xgb_direct_cfg = gr.Textbox(
751
- label="Direct XGB Config", value="config_xgb_direct.json")
752
- cmp_xgb_ft_cfg = gr.Textbox(
753
- label="FT Config", value="config_ft_unsupervised_ddp_embed.json")
754
- cmp_xgb_embed_cfg = gr.Textbox(
755
- label="FT-Embed XGB Config", value="config_xgb_from_ft_unsupervised.json")
756
- cmp_xgb_label_direct = gr.Textbox(
757
- label="Direct Label", value="XGB_raw")
758
- cmp_xgb_label_ft = gr.Textbox(
759
- label="FT Label", value="XGB_ft_embed")
760
- cmp_xgb_runtime = gr.Checkbox(
761
- label="Use Runtime FT Embedding", value=False)
762
- cmp_xgb_bins = gr.Number(
763
- label="Bins Override", value=10, precision=0)
764
- cmp_xgb_run_btn = gr.Button(
765
- "Run XGB Compare", variant="primary")
766
- cmp_xgb_status = gr.Textbox(label="Status", interactive=False)
767
- cmp_xgb_log = gr.Textbox(
768
- label="Logs", lines=15, max_lines=40, interactive=False)
769
-
770
- with gr.Tab("Compare ResNet"):
771
- cmp_resn_direct_cfg = gr.Textbox(
772
- label="Direct ResN Config", value="config_resn_direct.json")
773
- cmp_resn_ft_cfg = gr.Textbox(
774
- label="FT Config", value="config_ft_unsupervised_ddp_embed.json")
775
- cmp_resn_embed_cfg = gr.Textbox(
776
- label="FT-Embed ResN Config", value="config_resn_from_ft_unsupervised.json")
777
- cmp_resn_label_direct = gr.Textbox(
778
- label="Direct Label", value="ResN_raw")
779
- cmp_resn_label_ft = gr.Textbox(
780
- label="FT Label", value="ResN_ft_embed")
781
- cmp_resn_runtime = gr.Checkbox(
782
- label="Use Runtime FT Embedding", value=False)
783
- cmp_resn_bins = gr.Number(
784
- label="Bins Override", value=10, precision=0)
785
- cmp_resn_run_btn = gr.Button(
786
- "Run ResNet Compare", variant="primary")
787
- cmp_resn_status = gr.Textbox(label="Status", interactive=False)
788
- cmp_resn_log = gr.Textbox(
789
- label="Logs", lines=15, max_lines=40, interactive=False)
790
-
791
- # Event handlers
792
- load_btn.click(
793
- fn=app.load_json_config,
794
- inputs=[json_file],
795
- outputs=[load_status, config_display, config_json]
796
- )
797
-
798
- build_btn.click(
799
- fn=app.build_config_from_ui,
800
- inputs=[
801
- data_dir, model_list, model_categories, target, weight,
802
- feature_list, categorical_features, task_type, prop_test,
803
- holdout_ratio, val_ratio, split_strategy, rand_seed, epochs,
804
- output_dir, use_gpu, model_keys, max_evals,
805
- xgb_max_depth_max, xgb_n_estimators_max
806
- ],
807
- outputs=[build_status, config_json]
808
- )
809
-
810
- save_config_btn.click(
811
- fn=app.save_config,
812
- inputs=[config_json, save_filename],
813
- outputs=[save_status]
814
- )
815
-
816
- run_btn.click(
817
- fn=app.run_training,
818
- inputs=[config_json],
819
- outputs=[run_status, log_output]
820
- )
821
-
822
- open_folder_btn.click(
823
- fn=app.open_results_folder,
824
- inputs=[config_json],
825
- outputs=[folder_status]
826
- )
827
-
828
- prepare_step1_btn.click(
829
- fn=app.prepare_ft_step1,
830
- inputs=[config_json, ft_use_ddp, ft_nproc],
831
- outputs=[step1_status, step1_config_display]
832
- )
833
-
834
- prepare_step2_btn.click(
835
- fn=app.prepare_ft_step2,
836
- inputs=[gr.State(
837
- lambda: app.current_step1_config or "temp_ft_step1_config.json"), target_models_input],
838
- outputs=[step2_status, xgb_config_display, resn_config_display]
839
- )
840
-
841
- pre_run_btn.click(
842
- fn=app.run_pre_oneway_ui,
843
- inputs=[
844
- pre_data_path, pre_model_name, pre_target, pre_weight,
845
- pre_feature_list, pre_categorical, pre_n_bins,
846
- pre_holdout, pre_seed, pre_output_dir
847
- ],
848
- outputs=[pre_status, pre_log]
849
- )
850
-
851
- direct_run_btn.click(
852
- fn=app.run_plot_direct_ui,
853
- inputs=[direct_cfg_path, direct_xgb_cfg, direct_resn_cfg],
854
- outputs=[direct_status, direct_log]
855
- )
856
-
857
- embed_run_btn.click(
858
- fn=app.run_plot_embed_ui,
859
- inputs=[embed_cfg_path, embed_xgb_cfg,
860
- embed_resn_cfg, embed_ft_cfg, embed_runtime],
861
- outputs=[embed_status, embed_log]
862
- )
863
-
864
- pred_run_btn.click(
865
- fn=app.run_predict_ui,
866
- inputs=[
867
- pred_ft_cfg, pred_xgb_cfg, pred_resn_cfg, pred_input,
868
- pred_output, pred_model_name, pred_model_keys
869
- ],
870
- outputs=[pred_status, pred_log]
871
- )
872
-
873
- cmp_xgb_run_btn.click(
874
- fn=app.run_compare_xgb_ui,
875
- inputs=[
876
- cmp_xgb_direct_cfg, cmp_xgb_ft_cfg, cmp_xgb_embed_cfg,
877
- cmp_xgb_label_direct, cmp_xgb_label_ft,
878
- cmp_xgb_runtime, cmp_xgb_bins
879
- ],
880
- outputs=[cmp_xgb_status, cmp_xgb_log]
881
- )
36
+ import gradio as gr
882
37
 
883
- cmp_resn_run_btn.click(
884
- fn=app.run_compare_resn_ui,
885
- inputs=[
886
- cmp_resn_direct_cfg, cmp_resn_ft_cfg, cmp_resn_embed_cfg,
887
- cmp_resn_label_direct, cmp_resn_label_ft,
888
- cmp_resn_runtime, cmp_resn_bins
889
- ],
890
- outputs=[cmp_resn_status, cmp_resn_log]
38
+
39
+
40
+
41
+ class FrontendDependencyError(RuntimeError):
42
+ pass
43
+
44
+
45
+ def _check_frontend_deps() -> None:
46
+ """Fail fast with a clear message if frontend deps are incompatible."""
47
+ try:
48
+ import gradio # noqa: F401
49
+ except Exception as exc:
50
+ raise FrontendDependencyError(f"Failed to import gradio: {exc}")
51
+
52
+ try:
53
+ import huggingface_hub as hf # noqa: F401
54
+ except Exception as exc:
55
+ raise FrontendDependencyError(
56
+ f"Failed to import huggingface_hub: {exc}. "
57
+ "Pin version with `pip install 'huggingface_hub<0.24'`."
891
58
  )
892
-
893
- return demo
894
-
895
-
59
+
60
+ if not hasattr(hf, 'HfFolder'):
61
+ raise FrontendDependencyError(
62
+ 'Incompatible huggingface_hub detected: missing HfFolder. '
63
+ 'Please install `huggingface_hub<0.24`.'
64
+ )
65
+
66
+
67
+ class PricingApp:
68
+ """Main application class for the insurance pricing model tasks interface."""
69
+
70
+ def __init__(self):
71
+ self.config_builder = ConfigBuilder()
72
+ self.runner = TaskRunner()
73
+ self.ft_workflow = FTWorkflowHelper()
74
+ self.current_config = {}
75
+ self.current_step1_config = None
76
+ self.current_config_path: Optional[Path] = None
77
+ self.current_config_dir: Optional[Path] = None
78
+
79
+ def load_json_config(self, file_path) -> tuple[str, Dict[str, Any], str]:
80
+ """Load configuration from uploaded JSON file."""
81
+ if not file_path:
82
+ return "No file uploaded", {}, ""
83
+
84
+ try:
85
+ path = Path(file_path).resolve()
86
+ with open(path, 'r', encoding='utf-8') as f:
87
+ config = json.load(f)
88
+ self.current_config = config
89
+ self.current_config_path = path
90
+ self.current_config_dir = path.parent
91
+ config_json = json.dumps(config, indent=2, ensure_ascii=False)
92
+ return f"Configuration loaded successfully from {path.name}", config, config_json
93
+ except Exception as e:
94
+ return f"Error loading config: {str(e)}", {}, ""
95
+
96
+ def build_config_from_ui(
97
+ self,
98
+ data_dir: str,
99
+ model_list: str,
100
+ model_categories: str,
101
+ target: str,
102
+ weight: str,
103
+ feature_list: str,
104
+ categorical_features: str,
105
+ task_type: str,
106
+ prop_test: float,
107
+ holdout_ratio: float,
108
+ val_ratio: float,
109
+ split_strategy: str,
110
+ rand_seed: int,
111
+ epochs: int,
112
+ output_dir: str,
113
+ use_gpu: bool,
114
+ model_keys: str,
115
+ max_evals: int,
116
+ xgb_max_depth_max: int,
117
+ xgb_n_estimators_max: int,
118
+ ) -> tuple[str, str]:
119
+ """Build configuration from UI parameters."""
120
+ try:
121
+ # Parse comma-separated lists
122
+ model_list = [x.strip()
123
+ for x in model_list.split(',') if x.strip()]
124
+ model_categories = [x.strip()
125
+ for x in model_categories.split(',') if x.strip()]
126
+ feature_list = [x.strip()
127
+ for x in feature_list.split(',') if x.strip()]
128
+ categorical_features = [
129
+ x.strip() for x in categorical_features.split(',') if x.strip()]
130
+ model_keys = [x.strip()
131
+ for x in model_keys.split(',') if x.strip()]
132
+
133
+ config = self.config_builder.build_config(
134
+ data_dir=data_dir,
135
+ model_list=model_list,
136
+ model_categories=model_categories,
137
+ target=target,
138
+ weight=weight,
139
+ feature_list=feature_list,
140
+ categorical_features=categorical_features,
141
+ task_type=task_type,
142
+ prop_test=prop_test,
143
+ holdout_ratio=holdout_ratio,
144
+ val_ratio=val_ratio,
145
+ split_strategy=split_strategy,
146
+ rand_seed=rand_seed,
147
+ epochs=epochs,
148
+ output_dir=output_dir,
149
+ use_gpu=use_gpu,
150
+ model_keys=model_keys,
151
+ max_evals=max_evals,
152
+ xgb_max_depth_max=xgb_max_depth_max,
153
+ xgb_n_estimators_max=xgb_n_estimators_max,
154
+ )
155
+
156
+ is_valid, msg = self.config_builder.validate_config(config)
157
+ if not is_valid:
158
+ return f"Validation failed: {msg}", ""
159
+
160
+ self.current_config = config
161
+ self.current_config_path = None
162
+ self.current_config_dir = None
163
+ config_json = json.dumps(config, indent=2, ensure_ascii=False)
164
+ return "Configuration built successfully", config_json
165
+
166
+ except Exception as e:
167
+ return f"Error building config: {str(e)}", ""
168
+
169
+ def save_config(self, config_json: str, filename: str) -> str:
170
+ """Save current configuration to file."""
171
+ if not config_json:
172
+ return "No configuration to save"
173
+
174
+ try:
175
+ config_path = Path(filename)
176
+ with open(config_path, 'w', encoding='utf-8') as f:
177
+ json.dump(json.loads(config_json), f,
178
+ indent=2, ensure_ascii=False)
179
+ return f"Configuration saved to {config_path}"
180
+ except Exception as e:
181
+ return f"Error saving config: {str(e)}"
182
+
183
+ def run_training(self, config_json: str) -> tuple[str, str]:
184
+ """
185
+ Run task (training, explain, plotting, etc.) with the current configuration.
186
+
187
+ The task type is automatically detected from config.runner.mode.
188
+ Supported modes: entry (training), explain, incremental, watchdog, etc.
189
+ """
190
+ try:
191
+ temp_config_path = None
192
+ if config_json:
193
+ config = json.loads(config_json)
194
+ task_mode = config.get('runner', {}).get('mode', 'entry')
195
+ base_dir = self.current_config_dir or Path.cwd()
196
+ temp_config_path = (base_dir / "temp_config.json").resolve()
197
+ with open(temp_config_path, 'w', encoding='utf-8') as f:
198
+ json.dump(config, f, indent=2)
199
+ config_path = temp_config_path
200
+ elif self.current_config_path and self.current_config_path.exists():
201
+ config_path = self.current_config_path
202
+ config = json.loads(config_path.read_text(encoding="utf-8"))
203
+ task_mode = config.get('runner', {}).get('mode', 'entry')
204
+ elif self.current_config:
205
+ config = self.current_config
206
+ task_mode = config.get('runner', {}).get('mode', 'entry')
207
+ temp_config_path = (Path.cwd() / "temp_config.json").resolve()
208
+ with open(temp_config_path, 'w', encoding='utf-8') as f:
209
+ json.dump(config, f, indent=2)
210
+ config_path = temp_config_path
211
+ else:
212
+ return "No configuration provided", ""
213
+
214
+ log_generator = self.runner.run_task(str(config_path))
215
+
216
+ # Collect logs
217
+ full_log = ""
218
+ for log_line in log_generator:
219
+ full_log += log_line + "\n"
220
+ yield f"Task [{task_mode}] in progress...", full_log
221
+
222
+ # Clean up
223
+ if temp_config_path and temp_config_path.exists():
224
+ temp_config_path.unlink()
225
+
226
+ yield f"Task [{task_mode}] completed!", full_log
227
+
228
+ except Exception as e:
229
+ error_msg = f"Error during task execution: {str(e)}"
230
+ yield error_msg, error_msg
231
+
232
+ def prepare_ft_step1(self, config_json: str, use_ddp: bool, nproc: int) -> tuple[str, str]:
233
+ """Prepare FT Step 1 configuration."""
234
+ if not config_json:
235
+ return "No configuration provided", ""
236
+
237
+ try:
238
+ config = json.loads(config_json)
239
+ step1_config = self.ft_workflow.prepare_step1_config(
240
+ base_config=config,
241
+ use_ddp=use_ddp,
242
+ nproc_per_node=int(nproc)
243
+ )
244
+
245
+ # Save to temp file
246
+ temp_path = Path("temp_ft_step1_config.json")
247
+ with open(temp_path, 'w', encoding='utf-8') as f:
248
+ json.dump(step1_config, f, indent=2)
249
+
250
+ self.current_step1_config = str(temp_path)
251
+ step1_json = json.dumps(step1_config, indent=2, ensure_ascii=False)
252
+
253
+ return "Step 1 config prepared. Click 'Run Step 1' to train FT embeddings.", step1_json
254
+
255
+ except Exception as e:
256
+ return f"Error preparing Step 1 config: {str(e)}", ""
257
+
258
+ def prepare_ft_step2(self, step1_config_path: str, target_models: str) -> tuple[str, str, str]:
259
+ """Prepare FT Step 2 configurations."""
260
+ if not step1_config_path or not os.path.exists(step1_config_path):
261
+ return "Step 1 config not found. Run Step 1 first.", "", ""
262
+
263
+ try:
264
+ models = [m.strip() for m in target_models.split(',') if m.strip()]
265
+ xgb_cfg, resn_cfg = self.ft_workflow.generate_step2_configs(
266
+ step1_config_path=step1_config_path,
267
+ target_models=models
268
+ )
269
+
270
+ status_msg = f"Step 2 configs prepared for: {', '.join(models)}"
271
+ xgb_json = json.dumps(
272
+ xgb_cfg, indent=2, ensure_ascii=False) if xgb_cfg else ""
273
+ resn_json = json.dumps(
274
+ resn_cfg, indent=2, ensure_ascii=False) if resn_cfg else ""
275
+
276
+ return status_msg, xgb_json, resn_json
277
+
278
+ except FileNotFoundError as e:
279
+ return f"Error: {str(e)}\n\nMake sure Step 1 completed successfully.", "", ""
280
+ except Exception as e:
281
+ return f"Error preparing Step 2 configs: {str(e)}", "", ""
282
+
283
+ def open_results_folder(self, config_json: str) -> str:
284
+ """Open the results folder in file explorer."""
285
+ try:
286
+ if config_json:
287
+ config = json.loads(config_json)
288
+ output_dir = config.get('output_dir', './Results')
289
+ results_path = Path(output_dir).resolve()
290
+ elif self.current_config_path and self.current_config_path.exists():
291
+ config = json.loads(
292
+ self.current_config_path.read_text(encoding="utf-8"))
293
+ output_dir = config.get('output_dir', './Results')
294
+ results_path = (
295
+ self.current_config_path.parent / output_dir).resolve()
296
+ elif self.current_config:
297
+ output_dir = self.current_config.get('output_dir', './Results')
298
+ results_path = Path(output_dir).resolve()
299
+ else:
300
+ return "No configuration loaded"
301
+
302
+ if not results_path.exists():
303
+ return f"Results folder does not exist yet: {results_path}"
304
+
305
+ # Open folder based on OS
306
+ system = platform.system()
307
+ if system == "Windows":
308
+ os.startfile(results_path)
309
+ elif system == "Darwin": # macOS
310
+ subprocess.run(["open", str(results_path)])
311
+ else: # Linux
312
+ subprocess.run(["xdg-open", str(results_path)])
313
+
314
+ return f"Opened folder: {results_path}"
315
+
316
+ except Exception as e:
317
+ return f"Error opening folder: {str(e)}"
318
+
319
+ def _run_workflow(self, label: str, func: Callable, *args, **kwargs):
320
+ """Run a workflow function and stream logs."""
321
+ try:
322
+ log_generator = self.runner.run_callable(func, *args, **kwargs)
323
+ full_log = ""
324
+ for log_line in log_generator:
325
+ full_log += log_line + "\n"
326
+ yield f"{label} in progress...", full_log
327
+ yield f"{label} completed!", full_log
328
+ except Exception as e:
329
+ error_msg = f"{label} error: {str(e)}"
330
+ yield error_msg, error_msg
331
+
332
+ def run_pre_oneway_ui(
333
+ self,
334
+ data_path: str,
335
+ model_name: str,
336
+ target_col: str,
337
+ weight_col: str,
338
+ feature_list: str,
339
+ categorical_features: str,
340
+ n_bins: int,
341
+ holdout_ratio: float,
342
+ rand_seed: int,
343
+ output_dir: str,
344
+ ):
345
+ yield from self._run_workflow(
346
+ "Pre-Oneway Plot",
347
+ run_pre_oneway,
348
+ data_path=data_path,
349
+ model_name=model_name,
350
+ target_col=target_col,
351
+ weight_col=weight_col,
352
+ feature_list=feature_list,
353
+ categorical_features=categorical_features,
354
+ n_bins=n_bins,
355
+ holdout_ratio=holdout_ratio,
356
+ rand_seed=rand_seed,
357
+ output_dir=output_dir or None,
358
+ )
359
+
360
+ def run_plot_direct_ui(self, cfg_path: str, xgb_cfg_path: str, resn_cfg_path: str):
361
+ yield from self._run_workflow(
362
+ "Direct Plot",
363
+ run_plot_direct,
364
+ cfg_path=cfg_path,
365
+ xgb_cfg_path=xgb_cfg_path,
366
+ resn_cfg_path=resn_cfg_path,
367
+ )
368
+
369
+ def run_plot_embed_ui(
370
+ self,
371
+ cfg_path: str,
372
+ xgb_cfg_path: str,
373
+ resn_cfg_path: str,
374
+ ft_cfg_path: str,
375
+ use_runtime_ft_embedding: bool,
376
+ ):
377
+ yield from self._run_workflow(
378
+ "Embed Plot",
379
+ run_plot_embed,
380
+ cfg_path=cfg_path,
381
+ xgb_cfg_path=xgb_cfg_path,
382
+ resn_cfg_path=resn_cfg_path,
383
+ ft_cfg_path=ft_cfg_path,
384
+ use_runtime_ft_embedding=use_runtime_ft_embedding,
385
+ )
386
+
387
+ def run_predict_ui(
388
+ self,
389
+ ft_cfg_path: str,
390
+ xgb_cfg_path: str,
391
+ resn_cfg_path: str,
392
+ input_path: str,
393
+ output_path: str,
394
+ model_name: str,
395
+ model_keys: str,
396
+ ):
397
+ yield from self._run_workflow(
398
+ "Prediction",
399
+ run_predict_ft_embed,
400
+ ft_cfg_path=ft_cfg_path,
401
+ xgb_cfg_path=xgb_cfg_path or None,
402
+ resn_cfg_path=resn_cfg_path or None,
403
+ input_path=input_path,
404
+ output_path=output_path,
405
+ model_name=model_name or None,
406
+ model_keys=model_keys,
407
+ )
408
+
409
+ def run_compare_xgb_ui(
410
+ self,
411
+ direct_cfg_path: str,
412
+ ft_cfg_path: str,
413
+ ft_embed_cfg_path: str,
414
+ label_direct: str,
415
+ label_ft: str,
416
+ use_runtime_ft_embedding: bool,
417
+ n_bins_override: int,
418
+ ):
419
+ yield from self._run_workflow(
420
+ "Compare XGB",
421
+ run_compare_ft_embed,
422
+ direct_cfg_path=direct_cfg_path,
423
+ ft_cfg_path=ft_cfg_path,
424
+ ft_embed_cfg_path=ft_embed_cfg_path,
425
+ model_key="xgb",
426
+ label_direct=label_direct,
427
+ label_ft=label_ft,
428
+ use_runtime_ft_embedding=use_runtime_ft_embedding,
429
+ n_bins_override=n_bins_override,
430
+ )
431
+
432
+ def run_compare_resn_ui(
433
+ self,
434
+ direct_cfg_path: str,
435
+ ft_cfg_path: str,
436
+ ft_embed_cfg_path: str,
437
+ label_direct: str,
438
+ label_ft: str,
439
+ use_runtime_ft_embedding: bool,
440
+ n_bins_override: int,
441
+ ):
442
+ yield from self._run_workflow(
443
+ "Compare ResNet",
444
+ run_compare_ft_embed,
445
+ direct_cfg_path=direct_cfg_path,
446
+ ft_cfg_path=ft_cfg_path,
447
+ ft_embed_cfg_path=ft_embed_cfg_path,
448
+ model_key="resn",
449
+ label_direct=label_direct,
450
+ label_ft=label_ft,
451
+ use_runtime_ft_embedding=use_runtime_ft_embedding,
452
+ n_bins_override=n_bins_override,
453
+ )
454
+
455
+
456
+ def create_ui():
457
+ """Create the Gradio interface."""
458
+ app = PricingApp()
459
+
460
+ with gr.Blocks(title="Insurance Pricing Model Training", theme=gr.themes.Soft()) as demo:
461
+ gr.Markdown(
462
+ """
463
+ # Insurance Pricing Model Training Interface
464
+ Configure and train insurance pricing models with an easy-to-use interface.
465
+
466
+ **Two ways to configure:**
467
+ 1. **Upload JSON Config**: Upload an existing configuration file
468
+ 2. **Manual Configuration**: Fill in the parameters below
469
+ """
470
+ )
471
+
472
+ with gr.Tab("Configuration"):
473
+ with gr.Row():
474
+ with gr.Column(scale=1):
475
+ gr.Markdown("### Load Configuration")
476
+ json_file = gr.File(
477
+ label="Upload JSON Config File",
478
+ file_types=[".json"],
479
+ type="filepath"
480
+ )
481
+ load_btn = gr.Button("Load Config", variant="primary")
482
+ load_status = gr.Textbox(
483
+ label="Load Status", interactive=False)
484
+
485
+ with gr.Column(scale=2):
486
+ gr.Markdown("### Current Configuration")
487
+ config_display = gr.JSON(label="Configuration", value={})
488
+
489
+ gr.Markdown("---")
490
+ gr.Markdown("### Manual Configuration")
491
+
492
+ with gr.Row():
493
+ with gr.Column():
494
+ gr.Markdown("#### Data Settings")
495
+ data_dir = gr.Textbox(
496
+ label="Data Directory", value="./Data")
497
+ model_list = gr.Textbox(
498
+ label="Model List (comma-separated)", value="od")
499
+ model_categories = gr.Textbox(
500
+ label="Model Categories (comma-separated)", value="bc")
501
+ target = gr.Textbox(
502
+ label="Target Column", value="response")
503
+ weight = gr.Textbox(label="Weight Column", value="weights")
504
+
505
+ gr.Markdown("#### Features")
506
+ feature_list = gr.Textbox(
507
+ label="Feature List (comma-separated)",
508
+ placeholder="feature_1, feature_2, feature_3",
509
+ lines=3
510
+ )
511
+ categorical_features = gr.Textbox(
512
+ label="Categorical Features (comma-separated)",
513
+ placeholder="feature_2, feature_3",
514
+ lines=2
515
+ )
516
+
517
+ with gr.Column():
518
+ gr.Markdown("#### Model Settings")
519
+ task_type = gr.Dropdown(
520
+ label="Task Type",
521
+ choices=["regression", "binary", "multiclass"],
522
+ value="regression"
523
+ )
524
+ prop_test = gr.Slider(
525
+ label="Test Proportion", minimum=0.1, maximum=0.5, value=0.25, step=0.05)
526
+ holdout_ratio = gr.Slider(
527
+ label="Holdout Ratio", minimum=0.1, maximum=0.5, value=0.25, step=0.05)
528
+ val_ratio = gr.Slider(
529
+ label="Validation Ratio", minimum=0.1, maximum=0.5, value=0.25, step=0.05)
530
+ split_strategy = gr.Dropdown(
531
+ label="Split Strategy",
532
+ choices=["random", "stratified", "time", "group"],
533
+ value="random"
534
+ )
535
+ rand_seed = gr.Number(
536
+ label="Random Seed", value=13, precision=0)
537
+ epochs = gr.Number(label="Epochs", value=50, precision=0)
538
+
539
+ with gr.Row():
540
+ with gr.Column():
541
+ gr.Markdown("#### Training Settings")
542
+ output_dir = gr.Textbox(
543
+ label="Output Directory", value="./Results")
544
+ use_gpu = gr.Checkbox(label="Use GPU", value=True)
545
+ model_keys = gr.Textbox(
546
+ label="Model Keys (comma-separated)",
547
+ value="xgb, resn",
548
+ placeholder="xgb, resn, ft, gnn"
549
+ )
550
+ max_evals = gr.Number(
551
+ label="Max Evaluations", value=50, precision=0)
552
+
553
+ with gr.Column():
554
+ gr.Markdown("#### XGBoost Settings")
555
+ xgb_max_depth_max = gr.Number(
556
+ label="XGB Max Depth", value=25, precision=0)
557
+ xgb_n_estimators_max = gr.Number(
558
+ label="XGB Max Estimators", value=500, precision=0)
559
+
560
+ with gr.Row():
561
+ build_btn = gr.Button(
562
+ "Build Configuration", variant="primary", size="lg")
563
+ save_config_btn = gr.Button(
564
+ "Save Configuration", variant="secondary", size="lg")
565
+
566
+ with gr.Row():
567
+ build_status = gr.Textbox(label="Status", interactive=False)
568
+ config_json = gr.Textbox(
569
+ label="Generated Config (JSON)", lines=10, max_lines=20)
570
+
571
+ save_filename = gr.Textbox(
572
+ label="Save Filename", value="my_config.json")
573
+ save_status = gr.Textbox(label="Save Status", interactive=False)
574
+
575
+ with gr.Tab("Run Task"):
576
+ gr.Markdown(
577
+ """
578
+ ### Run Model Task
579
+ Click the button below to execute the task defined in your configuration.
580
+ Task type is automatically detected from `config.runner.mode`:
581
+ - **entry**: Standard model training
582
+ - **explain**: Model explanation (permutation, SHAP, integrated gradients)
583
+ - **incremental**: Incremental training
584
+ - **watchdog**: Watchdog mode
585
+
586
+ Task logs will appear in real-time below.
587
+ """
588
+ )
589
+
590
+ with gr.Row():
591
+ run_btn = gr.Button("Run Task", variant="primary", size="lg")
592
+ run_status = gr.Textbox(label="Task Status", interactive=False)
593
+
594
+ gr.Markdown("### Task Logs")
595
+ log_output = gr.Textbox(
596
+ label="Logs",
597
+ lines=25,
598
+ max_lines=50,
599
+ interactive=False,
600
+ autoscroll=True
601
+ )
602
+
603
+ gr.Markdown("---")
604
+ with gr.Row():
605
+ open_folder_btn = gr.Button("Open Results Folder", size="lg")
606
+ folder_status = gr.Textbox(
607
+ label="Status", interactive=False, scale=2)
608
+
609
+ with gr.Tab("FT Two-Step Workflow"):
610
+ gr.Markdown(
611
+ """
612
+ ### FT-Transformer Two-Step Training
613
+
614
+ Automates the FT → XGB/ResN workflow:
615
+ 1. **Step 1**: Train FT-Transformer as unsupervised embedding generator
616
+ 2. **Step 2**: Merge embeddings with raw data and train XGB/ResN
617
+
618
+ **Instructions**:
619
+ 1. Load or build a base configuration in the Configuration tab
620
+ 2. Prepare Step 1 config (FT embeddings)
621
+ 3. Run Step 1 to generate embeddings
622
+ 4. Prepare Step 2 configs (XGB/ResN using embeddings)
623
+ 5. Run Step 2 with the generated configs
624
+ """
625
+ )
626
+
627
+ with gr.Row():
628
+ with gr.Column():
629
+ gr.Markdown("### Step 1: FT Embedding Generation")
630
+ ft_use_ddp = gr.Checkbox(
631
+ label="Use DDP for FT", value=True)
632
+ ft_nproc = gr.Number(
633
+ label="Number of Processes (DDP)", value=2, precision=0)
634
+
635
+ prepare_step1_btn = gr.Button(
636
+ "Prepare Step 1 Config", variant="primary")
637
+ step1_status = gr.Textbox(
638
+ label="Status", interactive=False)
639
+ step1_config_display = gr.Textbox(
640
+ label="Step 1 Config (FT Embedding)",
641
+ lines=15,
642
+ max_lines=25
643
+ )
644
+
645
+ with gr.Column():
646
+ gr.Markdown("### Step 2: Train XGB/ResN with Embeddings")
647
+ target_models_input = gr.Textbox(
648
+ label="Target Models (comma-separated)",
649
+ value="xgb, resn",
650
+ placeholder="xgb, resn"
651
+ )
652
+
653
+ prepare_step2_btn = gr.Button(
654
+ "Prepare Step 2 Configs", variant="primary")
655
+ step2_status = gr.Textbox(
656
+ label="Status", interactive=False)
657
+
658
+ with gr.Tab("XGB Config"):
659
+ xgb_config_display = gr.Textbox(
660
+ label="XGB Step 2 Config",
661
+ lines=15,
662
+ max_lines=25
663
+ )
664
+
665
+ with gr.Tab("ResN Config"):
666
+ resn_config_display = gr.Textbox(
667
+ label="ResN Step 2 Config",
668
+ lines=15,
669
+ max_lines=25
670
+ )
671
+
672
+ gr.Markdown("---")
673
+ gr.Markdown(
674
+ """
675
+ ### Quick Actions
676
+ After preparing configs, you can:
677
+ - Copy the Step 1 config and paste it in the **Configuration** tab, then run it in **Run Task** tab
678
+ - After Step 1 completes, click **Prepare Step 2 Configs**
679
+ - Copy the Step 2 configs (XGB or ResN) and run them in **Run Task** tab
680
+ """
681
+ )
682
+
683
+ with gr.Tab("Plotting"):
684
+ gr.Markdown(
685
+ """
686
+ ### Plotting Workflows
687
+ Run the plotting steps from the example notebooks.
688
+ """
689
+ )
690
+
691
+ with gr.Tab("Pre Oneway"):
692
+ with gr.Row():
693
+ with gr.Column():
694
+ pre_data_path = gr.Textbox(
695
+ label="Data Path", value="./Data/od_bc.csv")
696
+ pre_model_name = gr.Textbox(
697
+ label="Model Name", value="od_bc")
698
+ pre_target = gr.Textbox(
699
+ label="Target Column", value="response")
700
+ pre_weight = gr.Textbox(
701
+ label="Weight Column", value="weights")
702
+ pre_output_dir = gr.Textbox(
703
+ label="Output Dir (optional)", value="")
704
+ with gr.Column():
705
+ pre_feature_list = gr.Textbox(
706
+ label="Feature List (comma-separated)",
707
+ lines=4,
708
+ placeholder="feature_1, feature_2, feature_3",
709
+ )
710
+ pre_categorical = gr.Textbox(
711
+ label="Categorical Features (comma-separated, optional)",
712
+ lines=3,
713
+ placeholder="feature_2, feature_3",
714
+ )
715
+ pre_n_bins = gr.Number(
716
+ label="Bins", value=10, precision=0)
717
+ pre_holdout = gr.Slider(
718
+ label="Holdout Ratio",
719
+ minimum=0.0,
720
+ maximum=0.5,
721
+ value=0.25,
722
+ step=0.05,
723
+ )
724
+ pre_seed = gr.Number(
725
+ label="Random Seed", value=13, precision=0)
726
+
727
+ pre_run_btn = gr.Button("Run Pre Oneway", variant="primary")
728
+ pre_status = gr.Textbox(label="Status", interactive=False)
729
+ pre_log = gr.Textbox(label="Logs", lines=15,
730
+ max_lines=40, interactive=False)
731
+
732
+ with gr.Tab("Direct Plot"):
733
+ direct_cfg_path = gr.Textbox(
734
+ label="Plot Config", value="config_plot.json")
735
+ direct_xgb_cfg = gr.Textbox(
736
+ label="XGB Config", value="config_xgb_direct.json")
737
+ direct_resn_cfg = gr.Textbox(
738
+ label="ResN Config", value="config_resn_direct.json")
739
+ direct_run_btn = gr.Button(
740
+ "Run Direct Plot", variant="primary")
741
+ direct_status = gr.Textbox(label="Status", interactive=False)
742
+ direct_log = gr.Textbox(
743
+ label="Logs", lines=15, max_lines=40, interactive=False)
744
+
745
+ with gr.Tab("Embed Plot"):
746
+ embed_cfg_path = gr.Textbox(
747
+ label="Plot Config", value="config_plot.json")
748
+ embed_xgb_cfg = gr.Textbox(
749
+ label="XGB Embed Config", value="config_xgb_from_ft_unsupervised.json")
750
+ embed_resn_cfg = gr.Textbox(
751
+ label="ResN Embed Config", value="config_resn_from_ft_unsupervised.json")
752
+ embed_ft_cfg = gr.Textbox(
753
+ label="FT Embed Config", value="config_ft_unsupervised_ddp_embed.json")
754
+ embed_runtime = gr.Checkbox(
755
+ label="Use Runtime FT Embedding", value=False)
756
+ embed_run_btn = gr.Button("Run Embed Plot", variant="primary")
757
+ embed_status = gr.Textbox(label="Status", interactive=False)
758
+ embed_log = gr.Textbox(
759
+ label="Logs", lines=15, max_lines=40, interactive=False)
760
+
761
+ with gr.Tab("Prediction"):
762
+ gr.Markdown("### FT Embed Prediction")
763
+ pred_ft_cfg = gr.Textbox(
764
+ label="FT Config", value="config_ft_unsupervised_ddp_embed.json")
765
+ pred_xgb_cfg = gr.Textbox(
766
+ label="XGB Config (optional)", value="config_xgb_from_ft_unsupervised.json")
767
+ pred_resn_cfg = gr.Textbox(
768
+ label="ResN Config (optional)", value="config_resn_from_ft_unsupervised.json")
769
+ pred_input = gr.Textbox(
770
+ label="Input Data", value="./Data/od_bc_new.csv")
771
+ pred_output = gr.Textbox(
772
+ label="Output CSV", value="./Results/predictions_ft_xgb.csv")
773
+ pred_model_name = gr.Textbox(
774
+ label="Model Name (optional)", value="")
775
+ pred_model_keys = gr.Textbox(label="Model Keys", value="xgb, resn")
776
+ pred_run_btn = gr.Button("Run Prediction", variant="primary")
777
+ pred_status = gr.Textbox(label="Status", interactive=False)
778
+ pred_log = gr.Textbox(label="Logs", lines=15,
779
+ max_lines=40, interactive=False)
780
+
781
+ with gr.Tab("Compare"):
782
+ gr.Markdown("### Compare Direct vs FT-Embed Models")
783
+
784
+ with gr.Tab("Compare XGB"):
785
+ cmp_xgb_direct_cfg = gr.Textbox(
786
+ label="Direct XGB Config", value="config_xgb_direct.json")
787
+ cmp_xgb_ft_cfg = gr.Textbox(
788
+ label="FT Config", value="config_ft_unsupervised_ddp_embed.json")
789
+ cmp_xgb_embed_cfg = gr.Textbox(
790
+ label="FT-Embed XGB Config", value="config_xgb_from_ft_unsupervised.json")
791
+ cmp_xgb_label_direct = gr.Textbox(
792
+ label="Direct Label", value="XGB_raw")
793
+ cmp_xgb_label_ft = gr.Textbox(
794
+ label="FT Label", value="XGB_ft_embed")
795
+ cmp_xgb_runtime = gr.Checkbox(
796
+ label="Use Runtime FT Embedding", value=False)
797
+ cmp_xgb_bins = gr.Number(
798
+ label="Bins Override", value=10, precision=0)
799
+ cmp_xgb_run_btn = gr.Button(
800
+ "Run XGB Compare", variant="primary")
801
+ cmp_xgb_status = gr.Textbox(label="Status", interactive=False)
802
+ cmp_xgb_log = gr.Textbox(
803
+ label="Logs", lines=15, max_lines=40, interactive=False)
804
+
805
+ with gr.Tab("Compare ResNet"):
806
+ cmp_resn_direct_cfg = gr.Textbox(
807
+ label="Direct ResN Config", value="config_resn_direct.json")
808
+ cmp_resn_ft_cfg = gr.Textbox(
809
+ label="FT Config", value="config_ft_unsupervised_ddp_embed.json")
810
+ cmp_resn_embed_cfg = gr.Textbox(
811
+ label="FT-Embed ResN Config", value="config_resn_from_ft_unsupervised.json")
812
+ cmp_resn_label_direct = gr.Textbox(
813
+ label="Direct Label", value="ResN_raw")
814
+ cmp_resn_label_ft = gr.Textbox(
815
+ label="FT Label", value="ResN_ft_embed")
816
+ cmp_resn_runtime = gr.Checkbox(
817
+ label="Use Runtime FT Embedding", value=False)
818
+ cmp_resn_bins = gr.Number(
819
+ label="Bins Override", value=10, precision=0)
820
+ cmp_resn_run_btn = gr.Button(
821
+ "Run ResNet Compare", variant="primary")
822
+ cmp_resn_status = gr.Textbox(label="Status", interactive=False)
823
+ cmp_resn_log = gr.Textbox(
824
+ label="Logs", lines=15, max_lines=40, interactive=False)
825
+
826
+ # Event handlers
827
+ load_btn.click(
828
+ fn=app.load_json_config,
829
+ inputs=[json_file],
830
+ outputs=[load_status, config_display, config_json]
831
+ )
832
+
833
+ build_btn.click(
834
+ fn=app.build_config_from_ui,
835
+ inputs=[
836
+ data_dir, model_list, model_categories, target, weight,
837
+ feature_list, categorical_features, task_type, prop_test,
838
+ holdout_ratio, val_ratio, split_strategy, rand_seed, epochs,
839
+ output_dir, use_gpu, model_keys, max_evals,
840
+ xgb_max_depth_max, xgb_n_estimators_max
841
+ ],
842
+ outputs=[build_status, config_json]
843
+ )
844
+
845
+ save_config_btn.click(
846
+ fn=app.save_config,
847
+ inputs=[config_json, save_filename],
848
+ outputs=[save_status]
849
+ )
850
+
851
+ run_btn.click(
852
+ fn=app.run_training,
853
+ inputs=[config_json],
854
+ outputs=[run_status, log_output]
855
+ )
856
+
857
+ open_folder_btn.click(
858
+ fn=app.open_results_folder,
859
+ inputs=[config_json],
860
+ outputs=[folder_status]
861
+ )
862
+
863
+ prepare_step1_btn.click(
864
+ fn=app.prepare_ft_step1,
865
+ inputs=[config_json, ft_use_ddp, ft_nproc],
866
+ outputs=[step1_status, step1_config_display]
867
+ )
868
+
869
+ prepare_step2_btn.click(
870
+ fn=app.prepare_ft_step2,
871
+ inputs=[gr.State(
872
+ lambda: app.current_step1_config or "temp_ft_step1_config.json"), target_models_input],
873
+ outputs=[step2_status, xgb_config_display, resn_config_display]
874
+ )
875
+
876
+ pre_run_btn.click(
877
+ fn=app.run_pre_oneway_ui,
878
+ inputs=[
879
+ pre_data_path, pre_model_name, pre_target, pre_weight,
880
+ pre_feature_list, pre_categorical, pre_n_bins,
881
+ pre_holdout, pre_seed, pre_output_dir
882
+ ],
883
+ outputs=[pre_status, pre_log]
884
+ )
885
+
886
+ direct_run_btn.click(
887
+ fn=app.run_plot_direct_ui,
888
+ inputs=[direct_cfg_path, direct_xgb_cfg, direct_resn_cfg],
889
+ outputs=[direct_status, direct_log]
890
+ )
891
+
892
+ embed_run_btn.click(
893
+ fn=app.run_plot_embed_ui,
894
+ inputs=[embed_cfg_path, embed_xgb_cfg,
895
+ embed_resn_cfg, embed_ft_cfg, embed_runtime],
896
+ outputs=[embed_status, embed_log]
897
+ )
898
+
899
+ pred_run_btn.click(
900
+ fn=app.run_predict_ui,
901
+ inputs=[
902
+ pred_ft_cfg, pred_xgb_cfg, pred_resn_cfg, pred_input,
903
+ pred_output, pred_model_name, pred_model_keys
904
+ ],
905
+ outputs=[pred_status, pred_log]
906
+ )
907
+
908
+ cmp_xgb_run_btn.click(
909
+ fn=app.run_compare_xgb_ui,
910
+ inputs=[
911
+ cmp_xgb_direct_cfg, cmp_xgb_ft_cfg, cmp_xgb_embed_cfg,
912
+ cmp_xgb_label_direct, cmp_xgb_label_ft,
913
+ cmp_xgb_runtime, cmp_xgb_bins
914
+ ],
915
+ outputs=[cmp_xgb_status, cmp_xgb_log]
916
+ )
917
+
918
+ cmp_resn_run_btn.click(
919
+ fn=app.run_compare_resn_ui,
920
+ inputs=[
921
+ cmp_resn_direct_cfg, cmp_resn_ft_cfg, cmp_resn_embed_cfg,
922
+ cmp_resn_label_direct, cmp_resn_label_ft,
923
+ cmp_resn_runtime, cmp_resn_bins
924
+ ],
925
+ outputs=[cmp_resn_status, cmp_resn_log]
926
+ )
927
+
928
+ return demo
929
+
930
+
896
931
  if __name__ == "__main__":
897
932
  demo = create_ui()
898
- demo.launch(
899
- server_name="0.0.0.0",
900
- server_port=7860,
901
- share=False,
902
- show_error=True,
903
- )
933
+ launch_kwargs = {
934
+ "server_name": "0.0.0.0",
935
+ "server_port": 7860,
936
+ "share": False,
937
+ "show_error": True,
938
+ }
939
+ if "analytics_enabled" in inspect.signature(demo.launch).parameters:
940
+ launch_kwargs["analytics_enabled"] = False
941
+ demo.launch(**launch_kwargs)