ins-pricing 0.3.0__tar.gz → 0.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/PKG-INFO +1 -1
  2. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/Explain_entry.py +50 -48
  3. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/bayesopt_entry_runner.py +73 -70
  4. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +4 -3
  5. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/models/model_gnn.py +114 -14
  6. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +6 -4
  7. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py +2 -2
  8. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/utils/torch_trainer_mixin.py +7 -3
  9. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/production/predict.py +5 -4
  10. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/setup.py +1 -1
  11. ins_pricing-0.3.2/ins_pricing/utils/torch_compat.py +45 -0
  12. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing.egg-info/PKG-INFO +1 -1
  13. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing.egg-info/SOURCES.txt +1 -0
  14. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/pyproject.toml +1 -1
  15. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/MANIFEST.in +0 -0
  16. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/README.md +0 -0
  17. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/CHANGELOG.md +0 -0
  18. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/README.md +0 -0
  19. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/RELEASE_NOTES_0.2.8.md +0 -0
  20. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/__init__.py +0 -0
  21. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/BayesOpt_entry.py +0 -0
  22. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/BayesOpt_incremental.py +0 -0
  23. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/Explain_Run.py +0 -0
  24. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/Pricing_Run.py +0 -0
  25. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/__init__.py +0 -0
  26. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/utils/__init__.py +0 -0
  27. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/utils/cli_common.py +0 -0
  28. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/utils/cli_config.py +0 -0
  29. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/utils/evaluation_context.py +0 -0
  30. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/utils/import_resolver.py +0 -0
  31. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/utils/notebook_utils.py +0 -0
  32. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/utils/run_logging.py +0 -0
  33. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/cli/watchdog_run.py +0 -0
  34. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/docs/modelling/BayesOpt_USAGE.md +0 -0
  35. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/docs/modelling/README.md +0 -0
  36. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/exceptions.py +0 -0
  37. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/governance/README.md +0 -0
  38. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/governance/__init__.py +0 -0
  39. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/governance/approval.py +0 -0
  40. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/governance/audit.py +0 -0
  41. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/governance/registry.py +0 -0
  42. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/governance/release.py +0 -0
  43. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/__init__.py +0 -0
  44. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/BayesOpt.py +0 -0
  45. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/__init__.py +0 -0
  46. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +0 -0
  47. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +0 -0
  48. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +0 -0
  49. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/__init__.py +0 -0
  50. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/config_components.py +0 -0
  51. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/config_preprocess.py +0 -0
  52. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/core.py +0 -0
  53. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/model_explain_mixin.py +0 -0
  54. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/model_plotting_mixin.py +0 -0
  55. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/models/__init__.py +0 -0
  56. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/models/model_ft_components.py +0 -0
  57. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/models/model_resn.py +0 -0
  58. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -0
  59. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py +0 -0
  60. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/trainers/trainer_glm.py +0 -0
  61. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/trainers/trainer_resn.py +0 -0
  62. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/trainers/trainer_xgb.py +0 -0
  63. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -0
  64. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -0
  65. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/utils/distributed_utils.py +0 -0
  66. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -0
  67. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -0
  68. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/utils.py +0 -0
  69. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -0
  70. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/core/evaluation.py +0 -0
  71. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/explain/__init__.py +0 -0
  72. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/explain/gradients.py +0 -0
  73. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/explain/metrics.py +0 -0
  74. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/explain/permutation.py +0 -0
  75. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/explain/shap_utils.py +0 -0
  76. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/plotting/__init__.py +0 -0
  77. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/plotting/common.py +0 -0
  78. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/plotting/curves.py +0 -0
  79. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/plotting/diagnostics.py +0 -0
  80. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/plotting/geo.py +0 -0
  81. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/modelling/plotting/importance.py +0 -0
  82. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/pricing/README.md +0 -0
  83. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/pricing/__init__.py +0 -0
  84. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/pricing/calibration.py +0 -0
  85. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/pricing/data_quality.py +0 -0
  86. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/pricing/exposure.py +0 -0
  87. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/pricing/factors.py +0 -0
  88. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/pricing/monitoring.py +0 -0
  89. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/pricing/rate_table.py +0 -0
  90. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/production/__init__.py +0 -0
  91. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/production/drift.py +0 -0
  92. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/production/monitoring.py +0 -0
  93. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/production/preprocess.py +0 -0
  94. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/production/scoring.py +0 -0
  95. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/reporting/README.md +0 -0
  96. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/reporting/__init__.py +0 -0
  97. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/reporting/report_builder.py +0 -0
  98. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/reporting/scheduler.py +0 -0
  99. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/governance/__init__.py +0 -0
  100. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/governance/test_audit.py +0 -0
  101. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/governance/test_registry.py +0 -0
  102. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/governance/test_release.py +0 -0
  103. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/modelling/conftest.py +0 -0
  104. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/modelling/test_cross_val_generic.py +0 -0
  105. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/modelling/test_distributed_utils.py +0 -0
  106. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/modelling/test_explain.py +0 -0
  107. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/modelling/test_geo_tokens_split.py +0 -0
  108. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/modelling/test_graph_cache.py +0 -0
  109. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/modelling/test_plotting.py +0 -0
  110. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/modelling/test_plotting_library.py +0 -0
  111. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/modelling/test_preprocessor.py +0 -0
  112. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/pricing/__init__.py +0 -0
  113. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/pricing/test_calibration.py +0 -0
  114. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/pricing/test_exposure.py +0 -0
  115. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/pricing/test_factors.py +0 -0
  116. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/pricing/test_rate_table.py +0 -0
  117. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/production/__init__.py +0 -0
  118. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/production/test_monitoring.py +0 -0
  119. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/production/test_predict.py +0 -0
  120. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/production/test_preprocess.py +0 -0
  121. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/tests/production/test_scoring.py +0 -0
  122. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/utils/__init__.py +0 -0
  123. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/utils/device.py +0 -0
  124. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/utils/logging.py +0 -0
  125. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/utils/metrics.py +0 -0
  126. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/utils/paths.py +0 -0
  127. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/utils/profiling.py +0 -0
  128. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing/utils/validation.py +0 -0
  129. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing.egg-info/dependency_links.txt +0 -0
  130. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing.egg-info/requires.txt +0 -0
  131. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/ins_pricing.egg-info/top_level.txt +0 -0
  132. {ins_pricing-0.3.0 → ins_pricing-0.3.2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ins_pricing
3
- Version: 0.3.0
3
+ Version: 0.3.2
4
4
  Summary: Reusable modelling, pricing, governance, and reporting utilities.
5
5
  Author: meishi125478
6
6
  License: Proprietary
@@ -491,54 +491,56 @@ def explain_from_config(args: argparse.Namespace) -> None:
491
491
  categorical_features = cfg.get("categorical_features")
492
492
  plot_path_style = runtime_cfg["plot_path_style"]
493
493
 
494
- model = ropt.BayesOptModel(
495
- train_df,
496
- test_df,
497
- model_name,
498
- cfg["target"],
499
- cfg["weight"],
500
- feature_list,
501
- task_type=str(cfg.get("task_type", "regression")),
502
- binary_resp_nme=binary_target,
503
- cate_list=categorical_features,
504
- prop_test=prop_test,
505
- rand_seed=rand_seed,
506
- epochs=int(runtime_cfg["epochs"]),
507
- use_gpu=bool(cfg.get("use_gpu", True)),
508
- output_dir=output_dir,
509
- xgb_max_depth_max=runtime_cfg["xgb_max_depth_max"],
510
- xgb_n_estimators_max=runtime_cfg["xgb_n_estimators_max"],
511
- resn_weight_decay=cfg.get("resn_weight_decay"),
512
- final_ensemble=bool(cfg.get("final_ensemble", False)),
513
- final_ensemble_k=int(cfg.get("final_ensemble_k", 3)),
514
- final_refit=bool(cfg.get("final_refit", True)),
515
- optuna_storage=runtime_cfg["optuna_storage"],
516
- optuna_study_prefix=runtime_cfg["optuna_study_prefix"],
517
- best_params_files=runtime_cfg["best_params_files"],
518
- gnn_use_approx_knn=cfg.get("gnn_use_approx_knn", True),
519
- gnn_approx_knn_threshold=cfg.get("gnn_approx_knn_threshold", 50000),
520
- gnn_graph_cache=cfg.get("gnn_graph_cache"),
521
- gnn_max_gpu_knn_nodes=cfg.get("gnn_max_gpu_knn_nodes", 200000),
522
- gnn_knn_gpu_mem_ratio=cfg.get("gnn_knn_gpu_mem_ratio", 0.9),
523
- gnn_knn_gpu_mem_overhead=cfg.get("gnn_knn_gpu_mem_overhead", 2.0),
524
- region_province_col=cfg.get("region_province_col"),
525
- region_city_col=cfg.get("region_city_col"),
526
- region_effect_alpha=cfg.get("region_effect_alpha"),
527
- geo_feature_nmes=cfg.get("geo_feature_nmes"),
528
- geo_token_hidden_dim=cfg.get("geo_token_hidden_dim"),
529
- geo_token_layers=cfg.get("geo_token_layers"),
530
- geo_token_dropout=cfg.get("geo_token_dropout"),
531
- geo_token_k_neighbors=cfg.get("geo_token_k_neighbors"),
532
- geo_token_learning_rate=cfg.get("geo_token_learning_rate"),
533
- geo_token_epochs=cfg.get("geo_token_epochs"),
534
- ft_role=str(cfg.get("ft_role", "model")),
535
- ft_feature_prefix=str(cfg.get("ft_feature_prefix", "ft_emb")),
536
- ft_num_numeric_tokens=cfg.get("ft_num_numeric_tokens"),
537
- infer_categorical_max_unique=int(cfg.get("infer_categorical_max_unique", 50)),
538
- infer_categorical_max_ratio=float(cfg.get("infer_categorical_max_ratio", 0.05)),
539
- reuse_best_params=runtime_cfg["reuse_best_params"],
540
- plot_path_style=plot_path_style,
541
- )
494
+ config_fields = getattr(ropt.BayesOptConfig, "__dataclass_fields__", {})
495
+ allowed_config_keys = set(config_fields.keys())
496
+ config_payload = {k: v for k, v in cfg.items() if k in allowed_config_keys}
497
+ config_payload.update({
498
+ "model_nme": model_name,
499
+ "resp_nme": cfg["target"],
500
+ "weight_nme": cfg["weight"],
501
+ "factor_nmes": feature_list,
502
+ "task_type": str(cfg.get("task_type", "regression")),
503
+ "binary_resp_nme": binary_target,
504
+ "cate_list": categorical_features,
505
+ "prop_test": prop_test,
506
+ "rand_seed": rand_seed,
507
+ "epochs": int(runtime_cfg["epochs"]),
508
+ "use_gpu": bool(cfg.get("use_gpu", True)),
509
+ "output_dir": output_dir,
510
+ "xgb_max_depth_max": runtime_cfg["xgb_max_depth_max"],
511
+ "xgb_n_estimators_max": runtime_cfg["xgb_n_estimators_max"],
512
+ "resn_weight_decay": cfg.get("resn_weight_decay"),
513
+ "final_ensemble": bool(cfg.get("final_ensemble", False)),
514
+ "final_ensemble_k": int(cfg.get("final_ensemble_k", 3)),
515
+ "final_refit": bool(cfg.get("final_refit", True)),
516
+ "optuna_storage": runtime_cfg["optuna_storage"],
517
+ "optuna_study_prefix": runtime_cfg["optuna_study_prefix"],
518
+ "best_params_files": runtime_cfg["best_params_files"],
519
+ "gnn_use_approx_knn": cfg.get("gnn_use_approx_knn", True),
520
+ "gnn_approx_knn_threshold": cfg.get("gnn_approx_knn_threshold", 50000),
521
+ "gnn_graph_cache": cfg.get("gnn_graph_cache"),
522
+ "gnn_max_gpu_knn_nodes": cfg.get("gnn_max_gpu_knn_nodes", 200000),
523
+ "gnn_knn_gpu_mem_ratio": cfg.get("gnn_knn_gpu_mem_ratio", 0.9),
524
+ "gnn_knn_gpu_mem_overhead": cfg.get("gnn_knn_gpu_mem_overhead", 2.0),
525
+ "region_province_col": cfg.get("region_province_col"),
526
+ "region_city_col": cfg.get("region_city_col"),
527
+ "region_effect_alpha": cfg.get("region_effect_alpha"),
528
+ "geo_feature_nmes": cfg.get("geo_feature_nmes"),
529
+ "geo_token_hidden_dim": cfg.get("geo_token_hidden_dim"),
530
+ "geo_token_layers": cfg.get("geo_token_layers"),
531
+ "geo_token_dropout": cfg.get("geo_token_dropout"),
532
+ "geo_token_k_neighbors": cfg.get("geo_token_k_neighbors"),
533
+ "geo_token_learning_rate": cfg.get("geo_token_learning_rate"),
534
+ "geo_token_epochs": cfg.get("geo_token_epochs"),
535
+ "ft_role": str(cfg.get("ft_role", "model")),
536
+ "ft_feature_prefix": str(cfg.get("ft_feature_prefix", "ft_emb")),
537
+ "ft_num_numeric_tokens": cfg.get("ft_num_numeric_tokens"),
538
+ "reuse_best_params": runtime_cfg["reuse_best_params"],
539
+ "plot_path_style": plot_path_style or "nested",
540
+ })
541
+ config_payload = {k: v for k, v in config_payload.items() if v is not None}
542
+ config = ropt.BayesOptConfig(**config_payload)
543
+ model = ropt.BayesOptModel(train_df, test_df, config=config)
542
544
 
543
545
  output_overrides = resolve_explain_output_overrides(
544
546
  explain_cfg,
@@ -1223,76 +1223,79 @@ def train_from_config(args: argparse.Namespace) -> None:
1223
1223
  cfg.get("ft_feature_prefix", args.ft_feature_prefix))
1224
1224
  ft_num_numeric_tokens = cfg.get("ft_num_numeric_tokens")
1225
1225
 
1226
- model = ropt.BayesOptModel(
1227
- train_df,
1228
- test_df,
1229
- model_name,
1230
- cfg["target"],
1231
- cfg["weight"],
1232
- feature_list,
1233
- task_type=task_type,
1234
- binary_resp_nme=binary_target,
1235
- cate_list=categorical_features,
1236
- prop_test=val_ratio,
1237
- rand_seed=rand_seed,
1238
- epochs=epochs,
1239
- use_gpu=use_gpu,
1240
- use_resn_data_parallel=use_resn_dp,
1241
- use_ft_data_parallel=use_ft_dp,
1242
- use_resn_ddp=use_resn_ddp,
1243
- use_ft_ddp=use_ft_ddp,
1244
- use_gnn_data_parallel=use_gnn_dp,
1245
- use_gnn_ddp=use_gnn_ddp,
1246
- output_dir=output_dir,
1247
- xgb_max_depth_max=xgb_max_depth_max,
1248
- xgb_n_estimators_max=xgb_n_estimators_max,
1249
- resn_weight_decay=cfg.get("resn_weight_decay"),
1250
- final_ensemble=bool(cfg.get("final_ensemble", False)),
1251
- final_ensemble_k=int(cfg.get("final_ensemble_k", 3)),
1252
- final_refit=bool(cfg.get("final_refit", True)),
1253
- optuna_storage=optuna_storage,
1254
- optuna_study_prefix=optuna_study_prefix,
1255
- best_params_files=best_params_files,
1256
- gnn_use_approx_knn=gnn_use_ann,
1257
- gnn_approx_knn_threshold=gnn_threshold,
1258
- gnn_graph_cache=gnn_graph_cache,
1259
- gnn_max_gpu_knn_nodes=gnn_max_gpu_nodes,
1260
- gnn_knn_gpu_mem_ratio=gnn_gpu_mem_ratio,
1261
- gnn_knn_gpu_mem_overhead=gnn_gpu_mem_overhead,
1262
- region_province_col=region_province_col,
1263
- region_city_col=region_city_col,
1264
- region_effect_alpha=region_effect_alpha,
1265
- geo_feature_nmes=geo_feature_nmes,
1266
- geo_token_hidden_dim=geo_token_hidden_dim,
1267
- geo_token_layers=geo_token_layers,
1268
- geo_token_dropout=geo_token_dropout,
1269
- geo_token_k_neighbors=geo_token_k_neighbors,
1270
- geo_token_learning_rate=geo_token_learning_rate,
1271
- geo_token_epochs=geo_token_epochs,
1272
- ft_role=ft_role,
1273
- ft_feature_prefix=ft_feature_prefix,
1274
- ft_num_numeric_tokens=ft_num_numeric_tokens,
1275
- infer_categorical_max_unique=int(
1276
- cfg.get("infer_categorical_max_unique", 50)),
1277
- infer_categorical_max_ratio=float(
1278
- cfg.get("infer_categorical_max_ratio", 0.05)),
1279
- reuse_best_params=reuse_best_params,
1280
- bo_sample_limit=bo_sample_limit,
1281
- cache_predictions=cache_predictions,
1282
- prediction_cache_dir=prediction_cache_dir,
1283
- prediction_cache_format=prediction_cache_format,
1284
- cv_strategy=cv_strategy or split_strategy,
1285
- cv_group_col=cv_group_col or split_group_col,
1286
- cv_time_col=cv_time_col or split_time_col,
1287
- cv_time_ascending=cv_time_ascending,
1288
- cv_splits=cv_splits,
1289
- ft_oof_folds=ft_oof_folds,
1290
- ft_oof_strategy=ft_oof_strategy,
1291
- ft_oof_shuffle=ft_oof_shuffle,
1292
- save_preprocess=save_preprocess,
1293
- preprocess_artifact_path=preprocess_artifact_path,
1294
- plot_path_style=plot_path_style,
1295
- )
1226
+ config_fields = getattr(ropt.BayesOptConfig,
1227
+ "__dataclass_fields__", {})
1228
+ allowed_config_keys = set(config_fields.keys())
1229
+ config_payload = {k: v for k,
1230
+ v in cfg.items() if k in allowed_config_keys}
1231
+ config_payload.update({
1232
+ "model_nme": model_name,
1233
+ "resp_nme": cfg["target"],
1234
+ "weight_nme": cfg["weight"],
1235
+ "factor_nmes": feature_list,
1236
+ "task_type": task_type,
1237
+ "binary_resp_nme": binary_target,
1238
+ "cate_list": categorical_features,
1239
+ "prop_test": val_ratio,
1240
+ "rand_seed": rand_seed,
1241
+ "epochs": epochs,
1242
+ "use_gpu": use_gpu,
1243
+ "use_resn_data_parallel": use_resn_dp,
1244
+ "use_ft_data_parallel": use_ft_dp,
1245
+ "use_gnn_data_parallel": use_gnn_dp,
1246
+ "use_resn_ddp": use_resn_ddp,
1247
+ "use_ft_ddp": use_ft_ddp,
1248
+ "use_gnn_ddp": use_gnn_ddp,
1249
+ "output_dir": output_dir,
1250
+ "xgb_max_depth_max": xgb_max_depth_max,
1251
+ "xgb_n_estimators_max": xgb_n_estimators_max,
1252
+ "resn_weight_decay": cfg.get("resn_weight_decay"),
1253
+ "final_ensemble": bool(cfg.get("final_ensemble", False)),
1254
+ "final_ensemble_k": int(cfg.get("final_ensemble_k", 3)),
1255
+ "final_refit": bool(cfg.get("final_refit", True)),
1256
+ "optuna_storage": optuna_storage,
1257
+ "optuna_study_prefix": optuna_study_prefix,
1258
+ "best_params_files": best_params_files,
1259
+ "gnn_use_approx_knn": gnn_use_ann,
1260
+ "gnn_approx_knn_threshold": gnn_threshold,
1261
+ "gnn_graph_cache": gnn_graph_cache,
1262
+ "gnn_max_gpu_knn_nodes": gnn_max_gpu_nodes,
1263
+ "gnn_knn_gpu_mem_ratio": gnn_gpu_mem_ratio,
1264
+ "gnn_knn_gpu_mem_overhead": gnn_gpu_mem_overhead,
1265
+ "region_province_col": region_province_col,
1266
+ "region_city_col": region_city_col,
1267
+ "region_effect_alpha": region_effect_alpha,
1268
+ "geo_feature_nmes": geo_feature_nmes,
1269
+ "geo_token_hidden_dim": geo_token_hidden_dim,
1270
+ "geo_token_layers": geo_token_layers,
1271
+ "geo_token_dropout": geo_token_dropout,
1272
+ "geo_token_k_neighbors": geo_token_k_neighbors,
1273
+ "geo_token_learning_rate": geo_token_learning_rate,
1274
+ "geo_token_epochs": geo_token_epochs,
1275
+ "ft_role": ft_role,
1276
+ "ft_feature_prefix": ft_feature_prefix,
1277
+ "ft_num_numeric_tokens": ft_num_numeric_tokens,
1278
+ "reuse_best_params": reuse_best_params,
1279
+ "bo_sample_limit": bo_sample_limit,
1280
+ "cache_predictions": cache_predictions,
1281
+ "prediction_cache_dir": prediction_cache_dir,
1282
+ "prediction_cache_format": prediction_cache_format,
1283
+ "cv_strategy": cv_strategy or split_strategy,
1284
+ "cv_group_col": cv_group_col or split_group_col,
1285
+ "cv_time_col": cv_time_col or split_time_col,
1286
+ "cv_time_ascending": cv_time_ascending,
1287
+ "cv_splits": cv_splits,
1288
+ "ft_oof_folds": ft_oof_folds,
1289
+ "ft_oof_strategy": ft_oof_strategy,
1290
+ "ft_oof_shuffle": ft_oof_shuffle,
1291
+ "save_preprocess": save_preprocess,
1292
+ "preprocess_artifact_path": preprocess_artifact_path,
1293
+ "plot_path_style": plot_path_style or "nested",
1294
+ })
1295
+ config_payload = {
1296
+ k: v for k, v in config_payload.items() if v is not None}
1297
+ config = ropt.BayesOptConfig(**config_payload)
1298
+ model = ropt.BayesOptModel(train_df, test_df, config=config)
1296
1299
 
1297
1300
  if plot_requested:
1298
1301
  plot_cfg = cfg.get("plot", {})
@@ -626,6 +626,7 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
626
626
  best_state = None
627
627
  patience_counter = 0
628
628
  is_ddp_model = isinstance(self.ft, DDP)
629
+ use_collectives = dist.is_initialized() and is_ddp_model
629
630
 
630
631
  clip_fn = None
631
632
  if self.device.type == 'cuda':
@@ -669,7 +670,7 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
669
670
  device=X_num_b.device)
