pyfemtet 0.4.20__py3-none-any.whl → 0.4.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyfemtet might be problematic. Click here for more details.

Files changed (61) hide show
  1. pyfemtet/__init__.py +1 -1
  2. pyfemtet/_test_util.py +0 -2
  3. pyfemtet/message/locales/ja/LC_MESSAGES/messages.mo +0 -0
  4. pyfemtet/message/locales/ja/LC_MESSAGES/messages.po +107 -96
  5. pyfemtet/message/locales/messages.pot +104 -96
  6. pyfemtet/message/messages.py +15 -1
  7. pyfemtet/opt/_femopt.py +289 -230
  8. pyfemtet/opt/_femopt_core.py +118 -49
  9. pyfemtet/opt/femprj_sample/ParametricIF.py +0 -2
  10. pyfemtet/opt/femprj_sample/cad_ex01_NX.py +0 -8
  11. pyfemtet/opt/femprj_sample/cad_ex01_SW.py +0 -8
  12. pyfemtet/opt/femprj_sample/gal_ex58_parametric.py +0 -8
  13. pyfemtet/opt/femprj_sample/gau_ex08_parametric.py +0 -8
  14. pyfemtet/opt/femprj_sample/her_ex40_parametric.py +0 -8
  15. pyfemtet/opt/femprj_sample/paswat_ex1_parametric.py +0 -8
  16. pyfemtet/opt/femprj_sample/paswat_ex1_parametric_parallel.py +0 -8
  17. pyfemtet/opt/femprj_sample/wat_ex14_parametric.py +0 -8
  18. pyfemtet/opt/femprj_sample/wat_ex14_parametric_parallel.py +0 -8
  19. pyfemtet/opt/femprj_sample_jp/ParametricIF_jp.py +0 -2
  20. pyfemtet/opt/femprj_sample_jp/cad_ex01_NX_jp.py +0 -8
  21. pyfemtet/opt/femprj_sample_jp/cad_ex01_SW_jp.py +0 -8
  22. pyfemtet/opt/femprj_sample_jp/gal_ex58_parametric_jp.py +0 -8
  23. pyfemtet/opt/femprj_sample_jp/gau_ex08_parametric_jp.py +0 -8
  24. pyfemtet/opt/femprj_sample_jp/her_ex40_parametric_jp.py +0 -8
  25. pyfemtet/opt/femprj_sample_jp/paswat_ex1_parametric_jp.py +0 -8
  26. pyfemtet/opt/femprj_sample_jp/paswat_ex1_parametric_parallel_jp.py +0 -8
  27. pyfemtet/opt/femprj_sample_jp/wat_ex14_parametric_jp.py +0 -8
  28. pyfemtet/opt/femprj_sample_jp/wat_ex14_parametric_parallel_jp.py +0 -8
  29. pyfemtet/opt/interface/_femtet.py +77 -24
  30. pyfemtet/opt/opt/_base.py +25 -18
  31. pyfemtet/opt/opt/_optuna.py +53 -14
  32. pyfemtet/opt/opt/_optuna_botorch_helper.py +209 -0
  33. pyfemtet/opt/opt/_scipy.py +1 -1
  34. pyfemtet/opt/opt/_scipy_scalar.py +1 -1
  35. pyfemtet/opt/parameter.py +113 -0
  36. pyfemtet/opt/visualization/complex_components/main_graph.py +22 -5
  37. pyfemtet/opt/visualization/complex_components/pm_graph.py +77 -25
  38. pyfemtet/opt/visualization/complex_components/pm_graph_creator.py +7 -0
  39. pyfemtet/opt/visualization/process_monitor/application.py +10 -6
  40. pyfemtet/opt/visualization/process_monitor/pages.py +102 -0
  41. pyfemtet/opt/visualization/result_viewer/application.py +6 -0
  42. pyfemtet/opt/visualization/result_viewer/pages.py +1 -1
  43. {pyfemtet-0.4.20.dist-info → pyfemtet-0.4.23.dist-info}/METADATA +3 -4
  44. {pyfemtet-0.4.20.dist-info → pyfemtet-0.4.23.dist-info}/RECORD +47 -59
  45. pyfemtet/FemtetPJTSample/NX_ex01/NX_ex01.femprj +0 -0
  46. pyfemtet/FemtetPJTSample/NX_ex01/NX_ex01.prt +0 -0
  47. pyfemtet/FemtetPJTSample/NX_ex01/NX_ex01.py +0 -118
  48. pyfemtet/FemtetPJTSample/Sldworks_ex01/Sldworks_ex01.SLDPRT +0 -0
  49. pyfemtet/FemtetPJTSample/Sldworks_ex01/Sldworks_ex01.femprj +0 -0
  50. pyfemtet/FemtetPJTSample/Sldworks_ex01/Sldworks_ex01.py +0 -121
  51. pyfemtet/FemtetPJTSample/_her_ex40_parametric.py +0 -148
  52. pyfemtet/FemtetPJTSample/gau_ex08_parametric.femprj +0 -0
  53. pyfemtet/FemtetPJTSample/gau_ex08_parametric.py +0 -58
  54. pyfemtet/FemtetPJTSample/her_ex40_parametric.femprj +0 -0
  55. pyfemtet/FemtetPJTSample/her_ex40_parametric.py +0 -148
  56. pyfemtet/FemtetPJTSample/wat_ex14_parallel_parametric.py +0 -65
  57. pyfemtet/FemtetPJTSample/wat_ex14_parametric.femprj +0 -0
  58. pyfemtet/FemtetPJTSample/wat_ex14_parametric.py +0 -64
  59. {pyfemtet-0.4.20.dist-info → pyfemtet-0.4.23.dist-info}/LICENSE +0 -0
  60. {pyfemtet-0.4.20.dist-info → pyfemtet-0.4.23.dist-info}/WHEEL +0 -0
  61. {pyfemtet-0.4.20.dist-info → pyfemtet-0.4.23.dist-info}/entry_points.txt +0 -0
