pyfemtet 0.4.21__py3-none-any.whl → 0.4.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyfemtet might be problematic. Click here for more details.

Files changed (54) hide show
  1. pyfemtet/__init__.py +1 -1
  2. pyfemtet/_test_util.py +0 -2
  3. pyfemtet/message/messages.py +15 -1
  4. pyfemtet/opt/_femopt.py +233 -199
  5. pyfemtet/opt/_femopt_core.py +116 -47
  6. pyfemtet/opt/femprj_sample/ParametricIF.py +0 -2
  7. pyfemtet/opt/femprj_sample/cad_ex01_NX.py +0 -8
  8. pyfemtet/opt/femprj_sample/cad_ex01_SW.py +0 -8
  9. pyfemtet/opt/femprj_sample/gal_ex58_parametric.py +0 -8
  10. pyfemtet/opt/femprj_sample/gau_ex08_parametric.py +0 -8
  11. pyfemtet/opt/femprj_sample/her_ex40_parametric.py +0 -8
  12. pyfemtet/opt/femprj_sample/paswat_ex1_parametric.py +0 -8
  13. pyfemtet/opt/femprj_sample/paswat_ex1_parametric_parallel.py +0 -8
  14. pyfemtet/opt/femprj_sample/wat_ex14_parametric.py +0 -8
  15. pyfemtet/opt/femprj_sample/wat_ex14_parametric_parallel.py +0 -8
  16. pyfemtet/opt/femprj_sample_jp/ParametricIF_jp.py +0 -2
  17. pyfemtet/opt/femprj_sample_jp/cad_ex01_NX_jp.py +0 -8
  18. pyfemtet/opt/femprj_sample_jp/cad_ex01_SW_jp.py +0 -8
  19. pyfemtet/opt/femprj_sample_jp/gal_ex58_parametric_jp.py +0 -8
  20. pyfemtet/opt/femprj_sample_jp/gau_ex08_parametric_jp.py +0 -8
  21. pyfemtet/opt/femprj_sample_jp/her_ex40_parametric_jp.py +0 -8
  22. pyfemtet/opt/femprj_sample_jp/paswat_ex1_parametric_jp.py +0 -8
  23. pyfemtet/opt/femprj_sample_jp/paswat_ex1_parametric_parallel_jp.py +0 -8
  24. pyfemtet/opt/femprj_sample_jp/wat_ex14_parametric_jp.py +0 -8
  25. pyfemtet/opt/femprj_sample_jp/wat_ex14_parametric_parallel_jp.py +0 -8
  26. pyfemtet/opt/opt/_base.py +4 -4
  27. pyfemtet/opt/opt/_optuna.py +33 -1
  28. pyfemtet/opt/opt/_optuna_botorch_helper.py +209 -0
  29. pyfemtet/opt/visualization/complex_components/main_graph.py +22 -5
  30. pyfemtet/opt/visualization/complex_components/pm_graph.py +77 -25
  31. pyfemtet/opt/visualization/complex_components/pm_graph_creator.py +7 -0
  32. pyfemtet/opt/visualization/process_monitor/application.py +10 -6
  33. pyfemtet/opt/visualization/process_monitor/pages.py +102 -0
  34. pyfemtet/opt/visualization/result_viewer/application.py +6 -0
  35. pyfemtet/opt/visualization/result_viewer/pages.py +1 -1
  36. {pyfemtet-0.4.21.dist-info → pyfemtet-0.4.23.dist-info}/METADATA +2 -4
  37. {pyfemtet-0.4.21.dist-info → pyfemtet-0.4.23.dist-info}/RECORD +40 -53
  38. pyfemtet/FemtetPJTSample/NX_ex01/NX_ex01.femprj +0 -0
  39. pyfemtet/FemtetPJTSample/NX_ex01/NX_ex01.prt +0 -0
  40. pyfemtet/FemtetPJTSample/NX_ex01/NX_ex01.py +0 -118
  41. pyfemtet/FemtetPJTSample/Sldworks_ex01/Sldworks_ex01.SLDPRT +0 -0
  42. pyfemtet/FemtetPJTSample/Sldworks_ex01/Sldworks_ex01.femprj +0 -0
  43. pyfemtet/FemtetPJTSample/Sldworks_ex01/Sldworks_ex01.py +0 -121
  44. pyfemtet/FemtetPJTSample/_her_ex40_parametric.py +0 -148
  45. pyfemtet/FemtetPJTSample/gau_ex08_parametric.femprj +0 -0
  46. pyfemtet/FemtetPJTSample/gau_ex08_parametric.py +0 -58
  47. pyfemtet/FemtetPJTSample/her_ex40_parametric.femprj +0 -0
  48. pyfemtet/FemtetPJTSample/her_ex40_parametric.py +0 -148
  49. pyfemtet/FemtetPJTSample/wat_ex14_parallel_parametric.py +0 -65
  50. pyfemtet/FemtetPJTSample/wat_ex14_parametric.femprj +0 -0
  51. pyfemtet/FemtetPJTSample/wat_ex14_parametric.py +0 -64
  52. {pyfemtet-0.4.21.dist-info → pyfemtet-0.4.23.dist-info}/LICENSE +0 -0
  53. {pyfemtet-0.4.21.dist-info → pyfemtet-0.4.23.dist-info}/WHEEL +0 -0
  54. {pyfemtet-0.4.21.dist-info → pyfemtet-0.4.23.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,209 @@
1
+ from typing import Optional, List, Tuple, Callable
2
+ from functools import partial
3
+ import inspect
4
+
5
+ import numpy as np
6
+ import optuna.study
7
+ import torch
8
+ from torch import Tensor
9
+ from botorch.optim.initializers import gen_batch_initial_conditions
10
+ from botorch.utils.transforms import unnormalize
11
+ from optuna.study import Study
12
+ from botorch.acquisition import AcquisitionFunction
13
+
14
+ from pyfemtet.opt.opt import AbstractOptimizer
15
+ from pyfemtet.opt.parameter import ExpressionEvaluator
16
+
17
+ # module to monkey patch
18
+ import optuna_integration
19
+
20
+
21
+ # モンキーパッチを実行するため、optimize_acqf の引数を MonkyPatch クラスで定義し optuna に上書きされないようにするためのクラス
22
+ class NonOverwritablePartial(partial):
23
+ def __call__(self, /, *args, **keywords):
24
+ stored_kwargs = self.keywords
25
+ keywords.update(stored_kwargs)
26
+ return self.func(*self.args, *args, **keywords)
27
+
28
+
29
+ # prm_name を引数に取る関数を optimize_acqf の nonlinear_inequality_constraints に入れられる形に変換する関数
30
+ class ConvertedConstraintFunction:
31
+ def __init__(self, fun, prm_args, kwargs, variables: ExpressionEvaluator, study: optuna.study.Study):
32
+ self.fun = fun
33
+ self.prm_args = prm_args
34
+ self.kwargs = kwargs
35
+ self.variables = variables
36
+ self.study = study
37
+
38
+ self.bounds = None
39
+ self.prm_name_seq = None
40
+
41
+ # fun の prm として使う引数が指定されていなければ fun の引数を取得
42
+ if self.prm_args is None:
43
+ signature = inspect.signature(fun)
44
+ prm_inputs = set([a.name for a in signature.parameters.values()])
45
+ else:
46
+ prm_inputs = set(self.prm_args)
47
+
48
+ # 引数の set から kwargs の key を削除
49
+ self.prm_arg_names = prm_inputs - set(kwargs.keys())
50
+
51
+ # 変な引数が残っていないか確認
52
+ assert all([(arg in variables.get_parameter_names()) for arg in self.prm_arg_names])
53
+
54
+ def __call__(self, x: Tensor or np.ndarray):
55
+ # x: all of normalized parameters whose sequence is sorted by optuna
56
+
57
+ if not isinstance(x, Tensor):
58
+ x = torch.tensor(np.array(x)).double()
59
+
60
+ x = unnormalize(x, self.bounds)
61
+
62
+ # fun で使うパラメータのみ value を取得
63
+ kwargs = self.kwargs
64
+ kwargs.update(
65
+ {k: v for k, v in zip(self.prm_name_seq, x) if k in self.prm_arg_names}
66
+ )
67
+
68
+ return self.fun(**kwargs)
69
+
70
+
71
+ # 与えられた獲得関数に拘束を満たさない場合 0 を返すよう加工された獲得関数
72
+ class AcqWithConstraint(AcquisitionFunction):
73
+
74
+ # noinspection PyAttributeOutsideInit
75
+ def set(self, _org_acq_function: AcquisitionFunction, nonlinear_constraints):
76
+ self._org_acq_function = _org_acq_function
77
+ self._nonlinear_constraints = nonlinear_constraints
78
+
79
+ def forward(self, X: Tensor) -> Tensor:
80
+ base = self._org_acq_function.forward(X)
81
+
82
+ is_feasible = all([cons(X[0][0]) > 0 for cons, _ in self._nonlinear_constraints])
83
+ if is_feasible:
84
+ return base
85
+ else:
86
+ # penalty = torch.Tensor(size=base.shape)
87
+ # penalty = torch.fill(penalty, -1e10)
88
+ # return base * penalty
89
+ return base * 0.
90
+
91
+
92
+ def remove_infeasible(_ic_batch, nonlinear_constraints):
93
+ # infeasible なものを削除
94
+ remove_indices = []
95
+ for i, ic in enumerate(_ic_batch): # ic: 1 x len(params) tensor
96
+ # cons: Callable[["Tensor"], "Tensor"]
97
+ is_feasible = all([cons(ic[0]) > 0 for cons, _ in nonlinear_constraints])
98
+ if not is_feasible:
99
+ # ic_batch[i] = torch.nan # これで無視にならない
100
+ remove_indices.append(i)
101
+ for i in remove_indices[::-1]:
102
+ _ic_batch = torch.cat((_ic_batch[:i], _ic_batch[i + 1:]))
103
+ return _ic_batch
104
+
105
+
106
+ class OptunaBotorchWithParameterConstraintMonkeyPatch:
107
+
108
+ def __init__(self, study: Study, opt: AbstractOptimizer):
109
+ self.num_restarts: int = 20
110
+ self.raw_samples_additional: int = 512
111
+ self.eta: float = 2.0
112
+ self.study = study
113
+ self.opt = opt
114
+ self.nonlinear_inequality_constraints = []
115
+ self.additional_kwargs = dict()
116
+ self.bounds = None
117
+ self.prm_name_seq = None
118
+
119
+ def add_nonlinear_constraint(self, fun, prm_args, kwargs):
120
+ f = ConvertedConstraintFunction(
121
+ fun,
122
+ prm_args,
123
+ kwargs,
124
+ self.opt.variables,
125
+ self.study,
126
+ )
127
+
128
+ # 初期化
129
+ self.nonlinear_inequality_constraints = self.nonlinear_inequality_constraints or []
130
+
131
+ # 自身に追加
132
+ self.nonlinear_inequality_constraints.append((f, True))
133
+
134
+ # optimize_acqf() に渡す引数に追加
135
+ self.additional_kwargs.update(
136
+ nonlinear_inequality_constraints=self.nonlinear_inequality_constraints
137
+ )
138
+
139
+ def _detect_prm_seq_if_needed(self):
140
+ # study から distribution の情報を復元する。
141
+ if self.bounds is None or self.prm_name_seq is None:
142
+ from optuna._transform import _transform_search_space
143
+ # sample_relative の後に呼ばれているから最後の trial は search_space を持つはず
144
+ search_space: dict = self.study.sampler.infer_relative_search_space(self.study, self.study.trials[-1])
145
+ self.bounds = _transform_search_space(search_space, False, False)[0].T
146
+ self.prm_name_seq = list(search_space.keys())
147
+
148
+ for cns in self.nonlinear_inequality_constraints:
149
+ cns[0].bounds = torch.tensor(self.bounds)
150
+ cns[0].prm_name_seq = self.prm_name_seq
151
+
152
+ def generate_initial_conditions(self, *args, **kwargs):
153
+ self._detect_prm_seq_if_needed()
154
+
155
+ # acqf_function を 上書きし、拘束を満たさないならば 0 を返すようにする
156
+ org_acq_function = kwargs['acq_function']
157
+ new_acqf = AcqWithConstraint(None)
158
+ new_acqf.set(org_acq_function, self.nonlinear_inequality_constraints)
159
+ kwargs['acq_function'] = new_acqf
160
+
161
+ # initial condition の提案 batch を作成
162
+ # ic: `num_restarts x q x d` tensor of initial conditions.
163
+ # q = 1, d = len(params)
164
+ ic_batch = gen_batch_initial_conditions(*args, **kwargs)
165
+
166
+ # 拘束を満たさないものを削除
167
+ ic_batch = remove_infeasible(ic_batch, self.nonlinear_inequality_constraints)
168
+
169
+ # 全部なくなっているならばランダムに生成
170
+ if len(ic_batch) == 0:
171
+ print('拘束を満たす組み合わせがなかったのでランダムサンプリングします')
172
+ while len(ic_batch) == 0:
173
+ size = ic_batch.shape
174
+ ic_batch = torch.rand(size=[100, *size[1:]]) # 正規化された変数の組合せ
175
+ ic_batch = remove_infeasible(ic_batch, self.nonlinear_inequality_constraints)
176
+
177
+ return ic_batch
178
+
179
+ def do_monkey_patch(self):
180
+ """optuna_integration.botorch には optimize_acqf に constraints を渡す方法が用意されていないので、モンキーパッチして渡す
181
+
182
+ モンキーパッチ自体は最適化実行前のどの時点で呼んでも機能するが、additional_kwargs の更新後に
183
+ モンキーパッチを呼ぶ必要があるのでコンストラクタにこの処理は入れない。
184
+ 各 add_constraint に入れるのはいいかも。
185
+
186
+ """
187
+
188
+ # === reconstruct argument ``options`` for optimize_acqf ===
189
+ options = dict() # initialize
190
+
191
+ # for nonlinear-constraint
192
+ options.update(dict(batch_limit=1))
193
+
194
+ # for gen_candidates_scipy()
195
+ # use COBYLA or SLSQP only.
196
+ options.update(dict(method='SLSQP'))
197
+
198
+ # make partial of optimize_acqf used in optuna_integration.botorch and replace to it.
199
+ original_fun = optuna_integration.botorch.optimize_acqf
200
+ overwritten_fun = NonOverwritablePartial(
201
+ original_fun,
202
+ q=1, # for nonlinear constraints
203
+ options=options,
204
+ num_restarts=20, # gen_batch_initial_conditions に渡すべきで、self.generate_initial_conditions に渡される変数。
205
+ raw_samples=512, # gen_batch_initial_conditions に渡すべきで、self.generate_initial_conditions に渡される変数。
206
+ nonlinear_inequality_constraints=self.nonlinear_inequality_constraints,
207
+ ic_generator=self.generate_initial_conditions,
208
+ )
209
+ optuna_integration.botorch.optimize_acqf = overwritten_fun
@@ -184,16 +184,18 @@ class MainGraph(AbstractPage):
184
184
  # create component
185
185
  title_component = html.H3(f"trial{trial}", style={"color": "darkblue"})
186
186
  img_component = self.create_image_content_if_femtet(trial)
187
- tbl_component = self.create_formatted_parameter(row)
187
+ tbl_component_prm = self.create_formatted_parameter(row)
188
+ tbl_component_obj = self.create_formatted_objective(row)
188
189
 
189
190
  # create layout
190
191
  description = html.Div([
191
192
  title_component,
192
- tbl_component,
193
+ tbl_component_prm,
193
194
  ])
194
195
  tooltip_layout = html.Div([
195
196
  html.Div(img_component, style={'display': 'inline-block', 'margin-right': '10px', 'vertical-align': 'top'}),
196
- html.Div(description, style={'display': 'inline-block', 'margin-right': '10px'})
197
+ html.Div(description, style={'display': 'inline-block', 'margin-right': '10px'}),
198
+ html.Div(tbl_component_obj, style={'display': 'inline-block', 'margin-right': '10px'}),
197
199
  ])
198
200
 
199
201
  return True, bbox, tooltip_layout
@@ -205,7 +207,22 @@ class MainGraph(AbstractPage):
205
207
  names = parameters.columns
206
208
  values = [f'{value:.3e}' for value in parameters.values.ravel()]
207
209
  data = pd.DataFrame(dict(
208
- name=names, value=values
210
+ parameter=names, value=values
211
+ ))
212
+ table = dash_table.DataTable(
213
+ columns=[{'name': col, 'id': col} for col in data.columns],
214
+ data=data.to_dict('records')
215
+ )
216
+ return table
217
+
218
+ def create_formatted_objective(self, row) -> Component:
219
+ metadata = self.application.history.metadata
220
+ pd.options.display.float_format = '{:.4e}'.format
221
+ objectives = row.iloc[:, np.where(np.array(metadata) == 'obj')[0]]
222
+ names = objectives.columns
223
+ values = [f'{value:.3e}' for value in objectives.values.ravel()]
224
+ data = pd.DataFrame(dict(
225
+ objective=names, value=values
209
226
  ))
210
227
  table = dash_table.DataTable(
211
228
  columns=[{'name': col, 'id': col} for col in data.columns],
@@ -258,5 +275,5 @@ class MainGraph(AbstractPage):
258
275
  if isinstance(self.application, ProcessMonitorApplication):
259
276
  df = self.application.local_data
260
277
  else:
261
- df = self.application.history.local_data
278
+ df = self.application.history.get_df()
262
279
  return df
@@ -129,6 +129,19 @@ class PredictionModelGraph(AbstractPage):
129
129
  self.slider_stack_data = html.Data(**{self.slider_stack_data_prop: {}})
130
130
  self.slider_container = html.Div()
131
131
 
132
+ # 2d or 3d
133
+ self.switch_3d = dbc.Checklist(
134
+ options=[
135
+ dict(
136
+ label=Msg.LABEL_SWITCH_PREDICTION_MODEL_3D,
137
+ disabled=False,
138
+ value=False,
139
+ )
140
+ ],
141
+ switch=True,
142
+ value=[],
143
+ )
144
+
132
145
  def setup_layout(self):
133
146
  self.card_header = dbc.CardHeader(self.tabs)
134
147
 
@@ -151,6 +164,7 @@ class PredictionModelGraph(AbstractPage):
151
164
  self.command_manager
152
165
  ],
153
166
  direction='horizontal', gap=2),