670
671
  local_bad = 0 if bool(torch.isfinite(batch_loss)) else 1
671
672
  global_bad = local_bad
672
- if dist.is_initialized():
673
+ if use_collectives:
673
674
  bad = torch.tensor(
674
675
  [local_bad],
675
676
  device=batch_loss.device,
@@ -774,7 +775,7 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
774
775
  total_n += float(end - start)
775
776
  val_loss_tensor[0] = total_val / max(total_n, 1.0)
776
777
 
777
- if dist.is_initialized():
778
+ if use_collectives:
778
779
  dist.broadcast(val_loss_tensor, src=0)
779
780
  val_loss_value = float(val_loss_tensor.item())
780
781
  prune_now = False
@@ -806,7 +807,7 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
806
807
  if trial.should_prune():
807
808
  prune_now = True
808
809
 
809
- if dist.is_initialized():
810
+ if use_collectives:
810
811
  flag = torch.tensor(
811
812
  [1 if prune_now else 0],
812
813
  device=loss_tensor_device,
@@ -42,6 +42,12 @@ _GNN_MPS_WARNED = False
42
42
  # Simplified GNN implementation.
43
43
  # =============================================================================
44
44
 
45
+ def _adj_mm(adj: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
46
+ """Matrix multiply that supports sparse or dense adjacency."""
47
+ if adj.is_sparse:
48
+ return torch.sparse.mm(adj, x)
49
+ return adj.matmul(x)
50
+
45
51
 
46
52
  class SimpleGraphLayer(nn.Module):
47
53
  def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.1):
@@ -52,7 +58,7 @@ class SimpleGraphLayer(nn.Module):
52
58
 
53
59
  def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
54
60
  # Message passing with normalized sparse adjacency: A_hat * X * W.
55
- h = torch.sparse.mm(adj, x)
61
+ h = _adj_mm(adj, x)
56
62
  h = self.linear(h)
57
63
  h = self.activation(h)
58
64
  return self.dropout(h)
@@ -86,7 +92,7 @@ class SimpleGNN(nn.Module):
86
92
  h = x
87
93
  for layer in self.layers:
88
94
  h = layer(h, adj_used)
89
- h = torch.sparse.mm(adj_used, h)
95
+ h = _adj_mm(adj_used, h)
90
96
  out = self.output(h)
91
97
  return self.output_act(out)
92
98
 
@@ -124,7 +130,11 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
124
130
  self.knn_gpu_mem_ratio = max(0.0, min(1.0, knn_gpu_mem_ratio))
125
131
  self.knn_gpu_mem_overhead = max(1.0, knn_gpu_mem_overhead)
126
132
  self.knn_cpu_jobs = knn_cpu_jobs
133
+ self.mps_dense_max_nodes = int(
134
+ os.environ.get("BAYESOPT_GNN_MPS_DENSE_MAX_NODES", "5000")
135
+ )
127
136
  self._knn_warning_emitted = False
137
+ self._mps_fallback_triggered = False
128
138
  self._adj_cache_meta: Optional[Dict[str, Any]] = None
129
139
  self._adj_cache_key: Optional[Tuple[Any, ...]] = None
130
140
  self._adj_cache_tensor: Optional[torch.Tensor] = None
@@ -168,11 +178,11 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
168
178
  else:
169
179
  self.device = torch.device('cuda')
170
180
  elif torch.backends.mps.is_available():
171
- self.device = torch.device('cpu')
181
+ self.device = torch.device('mps')
172
182
  global _GNN_MPS_WARNED
173
183
  if not _GNN_MPS_WARNED:
174
184
  print(
175
- "[GNN] MPS backend does not support sparse ops; falling back to CPU.",
185
+ "[GNN] Using MPS backend; will fall back to CPU on unsupported ops.",
176
186
  flush=True,
177
187
  )
178
188
  _GNN_MPS_WARNED = True
@@ -235,6 +245,41 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
235
245
  else:
236
246
  base.register_buffer("adj_buffer", adj)
237
247
 
248
+ @staticmethod
249
+ def _is_mps_unsupported_error(exc: BaseException) -> bool:
250
+ msg = str(exc).lower()
251
+ if "mps" not in msg:
252
+ return False
253
+ if any(token in msg for token in ("not supported", "not implemented", "does not support", "unimplemented", "out of memory")):
254
+ return True
255
+ return "sparse" in msg
256
+
257
+ def _fallback_to_cpu(self, reason: str) -> None:
258
+ if self.device.type != "mps" or self._mps_fallback_triggered:
259
+ return
260
+ self._mps_fallback_triggered = True
261
+ print(f"[GNN] MPS op unsupported ({reason}); falling back to CPU.", flush=True)
262
+ self.device = torch.device("cpu")
263
+ self.use_pyg_knn = False
264
+ self.data_parallel_enabled = False
265
+ self.ddp_enabled = False
266
+ base = self._unwrap_gnn()
267
+ try:
268
+ base = base.to(self.device)
269
+ except Exception:
270
+ pass
271
+ self.gnn = base
272
+ self.invalidate_graph_cache()
273
+
274
+ def _run_with_mps_fallback(self, fn, *args, **kwargs):
275
+ try:
276
+ return fn(*args, **kwargs)
277
+ except (RuntimeError, NotImplementedError) as exc:
278
+ if self.device.type == "mps" and self._is_mps_unsupported_error(exc):
279
+ self._fallback_to_cpu(str(exc))
280
+ return fn(*args, **kwargs)
281
+ raise
282
+
238
283
  def _graph_cache_meta(self, X_df: pd.DataFrame) -> Dict[str, Any]:
239
284
  row_hash = pd.util.hash_pandas_object(X_df, index=False).values
240
285
  idx_hash = pd.util.hash_pandas_object(X_df.index, index=False).values
@@ -255,11 +300,14 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
255
300
  "knn_gpu_mem_ratio": float(self.knn_gpu_mem_ratio),
256
301
  "knn_gpu_mem_overhead": float(self.knn_gpu_mem_overhead),
257
302
  }
303
+ adj_format = "dense" if self.device.type == "mps" else "sparse"
258
304
  return {
259
305
  "n_samples": int(X_df.shape[0]),
260
306
  "n_features": int(X_df.shape[1]),
261
307
  "hash": hasher.hexdigest(),
262
308
  "knn_config": knn_config,
309
+ "adj_format": adj_format,
310
+ "device_type": self.device.type,
263
311
  }
264
312
 
265
313
  def _graph_cache_key(self, X_df: pd.DataFrame) -> Tuple[Any, ...]:
@@ -284,8 +332,7 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
284
332
  if meta_expected is None:
285
333
  meta_expected = self._graph_cache_meta(X_df)
286
334
  try:
287
- payload = torch.load(self.graph_cache_path,
288
- map_location=self.device)
335
+ payload = torch.load(self.graph_cache_path, map_location="cpu")
289
336
  except Exception as exc:
290
337
  print(
291
338
  f"[GNN] Failed to load cached graph from {self.graph_cache_path}: {exc}")
@@ -293,7 +340,13 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
293
340
  if isinstance(payload, dict) and "adj" in payload:
294
341
  meta_cached = payload.get("meta")
295
342
  if meta_cached == meta_expected:
296
- return payload["adj"].to(self.device)
343
+ adj = payload["adj"]
344
+ if self.device.type == "mps" and getattr(adj, "is_sparse", False):
345
+ print(
346
+ f"[GNN] Cached sparse graph incompatible with MPS; rebuilding: {self.graph_cache_path}"
347
+ )
348
+ return None
349
+ return adj.to(self.device)
297
350
  print(
298
351
  f"[GNN] Cached graph metadata mismatch; rebuilding: {self.graph_cache_path}")
299
352
  return None
@@ -408,6 +461,11 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
408
461
  return True
409
462
 
410
463
  def _normalized_adj(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
464
+ if self.device.type == "mps":
465
+ return self._normalized_adj_dense(edge_index, num_nodes)
466
+ return self._normalized_adj_sparse(edge_index, num_nodes)
467
+
468
+ def _normalized_adj_sparse(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
411
469
  values = torch.ones(edge_index.shape[1], device=self.device)
412
470
  adj = torch.sparse_coo_tensor(
413
471
  edge_index.to(self.device), values, (num_nodes, num_nodes))
@@ -421,6 +479,21 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
421
479
  adj.indices(), norm_values, size=adj.shape)
422
480
  return adj_norm
423
481
 
482
+ def _normalized_adj_dense(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
483
+ if self.mps_dense_max_nodes <= 0 or num_nodes > self.mps_dense_max_nodes:
484
+ raise RuntimeError(
485
+ f"MPS dense adjacency not supported for {num_nodes} nodes; "
486
+ f"max={self.mps_dense_max_nodes}. Falling back to CPU."
487
+ )
488
+ edge_index = edge_index.to(self.device)
489
+ adj = torch.zeros((num_nodes, num_nodes), device=self.device, dtype=torch.float32)
490
+ adj[edge_index[0], edge_index[1]] = 1.0
491
+ deg = adj.sum(dim=1)
492
+ deg_inv_sqrt = torch.pow(deg + 1e-8, -0.5)
493
+ adj = adj * deg_inv_sqrt.view(-1, 1)
494
+ adj = adj * deg_inv_sqrt.view(1, -1)
495
+ return adj
496
+
424
497
  def _tensorize_split(self, X, y, w, allow_none: bool = False):
425
498
  if X is None and allow_none:
426
499
  return None, None, None
@@ -462,17 +535,25 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
462
535
  if self._adj_cache_meta == meta_expected and self._adj_cache_tensor is not None:
463
536
  cached = self._adj_cache_tensor
464
537
  if cached.device != self.device:
465
- cached = cached.to(self.device)
466
- self._adj_cache_tensor = cached
467
- return cached
538
+ if self.device.type == "mps" and getattr(cached, "is_sparse", False):
539
+ self._adj_cache_tensor = None
540
+ else:
541
+ cached = cached.to(self.device)
542
+ self._adj_cache_tensor = cached
543
+ if self._adj_cache_tensor is not None:
544
+ return self._adj_cache_tensor
468
545
  else:
469
546
  cache_key = self._graph_cache_key(X_df)
470
547
  if self._adj_cache_key == cache_key and self._adj_cache_tensor is not None:
471
548
  cached = self._adj_cache_tensor
472
549
  if cached.device != self.device:
473
- cached = cached.to(self.device)
474
- self._adj_cache_tensor = cached
475
- return cached
550
+ if self.device.type == "mps" and getattr(cached, "is_sparse", False):
551
+ self._adj_cache_tensor = None
552
+ else:
553
+ cached = cached.to(self.device)
554
+ self._adj_cache_tensor = cached
555
+ if self._adj_cache_tensor is not None:
556
+ return self._adj_cache_tensor
476
557
  X_np = None
477
558
  if X_tensor is None:
478
559
  X_np = X_df.to_numpy(dtype=np.float32, copy=False)
@@ -511,7 +592,20 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
511
592
  def fit(self, X_train, y_train, w_train=None,
512
593
  X_val=None, y_val=None, w_val=None,
513
594
  trial: Optional[optuna.trial.Trial] = None):
595
+ return self._run_with_mps_fallback(
596
+ self._fit_impl,
597
+ X_train,
598
+ y_train,
599
+ w_train,
600
+ X_val,
601
+ y_val,
602
+ w_val,
603
+ trial,
604
+ )
514
605
 
606
+ def _fit_impl(self, X_train, y_train, w_train=None,
607
+ X_val=None, y_val=None, w_val=None,
608
+ trial: Optional[optuna.trial.Trial] = None):
515
609
  X_train_tensor, y_train_tensor, w_train_tensor = self._tensorize_split(
516
610
  X_train, y_train, w_train, allow_none=False)
517
611
  has_val = X_val is not None and y_val is not None
@@ -621,6 +715,9 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
621
715
  self.best_epoch = int(best_epoch or self.epochs)
622
716
 
623
717
  def predict(self, X: pd.DataFrame) -> np.ndarray:
718
+ return self._run_with_mps_fallback(self._predict_impl, X)
719
+
720
+ def _predict_impl(self, X: pd.DataFrame) -> np.ndarray:
624
721
  self.gnn.eval()
625
722
  X_tensor, _, _ = self._tensorize_split(
626
723
  X, None, None, allow_none=False)
@@ -640,6 +737,9 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
640
737
  return y_pred.ravel()
641
738
 
642
739
  def encode(self, X: pd.DataFrame) -> np.ndarray:
740
+ return self._run_with_mps_fallback(self._encode_impl, X)
741
+
742
+ def _encode_impl(self, X: pd.DataFrame) -> np.ndarray:
643
743
  """Return per-sample node embeddings (hidden representations)."""
644
744
  base = self._unwrap_gnn()
645
745
  base.eval()
@@ -655,7 +755,7 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
655
755
  raise RuntimeError("GNN base module does not expose layers.")
656
756
  for layer in layers:
657
757
  h = layer(h, adj)
658
- h = torch.sparse.mm(adj, h)
758
+ h = _adj_mm(adj, h)
659
759
  return h.detach().cpu().numpy()
660
760
 
661
761
  def set_params(self, params: Dict[str, Any]):
@@ -27,6 +27,7 @@ from sklearn.preprocessing import StandardScaler
27
27
  from ..config_preprocess import BayesOptConfig, OutputManager
28
28
  from ..utils import DistributedUtils, EPS, ensure_parent_dir
29
29
  from ins_pricing.utils import get_logger, GPUMemoryManager, DeviceManager
30
+ from ins_pricing.utils.torch_compat import torch_load
30
31
 
31
32
  # Module-level logger
32
33
  _logger = get_logger("ins_pricing.trainer")
@@ -616,7 +617,7 @@ class TrainerBase:
616
617
  pass
617
618
  else:
618
619
  # FT-Transformer: load state_dict and reconstruct model
619
- loaded = torch.load(path, map_location='cpu', weights_only=False)
620
+ loaded = torch_load(path, map_location='cpu', weights_only=False)
620
621
  if isinstance(loaded, dict):
621
622
  if "state_dict" in loaded and "model_config" in loaded:
622
623
  # New format: state_dict + model_config
@@ -1094,7 +1095,7 @@ class TrainerBase:
1094
1095
  split_iter = splitter
1095
1096
 
1096
1097
  losses: List[float] = []
1097
- for train_idx, val_idx in split_iter:
1098
+ for fold_idx, (train_idx, val_idx) in enumerate(split_iter):
1098
1099
  X_train = X_all.iloc[train_idx]
1099
1100
  y_train = y_all.iloc[train_idx]
1100
1101
  X_val = X_all.iloc[val_idx]
@@ -1108,9 +1109,11 @@ class TrainerBase:
1108
1109
  model = model_builder(params)
1109
1110
  try:
1110
1111
  if fit_predict_fn:
1112
+ # Avoid duplicate Optuna step reports across folds.
1113
+ trial_for_fold = trial if fold_idx == 0 else None
1111
1114
  y_pred = fit_predict_fn(
1112
1115
  model, X_train, y_train, w_train,
1113
- X_val, y_val, w_val, trial
1116
+ X_val, y_val, w_val, trial_for_fold
1114
1117
  )
1115
1118
  else:
1116
1119
  fit_kwargs = {}
@@ -1288,4 +1291,3 @@ class TrainerBase:
1288
1291
  predict_kwargs_train=predict_kwargs_train,
1289
1292
  predict_kwargs_test=predict_kwargs_test,
1290
1293
  predict_fn=predict_fn)
1291
-
@@ -12,6 +12,7 @@ from .trainer_base import TrainerBase
12
12
  from ..models import GraphNeuralNetSklearn
13
13
  from ..utils import EPS
14
14
  from ins_pricing.utils import get_logger
15
+ from ins_pricing.utils.torch_compat import torch_load
15
16
 
16
17
  _logger = get_logger("ins_pricing.trainer.gnn")
17
18
 
@@ -300,7 +301,7 @@ class GNNTrainer(TrainerBase):
300
301
  if not os.path.exists(path):
301
302
  print(f"[load] Warning: Model file not found: {path}")
302
303
  return
303
- payload = torch.load(path, map_location='cpu', weights_only=False)
304
+ payload = torch_load(path, map_location='cpu', weights_only=False)
304
305
  if not isinstance(payload, dict):
305
306
  raise ValueError(f"Invalid GNN checkpoint: {path}")
306
307
  params = payload.get("best_params") or {}
@@ -322,4 +323,3 @@ class GNNTrainer(TrainerBase):
322
323
  self.model = model
323
324
  self.best_params = dict(params) if isinstance(params, dict) else None
324
325
  self.ctx.gnn_best = self.model
325
-