ins-pricing 0.3.0__tar.gz → 0.3.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/PKG-INFO +1 -1
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/Explain_entry.py +50 -48
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/bayesopt_entry_runner.py +73 -70
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +4 -3
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/models/model_gnn.py +114 -14
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +6 -4
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py +2 -2
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/utils/torch_trainer_mixin.py +7 -3
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/production/predict.py +5 -4
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/setup.py +1 -1
- ins_pricing-0.3.2/ins_pricing/utils/torch_compat.py +45 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing.egg-info/PKG-INFO +1 -1
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing.egg-info/SOURCES.txt +1 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/pyproject.toml +1 -1
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/MANIFEST.in +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/README.md +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/CHANGELOG.md +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/README.md +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/RELEASE_NOTES_0.2.8.md +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/__init__.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/BayesOpt_entry.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/BayesOpt_incremental.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/Explain_Run.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/Pricing_Run.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/__init__.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/utils/__init__.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/utils/cli_common.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/utils/cli_config.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/utils/evaluation_context.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/utils/import_resolver.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/utils/notebook_utils.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/utils/run_logging.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/watchdog_run.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/docs/modelling/BayesOpt_USAGE.md +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/docs/modelling/README.md +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/exceptions.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/governance/README.md +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/governance/__init__.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/governance/approval.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/governance/audit.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/governance/registry.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/governance/release.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/__init__.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/BayesOpt.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/__init__.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/__init__.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/config_components.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/config_preprocess.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/core.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/model_explain_mixin.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/model_plotting_mixin.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/models/__init__.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/models/model_ft_components.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/models/model_resn.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/trainers/trainer_glm.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/trainers/trainer_resn.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/trainers/trainer_xgb.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/utils/distributed_utils.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/utils.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/evaluation.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/explain/__init__.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/explain/gradients.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/explain/metrics.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/explain/permutation.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/explain/shap_utils.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/plotting/__init__.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/plotting/common.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/plotting/curves.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/plotting/diagnostics.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/plotting/geo.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/plotting/importance.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/pricing/README.md +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/pricing/__init__.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/pricing/calibration.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/pricing/data_quality.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/pricing/exposure.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/pricing/factors.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/pricing/monitoring.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/pricing/rate_table.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/production/__init__.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/production/drift.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/production/monitoring.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/production/preprocess.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/production/scoring.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/reporting/README.md +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/reporting/__init__.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/reporting/report_builder.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/reporting/scheduler.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/governance/__init__.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/governance/test_audit.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/governance/test_registry.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/governance/test_release.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/modelling/conftest.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/modelling/test_cross_val_generic.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/modelling/test_distributed_utils.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/modelling/test_explain.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/modelling/test_geo_tokens_split.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/modelling/test_graph_cache.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/modelling/test_plotting.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/modelling/test_plotting_library.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/modelling/test_preprocessor.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/pricing/__init__.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/pricing/test_calibration.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/pricing/test_exposure.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/pricing/test_factors.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/pricing/test_rate_table.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/production/__init__.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/production/test_monitoring.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/production/test_predict.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/production/test_preprocess.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/production/test_scoring.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/utils/__init__.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/utils/device.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/utils/logging.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/utils/metrics.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/utils/paths.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/utils/profiling.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/utils/validation.py +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing.egg-info/dependency_links.txt +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing.egg-info/requires.txt +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing.egg-info/top_level.txt +0 -0
- {ins_pricing-0.3.0 → ins_pricing-0.3.2}/setup.cfg +0 -0
|
@@ -491,54 +491,56 @@ def explain_from_config(args: argparse.Namespace) -> None:
|
|
|
491
491
|
categorical_features = cfg.get("categorical_features")
|
|
492
492
|
plot_path_style = runtime_cfg["plot_path_style"]
|
|
493
493
|
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
cfg["
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
)
|
|
494
|
+
config_fields = getattr(ropt.BayesOptConfig, "__dataclass_fields__", {})
|
|
495
|
+
allowed_config_keys = set(config_fields.keys())
|
|
496
|
+
config_payload = {k: v for k, v in cfg.items() if k in allowed_config_keys}
|
|
497
|
+
config_payload.update({
|
|
498
|
+
"model_nme": model_name,
|
|
499
|
+
"resp_nme": cfg["target"],
|
|
500
|
+
"weight_nme": cfg["weight"],
|
|
501
|
+
"factor_nmes": feature_list,
|
|
502
|
+
"task_type": str(cfg.get("task_type", "regression")),
|
|
503
|
+
"binary_resp_nme": binary_target,
|
|
504
|
+
"cate_list": categorical_features,
|
|
505
|
+
"prop_test": prop_test,
|
|
506
|
+
"rand_seed": rand_seed,
|
|
507
|
+
"epochs": int(runtime_cfg["epochs"]),
|
|
508
|
+
"use_gpu": bool(cfg.get("use_gpu", True)),
|
|
509
|
+
"output_dir": output_dir,
|
|
510
|
+
"xgb_max_depth_max": runtime_cfg["xgb_max_depth_max"],
|
|
511
|
+
"xgb_n_estimators_max": runtime_cfg["xgb_n_estimators_max"],
|
|
512
|
+
"resn_weight_decay": cfg.get("resn_weight_decay"),
|
|
513
|
+
"final_ensemble": bool(cfg.get("final_ensemble", False)),
|
|
514
|
+
"final_ensemble_k": int(cfg.get("final_ensemble_k", 3)),
|
|
515
|
+
"final_refit": bool(cfg.get("final_refit", True)),
|
|
516
|
+
"optuna_storage": runtime_cfg["optuna_storage"],
|
|
517
|
+
"optuna_study_prefix": runtime_cfg["optuna_study_prefix"],
|
|
518
|
+
"best_params_files": runtime_cfg["best_params_files"],
|
|
519
|
+
"gnn_use_approx_knn": cfg.get("gnn_use_approx_knn", True),
|
|
520
|
+
"gnn_approx_knn_threshold": cfg.get("gnn_approx_knn_threshold", 50000),
|
|
521
|
+
"gnn_graph_cache": cfg.get("gnn_graph_cache"),
|
|
522
|
+
"gnn_max_gpu_knn_nodes": cfg.get("gnn_max_gpu_knn_nodes", 200000),
|
|
523
|
+
"gnn_knn_gpu_mem_ratio": cfg.get("gnn_knn_gpu_mem_ratio", 0.9),
|
|
524
|
+
"gnn_knn_gpu_mem_overhead": cfg.get("gnn_knn_gpu_mem_overhead", 2.0),
|
|
525
|
+
"region_province_col": cfg.get("region_province_col"),
|
|
526
|
+
"region_city_col": cfg.get("region_city_col"),
|
|
527
|
+
"region_effect_alpha": cfg.get("region_effect_alpha"),
|
|
528
|
+
"geo_feature_nmes": cfg.get("geo_feature_nmes"),
|
|
529
|
+
"geo_token_hidden_dim": cfg.get("geo_token_hidden_dim"),
|
|
530
|
+
"geo_token_layers": cfg.get("geo_token_layers"),
|
|
531
|
+
"geo_token_dropout": cfg.get("geo_token_dropout"),
|
|
532
|
+
"geo_token_k_neighbors": cfg.get("geo_token_k_neighbors"),
|
|
533
|
+
"geo_token_learning_rate": cfg.get("geo_token_learning_rate"),
|
|
534
|
+
"geo_token_epochs": cfg.get("geo_token_epochs"),
|
|
535
|
+
"ft_role": str(cfg.get("ft_role", "model")),
|
|
536
|
+
"ft_feature_prefix": str(cfg.get("ft_feature_prefix", "ft_emb")),
|
|
537
|
+
"ft_num_numeric_tokens": cfg.get("ft_num_numeric_tokens"),
|
|
538
|
+
"reuse_best_params": runtime_cfg["reuse_best_params"],
|
|
539
|
+
"plot_path_style": plot_path_style or "nested",
|
|
540
|
+
})
|
|
541
|
+
config_payload = {k: v for k, v in config_payload.items() if v is not None}
|
|
542
|
+
config = ropt.BayesOptConfig(**config_payload)
|
|
543
|
+
model = ropt.BayesOptModel(train_df, test_df, config=config)
|
|
542
544
|
|
|
543
545
|
output_overrides = resolve_explain_output_overrides(
|
|
544
546
|
explain_cfg,
|
|
@@ -1223,76 +1223,79 @@ def train_from_config(args: argparse.Namespace) -> None:
|
|
|
1223
1223
|
cfg.get("ft_feature_prefix", args.ft_feature_prefix))
|
|
1224
1224
|
ft_num_numeric_tokens = cfg.get("ft_num_numeric_tokens")
|
|
1225
1225
|
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
|
|
1259
|
-
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1226
|
+
config_fields = getattr(ropt.BayesOptConfig,
|
|
1227
|
+
"__dataclass_fields__", {})
|
|
1228
|
+
allowed_config_keys = set(config_fields.keys())
|
|
1229
|
+
config_payload = {k: v for k,
|
|
1230
|
+
v in cfg.items() if k in allowed_config_keys}
|
|
1231
|
+
config_payload.update({
|
|
1232
|
+
"model_nme": model_name,
|
|
1233
|
+
"resp_nme": cfg["target"],
|
|
1234
|
+
"weight_nme": cfg["weight"],
|
|
1235
|
+
"factor_nmes": feature_list,
|
|
1236
|
+
"task_type": task_type,
|
|
1237
|
+
"binary_resp_nme": binary_target,
|
|
1238
|
+
"cate_list": categorical_features,
|
|
1239
|
+
"prop_test": val_ratio,
|
|
1240
|
+
"rand_seed": rand_seed,
|
|
1241
|
+
"epochs": epochs,
|
|
1242
|
+
"use_gpu": use_gpu,
|
|
1243
|
+
"use_resn_data_parallel": use_resn_dp,
|
|
1244
|
+
"use_ft_data_parallel": use_ft_dp,
|
|
1245
|
+
"use_gnn_data_parallel": use_gnn_dp,
|
|
1246
|
+
"use_resn_ddp": use_resn_ddp,
|
|
1247
|
+
"use_ft_ddp": use_ft_ddp,
|
|
1248
|
+
"use_gnn_ddp": use_gnn_ddp,
|
|
1249
|
+
"output_dir": output_dir,
|
|
1250
|
+
"xgb_max_depth_max": xgb_max_depth_max,
|
|
1251
|
+
"xgb_n_estimators_max": xgb_n_estimators_max,
|
|
1252
|
+
"resn_weight_decay": cfg.get("resn_weight_decay"),
|
|
1253
|
+
"final_ensemble": bool(cfg.get("final_ensemble", False)),
|
|
1254
|
+
"final_ensemble_k": int(cfg.get("final_ensemble_k", 3)),
|
|
1255
|
+
"final_refit": bool(cfg.get("final_refit", True)),
|
|
1256
|
+
"optuna_storage": optuna_storage,
|
|
1257
|
+
"optuna_study_prefix": optuna_study_prefix,
|
|
1258
|
+
"best_params_files": best_params_files,
|
|
1259
|
+
"gnn_use_approx_knn": gnn_use_ann,
|
|
1260
|
+
"gnn_approx_knn_threshold": gnn_threshold,
|
|
1261
|
+
"gnn_graph_cache": gnn_graph_cache,
|
|
1262
|
+
"gnn_max_gpu_knn_nodes": gnn_max_gpu_nodes,
|
|
1263
|
+
"gnn_knn_gpu_mem_ratio": gnn_gpu_mem_ratio,
|
|
1264
|
+
"gnn_knn_gpu_mem_overhead": gnn_gpu_mem_overhead,
|
|
1265
|
+
"region_province_col": region_province_col,
|
|
1266
|
+
"region_city_col": region_city_col,
|
|
1267
|
+
"region_effect_alpha": region_effect_alpha,
|
|
1268
|
+
"geo_feature_nmes": geo_feature_nmes,
|
|
1269
|
+
"geo_token_hidden_dim": geo_token_hidden_dim,
|
|
1270
|
+
"geo_token_layers": geo_token_layers,
|
|
1271
|
+
"geo_token_dropout": geo_token_dropout,
|
|
1272
|
+
"geo_token_k_neighbors": geo_token_k_neighbors,
|
|
1273
|
+
"geo_token_learning_rate": geo_token_learning_rate,
|
|
1274
|
+
"geo_token_epochs": geo_token_epochs,
|
|
1275
|
+
"ft_role": ft_role,
|
|
1276
|
+
"ft_feature_prefix": ft_feature_prefix,
|
|
1277
|
+
"ft_num_numeric_tokens": ft_num_numeric_tokens,
|
|
1278
|
+
"reuse_best_params": reuse_best_params,
|
|
1279
|
+
"bo_sample_limit": bo_sample_limit,
|
|
1280
|
+
"cache_predictions": cache_predictions,
|
|
1281
|
+
"prediction_cache_dir": prediction_cache_dir,
|
|
1282
|
+
"prediction_cache_format": prediction_cache_format,
|
|
1283
|
+
"cv_strategy": cv_strategy or split_strategy,
|
|
1284
|
+
"cv_group_col": cv_group_col or split_group_col,
|
|
1285
|
+
"cv_time_col": cv_time_col or split_time_col,
|
|
1286
|
+
"cv_time_ascending": cv_time_ascending,
|
|
1287
|
+
"cv_splits": cv_splits,
|
|
1288
|
+
"ft_oof_folds": ft_oof_folds,
|
|
1289
|
+
"ft_oof_strategy": ft_oof_strategy,
|
|
1290
|
+
"ft_oof_shuffle": ft_oof_shuffle,
|
|
1291
|
+
"save_preprocess": save_preprocess,
|
|
1292
|
+
"preprocess_artifact_path": preprocess_artifact_path,
|
|
1293
|
+
"plot_path_style": plot_path_style or "nested",
|
|
1294
|
+
})
|
|
1295
|
+
config_payload = {
|
|
1296
|
+
k: v for k, v in config_payload.items() if v is not None}
|
|
1297
|
+
config = ropt.BayesOptConfig(**config_payload)
|
|
1298
|
+
model = ropt.BayesOptModel(train_df, test_df, config=config)
|
|
1296
1299
|
|
|
1297
1300
|
if plot_requested:
|
|
1298
1301
|
plot_cfg = cfg.get("plot", {})
|
|
@@ -626,6 +626,7 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
|
|
|
626
626
|
best_state = None
|
|
627
627
|
patience_counter = 0
|
|
628
628
|
is_ddp_model = isinstance(self.ft, DDP)
|
|
629
|
+
use_collectives = dist.is_initialized() and is_ddp_model
|
|
629
630
|
|
|
630
631
|
clip_fn = None
|
|
631
632
|
if self.device.type == 'cuda':
|
|
@@ -669,7 +670,7 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
|
|
|
669
670
|
device=X_num_b.device)
|
|
670
671
|
local_bad = 0 if bool(torch.isfinite(batch_loss)) else 1
|
|
671
672
|
global_bad = local_bad
|
|
672
|
-
if
|
|
673
|
+
if use_collectives:
|
|
673
674
|
bad = torch.tensor(
|
|
674
675
|
[local_bad],
|
|
675
676
|
device=batch_loss.device,
|
|
@@ -774,7 +775,7 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
|
|
|
774
775
|
total_n += float(end - start)
|
|
775
776
|
val_loss_tensor[0] = total_val / max(total_n, 1.0)
|
|
776
777
|
|
|
777
|
-
if
|
|
778
|
+
if use_collectives:
|
|
778
779
|
dist.broadcast(val_loss_tensor, src=0)
|
|
779
780
|
val_loss_value = float(val_loss_tensor.item())
|
|
780
781
|
prune_now = False
|
|
@@ -806,7 +807,7 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
|
|
|
806
807
|
if trial.should_prune():
|
|
807
808
|
prune_now = True
|
|
808
809
|
|
|
809
|
-
if
|
|
810
|
+
if use_collectives:
|
|
810
811
|
flag = torch.tensor(
|
|
811
812
|
[1 if prune_now else 0],
|
|
812
813
|
device=loss_tensor_device,
|
{ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/models/model_gnn.py
RENAMED
|
@@ -42,6 +42,12 @@ _GNN_MPS_WARNED = False
|
|
|
42
42
|
# Simplified GNN implementation.
|
|
43
43
|
# =============================================================================
|
|
44
44
|
|
|
45
|
+
def _adj_mm(adj: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
|
46
|
+
"""Matrix multiply that supports sparse or dense adjacency."""
|
|
47
|
+
if adj.is_sparse:
|
|
48
|
+
return torch.sparse.mm(adj, x)
|
|
49
|
+
return adj.matmul(x)
|
|
50
|
+
|
|
45
51
|
|
|
46
52
|
class SimpleGraphLayer(nn.Module):
|
|
47
53
|
def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.1):
|
|
@@ -52,7 +58,7 @@ class SimpleGraphLayer(nn.Module):
|
|
|
52
58
|
|
|
53
59
|
def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
|
|
54
60
|
# Message passing with normalized sparse adjacency: A_hat * X * W.
|
|
55
|
-
h =
|
|
61
|
+
h = _adj_mm(adj, x)
|
|
56
62
|
h = self.linear(h)
|
|
57
63
|
h = self.activation(h)
|
|
58
64
|
return self.dropout(h)
|
|
@@ -86,7 +92,7 @@ class SimpleGNN(nn.Module):
|
|
|
86
92
|
h = x
|
|
87
93
|
for layer in self.layers:
|
|
88
94
|
h = layer(h, adj_used)
|
|
89
|
-
h =
|
|
95
|
+
h = _adj_mm(adj_used, h)
|
|
90
96
|
out = self.output(h)
|
|
91
97
|
return self.output_act(out)
|
|
92
98
|
|
|
@@ -124,7 +130,11 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
|
|
|
124
130
|
self.knn_gpu_mem_ratio = max(0.0, min(1.0, knn_gpu_mem_ratio))
|
|
125
131
|
self.knn_gpu_mem_overhead = max(1.0, knn_gpu_mem_overhead)
|
|
126
132
|
self.knn_cpu_jobs = knn_cpu_jobs
|
|
133
|
+
self.mps_dense_max_nodes = int(
|
|
134
|
+
os.environ.get("BAYESOPT_GNN_MPS_DENSE_MAX_NODES", "5000")
|
|
135
|
+
)
|
|
127
136
|
self._knn_warning_emitted = False
|
|
137
|
+
self._mps_fallback_triggered = False
|
|
128
138
|
self._adj_cache_meta: Optional[Dict[str, Any]] = None
|
|
129
139
|
self._adj_cache_key: Optional[Tuple[Any, ...]] = None
|
|
130
140
|
self._adj_cache_tensor: Optional[torch.Tensor] = None
|
|
@@ -168,11 +178,11 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
|
|
|
168
178
|
else:
|
|
169
179
|
self.device = torch.device('cuda')
|
|
170
180
|
elif torch.backends.mps.is_available():
|
|
171
|
-
self.device = torch.device('
|
|
181
|
+
self.device = torch.device('mps')
|
|
172
182
|
global _GNN_MPS_WARNED
|
|
173
183
|
if not _GNN_MPS_WARNED:
|
|
174
184
|
print(
|
|
175
|
-
"[GNN] MPS backend
|
|
185
|
+
"[GNN] Using MPS backend; will fall back to CPU on unsupported ops.",
|
|
176
186
|
flush=True,
|
|
177
187
|
)
|
|
178
188
|
_GNN_MPS_WARNED = True
|
|
@@ -235,6 +245,41 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
|
|
|
235
245
|
else:
|
|
236
246
|
base.register_buffer("adj_buffer", adj)
|
|
237
247
|
|
|
248
|
+
@staticmethod
|
|
249
|
+
def _is_mps_unsupported_error(exc: BaseException) -> bool:
|
|
250
|
+
msg = str(exc).lower()
|
|
251
|
+
if "mps" not in msg:
|
|
252
|
+
return False
|
|
253
|
+
if any(token in msg for token in ("not supported", "not implemented", "does not support", "unimplemented", "out of memory")):
|
|
254
|
+
return True
|
|
255
|
+
return "sparse" in msg
|
|
256
|
+
|
|
257
|
+
def _fallback_to_cpu(self, reason: str) -> None:
|
|
258
|
+
if self.device.type != "mps" or self._mps_fallback_triggered:
|
|
259
|
+
return
|
|
260
|
+
self._mps_fallback_triggered = True
|
|
261
|
+
print(f"[GNN] MPS op unsupported ({reason}); falling back to CPU.", flush=True)
|
|
262
|
+
self.device = torch.device("cpu")
|
|
263
|
+
self.use_pyg_knn = False
|
|
264
|
+
self.data_parallel_enabled = False
|
|
265
|
+
self.ddp_enabled = False
|
|
266
|
+
base = self._unwrap_gnn()
|
|
267
|
+
try:
|
|
268
|
+
base = base.to(self.device)
|
|
269
|
+
except Exception:
|
|
270
|
+
pass
|
|
271
|
+
self.gnn = base
|
|
272
|
+
self.invalidate_graph_cache()
|
|
273
|
+
|
|
274
|
+
def _run_with_mps_fallback(self, fn, *args, **kwargs):
|
|
275
|
+
try:
|
|
276
|
+
return fn(*args, **kwargs)
|
|
277
|
+
except (RuntimeError, NotImplementedError) as exc:
|
|
278
|
+
if self.device.type == "mps" and self._is_mps_unsupported_error(exc):
|
|
279
|
+
self._fallback_to_cpu(str(exc))
|
|
280
|
+
return fn(*args, **kwargs)
|
|
281
|
+
raise
|
|
282
|
+
|
|
238
283
|
def _graph_cache_meta(self, X_df: pd.DataFrame) -> Dict[str, Any]:
|
|
239
284
|
row_hash = pd.util.hash_pandas_object(X_df, index=False).values
|
|
240
285
|
idx_hash = pd.util.hash_pandas_object(X_df.index, index=False).values
|
|
@@ -255,11 +300,14 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
|
|
|
255
300
|
"knn_gpu_mem_ratio": float(self.knn_gpu_mem_ratio),
|
|
256
301
|
"knn_gpu_mem_overhead": float(self.knn_gpu_mem_overhead),
|
|
257
302
|
}
|
|
303
|
+
adj_format = "dense" if self.device.type == "mps" else "sparse"
|
|
258
304
|
return {
|
|
259
305
|
"n_samples": int(X_df.shape[0]),
|
|
260
306
|
"n_features": int(X_df.shape[1]),
|
|
261
307
|
"hash": hasher.hexdigest(),
|
|
262
308
|
"knn_config": knn_config,
|
|
309
|
+
"adj_format": adj_format,
|
|
310
|
+
"device_type": self.device.type,
|
|
263
311
|
}
|
|
264
312
|
|
|
265
313
|
def _graph_cache_key(self, X_df: pd.DataFrame) -> Tuple[Any, ...]:
|
|
@@ -284,8 +332,7 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
|
|
|
284
332
|
if meta_expected is None:
|
|
285
333
|
meta_expected = self._graph_cache_meta(X_df)
|
|
286
334
|
try:
|
|
287
|
-
payload = torch.load(self.graph_cache_path,
|
|
288
|
-
map_location=self.device)
|
|
335
|
+
payload = torch.load(self.graph_cache_path, map_location="cpu")
|
|
289
336
|
except Exception as exc:
|
|
290
337
|
print(
|
|
291
338
|
f"[GNN] Failed to load cached graph from {self.graph_cache_path}: {exc}")
|
|
@@ -293,7 +340,13 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
|
|
|
293
340
|
if isinstance(payload, dict) and "adj" in payload:
|
|
294
341
|
meta_cached = payload.get("meta")
|
|
295
342
|
if meta_cached == meta_expected:
|
|
296
|
-
|
|
343
|
+
adj = payload["adj"]
|
|
344
|
+
if self.device.type == "mps" and getattr(adj, "is_sparse", False):
|
|
345
|
+
print(
|
|
346
|
+
f"[GNN] Cached sparse graph incompatible with MPS; rebuilding: {self.graph_cache_path}"
|
|
347
|
+
)
|
|
348
|
+
return None
|
|
349
|
+
return adj.to(self.device)
|
|
297
350
|
print(
|
|
298
351
|
f"[GNN] Cached graph metadata mismatch; rebuilding: {self.graph_cache_path}")
|
|
299
352
|
return None
|
|
@@ -408,6 +461,11 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
|
|
|
408
461
|
return True
|
|
409
462
|
|
|
410
463
|
def _normalized_adj(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
|
|
464
|
+
if self.device.type == "mps":
|
|
465
|
+
return self._normalized_adj_dense(edge_index, num_nodes)
|
|
466
|
+
return self._normalized_adj_sparse(edge_index, num_nodes)
|
|
467
|
+
|
|
468
|
+
def _normalized_adj_sparse(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
|
|
411
469
|
values = torch.ones(edge_index.shape[1], device=self.device)
|
|
412
470
|
adj = torch.sparse_coo_tensor(
|
|
413
471
|
edge_index.to(self.device), values, (num_nodes, num_nodes))
|
|
@@ -421,6 +479,21 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
|
|
|
421
479
|
adj.indices(), norm_values, size=adj.shape)
|
|
422
480
|
return adj_norm
|
|
423
481
|
|
|
482
|
+
def _normalized_adj_dense(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
|
|
483
|
+
if self.mps_dense_max_nodes <= 0 or num_nodes > self.mps_dense_max_nodes:
|
|
484
|
+
raise RuntimeError(
|
|
485
|
+
f"MPS dense adjacency not supported for {num_nodes} nodes; "
|
|
486
|
+
f"max={self.mps_dense_max_nodes}. Falling back to CPU."
|
|
487
|
+
)
|
|
488
|
+
edge_index = edge_index.to(self.device)
|
|
489
|
+
adj = torch.zeros((num_nodes, num_nodes), device=self.device, dtype=torch.float32)
|
|
490
|
+
adj[edge_index[0], edge_index[1]] = 1.0
|
|
491
|
+
deg = adj.sum(dim=1)
|
|
492
|
+
deg_inv_sqrt = torch.pow(deg + 1e-8, -0.5)
|
|
493
|
+
adj = adj * deg_inv_sqrt.view(-1, 1)
|
|
494
|
+
adj = adj * deg_inv_sqrt.view(1, -1)
|
|
495
|
+
return adj
|
|
496
|
+
|
|
424
497
|
def _tensorize_split(self, X, y, w, allow_none: bool = False):
|
|
425
498
|
if X is None and allow_none:
|
|
426
499
|
return None, None, None
|
|
@@ -462,17 +535,25 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
|
|
|
462
535
|
if self._adj_cache_meta == meta_expected and self._adj_cache_tensor is not None:
|
|
463
536
|
cached = self._adj_cache_tensor
|
|
464
537
|
if cached.device != self.device:
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
538
|
+
if self.device.type == "mps" and getattr(cached, "is_sparse", False):
|
|
539
|
+
self._adj_cache_tensor = None
|
|
540
|
+
else:
|
|
541
|
+
cached = cached.to(self.device)
|
|
542
|
+
self._adj_cache_tensor = cached
|
|
543
|
+
if self._adj_cache_tensor is not None:
|
|
544
|
+
return self._adj_cache_tensor
|
|
468
545
|
else:
|
|
469
546
|
cache_key = self._graph_cache_key(X_df)
|
|
470
547
|
if self._adj_cache_key == cache_key and self._adj_cache_tensor is not None:
|
|
471
548
|
cached = self._adj_cache_tensor
|
|
472
549
|
if cached.device != self.device:
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
550
|
+
if self.device.type == "mps" and getattr(cached, "is_sparse", False):
|
|
551
|
+
self._adj_cache_tensor = None
|
|
552
|
+
else:
|
|
553
|
+
cached = cached.to(self.device)
|
|
554
|
+
self._adj_cache_tensor = cached
|
|
555
|
+
if self._adj_cache_tensor is not None:
|
|
556
|
+
return self._adj_cache_tensor
|
|
476
557
|
X_np = None
|
|
477
558
|
if X_tensor is None:
|
|
478
559
|
X_np = X_df.to_numpy(dtype=np.float32, copy=False)
|
|
@@ -511,7 +592,20 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
|
|
|
511
592
|
def fit(self, X_train, y_train, w_train=None,
|
|
512
593
|
X_val=None, y_val=None, w_val=None,
|
|
513
594
|
trial: Optional[optuna.trial.Trial] = None):
|
|
595
|
+
return self._run_with_mps_fallback(
|
|
596
|
+
self._fit_impl,
|
|
597
|
+
X_train,
|
|
598
|
+
y_train,
|
|
599
|
+
w_train,
|
|
600
|
+
X_val,
|
|
601
|
+
y_val,
|
|
602
|
+
w_val,
|
|
603
|
+
trial,
|
|
604
|
+
)
|
|
514
605
|
|
|
606
|
+
def _fit_impl(self, X_train, y_train, w_train=None,
|
|
607
|
+
X_val=None, y_val=None, w_val=None,
|
|
608
|
+
trial: Optional[optuna.trial.Trial] = None):
|
|
515
609
|
X_train_tensor, y_train_tensor, w_train_tensor = self._tensorize_split(
|
|
516
610
|
X_train, y_train, w_train, allow_none=False)
|
|
517
611
|
has_val = X_val is not None and y_val is not None
|
|
@@ -621,6 +715,9 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
|
|
|
621
715
|
self.best_epoch = int(best_epoch or self.epochs)
|
|
622
716
|
|
|
623
717
|
def predict(self, X: pd.DataFrame) -> np.ndarray:
|
|
718
|
+
return self._run_with_mps_fallback(self._predict_impl, X)
|
|
719
|
+
|
|
720
|
+
def _predict_impl(self, X: pd.DataFrame) -> np.ndarray:
|
|
624
721
|
self.gnn.eval()
|
|
625
722
|
X_tensor, _, _ = self._tensorize_split(
|
|
626
723
|
X, None, None, allow_none=False)
|
|
@@ -640,6 +737,9 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
|
|
|
640
737
|
return y_pred.ravel()
|
|
641
738
|
|
|
642
739
|
def encode(self, X: pd.DataFrame) -> np.ndarray:
|
|
740
|
+
return self._run_with_mps_fallback(self._encode_impl, X)
|
|
741
|
+
|
|
742
|
+
def _encode_impl(self, X: pd.DataFrame) -> np.ndarray:
|
|
643
743
|
"""Return per-sample node embeddings (hidden representations)."""
|
|
644
744
|
base = self._unwrap_gnn()
|
|
645
745
|
base.eval()
|
|
@@ -655,7 +755,7 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
|
|
|
655
755
|
raise RuntimeError("GNN base module does not expose layers.")
|
|
656
756
|
for layer in layers:
|
|
657
757
|
h = layer(h, adj)
|
|
658
|
-
h =
|
|
758
|
+
h = _adj_mm(adj, h)
|
|
659
759
|
return h.detach().cpu().numpy()
|
|
660
760
|
|
|
661
761
|
def set_params(self, params: Dict[str, Any]):
|
{ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py
RENAMED
|
@@ -27,6 +27,7 @@ from sklearn.preprocessing import StandardScaler
|
|
|
27
27
|
from ..config_preprocess import BayesOptConfig, OutputManager
|
|
28
28
|
from ..utils import DistributedUtils, EPS, ensure_parent_dir
|
|
29
29
|
from ins_pricing.utils import get_logger, GPUMemoryManager, DeviceManager
|
|
30
|
+
from ins_pricing.utils.torch_compat import torch_load
|
|
30
31
|
|
|
31
32
|
# Module-level logger
|
|
32
33
|
_logger = get_logger("ins_pricing.trainer")
|
|
@@ -616,7 +617,7 @@ class TrainerBase:
|
|
|
616
617
|
pass
|
|
617
618
|
else:
|
|
618
619
|
# FT-Transformer: load state_dict and reconstruct model
|
|
619
|
-
loaded =
|
|
620
|
+
loaded = torch_load(path, map_location='cpu', weights_only=False)
|
|
620
621
|
if isinstance(loaded, dict):
|
|
621
622
|
if "state_dict" in loaded and "model_config" in loaded:
|
|
622
623
|
# New format: state_dict + model_config
|
|
@@ -1094,7 +1095,7 @@ class TrainerBase:
|
|
|
1094
1095
|
split_iter = splitter
|
|
1095
1096
|
|
|
1096
1097
|
losses: List[float] = []
|
|
1097
|
-
for train_idx, val_idx in split_iter:
|
|
1098
|
+
for fold_idx, (train_idx, val_idx) in enumerate(split_iter):
|
|
1098
1099
|
X_train = X_all.iloc[train_idx]
|
|
1099
1100
|
y_train = y_all.iloc[train_idx]
|
|
1100
1101
|
X_val = X_all.iloc[val_idx]
|
|
@@ -1108,9 +1109,11 @@ class TrainerBase:
|
|
|
1108
1109
|
model = model_builder(params)
|
|
1109
1110
|
try:
|
|
1110
1111
|
if fit_predict_fn:
|
|
1112
|
+
# Avoid duplicate Optuna step reports across folds.
|
|
1113
|
+
trial_for_fold = trial if fold_idx == 0 else None
|
|
1111
1114
|
y_pred = fit_predict_fn(
|
|
1112
1115
|
model, X_train, y_train, w_train,
|
|
1113
|
-
X_val, y_val, w_val,
|
|
1116
|
+
X_val, y_val, w_val, trial_for_fold
|
|
1114
1117
|
)
|
|
1115
1118
|
else:
|
|
1116
1119
|
fit_kwargs = {}
|
|
@@ -1288,4 +1291,3 @@ class TrainerBase:
|
|
|
1288
1291
|
predict_kwargs_train=predict_kwargs_train,
|
|
1289
1292
|
predict_kwargs_test=predict_kwargs_test,
|
|
1290
1293
|
predict_fn=predict_fn)
|
|
1291
|
-
|
{ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py
RENAMED
|
@@ -12,6 +12,7 @@ from .trainer_base import TrainerBase
|
|
|
12
12
|
from ..models import GraphNeuralNetSklearn
|
|
13
13
|
from ..utils import EPS
|
|
14
14
|
from ins_pricing.utils import get_logger
|
|
15
|
+
from ins_pricing.utils.torch_compat import torch_load
|
|
15
16
|
|
|
16
17
|
_logger = get_logger("ins_pricing.trainer.gnn")
|
|
17
18
|
|
|
@@ -300,7 +301,7 @@ class GNNTrainer(TrainerBase):
|
|
|
300
301
|
if not os.path.exists(path):
|
|
301
302
|
print(f"[load] Warning: Model file not found: {path}")
|
|
302
303
|
return
|
|
303
|
-
payload =
|
|
304
|
+
payload = torch_load(path, map_location='cpu', weights_only=False)
|
|
304
305
|
if not isinstance(payload, dict):
|
|
305
306
|
raise ValueError(f"Invalid GNN checkpoint: {path}")
|
|
306
307
|
params = payload.get("best_params") or {}
|
|
@@ -322,4 +323,3 @@ class GNNTrainer(TrainerBase):
|
|
|
322
323
|
self.model = model
|
|
323
324
|
self.best_params = dict(params) if isinstance(params, dict) else None
|
|
324
325
|
self.ctx.gnn_best = self.model
|
|
325
|
-
|