167
+ self.switch_3d,
154
168
  *dropdown_rows,
155
169
  self.slider_container,
156
170
  self.slider_stack_data,
@@ -193,13 +207,15 @@ class PredictionModelGraph(AbstractPage):
193
207
  Output(self.redraw_graph_button_spinner, 'spinner_style', allow_duplicate=True),
194
208
  Output(self.redraw_graph_button, 'disabled', allow_duplicate=True),
195
209
  Output(self.command_manager, self.command_manager_prop, allow_duplicate=True),
210
+ Output(self.switch_3d, 'options', allow_duplicate=True),
196
211
  Input(self.fit_rsm_button, 'n_clicks'),
197
212
  Input(self.redraw_graph_button, 'n_clicks'),
198
213
  State(self.fit_rsm_button_spinner, 'spinner_style'),
199
214
  State(self.redraw_graph_button_spinner, 'spinner_style'),
215
+ State(self.switch_3d, 'options'),
200
216
  prevent_initial_call=True,
201
217
  )
202
- def disable_fit_button(_1, _2, state1, state2):
218
+ def disable_fit_button(_1, _2, state1, state2, switch_options):
203
219
  # spinner visibility
204
220
  if 'display' in state1.keys(): state1.pop('display')
205
221
  if 'display' in state2.keys(): state2.pop('display')
