pyfemtet 0.8.4__py3-none-any.whl → 0.8.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyfemtet might be problematic. Click here for more details.

@@ -1,4 +1,5 @@
1
1
  # typing
2
+ import json
2
3
  from typing import List, TYPE_CHECKING
3
4
 
4
5
  # built-in
@@ -524,7 +525,6 @@ class History:
524
525
  obj_names (List[str], optional): The names of objectives. Defaults to None.
525
526
  cns_names (List[str], optional): The names of constraints. Defaults to None.
526
527
  client (dask.distributed.Client): Dask client.
527
- additional_metadata (str, optional): metadata of optimization process.
528
528
  hv_reference (str or list[float or np.ndarray, optional):
529
529
  The method to calculate hypervolume or
530
530
  the reference point itself.
@@ -550,7 +550,6 @@ class History:
550
550
  obj_names=None,
551
551
  cns_names=None,
552
552
  client=None,
553
- additional_metadata=None,
554
553
  hv_reference=None,
555
554
  ):
556
555
  # hypervolume 計算メソッド
@@ -561,7 +560,8 @@ class History:
561
560
  self.prm_names = prm_names
562
561
  self.obj_names = obj_names
563
562
  self.cns_names = cns_names
564
- self.additional_metadata = additional_metadata or ''
563
+ self.extra_data = dict()
564
+ self.meta_columns = None
565
565
  self.__scheduler_address = client.scheduler.address if client is not None else None
566
566
 
567
567
  # 最適化実行中かどうか
@@ -575,15 +575,15 @@ class History:
575
575
 
576
576
  # 続きからなら df を読み込んで df にコピー
577
577
  if self.is_restart:
578
- self.load()
578
+ self.load() # 中で meta_columns を読む
579
579
 
580
580
  # そうでなければ df を初期化
581
581
  else:
582
- columns, metadata = self.create_df_columns()
582
+ columns, meta_columns = self.create_df_columns()
583
583
  df = pd.DataFrame()
584
584
  for c in columns:
585
585
  df[c] = None
586
- self.metadata = metadata
586
+ self.meta_columns = meta_columns
587
587
  self.set_df(df)
588
588
 
589
589
  # 一時ファイルに書き込みを試み、UnicodeEncodeError が出ないかチェック
@@ -609,16 +609,16 @@ class History:
609
609
  # df を読み込む
610
610
  df = pd.read_csv(self.path, encoding=self.ENCODING, header=self.HEADER_ROW)
611
611
 
612
- # metadata を読み込む
612
+ # meta_columns を読み込む
613
613
  with open(self.path, mode='r', encoding=self.ENCODING, newline='\n') as f:
614
614
  reader = csv.reader(f, delimiter=',')
615
- self.metadata = reader.__next__()
615
+ self.meta_columns = reader.__next__()
616
616
 
617
617
  # 最適化問題を読み込む
618
618
  columns = df.columns
619
- prm_names = [column for i, column in enumerate(columns) if self.metadata[i] == 'prm']
620
- obj_names = [column for i, column in enumerate(columns) if self.metadata[i] == 'obj']
621
- cns_names = [column for i, column in enumerate(columns) if self.metadata[i] == 'cns']
619
+ prm_names = [column for i, column in enumerate(columns) if self.meta_columns[i] == 'prm']
620
+ obj_names = [column for i, column in enumerate(columns) if self.meta_columns[i] == 'obj']
621
+ cns_names = [column for i, column in enumerate(columns) if self.meta_columns[i] == 'cns']
622
622
 
623
623
  # is_restart の場合、読み込んだ names と引数の names が一致するか確認しておく
624
624
  if self.is_restart:
@@ -662,7 +662,7 @@ class History:
662
662
  logger.debug('Access df of History before it is initialized.')
663
663
  return pd.DataFrame()
664
664
  except OSError:
665
- logger.error('Scheduler is already dead. Most frequent reasen to show this message is that the pyfemtet monitor UI is not refreshed even if the main optimization process is terminated.')
665
+ logger.error('Scheduler is already dead. Most frequent reason to show this message is that the pyfemtet monitor UI is not refreshed even if the main optimization process is terminated.')
666
666
  return pd.DataFrame()
667
667
 
668
668
  def set_df(self, df: pd.DataFrame):
@@ -687,46 +687,46 @@ class History:
687
687
  columns = list()
688
688
 
