pyfemtet 1.0.4__py3-none-any.whl → 1.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyfemtet might be problematic. Click here for more details.
- pyfemtet/_i18n/locales/ja/LC_MESSAGES/messages.mo +0 -0
- pyfemtet/_i18n/locales/ja/LC_MESSAGES/messages.po +2 -2
- pyfemtet/opt/femopt.py +9 -0
- pyfemtet/opt/history/_history.py +108 -11
- pyfemtet/opt/optimizer/_base_optimizer.py +50 -8
- pyfemtet/opt/optimizer/optuna_optimizer/_pof_botorch/enable_nonlinear_constraint.py +4 -2
- pyfemtet/opt/optimizer/optuna_optimizer/_pof_botorch/pof_botorch_sampler.py +63 -11
- pyfemtet/opt/prediction/_botorch_utils.py +59 -2
- pyfemtet/opt/prediction/_model.py +2 -2
- pyfemtet/opt/problem/problem.py +9 -0
- pyfemtet/opt/problem/variable_manager/_variable_manager.py +1 -1
- pyfemtet/opt/visualization/history_viewer/_complex_components/detail_graphs.py +551 -0
- pyfemtet/opt/visualization/history_viewer/_detail_page.py +106 -0
- pyfemtet/opt/visualization/history_viewer/_process_monitor/_application.py +3 -2
- pyfemtet/opt/visualization/history_viewer/result_viewer/_application.py +3 -2
- pyfemtet/opt/visualization/history_viewer/result_viewer/_pages.py +1 -0
- pyfemtet/opt/visualization/plotter/contour_creator.py +105 -0
- pyfemtet/opt/visualization/plotter/parallel_plot_creator.py +33 -0
- pyfemtet/opt/visualization/plotter/pm_graph_creator.py +7 -7
- {pyfemtet-1.0.4.dist-info → pyfemtet-1.0.6.dist-info}/METADATA +1 -1
- {pyfemtet-1.0.4.dist-info → pyfemtet-1.0.6.dist-info}/RECORD +25 -21
- {pyfemtet-1.0.4.dist-info → pyfemtet-1.0.6.dist-info}/LICENSE +0 -0
- {pyfemtet-1.0.4.dist-info → pyfemtet-1.0.6.dist-info}/LICENSE_THIRD_PARTY.txt +0 -0
- {pyfemtet-1.0.4.dist-info → pyfemtet-1.0.6.dist-info}/WHEEL +0 -0
- {pyfemtet-1.0.4.dist-info → pyfemtet-1.0.6.dist-info}/entry_points.txt +0 -0
|
@@ -3,15 +3,17 @@ from __future__ import annotations
|
|
|
3
3
|
|
|
4
4
|
from packaging import version
|
|
5
5
|
|
|
6
|
+
from tqdm import tqdm
|
|
6
7
|
import torch
|
|
7
8
|
|
|
8
9
|
from gpytorch.mlls import ExactMarginalLogLikelihood
|
|
9
10
|
from gpytorch.kernels import MaternKernel, ScaleKernel # , RBFKernel
|
|
10
11
|
from gpytorch.priors.torch_priors import GammaPrior # , LogNormalPrior
|
|
11
|
-
|
|
12
|
+
from gpytorch.constraints.constraints import GreaterThan
|
|
12
13
|
|
|
13
14
|
from botorch.models import SingleTaskGP
|
|
14
15
|
from botorch.models.transforms import Standardize, Normalize
|
|
16
|
+
from botorch.exceptions import ModelFittingError
|
|
15
17
|
|
|
16
18
|
# import fit_gpytorch_mll
|
|
17
19
|
import botorch.version
|
|
@@ -22,6 +24,9 @@ if version.parse(botorch.version.version) < version.parse("0.8.0"):
|
|
|
22
24
|
else:
|
|
23
25
|
from botorch.fit import fit_gpytorch_mll
|
|
24
26
|
|
|
27
|
+
from pyfemtet.logger import get_module_logger
|
|
28
|
+
|
|
29
|
+
logger = get_module_logger('opt.botorch_util')
|
|
25
30
|
|
|
26
31
|
__all__ = [
|
|
27
32
|
'get_standardizer_and_no_noise_train_yvar',
|
|
@@ -128,6 +133,58 @@ def setup_gp(X, Y, bounds, observation_noise, lh_class=None, covar_module=None):
|
|
|
128
133
|
)
|
|
129
134
|
|
|
130
135
|
mll_ = lh_class(model_.likelihood, model_)
|
|
131
|
-
|
|
136
|
+
|
|
137
|
+
try:
|
|
138
|
+
fit_gpytorch_mll(mll_)
|
|
139
|
+
|
|
140
|
+
except ModelFittingError as e:
|
|
141
|
+
logger.warning(f'{type(e).__name__} is raised '
|
|
142
|
+
f'during fit_gpytorch_mll()! '
|
|
143
|
+
f'The original message is: `{",".join(e.args)}`.')
|
|
144
|
+
|
|
145
|
+
# retry with noise level setting
|
|
146
|
+
try:
|
|
147
|
+
logger.info('Attempt to retrying...')
|
|
148
|
+
|
|
149
|
+
# noinspection PyBroadException
|
|
150
|
+
try:
|
|
151
|
+
model_.likelihood.noise_covar.register_constraint(
|
|
152
|
+
"raw_noise", GreaterThan(1e-5))
|
|
153
|
+
logger.warning('Set the raw_noise constraint to 1e-5.')
|
|
154
|
+
except Exception:
|
|
155
|
+
logger.info('Failed to set the raw_noise constraint. '
|
|
156
|
+
'Retrying simply.')
|
|
157
|
+
|
|
158
|
+
fit_gpytorch_mll(mll_)
|
|
159
|
+
|
|
160
|
+
except ModelFittingError:
|
|
161
|
+
logger.warning(f'{type(e).__name__} is raised '
|
|
162
|
+
f'during *second* fit_gpytorch_mll()! '
|
|
163
|
+
f'The original message is: `{",".join(e.args)}`.')
|
|
164
|
+
|
|
165
|
+
# retry with another way
|
|
166
|
+
logger.warning('Attempt to retrying.')
|
|
167
|
+
logger.warning('Try to use SGD algorithm...')
|
|
168
|
+
|
|
169
|
+
from torch.optim import SGD
|
|
170
|
+
|
|
171
|
+
NUM_EPOCHS = 150
|
|
172
|
+
|
|
173
|
+
optimizer = SGD([{"params": model_.parameters()}], lr=0.025)
|
|
174
|
+
|
|
175
|
+
model_.train()
|
|
176
|
+
|
|
177
|
+
for epoch in tqdm(range(NUM_EPOCHS), desc='fit with SGD'):
|
|
178
|
+
# clear gradients
|
|
179
|
+
optimizer.zero_grad()
|
|
180
|
+
# forward pass through the model to obtain the output MultivariateNormal
|
|
181
|
+
output = model_(X)
|
|
182
|
+
# Compute negative marginal log likelihood
|
|
183
|
+
loss = -mll_(output, model_.train_targets)
|
|
184
|
+
# back prop gradients
|
|
185
|
+
loss.backward()
|
|
186
|
+
optimizer.step()
|
|
187
|
+
|
|
188
|
+
model_.eval()
|
|
132
189
|
|
|
133
190
|
return model_
|
|
@@ -73,8 +73,8 @@ class SingleTaskGPModel(AbstractModel):
|
|
|
73
73
|
X = torch.tensor(x, **self.KWARGS)
|
|
74
74
|
post = self.gp.posterior(X)
|
|
75
75
|
with torch.no_grad():
|
|
76
|
-
mean = post.mean.numpy()
|
|
77
|
-
std = post.variance.sqrt().numpy()
|
|
76
|
+
mean = post.mean.cpu().numpy()
|
|
77
|
+
std = post.variance.sqrt().cpu().numpy()
|
|
78
78
|
return mean, std
|
|
79
79
|
|
|
80
80
|
|
pyfemtet/opt/problem/problem.py
CHANGED
|
@@ -28,7 +28,9 @@ __all__ = [
|
|
|
28
28
|
'TrialInput',
|
|
29
29
|
'TrialOutput',
|
|
30
30
|
'TrialConstraintOutput',
|
|
31
|
+
'TrialFunctionOutput',
|
|
31
32
|
'Function',
|
|
33
|
+
'FunctionResult',
|
|
32
34
|
'Functions',
|
|
33
35
|
'Objective',
|
|
34
36
|
'ObjectiveResult',
|
|
@@ -138,6 +140,12 @@ class Objective(Function):
|
|
|
138
140
|
return self._convert(value, self.direction)
|
|
139
141
|
|
|
140
142
|
|
|
143
|
+
class FunctionResult:
|
|
144
|
+
|
|
145
|
+
def __init__(self, func: Function, fem: AbstractFEMInterface):
|
|
146
|
+
self.value: float = func.eval(fem)
|
|
147
|
+
|
|
148
|
+
|
|
141
149
|
class ObjectiveResult:
|
|
142
150
|
|
|
143
151
|
def __init__(self, obj: Objective, fem: AbstractFEMInterface, obj_value: float = None):
|
|
@@ -302,3 +310,4 @@ SubSampling: TypeAlias = int
|
|
|
302
310
|
TrialInput: TypeAlias = dict[str, Variable]
|
|
303
311
|
TrialOutput: TypeAlias = dict[str, ObjectiveResult]
|
|
304
312
|
TrialConstraintOutput: TypeAlias = dict[str, ConstraintResult]
|
|
313
|
+
TrialFunctionOutput: TypeAlias = dict[str, FunctionResult]
|
|
@@ -304,7 +304,7 @@ class VariableManager:
|
|
|
304
304
|
filter: (Literal['pass_to_fem', 'parameter']
|
|
305
305
|
| tuple[Literal['pass_to_fem', 'parameter']]
|
|
306
306
|
| None) = None, # 'pass_to_fem' and 'parameter' (OR filter)
|
|
307
|
-
format:
|
|
307
|
+
format: Literal['dict', 'values', 'raw'] | None = None, # Defaults to 'raw'
|
|
308
308
|
) -> (
|
|
309
309
|
dict[str, Variable]
|
|
310
310
|
| dict[str, Parameter]
|
|
@@ -0,0 +1,551 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import optuna
|
|
7
|
+
|
|
8
|
+
# dash components
|
|
9
|
+
from pyfemtet.opt.visualization.history_viewer._wrapped_components import dcc, dbc, html
|
|
10
|
+
|
|
11
|
+
# dash callback
|
|
12
|
+
from dash import Output, Input, callback_context, no_update
|
|
13
|
+
from dash.exceptions import PreventUpdate
|
|
14
|
+
|
|
15
|
+
from pyfemtet.logger import get_module_logger
|
|
16
|
+
from pyfemtet._i18n import _
|
|
17
|
+
|
|
18
|
+
from pyfemtet.opt.history import History, MAIN_FILTER
|
|
19
|
+
from pyfemtet.opt.visualization.history_viewer._base_application import AbstractPage
|
|
20
|
+
from pyfemtet.opt.visualization.plotter.parallel_plot_creator import parallel_plot
|
|
21
|
+
from pyfemtet.opt.visualization.plotter.contour_creator import contour_creator
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SelectablePlot(AbstractPage):
|
|
25
|
+
location: dcc.Location
|
|
26
|
+
graph: dcc.Graph
|
|
27
|
+
input_items: dcc.Checklist | dcc.RadioItems | html.Div
|
|
28
|
+
output_items: dcc.Checklist | dcc.RadioItems | html.Div
|
|
29
|
+
InputItemsClass = dcc.Checklist
|
|
30
|
+
OutputItemsClass = dcc.Checklist
|
|
31
|
+
alerts: html.Div
|
|
32
|
+
input_item_kind: set[Literal['all', 'prm', 'obj', 'cns']] = {'prm'}
|
|
33
|
+
output_item_kind: set[Literal['all', 'prm', 'obj', 'cns']] = {'obj', 'cns'}
|
|
34
|
+
description_markdown: str = ''
|
|
35
|
+
|
|
36
|
+
def __init__(self, title='base-page', rel_url='/', application=None,
|
|
37
|
+
location=None):
|
|
38
|
+
self.location = location
|
|
39
|
+
super().__init__(title, rel_url, application)
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def plot_title(self) -> str:
|
|
43
|
+
raise NotImplementedError
|
|
44
|
+
|
|
45
|
+
def setup_layout(self):
|
|
46
|
+
|
|
47
|
+
self.layout = dbc.Container([
|
|
48
|
+
# ----- hidden -----
|
|
49
|
+
dbc.Row([self.location]),
|
|
50
|
+
|
|
51
|
+
# ----- visible -----
|
|
52
|
+
dbc.Row([html.H2(self.plot_title)]),
|
|
53
|
+
dbc.Row([dbc.Col(dcc.Markdown(self.description_markdown))]),
|
|
54
|
+
dbc.Row(
|
|
55
|
+
[
|
|
56
|
+
dbc.Col(dbc.Spinner(self.graph)),
|
|
57
|
+
dbc.Col(
|
|
58
|
+
[
|
|
59
|
+
dbc.Row(html.H3(_('Choices:', '選択肢:'))),
|
|
60
|
+
dbc.Row(html.Hr()),
|
|
61
|
+
dbc.Row(self.input_items),
|
|
62
|
+
dbc.Row(self.output_items),
|
|
63
|
+
],
|
|
64
|
+
md=2
|
|
65
|
+
),
|
|
66
|
+
],
|
|
67
|
+
),
|
|
68
|
+
dbc.Row([self.alerts]),
|
|
69
|
+
dbc.Row(html.Hr()),
|
|
70
|
+
])
|
|
71
|
+
|
|
72
|
+
def setup_component(self):
|
|
73
|
+
|
|
74
|
+
if self.location is None:
|
|
75
|
+
self.location = dcc.Location(id='selectable-plot-location', refresh=True)
|
|
76
|
+
|
|
77
|
+
# graph
|
|
78
|
+
self.graph = dcc.Graph(style={'height': '85vh'})
|
|
79
|
+
|
|
80
|
+
# checklist
|
|
81
|
+
self.input_items = self.InputItemsClass(options=[]) if self.InputItemsClass is not None else html.Div()
|
|
82
|
+
self.output_items = self.OutputItemsClass(options=[]) if self.OutputItemsClass is not None else html.Div()
|
|
83
|
+
|
|
84
|
+
# alert
|
|
85
|
+
self.alerts = html.Div()
|
|
86
|
+
|
|
87
|
+
def _check_precondition(self, logger) -> tuple[History, pd.DataFrame, pd.DataFrame]:
|
|
88
|
+
|
|
89
|
+
if callback_context.triggered_id is None:
|
|
90
|
+
logger.debug('PreventUpdate. No trigger.')
|
|
91
|
+
raise PreventUpdate
|
|
92
|
+
|
|
93
|
+
if self.application is None:
|
|
94
|
+
logger.debug('PreventUpdate. No application.')
|
|
95
|
+
raise PreventUpdate
|
|
96
|
+
|
|
97
|
+
if self.application.history is None:
|
|
98
|
+
logger.debug('PreventUpdate. No history.')
|
|
99
|
+
raise PreventUpdate
|
|
100
|
+
|
|
101
|
+
history = self.application.history
|
|
102
|
+
|
|
103
|
+
df = self.application.get_df()
|
|
104
|
+
main_df = self.application.get_df(MAIN_FILTER)
|
|
105
|
+
if len(df) == 0:
|
|
106
|
+
logger.debug('PreventUpdate. No df.')
|
|
107
|
+
raise PreventUpdate
|
|
108
|
+
|
|
109
|
+
return history, df, main_df
|
|
110
|
+
|
|
111
|
+
@staticmethod
|
|
112
|
+
def _return_checklist_options_and_value(history, types) -> tuple[list[dict], list[str]]:
|
|
113
|
+
|
|
114
|
+
keys = []
|
|
115
|
+
|
|
116
|
+
if 'all' in types:
|
|
117
|
+
keys.extend(history.prm_names)
|
|
118
|
+
keys.extend(history.obj_names)
|
|
119
|
+
keys.extend(history.cns_names)
|
|
120
|
+
|
|
121
|
+
if 'prm' in types:
|
|
122
|
+
keys.extend(history.prm_names)
|
|
123
|
+
if 'obj' in types:
|
|
124
|
+
keys.extend(history.obj_names)
|
|
125
|
+
if 'cns' in types:
|
|
126
|
+
keys.extend(history.cns_names)
|
|
127
|
+
|
|
128
|
+
return [dict(label=key, value=key) for key in keys], keys
|
|
129
|
+
|
|
130
|
+
def _return_input_checklist_options_and_value(self, history):
|
|
131
|
+
return self._return_checklist_options_and_value(history, self.input_item_kind)
|
|
132
|
+
|
|
133
|
+
def _return_output_checklist_options_and_value(self, history):
|
|
134
|
+
return self._return_checklist_options_and_value(history, self.output_item_kind)
|
|
135
|
+
|
|
136
|
+
def setup_update_plot_input_checklist_callback(self):
|
|
137
|
+
|
|
138
|
+
@self.application.app.callback(
|
|
139
|
+
Output(self.input_items, 'options'),
|
|
140
|
+
Output(self.input_items, 'value'),
|
|
141
|
+
Input(self.location, 'pathname'), # on page load
|
|
142
|
+
)
|
|
143
|
+
def update_plot_input_checklist(_):
|
|
144
|
+
|
|
145
|
+
logger_name = f'opt.{type(self).__name__}.update_plot_input_checklist()'
|
|
146
|
+
|
|
147
|
+
logger = get_module_logger(
|
|
148
|
+
logger_name,
|
|
149
|
+
debug=False,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
logger.debug('callback fired!')
|
|
153
|
+
|
|
154
|
+
# ----- preconditions -----
|
|
155
|
+
|
|
156
|
+
history, _, _ = self._check_precondition(logger)
|
|
157
|
+
|
|
158
|
+
# ----- main -----
|
|
159
|
+
options, value = self._return_input_checklist_options_and_value(history)
|
|
160
|
+
|
|
161
|
+
if isinstance(self.input_items, dcc.RadioItems):
|
|
162
|
+
value = value[0]
|
|
163
|
+
|
|
164
|
+
return options, value
|
|
165
|
+
|
|
166
|
+
def setup_update_plot_output_checklist_callback(self):
|
|
167
|
+
|
|
168
|
+
@self.application.app.callback(
|
|
169
|
+
Output(self.output_items, 'options'),
|
|
170
|
+
Output(self.output_items, 'value'),
|
|
171
|
+
Input(self.location, 'pathname'), # on page load
|
|
172
|
+
)
|
|
173
|
+
def update_plot_output_checklist(_):
|
|
174
|
+
|
|
175
|
+
logger_name = f'opt.{type(self).__name__}.update_plot_output_checklist()'
|
|
176
|
+
|
|
177
|
+
logger = get_module_logger(
|
|
178
|
+
logger_name,
|
|
179
|
+
debug=False,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
logger.debug('callback fired!')
|
|
183
|
+
|
|
184
|
+
# ----- preconditions -----
|
|
185
|
+
|
|
186
|
+
history, _, _ = self._check_precondition(logger)
|
|
187
|
+
|
|
188
|
+
# ----- main -----
|
|
189
|
+
options, value = self._return_output_checklist_options_and_value(history)
|
|
190
|
+
|
|
191
|
+
logger.debug(value)
|
|
192
|
+
if isinstance(self.output_items, dcc.RadioItems):
|
|
193
|
+
value = value[0]
|
|
194
|
+
logger.debug(value)
|
|
195
|
+
|
|
196
|
+
return options, value
|
|
197
|
+
|
|
198
|
+
def setup_update_plot_graph_callback(self):
|
|
199
|
+
|
|
200
|
+
@self.application.app.callback(
|
|
201
|
+
# graph output
|
|
202
|
+
Output(self.graph, 'figure'),
|
|
203
|
+
Output(self.alerts, 'children'),
|
|
204
|
+
# checklist input
|
|
205
|
+
inputs=dict(
|
|
206
|
+
selected_input_values=Input(self.input_items, 'value'),
|
|
207
|
+
selected_output_values=Input(self.output_items, 'value'),
|
|
208
|
+
),
|
|
209
|
+
)
|
|
210
|
+
def update_plot_graph(
|
|
211
|
+
selected_input_values: list[str] | str,
|
|
212
|
+
selected_output_values: list[str] | str,
|
|
213
|
+
):
|
|
214
|
+
|
|
215
|
+
logger_name = f'opt.{type(self).__name__}.update_plot_graph()'
|
|
216
|
+
|
|
217
|
+
logger = get_module_logger(
|
|
218
|
+
logger_name,
|
|
219
|
+
debug=False,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
logger.debug('callback fired!')
|
|
223
|
+
|
|
224
|
+
# ----- preconditions -----
|
|
225
|
+
|
|
226
|
+
history, df, main_df = self._check_precondition(logger)
|
|
227
|
+
|
|
228
|
+
# null selected values
|
|
229
|
+
if selected_input_values is None:
|
|
230
|
+
logger.debug('No input items.')
|
|
231
|
+
return no_update, [dbc.Alert('No input items.', color='danger')]
|
|
232
|
+
|
|
233
|
+
if selected_output_values is None:
|
|
234
|
+
logger.debug('No output items.')
|
|
235
|
+
return no_update, [dbc.Alert('No output items.', color='danger')]
|
|
236
|
+
|
|
237
|
+
# type correction
|
|
238
|
+
if isinstance(selected_input_values, str):
|
|
239
|
+
selected_input_values = [selected_input_values]
|
|
240
|
+
if isinstance(selected_output_values, str):
|
|
241
|
+
selected_output_values = [selected_output_values]
|
|
242
|
+
|
|
243
|
+
# nothing selected
|
|
244
|
+
# selected_values = selected_input_values + selected_output_values
|
|
245
|
+
# if len(selected_values) == 0:
|
|
246
|
+
# logger.debug('No items are selected.')
|
|
247
|
+
# return no_update, [dbc.Alert('No items are selected.', color='danger')]
|
|
248
|
+
if len(selected_input_values) == 0:
|
|
249
|
+
logger.debug('No input items are selected.')
|
|
250
|
+
return no_update, [dbc.Alert('No input items are selected.', color='danger')]
|
|
251
|
+
if len(selected_output_values) == 0:
|
|
252
|
+
logger.debug('No output items are selected.')
|
|
253
|
+
return no_update, [dbc.Alert('No output items are selected.', color='danger')]
|
|
254
|
+
|
|
255
|
+
# ----- main -----
|
|
256
|
+
used_df = self.make_used_df(history, df, main_df, selected_input_values, selected_output_values)
|
|
257
|
+
assert len(used_df) > 0
|
|
258
|
+
assert len(used_df.columns) > 0
|
|
259
|
+
|
|
260
|
+
fig_or_err = self.create_plot(used_df)
|
|
261
|
+
|
|
262
|
+
if isinstance(fig_or_err, str):
|
|
263
|
+
return no_update, [dbc.Alert(fig_or_err, color='danger')]
|
|
264
|
+
|
|
265
|
+
return fig_or_err, []
|
|
266
|
+
|
|
267
|
+
def setup_callback(self):
|
|
268
|
+
self.setup_update_plot_input_checklist_callback()
|
|
269
|
+
self.setup_update_plot_output_checklist_callback()
|
|
270
|
+
self.setup_update_plot_graph_callback()
|
|
271
|
+
|
|
272
|
+
# noinspection PyUnusedLocal
|
|
273
|
+
@staticmethod
|
|
274
|
+
def make_used_df(history, df, main_df, selected_input_values, selected_output_values):
|
|
275
|
+
# NotImplementedError でもいいが、汎用的なので
|
|
276
|
+
|
|
277
|
+
columns = [
|
|
278
|
+
col for col in history.prm_names + history.all_output_names
|
|
279
|
+
if col in selected_input_values + selected_output_values
|
|
280
|
+
]
|
|
281
|
+
|
|
282
|
+
use_df = main_df[columns]
|
|
283
|
+
|
|
284
|
+
return use_df
|
|
285
|
+
|
|
286
|
+
@staticmethod
|
|
287
|
+
def create_plot(used_df):
|
|
288
|
+
raise NotImplementedError
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class ParallelPlot(SelectablePlot):
|
|
292
|
+
|
|
293
|
+
plot_title = _('parallel coordinate plot', '平行座標プロット')
|
|
294
|
+
description_markdown: str = _(
|
|
295
|
+
en_message='Visualize the relationships between input and output values in multiple dimensions. '
|
|
296
|
+
'You can intuitively grasp trends and the magnitude of influence between variables for specific output values.\n\n'
|
|
297
|
+
'**Tips: You can rearrange the axes and select ranges.**',
|
|
298
|
+
jp_message='各入力値と出力値の関係を多次元で可視化。'
|
|
299
|
+
'特定の出力値に対する変数間の傾向や影響の大きさを'
|
|
300
|
+
'直観的に把握できます。\n\n'
|
|
301
|
+
'**Tips: 軸は順番を入れ替えることができ、範囲選択することができます。**'
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
@staticmethod
|
|
305
|
+
def create_plot(used_df):
|
|
306
|
+
return parallel_plot(used_df)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
class ContourPlot(SelectablePlot):
|
|
310
|
+
|
|
311
|
+
plot_title = _('contour plot', 'コンタープロット')
|
|
312
|
+
OutputItemsClass = dcc.RadioItems
|
|
313
|
+
description_markdown: str = _(
|
|
314
|
+
en_message='Visualize the correlation between input variables and changes in output using contour plots. '
|
|
315
|
+
'You can identify combinations of variables that have a strong influence.\n\n'
|
|
316
|
+
'**Tips: You can hide the scatter plot.**',
|
|
317
|
+
jp_message='入力変数間の相関と、出力の変化をコンターで可視化。'
|
|
318
|
+
'影響の強い変数の組合せを確認できます。\n\n'
|
|
319
|
+
'**Tips: 点プロットは非表示にできます。**'
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
@staticmethod
|
|
323
|
+
def create_plot(used_df):
|
|
324
|
+
return contour_creator(used_df)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
class SelectableOptunaPlot(SelectablePlot):
|
|
328
|
+
|
|
329
|
+
def setup_update_plot_graph_callback(self):
|
|
330
|
+
|
|
331
|
+
@self.application.app.callback(
|
|
332
|
+
# graph output
|
|
333
|
+
Output(self.graph, 'figure'),
|
|
334
|
+
Output(self.alerts, 'children'),
|
|
335
|
+
# checklist input
|
|
336
|
+
inputs=dict(
|
|
337
|
+
selected_input_values=Input(self.input_items, 'value'),
|
|
338
|
+
selected_output_values=Input(self.output_items, 'value'),
|
|
339
|
+
),
|
|
340
|
+
)
|
|
341
|
+
def update_plot_graph(
|
|
342
|
+
selected_input_values: list[str] | str,
|
|
343
|
+
selected_output_values: list[str] | str,
|
|
344
|
+
):
|
|
345
|
+
|
|
346
|
+
logger_name = f'opt.{type(self).__name__}.update_plot_graph()'
|
|
347
|
+
|
|
348
|
+
logger = get_module_logger(
|
|
349
|
+
logger_name,
|
|
350
|
+
debug=False,
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
logger.debug('callback fired!')
|
|
354
|
+
|
|
355
|
+
# ----- preconditions -----
|
|
356
|
+
|
|
357
|
+
history, df, main_df = self._check_precondition(logger)
|
|
358
|
+
|
|
359
|
+
# null selected values
|
|
360
|
+
if selected_input_values is None:
|
|
361
|
+
logger.debug('No input items.')
|
|
362
|
+
return no_update, [dbc.Alert('No input items.', color='danger')]
|
|
363
|
+
|
|
364
|
+
if selected_output_values is None:
|
|
365
|
+
logger.debug('No output items.')
|
|
366
|
+
return no_update, [dbc.Alert('No output items.', color='danger')]
|
|
367
|
+
|
|
368
|
+
# type correction
|
|
369
|
+
if isinstance(selected_input_values, str):
|
|
370
|
+
selected_input_values = [selected_input_values]
|
|
371
|
+
if isinstance(selected_output_values, str):
|
|
372
|
+
selected_output_values = [selected_output_values]
|
|
373
|
+
|
|
374
|
+
# nothing selected
|
|
375
|
+
# selected_values = selected_input_values + selected_output_values
|
|
376
|
+
# if len(selected_values) == 0:
|
|
377
|
+
# logger.debug('No items are selected.')
|
|
378
|
+
# return no_update, [dbc.Alert('No items are selected.', color='danger')]
|
|
379
|
+
if len(selected_input_values) == 0:
|
|
380
|
+
logger.debug('No input items are selected.')
|
|
381
|
+
return no_update, [dbc.Alert('No input items are selected.', color='danger')]
|
|
382
|
+
if len(selected_output_values) == 0:
|
|
383
|
+
logger.debug('No output items are selected.')
|
|
384
|
+
return no_update, [dbc.Alert('No output items are selected.', color='danger')]
|
|
385
|
+
|
|
386
|
+
# ----- main -----
|
|
387
|
+
fig = self.create_optuna_plot(
|
|
388
|
+
history._create_optuna_study_for_visualization(),
|
|
389
|
+
selected_input_values,
|
|
390
|
+
selected_output_values,
|
|
391
|
+
[history.all_output_names.index(v) for v in selected_output_values],
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
return fig, []
|
|
395
|
+
|
|
396
|
+
@staticmethod
|
|
397
|
+
def create_optuna_plot(
|
|
398
|
+
study,
|
|
399
|
+
prm_names: list[str],
|
|
400
|
+
obj_name: list[str],
|
|
401
|
+
obj_indices: list[int],
|
|
402
|
+
):
|
|
403
|
+
|
|
404
|
+
raise NotImplementedError
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
class SelectableOptunaPlotAllInput(SelectablePlot):
|
|
408
|
+
|
|
409
|
+
InputItemsClass = None
|
|
410
|
+
OutputItemsClass = dcc.RadioItems
|
|
411
|
+
|
|
412
|
+
def setup_update_plot_graph_callback(self):
|
|
413
|
+
|
|
414
|
+
@self.application.app.callback(
|
|
415
|
+
# graph output
|
|
416
|
+
Output(self.graph, 'figure'),
|
|
417
|
+
Output(self.alerts, 'children'),
|
|
418
|
+
# checklist input
|
|
419
|
+
inputs=dict(
|
|
420
|
+
selected_output_value=Input(self.output_items, 'value'),
|
|
421
|
+
),
|
|
422
|
+
)
|
|
423
|
+
def update_plot_graph(
|
|
424
|
+
selected_output_value: str,
|
|
425
|
+
):
|
|
426
|
+
|
|
427
|
+
logger_name = f'opt.{type(self).__name__}.update_plot_graph()'
|
|
428
|
+
|
|
429
|
+
logger = get_module_logger(
|
|
430
|
+
logger_name,
|
|
431
|
+
debug=False,
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
logger.debug('callback fired!')
|
|
435
|
+
|
|
436
|
+
# ----- preconditions -----
|
|
437
|
+
|
|
438
|
+
history, df, main_df = self._check_precondition(logger)
|
|
439
|
+
|
|
440
|
+
# null selected values
|
|
441
|
+
if selected_output_value is None:
|
|
442
|
+
logger.debug('No output items.')
|
|
443
|
+
return no_update, [dbc.Alert('No output items.', color='danger')]
|
|
444
|
+
|
|
445
|
+
# ----- main -----
|
|
446
|
+
fig = self.create_optuna_plot(
|
|
447
|
+
history._create_optuna_study_for_visualization(),
|
|
448
|
+
selected_output_value,
|
|
449
|
+
history.all_output_names.index(selected_output_value)
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
return fig, []
|
|
453
|
+
|
|
454
|
+
def setup_callback(self):
|
|
455
|
+
self.setup_update_plot_output_checklist_callback()
|
|
456
|
+
self.setup_update_plot_graph_callback()
|
|
457
|
+
|
|
458
|
+
@staticmethod
|
|
459
|
+
def create_optuna_plot(
|
|
460
|
+
study, obj_name, obj_index,
|
|
461
|
+
):
|
|
462
|
+
|
|
463
|
+
raise NotImplementedError
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
class ImportancePlot(SelectableOptunaPlotAllInput):
|
|
467
|
+
|
|
468
|
+
plot_title = _('importance plot', '重要度プロット')
|
|
469
|
+
description_markdown: str = _(
|
|
470
|
+
en_message='Evaluate the importance of each input variable for the output using fANOVA. '
|
|
471
|
+
'You can quantitatively understand which inputs are important.',
|
|
472
|
+
jp_message='出力に対する各入力変数の重要度を fANOVA で評価。'
|
|
473
|
+
'重要な入力を定量的に把握できます。'
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
@staticmethod
|
|
477
|
+
def create_optuna_plot(
|
|
478
|
+
study, obj_name, obj_index,
|
|
479
|
+
):
|
|
480
|
+
|
|
481
|
+
# create plot using optuna
|
|
482
|
+
fig = optuna.visualization.plot_param_importances(
|
|
483
|
+
study,
|
|
484
|
+
target=lambda trial: trial.values[obj_index],
|
|
485
|
+
target_name=obj_name
|
|
486
|
+
)
|
|
487
|
+
fig.update_layout(
|
|
488
|
+
title=f'Normalized importance of {obj_name}'
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
return fig
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
class HistoryPlot(SelectableOptunaPlotAllInput):
|
|
495
|
+
|
|
496
|
+
plot_title = _('optimization history plot', '最適化履歴プロット')
|
|
497
|
+
description_markdown: str = _(
|
|
498
|
+
en_message='Display the history of outputs generated during optimization. '
|
|
499
|
+
'You can check the progress of improvements and the variability of the search.',
|
|
500
|
+
jp_message='最適化中に生成された出力の履歴を表示。'
|
|
501
|
+
'改善の進行や探索のばらつきを確認できます。'
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
@staticmethod
|
|
505
|
+
def create_optuna_plot(
|
|
506
|
+
study, obj_name, obj_index,
|
|
507
|
+
):
|
|
508
|
+
|
|
509
|
+
# create plot using optuna
|
|
510
|
+
fig = optuna.visualization.plot_optimization_history(
|
|
511
|
+
study,
|
|
512
|
+
target=lambda trial: trial.values[obj_index],
|
|
513
|
+
target_name=obj_name
|
|
514
|
+
)
|
|
515
|
+
fig.update_layout(
|
|
516
|
+
title=f'Optimization history of {obj_name}'
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
return fig
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
class SlicePlot(SelectableOptunaPlot):
|
|
523
|
+
|
|
524
|
+
plot_title = _('slice plot', 'スライスプロット')
|
|
525
|
+
OutputItemsClass = dcc.RadioItems
|
|
526
|
+
description_markdown: str = _(
|
|
527
|
+
en_message='Displays the output response to a specific input. '
|
|
528
|
+
'You can intuitively see the univariate effect, ignoring other variables.',
|
|
529
|
+
jp_message='特定の入力に対する出力の応答を表示。'
|
|
530
|
+
'他変数を無視した単変量の影響を'
|
|
531
|
+
'直観的に確認できます。'
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
@staticmethod
|
|
535
|
+
def create_optuna_plot(
|
|
536
|
+
study,
|
|
537
|
+
prm_names: list[str],
|
|
538
|
+
obj_names: list[str],
|
|
539
|
+
obj_indices: list[int],
|
|
540
|
+
):
|
|
541
|
+
|
|
542
|
+
assert len(obj_names) == len(obj_indices) == 1
|
|
543
|
+
|
|
544
|
+
fig = optuna.visualization.plot_slice(
|
|
545
|
+
study,
|
|
546
|
+
params=prm_names,
|
|
547
|
+
target=lambda trial: trial.values[obj_indices[0]],
|
|
548
|
+
target_name=obj_names[0],
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
return fig
|