@@ -210,7 +226,12 @@ class PredictionModelGraph(AbstractPage):
210
226
  else:
211
227
  command = self.CommandState.redraw.value
212
228
 
213
- return state1, True, state2, True, command
229
+ # disable switch
230
+ option = switch_options[0]
231
+ option.update({'disabled': True})
232
+ switch_options[0] = option
233
+
234
+ return state1, True, state2, True, command, switch_options
214
235
 
215
236
  # ===== recreate RSM =====
216
237
  @app.callback(
@@ -256,9 +277,10 @@ class PredictionModelGraph(AbstractPage):
256
277
  State(self.axis3_obj_dropdown, 'label'),
257
278
  State(self.slider_container, 'children'), # for callback chain
258
279
  State({'type': 'prm-slider', 'index': ALL}, 'value'),
280
+ State(self.switch_3d, 'value'),
259
281
  prevent_initial_call=True,
260
282
  )
261
- def redraw_graph(command, active_tab_id, axis1_label, axis2_label, axis3_label, _2, prm_values):
283
+ def redraw_graph(command, active_tab_id, axis1_label, axis2_label, axis3_label, _2, prm_values, is_3d):
262
284
  # just in case
263
285
  if callback_context.triggered_id is None:
264
286
  raise PreventUpdate
@@ -283,6 +305,9 @@ class PredictionModelGraph(AbstractPage):
283
305
  logger.error(Msg.ERR_NO_PREDICTION_MODEL)