689
689
  # columns のメタデータを作成
690
- metadata = list()
690
+ meta_columns = list()
691
691
 
692
692
  # trial
693
693
  columns.append('trial') # index
694
- metadata.append(self.additional_metadata)
694
+ meta_columns.append('') # extra_data. save 時に中身を記入する。
695
695
 
696
696
  # parameter
697
697
  for prm_name in self.prm_names:
698
698
  columns.extend([prm_name, prm_name + '_lower_bound', prm_name + '_upper_bound'])
699
- metadata.extend(['prm', 'prm_lb', 'prm_ub'])
699
+ meta_columns.extend(['prm', 'prm_lb', 'prm_ub'])
700
700
 
701
701
  # objective relative
702
702
  for name in self.obj_names:
703
703
  columns.append(name)
704
- metadata.append('obj')
704
+ meta_columns.append('obj')
705
705
  columns.append(name + '_direction')
706
- metadata.append('obj_direction')
706
+ meta_columns.append('obj_direction')
707
707
  columns.append('non_domi')
708
- metadata.append('')
708
+ meta_columns.append('')
709
709
 
710
710
  # constraint relative
711
711
  for name in self.cns_names:
712
712
  columns.append(name)
713
- metadata.append('cns')
713
+ meta_columns.append('cns')
714
714
  columns.append(name + '_lower_bound')
715
- metadata.append('cns_lb')
715
+ meta_columns.append('cns_lb')
716
716
  columns.append(name + '_upper_bound')
717
- metadata.append('cns_ub')
717
+ meta_columns.append('cns_ub')
718
718
  columns.append('feasible')
719
- metadata.append('')
719
+ meta_columns.append('')
720
720
 
721
721
  # the others
722
722
  columns.append('hypervolume')
723
- metadata.append('')
723
+ meta_columns.append('')
724
724
  columns.append('message')
725
- metadata.append('')
725
+ meta_columns.append('')
726
726
  columns.append('time')
727
- metadata.append('')
727
+ meta_columns.append('')
728
728
 
729
- return columns, metadata
729
+ return columns, meta_columns
730
730
 
