pyfemtet 0.1.12__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyfemtet might be problematic. Click here for more details.
- pyfemtet/FemtetPJTSample/NX_ex01/NX_ex01.femprj +0 -0
- pyfemtet/FemtetPJTSample/NX_ex01/NX_ex01.prt +0 -0
- pyfemtet/FemtetPJTSample/NX_ex01/NX_ex01.py +69 -32
- pyfemtet/FemtetPJTSample/gau_ex08_parametric.femprj +0 -0
- pyfemtet/FemtetPJTSample/gau_ex08_parametric.py +37 -25
- pyfemtet/FemtetPJTSample/her_ex40_parametric.py +57 -35
- pyfemtet/FemtetPJTSample/wat_ex14_parallel_parametric.py +62 -0
- pyfemtet/FemtetPJTSample/wat_ex14_parametric.femprj +0 -0
- pyfemtet/FemtetPJTSample/wat_ex14_parametric.py +61 -0
- pyfemtet/__init__.py +1 -1
- pyfemtet/opt/_FemtetWithNX/update_model.py +6 -2
- pyfemtet/opt/__init__.py +1 -1
- pyfemtet/opt/base.py +457 -86
- pyfemtet/opt/core.py +77 -17
- pyfemtet/opt/interface.py +217 -137
- pyfemtet/opt/monitor.py +181 -98
- pyfemtet/opt/{_optuna.py → optimizer.py} +70 -30
- pyfemtet/tools/DispatchUtils.py +46 -44
- {pyfemtet-0.1.12.dist-info → pyfemtet-0.2.1.dist-info}/LICENSE +1 -1
- pyfemtet-0.2.1.dist-info/METADATA +42 -0
- pyfemtet-0.2.1.dist-info/RECORD +31 -0
- pyfemtet/FemtetPJTSample/NX_ex01/NX_ex01 - original.x_t +0 -359
- pyfemtet/FemtetPJTSample/NX_ex01/NX_ex01.x_t +0 -359
- pyfemtet/FemtetPJTSample/fem4 = Femtet(femprj_path=None, model_name=None, connect_method='catch').femprj +0 -0
- pyfemtet/FemtetPJTSample/gal_ex11_parametric.femprj +0 -0
- pyfemtet/FemtetPJTSample/gal_ex11_parametric.py +0 -54
- pyfemtet/FemtetPJTSample/pas_ex1_parametric.femprj +0 -0
- pyfemtet/FemtetPJTSample/pas_ex1_parametric.py +0 -66
- pyfemtet/FemtetPJTSample/pas_ex1_parametric2.py +0 -68
- pyfemtet/tools/FemtetClassConst.py +0 -9
- pyfemtet-0.1.12.dist-info/METADATA +0 -205
- pyfemtet-0.1.12.dist-info/RECORD +0 -37
- {pyfemtet-0.1.12.dist-info → pyfemtet-0.2.1.dist-info}/WHEEL +0 -0
pyfemtet/opt/monitor.py
CHANGED
|
@@ -1,9 +1,20 @@
|
|
|
1
1
|
import webbrowser
|
|
2
2
|
import logging
|
|
3
|
-
from dash import Dash, html, dcc
|
|
4
|
-
from dash.dependencies import Output, Input
|
|
3
|
+
from dash import Dash, html, dcc, ctx, Output, Input
|
|
5
4
|
import dash_bootstrap_components as dbc
|
|
6
5
|
import plotly.graph_objs as go
|
|
6
|
+
import plotly.express as px
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def update_hypervolume_plot(femopt):
|
|
10
|
+
# data setting
|
|
11
|
+
df = femopt.history.data
|
|
12
|
+
|
|
13
|
+
# create figure
|
|
14
|
+
fig = px.line(df, x="trial", y="hypervolume", markers=True)
|
|
15
|
+
|
|
16
|
+
return fig
|
|
17
|
+
|
|
7
18
|
|
|
8
19
|
|
|
9
20
|
def update_scatter_matrix(femopt):
|
|
@@ -13,12 +24,6 @@ def update_scatter_matrix(femopt):
|
|
|
13
24
|
|
|
14
25
|
# create figure
|
|
15
26
|
fig = go.Figure()
|
|
16
|
-
fig.update_layout(
|
|
17
|
-
dict(
|
|
18
|
-
width=800,
|
|
19
|
-
height=600,
|
|
20
|
-
)
|
|
21
|
-
)
|
|
22
27
|
|
|
23
28
|
# graphs setting dependent on n_objectives
|
|
24
29
|
if len(obj_names) == 0:
|
|
@@ -27,14 +32,14 @@ def update_scatter_matrix(femopt):
|
|
|
27
32
|
elif len(obj_names) == 1:
|
|
28
33
|
fig.add_trace(
|
|
29
34
|
go.Scatter(
|
|
30
|
-
x=
|
|
31
|
-
y=data[obj_names[0]],
|
|
35
|
+
x=data['trial'],
|
|
36
|
+
y=data[obj_names[0]].values,
|
|
32
37
|
mode='markers+lines',
|
|
33
38
|
)
|
|
34
39
|
)
|
|
35
40
|
fig.update_layout(
|
|
36
41
|
dict(
|
|
37
|
-
title_text="
|
|
42
|
+
title_text="単目的プロット",
|
|
38
43
|
xaxis_title="解析実行回数(回)",
|
|
39
44
|
yaxis_title=obj_names[0],
|
|
40
45
|
)
|
|
@@ -73,20 +78,74 @@ def update_scatter_matrix(femopt):
|
|
|
73
78
|
return fig
|
|
74
79
|
|
|
75
80
|
|
|
81
|
+
def setup_home():
|
|
82
|
+
# components の設定
|
|
83
|
+
# https://dash-bootstrap-components.opensource.faculty.ai/docs/components/accordion/
|
|
84
|
+
dummy = html.Div('', id='dummy')
|
|
85
|
+
interval = dcc.Interval(
|
|
86
|
+
id='interval-component',
|
|
87
|
+
interval=1*1000, # in milliseconds
|
|
88
|
+
n_intervals=0,
|
|
89
|
+
)
|
|
90
|
+
header = html.H1("最適化の進行状況"),
|
|
91
|
+
graphs = dbc.Card(
|
|
92
|
+
[
|
|
93
|
+
dbc.CardHeader(
|
|
94
|
+
dbc.Tabs(
|
|
95
|
+
[
|
|
96
|
+
dbc.Tab(label="目的プロット", tab_id="tab-1"),
|
|
97
|
+
dbc.Tab(label="Hypervolume", tab_id="tab-2"),
|
|
98
|
+
],
|
|
99
|
+
id="card-tabs",
|
|
100
|
+
active_tab="tab-1",
|
|
101
|
+
)
|
|
102
|
+
),
|
|
103
|
+
dbc.CardBody(html.P(id="card-content", className="card-text")),
|
|
104
|
+
]
|
|
105
|
+
)
|
|
106
|
+
toggle_update_button = dbc.Button('グラフの自動更新の一時停止', id='toggle-update-button')
|
|
107
|
+
interrupt_button = dbc.Button('最適化を中断', id='interrupt-button', color='danger')
|
|
108
|
+
status_text = dcc.Markdown(f'''
|
|
109
|
+
---
|
|
110
|
+
- このページでは、最適化の進捗状況を見ることができます。
|
|
111
|
+
- このページを閉じても最適化は進行します。
|
|
112
|
+
- この機能はブラウザによる状況確認機能ですが、インターネット通信は行いません。
|
|
113
|
+
- 再びこのページを開くには、ブラウザのアドレスバーに __localhost:8080__ と入力してください。
|
|
114
|
+
- ※ 特定のホスト名及びポートを指定するには、OptimizerBase.main() の実行前に
|
|
115
|
+
OptimizerBase.set_monitor_server() を実行してください。
|
|
116
|
+
''')
|
|
117
|
+
|
|
118
|
+
# layout の設定
|
|
119
|
+
layout = dbc.Container([
|
|
120
|
+
dbc.Row([dbc.Col(dummy), dbc.Col(interval)]),
|
|
121
|
+
dbc.Row([dbc.Col(header)]),
|
|
122
|
+
dbc.Row([dbc.Col(graphs)]),
|
|
123
|
+
dbc.Row([dbc.Col(toggle_update_button), dbc.Col(interrupt_button)]),
|
|
124
|
+
dbc.Row([dbc.Col(status_text)]),
|
|
125
|
+
], fluid=True)
|
|
126
|
+
|
|
127
|
+
return layout
|
|
128
|
+
|
|
129
|
+
|
|
76
130
|
class Monitor(object):
|
|
77
131
|
|
|
78
132
|
def __init__(self, femopt):
|
|
79
133
|
|
|
134
|
+
# 引数の処理
|
|
80
135
|
self.femopt = femopt
|
|
81
136
|
|
|
137
|
+
# ログファイルの保存場所
|
|
82
138
|
log_path = self.femopt.history.path.replace('.csv', '.uilog')
|
|
83
139
|
l = logging.getLogger()
|
|
84
140
|
l.addHandler(logging.FileHandler(log_path))
|
|
85
141
|
|
|
142
|
+
# app の立上げ
|
|
86
143
|
self.app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
|
|
87
144
|
|
|
145
|
+
# ページの components と layout の設定
|
|
146
|
+
self.home = setup_home()
|
|
88
147
|
|
|
89
|
-
|
|
148
|
+
# setup sidebar
|
|
90
149
|
# https://dash-bootstrap-components.opensource.faculty.ai/examples/simple-sidebar/
|
|
91
150
|
|
|
92
151
|
# the style arguments for the sidebar. We use position:fixed and a fixed width
|
|
@@ -107,20 +166,17 @@ class Monitor(object):
|
|
|
107
166
|
"margin-right": "2rem",
|
|
108
167
|
"padding": "2rem 1rem",
|
|
109
168
|
}
|
|
110
|
-
|
|
111
|
-
# setup sidebar
|
|
112
169
|
sidebar = html.Div(
|
|
113
170
|
[
|
|
114
|
-
html.H2("
|
|
171
|
+
html.H2("PyFemtet Monitor", className="display-4"),
|
|
115
172
|
html.Hr(),
|
|
116
173
|
html.P(
|
|
117
|
-
"
|
|
174
|
+
"最適化の進捗を可視化します.", className="lead"
|
|
118
175
|
),
|
|
119
176
|
dbc.Nav(
|
|
120
177
|
[
|
|
121
178
|
dbc.NavLink("Home", href="/", active="exact"),
|
|
122
|
-
dbc.NavLink("ペアプロット", href="/page-1", active="exact"),
|
|
123
|
-
# dbc.NavLink("Page 2", href="/page-2", active="exact"),
|
|
179
|
+
# dbc.NavLink("ペアプロット", href="/page-1", active="exact"),
|
|
124
180
|
],
|
|
125
181
|
vertical=True,
|
|
126
182
|
pills=True,
|
|
@@ -128,22 +184,16 @@ class Monitor(object):
|
|
|
128
184
|
],
|
|
129
185
|
style=SIDEBAR_STYLE,
|
|
130
186
|
)
|
|
131
|
-
|
|
132
187
|
content = html.Div(id="page-content", style=CONTENT_STYLE)
|
|
133
188
|
self.app.layout = html.Div([dcc.Location(id="url"), sidebar, content])
|
|
134
189
|
|
|
135
|
-
#### settings for multiobjective pairplot
|
|
136
|
-
self.home = self.setup_home()
|
|
137
|
-
self.multi_pairplot_layout = self.setup_page1()
|
|
138
|
-
|
|
139
|
-
|
|
140
190
|
# sidebar によるページ遷移のための callback
|
|
141
191
|
@self.app.callback(Output("page-content", "children"), [Input("url", "pathname")])
|
|
142
192
|
def render_page_content(pathname):
|
|
143
|
-
if pathname == "/":
|
|
193
|
+
if pathname == "/": # p0
|
|
144
194
|
return self.home
|
|
145
|
-
elif pathname == "/page-1":
|
|
146
|
-
|
|
195
|
+
# elif pathname == "/page-1":
|
|
196
|
+
# return self.multi_pairplot_layout
|
|
147
197
|
# elif pathname == "/page-2":
|
|
148
198
|
# return html.P("Oh cool, this is page 2!")
|
|
149
199
|
# If the user tries to reach a different page, return a 404 message
|
|
@@ -156,109 +206,142 @@ class Monitor(object):
|
|
|
156
206
|
className="p-3 bg-light rounded-3",
|
|
157
207
|
)
|
|
158
208
|
|
|
159
|
-
|
|
160
|
-
#
|
|
209
|
+
# 1. 一定時間ごとに ==> 自動更新が有効なら figure を更新する
|
|
210
|
+
# 2. 中断ボタンを押したら ==> 更新を無効にする and 中断を無効にする
|
|
211
|
+
# 3. メイン処理が中断 or 終了していたら ==> 更新を無効にする and 中断を無効にする
|
|
212
|
+
# 4. toggle_button が押されたら ==> 更新を有効にする or 更新を無効にする
|
|
213
|
+
# 5. タブを押したら ==> グラフの種類を切り替える
|
|
161
214
|
@self.app.callback(
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
215
|
+
[
|
|
216
|
+
Output('interval-component', 'max_intervals'), # 2 3 4
|
|
217
|
+
Output('interrupt-button', 'disabled'), # 2 3
|
|
218
|
+
Output('toggle-update-button', 'disabled'), # 2 3 4
|
|
219
|
+
Output('toggle-update-button', 'children'), # 2 3 4
|
|
220
|
+
Output('card-content', 'children'), # 1 5
|
|
221
|
+
],
|
|
222
|
+
[
|
|
223
|
+
Input('interval-component', 'n_intervals'), # 1 3
|
|
224
|
+
Input('toggle-update-button', 'n_clicks'), # 4
|
|
225
|
+
Input('interrupt-button', 'n_clicks'), # 2
|
|
226
|
+
Input("card-tabs", "active_tab"), # 5
|
|
227
|
+
]
|
|
228
|
+
)
|
|
229
|
+
def control(
|
|
230
|
+
_, # n_intervals
|
|
231
|
+
toggle_n_clicks,
|
|
232
|
+
interrupt_n_clicks,
|
|
233
|
+
active_tab_id,
|
|
234
|
+
):
|
|
235
|
+
# 引数の処理
|
|
236
|
+
toggle_n_clicks = 0 if toggle_n_clicks is None else toggle_n_clicks
|
|
237
|
+
interrupt_n_clicks = 0 if interrupt_n_clicks is None else interrupt_n_clicks
|
|
238
|
+
|
|
239
|
+
# 下記を基本に戻り値を上書きしていく(優先のものほど下に来る)
|
|
240
|
+
max_intervals = -1 # enable
|
|
241
|
+
button_disable = False
|
|
242
|
+
toggle_text = 'グラフの自動更新を一時停止する'
|
|
243
|
+
graph = None
|
|
244
|
+
|
|
245
|
+
# toggle_button が奇数なら interval を disable にする
|
|
246
|
+
if toggle_n_clicks % 2 == 1:
|
|
247
|
+
max_intervals = 0 # disable
|
|
248
|
+
button_disable = False
|
|
249
|
+
toggle_text = 'グラフの自動更新を再開する'
|
|
175
250
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
[Output('interval-component', 'max_intervals'),
|
|
179
|
-
Output('interrupt-button', 'disabled'),],
|
|
180
|
-
[Input('interval-component', 'n_intervals'),])
|
|
181
|
-
def stop_interval(_):
|
|
251
|
+
# 中断又は終了なら interval とボタンを disable にする
|
|
252
|
+
should_stop = False
|
|
182
253
|
try:
|
|
183
254
|
state = self.femopt.ipv.get_state()
|
|
184
255
|
should_stop = (state == 'interrupted') or (state == 'terminated')
|
|
185
256
|
except AttributeError:
|
|
186
257
|
should_stop = True
|
|
187
|
-
|
|
188
|
-
|
|
258
|
+
finally:
|
|
259
|
+
if should_stop:
|
|
260
|
+
max_intervals = 0 # disable
|
|
261
|
+
button_disable = True
|
|
262
|
+
toggle_text = 'グラフの更新は行われません'
|
|
263
|
+
|
|
264
|
+
# 中断ボタンが押されたなら interval とボタンを disable にして femopt の状態を set する
|
|
265
|
+
button_id = ctx.triggered_id if not None else 'No clicks yet'
|
|
266
|
+
if button_id == 'interrupt-button':
|
|
267
|
+
max_intervals = 0 # disable
|
|
189
268
|
button_disable = True
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
button_disable = False
|
|
193
|
-
return max_intervals, button_disable
|
|
194
|
-
|
|
195
|
-
def setup_home(self):
|
|
196
|
-
# components の設定
|
|
197
|
-
text = dcc.Markdown('''
|
|
198
|
-
# 最適化の進行状況モニター
|
|
199
|
-
---
|
|
200
|
-
#### 左のサイドバーから、可視化方法を選択してください
|
|
201
|
-
- このページでは、最適化の進捗状況を見ることができます。
|
|
202
|
-
- ブラウザによる進捗状況確認機能ですが、インターネット通信は行いません。
|
|
203
|
-
- このページを閉じても最適化は進行します。再びこのページを開くには、ブラウザのアドレスバーに __localhost:8080__ と入力してください。
|
|
204
|
-
''')
|
|
205
|
-
|
|
206
|
-
return text
|
|
269
|
+
toggle_text = 'グラフの更新は行われません'
|
|
270
|
+
self.femopt.ipv.set_state('interrupted')
|
|
207
271
|
|
|
272
|
+
# グラフを更新する
|
|
273
|
+
if active_tab_id is not None:
|
|
274
|
+
if active_tab_id == "tab-1":
|
|
275
|
+
graph = dcc.Graph(figure=update_scatter_matrix(self.femopt))
|
|
276
|
+
elif active_tab_id == "tab-2":
|
|
277
|
+
graph = dcc.Graph(figure=update_hypervolume_plot(self.femopt))
|
|
208
278
|
|
|
209
|
-
|
|
210
|
-
# components の設定
|
|
211
|
-
# https://dash-bootstrap-components.opensource.faculty.ai/docs/components/accordion/
|
|
212
|
-
dummy = html.Div('', id='dummy')
|
|
213
|
-
interval = dcc.Interval(
|
|
214
|
-
id='interval-component',
|
|
215
|
-
interval=1*1000, # in milliseconds
|
|
216
|
-
n_intervals=0,
|
|
217
|
-
)
|
|
218
|
-
header = html.H1("最適化の進行状況"),
|
|
219
|
-
graph = dcc.Graph(id='scatter-matrix-graph')
|
|
220
|
-
interrupt_button = dbc.Button('最適化を中断', id='interrupt-button', color='danger')
|
|
279
|
+
return max_intervals, button_disable, button_disable, toggle_text, graph
|
|
221
280
|
|
|
222
|
-
|
|
223
|
-
layout = dbc.Container([
|
|
224
|
-
dbc.Row([dbc.Col(dummy), dbc.Col(interval)]),
|
|
225
|
-
dbc.Row([dbc.Col(header)]),
|
|
226
|
-
dbc.Row([dbc.Col(graph)]),
|
|
227
|
-
dbc.Row([dbc.Col(interrupt_button)], justify="center",),
|
|
228
|
-
])
|
|
229
|
-
return layout
|
|
281
|
+
def start_server(self, host='localhost', port=8080):
|
|
230
282
|
|
|
283
|
+
if host is None:
|
|
284
|
+
host = 'localhost'
|
|
285
|
+
if port is None:
|
|
286
|
+
port = 8080
|
|
231
287
|
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
288
|
+
if host == '0.0.0.0':
|
|
289
|
+
webbrowser.open(f'http://localhost:{str(port)}')
|
|
290
|
+
else:
|
|
291
|
+
webbrowser.open(f'http://{host}:{str(port)}')
|
|
292
|
+
self.app.run(debug=False, host=host, port=port)
|
|
235
293
|
|
|
236
294
|
|
|
237
295
|
if __name__ == '__main__':
|
|
296
|
+
import datetime
|
|
297
|
+
from time import sleep
|
|
298
|
+
from threading import Thread
|
|
238
299
|
import numpy as np
|
|
239
300
|
import pandas as pd
|
|
301
|
+
|
|
302
|
+
|
|
240
303
|
class IPV:
|
|
241
304
|
def __init__(self):
|
|
242
305
|
self.state = 'running'
|
|
306
|
+
|
|
243
307
|
def get_state(self):
|
|
244
308
|
return self.state
|
|
309
|
+
|
|
245
310
|
def set_state(self, state):
|
|
246
311
|
self.state = state
|
|
312
|
+
|
|
313
|
+
|
|
247
314
|
class History:
|
|
248
315
|
def __init__(self):
|
|
249
316
|
self.obj_names = 'A B C D E'.split()
|
|
250
|
-
self.data = pd.DataFrame(
|
|
251
|
-
np.random.rand(5, len(self.obj_names)),
|
|
252
|
-
columns=self.obj_names,
|
|
253
|
-
)
|
|
254
317
|
self.path = 'tmp.csv'
|
|
318
|
+
self.data = None
|
|
319
|
+
t = Thread(target=self.update)
|
|
320
|
+
t.start()
|
|
321
|
+
|
|
322
|
+
def update(self):
|
|
323
|
+
|
|
324
|
+
d = dict(
|
|
325
|
+
trial=range(5),
|
|
326
|
+
hypervolume=np.random.rand(5),
|
|
327
|
+
time=[datetime.datetime(year=2000, month=1, day=1, second=s) for s in range(5)]
|
|
328
|
+
)
|
|
329
|
+
for obj_name in self.obj_names:
|
|
330
|
+
d[obj_name] = np.random.rand(5)
|
|
331
|
+
|
|
332
|
+
while True:
|
|
333
|
+
self.data = pd.DataFrame(d)
|
|
334
|
+
sleep(1)
|
|
335
|
+
|
|
336
|
+
|
|
255
337
|
class FEMOPT:
|
|
256
338
|
def __init__(self, history, ipv):
|
|
257
339
|
self.history = history
|
|
258
340
|
self.ipv = ipv
|
|
259
341
|
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
342
|
+
|
|
343
|
+
_ipv = IPV()
|
|
344
|
+
_history = History()
|
|
345
|
+
_femopt = FEMOPT(_history, _ipv)
|
|
346
|
+
monitor = Monitor(_femopt)
|
|
264
347
|
monitor.start_server()
|
|
@@ -17,11 +17,22 @@ from .base import OptimizerBase
|
|
|
17
17
|
warnings.filterwarnings('ignore', category=ExperimentalWarning)
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
def generate_lhs(bounds, seed=None) -> np.ndarray:
|
|
21
|
-
"""
|
|
22
|
-
|
|
23
|
-
|
|
20
|
+
def generate_lhs(bounds: list[list[float]], seed: int | None = None) -> np.ndarray:
|
|
21
|
+
"""Latin Hypercube Sampling from given design parameter bounds.
|
|
22
|
+
|
|
23
|
+
If the number of parameters is d,
|
|
24
|
+
sampler returns (N, d) shape ndarray.
|
|
25
|
+
N equals p**2, p is the minimum prime number over d.
|
|
26
|
+
For example, when d=3, then p=5 and N=25.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
bounds (list[list[float]]): List of [lower_bound, upper_bound] of parameters.
|
|
30
|
+
seed (int | None, optional): Random seed. Defaults to None.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
np.ndarray: (N, d) shape ndarray.
|
|
24
34
|
"""
|
|
35
|
+
|
|
25
36
|
d = len(bounds)
|
|
26
37
|
|
|
27
38
|
sampler = LatinHypercube(
|
|
@@ -59,6 +70,9 @@ def generate_lhs(bounds, seed=None) -> np.ndarray:
|
|
|
59
70
|
|
|
60
71
|
|
|
61
72
|
class OptimizerOptuna(OptimizerBase):
|
|
73
|
+
"""Optimizer class using Optuna.
|
|
74
|
+
|
|
75
|
+
"""
|
|
62
76
|
|
|
63
77
|
def _objective(self, trial):
|
|
64
78
|
|
|
@@ -77,13 +91,16 @@ class OptimizerOptuna(OptimizerBase):
|
|
|
77
91
|
for i, row in self.parameters.iterrows():
|
|
78
92
|
x.append(trial.suggest_float(row['name'], row['lb'], row['ub']))
|
|
79
93
|
x = np.array(x)
|
|
94
|
+
self.parameters['value'] = x
|
|
80
95
|
|
|
81
96
|
# strict 拘束の計算で Prune することになったとき
|
|
82
97
|
# constraint attr がないとエラーになるのでダミーを置いておく
|
|
83
98
|
trial.set_user_attr("constraint", (1.,)) # 非正が feasible 扱い
|
|
84
99
|
|
|
100
|
+
# GetVariableValue 経由で変数にアクセスするなどの場合
|
|
101
|
+
self.fem.update_parameter(self.parameters)
|
|
102
|
+
|
|
85
103
|
# strict 拘束の計算
|
|
86
|
-
self.parameters['value'] = x
|
|
87
104
|
tmp = [[cns.calc(self.fem), cns.lb, cns.ub, name] for name, cns in self.constraints.items() if cns.strict]
|
|
88
105
|
for val, lb, ub, name in tmp:
|
|
89
106
|
if lb is not None:
|
|
@@ -103,6 +120,9 @@ class OptimizerOptuna(OptimizerBase):
|
|
|
103
120
|
try:
|
|
104
121
|
obj_values = self.f(x, message) # obj_val と cns_val の更新
|
|
105
122
|
except (ModelError, MeshError, SolveError):
|
|
123
|
+
print('FEM 解析に失敗しました。')
|
|
124
|
+
print('変数の組み合わせは以下の通りです。')
|
|
125
|
+
print(self.parameters)
|
|
106
126
|
raise optuna.TrialPruned()
|
|
107
127
|
|
|
108
128
|
# 拘束 attr の更新
|
|
@@ -123,7 +143,15 @@ class OptimizerOptuna(OptimizerBase):
|
|
|
123
143
|
def _constraint_function(self, trial):
|
|
124
144
|
return trial.user_attrs["constraint"]
|
|
125
145
|
|
|
126
|
-
def
|
|
146
|
+
def setup_concrete_main(self, use_lhs_init=True):
|
|
147
|
+
"""Performs the setup for the optimization using Optuna.
|
|
148
|
+
|
|
149
|
+
Do sampler settings, study creation or loading and initial trials settings.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
use_lhs_init (bool, optional): Flag indicating whether to use Latin Hypercube Sampling for initializing trials. Defaults to True.
|
|
153
|
+
|
|
154
|
+
"""
|
|
127
155
|
|
|
128
156
|
# sampler の設定
|
|
129
157
|
self.sampler_kwargs = dict(
|
|
@@ -133,6 +161,11 @@ class OptimizerOptuna(OptimizerBase):
|
|
|
133
161
|
self.sampler_class = optuna.samplers.TPESampler
|
|
134
162
|
if self.method == 'botorch':
|
|
135
163
|
self.sampler_class = optuna.integration.BoTorchSampler
|
|
164
|
+
# self.sampler_kwargs.update(
|
|
165
|
+
# dict(
|
|
166
|
+
# consider_running_trials=True
|
|
167
|
+
# )
|
|
168
|
+
# )
|
|
136
169
|
|
|
137
170
|
# study name の設定
|
|
138
171
|
self.study_name = os.path.splitext(os.path.basename(self.history.path))[0]
|
|
@@ -148,32 +181,35 @@ class OptimizerOptuna(OptimizerBase):
|
|
|
148
181
|
)
|
|
149
182
|
|
|
150
183
|
# 初期値の設定
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
for
|
|
167
|
-
d
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
184
|
+
if len(study.trials) == 0: # リスタートでなければ
|
|
185
|
+
# ユーザーの指定した初期値
|
|
186
|
+
params = self.get_parameter('dict')
|
|
187
|
+
study.enqueue_trial(params, user_attrs={"message": "initial"})
|
|
188
|
+
|
|
189
|
+
# LHS を初期値にする
|
|
190
|
+
if use_lhs_init:
|
|
191
|
+
names = []
|
|
192
|
+
bounds = []
|
|
193
|
+
for i, row in self.parameters.iterrows():
|
|
194
|
+
names.append(row['name'])
|
|
195
|
+
lb = row['lb']
|
|
196
|
+
ub = row['ub']
|
|
197
|
+
bounds.append([lb, ub])
|
|
198
|
+
data = generate_lhs(bounds, seed=self.seed)
|
|
199
|
+
for datum in data:
|
|
200
|
+
d = {}
|
|
201
|
+
for name, v in zip(names, datum):
|
|
202
|
+
d[name] = v
|
|
203
|
+
study.enqueue_trial(d, user_attrs={"message": "initial Latin Hypercube Sampling"})
|
|
204
|
+
|
|
205
|
+
def concrete_main(self, subprocess_idx=None):
|
|
206
|
+
"""Optimization using Optuna."""
|
|
172
207
|
|
|
173
208
|
# 乱数シードをプロセス固有にする
|
|
174
209
|
seed = self.seed
|
|
175
210
|
if seed is not None:
|
|
176
|
-
|
|
211
|
+
if subprocess_idx is not None:
|
|
212
|
+
seed = seed + (1 + subprocess_idx) # main process と subprocess0 が重複する
|
|
177
213
|
|
|
178
214
|
# sampler の restore
|
|
179
215
|
sampler = self.sampler_class(
|
|
@@ -188,10 +224,14 @@ class OptimizerOptuna(OptimizerBase):
|
|
|
188
224
|
sampler=sampler,
|
|
189
225
|
)
|
|
190
226
|
|
|
191
|
-
#
|
|
227
|
+
# 最大実行回数の指定
|
|
192
228
|
callbacks = []
|
|
229
|
+
n_existing_trials = len(self.history.data)
|
|
193
230
|
if self.n_trials is not None:
|
|
194
|
-
|
|
231
|
+
n_trials = n_existing_trials + self.n_trials
|
|
232
|
+
callbacks.append(MaxTrialsCallback(n_trials, states=(TrialState.COMPLETE,)))
|
|
233
|
+
|
|
234
|
+
# run
|
|
195
235
|
study.optimize(
|
|
196
236
|
self._objective,
|
|
197
237
|
timeout=self.timeout,
|