284
306
  return no_update, self.CommandState.ready.value # to re-enable buttons, fire callback chain
285
307
 
308
+ if not is_3d:
309
+ axis2_label = None
310
+
286
311
  # get indices to remove
287
312
  idx1 = prm_names.index(axis1_label) if axis1_label in prm_names else None
288
313
  idx2 = prm_names.index(axis2_label) if axis2_label in prm_names else None
@@ -307,23 +332,31 @@ class PredictionModelGraph(AbstractPage):
307
332
 
308
333
  return fig, self.CommandState.ready.value
309
334
 
310
- # ===== When the graph is updated, enable buttons =====
335
+ # ===== re-enable buttons when the graph is updated, =====
311
336
  @app.callback(
312
337
  Output(self.fit_rsm_button, 'disabled', allow_duplicate=True),
313
338
  Output(self.fit_rsm_button_spinner, 'spinner_style', allow_duplicate=True),
314
339
  Output(self.redraw_graph_button, 'disabled', allow_duplicate=True),
315
340
  Output(self.redraw_graph_button_spinner, 'spinner_style', allow_duplicate=True),
341
+ Output(self.switch_3d, 'options', allow_duplicate=True),
316
342
  Input(self.command_manager, self.command_manager_prop),
317
343
  State(self.fit_rsm_button_spinner, 'spinner_style'),
318
344
  State(self.redraw_graph_button_spinner, 'spinner_style'),
345
+ State(self.switch_3d, 'options'),
319
346
  prevent_initial_call=True,
320
347
  )