@@ -215,6 +215,24 @@ class FemtetInterface(FEMInterface):
215
215
  if self.Femtet is None:
216
216
  raise RuntimeError(Msg.ERR_FEMTET_CONNECTION_FAILED)
217
217
 
218
+ def _check_gaudi_accessible(self) -> bool:
219
+ try:
220
+ _ = self.Femtet.Gaudi
221
+ except com_error:
222
+ # モデルが開かれていないかFemtetが起動していない
223
+ return False
224
+ return True
225
+
226
+ # noinspection PyMethodMayBeStatic
227
+ def _construct_femtet_api(self, string): # static にしてはいけない
228
+ if isinstance(string, str):
229
+ if string.startswith('self.'):
230
+ return eval(string)
231
+ else:
232
+ return eval('self.' + string)
233
+ else:
234
+ return string # Callable
235
+
218
236
  def _call_femtet_api(
219
237
  self,
220
238
  fun,
@@ -232,7 +250,7 @@ class FemtetInterface(FEMInterface):
232
250
 
233
251
  Parameters
234
252
  ----------
235
- fun : Callable
253
+ fun : Callable or str
236
254
  Femtet API
237
255
  return_value_if_failed : Any
238
256
  API が失敗した時の戻り値
@@ -271,36 +289,62 @@ class FemtetInterface(FEMInterface):
271
289
  # 1. 結果に関わらず戻り値が None で API 実行時に com_error を送出する
272
290
  # 2. API 実行時に成功失敗を示す戻り値を返し、ShowLastError で例外にアクセスできる状態になる
273
291
 
292
+ # 実行する API をデバッグ出力
293
+ if isinstance(fun, str):
294
+ logger.debug(' ' * print_indent + f'Femtet API:{fun}, args:{args}, kwargs:{kwargs}')
295
+ else:
296
+ logger.debug(' ' * print_indent + f'Femtet API:{fun.__name__}, args:{args}, kwargs:{kwargs}')
297
+
274
298
  # Gaudi コマンドなら Gaudi.Activate する
275
- logger.debug(' ' * print_indent + f'Femtet API:{fun.__name__}, args:{args}, kwargs:{kwargs}')
276
299
  if is_Gaudi_method: # Optimizer は Gogh に触らないので全部にこれをつけてもいい気がする
277
300
  try:
278
- self._call_femtet_api(
279
- self.Femtet.Gaudi.Activate,
280
- False, # None 以外なら何でもいい
281
- Exception,
282
- 'Gaudi のオープンに失敗しました',
283
- print_indent=print_indent + 1
284
- )
301
+ # まず Gaudi にアクセスできるか
302
+ gaudi_accessible = self._check_gaudi_accessible()
303
+ if gaudi_accessible:
304
+ # Gaudi にアクセスできるなら Gaudi を Activate する
305
+ fun = self._construct_femtet_api(fun) # (str) -> Callable
306
+ if fun.__name__ != 'Activate':
307
+ # 再帰ループにならないように
308
+ self._call_femtet_api(
309
+ self.Femtet.Gaudi.Activate,
310
+ False, # None 以外なら何でもいい
311
+ Exception,
312
+ 'Gaudi のオープンに失敗しました',
313
+ print_indent=print_indent + 1
314
+ )
315
+
316
+ else:
317
+ # Gaudi にアクセスできないならば次の API 実行でエラーになる
318
+ pass
319
+
285
320
  except com_error:
286
- # Gaudi へのアクセスだけで com_error が生じうる
287
- # そういう場合は次の API 実行で間違いなくエラーになるので放っておく
288
321
  pass
289
322
 
290
323
  # API を実行
291
324
  try:
325
+ # gaudi のメソッドかどうかにかかわらず、gaudi へのアクセスでエラーが出るか
326
+ if not self._check_gaudi_accessible():
327
+ raise com_error
328
+
329
+ # gaudi_accessible なので関数が何であろうが安全にアクセスはできる
330
+ if isinstance(fun, str):
331
+ fun = self._construct_femtet_api(fun) # (str) -> Callable
332
+
292
333
  # 解析結果を開いた状態で Gaudi.Activate して ReExecute する場合、ReExecute の前後にアクティブ化イベントが必要
293
334
  # さらに、プロジェクトツリーが開いていないとアクティブ化イベントも意味がないらしい。
294
335
  if fun.__name__ == 'ReExecute':
295
336
  if self.open_result_with_gui or self.parametric_output_indexes_use_as_objective:
296
337
  post_activate_message(self.Femtet.hWnd)
338
+ # API を実行
297
339
  returns = fun(*args, **kwargs) # can raise pywintypes.error
298
340
  if self.open_result_with_gui or self.parametric_output_indexes_use_as_objective:
299
341
  post_activate_message(self.Femtet.hWnd)
300
342
  else:
301
343
  returns = fun(*args, **kwargs)
344
+
345
+ # API の実行に失敗
302
346
  except (com_error, error):
303
- # パターン 2 エラーが生じたことは確定なのでエラーが起こるよう returns を作る
347
+ # 後続の処理でエラー判定されるように returns を作る
304
348
  # com_error ではなく error の場合はおそらく Femtet が落ちている
305
349
  if ret_for_check_idx is None:
306
350
  returns = return_value_if_failed
@@ -361,7 +405,7 @@ class FemtetInterface(FEMInterface):
361
405
 
362
406
  def femtet_is_alive(self) -> bool:
363
407
  """Returns connected femtet process is existing or not."""
364
- return _get_pid(self.Femtet.hWnd) > 0
408
+ return _get_pid(self.Femtet.hWnd) > 0 # hWnd の値はすでに Femtet が終了している場合は 0
365
409
 
366
410
  def open(self, femprj_path: str, model_name: str or None = None) -> None:
367
411
  """Open specific analysis model with connected Femtet."""
@@ -479,7 +523,7 @@ class FemtetInterface(FEMInterface):
479
523
  # 変数更新のための処理
480
524
  sleep(0.1) # Gaudi がおかしくなる時がある対策
481
525
  self._call_femtet_api(
482
- self.Femtet.Gaudi.Activate,
526
+ 'self.Femtet.Gaudi.Activate',
483
527
  True, # 戻り値を持たないのでここは無意味で None 以外なら何でもいい
484
528
  Exception, # 生きてるのに開けない場合
485
529
  error_message=Msg.NO_ANALYSIS_MODEL_IS_OPEN,
@@ -553,7 +597,7 @@ class FemtetInterface(FEMInterface):
553
597
 
554
598
  # 設計変数に従ってモデルを再構築
555
599
  self._call_femtet_api(
556
- self.Femtet.Gaudi.ReExecute,
600
+ 'self.Femtet.Gaudi.ReExecute',
557
601
  False,
558
602
  ModelError, # 生きてるのに失敗した場合
559
603
  error_message=Msg.ERR_RE_EXECUTE_MODEL_FAILED,
@@ -576,7 +620,7 @@ class FemtetInterface(FEMInterface):
576
620
  """Execute FEM analysis."""
577
621
  # # メッシュを切る
578
622
  self._call_femtet_api(
579
- self.Femtet.Gaudi.Mesh,
623
+ 'self.Femtet.Gaudi.Mesh',
580
624
  0,
581
625
  MeshError,
582
626
  Msg.ERR_MODEL_MESH_FAILED,
@@ -630,6 +674,11 @@ class FemtetInterface(FEMInterface):
630
674
  def quit(self, timeout=1, force=True):
631
675
  """Force to terminate connected Femtet."""
632
676
  major, minor, bugfix = 2024, 0, 1
677
+
678
+ # すでに終了しているならば何もしない
679
+ if not self.femtet_is_alive():
680
+ return
681
+
633
682
  if self._version() >= _version(major, minor, bugfix):
634
683
  # gracefully termination method without save project available from 2024.0.1
635
684
  try:
@@ -715,16 +764,20 @@ class FemtetInterface(FEMInterface):
715
764
  # save to worker space
716
765
  result_dir = self.femprj_path.replace('.femprj', '.Results')
717
766
  pdt_path = os.path.join(result_dir, self.model_name + '.pdt')
718
- succeed = self.Femtet.SavePDT(pdt_path, True)
719
767
 
720
- # convert .pdt to ByteIO
721
- if succeed:
722
- with open(pdt_path, 'rb') as f:
723
- content = f.read()
724
- return content
768
+ self._call_femtet_api(
769
+ fun=self.Femtet.SavePDT,
770
+ args=(pdt_path, True),
771
+ return_value_if_failed=False,
772
+ if_error=SolveError,
773
+ error_message=Msg.ERR_FAILED_TO_SAVE_PDT,
774
+ is_Gaudi_method=False,
775
+ )
725
776
 
726
- else:
727
- raise Exception(Msg.ERR_FAILED_TO_SAVE_PDT)
777
+ # convert .pdt to ByteIO and return it
778
+ with open(pdt_path, 'rb') as f:
779
+ content = f.read()
780
+ return content
728
781
 
729
782
  else:
730
783
  return None
pyfemtet/opt/opt/_base.py CHANGED
@@ -1,5 +1,6 @@
1
1
  # typing
2
2
  from abc import ABC, abstractmethod
3
+ from typing import Optional
3
4
 
4
5
  # built-in
5
6
  import traceback
@@ -13,6 +14,7 @@ import pandas as pd
13
14
  from pyfemtet.opt.interface import FemtetInterface
14
15
  from pyfemtet.opt._femopt_core import OptimizationStatus
15
16
  from pyfemtet.message import Msg
17
+ from pyfemtet.opt.parameter import ExpressionEvaluator
16
18
 
17
19
  # logger
18
20
  import logging
@@ -134,6 +136,7 @@ class AbstractOptimizer(ABC):
134
136
  self.fem_class = None
135
137
  self.fem_kwargs = dict()
136
138
  self.parameters: pd.DataFrame = pd.DataFrame()
139
+ self.variables: ExpressionEvaluator = ExpressionEvaluator()
137
140
  self.objectives: dict = dict()
138
141
  self.constraints: dict = dict()
139
142
  self.entire_status = None # actor
@@ -145,7 +148,7 @@ class AbstractOptimizer(ABC):
145
148
  self.n_trials = None
146
149
  self.is_cluster = False
147
150
  self.subprocess_idx = None
148
- self._is_error_exit = False
151
+ self._exception = None
149
152
  self.method_checker: OptimizationMethodChecker = OptimizationMethodChecker(self)
150
153
 
151
154
  def f(self, x):
@@ -153,14 +156,22 @@ class AbstractOptimizer(ABC):
153
156
  # interruption の実装は具象クラスに任せる
154
157
 
155
158
  # x の更新
156
- self.parameters['value'] = x
159
+ prm_names = self.variables.get_parameter_names()
160
+ for name, value in zip(prm_names, x):
161
+ self.variables.variables[name].value = value
162
+
157
163
  logger.info('---------------------')
158
164
  logger.info(f'input: {x}')
159
165
 
160
166
  # FEM の更新
161
167
  logger.debug('fem.update() start')
162
168
  try:
163
- self.fem.update(self.parameters)
169
+ df_to_fem = self.variables.get_variables(
170
+ format='df',
171
+ filter_pass_to_fem=True
172
+ )
173
+ self.fem.update(df_to_fem)
174
+
164
175
  except Exception as e:
165
176
  logger.info(f'{type(e).__name__} : {e}')
166
177
  logger.info(Msg.INFO_EXCEPTION_DURING_FEM_ANALYSIS)
@@ -178,8 +189,14 @@ class AbstractOptimizer(ABC):
178
189
  c = [cns.calc(self.fem) for cns in self.constraints.values()]
179
190
 
180
191
  logger.debug('history.record start')
192
+
193
+ df_to_opt = self.variables.get_variables(
194
+ format='df',
195
+ filter_parameter=True,
196
+ )
197
+
181
198
  self.history.record(
182
- self.parameters,
199
+ df_to_opt,
183
200
  self.objectives,
184
201
  self.constraints,
185
202
  y,
@@ -220,17 +237,7 @@ class AbstractOptimizer(ABC):
220
237
  ValueError: If an invalid format is provided.
221
238
 
222
239
  """
223
- if format == 'df':
224
- return self.parameters
225
- elif format == 'values' or format == 'value':
226
- return self.parameters.value.values
227
- elif format == 'dict':
228
- ret = {}
229
- for i, row in self.parameters.iterrows():
230
- ret[row['name']] = row.value
231
- return ret
232
- else:
233
- raise ValueError(f'get_parameter() got invalid format: {format}')
240
+ return self.variables.get_variables(format=format)
234
241
 
235
242
  def _check_interruption(self):
236
243
  """"""
@@ -253,7 +260,7 @@ class AbstractOptimizer(ABC):
253
260
  worker_status_list,
254
261
  wait_setup,
255
262
  skip_set_fem=False,
256
- ) -> bool:
263
+ ) -> Optional[Exception]:
257
264
 
258
265
  # 自分の worker_status の取得
259
266
  self.subprocess_idx = subprocess_idx
@@ -296,12 +303,12 @@ class AbstractOptimizer(ABC):
296
303
  logger.error("=================================")
297
304
  logger.error(f'{type(e).__name__}: {e}')
298
305
  traceback.print_exc()
299
- self._is_error_exit = True
306
+ self._exception = e
300
307
  self.worker_status.set(OptimizationStatus.CRASHED)
301
308
  finally:
302
309
  self._finalize()
303
310
 
304
- return self._is_error_exit
311
+ return self._exception
305
312
 
306
313
  @abstractmethod
307
314
  def run(self) -> None:
@@ -54,6 +54,7 @@ class OptunaOptimizer(AbstractOptimizer):
54
54
  self.additional_initial_parameter = []
55
55
  self.additional_initial_methods = add_init_method if hasattr(add_init_method, '__iter__') else [add_init_method]
56
56
  self.method_checker = OptunaMethodChecker(self)
57
+ self.parameter_constraints = []
57
58
 
58
59
  def _objective(self, trial):
59
60
 
@@ -63,19 +64,25 @@ class OptunaOptimizer(AbstractOptimizer):
63
64
  trial.study.stop() # 現在実行中の trial を最後にする
64
65
  return None # set TrialState FAIL
65
66
 
66
- # candidate x
67
- x = []
68
- for i, row in self.parameters.iterrows():
69
- v = trial.suggest_float(row['name'], row['lb'], row['ub'], step=row['step'])
70
- x.append(v)
71
- x = np.array(x).astype(float)
67
+ # candidate x and update parameters
68
+ for prm in self.variables.get_variables(format='raw', filter_parameter=True):
69
+ value = trial.suggest_float(
70
+ name=prm.name,
71
+ low=prm.lower_bound,
72
+ high=prm.upper_bound,
73
+ step=prm.step,
74
+ )
75
+ self.variables.variables[prm.name].value = value
76
+
77
+ # update expressions
78
+ self.variables.evaluate()
72
79
 
73
80
  # message の設定
74
81
  self.message = trial.user_attrs['message'] if 'message' in trial.user_attrs.keys() else ''
75
82
 
76
- # fem や opt 経由で変数を取得して constraint を計算する時のためにアップデート
77
- self.parameters['value'] = x
78
- self.fem.update_parameter(self.parameters)
83
+ # fem 経由で変数を取得して constraint を計算する時のためにアップデート
84
+ df_fem = self.variables.get_variables(format='df', filter_pass_to_fem=True)
85
+ self.fem.update_parameter(df_fem)
79
86
 
80
87
  # strict 拘束
81
88
  strict_constraints = [cns for cns in self.constraints.values() if cns.strict]
@@ -89,10 +96,11 @@ class OptunaOptimizer(AbstractOptimizer):
89
96
  if not feasible:
90
97
  logger.info(Msg.INFO_INFEASIBLE)
91
98
  logger.info(f'Constraint: {cns.name}')
92
- logger.info(self.get_parameter('dict'))
99
+ logger.info(self.variables.get_variables('dict', filter_parameter=True))
93
100
  raise optuna.TrialPruned() # set TrialState PRUNED because FAIL causes similar candidate loop.
94
101
 
95
102
  # 計算
103
+ x = self.variables.get_variables(format='values', filter_parameter=True)
96
104
  try:
97
105
  _, _y, c = self.f(x) # f の中で info は出している
98
106
  except (ModelError, MeshError, SolveError) as e:
@@ -153,7 +161,7 @@ class OptunaOptimizer(AbstractOptimizer):
153
161
 
154
162
  # restart である場合、追加 N 回と見做す
155
163
  if self.history.is_restart:
156
- n_existing_trials = len(self.history.actor_data)
164
+ n_existing_trials = len(self.history.get_df())
157
165
  n_trials += n_existing_trials
158
166
 
159
167
  self.optimize_callbacks.append(MaxTrialsCallback(n_trials, states=(TrialState.COMPLETE,)))
@@ -175,7 +183,7 @@ class OptunaOptimizer(AbstractOptimizer):
175
183
  # 初期値の設定
176
184
  if len(self.study.trials) == 0: # リスタートでなければ
177
185
  # ユーザーの指定した初期値
178
- params = self.get_parameter('dict')
186
+ params = self.variables.get_variables('dict', filter_parameter=True)
179
187
  self.study.enqueue_trial(params, user_attrs={"message": "initial"})
180
188
 
181
189
  # add_initial_parameter で追加された初期値
@@ -197,8 +205,8 @@ class OptunaOptimizer(AbstractOptimizer):
197
205
  bounds = []
198
206
  for i, row in self.parameters.iterrows():
199
207
  names.append(row['name'])
200
- lb = row['lb']
201
- ub = row['ub']
208
+ lb = row['lower_bound']
209
+ ub = row['upper_bound']
202
210
  bounds.append([lb, ub])
203
211
  data = generate_lhs(bounds, seed=self.seed)
204
212
  for datum in data:
@@ -266,9 +274,40 @@ class OptunaOptimizer(AbstractOptimizer):
266
274
  sampler=sampler,
267
275
  )
268
276
 
277
+ # monkey patch
278
+ if len(self.parameter_constraints) > 0:
279
+ assert isinstance(sampler, optuna.integration.BoTorchSampler), Msg.ERR_PARAMETER_CONSTRAINT_ONLY_BOTORCH
280
+
281
+ from pyfemtet.opt.opt._optuna_botorch_helper import OptunaBotorchWithParameterConstraintMonkeyPatch
282
+ mp = OptunaBotorchWithParameterConstraintMonkeyPatch(
283
+ study,
284
+ self,
285
+ )
286
+ for p_cns in self.parameter_constraints:
287
+ fun = p_cns['fun']
288
+ prm_args = p_cns['prm_args']
289
+ kwargs = p_cns['kwargs']
290
+ mp.add_nonlinear_constraint(fun, prm_args, kwargs)
291
+ mp.do_monkey_patch()
292
+
269
293
  # run
270
294
  study.optimize(
271
295
  self._objective,
272
296
  timeout=self.timeout,
273
297
  callbacks=self.optimize_callbacks,
274
298
  )
299
+
300
+ def add_parameter_constraints(
301
+ self,
302
+ fun,
303
+ prm_args=None,
304
+ kwargs=None
305
+ ):
306
+ kwargs = kwargs if kwargs is not None else {}
307
+ self.parameter_constraints.append(
308
+ dict(
309
+ fun=fun,
310
+ prm_args=prm_args,
311
+ kwargs=kwargs,
312
+ )
313
+ )
@@ -0,0 +1,209 @@
1
+ from typing import Optional, List, Tuple, Callable
2
+ from functools import partial
3
+ import inspect
4
+
5
+ import numpy as np
6
+ import optuna.study
7
+ import torch
8
+ from torch import Tensor
9
+ from botorch.optim.initializers import gen_batch_initial_conditions
10
+ from botorch.utils.transforms import unnormalize
11
+ from optuna.study import Study
12
+ from botorch.acquisition import AcquisitionFunction
13
+
14
+ from pyfemtet.opt.opt import AbstractOptimizer
15
+ from pyfemtet.opt.parameter import ExpressionEvaluator
16
+
17
+ # module to monkey patch
18
+ import optuna_integration
19
+
20
+
21
+ # モンキーパッチを実行するため、optimize_acqf の引数を MonkyPatch クラスで定義し optuna に上書きされないようにするためのクラス
22
+ class NonOverwritablePartial(partial):
23
+ def __call__(self, /, *args, **keywords):
24
+ stored_kwargs = self.keywords
25
+ keywords.update(stored_kwargs)
26
+ return self.func(*self.args, *args, **keywords)
27
+
28
+
29
+ # prm_name を引数に取る関数を optimize_acqf の nonlinear_inequality_constraints に入れられる形に変換する関数
30
+ class ConvertedConstraintFunction:
31
+ def __init__(self, fun, prm_args, kwargs, variables: ExpressionEvaluator, study: optuna.study.Study):
32
+ self.fun = fun
33
+ self.prm_args = prm_args
34
+ self.kwargs = kwargs
35
+ self.variables = variables
36
+ self.study = study
37
+
38
+ self.bounds = None
39
+ self.prm_name_seq = None
40
+
41
+ # fun の prm として使う引数が指定されていなければ fun の引数を取得
42
+ if self.prm_args is None:
43
+ signature = inspect.signature(fun)
44
+ prm_inputs = set([a.name for a in signature.parameters.values()])
45
+ else:
46
+ prm_inputs = set(self.prm_args)
47
+
48
+ # 引数の set から kwargs の key を削除
49
+ self.prm_arg_names = prm_inputs - set(kwargs.keys())
50
+
51
+ # 変な引数が残っていないか確認
52
+ assert all([(arg in variables.get_parameter_names()) for arg in self.prm_arg_names])
53
+
54
+ def __call__(self, x: Tensor or np.ndarray):
55
+ # x: all of normalized parameters whose sequence is sorted by optuna
56
+
57
+ if not isinstance(x, Tensor):
58
+ x = torch.tensor(np.array(x)).double()
59
+
60
+ x = unnormalize(x, self.bounds)
61
+
62
+ # fun で使うパラメータのみ value を取得
63
+ kwargs = self.kwargs
64
+ kwargs.update(
65
+ {k: v for k, v in zip(self.prm_name_seq, x) if k in self.prm_arg_names}
66
+ )
67
+
68
+ return self.fun(**kwargs)
69
+
70
+
71
+ # 与えられた獲得関数に拘束を満たさない場合 0 を返すよう加工された獲得関数
72
+ class AcqWithConstraint(AcquisitionFunction):
73
+
74
+ # noinspection PyAttributeOutsideInit
75
+ def set(self, _org_acq_function: AcquisitionFunction, nonlinear_constraints):
76
+ self._org_acq_function = _org_acq_function
77
+ self._nonlinear_constraints = nonlinear_constraints
78
+
79
+ def forward(self, X: Tensor) -> Tensor:
80
+ base = self._org_acq_function.forward(X)
81
+
82
+ is_feasible = all([cons(X[0][0]) > 0 for cons, _ in self._nonlinear_constraints])
83
+ if is_feasible:
84
+ return base
85
+ else:
86
+ # penalty = torch.Tensor(size=base.shape)
87
+ # penalty = torch.fill(penalty, -1e10)
88
+ # return base * penalty
89
+ return base * 0.
90
+
91
+
92
+ def remove_infeasible(_ic_batch, nonlinear_constraints):
93
+ # infeasible なものを削除
94
+ remove_indices = []
95
+ for i, ic in enumerate(_ic_batch): # ic: 1 x len(params) tensor
96
+ # cons: Callable[["Tensor"], "Tensor"]
97
+ is_feasible = all([cons(ic[0]) > 0 for cons, _ in nonlinear_constraints])
98
+ if not is_feasible:
99
+ # ic_batch[i] = torch.nan # これで無視にならない
100
+ remove_indices.append(i)
101
+ for i in remove_indices[::-1]:
102
+ _ic_batch = torch.cat((_ic_batch[:i], _ic_batch[i + 1:]))
103
+ return _ic_batch
104
+
105
+
106
+ class OptunaBotorchWithParameterConstraintMonkeyPatch:
107
+
108
+ def __init__(self, study: Study, opt: AbstractOptimizer):
109
+ self.num_restarts: int = 20
110
+ self.raw_samples_additional: int = 512
111
+ self.eta: float = 2.0
112
+ self.study = study
113
+ self.opt = opt
114
+ self.nonlinear_inequality_constraints = []
115
+ self.additional_kwargs = dict()
116
+ self.bounds = None
117
+ self.prm_name_seq = None
118
+
119
+ def add_nonlinear_constraint(self, fun, prm_args, kwargs):
120
+ f = ConvertedConstraintFunction(
121
+ fun,
122
+ prm_args,
123
+ kwargs,
124
+ self.opt.variables,
125
+ self.study,
126
+ )
127
+
128
+ # 初期化
129
+ self.nonlinear_inequality_constraints = self.nonlinear_inequality_constraints or []
130
+
131
+ # 自身に追加
132
+ self.nonlinear_inequality_constraints.append((f, True))
133
+
134
+ # optimize_acqf() に渡す引数に追加
135
+ self.additional_kwargs.update(
136
+ nonlinear_inequality_constraints=self.nonlinear_inequality_constraints
137
+ )
138
+
139
+ def _detect_prm_seq_if_needed(self):
140
+ # study から distribution の情報を復元する。
141
+ if self.bounds is None or self.prm_name_seq is None:
142
+ from optuna._transform import _transform_search_space
143
+ # sample_relative の後に呼ばれているから最後の trial は search_space を持つはず
144
+ search_space: dict = self.study.sampler.infer_relative_search_space(self.study, self.study.trials[-1])
145
+ self.bounds = _transform_search_space(search_space, False, False)[0].T
146
+ self.prm_name_seq = list(search_space.keys())
147
+
148
+ for cns in self.nonlinear_inequality_constraints:
149
+ cns[0].bounds = torch.tensor(self.bounds)
150
+ cns[0].prm_name_seq = self.prm_name_seq
151
+
152
+ def generate_initial_conditions(self, *args, **kwargs):
153
+ self._detect_prm_seq_if_needed()
154
+
155
+ # acqf_function を 上書きし、拘束を満たさないならば 0 を返すようにする
156
+ org_acq_function = kwargs['acq_function']
157
+ new_acqf = AcqWithConstraint(None)
158
+ new_acqf.set(org_acq_function, self.nonlinear_inequality_constraints)
159
+ kwargs['acq_function'] = new_acqf
160
+
161
+ # initial condition の提案 batch を作成
162
+ # ic: `num_restarts x q x d` tensor of initial conditions.
163
+ # q = 1, d = len(params)
164
+ ic_batch = gen_batch_initial_conditions(*args, **kwargs)
165
+
166
+ # 拘束を満たさないものを削除
167
+ ic_batch = remove_infeasible(ic_batch, self.nonlinear_inequality_constraints)
168
+
169
+ # 全部なくなっているならばランダムに生成
170
+ if len(ic_batch) == 0:
171
+ print('拘束を満たす組み合わせがなかったのでランダムサンプリングします')
172
+ while len(ic_batch) == 0:
173
+ size = ic_batch.shape
174
+ ic_batch = torch.rand(size=[100, *size[1:]]) # 正規化された変数の組合せ
175
+ ic_batch = remove_infeasible(ic_batch, self.nonlinear_inequality_constraints)
176
+
177
+ return ic_batch
178
+
179
+ def do_monkey_patch(self):
180
+ """optuna_integration.botorch には optimize_acqf に constraints を渡す方法が用意されていないので、モンキーパッチして渡す
181
+
182
+ モンキーパッチ自体は最適化実行前のどの時点で呼んでも機能するが、additional_kwargs の更新後に
183
+ モンキーパッチを呼ぶ必要があるのでコンストラクタにこの処理は入れない。
184
+ 各 add_constraint に入れるのはいいかも。
185
+
186
+ """
187
+
188
+ # === reconstruct argument ``options`` for optimize_acqf ===
189
+ options = dict() # initialize
190
+
191
+ # for nonlinear-constraint
192
+ options.update(dict(batch_limit=1))
193
+
194
+ # for gen_candidates_scipy()
195
+ # use COBYLA or SLSQP only.
196
+ options.update(dict(method='SLSQP'))
197
+
198
+ # make partial of optimize_acqf used in optuna_integration.botorch and replace to it.
199
+ original_fun = optuna_integration.botorch.optimize_acqf
200
+ overwritten_fun = NonOverwritablePartial(
201
+ original_fun,
202
+ q=1, # for nonlinear constraints
203
+ options=options,
204
+ num_restarts=20, # gen_batch_initial_conditions に渡すべきで、self.generate_initial_conditions に渡される変数。
205
+ raw_samples=512, # gen_batch_initial_conditions に渡すべきで、self.generate_initial_conditions に渡される変数。
206
+ nonlinear_inequality_constraints=self.nonlinear_inequality_constraints,
207
+ ic_generator=self.generate_initial_conditions,
208
+ )
209
+ optuna_integration.botorch.optimize_acqf = overwritten_fun
@@ -114,7 +114,7 @@ class ScipyOptimizer(AbstractOptimizer):
114
114
  if 'bounds' not in self.minimize_kwargs.keys():
115
115
  bounds = []
116
116
  for i, row in self.parameters.iterrows():
117
- lb, ub = row['lb'], row['ub']
117
+ lb, ub = row['lower_buond'], row['upper_bound']
118
118
  if lb is None: lb = -np.inf
119
119
  if ub is None: ub = np.inf
120
120
  bounds.append([lb, ub])
@@ -78,7 +78,7 @@ class ScipyScalarOptimizer(AbstractOptimizer):
78
78
  if 'bounds' not in self.minimize_kwargs.keys():
79
79
  bounds = []
80
80
  for i, row in self.parameters.iterrows():
81
- lb, ub = row['lb'], row['ub']
81
+ lb, ub = row['lower_bound'], row['upper_bound']
82
82
  if lb is None: lb = -np.inf
83
83
  if ub is None: ub = np.inf
84
84
  bounds.append([lb, ub])