pyfemtet 0.4.21__py3-none-any.whl → 0.4.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyfemtet might be problematic. Click here for more details.
- pyfemtet/__init__.py +1 -1
- pyfemtet/_test_util.py +0 -2
- pyfemtet/message/messages.py +15 -1
- pyfemtet/opt/_femopt.py +233 -199
- pyfemtet/opt/_femopt_core.py +116 -47
- pyfemtet/opt/femprj_sample/ParametricIF.py +0 -2
- pyfemtet/opt/femprj_sample/cad_ex01_NX.py +0 -8
- pyfemtet/opt/femprj_sample/cad_ex01_SW.py +0 -8
- pyfemtet/opt/femprj_sample/gal_ex58_parametric.py +0 -8
- pyfemtet/opt/femprj_sample/gau_ex08_parametric.py +0 -8
- pyfemtet/opt/femprj_sample/her_ex40_parametric.py +0 -8
- pyfemtet/opt/femprj_sample/paswat_ex1_parametric.py +0 -8
- pyfemtet/opt/femprj_sample/paswat_ex1_parametric_parallel.py +0 -8
- pyfemtet/opt/femprj_sample/wat_ex14_parametric.py +0 -8
- pyfemtet/opt/femprj_sample/wat_ex14_parametric_parallel.py +0 -8
- pyfemtet/opt/femprj_sample_jp/ParametricIF_jp.py +0 -2
- pyfemtet/opt/femprj_sample_jp/cad_ex01_NX_jp.py +0 -8
- pyfemtet/opt/femprj_sample_jp/cad_ex01_SW_jp.py +0 -8
- pyfemtet/opt/femprj_sample_jp/gal_ex58_parametric_jp.py +0 -8
- pyfemtet/opt/femprj_sample_jp/gau_ex08_parametric_jp.py +0 -8
- pyfemtet/opt/femprj_sample_jp/her_ex40_parametric_jp.py +0 -8
- pyfemtet/opt/femprj_sample_jp/paswat_ex1_parametric_jp.py +0 -8
- pyfemtet/opt/femprj_sample_jp/paswat_ex1_parametric_parallel_jp.py +0 -8
- pyfemtet/opt/femprj_sample_jp/wat_ex14_parametric_jp.py +0 -8
- pyfemtet/opt/femprj_sample_jp/wat_ex14_parametric_parallel_jp.py +0 -8
- pyfemtet/opt/opt/_base.py +4 -4
- pyfemtet/opt/opt/_optuna.py +33 -1
- pyfemtet/opt/opt/_optuna_botorch_helper.py +209 -0
- pyfemtet/opt/visualization/complex_components/main_graph.py +22 -5
- pyfemtet/opt/visualization/complex_components/pm_graph.py +77 -25
- pyfemtet/opt/visualization/complex_components/pm_graph_creator.py +7 -0
- pyfemtet/opt/visualization/process_monitor/application.py +10 -6
- pyfemtet/opt/visualization/process_monitor/pages.py +102 -0
- pyfemtet/opt/visualization/result_viewer/application.py +6 -0
- pyfemtet/opt/visualization/result_viewer/pages.py +1 -1
- {pyfemtet-0.4.21.dist-info → pyfemtet-0.4.23.dist-info}/METADATA +2 -4
- {pyfemtet-0.4.21.dist-info → pyfemtet-0.4.23.dist-info}/RECORD +40 -53
- pyfemtet/FemtetPJTSample/NX_ex01/NX_ex01.femprj +0 -0
- pyfemtet/FemtetPJTSample/NX_ex01/NX_ex01.prt +0 -0
- pyfemtet/FemtetPJTSample/NX_ex01/NX_ex01.py +0 -118
- pyfemtet/FemtetPJTSample/Sldworks_ex01/Sldworks_ex01.SLDPRT +0 -0
- pyfemtet/FemtetPJTSample/Sldworks_ex01/Sldworks_ex01.femprj +0 -0
- pyfemtet/FemtetPJTSample/Sldworks_ex01/Sldworks_ex01.py +0 -121
- pyfemtet/FemtetPJTSample/_her_ex40_parametric.py +0 -148
- pyfemtet/FemtetPJTSample/gau_ex08_parametric.femprj +0 -0
- pyfemtet/FemtetPJTSample/gau_ex08_parametric.py +0 -58
- pyfemtet/FemtetPJTSample/her_ex40_parametric.femprj +0 -0
- pyfemtet/FemtetPJTSample/her_ex40_parametric.py +0 -148
- pyfemtet/FemtetPJTSample/wat_ex14_parallel_parametric.py +0 -65
- pyfemtet/FemtetPJTSample/wat_ex14_parametric.femprj +0 -0
- pyfemtet/FemtetPJTSample/wat_ex14_parametric.py +0 -64
- {pyfemtet-0.4.21.dist-info → pyfemtet-0.4.23.dist-info}/LICENSE +0 -0
- {pyfemtet-0.4.21.dist-info → pyfemtet-0.4.23.dist-info}/WHEEL +0 -0
- {pyfemtet-0.4.21.dist-info → pyfemtet-0.4.23.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
from typing import Optional, List, Tuple, Callable
|
|
2
|
+
from functools import partial
|
|
3
|
+
import inspect
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import optuna.study
|
|
7
|
+
import torch
|
|
8
|
+
from torch import Tensor
|
|
9
|
+
from botorch.optim.initializers import gen_batch_initial_conditions
|
|
10
|
+
from botorch.utils.transforms import unnormalize
|
|
11
|
+
from optuna.study import Study
|
|
12
|
+
from botorch.acquisition import AcquisitionFunction
|
|
13
|
+
|
|
14
|
+
from pyfemtet.opt.opt import AbstractOptimizer
|
|
15
|
+
from pyfemtet.opt.parameter import ExpressionEvaluator
|
|
16
|
+
|
|
17
|
+
# module to monkey patch
|
|
18
|
+
import optuna_integration
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# モンキーパッチを実行するため、optimize_acqf の引数を MonkyPatch クラスで定義し optuna に上書きされないようにするためのクラス
|
|
22
|
+
class NonOverwritablePartial(partial):
|
|
23
|
+
def __call__(self, /, *args, **keywords):
|
|
24
|
+
stored_kwargs = self.keywords
|
|
25
|
+
keywords.update(stored_kwargs)
|
|
26
|
+
return self.func(*self.args, *args, **keywords)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# prm_name を引数に取る関数を optimize_acqf の nonlinear_inequality_constraints に入れられる形に変換する関数
|
|
30
|
+
class ConvertedConstraintFunction:
|
|
31
|
+
def __init__(self, fun, prm_args, kwargs, variables: ExpressionEvaluator, study: optuna.study.Study):
|
|
32
|
+
self.fun = fun
|
|
33
|
+
self.prm_args = prm_args
|
|
34
|
+
self.kwargs = kwargs
|
|
35
|
+
self.variables = variables
|
|
36
|
+
self.study = study
|
|
37
|
+
|
|
38
|
+
self.bounds = None
|
|
39
|
+
self.prm_name_seq = None
|
|
40
|
+
|
|
41
|
+
# fun の prm として使う引数が指定されていなければ fun の引数を取得
|
|
42
|
+
if self.prm_args is None:
|
|
43
|
+
signature = inspect.signature(fun)
|
|
44
|
+
prm_inputs = set([a.name for a in signature.parameters.values()])
|
|
45
|
+
else:
|
|
46
|
+
prm_inputs = set(self.prm_args)
|
|
47
|
+
|
|
48
|
+
# 引数の set から kwargs の key を削除
|
|
49
|
+
self.prm_arg_names = prm_inputs - set(kwargs.keys())
|
|
50
|
+
|
|
51
|
+
# 変な引数が残っていないか確認
|
|
52
|
+
assert all([(arg in variables.get_parameter_names()) for arg in self.prm_arg_names])
|
|
53
|
+
|
|
54
|
+
def __call__(self, x: Tensor or np.ndarray):
|
|
55
|
+
# x: all of normalized parameters whose sequence is sorted by optuna
|
|
56
|
+
|
|
57
|
+
if not isinstance(x, Tensor):
|
|
58
|
+
x = torch.tensor(np.array(x)).double()
|
|
59
|
+
|
|
60
|
+
x = unnormalize(x, self.bounds)
|
|
61
|
+
|
|
62
|
+
# fun で使うパラメータのみ value を取得
|
|
63
|
+
kwargs = self.kwargs
|
|
64
|
+
kwargs.update(
|
|
65
|
+
{k: v for k, v in zip(self.prm_name_seq, x) if k in self.prm_arg_names}
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
return self.fun(**kwargs)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# 与えられた獲得関数に拘束を満たさない場合 0 を返すよう加工された獲得関数
|
|
72
|
+
class AcqWithConstraint(AcquisitionFunction):
|
|
73
|
+
|
|
74
|
+
# noinspection PyAttributeOutsideInit
|
|
75
|
+
def set(self, _org_acq_function: AcquisitionFunction, nonlinear_constraints):
|
|
76
|
+
self._org_acq_function = _org_acq_function
|
|
77
|
+
self._nonlinear_constraints = nonlinear_constraints
|
|
78
|
+
|
|
79
|
+
def forward(self, X: Tensor) -> Tensor:
|
|
80
|
+
base = self._org_acq_function.forward(X)
|
|
81
|
+
|
|
82
|
+
is_feasible = all([cons(X[0][0]) > 0 for cons, _ in self._nonlinear_constraints])
|
|
83
|
+
if is_feasible:
|
|
84
|
+
return base
|
|
85
|
+
else:
|
|
86
|
+
# penalty = torch.Tensor(size=base.shape)
|
|
87
|
+
# penalty = torch.fill(penalty, -1e10)
|
|
88
|
+
# return base * penalty
|
|
89
|
+
return base * 0.
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def remove_infeasible(_ic_batch, nonlinear_constraints):
|
|
93
|
+
# infeasible なものを削除
|
|
94
|
+
remove_indices = []
|
|
95
|
+
for i, ic in enumerate(_ic_batch): # ic: 1 x len(params) tensor
|
|
96
|
+
# cons: Callable[["Tensor"], "Tensor"]
|
|
97
|
+
is_feasible = all([cons(ic[0]) > 0 for cons, _ in nonlinear_constraints])
|
|
98
|
+
if not is_feasible:
|
|
99
|
+
# ic_batch[i] = torch.nan # これで無視にならない
|
|
100
|
+
remove_indices.append(i)
|
|
101
|
+
for i in remove_indices[::-1]:
|
|
102
|
+
_ic_batch = torch.cat((_ic_batch[:i], _ic_batch[i + 1:]))
|
|
103
|
+
return _ic_batch
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class OptunaBotorchWithParameterConstraintMonkeyPatch:
|
|
107
|
+
|
|
108
|
+
def __init__(self, study: Study, opt: AbstractOptimizer):
|
|
109
|
+
self.num_restarts: int = 20
|
|
110
|
+
self.raw_samples_additional: int = 512
|
|
111
|
+
self.eta: float = 2.0
|
|
112
|
+
self.study = study
|
|
113
|
+
self.opt = opt
|
|
114
|
+
self.nonlinear_inequality_constraints = []
|
|
115
|
+
self.additional_kwargs = dict()
|
|
116
|
+
self.bounds = None
|
|
117
|
+
self.prm_name_seq = None
|
|
118
|
+
|
|
119
|
+
def add_nonlinear_constraint(self, fun, prm_args, kwargs):
|
|
120
|
+
f = ConvertedConstraintFunction(
|
|
121
|
+
fun,
|
|
122
|
+
prm_args,
|
|
123
|
+
kwargs,
|
|
124
|
+
self.opt.variables,
|
|
125
|
+
self.study,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# 初期化
|
|
129
|
+
self.nonlinear_inequality_constraints = self.nonlinear_inequality_constraints or []
|
|
130
|
+
|
|
131
|
+
# 自身に追加
|
|
132
|
+
self.nonlinear_inequality_constraints.append((f, True))
|
|
133
|
+
|
|
134
|
+
# optimize_acqf() に渡す引数に追加
|
|
135
|
+
self.additional_kwargs.update(
|
|
136
|
+
nonlinear_inequality_constraints=self.nonlinear_inequality_constraints
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def _detect_prm_seq_if_needed(self):
|
|
140
|
+
# study から distribution の情報を復元する。
|
|
141
|
+
if self.bounds is None or self.prm_name_seq is None:
|
|
142
|
+
from optuna._transform import _transform_search_space
|
|
143
|
+
# sample_relative の後に呼ばれているから最後の trial は search_space を持つはず
|
|
144
|
+
search_space: dict = self.study.sampler.infer_relative_search_space(self.study, self.study.trials[-1])
|
|
145
|
+
self.bounds = _transform_search_space(search_space, False, False)[0].T
|
|
146
|
+
self.prm_name_seq = list(search_space.keys())
|
|
147
|
+
|
|
148
|
+
for cns in self.nonlinear_inequality_constraints:
|
|
149
|
+
cns[0].bounds = torch.tensor(self.bounds)
|
|
150
|
+
cns[0].prm_name_seq = self.prm_name_seq
|
|
151
|
+
|
|
152
|
+
def generate_initial_conditions(self, *args, **kwargs):
|
|
153
|
+
self._detect_prm_seq_if_needed()
|
|
154
|
+
|
|
155
|
+
# acqf_function を 上書きし、拘束を満たさないならば 0 を返すようにする
|
|
156
|
+
org_acq_function = kwargs['acq_function']
|
|
157
|
+
new_acqf = AcqWithConstraint(None)
|
|
158
|
+
new_acqf.set(org_acq_function, self.nonlinear_inequality_constraints)
|
|
159
|
+
kwargs['acq_function'] = new_acqf
|
|
160
|
+
|
|
161
|
+
# initial condition の提案 batch を作成
|
|
162
|
+
# ic: `num_restarts x q x d` tensor of initial conditions.
|
|
163
|
+
# q = 1, d = len(params)
|
|
164
|
+
ic_batch = gen_batch_initial_conditions(*args, **kwargs)
|
|
165
|
+
|
|
166
|
+
# 拘束を満たさないものを削除
|
|
167
|
+
ic_batch = remove_infeasible(ic_batch, self.nonlinear_inequality_constraints)
|
|
168
|
+
|
|
169
|
+
# 全部なくなっているならばランダムに生成
|
|
170
|
+
if len(ic_batch) == 0:
|
|
171
|
+
print('拘束を満たす組み合わせがなかったのでランダムサンプリングします')
|
|
172
|
+
while len(ic_batch) == 0:
|
|
173
|
+
size = ic_batch.shape
|
|
174
|
+
ic_batch = torch.rand(size=[100, *size[1:]]) # 正規化された変数の組合せ
|
|
175
|
+
ic_batch = remove_infeasible(ic_batch, self.nonlinear_inequality_constraints)
|
|
176
|
+
|
|
177
|
+
return ic_batch
|
|
178
|
+
|
|
179
|
+
def do_monkey_patch(self):
|
|
180
|
+
"""optuna_integration.botorch には optimize_acqf に constraints を渡す方法が用意されていないので、モンキーパッチして渡す
|
|
181
|
+
|
|
182
|
+
モンキーパッチ自体は最適化実行前のどの時点で呼んでも機能するが、additional_kwargs の更新後に
|
|
183
|
+
モンキーパッチを呼ぶ必要があるのでコンストラクタにこの処理は入れない。
|
|
184
|
+
各 add_constraint に入れるのはいいかも。
|
|
185
|
+
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
# === reconstruct argument ``options`` for optimize_acqf ===
|
|
189
|
+
options = dict() # initialize
|
|
190
|
+
|
|
191
|
+
# for nonlinear-constraint
|
|
192
|
+
options.update(dict(batch_limit=1))
|
|
193
|
+
|
|
194
|
+
# for gen_candidates_scipy()
|
|
195
|
+
# use COBYLA or SLSQP only.
|
|
196
|
+
options.update(dict(method='SLSQP'))
|
|
197
|
+
|
|
198
|
+
# make partial of optimize_acqf used in optuna_integration.botorch and replace to it.
|
|
199
|
+
original_fun = optuna_integration.botorch.optimize_acqf
|
|
200
|
+
overwritten_fun = NonOverwritablePartial(
|
|
201
|
+
original_fun,
|
|
202
|
+
q=1, # for nonlinear constraints
|
|
203
|
+
options=options,
|
|
204
|
+
num_restarts=20, # gen_batch_initial_conditions に渡すべきで、self.generate_initial_conditions に渡される変数。
|
|
205
|
+
raw_samples=512, # gen_batch_initial_conditions に渡すべきで、self.generate_initial_conditions に渡される変数。
|
|
206
|
+
nonlinear_inequality_constraints=self.nonlinear_inequality_constraints,
|
|
207
|
+
ic_generator=self.generate_initial_conditions,
|
|
208
|
+
)
|
|
209
|
+
optuna_integration.botorch.optimize_acqf = overwritten_fun
|
|
@@ -184,16 +184,18 @@ class MainGraph(AbstractPage):
|
|
|
184
184
|
# create component
|
|
185
185
|
title_component = html.H3(f"trial{trial}", style={"color": "darkblue"})
|
|
186
186
|
img_component = self.create_image_content_if_femtet(trial)
|
|
187
|
-
|
|
187
|
+
tbl_component_prm = self.create_formatted_parameter(row)
|
|
188
|
+
tbl_component_obj = self.create_formatted_objective(row)
|
|
188
189
|
|
|
189
190
|
# create layout
|
|
190
191
|
description = html.Div([
|
|
191
192
|
title_component,
|
|
192
|
-
|
|
193
|
+
tbl_component_prm,
|
|
193
194
|
])
|
|
194
195
|
tooltip_layout = html.Div([
|
|
195
196
|
html.Div(img_component, style={'display': 'inline-block', 'margin-right': '10px', 'vertical-align': 'top'}),
|
|
196
|
-
html.Div(description, style={'display': 'inline-block', 'margin-right': '10px'})
|
|
197
|
+
html.Div(description, style={'display': 'inline-block', 'margin-right': '10px'}),
|
|
198
|
+
html.Div(tbl_component_obj, style={'display': 'inline-block', 'margin-right': '10px'}),
|
|
197
199
|
])
|
|
198
200
|
|
|
199
201
|
return True, bbox, tooltip_layout
|
|
@@ -205,7 +207,22 @@ class MainGraph(AbstractPage):
|
|
|
205
207
|
names = parameters.columns
|
|
206
208
|
values = [f'{value:.3e}' for value in parameters.values.ravel()]
|
|
207
209
|
data = pd.DataFrame(dict(
|
|
208
|
-
|
|
210
|
+
parameter=names, value=values
|
|
211
|
+
))
|
|
212
|
+
table = dash_table.DataTable(
|
|
213
|
+
columns=[{'name': col, 'id': col} for col in data.columns],
|
|
214
|
+
data=data.to_dict('records')
|
|
215
|
+
)
|
|
216
|
+
return table
|
|
217
|
+
|
|
218
|
+
def create_formatted_objective(self, row) -> Component:
|
|
219
|
+
metadata = self.application.history.metadata
|
|
220
|
+
pd.options.display.float_format = '{:.4e}'.format
|
|
221
|
+
objectives = row.iloc[:, np.where(np.array(metadata) == 'obj')[0]]
|
|
222
|
+
names = objectives.columns
|
|
223
|
+
values = [f'{value:.3e}' for value in objectives.values.ravel()]
|
|
224
|
+
data = pd.DataFrame(dict(
|
|
225
|
+
objective=names, value=values
|
|
209
226
|
))
|
|
210
227
|
table = dash_table.DataTable(
|
|
211
228
|
columns=[{'name': col, 'id': col} for col in data.columns],
|
|
@@ -258,5 +275,5 @@ class MainGraph(AbstractPage):
|
|
|
258
275
|
if isinstance(self.application, ProcessMonitorApplication):
|
|
259
276
|
df = self.application.local_data
|
|
260
277
|
else:
|
|
261
|
-
df = self.application.history.
|
|
278
|
+
df = self.application.history.get_df()
|
|
262
279
|
return df
|
|
@@ -129,6 +129,19 @@ class PredictionModelGraph(AbstractPage):
|
|
|
129
129
|
self.slider_stack_data = html.Data(**{self.slider_stack_data_prop: {}})
|
|
130
130
|
self.slider_container = html.Div()
|
|
131
131
|
|
|
132
|
+
# 2d or 3d
|
|
133
|
+
self.switch_3d = dbc.Checklist(
|
|
134
|
+
options=[
|
|
135
|
+
dict(
|
|
136
|
+
label=Msg.LABEL_SWITCH_PREDICTION_MODEL_3D,
|
|
137
|
+
disabled=False,
|
|
138
|
+
value=False,
|
|
139
|
+
)
|
|
140
|
+
],
|
|
141
|
+
switch=True,
|
|
142
|
+
value=[],
|
|
143
|
+
)
|
|
144
|
+
|
|
132
145
|
def setup_layout(self):
|
|
133
146
|
self.card_header = dbc.CardHeader(self.tabs)
|
|
134
147
|
|
|
@@ -151,6 +164,7 @@ class PredictionModelGraph(AbstractPage):
|
|
|
151
164
|
self.command_manager
|
|
152
165
|
],
|
|
153
166
|
direction='horizontal', gap=2),
|
|
167
|
+
self.switch_3d,
|
|
154
168
|
*dropdown_rows,
|
|
155
169
|
self.slider_container,
|
|
156
170
|
self.slider_stack_data,
|
|
@@ -193,13 +207,15 @@ class PredictionModelGraph(AbstractPage):
|
|
|
193
207
|
Output(self.redraw_graph_button_spinner, 'spinner_style', allow_duplicate=True),
|
|
194
208
|
Output(self.redraw_graph_button, 'disabled', allow_duplicate=True),
|
|
195
209
|
Output(self.command_manager, self.command_manager_prop, allow_duplicate=True),
|
|
210
|
+
Output(self.switch_3d, 'options', allow_duplicate=True),
|
|
196
211
|
Input(self.fit_rsm_button, 'n_clicks'),
|
|
197
212
|
Input(self.redraw_graph_button, 'n_clicks'),
|
|
198
213
|
State(self.fit_rsm_button_spinner, 'spinner_style'),
|
|
199
214
|
State(self.redraw_graph_button_spinner, 'spinner_style'),
|
|
215
|
+
State(self.switch_3d, 'options'),
|
|
200
216
|
prevent_initial_call=True,
|
|
201
217
|
)
|
|
202
|
-
def disable_fit_button(_1, _2, state1, state2):
|
|
218
|
+
def disable_fit_button(_1, _2, state1, state2, switch_options):
|
|
203
219
|
# spinner visibility
|
|
204
220
|
if 'display' in state1.keys(): state1.pop('display')
|
|
205
221
|
if 'display' in state2.keys(): state2.pop('display')
|
|
@@ -210,7 +226,12 @@ class PredictionModelGraph(AbstractPage):
|
|
|
210
226
|
else:
|
|
211
227
|
command = self.CommandState.redraw.value
|
|
212
228
|
|
|
213
|
-
|
|
229
|
+
# disable switch
|
|
230
|
+
option = switch_options[0]
|
|
231
|
+
option.update({'disabled': True})
|
|
232
|
+
switch_options[0] = option
|
|
233
|
+
|
|
234
|
+
return state1, True, state2, True, command, switch_options
|
|
214
235
|
|
|
215
236
|
# ===== recreate RSM =====
|
|
216
237
|
@app.callback(
|
|
@@ -256,9 +277,10 @@ class PredictionModelGraph(AbstractPage):
|
|
|
256
277
|
State(self.axis3_obj_dropdown, 'label'),
|
|
257
278
|
State(self.slider_container, 'children'), # for callback chain
|
|
258
279
|
State({'type': 'prm-slider', 'index': ALL}, 'value'),
|
|
280
|
+
State(self.switch_3d, 'value'),
|
|
259
281
|
prevent_initial_call=True,
|
|
260
282
|
)
|
|
261
|
-
def redraw_graph(command, active_tab_id, axis1_label, axis2_label, axis3_label, _2, prm_values):
|
|
283
|
+
def redraw_graph(command, active_tab_id, axis1_label, axis2_label, axis3_label, _2, prm_values, is_3d):
|
|
262
284
|
# just in case
|
|
263
285
|
if callback_context.triggered_id is None:
|
|
264
286
|
raise PreventUpdate
|
|
@@ -283,6 +305,9 @@ class PredictionModelGraph(AbstractPage):
|
|
|
283
305
|
logger.error(Msg.ERR_NO_PREDICTION_MODEL)
|
|
284
306
|
return no_update, self.CommandState.ready.value # to re-enable buttons, fire callback chain
|
|
285
307
|
|
|
308
|
+
if not is_3d:
|
|
309
|
+
axis2_label = None
|
|
310
|
+
|
|
286
311
|
# get indices to remove
|
|
287
312
|
idx1 = prm_names.index(axis1_label) if axis1_label in prm_names else None
|
|
288
313
|
idx2 = prm_names.index(axis2_label) if axis2_label in prm_names else None
|
|
@@ -307,23 +332,31 @@ class PredictionModelGraph(AbstractPage):
|
|
|
307
332
|
|
|
308
333
|
return fig, self.CommandState.ready.value
|
|
309
334
|
|
|
310
|
-
# =====
|
|
335
|
+
# ===== re-enable buttons when the graph is updated, =====
|
|
311
336
|
@app.callback(
|
|
312
337
|
Output(self.fit_rsm_button, 'disabled', allow_duplicate=True),
|
|
313
338
|
Output(self.fit_rsm_button_spinner, 'spinner_style', allow_duplicate=True),
|
|
314
339
|
Output(self.redraw_graph_button, 'disabled', allow_duplicate=True),
|
|
315
340
|
Output(self.redraw_graph_button_spinner, 'spinner_style', allow_duplicate=True),
|
|
341
|
+
Output(self.switch_3d, 'options', allow_duplicate=True),
|
|
316
342
|
Input(self.command_manager, self.command_manager_prop),
|
|
317
343
|
State(self.fit_rsm_button_spinner, 'spinner_style'),
|
|
318
344
|
State(self.redraw_graph_button_spinner, 'spinner_style'),
|
|
345
|
+
State(self.switch_3d, 'options'),
|
|
319
346
|
prevent_initial_call=True,
|
|
320
347
|
)
|
|
321
|
-
def enable_buttons(command, state1, state2):
|
|
348
|
+
def enable_buttons(command, state1, state2, switch_options):
|
|
322
349
|
if command != self.CommandState.ready.value:
|
|
323
350
|
raise PreventUpdate
|
|
324
351
|
state1.update({'display': 'none'})
|
|
325
352
|
state2.update({'display': 'none'})
|
|
326
|
-
|
|
353
|
+
|
|
354
|
+
# enable switch
|
|
355
|
+
option = switch_options[0]
|
|
356
|
+
option.update({'disabled': False})
|
|
357
|
+
switch_options[0] = option
|
|
358
|
+
|
|
359
|
+
return False, state1, False, state2, switch_options
|
|
327
360
|
|
|
328
361
|
# ===== setup dropdown and sliders from history =====
|
|
329
362
|
@app.callback(
|
|
@@ -414,6 +447,7 @@ class PredictionModelGraph(AbstractPage):
|
|
|
414
447
|
Input({'type': 'axis2-dropdown-menu-item', 'index': ALL}, 'n_clicks'),
|
|
415
448
|
Input({'type': 'axis3-dropdown-menu-item', 'index': ALL}, 'n_clicks'),
|
|
416
449
|
Input(self.axis1_prm_dropdown, 'children'), # for callback chain timing
|
|
450
|
+
Input(self.switch_3d, 'value'),
|
|
417
451
|
State(self.axis1_prm_dropdown, 'label'),
|
|
418
452
|
State(self.axis2_prm_dropdown, 'label'),
|
|
419
453
|
State(self.axis3_obj_dropdown, 'label'),
|
|
@@ -422,10 +456,10 @@ class PredictionModelGraph(AbstractPage):
|
|
|
422
456
|
)
|
|
423
457
|
def update_controller(*args):
|
|
424
458
|
# argument processing
|
|
425
|
-
current_ax1_label = args[
|
|
426
|
-
current_ax2_label = args[
|
|
427
|
-
|
|
428
|
-
|
|
459
|
+
current_ax1_label = args[5]
|
|
460
|
+
current_ax2_label = args[6]
|
|
461
|
+
current_styles: list[dict] = args[8]
|
|
462
|
+
is_3d = args[4]
|
|
429
463
|
|
|
430
464
|
# just in case
|
|
431
465
|
if callback_context.triggered_id is None:
|
|
@@ -451,6 +485,10 @@ class PredictionModelGraph(AbstractPage):
|
|
|
451
485
|
if len(prm_names) < 2:
|
|
452
486
|
ret[ax2_hidden] = True
|
|
453
487
|
|
|
488
|
+
# ===== hide dropdown of axis 2 if not is_3d =====
|
|
489
|
+
if not is_3d:
|
|
490
|
+
ret[ax2_hidden] = True
|
|
491
|
+
|
|
454
492
|
# ===== update dropdown label =====
|
|
455
493
|
|
|
456
494
|
# by callback chain on loaded after setup_dropdown_and_sliders()
|
|
@@ -466,29 +504,43 @@ class PredictionModelGraph(AbstractPage):
|
|
|
466
504
|
|
|
467
505
|
# ax1
|
|
468
506
|
if callback_context.triggered_id['type'] == 'axis1-dropdown-menu-item':
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
507
|
+
ret[ax1_label_key] = new_label
|
|
508
|
+
if new_label == current_ax2_label:
|
|
509
|
+
ret[ax2_label_key] = current_ax1_label
|
|
510
|
+
|
|
473
511
|
|
|
474
512
|
# ax2
|
|
475
513
|
elif callback_context.triggered_id['type'] == 'axis2-dropdown-menu-item':
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
514
|
+
ret[ax2_label_key] = new_label
|
|
515
|
+
if new_label == current_ax1_label:
|
|
516
|
+
ret[ax1_label_key] = current_ax2_label
|
|
517
|
+
|
|
480
518
|
|
|
481
519
|
# ax3
|
|
482
520
|
elif callback_context.triggered_id['type'] == 'axis3-dropdown-menu-item':
|
|
483
521
|
ret[ax3_label_key] = new_label
|
|
484
522
|
|
|
485
523
|
# ===== update visibility of sliders =====
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
524
|
+
|
|
525
|
+
# invisible the slider correspond to the dropdown-1
|
|
526
|
+
label_key, current_label = ax1_label_key, current_ax1_label
|
|
527
|
+
# get label of output
|
|
528
|
+
label = ret[label_key] if ret[label_key] != no_update else current_label
|
|
529
|
+
# update display style of slider
|
|
530
|
+
idx = prm_names.index(label) if label in prm_names else None
|
|
531
|
+
if idx is not None:
|
|
532
|
+
current_styles[idx].update({'display': 'none'})
|
|
533
|
+
ret[slider_style_list_key][idx] = current_styles[idx]
|
|
534
|
+
|
|
535
|
+
# invisible the slider correspond to the dropdown-2
|
|
536
|
+
label_key, current_label = ax2_label_key, current_ax2_label
|
|
537
|
+
# get label of output
|
|
538
|
+
label = ret[label_key] if ret[label_key] != no_update else current_label
|
|
539
|
+
# update display style of slider
|
|
540
|
+
idx = prm_names.index(label) if label in prm_names else None
|
|
541
|
+
if idx is not None:
|
|
542
|
+
# if 2d, should not disable the slider correspond to dropdown-2.
|
|
543
|
+
if is_3d:
|
|
492
544
|
current_styles[idx].update({'display': 'none'})
|
|
493
545
|
ret[slider_style_list_key][idx] = current_styles[idx]
|
|
494
546
|
|
|
@@ -554,5 +606,5 @@ class PredictionModelGraph(AbstractPage):
|
|
|
554
606
|
if isinstance(self.application, ProcessMonitorApplication):
|
|
555
607
|
df = self.application.local_data
|
|
556
608
|
else:
|
|
557
|
-
df = self.application.history.
|
|
609
|
+
df = self.application.history.get_df()
|
|
558
610
|
return df
|
|
@@ -5,7 +5,7 @@ from threading import Thread
|
|
|
5
5
|
import pandas as pd
|
|
6
6
|
|
|
7
7
|
from pyfemtet.opt.visualization.base import PyFemtetApplicationBase, logger
|
|
8
|
-
from pyfemtet.opt.visualization.process_monitor.pages import HomePage, WorkerPage, PredictionModelPage
|
|
8
|
+
from pyfemtet.opt.visualization.process_monitor.pages import HomePage, WorkerPage, PredictionModelPage, OptunaVisualizerPage
|
|
9
9
|
from pyfemtet.message import Msg
|
|
10
10
|
|
|
11
11
|
|
|
@@ -67,14 +67,14 @@ class ProcessMonitorApplication(PyFemtetApplicationBase):
|
|
|
67
67
|
if self._should_get_actor_data:
|
|
68
68
|
return self._df
|
|
69
69
|
else:
|
|
70
|
-
return self.history.
|
|
70
|
+
return self.history.get_df()
|
|
71
71
|
|
|
72
72
|
@local_data.setter
|
|
73
73
|
def local_data(self, value: pd.DataFrame):
|
|
74
74
|
if self._should_get_actor_data:
|
|
75
75
|
raise NotImplementedError('If should_get_actor_data, ProcessMonitorApplication.local_df is read_only.')
|
|
76
76
|
else:
|
|
77
|
-
self.history.
|
|
77
|
+
self.history.set_df(value)
|
|
78
78
|
|
|
79
79
|
def setup_callback(self, debug=False):
|
|
80
80
|
if not debug:
|
|
@@ -112,7 +112,7 @@ class ProcessMonitorApplication(PyFemtetApplicationBase):
|
|
|
112
112
|
worker_status.set(OptimizationStatus.INTERRUPTING)
|
|
113
113
|
|
|
114
114
|
# status と df を actor から application に反映する
|
|
115
|
-
self._df = self.history.
|
|
115
|
+
self._df = self.history.get_df().copy()
|
|
116
116
|
self.local_entire_status_int = self.entire_status.get()
|
|
117
117
|
self.local_worker_status_int_list = [s.get() for s in self.worker_status_list]
|
|
118
118
|
|
|
@@ -176,11 +176,13 @@ def g_debug():
|
|
|
176
176
|
|
|
177
177
|
g_home_page = HomePage(Msg.PAGE_TITLE_PROGRESS)
|
|
178
178
|
g_rsm_page = PredictionModelPage(Msg.PAGE_TITLE_PREDICTION_MODEL, '/prediction-model', g_application)
|
|
179
|
+
g_optuna = OptunaVisualizerPage(Msg.PAGE_TITLE_OPTUNA_VISUALIZATION, '/optuna', g_application)
|
|
179
180
|
g_worker_page = WorkerPage(Msg.PAGE_TITLE_WORKERS, '/workers', g_application)
|
|
180
181
|
|
|
181
182
|
g_application.add_page(g_home_page, 0)
|
|
182
183
|
g_application.add_page(g_rsm_page, 1)
|
|
183
|
-
g_application.add_page(
|
|
184
|
+
g_application.add_page(g_optuna, 2)
|
|
185
|
+
g_application.add_page(g_worker_page, 3)
|
|
184
186
|
g_application.setup_callback(debug=False)
|
|
185
187
|
|
|
186
188
|
g_application.run(debug=False)
|
|
@@ -191,11 +193,13 @@ def main(history, status, worker_addresses, worker_status_list, host=None, port=
|
|
|
191
193
|
|
|
192
194
|
g_home_page = HomePage(Msg.PAGE_TITLE_PROGRESS)
|
|
193
195
|
g_rsm_page = PredictionModelPage(Msg.PAGE_TITLE_PREDICTION_MODEL, '/prediction-model', g_application)
|
|
196
|
+
g_optuna = OptunaVisualizerPage(Msg.PAGE_TITLE_OPTUNA_VISUALIZATION, '/optuna', g_application)
|
|
194
197
|
g_worker_page = WorkerPage(Msg.PAGE_TITLE_WORKERS, '/workers', g_application)
|
|
195
198
|
|
|
196
199
|
g_application.add_page(g_home_page, 0)
|
|
197
200
|
g_application.add_page(g_rsm_page, 1)
|
|
198
|
-
g_application.add_page(
|
|
201
|
+
g_application.add_page(g_optuna, 2)
|
|
202
|
+
g_application.add_page(g_worker_page, 3)
|
|
199
203
|
g_application.setup_callback()
|
|
200
204
|
|
|
201
205
|
g_application.start_server(host, port)
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
import pandas as pd
|
|
3
3
|
|
|
4
|
+
import optuna
|
|
5
|
+
|
|
4
6
|
from dash import Output, Input, State, callback_context, no_update, ALL
|
|
5
7
|
from dash.exceptions import PreventUpdate
|
|
6
8
|
|
|
@@ -289,3 +291,103 @@ class PredictionModelPage(AbstractPage):
|
|
|
289
291
|
|
|
290
292
|
def setup_layout(self):
|
|
291
293
|
self.layout = self.rsm_graph.layout
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
class OptunaVisualizerPage(AbstractPage):
|
|
297
|
+
|
|
298
|
+
def __init__(self, title, rel_url, application):
|
|
299
|
+
from pyfemtet.opt.visualization.process_monitor.application import ProcessMonitorApplication
|
|
300
|
+
self.application: ProcessMonitorApplication = None
|
|
301
|
+
super().__init__(title, rel_url, application)
|
|
302
|
+
|
|
303
|
+
def setup_component(self):
|
|
304
|
+
self.location = dcc.Location(id='optuna-page-location', refresh=True)
|
|
305
|
+
self._layout = html.Div(children=[Msg.DETAIL_PAGE_TEXT_BEFORE_LOADING])
|
|
306
|
+
self.layout = [self.location, self._layout]
|
|
307
|
+
|
|
308
|
+
def _setup_layout(self):
|
|
309
|
+
|
|
310
|
+
study = self.application.history.create_optuna_study()
|
|
311
|
+
prm_names = self.application.history.prm_names
|
|
312
|
+
obj_names = self.application.history.obj_names
|
|
313
|
+
|
|
314
|
+
layout = []
|
|
315
|
+
|
|
316
|
+
layout.append(html.H2(Msg.DETAIL_PAGE_HISTORY_HEADER))
|
|
317
|
+
layout.append(html.H4(Msg.DETAIL_PAGE_HISTORY_DESCRIPTION))
|
|
318
|
+
for i, obj_name in enumerate(obj_names):
|
|
319
|
+
fig = optuna.visualization.plot_optimization_history(
|
|
320
|
+
study,
|
|
321
|
+
target=lambda t: t.values[i],
|
|
322
|
+
target_name=obj_name
|
|
323
|
+
)
|
|
324
|
+
layout.append(dcc.Graph(figure=fig, style={'height': '70vh'}))
|
|
325
|
+
|
|
326
|
+
layout.append(html.H2(Msg.DETAIL_PAGE_PARALLEL_COOR_HEADER))
|
|
327
|
+
layout.append(html.H4(Msg.DETAIL_PAGE_PARALLEL_COOR_DESCRIPTION))
|
|
328
|
+
for i, obj_name in enumerate(obj_names):
|
|
329
|
+
fig = optuna.visualization.plot_parallel_coordinate(
|
|
330
|
+
study,
|
|
331
|
+
target=lambda t: t.values[i],
|
|
332
|
+
target_name=obj_name
|
|
333
|
+
)
|
|
334
|
+
layout.append(dcc.Graph(figure=fig, style={'height': '70vh'}))
|
|
335
|
+
|
|
336
|
+
layout.append(html.H2(Msg.DETAIL_PAGE_CONTOUR_HEADER))
|
|
337
|
+
layout.append(html.H4(Msg.DETAIL_PAGE_CONTOUR_DESCRIPTION))
|
|
338
|
+
for i, obj_name in enumerate(obj_names):
|
|
339
|
+
fig = optuna.visualization.plot_contour(
|
|
340
|
+
study,
|
|
341
|
+
target=lambda t: t.values[i],
|
|
342
|
+
target_name=obj_name
|
|
343
|
+
)
|
|
344
|
+
layout.append(dcc.Graph(figure=fig, style={'height': '90vh'}))
|
|
345
|
+
|
|
346
|
+
# import itertools
|
|
347
|
+
# for (i, j) in itertools.combinations(range(len(obj_names)), 2):
|
|
348
|
+
# fig = optuna.visualization.plot_pareto_front(
|
|
349
|
+
# study,
|
|
350
|
+
# targets=lambda t: (t.values[i], t.values[j]),
|
|
351
|
+
# target_names=[obj_names[i], obj_names[j]],
|
|
352
|
+
# )
|
|
353
|
+
# self.graphs.append(dcc.Graph(figure=fig, style={'height': '50vh'}))
|
|
354
|
+
|
|
355
|
+
layout.append(html.H2(Msg.DETAIL_PAGE_SLICE_HEADER))
|
|
356
|
+
layout.append(html.H4(Msg.DETAIL_PAGE_SLICE_DESCRIPTION))
|
|
357
|
+
for i, obj_name in enumerate(obj_names):
|
|
358
|
+
fig = optuna.visualization.plot_slice(
|
|
359
|
+
study,
|
|
360
|
+
target=lambda t: t.values[i],
|
|
361
|
+
target_name=obj_name
|
|
362
|
+
)
|
|
363
|
+
layout.append(dcc.Graph(figure=fig, style={'height': '70vh'}))
|
|
364
|
+
|
|
365
|
+
return layout
|
|
366
|
+
|
|
367
|
+
def setup_callback(self):
|
|
368
|
+
app = self.application.app
|
|
369
|
+
|
|
370
|
+
@app.callback(
|
|
371
|
+
Output(self._layout, 'children'),
|
|
372
|
+
Input(self.location, 'pathname'), # on page load
|
|
373
|
+
)
|
|
374
|
+
def update_page(_):
|
|
375
|
+
if self.application.history is None:
|
|
376
|
+
return Msg.ERR_NO_HISTORY_SELECTED
|
|
377
|
+
|
|
378
|
+
if len(self.data_accessor()) == 0:
|
|
379
|
+
return Msg.ERR_NO_FEM_RESULT
|
|
380
|
+
|
|
381
|
+
return self._setup_layout()
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def setup_layout(self):
|
|
385
|
+
pass
|
|
386
|
+
|
|
387
|
+
def data_accessor(self) -> pd.DataFrame:
|
|
388
|
+
from pyfemtet.opt.visualization.process_monitor.application import ProcessMonitorApplication
|
|
389
|
+
if isinstance(self.application, ProcessMonitorApplication):
|
|
390
|
+
df = self.application.local_data
|
|
391
|
+
else:
|
|
392
|
+
df = self.application.history.local_data
|
|
393
|
+
return df
|