321
- def enable_buttons(command, state1, state2):
348
+ def enable_buttons(command, state1, state2, switch_options):
322
349
  if command != self.CommandState.ready.value:
323
350
  raise PreventUpdate
324
351
  state1.update({'display': 'none'})
325
352
  state2.update({'display': 'none'})
326
- return False, state1, False, state2
353
+
354
+ # enable switch
355
+ option = switch_options[0]
356
+ option.update({'disabled': False})
357
+ switch_options[0] = option
358
+
359
+ return False, state1, False, state2, switch_options
327
360
 
328
361
  # ===== setup dropdown and sliders from history =====
329
362
  @app.callback(
@@ -414,6 +447,7 @@ class PredictionModelGraph(AbstractPage):
414
447
  Input({'type': 'axis2-dropdown-menu-item', 'index': ALL}, 'n_clicks'),
415
448
  Input({'type': 'axis3-dropdown-menu-item', 'index': ALL}, 'n_clicks'),
416
449
  Input(self.axis1_prm_dropdown, 'children'), # for callback chain timing
450
+ Input(self.switch_3d, 'value'),
417
451
  State(self.axis1_prm_dropdown, 'label'),
418
452
  State(self.axis2_prm_dropdown, 'label'),
419
453
  State(self.axis3_obj_dropdown, 'label'),
@@ -422,10 +456,10 @@ class PredictionModelGraph(AbstractPage):
422
456
  )