731
731
  def record(
732
732
  self,
@@ -973,13 +973,16 @@ class History:
973
973
 
974
974
  df = self.get_df()
975
975
 
976
+ # extra_data の更新
977
+ self.meta_columns[0] = json.dumps(self.extra_data)
978
+
976
979
  if _f is None:
977
980
  # save df with columns with prefix
978
981
  with open(self.path, 'w', encoding=self.ENCODING) as f:
979
982
  writer = csv.writer(f, delimiter=',', lineterminator="\n")
980
- writer.writerow(self.metadata)
983
+ writer.writerow(self.meta_columns)
981
984
  for i in range(self.HEADER_ROW-1):
982
- writer.writerow([''] * len(self.metadata))
985
+ writer.writerow([''] * len(self.meta_columns))
983
986
  df.to_csv(f, index=None, encoding=self.ENCODING, lineterminator='\n')
984
987
  else: # test
985
988
  df.to_csv(_f, index=None, encoding=self.ENCODING, lineterminator='\n')
@@ -1088,3 +1091,35 @@ class OptimizationStatus:
1088
1091
  def get_text(self) -> str:
1089
1092
  """Get optimization status message."""
1090
1093
  return self._actor.status
1094
+
1095
+
1096
+ class _MonitorHostRecordActor:
1097
+ host = None
1098
+ port = None
1099
+
1100
+ def set(self, host, port):
1101
+ self.host = host
1102
+ self.port = port
1103
+
1104
+
1105
+ class MonitorHostRecord:
1106
+
1107
+ def __init__(self, client, worker_name):
1108
+ self._future = client.submit(
1109
+ _MonitorHostRecordActor,
1110
+ actor=True,
1111
+ workers=(worker_name,),
1112
+ allow_other_workers=False,
1113
+ )
1114
+ self._actor = self._future.result()
1115
+
1116
+ def set(self, host, port):
1117
+ self._actor.set(host, port).result()
1118
+
1119
+ def get(self):
1120
+ host = self._actor.host
1121
+ port = self._actor.port
1122
+ if host is None and port is None:
1123
+ return dict()
1124
+ else:
1125
+ return dict(host=host, port=port)
@@ -10,7 +10,7 @@ from pyfemtet.opt import FEMOpt
10
10
  from pyfemtet._message import encoding as ENCODING
11
11
 
12
12
 
13
- def remove_femprj_metadata_from_csv(csv_path, encoding=ENCODING):
13
+ def remove_extra_data_from_csv(csv_path, encoding=ENCODING):
14
14
 
15
15
  with open(csv_path, mode="r", encoding=encoding, newline="\n") as f:
16
16
  reader = csv.reader(f, delimiter=",")
@@ -59,32 +59,39 @@ def _get_obj_from_csv(csv_path, encoding=ENCODING):
59
59
  reader = csv.reader(f, delimiter=",")
60
60
  meta = reader.__next__()
61
61
  obj_indices = np.where(np.array(meta) == "obj")[0]
62
- out = df.iloc[:, obj_indices]
62
+ out: pd.DataFrame = df.iloc[:, obj_indices]
63
+ out = out.dropna(axis=0)
63
64
  return out, columns
64
65
 
65
66
 
66
- def is_equal_result(ref_path, dif_path, log_path):
67
+ def is_equal_result(ref_path, dif_path, log_path=None, threashold=0.05):
67
68
  """Check the equality of two result csv files."""
68
69
  ref_df, ref_columns = _get_obj_from_csv(ref_path)
69
70
  dif_df, dif_columns = _get_obj_from_csv(dif_path)
70
71
 
71
- with open(log_path, "a", newline="\n", encoding=ENCODING) as f:
72
- f.write("\n\n===== 結果の分析 =====\n\n")
73
- f.write(f" \tref\tdif\n")
74
- f.write(f"---------------------\n")
75
- f.write(f"len(col)\t{len(ref_columns)}\t{len(dif_columns)}\n")
76
- f.write(f"len(df) \t{len(ref_df)}\t{len(dif_df)}\n")
77
- try:
78
- difference = (
72
+ if log_path is not None:
73
+ with open(log_path, "a", newline="\n", encoding=ENCODING) as f:
74
+ f.write("\n\n===== 結果の分析 =====\n\n")
75
+ f.write(f" \tref\tdif\n")
76
+ f.write(f"---------------------\n")
77
+ f.write(f"len(col)\t{len(ref_columns)}\t{len(dif_columns)}\n")
78
+ f.write(f"len(df) \t{len(ref_df)}\t{len(dif_df)}\n")
79
+ try:
80
+ difference = (
81
+ np.abs(ref_df.values - dif_df.values) / np.abs(dif_df.values)
82
+ ).mean()
83
+ f.write(f"diff \t{int(difference*100)}%\n")
84
+ except Exception:
85
+ f.write(f"diff \tcannot calc\n")
86
+
87
+ else:
88
+ difference = (
79
89
  np.abs(ref_df.values - dif_df.values) / np.abs(dif_df.values)
80
- ).mean()
81
- f.write(f"diff \t{int(difference*100)}%\n")
82
- except Exception:
83
- f.write(f"diff \tcannot calc\n")
90
+ ).mean()
84
91
 
85
92
  assert len(ref_columns) == len(dif_columns), "結果 csv の column 数が異なります。"
86
93
  assert len(ref_df) == len(dif_df), "結果 csv の row 数が異なります。"
87
- assert difference <= 0.05, "前回の結果との平均差異が 5% を超えています。"
94
+ assert difference <= threashold*100, f"前回の結果との平均差異が {int(difference)}% で {int(threashold*100)}% を超えています。"
88
95
 
89
96
 
90
97
  def _get_simplified_df_values(csv_path, exclude_columns=None):
@@ -21,7 +21,9 @@ else:
21
21
  FemtetWithNXInterface = type('FemtetWithNXInterface', (FemtetInterface,), {})
22
22
  ExcelInterface = type('FemtetInterface', (NotAvailableForWindows,), {})
23
23
 
24
- from pyfemtet.opt.interface._surrogate import PoFBoTorchInterface
24
+ from pyfemtet.opt.interface._surrogate._base import SurrogateModelInterfaceBase
25
+ from pyfemtet.opt.interface._surrogate._singletaskgp import PoFBoTorchInterface
26
+
25
27
 
26
28
  __all__ =[
27
29
  'FEMInterface',
@@ -30,5 +32,6 @@ __all__ =[
30
32
  'FemtetWithSolidworksInterface',
31
33
  'FemtetWithNXInterface',
32
34
  'ExcelInterface',
35
+ 'SurrogateModelInterfaceBase',
33
36
  'PoFBoTorchInterface',
34
37
  ]
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  from typing import Optional, List, Final
2
3
 
3
4
  import os
@@ -90,6 +91,19 @@ class FemtetInterface(FEMInterface):
90
91
  it will be None and no parametric outputs are used
91
92
  as objectives.
92
93
 
94
+
95
+ Note:
96
+ Indexes start at 0, but the parametric analysis
97
+ output settings in the Femtet dialog box indicate
98
+ setting numbers starting at 1.
99
+
100
+
101
+ Warning:
102
+ **Setting this argument deletes the parametric
103
+ analysis swept table set in the femprj file.**
104
+ If you do not want to delete the swept table,
105
+ make a copy of the original file.
106
+
93
107
  **kwargs: Additional arguments from inherited classes.
94
108
 
95
109
  Warning:
@@ -117,6 +131,9 @@ class FemtetInterface(FEMInterface):
117
131
  parametric_output_indexes_use_as_objective: dict[int, str or float] = None,
118
132
  **kwargs # 継承されたクラスからの引数
119
133
  ):
134
+ # warning
135
+ if parametric_output_indexes_use_as_objective is not None:
136
+ warnings.warn('解析モデルに設定された既存のスイープテーブルは削除されます。')
120
137
 
121
138
  # win32com の初期化
122
139
  CoInitialize()
@@ -1,12 +1,147 @@
1
+ import os
2
+ import ctypes
3
+ # from ctypes import wintypes
1
4
  import logging
5
+ import warnings
6
+ from time import sleep, time
7
+ from packaging.version import Version
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ from femtetutils import util
12
+ from femtetutils import logger as util_logger
2
13
 
3
- from femtetutils import util, logger
4
14
  from pyfemtet.dispatch_extensions import _get_pid
15
+ from pyfemtet.core import SolveError
16
+ from pyfemtet._message.messages import encoding, Message
17
+ from pyfemtet.logger import get_module_logger
5
18
 
6
- import ctypes
19
+ logger = get_module_logger('opt.fem.ParametricIF', __name__)
20
+
21
+ util_logger.setLevel(logging.ERROR)
22
+
23
+ # singleton pattern
24
+ _P_CSV: 'ParametricResultCSVProcessor' = None
25
+
26
+
27
+ def get_csv_processor(Femtet):
28
+ global _P_CSV
29
+ if _P_CSV is None:
30
+ _P_CSV = ParametricResultCSVProcessor(Femtet)
31
+ return _P_CSV
32
+
33
+
34
+ class ParametricResultCSVProcessor:
35
+
36
+ def __init__(self, Femtet):
37
+ self.Femtet = Femtet
38
+
39
+ def refresh_csv(self):
40
+ # 存在するならば削除する
41
+ csv_paths = self.get_csv_paths()
42
+ for path in csv_paths:
43
+ if os.path.exists(path):
44
+ os.remove(path)
45
+
46
+ def get_csv_paths(self):
47
+ # 結果フォルダを取得
48
+ path: str = self.Femtet.Project
49
+ res_dir_path = path.removesuffix('.femprj') + '.Results'
50
+
51
+ # csv を取得
52
+ model_name = self.Femtet.AnalysisModelName
53
+ csv_path = os.path.join(res_dir_path, f'{model_name}.csv')
54
+ table_csv_path = os.path.join(res_dir_path, f'{model_name}_table.csv')
55
+
56
+ return csv_path, table_csv_path
57
+
58
+ def check_csv_after_succeeded_PrmCalcExecute(self):
59
+ """Parametric Solve の後に呼ぶこと。"""
60
+
61
+ csv_path, table_csv_path = self.get_csv_paths()
62
+
63
+ # csv が生成されているか
64
+ start = time()
65
+ while not os.path.exists(csv_path):
66
+ # solve が succeeded であるにもかかわらず
67
+ # 数秒経過しても csv が存在しないのはおかしい
68
+ if time() - start > 3.:
69
+ return False
70
+ sleep(0.25)
71
+
72
+ # csv は存在するが、Femtet が古いと
73
+ # table は生成されない
74
+ if not os.path.exists(table_csv_path):
75
+ warnings.warn('テーブル形式 csv が生成されていないため、'
76
+ '結果出力エラーチェックが行われません。'
77
+ 'そのため、結果出力にエラーがある場合は'
78
+ '目的関数が 0 と記録される場合があります。'
79
+ '結果出力エラーチェック機能を利用するためには、'
80
+ 'Femtet を最新バージョンにアップデートして'
81
+ 'ください。')
82
+
83
+ return True
84
+
85
+ def is_succeeded(self, parametric_output_index):
86
+
87
+ # まず csv 保存が成功しているかどうか。通常あるはず。
88
+ if not self.check_csv_after_succeeded_PrmCalcExecute():
89
+ return False, 'Reason: output csv not found.'
90
+
91
+ # 成功しているならば table があるかどうか
92
+ csv_path, table_csv_path = self.get_csv_paths()
93
+
94
+ # なければエラーチェックできないので
95
+ # エラーなしとみなす (warning の記載通り)
96
+ if not os.path.exists(table_csv_path):
97
+ return True, None
98
+
99
+ # table があれば読み込む
100
+ df = pd.read_csv(table_csv_path, encoding=encoding)
7
101
 
102
+ # 結果出力用に行番号を付記する
103
+ df['row_num'] = range(2, len(df) + 2) # row=0 は header, excel は 1 始まり
8
104
 
9
- logger.setLevel(logging.ERROR)
105
+ # 「結果出力設定番号」カラムが存在するか
106
+ if '結果出力設定番号' in df.columns: # TODO: 英語版対応
107
+
108
+ # 与えられた output_number に関連する行だけ抜き出し
109
+ # エラーがあるかどうかチェックする
110
+ pdf = df['結果出力設定番号'] == parametric_output_index + 1
111
+
112
+ # 結果出力設定番号 カラムが存在しない
113
+ else:
114
+ # output_number に関係なくエラーがあればエラーにする
115
+ pdf = df
116
+
117
+ # エラーの有無を確認
118
+ if 'エラー' in pdf.columns: # TODO: 英語版対応
119
+ is_no_error = np.all(pdf['エラー'].isna().values)
120
+
121
+ if not is_no_error:
122
+ error_message_row_numbers = pdf['row_num'][~pdf['エラー'].isna()].values.astype(str)
123
+ error_messages = pdf['エラー'][~pdf['エラー'].isna()].values.astype(str)
124
+
125
+ def add_st_or_nd_or_th(n_: int):
126
+ if n_ == 1:
127
+ return f'{n_}st'
128
+ elif n_ == 2:
129
+ return f'{n_}nd'
130
+ elif n_ == 3:
131
+ return f'{n_}rd'
132
+ else:
133
+ return f'{n_}th'
134
+
135
+ error_msg = f'Error message(s) from {os.path.basename(table_csv_path)}: ' + ', '.join(
136
+ [f'({add_st_or_nd_or_th(row)} row) {message}' for row, message in zip(error_message_row_numbers, error_messages)])
137
+ else:
138
+ error_msg = None
139
+
140
+ else:
141
+ raise RuntimeError('Internal Error! Parametric Analysis '
142
+ 'output csv has no error column.')
143
+
144
+ return is_no_error, error_msg
10
145
 
11
146
 
12
147
  def _get_dll():
@@ -46,7 +181,6 @@ def add_parametric_results_as_objectives(femopt, indexes, directions) -> bool:
46
181
 
47
182
  # get objective names
48
183
  dll.GetPrmnResult.restype = ctypes.c_int
49
- n = dll.GetPrmnResult()
50
184
  for i, direction in zip(indexes, directions):
51
185
  # objective name
52
186
  dll.GetPrmResultName.restype = ctypes.c_char_p
@@ -58,6 +192,15 @@ def add_parametric_results_as_objectives(femopt, indexes, directions) -> bool:
58
192
 
59
193
 
60
194
  def _parametric_objective(Femtet, parametric_result_index):
195
+ # csv から結果取得エラーの有無を確認する
196
+ # (解析自体は成功していないと objective は呼ばれないはず)
197
+ csv_processor = get_csv_processor(Femtet)
198
+ succeeded, error_msg = csv_processor.is_succeeded(parametric_result_index)
199
+ if not succeeded:
200
+ logger.error(Message.ERR_PARAMETRIC_CSV_CONTAINS_ERROR)
201
+ logger.error(error_msg)
202
+ raise SolveError
203
+
61
204
  # load dll and set target femtet
62
205
  dll = _get_dll_with_set_femtet(Femtet)
63
206
  dll.GetPrmResult.restype = ctypes.c_double # 複素数の場合は実部しか取らない
@@ -65,9 +208,65 @@ def _parametric_objective(Femtet, parametric_result_index):
65
208
 
66
209
 
67
210
  def solve_via_parametric_dll(Femtet) -> bool:
211
+ csv_processor = get_csv_processor(Femtet)
212
+
213
+ # remove previous csv if exists
214
+ # 消さなくても解析はできるが
215
+ # エラーハンドリングのため
216
+ csv_processor.refresh_csv()
217
+
68
218
  # load dll and set target femtet
69
219
  dll = _get_dll_with_set_femtet(Femtet)
220
+
221
+ # reset existing sweep table
222
+ dll.ClearPrmSweepTable.restype = ctypes.c_bool
223
+ succeed = dll.ClearPrmSweepTable()
224
+ if not succeed:
225
+ logger.error('Failed to remove existing sweep table!') # 通常ありえないので error
226
+ return False
227
+
70
228
  # solve
71
229
  dll.PrmCalcExecute.restype = ctypes.c_bool
72
230
  succeed = dll.PrmCalcExecute()
231
+ if not succeed:
232
+ logger.warning('Failed to solve!') # 通常起こりえるので warn
233
+ return False
234
+
235
+ # Check post-processing error
236
+ # 現時点では table csv に index の情報がないので、
237
+ # エラーがどの番号のものかわからない。
238
+ # ただし、エラーがそのまま出力されるよりマシなので
239
+ # 安全目に引っ掛けることにする
240
+ succeed = csv_processor.check_csv_after_succeeded_PrmCalcExecute()
241
+ if not succeed:
242
+ logger.error('Failed to save parametric result csv!')
243
+ return False # 通常ありえないので error
244
+
73
245
  return succeed # 成功した場合はTRUE、失敗した場合はFALSEを返す
246
+
247
+
248
+ if __name__ == '__main__':
249
+ from win32com.client import Dispatch
250
+
251
+ g_Femtet = Dispatch('FemtetMacro.Femtet')
252
+ g_dll = _get_dll_with_set_femtet(g_Femtet)
253
+
254
+ # solve
255
+ g_succeeded = solve_via_parametric_dll(g_Femtet)
256
+ if not g_succeeded:
257
+ g_dll.GetLastErrorMsg.restype = ctypes.c_char_p # or wintypes.LPCSTR
258
+ g_error_msg: bytes = g_dll.GetLastErrorMsg()
259
+ g_error_msg: str = g_error_msg.decode(encoding='932')
260
+
261
+ # 結果取得:内部的にはエラーになっているはず
262
+ g_parametric_result_index = 1
263
+ g_dll = _get_dll_with_set_femtet(g_Femtet)
264
+ g_dll.GetPrmResult.restype = ctypes.c_double # 複素数やベクトルの場合は実部や第一成分しか取らない PIF の仕様
265
+ g_output = g_dll.GetPrmResult(g_parametric_result_index)
266
+
267
+ # ... だが、下記のコードでそれは出てこない。
268
+ # 値が実際に 0 である場合と切り分けられないので、
269
+ # csv を見てエラーがあるかどうか判断せざるを得ない。
270
+ g_dll.GetLastErrorMsg.restype = ctypes.c_char_p # or wintypes.LPCSTR
271
+ g_error_msg: bytes = g_dll.GetLastErrorMsg()
272
+ g_error_msg: str = g_error_msg.decode(encoding='932')
@@ -17,6 +17,7 @@ class SurrogateModelInterfaceBase(FEMInterface, ABC):
17
17
  self,
18
18
  history_path: str = None,
19
19
  history: History = None,
20
+ override_objective: bool = True,
20
21
  ):
