ennbo 0.1.2__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- enn/__init__.py +25 -13
- enn/benchmarks/__init__.py +3 -0
- enn/benchmarks/ackley.py +5 -0
- enn/benchmarks/ackley_class.py +17 -0
- enn/benchmarks/ackley_core.py +12 -0
- enn/benchmarks/double_ackley.py +24 -0
- enn/enn/candidates.py +14 -0
- enn/enn/conditional_posterior_draw_internals.py +15 -0
- enn/enn/draw_internals.py +15 -0
- enn/enn/enn.py +16 -269
- enn/enn/enn_class.py +423 -0
- enn/enn/enn_conditional.py +325 -0
- enn/enn/enn_fit.py +69 -70
- enn/enn/enn_hash.py +79 -0
- enn/enn/enn_index.py +92 -0
- enn/enn/enn_like_protocol.py +35 -0
- enn/enn/enn_normal.py +0 -1
- enn/enn/enn_params.py +3 -22
- enn/enn/enn_params_class.py +24 -0
- enn/enn/enn_util.py +60 -46
- enn/enn/neighbor_data.py +14 -0
- enn/enn/neighbors.py +14 -0
- enn/enn/posterior_flags.py +8 -0
- enn/enn/weighted_stats.py +14 -0
- enn/turbo/components/__init__.py +41 -0
- enn/turbo/components/acquisition.py +13 -0
- enn/turbo/components/acquisition_optimizer_protocol.py +19 -0
- enn/turbo/components/builder.py +22 -0
- enn/turbo/components/chebyshev_incumbent_selector.py +76 -0
- enn/turbo/components/enn_surrogate.py +115 -0
- enn/turbo/components/gp_surrogate.py +144 -0
- enn/turbo/components/hnr_acq_optimizer.py +83 -0
- enn/turbo/components/incumbent_selector.py +11 -0
- enn/turbo/components/incumbent_selector_protocol.py +16 -0
- enn/turbo/components/no_incumbent_selector.py +21 -0
- enn/turbo/components/no_surrogate.py +49 -0
- enn/turbo/components/pareto_acq_optimizer.py +49 -0
- enn/turbo/components/posterior_result.py +12 -0
- enn/turbo/components/protocols.py +13 -0
- enn/turbo/components/random_acq_optimizer.py +21 -0
- enn/turbo/components/scalar_incumbent_selector.py +39 -0
- enn/turbo/components/surrogate_protocol.py +32 -0
- enn/turbo/components/surrogate_result.py +12 -0
- enn/turbo/components/surrogates.py +5 -0
- enn/turbo/components/thompson_acq_optimizer.py +49 -0
- enn/turbo/components/trust_region_protocol.py +24 -0
- enn/turbo/components/ucb_acq_optimizer.py +49 -0
- enn/turbo/config/__init__.py +87 -0
- enn/turbo/config/acq_type.py +8 -0
- enn/turbo/config/acquisition.py +26 -0
- enn/turbo/config/base.py +4 -0
- enn/turbo/config/candidate_gen_config.py +49 -0
- enn/turbo/config/candidate_rv.py +7 -0
- enn/turbo/config/draw_acquisition_config.py +14 -0
- enn/turbo/config/enn_index_driver.py +6 -0
- enn/turbo/config/enn_surrogate_config.py +42 -0
- enn/turbo/config/enums.py +7 -0
- enn/turbo/config/factory.py +118 -0
- enn/turbo/config/gp_surrogate_config.py +14 -0
- enn/turbo/config/hnr_optimizer_config.py +7 -0
- enn/turbo/config/init_config.py +17 -0
- enn/turbo/config/init_strategies/__init__.py +9 -0
- enn/turbo/config/init_strategies/hybrid_init.py +23 -0
- enn/turbo/config/init_strategies/init_strategy.py +19 -0
- enn/turbo/config/init_strategies/lhd_only_init.py +24 -0
- enn/turbo/config/morbo_tr_config.py +82 -0
- enn/turbo/config/nds_optimizer_config.py +7 -0
- enn/turbo/config/no_surrogate_config.py +14 -0
- enn/turbo/config/no_tr_config.py +31 -0
- enn/turbo/config/optimizer_config.py +72 -0
- enn/turbo/config/pareto_acquisition_config.py +14 -0
- enn/turbo/config/raasp_driver.py +6 -0
- enn/turbo/config/raasp_optimizer_config.py +7 -0
- enn/turbo/config/random_acquisition_config.py +14 -0
- enn/turbo/config/rescalarize.py +7 -0
- enn/turbo/config/surrogate.py +12 -0
- enn/turbo/config/trust_region.py +34 -0
- enn/turbo/config/turbo_tr_config.py +71 -0
- enn/turbo/config/ucb_acquisition_config.py +14 -0
- enn/turbo/config/validation.py +45 -0
- enn/turbo/hypervolume.py +30 -0
- enn/turbo/impl_helpers.py +68 -0
- enn/turbo/morbo_trust_region.py +131 -70
- enn/turbo/no_trust_region.py +32 -39
- enn/turbo/optimizer.py +300 -0
- enn/turbo/optimizer_config.py +8 -0
- enn/turbo/proposal.py +36 -38
- enn/turbo/sampling.py +21 -0
- enn/turbo/strategies/__init__.py +9 -0
- enn/turbo/strategies/lhd_only_strategy.py +36 -0
- enn/turbo/strategies/optimization_strategy.py +19 -0
- enn/turbo/strategies/turbo_hybrid_strategy.py +124 -0
- enn/turbo/tr_helpers.py +202 -0
- enn/turbo/turbo_gp.py +0 -1
- enn/turbo/turbo_gp_base.py +0 -1
- enn/turbo/turbo_gp_fit.py +187 -0
- enn/turbo/turbo_gp_noisy.py +0 -1
- enn/turbo/turbo_optimizer_utils.py +98 -0
- enn/turbo/turbo_trust_region.py +126 -58
- enn/turbo/turbo_utils.py +98 -161
- enn/turbo/types/__init__.py +7 -0
- enn/turbo/types/appendable_array.py +85 -0
- enn/turbo/types/gp_data_prep.py +13 -0
- enn/turbo/types/gp_fit_result.py +11 -0
- enn/turbo/types/obs_lists.py +10 -0
- enn/turbo/types/prepare_ask_result.py +14 -0
- enn/turbo/types/tell_inputs.py +14 -0
- {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/METADATA +18 -11
- ennbo-0.1.7.dist-info/RECORD +111 -0
- enn/enn/__init__.py +0 -4
- enn/turbo/__init__.py +0 -11
- enn/turbo/base_turbo_impl.py +0 -144
- enn/turbo/lhd_only_impl.py +0 -49
- enn/turbo/turbo_config.py +0 -72
- enn/turbo/turbo_enn_impl.py +0 -201
- enn/turbo/turbo_mode.py +0 -10
- enn/turbo/turbo_mode_impl.py +0 -76
- enn/turbo/turbo_one_impl.py +0 -302
- enn/turbo/turbo_optimizer.py +0 -525
- enn/turbo/turbo_zero_impl.py +0 -29
- ennbo-0.1.2.dist-info/RECORD +0 -29
- {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/WHEEL +0 -0
- {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
enn/__init__.py,sha256=Iv1cXM6xUCjGDN2x7qj-A9DFy92-6iXQOJke-7_DBbQ,1465
|
|
2
|
+
enn/benchmarks/__init__.py,sha256=50yqq29luUZQmH618jLq006MwimJBYCvlR9PMSs97MM,79
|
|
3
|
+
enn/benchmarks/ackley.py,sha256=LMx_8C4swL3oKbzm70S8wHe_fp7Z0JCbyoi4kx2E9YY,163
|
|
4
|
+
enn/benchmarks/ackley_class.py,sha256=AwLwopw7bUQ0QaN1yFmVdNPZdL4948lqGfKIvyXk7Os,519
|
|
5
|
+
enn/benchmarks/ackley_core.py,sha256=zIGlZuIkzfv2OLLU_R7hSSYnm0TCm7G7Nt7KlDXuu5k,328
|
|
6
|
+
enn/benchmarks/double_ackley.py,sha256=IXZntZ1byJxIlFh3aj3Rzw_2cnjih8c9nyKTt_iM0P8,766
|
|
7
|
+
enn/enn/candidates.py,sha256=axVT0bs0dt4ZEXmy_d540U1h0JJ-YIAqYFaUPaiXwa0,276
|
|
8
|
+
enn/enn/conditional_posterior_draw_internals.py,sha256=Rr9sntjci85ke-LEoZPbxq4le7yKK1qKsOnrSfwA7Jw,317
|
|
9
|
+
enn/enn/draw_internals.py,sha256=o0MEpVV03x8WNdKmOq4yDTerDmxaDc7jeo17b1XxvAo,297
|
|
10
|
+
enn/enn/enn.py,sha256=JNrvi6wDLvYDEVaQ5d94M9uGgBdZWgT49rausgIABOA,358
|
|
11
|
+
enn/enn/enn_class.py,sha256=BAzPZMpaink-ddfRD0N_vZTM1IyVJL7Y6cNe_0Xs_ps,15131
|
|
12
|
+
enn/enn/enn_conditional.py,sha256=rDR3NJq0JHnKUqRu7sma7XR8g1PPeUAJUlbt-iI63YI,10535
|
|
13
|
+
enn/enn/enn_fit.py,sha256=hNAG1FHMbZNBeq0T4CIXp-1aiGKRbupL2nnnlQ4fb0A,4573
|
|
14
|
+
enn/enn/enn_hash.py,sha256=yk8N8EAf-IXsXS_NkoyTtlW0SyxJqzdl6Y8dinB2JYo,3151
|
|
15
|
+
enn/enn/enn_index.py,sha256=nWuWhuA-MqyPZmTAAwJLB4sdv1n7jMmRBuRTS5LzZjI,3060
|
|
16
|
+
enn/enn/enn_like_protocol.py,sha256=bzRvvE-ggqB0CqpdshZb2JJPzyV-BKNYZOjalX54ufQ,925
|
|
17
|
+
enn/enn/enn_normal.py,sha256=MOPsoC3KUCNQ_UMd_Ynkpyg9KSWNdRibaqIyAXdj4yo,661
|
|
18
|
+
enn/enn/enn_params.py,sha256=9NfT2mX3xbXBfPrdKNgVwsu6zWLBNXO0h1IF9NDpfQI,127
|
|
19
|
+
enn/enn/enn_params_class.py,sha256=hfzU2ftyRXMTDTW3euK-v-2ir--fjRXdoh6OYN1c1bo,854
|
|
20
|
+
enn/enn/enn_util.py,sha256=YWJjtbIPpEBN4wJ5EaTrvJB686BKkDxTykewhslzJe8,4543
|
|
21
|
+
enn/enn/neighbor_data.py,sha256=K8Djik2kCfsHSlG8RSvpcudRVmm6l78QXUj1V2_pwMc,272
|
|
22
|
+
enn/enn/neighbors.py,sha256=PUEqtgwUYk7mPD8Gm6RtW3yfDp9J6DNlnuMzYz-XHHI,275
|
|
23
|
+
enn/enn/posterior_flags.py,sha256=La3ZdSy9ZoXtaR4P2D2MOQdaZJ1Vfaz5xK_Jy_-2wWE,187
|
|
24
|
+
enn/enn/weighted_stats.py,sha256=mrghm8hrfrxrQ-jTSrakUjJ5iIdeD7Ilx2H31KaTU50,277
|
|
25
|
+
enn/turbo/hypervolume.py,sha256=lMTA9LlBXNo9kt83POfI4DfcAuvzYD3MgcN8AjIFVQM,910
|
|
26
|
+
enn/turbo/impl_helpers.py,sha256=hlfNBqWMzCJi6PtP8hJ8jlZNv68HIIBSVfsJXQQYgJM,1695
|
|
27
|
+
enn/turbo/morbo_trust_region.py,sha256=AUDeGsnljIYhCflPtPlQTe3TuEIHvWeNp_i07M4mCzQ,8690
|
|
28
|
+
enn/turbo/no_trust_region.py,sha256=3OKAzYrn_F8jXrZyE9LiYHIzJqEl6byjXeuzszwzm5s,1679
|
|
29
|
+
enn/turbo/optimizer.py,sha256=wCTV7a2nwejSgQTw4D2eKWcqTqgZMz_UnxmwshZ2HoU,11146
|
|
30
|
+
enn/turbo/optimizer_config.py,sha256=nIEu-e2rRk0XbEbiC8WkOsTk36go0ocdKsAOjAWD3mI,174
|
|
31
|
+
enn/turbo/proposal.py,sha256=00SHsdgFXi8_d734kpXadOBC6Hz0HO9Pqo4tWuqoKlw,4430
|
|
32
|
+
enn/turbo/sampling.py,sha256=y1ivaEwOYPTo8LVIyMUJ4J5mb6XGXn1EuhSwTYsqvGU,662
|
|
33
|
+
enn/turbo/tr_helpers.py,sha256=KsQBFq9uykeBPuPVpBAhEyZcO1HotbWFkKmKca2zCk0,5417
|
|
34
|
+
enn/turbo/turbo_gp.py,sha256=PY_R4YNPtH7-d2MR2M2eoRnlTgSAkT3Bv1hLGzs7IcQ,1068
|
|
35
|
+
enn/turbo/turbo_gp_base.py,sha256=EtFp6yUber1EIOsKjdNPBjjG6gg3jy-EZHgN4GAseDA,637
|
|
36
|
+
enn/turbo/turbo_gp_fit.py,sha256=JGcDP36DL1-lj93INHjyWCc0EqnhJh1psLPNLiMMy_w,5916
|
|
37
|
+
enn/turbo/turbo_gp_noisy.py,sha256=SCcZPRM5G8kZHdUjHE6whoobOVXOwSZRB6YwlwWFBCg,1079
|
|
38
|
+
enn/turbo/turbo_optimizer_utils.py,sha256=Pl_i5aN3pAT9l4u_0HabJyHo1wBIWvcE3qaSqeAz4QU,3349
|
|
39
|
+
enn/turbo/turbo_trust_region.py,sha256=sNMXYsPbFRfmYxHRUWz8604G3T_mKtJN0WYFTiw_DPg,6608
|
|
40
|
+
enn/turbo/turbo_utils.py,sha256=DPUoaCSAHU2WYSLGh29rREIgYpj2rEir95UBx0w-yVI,7788
|
|
41
|
+
enn/turbo/components/__init__.py,sha256=xDYcJhyIbnaTC2nw84Zt0Y_kKNU7Z1k1DFdQXhrgC68,906
|
|
42
|
+
enn/turbo/components/acquisition.py,sha256=IyBWLG2QyLPduW1wW4LA37iW3XSCp0Z_m259yeMjMnk,398
|
|
43
|
+
enn/turbo/components/acquisition_optimizer_protocol.py,sha256=_PC4ngQbacB8otvPZroCsx5wWZK1PfRjo_bl6aqWPLI,454
|
|
44
|
+
enn/turbo/components/builder.py,sha256=BrhhhkP1k_XjVJt7pnlhHVsnA1En77RBu-vqvffJNvY,822
|
|
45
|
+
enn/turbo/components/chebyshev_incumbent_selector.py,sha256=nJNPAMBIafkL09BzeFPecA5LrUvi13xjJnzGmd71qdo,2453
|
|
46
|
+
enn/turbo/components/enn_surrogate.py,sha256=9_UgTJ3Lk3hsbcdsbUXegRTDQtQa41cFca_ttRFEy7k,4598
|
|
47
|
+
enn/turbo/components/gp_surrogate.py,sha256=yvLknU_NE7S4qrPcDBniB9UYtebvS9_ak5VlUqfzU9A,6012
|
|
48
|
+
enn/turbo/components/hnr_acq_optimizer.py,sha256=wfmJjA9D98HjOyybvlvERru_pVvWTD1QV1REyI2CeBU,2836
|
|
49
|
+
enn/turbo/components/incumbent_selector.py,sha256=mPwtqGdG22b_ZiNinpV4dxKm3Vkvcl4MeYwo9vAVRlg,378
|
|
50
|
+
enn/turbo/components/incumbent_selector_protocol.py,sha256=5LVVsLvCmOeJ9vTIxQ3g59m-15fEZvkL22yzg8abfHM,379
|
|
51
|
+
enn/turbo/components/no_incumbent_selector.py,sha256=toLi-DxPRIlEqCtCOFkq5na626NZRkFnH3tb81e6FXo,429
|
|
52
|
+
enn/turbo/components/no_surrogate.py,sha256=7I6A4ceMhOU5y50blPTM3KgUtzFZyJsxKUVOr9bifmk,1777
|
|
53
|
+
enn/turbo/components/pareto_acq_optimizer.py,sha256=cPXjAtYf3P_xOCYF9ga8FIcLSEoTJ6_VUQjAOJ9TSIs,1853
|
|
54
|
+
enn/turbo/components/posterior_result.py,sha256=LM3ZeE1eT-8gyOhkpa3PF0THXRKH0URTQO0ija1eMbc,235
|
|
55
|
+
enn/turbo/components/protocols.py,sha256=Rwpln-22_9TS1PB5XFYG-B_wqUOOgfaW54tdZbW6nv4,371
|
|
56
|
+
enn/turbo/components/random_acq_optimizer.py,sha256=TEMclz-zQ0ncccmxj58dfxW1kwEwUFfEJG2DzV1gkII,514
|
|
57
|
+
enn/turbo/components/scalar_incumbent_selector.py,sha256=YXdpH9eLwZlYBOyroAyiuQPvJfj9G4jn_wKGEk0PkVk,1092
|
|
58
|
+
enn/turbo/components/surrogate_protocol.py,sha256=kXoeyrdX005yUS-ObaVzx_PeyCV6gLHsnOC36_tHMBY,985
|
|
59
|
+
enn/turbo/components/surrogate_result.py,sha256=ax_y4qZ2a99XdwAbLqHjYAY0ck607oDUFonOtGwgLGU,243
|
|
60
|
+
enn/turbo/components/surrogates.py,sha256=RxEl4VeIaFlLEwas3S3XZ3scxIjuS70Fa0Q8cGrm6LU,174
|
|
61
|
+
enn/turbo/components/thompson_acq_optimizer.py,sha256=Eb_3KPrZRyMBTfgO9OmUQnkXNxcsBZdcb6ecmbpNm4Q,1847
|
|
62
|
+
enn/turbo/components/trust_region_protocol.py,sha256=X8RCN8FtrWpua9VRGHZtSz_wNMxqBdqopULZyopLdPk,879
|
|
63
|
+
enn/turbo/components/ucb_acq_optimizer.py,sha256=RO_9HGCVb6xAGDL8iFJqF4ZkapPbG7tWCvGhLh2C7es,1797
|
|
64
|
+
enn/turbo/config/__init__.py,sha256=dH8GVPL0MJuCdyA_Mo1v__tlLtjd1G2Se3yM8UGIP1Q,1924
|
|
65
|
+
enn/turbo/config/acq_type.py,sha256=P89H41Xt26pCfCt1ODByRDv-TU-Zb5XKlgLcFHb-cpo,144
|
|
66
|
+
enn/turbo/config/acquisition.py,sha256=cPb6PrF6cgcr4c57ylPAspGb-ScDZflyKamnohc_R2c,891
|
|
67
|
+
enn/turbo/config/base.py,sha256=LhX0siqu7YbJ9BxDqC0MujupYwM-B6RLPg2NjdF7Owc,137
|
|
68
|
+
enn/turbo/config/candidate_gen_config.py,sha256=g5OzuMLfDoU-htcjdI4N5V5fnhXdQa_drNF5Ee7ws3o,1507
|
|
69
|
+
enn/turbo/config/candidate_rv.py,sha256=dQumGSHF3uE4HTC0Ix3yzmmsAaLsSs_m8NB93Kpji6g,128
|
|
70
|
+
enn/turbo/config/draw_acquisition_config.py,sha256=nIg6jlWjgjoNyHMKBI5i3zMxu4-WQSitEQGy3h6cUQQ,386
|
|
71
|
+
enn/turbo/config/enn_index_driver.py,sha256=qDrNXatK4nMvUi8h_rkg9Uv_M9WfLdlsTHTP4X3COjE,94
|
|
72
|
+
enn/turbo/config/enn_surrogate_config.py,sha256=dsvDkgPVm5qv8zhE3AvUN9yX37cyGrno0E0GNwodXoY,1269
|
|
73
|
+
enn/turbo/config/enums.py,sha256=nlxqrB3sivvs2_1fMTLGCg0MfoSQ0ZUG8kREM3raSfY,274
|
|
74
|
+
enn/turbo/config/factory.py,sha256=eZNw6Zao7zrKeQHaeLRVDP2tdqsVguSZFoBiW31UkUI,4458
|
|
75
|
+
enn/turbo/config/gp_surrogate_config.py,sha256=tXSF4XPrvNfmZvEHJavGK5AaJ7ZyWWjV105qe4y_DZM,341
|
|
76
|
+
enn/turbo/config/hnr_optimizer_config.py,sha256=oo1Cn0jcaTZXz8L0730KvGVD6m6xiQxmINMVWg6Fj68,130
|
|
77
|
+
enn/turbo/config/init_config.py,sha256=AZhu75v_O-8I3Li6ptgb-uaNv2aP6xX4BF0v58qVS24,610
|
|
78
|
+
enn/turbo/config/morbo_tr_config.py,sha256=RPUEm5wooZja38QwTvMLkA1jRf9cvyBOkMQ-WOVsX8k,2021
|
|
79
|
+
enn/turbo/config/nds_optimizer_config.py,sha256=wuNFqlVVlDLU0tU2XdogLOTMni90-ngLB19nt7XKJ6A,130
|
|
80
|
+
enn/turbo/config/no_surrogate_config.py,sha256=Tf0CXESMbwXlX7l2Tk1HGzH1lNY3c0qk995bHGpZaPk,341
|
|
81
|
+
enn/turbo/config/no_tr_config.py,sha256=5QXJQtjKwNMXs7QqhgLgA3HngCh73g0a-gWzt_OzWNU,789
|
|
82
|
+
enn/turbo/config/optimizer_config.py,sha256=WDYVyAWeX5GTBihlgArJ4XrcZAONw_0QdR2ylHyJUpE,2243
|
|
83
|
+
enn/turbo/config/pareto_acquisition_config.py,sha256=UsH4r4RPWAXgPDgcxlUJ9N0x4u1ZB5Wn3OBDu4Fxg6k,384
|
|
84
|
+
enn/turbo/config/raasp_driver.py,sha256=NY4YZdqhPrusbpPjSLRY2MiDv5_f7CpUXlyVsGx3an8,91
|
|
85
|
+
enn/turbo/config/raasp_optimizer_config.py,sha256=LQUe6nYx8gw3b2OXUGodkK_EK3jvIKSBc9DLhKGOzkM,132
|
|
86
|
+
enn/turbo/config/random_acquisition_config.py,sha256=ktNKCqnAOt69Cbsx9r5pVcgcb1gUYLnRrbureSq4BGA,384
|
|
87
|
+
enn/turbo/config/rescalarize.py,sha256=dB80E62uAGmNRnWrV8aLoTjehFB3pWN1XMkgMh03RAc,144
|
|
88
|
+
enn/turbo/config/surrogate.py,sha256=S1BqKv9Fd-ZmfAXNUKdLFrYhSNY_zdBNL45kvD8jhUo,380
|
|
89
|
+
enn/turbo/config/trust_region.py,sha256=yjDnCq74R81jLu_6e7fCHy5s_m8CZGVABxAPEVUXyTM,769
|
|
90
|
+
enn/turbo/config/turbo_tr_config.py,sha256=5yGVk0_es5yYjYH94GV29q1nWNoY32hed9mmzF0tKkw,2188
|
|
91
|
+
enn/turbo/config/ucb_acquisition_config.py,sha256=00SICPVJZaADCu1Mv0DTDgFfBVeilc5IBzlnjjyb5ko,375
|
|
92
|
+
enn/turbo/config/validation.py,sha256=OCnULdJY6I7QCHqSJWEgURVwcq6CLgu214LTWNWHYiY,1974
|
|
93
|
+
enn/turbo/config/init_strategies/__init__.py,sha256=XgowibgAkggbioBgARZ7LR4y1bUHopZs6luoFMRdgsw,187
|
|
94
|
+
enn/turbo/config/init_strategies/hybrid_init.py,sha256=GDzBMvYXJMfWkaNpnDFNfe8yl7XAhIPaNVS4jkK9nD8,643
|
|
95
|
+
enn/turbo/config/init_strategies/init_strategy.py,sha256=yRKGd9EjMh09rLKejZOkcj7pAL7vnqbpNjr9ax_oLrs,458
|
|
96
|
+
enn/turbo/config/init_strategies/lhd_only_init.py,sha256=84xUlaifq98Vonr0ltEoq8xArS1knRuI0y8qhhNe1QE,638
|
|
97
|
+
enn/turbo/strategies/__init__.py,sha256=PA-p1M-4w8Miv1wLqZF0FODyhyVr3_0vED-qcwiaD8E,251
|
|
98
|
+
enn/turbo/strategies/lhd_only_strategy.py,sha256=nThr0BBHOb0fBialCU7ToYz-BEodS72ZMyHInkiybnc,1154
|
|
99
|
+
enn/turbo/strategies/optimization_strategy.py,sha256=xwiU_P9_1YBt1UaPthNsMo__QvEo9dREc4mJwr05dBA,548
|
|
100
|
+
enn/turbo/strategies/turbo_hybrid_strategy.py,sha256=H-JtlAl3NR0MziiPETCFAY89AusWpObLy-j-xhgWaaU,4567
|
|
101
|
+
enn/turbo/types/__init__.py,sha256=mIP7fNkcgQ1WLPeCSMTzcwbZuNiMlrCt8JLphfml8KU,280
|
|
102
|
+
enn/turbo/types/appendable_array.py,sha256=YAxXuC6QD_52tzYBssJgjP6Dm3rQSYS1NxE3FwKWT4w,2812
|
|
103
|
+
enn/turbo/types/gp_data_prep.py,sha256=KZACZCyqpoDGCMQ4WhRTGAnZhKO-FHtIBB3_dGfAVcc,222
|
|
104
|
+
enn/turbo/types/gp_fit_result.py,sha256=7jvGQxT0XZx8rbOFdmm_nLwV_L3KwybjwEV-uZTLLXM,190
|
|
105
|
+
enn/turbo/types/obs_lists.py,sha256=LytRAkOgRvORMLS-lpP97AvSj2DfrFu2kyHM7JLi7GE,164
|
|
106
|
+
enn/turbo/types/prepare_ask_result.py,sha256=uTFdFZ-e7s1ZmJuyf6l1H4qDTxAmwq5UznP5pVc3r4Y,286
|
|
107
|
+
enn/turbo/types/tell_inputs.py,sha256=OqkkL55MpVVCzhFJuLajYuxrQ6WsJRZ4R521Zn7SBr4,261
|
|
108
|
+
ennbo-0.1.7.dist-info/METADATA,sha256=niGm5jUopcIGOZYrrYyffQk_m3bFaHfGS50YVSy8nRU,6085
|
|
109
|
+
ennbo-0.1.7.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
110
|
+
ennbo-0.1.7.dist-info/licenses/LICENSE,sha256=KTA0NjGalsl_JGrjT_x6SSq9ZYVO3gQ-hLVMEaekc5w,1070
|
|
111
|
+
ennbo-0.1.7.dist-info/RECORD,,
|
enn/enn/__init__.py
DELETED
enn/turbo/__init__.py
DELETED
enn/turbo/base_turbo_impl.py
DELETED
|
@@ -1,144 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
from typing import TYPE_CHECKING, Any, Callable
|
|
4
|
-
|
|
5
|
-
if TYPE_CHECKING:
|
|
6
|
-
import numpy as np
|
|
7
|
-
from numpy.random import Generator
|
|
8
|
-
|
|
9
|
-
from .turbo_config import TurboConfig
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class BaseTurboImpl:
|
|
13
|
-
def __init__(self, config: TurboConfig) -> None:
|
|
14
|
-
self._config = config
|
|
15
|
-
|
|
16
|
-
def get_x_center(
|
|
17
|
-
self,
|
|
18
|
-
x_obs_list: list,
|
|
19
|
-
y_obs_list: list,
|
|
20
|
-
rng: Generator,
|
|
21
|
-
tr_state: Any = None,
|
|
22
|
-
) -> np.ndarray | None:
|
|
23
|
-
import numpy as np
|
|
24
|
-
|
|
25
|
-
from .turbo_utils import argmax_random_tie
|
|
26
|
-
|
|
27
|
-
y_array = np.asarray(y_obs_list, dtype=float)
|
|
28
|
-
if y_array.size == 0:
|
|
29
|
-
return None
|
|
30
|
-
x_array = np.asarray(x_obs_list, dtype=float)
|
|
31
|
-
|
|
32
|
-
# For morbo: scalarize raw y observations
|
|
33
|
-
if self._config.tr_type == "morbo" and tr_state is not None:
|
|
34
|
-
if y_array.ndim == 1:
|
|
35
|
-
y_array = y_array.reshape(-1, tr_state.num_metrics)
|
|
36
|
-
scalarized = tr_state.scalarize(y_array, clip=True)
|
|
37
|
-
idx = argmax_random_tie(scalarized, rng=rng)
|
|
38
|
-
else:
|
|
39
|
-
idx = argmax_random_tie(y_array, rng=rng)
|
|
40
|
-
|
|
41
|
-
return x_array[idx]
|
|
42
|
-
|
|
43
|
-
def needs_tr_list(self) -> bool:
|
|
44
|
-
return False
|
|
45
|
-
|
|
46
|
-
def create_trust_region(
|
|
47
|
-
self,
|
|
48
|
-
num_dim: int,
|
|
49
|
-
num_arms: int,
|
|
50
|
-
rng: Generator,
|
|
51
|
-
num_metrics: int | None = None,
|
|
52
|
-
) -> Any:
|
|
53
|
-
if self._config.tr_type == "none":
|
|
54
|
-
from .no_trust_region import NoTrustRegion
|
|
55
|
-
|
|
56
|
-
return NoTrustRegion(num_dim=num_dim, num_arms=num_arms)
|
|
57
|
-
elif self._config.tr_type == "turbo":
|
|
58
|
-
from .turbo_trust_region import TurboTrustRegion
|
|
59
|
-
|
|
60
|
-
return TurboTrustRegion(num_dim=num_dim, num_arms=num_arms)
|
|
61
|
-
elif self._config.tr_type == "morbo":
|
|
62
|
-
from .morbo_trust_region import MorboTrustRegion
|
|
63
|
-
|
|
64
|
-
effective_num_metrics = num_metrics or self._config.num_metrics
|
|
65
|
-
if effective_num_metrics is None:
|
|
66
|
-
raise ValueError("num_metrics required for tr_type='morbo'")
|
|
67
|
-
return MorboTrustRegion(
|
|
68
|
-
num_dim=num_dim,
|
|
69
|
-
num_arms=num_arms,
|
|
70
|
-
num_metrics=effective_num_metrics,
|
|
71
|
-
rng=rng,
|
|
72
|
-
)
|
|
73
|
-
else:
|
|
74
|
-
raise ValueError(f"Unknown tr_type: {self._config.tr_type!r}")
|
|
75
|
-
|
|
76
|
-
def try_early_ask(
|
|
77
|
-
self,
|
|
78
|
-
num_arms: int,
|
|
79
|
-
x_obs_list: list,
|
|
80
|
-
draw_initial_fn: Callable[[int], np.ndarray],
|
|
81
|
-
get_init_lhd_points_fn: Callable[[int], np.ndarray],
|
|
82
|
-
) -> np.ndarray | None:
|
|
83
|
-
return None
|
|
84
|
-
|
|
85
|
-
def handle_restart(
|
|
86
|
-
self,
|
|
87
|
-
x_obs_list: list,
|
|
88
|
-
y_obs_list: list,
|
|
89
|
-
yvar_obs_list: list,
|
|
90
|
-
init_idx: int,
|
|
91
|
-
num_init: int,
|
|
92
|
-
) -> tuple[bool, int]:
|
|
93
|
-
if self._config.tr_type == "morbo":
|
|
94
|
-
x_obs_list.clear()
|
|
95
|
-
y_obs_list.clear()
|
|
96
|
-
yvar_obs_list.clear()
|
|
97
|
-
return True, 0
|
|
98
|
-
return False, init_idx
|
|
99
|
-
|
|
100
|
-
def prepare_ask(
|
|
101
|
-
self,
|
|
102
|
-
x_obs_list: list,
|
|
103
|
-
y_obs_list: list,
|
|
104
|
-
yvar_obs_list: list,
|
|
105
|
-
num_dim: int,
|
|
106
|
-
gp_num_steps: int,
|
|
107
|
-
rng: Any | None = None,
|
|
108
|
-
) -> tuple[Any, float | None, float | None, np.ndarray | None]:
|
|
109
|
-
return None, None, None, None
|
|
110
|
-
|
|
111
|
-
def select_candidates(
|
|
112
|
-
self,
|
|
113
|
-
x_cand: np.ndarray,
|
|
114
|
-
num_arms: int,
|
|
115
|
-
num_dim: int,
|
|
116
|
-
rng: Generator,
|
|
117
|
-
fallback_fn: Callable[[np.ndarray, int], np.ndarray],
|
|
118
|
-
from_unit_fn: Callable[[np.ndarray], np.ndarray],
|
|
119
|
-
tr_state: Any = None,
|
|
120
|
-
) -> np.ndarray:
|
|
121
|
-
raise NotImplementedError("Subclasses must implement select_candidates")
|
|
122
|
-
|
|
123
|
-
def update_trust_region(
|
|
124
|
-
self,
|
|
125
|
-
tr_state: Any,
|
|
126
|
-
x_obs_list: list,
|
|
127
|
-
y_obs_list: list,
|
|
128
|
-
x_center: np.ndarray | None = None,
|
|
129
|
-
k: int | None = None,
|
|
130
|
-
) -> None:
|
|
131
|
-
import numpy as np
|
|
132
|
-
|
|
133
|
-
x_obs_array = np.asarray(x_obs_list, dtype=float)
|
|
134
|
-
y_obs_array = np.asarray(y_obs_list, dtype=float)
|
|
135
|
-
if hasattr(tr_state, "update_xy"):
|
|
136
|
-
tr_state.update_xy(x_obs_array, y_obs_array, k=k)
|
|
137
|
-
else:
|
|
138
|
-
tr_state.update(y_obs_array)
|
|
139
|
-
|
|
140
|
-
def estimate_y(self, x_unit: np.ndarray, y_observed: np.ndarray) -> np.ndarray:
|
|
141
|
-
return y_observed
|
|
142
|
-
|
|
143
|
-
def get_mu_sigma(self, x_unit: np.ndarray) -> tuple[np.ndarray, np.ndarray] | None:
|
|
144
|
-
return None
|
enn/turbo/lhd_only_impl.py
DELETED
|
@@ -1,49 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
from typing import TYPE_CHECKING, Any, Callable
|
|
4
|
-
|
|
5
|
-
if TYPE_CHECKING:
|
|
6
|
-
import numpy as np
|
|
7
|
-
from numpy.random import Generator
|
|
8
|
-
|
|
9
|
-
from .base_turbo_impl import BaseTurboImpl
|
|
10
|
-
from .turbo_config import LHDOnlyConfig
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class LHDOnlyImpl(BaseTurboImpl):
|
|
14
|
-
def __init__(self, config: LHDOnlyConfig) -> None:
|
|
15
|
-
super().__init__(config)
|
|
16
|
-
|
|
17
|
-
def get_x_center(
|
|
18
|
-
self,
|
|
19
|
-
x_obs_list: list,
|
|
20
|
-
y_obs_list: list,
|
|
21
|
-
rng: Generator,
|
|
22
|
-
tr_state: Any = None,
|
|
23
|
-
) -> np.ndarray | None:
|
|
24
|
-
return None
|
|
25
|
-
|
|
26
|
-
def select_candidates(
|
|
27
|
-
self,
|
|
28
|
-
x_cand: np.ndarray,
|
|
29
|
-
num_arms: int,
|
|
30
|
-
num_dim: int,
|
|
31
|
-
rng: Generator,
|
|
32
|
-
fallback_fn: Callable[[np.ndarray, int], np.ndarray],
|
|
33
|
-
from_unit_fn: Callable[[np.ndarray], np.ndarray],
|
|
34
|
-
tr_state: Any = None, # noqa: ARG002
|
|
35
|
-
) -> np.ndarray:
|
|
36
|
-
from .turbo_utils import latin_hypercube
|
|
37
|
-
|
|
38
|
-
unit = latin_hypercube(num_arms, num_dim, rng=rng)
|
|
39
|
-
return from_unit_fn(unit)
|
|
40
|
-
|
|
41
|
-
def update_trust_region(
|
|
42
|
-
self,
|
|
43
|
-
tr_state: Any,
|
|
44
|
-
x_obs_list: list,
|
|
45
|
-
y_obs_list: list,
|
|
46
|
-
x_center: np.ndarray | None = None,
|
|
47
|
-
k: int | None = None,
|
|
48
|
-
) -> None:
|
|
49
|
-
pass
|
enn/turbo/turbo_config.py
DELETED
|
@@ -1,72 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
from dataclasses import dataclass
|
|
4
|
-
from typing import Literal
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
@dataclass(frozen=True)
|
|
8
|
-
class TurboConfig:
|
|
9
|
-
k: int | None = None
|
|
10
|
-
num_candidates: int | None = None
|
|
11
|
-
num_init: int | None = None
|
|
12
|
-
|
|
13
|
-
# Experimental
|
|
14
|
-
trailing_obs: int | None = None
|
|
15
|
-
tr_type: Literal["turbo", "morbo", "none"] = "turbo"
|
|
16
|
-
num_metrics: int | None = None
|
|
17
|
-
|
|
18
|
-
def __post_init__(self) -> None:
|
|
19
|
-
if self.tr_type not in ["turbo", "morbo", "none"]:
|
|
20
|
-
raise ValueError(
|
|
21
|
-
f"tr_type must be 'turbo', 'morbo', or 'none', got {self.tr_type!r}"
|
|
22
|
-
)
|
|
23
|
-
if self.num_metrics is not None and self.num_metrics < 1:
|
|
24
|
-
raise ValueError(f"num_metrics must be >= 1, got {self.num_metrics}")
|
|
25
|
-
if self.tr_type == "turbo":
|
|
26
|
-
if self.num_metrics is not None and self.num_metrics != 1:
|
|
27
|
-
raise ValueError(
|
|
28
|
-
f"num_metrics must be 1 for tr_type='turbo', got {self.num_metrics}"
|
|
29
|
-
)
|
|
30
|
-
if self.tr_type == "none":
|
|
31
|
-
if self.num_metrics is not None and self.num_metrics != 1:
|
|
32
|
-
raise ValueError(
|
|
33
|
-
f"num_metrics must be 1 for tr_type='none', got {self.num_metrics}"
|
|
34
|
-
)
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
@dataclass(frozen=True)
|
|
38
|
-
class TurboOneConfig(TurboConfig):
|
|
39
|
-
pass
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
@dataclass(frozen=True)
|
|
43
|
-
class TurboZeroConfig(TurboConfig):
|
|
44
|
-
pass
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
@dataclass(frozen=True)
|
|
48
|
-
class LHDOnlyConfig(TurboConfig):
|
|
49
|
-
pass
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
@dataclass(frozen=True)
|
|
53
|
-
class TurboENNConfig(TurboConfig):
|
|
54
|
-
acq_type: Literal["thompson", "pareto", "ucb"] = "pareto"
|
|
55
|
-
num_fit_samples: int | None = None
|
|
56
|
-
num_fit_candidates: int | None = None
|
|
57
|
-
scale_x: bool = False
|
|
58
|
-
|
|
59
|
-
def __post_init__(self) -> None:
|
|
60
|
-
super().__post_init__()
|
|
61
|
-
if self.acq_type not in ["thompson", "pareto", "ucb"]:
|
|
62
|
-
raise ValueError(
|
|
63
|
-
f"acq_type must be 'thompson', 'pareto', or 'ucb', got {self.acq_type!r}"
|
|
64
|
-
)
|
|
65
|
-
if self.num_fit_samples is None and self.acq_type != "pareto":
|
|
66
|
-
raise ValueError(f"num_fit_samples required for acq_type={self.acq_type!r}")
|
|
67
|
-
if self.num_fit_samples is not None and int(self.num_fit_samples) <= 0:
|
|
68
|
-
raise ValueError(f"num_fit_samples must be > 0, got {self.num_fit_samples}")
|
|
69
|
-
if self.num_fit_candidates is not None and int(self.num_fit_candidates) <= 0:
|
|
70
|
-
raise ValueError(
|
|
71
|
-
f"num_fit_candidates must be > 0, got {self.num_fit_candidates}"
|
|
72
|
-
)
|
enn/turbo/turbo_enn_impl.py
DELETED
|
@@ -1,201 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
from typing import TYPE_CHECKING, Any, Callable
|
|
4
|
-
|
|
5
|
-
if TYPE_CHECKING:
|
|
6
|
-
import numpy as np
|
|
7
|
-
from numpy.random import Generator
|
|
8
|
-
|
|
9
|
-
from .base_turbo_impl import BaseTurboImpl
|
|
10
|
-
from .turbo_config import TurboENNConfig
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class TurboENNImpl(BaseTurboImpl):
|
|
14
|
-
def __init__(self, config: TurboENNConfig) -> None:
|
|
15
|
-
super().__init__(config)
|
|
16
|
-
self._enn: Any | None = None
|
|
17
|
-
self._fitted_params: Any | None = None
|
|
18
|
-
self._fitted_n_obs: int = 0
|
|
19
|
-
|
|
20
|
-
def get_x_center(
|
|
21
|
-
self,
|
|
22
|
-
x_obs_list: list,
|
|
23
|
-
y_obs_list: list,
|
|
24
|
-
rng: Generator,
|
|
25
|
-
tr_state: Any = None,
|
|
26
|
-
) -> np.ndarray | None:
|
|
27
|
-
import numpy as np
|
|
28
|
-
|
|
29
|
-
from .turbo_utils import argmax_random_tie
|
|
30
|
-
|
|
31
|
-
if len(y_obs_list) == 0:
|
|
32
|
-
return None
|
|
33
|
-
if self._enn is None or self._fitted_params is None:
|
|
34
|
-
return super().get_x_center(x_obs_list, y_obs_list, rng, tr_state)
|
|
35
|
-
if self._fitted_n_obs != len(x_obs_list):
|
|
36
|
-
raise RuntimeError(
|
|
37
|
-
f"ENN fitted on {self._fitted_n_obs} obs but get_x_center called with {len(x_obs_list)}"
|
|
38
|
-
)
|
|
39
|
-
|
|
40
|
-
y_array = np.asarray(y_obs_list, dtype=float)
|
|
41
|
-
x_array = np.asarray(x_obs_list, dtype=float)
|
|
42
|
-
k = self._config.k if self._config.k is not None else 10
|
|
43
|
-
|
|
44
|
-
# For morbo: top-k per metric → union → scalarize mu
|
|
45
|
-
if self._config.tr_type == "morbo" and tr_state is not None:
|
|
46
|
-
if y_array.ndim == 1:
|
|
47
|
-
y_array = y_array.reshape(-1, tr_state.num_metrics)
|
|
48
|
-
num_metrics = y_array.shape[1]
|
|
49
|
-
|
|
50
|
-
# Find top-k indices for each metric and take union
|
|
51
|
-
union_indices = set()
|
|
52
|
-
for m in range(num_metrics):
|
|
53
|
-
num_top = min(k, len(y_array))
|
|
54
|
-
top_m = np.argpartition(-y_array[:, m], num_top - 1)[:num_top]
|
|
55
|
-
union_indices.update(top_m.tolist())
|
|
56
|
-
union_indices = np.array(sorted(union_indices), dtype=int)
|
|
57
|
-
|
|
58
|
-
x_union = x_array[union_indices]
|
|
59
|
-
posterior = self._enn.posterior(x_union, params=self._fitted_params)
|
|
60
|
-
mu = posterior.mu # (len(union), num_metrics)
|
|
61
|
-
|
|
62
|
-
scalarized = tr_state.scalarize(mu, clip=False)
|
|
63
|
-
best_idx_in_union = argmax_random_tie(scalarized, rng=rng)
|
|
64
|
-
return x_union[best_idx_in_union]
|
|
65
|
-
else:
|
|
66
|
-
# Single-objective: original logic
|
|
67
|
-
num_top = min(k, len(y_array))
|
|
68
|
-
top_indices = np.argpartition(-y_array, num_top - 1)[:num_top]
|
|
69
|
-
|
|
70
|
-
x_top = x_array[top_indices]
|
|
71
|
-
posterior = self._enn.posterior(x_top, params=self._fitted_params)
|
|
72
|
-
mu = posterior.mu[:, 0]
|
|
73
|
-
|
|
74
|
-
best_idx_in_top = argmax_random_tie(mu, rng=rng)
|
|
75
|
-
return x_top[best_idx_in_top]
|
|
76
|
-
|
|
77
|
-
def needs_tr_list(self) -> bool:
|
|
78
|
-
return True
|
|
79
|
-
|
|
80
|
-
def handle_restart(
|
|
81
|
-
self,
|
|
82
|
-
x_obs_list: list,
|
|
83
|
-
y_obs_list: list,
|
|
84
|
-
yvar_obs_list: list,
|
|
85
|
-
init_idx: int,
|
|
86
|
-
num_init: int,
|
|
87
|
-
) -> tuple[bool, int]:
|
|
88
|
-
x_obs_list.clear()
|
|
89
|
-
y_obs_list.clear()
|
|
90
|
-
yvar_obs_list.clear()
|
|
91
|
-
return True, 0
|
|
92
|
-
|
|
93
|
-
def prepare_ask(
|
|
94
|
-
self,
|
|
95
|
-
x_obs_list: list,
|
|
96
|
-
y_obs_list: list,
|
|
97
|
-
yvar_obs_list: list,
|
|
98
|
-
num_dim: int,
|
|
99
|
-
gp_num_steps: int,
|
|
100
|
-
rng: Any | None = None,
|
|
101
|
-
) -> tuple[Any, float | None, float | None, np.ndarray | None]:
|
|
102
|
-
from .proposal import mk_enn
|
|
103
|
-
|
|
104
|
-
k = self._config.k if self._config.k is not None else 10
|
|
105
|
-
self._enn, self._fitted_params = mk_enn(
|
|
106
|
-
x_obs_list,
|
|
107
|
-
y_obs_list,
|
|
108
|
-
yvar_obs_list=yvar_obs_list,
|
|
109
|
-
k=k,
|
|
110
|
-
num_fit_samples=self._config.num_fit_samples,
|
|
111
|
-
num_fit_candidates=self._config.num_fit_candidates,
|
|
112
|
-
scale_x=self._config.scale_x,
|
|
113
|
-
rng=rng,
|
|
114
|
-
params_warm_start=self._fitted_params,
|
|
115
|
-
)
|
|
116
|
-
self._fitted_n_obs = len(x_obs_list)
|
|
117
|
-
return None, None, None, None
|
|
118
|
-
|
|
119
|
-
def select_candidates(
|
|
120
|
-
self,
|
|
121
|
-
x_cand: np.ndarray,
|
|
122
|
-
num_arms: int,
|
|
123
|
-
num_dim: int,
|
|
124
|
-
rng: Generator,
|
|
125
|
-
fallback_fn: Callable[[np.ndarray, int], np.ndarray],
|
|
126
|
-
from_unit_fn: Callable[[np.ndarray], np.ndarray],
|
|
127
|
-
tr_state: Any = None, # noqa: ARG002
|
|
128
|
-
) -> np.ndarray:
|
|
129
|
-
import numpy as np
|
|
130
|
-
|
|
131
|
-
from enn.enn.enn_params import ENNParams
|
|
132
|
-
|
|
133
|
-
acq_type = self._config.acq_type
|
|
134
|
-
k = self._config.k
|
|
135
|
-
|
|
136
|
-
if self._enn is None:
|
|
137
|
-
return fallback_fn(x_cand, num_arms)
|
|
138
|
-
|
|
139
|
-
if self._fitted_params is not None:
|
|
140
|
-
params = self._fitted_params
|
|
141
|
-
else:
|
|
142
|
-
k_val = k if k is not None else 10
|
|
143
|
-
params = ENNParams(k=k_val, epi_var_scale=1.0, ale_homoscedastic_scale=0.0)
|
|
144
|
-
|
|
145
|
-
posterior = self._enn.posterior(x_cand, params=params)
|
|
146
|
-
mu = posterior.mu[:, 0]
|
|
147
|
-
se = posterior.se[:, 0]
|
|
148
|
-
|
|
149
|
-
if acq_type == "pareto":
|
|
150
|
-
from enn.enn.enn_util import arms_from_pareto_fronts
|
|
151
|
-
|
|
152
|
-
x_arms = arms_from_pareto_fronts(x_cand, mu, se, num_arms, rng)
|
|
153
|
-
elif acq_type == "ucb":
|
|
154
|
-
scores = mu + se
|
|
155
|
-
shuffled_indices = rng.permutation(len(scores))
|
|
156
|
-
shuffled_scores = scores[shuffled_indices]
|
|
157
|
-
top_k_in_shuffled = np.argpartition(-shuffled_scores, num_arms - 1)[
|
|
158
|
-
:num_arms
|
|
159
|
-
]
|
|
160
|
-
idx = shuffled_indices[top_k_in_shuffled]
|
|
161
|
-
x_arms = x_cand[idx]
|
|
162
|
-
elif acq_type == "thompson":
|
|
163
|
-
samples = posterior.sample(num_samples=1, rng=rng)
|
|
164
|
-
scores = samples[:, 0, 0]
|
|
165
|
-
shuffled_indices = rng.permutation(len(scores))
|
|
166
|
-
shuffled_scores = scores[shuffled_indices]
|
|
167
|
-
top_k_in_shuffled = np.argpartition(-shuffled_scores, num_arms - 1)[
|
|
168
|
-
:num_arms
|
|
169
|
-
]
|
|
170
|
-
idx = shuffled_indices[top_k_in_shuffled]
|
|
171
|
-
x_arms = x_cand[idx]
|
|
172
|
-
else:
|
|
173
|
-
raise ValueError(f"Unknown acq_type: {acq_type}")
|
|
174
|
-
|
|
175
|
-
return from_unit_fn(x_arms)
|
|
176
|
-
|
|
177
|
-
def estimate_y(self, x_unit: np.ndarray, y_observed: np.ndarray) -> np.ndarray:
|
|
178
|
-
if self._enn is None or self._fitted_params is None:
|
|
179
|
-
return y_observed
|
|
180
|
-
posterior = self._enn.posterior(x_unit, params=self._fitted_params)
|
|
181
|
-
# For multi-metric (morbo), return full mu; for single-metric, return 1D
|
|
182
|
-
if posterior.mu.shape[1] > 1:
|
|
183
|
-
return posterior.mu
|
|
184
|
-
return posterior.mu[:, 0]
|
|
185
|
-
|
|
186
|
-
def get_mu_sigma(self, x_unit: np.ndarray) -> tuple[np.ndarray, np.ndarray] | None:
|
|
187
|
-
if self._enn is None:
|
|
188
|
-
return None
|
|
189
|
-
k = self._config.k if self._config.k is not None else 10
|
|
190
|
-
from enn.enn.enn_params import ENNParams
|
|
191
|
-
|
|
192
|
-
params = (
|
|
193
|
-
self._fitted_params
|
|
194
|
-
if self._fitted_params is not None
|
|
195
|
-
else ENNParams(k=k, epi_var_scale=1.0, ale_homoscedastic_scale=0.0)
|
|
196
|
-
)
|
|
197
|
-
posterior = self._enn.posterior(x_unit, params=params, observation_noise=False)
|
|
198
|
-
# For multi-metric (morbo), return full mu/sigma; for single-metric, return 1D
|
|
199
|
-
if posterior.mu.shape[1] > 1:
|
|
200
|
-
return posterior.mu, posterior.se
|
|
201
|
-
return posterior.mu[:, 0], posterior.se[:, 0]
|
enn/turbo/turbo_mode.py
DELETED
enn/turbo/turbo_mode_impl.py
DELETED
|
@@ -1,76 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
from typing import TYPE_CHECKING, Any, Callable, Protocol
|
|
4
|
-
|
|
5
|
-
if TYPE_CHECKING:
|
|
6
|
-
import numpy as np
|
|
7
|
-
from numpy.random import Generator
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class TurboModeImpl(Protocol):
|
|
11
|
-
def get_x_center(
|
|
12
|
-
self,
|
|
13
|
-
x_obs_list: list,
|
|
14
|
-
y_obs_list: list,
|
|
15
|
-
rng: Generator,
|
|
16
|
-
tr_state: Any = None,
|
|
17
|
-
) -> np.ndarray | None: ...
|
|
18
|
-
|
|
19
|
-
def needs_tr_list(self) -> bool: ...
|
|
20
|
-
|
|
21
|
-
def create_trust_region(
|
|
22
|
-
self,
|
|
23
|
-
num_dim: int,
|
|
24
|
-
num_arms: int,
|
|
25
|
-
rng: Generator,
|
|
26
|
-
num_metrics: int | None = None,
|
|
27
|
-
) -> Any: ...
|
|
28
|
-
|
|
29
|
-
def try_early_ask(
|
|
30
|
-
self,
|
|
31
|
-
num_arms: int,
|
|
32
|
-
x_obs_list: list,
|
|
33
|
-
draw_initial_fn: Callable[[int], np.ndarray],
|
|
34
|
-
get_init_lhd_points_fn: Callable[[int], np.ndarray],
|
|
35
|
-
) -> np.ndarray | None: ...
|
|
36
|
-
|
|
37
|
-
def handle_restart(
|
|
38
|
-
self,
|
|
39
|
-
x_obs_list: list,
|
|
40
|
-
y_obs_list: list,
|
|
41
|
-
yvar_obs_list: list,
|
|
42
|
-
init_idx: int,
|
|
43
|
-
num_init: int,
|
|
44
|
-
) -> tuple[bool, int]: ...
|
|
45
|
-
|
|
46
|
-
def prepare_ask(
|
|
47
|
-
self,
|
|
48
|
-
x_obs_list: list,
|
|
49
|
-
y_obs_list: list,
|
|
50
|
-
yvar_obs_list: list,
|
|
51
|
-
num_dim: int,
|
|
52
|
-
gp_num_steps: int,
|
|
53
|
-
rng: Generator | Any | None = None,
|
|
54
|
-
) -> tuple[Any, float | None, float | None, np.ndarray | None]: ...
|
|
55
|
-
|
|
56
|
-
def select_candidates(
|
|
57
|
-
self,
|
|
58
|
-
x_cand: np.ndarray,
|
|
59
|
-
num_arms: int,
|
|
60
|
-
num_dim: int,
|
|
61
|
-
rng: Generator,
|
|
62
|
-
fallback_fn: Callable[[np.ndarray, int], np.ndarray],
|
|
63
|
-
from_unit_fn: Callable[[np.ndarray], np.ndarray],
|
|
64
|
-
tr_state: Any = None,
|
|
65
|
-
) -> np.ndarray: ...
|
|
66
|
-
|
|
67
|
-
def update_trust_region(
|
|
68
|
-
self,
|
|
69
|
-
tr_state: Any,
|
|
70
|
-
x_obs_list: list,
|
|
71
|
-
y_obs_list: list,
|
|
72
|
-
x_center: np.ndarray | None = None,
|
|
73
|
-
k: int | None = None,
|
|
74
|
-
) -> None: ...
|
|
75
|
-
|
|
76
|
-
def estimate_y(self, x_unit: np.ndarray, y_observed: np.ndarray) -> np.ndarray: ...
|