423
457
  def update_controller(*args):
424
458
  # argument processing
425
- current_ax1_label = args[4]
426
- current_ax2_label = args[5]
427
- # current_ax3_label = args[6]
428
- current_styles: list[dict] = args[7]
459
+ current_ax1_label = args[5]
460
+ current_ax2_label = args[6]
461
+ current_styles: list[dict] = args[8]
462
+ is_3d = args[4]
429
463
 
430
464
  # just in case
431
465
  if callback_context.triggered_id is None:
@@ -451,6 +485,10 @@ class PredictionModelGraph(AbstractPage):
451
485
  if len(prm_names) < 2:
452
486
  ret[ax2_hidden] = True
453
487
 
488
+ # ===== hide dropdown of axis 2 if not is_3d =====
489
+ if not is_3d:
490
+ ret[ax2_hidden] = True
491
+
454
492
  # ===== update dropdown label =====
455
493
 
456
494
  # by callback chain on loaded after setup_dropdown_and_sliders()
@@ -466,29 +504,43 @@ class PredictionModelGraph(AbstractPage):
466
504
 
467
505
  # ax1
468
506
  if callback_context.triggered_id['type'] == 'axis1-dropdown-menu-item':
469
- if new_label != current_ax2_label:
470
- ret[ax1_label_key] = new_label
471
- else:
472
- logger.error(Msg.ERR_CANNOT_SELECT_SAME_PARAMETER)
507
+ ret[ax1_label_key] = new_label
508
+ if new_label == current_ax2_label:
509
+ ret[ax2_label_key] = current_ax1_label
510
+
473
511
 
474
512
  # ax2
475
513
  elif callback_context.triggered_id['type'] == 'axis2-dropdown-menu-item':
476
- if new_label != current_ax1_label:
477
- ret[ax2_label_key] = new_label
478
- else:
479
- logger.error(Msg.ERR_CANNOT_SELECT_SAME_PARAMETER)
514
+ ret[ax2_label_key] = new_label
515
+ if new_label == current_ax1_label:
516
+ ret[ax1_label_key] = current_ax2_label
517
+
480
518
 
481
519
  # ax3
482
520
  elif callback_context.triggered_id['type'] == 'axis3-dropdown-menu-item':
483
521
  ret[ax3_label_key] = new_label
484
522
 
485
523
  # ===== update visibility of sliders =====
486
- for label_key, current_label in zip((ax1_label_key, ax2_label_key), (current_ax1_label, current_ax2_label)):
487
- # get label of output
488
- label = ret[label_key] if ret[label_key] != no_update else current_label
489
- # update display style of slider
490
- idx = prm_names.index(label) if label in prm_names else None
491
- if idx is not None:
524
+
525
+ # invisible the slider correspond to the dropdown-1
526
+ label_key, current_label = ax1_label_key, current_ax1_label
527
+ # get label of output
528
+ label = ret[label_key] if ret[label_key] != no_update else current_label
529
+ # update display style of slider
530
+ idx = prm_names.index(label) if label in prm_names else None
531
+ if idx is not None:
532
+ current_styles[idx].update({'display': 'none'})
533
+ ret[slider_style_list_key][idx] = current_styles[idx]
534
+
535
+ # invisible the slider correspond to the dropdown-2
536
+ label_key, current_label = ax2_label_key, current_ax2_label
537
+ # get label of output
538
+ label = ret[label_key] if ret[label_key] != no_update else current_label
539
+ # update display style of slider
540
+ idx = prm_names.index(label) if label in prm_names else None
541
+ if idx is not None:
542
+ # if 2d, should not disable the slider correspond to dropdown-2.
543
+ if is_3d:
492
544
  current_styles[idx].update({'display': 'none'})
493
545
  ret[slider_style_list_key][idx] = current_styles[idx]
494
546
 
@@ -554,5 +606,5 @@ class PredictionModelGraph(AbstractPage):
554
606
  if isinstance(self.application, ProcessMonitorApplication):