21
22
 
22
23
  self.history: History
@@ -25,6 +26,7 @@ class SurrogateModelInterfaceBase(FEMInterface, ABC):
25
26
  self.obj: dict[str, float] = dict()
26
27
  self.df_prm: pd.DataFrame
27
28
  self.df_obj: pd.DataFrame
29
+ self.override_objective: bool = override_objective
28
30
 
29
31
  # history_path が与えられた場合、history をコンストラクトする
30
32
  if history_path is not None:
@@ -57,27 +59,25 @@ class SurrogateModelInterfaceBase(FEMInterface, ABC):
57
59
  FEMInterface.__init__(
58
60
  self,
59
61
  history=history, # コンストラクト済み history を渡せば並列計算時も何もしなくてよい
62
+ override_objective=self.override_objective
60
63
  )
61
64
 
62
65
  def filter_feasible(self, x: np.ndarray, y: np.ndarray, return_feasibility=False):
63
66
  feasible_idx = np.where(~np.isnan(y.sum(axis=1)))
64
67
  if return_feasibility:
65
68
  # calculated or not
66
- y = np.zeros_like(y)
67
- y[feasible_idx] = 1.
68
- # satisfy weak feasibility or not
69
- infeasible_idx = np.where(~self.history.get_df()['feasible'].values)
70
- y[infeasible_idx] = .0
71
- return x, y.reshape((-1, 1))
69
+ feas = np.zeros((len(y), 1), dtype=float)
70
+ feas[feasible_idx] = 1.
71
+ return x, feas
72
72
  else:
73
73
  return x[feasible_idx], y[feasible_idx]
74
74
 
75
75
  def _setup_after_parallel(self, *args, **kwargs):
76
-
77
- opt: AbstractOptimizer = kwargs['opt']
78
- obj: Objective
79
- for obj_name, obj in opt.objectives.items():
80
- obj.fun = lambda: self.obj[obj_name]
76
+ if self.override_objective:
77
+ opt: AbstractOptimizer = kwargs['opt']
78
+ obj: Objective
79
+ for obj_name, obj in opt.objectives.items():
80
+ obj.fun = lambda obj_name_=obj_name: self.obj[obj_name_]
81
81
 
82
82
  def update_parameter(self, parameters: pd.DataFrame, with_warning=False) -> Optional[List[str]]:
83
83
  for i, row in parameters.iterrows():
@@ -27,8 +27,8 @@ class PoFBoTorchInterface(SurrogateModelInterfaceBase):
27
27
  def train_f(self):
28
28
  # df そのまま用いて training する
29
29
  x, y = self.filter_feasible(self.df_prm.values, self.df_obj.values, return_feasibility=True)
30
- if y.min() == 1:
31
- self.model_f.predict = lambda *args, **kwargs: (1., 0.001)
30
+ if y.min() == 1: # feasible values only
31
+ self.model_f.predict = lambda *args, **kwargs: (1., 0.001) # mean, std
32
32
  self.model_f.fit(x, y)