555
607
  df = self.application.local_data
556
608
  else:
557
- df = self.application.history.local_data
609
+ df = self.application.history.get_df()
558
610
  return df
@@ -183,4 +183,11 @@ class PredictionModelCreator:
183
183
  )
184
184
  )
185
185
 
186
+ # layout
187
+ fig.update_layout(
188
+ title=Msg.GRAPH_TITLE_PREDICTION_MODEL,
189
+ xaxis_title=prm_name_1,
190
+ yaxis_title=obj_name,
191
+ )
192
+
186
193
  return fig
@@ -5,7 +5,7 @@ from threading import Thread
5
5
  import pandas as pd
6
6
 
7
7
  from pyfemtet.opt.visualization.base import PyFemtetApplicationBase, logger
8
- from pyfemtet.opt.visualization.process_monitor.pages import HomePage, WorkerPage, PredictionModelPage
8
+ from pyfemtet.opt.visualization.process_monitor.pages import HomePage, WorkerPage, PredictionModelPage, OptunaVisualizerPage
9
9
  from pyfemtet.message import Msg
10
10
 
11
11
 
@@ -67,14 +67,14 @@ class ProcessMonitorApplication(PyFemtetApplicationBase):
67
67
  if self._should_get_actor_data:
68
68
  return self._df
69
69
  else:
70
- return self.history.local_data
70
+ return self.history.get_df()
71
71
 
72
72
  @local_data.setter
73
73
  def local_data(self, value: pd.DataFrame):
74
74
  if self._should_get_actor_data:
75
75
  raise NotImplementedError('If should_get_actor_data, ProcessMonitorApplication.local_df is read_only.')
76
76
  else:
77
- self.history.local_data = value
77
+ self.history.set_df(value)
78
78
 
79
79
  def setup_callback(self, debug=False):
80
80
  if not debug:
@@ -112,7 +112,7 @@ class ProcessMonitorApplication(PyFemtetApplicationBase):
112
112
  worker_status.set(OptimizationStatus.INTERRUPTING)
113
113
 
114
114
  # status と df を actor から application に反映する
115
- self._df = self.history.actor_data.copy()
115
+ self._df = self.history.get_df().copy()
116
116
  self.local_entire_status_int = self.entire_status.get()
117
117
  self.local_worker_status_int_list = [s.get() for s in self.worker_status_list]
118
118
 
@@ -176,11 +176,13 @@ def g_debug():
176
176
 
177
177
  g_home_page = HomePage(Msg.PAGE_TITLE_PROGRESS)
178
178
  g_rsm_page = PredictionModelPage(Msg.PAGE_TITLE_PREDICTION_MODEL, '/prediction-model', g_application)
179
+ g_optuna = OptunaVisualizerPage(Msg.PAGE_TITLE_OPTUNA_VISUALIZATION, '/optuna', g_application)
179
180
  g_worker_page = WorkerPage(Msg.PAGE_TITLE_WORKERS, '/workers', g_application)
180
181
 
181
182
  g_application.add_page(g_home_page, 0)
182
183
  g_application.add_page(g_rsm_page, 1)
183
- g_application.add_page(g_worker_page, 2)
184
+ g_application.add_page(g_optuna, 2)
185
+ g_application.add_page(g_worker_page, 3)
184
186
  g_application.setup_callback(debug=False)
185
187
 
186
188
  g_application.run(debug=False)