33
33
 
34
34
  def _setup_after_parallel(self, *args, **kwargs):
@@ -64,7 +64,8 @@ class PoFBoTorchInterface(SurrogateModelInterfaceBase):
64
64
  raise SolveError(Msg.INFO_POF_IS_LESS_THAN_THRESHOLD)
65
65
 
66
66
  # 実際の計算(mean は history.obj_names 順)
67
- mean, _ = self.model.predict(np.array([x]))
67
+ _mean, _std = self.model.predict(np.array([x]))
68
+ mean = _mean[0]
68
69
 
69
70
  # 目的関数の更新
70
71
  self.obj = {obj_name: value for obj_name, value in zip(self.history.obj_names, mean)}
@@ -42,16 +42,16 @@ class SingleTaskGPModel(PredictionModelBase):
42
42
  def set_bounds_from_history(self, history, df=None):
43
43
  from pyfemtet.opt._femopt_core import History
44
44
  history: History
45
- metadata: str
45
+ meta_column: str
46
46
 
47
47
  if df is None:
48
48
  df = history.get_df()
49
49
 
50
50
  columns = df.columns
51
- metadata_columns = history.metadata
51
+
52
52
  target_columns = [
53
- col for col, metadata in zip(columns, metadata_columns)
54
- if metadata == 'prm_lb' or metadata == 'prm_ub'
53
+ col for col, meta_column in zip(columns, history.meta_columns)
54
+ if meta_column == 'prm_lb' or meta_column == 'prm_ub'
55
55
  ]
56
56
 
57
57
  bounds_buff = df.iloc[0][target_columns].values # 2*len(prm_names) array
@@ -85,6 +85,8 @@ class SingleTaskGPModel(PredictionModelBase):
85
85
  fit_gpytorch_mll(mll)
86
86
 
87
87
  def predict(self, x: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
88
+ assert len(x.shape) >= 2
89
+
88
90
  X = tensor(x)
89
91
 
90
92
  post = self.gp.posterior(X)