@@ -191,11 +193,13 @@ def main(history, status, worker_addresses, worker_status_list, host=None, port=
191
193
 
192
194
  g_home_page = HomePage(Msg.PAGE_TITLE_PROGRESS)
193
195
  g_rsm_page = PredictionModelPage(Msg.PAGE_TITLE_PREDICTION_MODEL, '/prediction-model', g_application)
196
+ g_optuna = OptunaVisualizerPage(Msg.PAGE_TITLE_OPTUNA_VISUALIZATION, '/optuna', g_application)
194
197
  g_worker_page = WorkerPage(Msg.PAGE_TITLE_WORKERS, '/workers', g_application)
195
198
 
196
199
  g_application.add_page(g_home_page, 0)
197
200
  g_application.add_page(g_rsm_page, 1)
198
- g_application.add_page(g_worker_page, 2)
201
+ g_application.add_page(g_optuna, 2)
202
+ g_application.add_page(g_worker_page, 3)
199
203
  g_application.setup_callback()
200
204
 
201
205
  g_application.start_server(host, port)
@@ -1,6 +1,8 @@
1
1
  import numpy as np
2
2
  import pandas as pd
3
3
 
4
+ import optuna
5
+
4
6
  from dash import Output, Input, State, callback_context, no_update, ALL
5
7
  from dash.exceptions import PreventUpdate
6
8
 
@@ -289,3 +291,103 @@ class PredictionModelPage(AbstractPage):
289
291
 
290
292
  def setup_layout(self):
291
293
  self.layout = self.rsm_graph.layout
294
+
295
+
296
+ class OptunaVisualizerPage(AbstractPage):
297
+
298
+ def __init__(self, title, rel_url, application):
299
+ from pyfemtet.opt.visualization.process_monitor.application import ProcessMonitorApplication
300
+ self.application: ProcessMonitorApplication = None
301
+ super().__init__(title, rel_url, application)
302
+
303
+ def setup_component(self):
304
+ self.location = dcc.Location(id='optuna-page-location', refresh=True)
305
+ self._layout = html.Div(children=[Msg.DETAIL_PAGE_TEXT_BEFORE_LOADING])
306
+ self.layout = [self.location, self._layout]
307
+
308
+ def _setup_layout(self):
309
+
310
+ study = self.application.history.create_optuna_study()
311
+ prm_names = self.application.history.prm_names
312
+ obj_names = self.application.history.obj_names
313
+
314
+ layout = []
315
+
316
+ layout.append(html.H2(Msg.DETAIL_PAGE_HISTORY_HEADER))
317
+ layout.append(html.H4(Msg.DETAIL_PAGE_HISTORY_DESCRIPTION))
318
+ for i, obj_name in enumerate(obj_names):
319
+ fig = optuna.visualization.plot_optimization_history(
320
+ study,
321
+ target=lambda t: t.values[i],
322
+ target_name=obj_name
323
+ )
324
+ layout.append(dcc.Graph(figure=fig, style={'height': '70vh'}))
325
+
326
+ layout.append(html.H2(Msg.DETAIL_PAGE_PARALLEL_COOR_HEADER))
327
+ layout.append(html.H4(Msg.DETAIL_PAGE_PARALLEL_COOR_DESCRIPTION))
328
+ for i, obj_name in enumerate(obj_names):
329
+ fig = optuna.visualization.plot_parallel_coordinate(
330
+ study,
331
+ target=lambda t: t.values[i],
332
+ target_name=obj_name
333
+ )
334
+ layout.append(dcc.Graph(figure=fig, style={'height': '70vh'}))
335
+
336
+ layout.append(html.H2(Msg.DETAIL_PAGE_CONTOUR_HEADER))
337
+ layout.append(html.H4(Msg.DETAIL_PAGE_CONTOUR_DESCRIPTION))
338
+ for i, obj_name in enumerate(obj_names):
339
+ fig = optuna.visualization.plot_contour(
340
+ study,
341
+ target=lambda t: t.values[i],
342
+ target_name=obj_name
343
+ )
344
+ layout.append(dcc.Graph(figure=fig, style={'height': '90vh'}))
345
+
346
+ # import itertools
347
+ # for (i, j) in itertools.combinations(range(len(obj_names)), 2):
348
+ # fig = optuna.visualization.plot_pareto_front(
349
+ # study,
350
+ # targets=lambda t: (t.values[i], t.values[j]),
351
+ # target_names=[obj_names[i], obj_names[j]],
352
+ # )
353
+ # self.graphs.append(dcc.Graph(figure=fig, style={'height': '50vh'}))
354
+
355
+ layout.append(html.H2(Msg.DETAIL_PAGE_SLICE_HEADER))
356
+ layout.append(html.H4(Msg.DETAIL_PAGE_SLICE_DESCRIPTION))
357
+ for i, obj_name in enumerate(obj_names):
358
+ fig = optuna.visualization.plot_slice(
359
+ study,
360
+ target=lambda t: t.values[i],
361
+ target_name=obj_name
362
+ )
363
+ layout.append(dcc.Graph(figure=fig, style={'height': '70vh'}))
364
+
365
+ return layout
366
+
367
+ def setup_callback(self):
368
+ app = self.application.app
369
+
370
+ @app.callback(
371
+ Output(self._layout, 'children'),
372
+ Input(self.location, 'pathname'), # on page load
373
+ )
374
+ def update_page(_):
375
+ if self.application.history is None:
376
+ return Msg.ERR_NO_HISTORY_SELECTED
377
+
378
+ if len(self.data_accessor()) == 0:
379
+ return Msg.ERR_NO_FEM_RESULT
380
+
381
+ return self._setup_layout()
382
+
383
+
384
+ def setup_layout(self):
385
+ pass
386
+
387
+ def data_accessor(self) -> pd.DataFrame:
388
+ from pyfemtet.opt.visualization.process_monitor.application import ProcessMonitorApplication
389
+ if isinstance(self.application, ProcessMonitorApplication):
390
+ df = self.application.local_data
391
+ else:
392
+ df = self.application.history.local_data
393
+ return df