pwact 0.1.20__py3-none-any.whl → 0.1.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pwact/active_learning/explore/select_image.py +60 -18
- pwact/active_learning/user_input/cmd_infos.py +10 -0
- pwact/main.py +29 -2
- {pwact-0.1.20.dist-info → pwact-0.1.22.dist-info}/METADATA +1 -1
- {pwact-0.1.20.dist-info → pwact-0.1.22.dist-info}/RECORD +9 -9
- {pwact-0.1.20.dist-info → pwact-0.1.22.dist-info}/LICENSE +0 -0
- {pwact-0.1.20.dist-info → pwact-0.1.22.dist-info}/WHEEL +0 -0
- {pwact-0.1.20.dist-info → pwact-0.1.22.dist-info}/entry_points.txt +0 -0
- {pwact-0.1.20.dist-info → pwact-0.1.22.dist-info}/top_level.txt +0 -0
|
@@ -118,7 +118,51 @@ def select_image(
|
|
|
118
118
|
print("Image select result:\n {}\n\n".format(summary_info))
|
|
119
119
|
return summary
|
|
120
120
|
|
|
121
|
-
|
|
121
|
+
def print_select_image(
|
|
122
|
+
md_dir:str,
|
|
123
|
+
save_dir:str,
|
|
124
|
+
devi_name:str,
|
|
125
|
+
lower:float,
|
|
126
|
+
higer:float
|
|
127
|
+
):
|
|
128
|
+
#1. get model_deviation file
|
|
129
|
+
model_deviation_patten = "{}/{}".format(get_sub_md_sys_template_name(), devi_name)
|
|
130
|
+
model_devi_files = search_files(md_dir, model_deviation_patten)
|
|
131
|
+
model_devi_files = sorted(model_devi_files)
|
|
132
|
+
md_sys_dict = sort_model_devi_files(model_devi_files)
|
|
133
|
+
|
|
134
|
+
error_pd =None
|
|
135
|
+
accurate_pd =None
|
|
136
|
+
rand_candi =None
|
|
137
|
+
remove_candi =None
|
|
138
|
+
|
|
139
|
+
for md in md_sys_dict.keys():
|
|
140
|
+
sys_dict = md_sys_dict[md]
|
|
141
|
+
for sys_idx, sys in enumerate(sys_dict.keys()):
|
|
142
|
+
devi_files = sys_dict[sys]
|
|
143
|
+
tmp_devi_pd, _base_kpu = read_pd_files(devi_files)
|
|
144
|
+
if len(_base_kpu) > 0: # for kpu upper and lower
|
|
145
|
+
_lower = np.mean(_base_kpu)*lower
|
|
146
|
+
_higer = _lower * higer
|
|
147
|
+
else:
|
|
148
|
+
_lower = lower
|
|
149
|
+
_higer = higer
|
|
150
|
+
tmp_error_pd, tmp_accurate_pd, tmp_rand_candi, tmp_remove_candi = select_pd(tmp_devi_pd, _lower, _higer, 10000000)
|
|
151
|
+
error_pd = pd.concat([error_pd, tmp_error_pd]) if error_pd is not None else tmp_error_pd
|
|
152
|
+
accurate_pd = pd.concat([accurate_pd, tmp_accurate_pd]) if error_pd is not None else tmp_accurate_pd
|
|
153
|
+
rand_candi = pd.concat([rand_candi, tmp_rand_candi]) if error_pd is not None else tmp_rand_candi
|
|
154
|
+
remove_candi = pd.concat([remove_candi, tmp_remove_candi]) if error_pd is not None else tmp_remove_candi
|
|
155
|
+
summary_info, summary = count_info(save_dir, error_pd, accurate_pd, rand_candi, remove_candi)
|
|
156
|
+
|
|
157
|
+
# summary_info, summary = select_image(save_dir=self.select_dir,
|
|
158
|
+
# devi_pd=devi_pd,
|
|
159
|
+
# lower=self.input_param.strategy.lower_model_deiv_f,
|
|
160
|
+
# higer=self.input_param.strategy.upper_model_deiv_f,
|
|
161
|
+
# max_select=self.input_param.strategy.max_select)
|
|
162
|
+
print("Image select result (lower {} upper {}):\n {}\n\n".format(lower, higer, summary_info))
|
|
163
|
+
return summary
|
|
164
|
+
|
|
165
|
+
|
|
122
166
|
def select_pd(devi_pd:DataFrame, lower:float, higer:float, max_select:float):
|
|
123
167
|
accurate_pd = devi_pd[devi_pd[EXPLORE_FILE_STRUCTURE.devi_columns[0]] < lower]
|
|
124
168
|
candidate_pd = devi_pd[(devi_pd[EXPLORE_FILE_STRUCTURE.devi_columns[0]] >= lower) & (devi_pd[EXPLORE_FILE_STRUCTURE.devi_columns[0]] < higer)]
|
|
@@ -135,7 +179,7 @@ def select_pd(devi_pd:DataFrame, lower:float, higer:float, max_select:float):
|
|
|
135
179
|
return error_pd, accurate_pd, cand_rand_candi, cand_remove_candi
|
|
136
180
|
|
|
137
181
|
def read_pd_files(model_devi_files:list[str]):
|
|
138
|
-
devi_pd = pd.DataFrame(columns=EXPLORE_FILE_STRUCTURE.devi_columns
|
|
182
|
+
devi_pd = pd.DataFrame()#columns=EXPLORE_FILE_STRUCTURE.devi_columns
|
|
139
183
|
base_force_kpu = []
|
|
140
184
|
if os.path.basename(model_devi_files[0]) == EXPLORE_FILE_STRUCTURE.kpu_model_devi:
|
|
141
185
|
for devi_file in model_devi_files:
|
|
@@ -169,41 +213,39 @@ def read_pd_files(model_devi_files:list[str]):
|
|
|
169
213
|
|
|
170
214
|
def count_info(save_dir, error_pd, accurate_pd, rand_candi, remove_candi):
|
|
171
215
|
#5. save select info
|
|
172
|
-
if not os.path.exists(save_dir):
|
|
173
|
-
os.makedirs(save_dir)
|
|
174
216
|
total_num = error_pd.shape[0] + accurate_pd.shape[0] + rand_candi.shape[0] + remove_candi.shape[0]
|
|
175
217
|
cand_num = rand_candi.shape[0] + remove_candi.shape[0]
|
|
176
218
|
summary = "Total structures {} accurate {} rate {:.2f}% selected {} rate {:.2f}% error {} rate {:.2f}%\n"\
|
|
177
219
|
.format(total_num, accurate_pd.shape[0], accurate_pd.shape[0]/total_num*100, \
|
|
178
220
|
cand_num, cand_num/total_num*100, \
|
|
179
221
|
error_pd.shape[0], error_pd.shape[0]/total_num*100)
|
|
180
|
-
|
|
181
|
-
accurate_pd.to_csv(os.path.join(save_dir, EXPLORE_FILE_STRUCTURE.accurate))
|
|
182
222
|
candi_info = ""
|
|
183
|
-
rand_candi.to_csv(os.path.join(save_dir, EXPLORE_FILE_STRUCTURE.candidate))
|
|
184
|
-
|
|
185
223
|
if remove_candi.shape[0] == 0:
|
|
186
224
|
candi_info += "Candidate configurations: {}\n Select details in file {}\n".format(
|
|
187
225
|
cand_num, EXPLORE_FILE_STRUCTURE.candidate)
|
|
188
226
|
else:
|
|
189
|
-
remove_candi.to_csv(os.path.join(save_dir, EXPLORE_FILE_STRUCTURE.candidate_delete))
|
|
190
227
|
candi_info += "Candidate configurations: {}, randomly select {}, delete {}\n Select details in file {}\n Delete details in file {}.\n".format(
|
|
191
228
|
cand_num, rand_candi.shape[0], remove_candi.shape[0],\
|
|
192
229
|
EXPLORE_FILE_STRUCTURE.candidate, EXPLORE_FILE_STRUCTURE.candidate_delete)
|
|
193
|
-
|
|
194
|
-
error_pd.to_csv(os.path.join(save_dir, EXPLORE_FILE_STRUCTURE.failed))
|
|
195
|
-
|
|
196
230
|
summary_info = ""
|
|
197
|
-
|
|
198
231
|
summary_info += summary
|
|
199
232
|
summary_info += "\nSelect by model deviation force:\n"
|
|
200
233
|
summary_info += "Accurate configurations: {}, details in file {}\n".\
|
|
201
234
|
format(accurate_pd.shape[0], EXPLORE_FILE_STRUCTURE.accurate)
|
|
202
|
-
|
|
203
235
|
summary_info += candi_info
|
|
204
|
-
|
|
205
236
|
summary_info += "Error configurations: {}, details in file {}\n".\
|
|
206
237
|
format(error_pd.shape[0], EXPLORE_FILE_STRUCTURE.failed)
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
238
|
+
|
|
239
|
+
if save_dir is not None:
|
|
240
|
+
if not os.path.exists(save_dir):
|
|
241
|
+
os.makedirs(save_dir)
|
|
242
|
+
accurate_pd.to_csv(os.path.join(save_dir, EXPLORE_FILE_STRUCTURE.accurate))
|
|
243
|
+
rand_candi.to_csv(os.path.join(save_dir, EXPLORE_FILE_STRUCTURE.candidate))
|
|
244
|
+
if remove_candi.shape[0] > 0:
|
|
245
|
+
remove_candi.to_csv(os.path.join(save_dir, EXPLORE_FILE_STRUCTURE.candidate_delete))
|
|
246
|
+
error_pd.to_csv(os.path.join(save_dir, EXPLORE_FILE_STRUCTURE.failed))
|
|
247
|
+
write_to_file(os.path.join(save_dir, EXPLORE_FILE_STRUCTURE.select_summary), summary_info, "w")
|
|
248
|
+
|
|
249
|
+
return summary_info, summary
|
|
250
|
+
|
|
251
|
+
|
|
@@ -39,6 +39,8 @@ def cmd_infos(cmd_type=None):
|
|
|
39
39
|
cmd_info = cmd_info_run_iter()
|
|
40
40
|
elif cmd_type == "kill":
|
|
41
41
|
cmd_info = cmd_info_kill()
|
|
42
|
+
elif cmd_type == "filter":
|
|
43
|
+
cmd_info = cmd_info_filter()
|
|
42
44
|
print(cmd_info)
|
|
43
45
|
|
|
44
46
|
|
|
@@ -64,3 +66,11 @@ def cmd_info_kill():
|
|
|
64
66
|
cmd_info += "'pwact kill init_bulk' for 'init_bulk' tasks\n"
|
|
65
67
|
cmd_info += "'pwact kill run' for 'run' tasks\n\n"
|
|
66
68
|
return cmd_info
|
|
69
|
+
|
|
70
|
+
def cmd_info_filter():
|
|
71
|
+
cmd_info = ""
|
|
72
|
+
cmd_info += "filter" + "\n"
|
|
73
|
+
cmd_info += "you could use this method to test the selection results corresponding to the upper and lower limit settings\n"
|
|
74
|
+
cmd_info += "example:\n"
|
|
75
|
+
cmd_info += "'pwact filter -i iter.0000/explore/md -l 0.01 -u 0.02 -s'\n\n"
|
|
76
|
+
return cmd_info
|
pwact/main.py
CHANGED
|
@@ -23,7 +23,7 @@ from pwact.active_learning.init_bulk.init_bulk_run import init_bulk_run, scancel
|
|
|
23
23
|
from pwact.active_learning.environment import check_envs
|
|
24
24
|
|
|
25
25
|
from pwact.data_format.configop import extract_pwdata
|
|
26
|
-
from pwact.active_learning.explore.select_image import select_image
|
|
26
|
+
from pwact.active_learning.explore.select_image import select_image, print_select_image
|
|
27
27
|
from pwact.utils.process_tool import kill_process
|
|
28
28
|
def run_iter():
|
|
29
29
|
system_json = json.load(open(sys.argv[2]))
|
|
@@ -299,6 +299,26 @@ def kill_job():
|
|
|
299
299
|
|
|
300
300
|
# for run iters jobs
|
|
301
301
|
|
|
302
|
+
def filter_test(input_cmds):
|
|
303
|
+
parser = argparse.ArgumentParser()
|
|
304
|
+
parser.add_argument('-i', '--md_dir', help="specify the md dir such as 'iter.0000/temp_run_iter_work/explore/md' or iter.0000/explore/md", type=str, required=True)
|
|
305
|
+
parser.add_argument('-l', '--lower', help="specify lower limit value", type=float, required=True)
|
|
306
|
+
parser.add_argument('-u', '--upper', help="specify upper limit value", type=float, required=True)
|
|
307
|
+
parser.add_argument('-s', '--save', action='store_true', help="if '-s' is set, save the detailed information of the selected configs to CSV files")
|
|
308
|
+
|
|
309
|
+
args = parser.parse_args(input_cmds)
|
|
310
|
+
if not os.path.exists(args.md_dir):
|
|
311
|
+
raise Exception("ERROR! The input md_dir {} not found!".format(args.md_dir))
|
|
312
|
+
|
|
313
|
+
save_dir = os.path.join(os.getcwd(), "filter_test_result") if args.save else None
|
|
314
|
+
summary = print_select_image(
|
|
315
|
+
md_dir=args.md_dir,
|
|
316
|
+
save_dir=save_dir,
|
|
317
|
+
devi_name=EXPLORE_FILE_STRUCTURE.get_devi_name(UNCERTAINTY.committee),
|
|
318
|
+
lower=args.lower,
|
|
319
|
+
higer=args.upper
|
|
320
|
+
)
|
|
321
|
+
|
|
302
322
|
def main():
|
|
303
323
|
environment_check()
|
|
304
324
|
if len(sys.argv) == 1 or "-h".upper() == sys.argv[1].upper() or \
|
|
@@ -343,7 +363,14 @@ def main():
|
|
|
343
363
|
cmd_infos("kill")
|
|
344
364
|
else:
|
|
345
365
|
kill_job()
|
|
346
|
-
|
|
366
|
+
|
|
367
|
+
elif "filter_test".upper() == sys.argv[1].upper() or "filter".upper() == sys.argv[1].upper():
|
|
368
|
+
if len(sys.argv) == 2 or "-h".upper() == sys.argv[2].upper() or \
|
|
369
|
+
"help".upper() == sys.argv[2].upper() or "-help".upper() == sys.argv[2].upper() or "--help".upper() == sys.argv[2].upper():
|
|
370
|
+
cmd_infos("filter")
|
|
371
|
+
else:
|
|
372
|
+
filter_test(sys.argv[2:])
|
|
373
|
+
|
|
347
374
|
else:
|
|
348
375
|
print("ERROR! The input cmd {} can not be recognized, please check.".format(sys.argv[1]))
|
|
349
376
|
print("\n\n\nYou can enter the following command.\n\n\n")
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
pwact/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
pwact/main.py,sha256=
|
|
2
|
+
pwact/main.py,sha256=61QOZIBvxP-ddz2uI9JD35T-LIDU09NuARbWkNOI3pA,16872
|
|
3
3
|
pwact/active_learning/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
4
|
pwact/active_learning/environment.py,sha256=KvyMaOXrM-HMMma4SnoOQFO6fZxDsk0Fsyyy7xqfGCo,684
|
|
5
5
|
pwact/active_learning/explore/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
6
|
pwact/active_learning/explore/run_model_md.py,sha256=uSW-ZaXH7NHexnMDm3Qb7ny4rRhk9KZ3dds2D9AVDeo,19891
|
|
7
|
-
pwact/active_learning/explore/select_image.py,sha256=
|
|
7
|
+
pwact/active_learning/explore/select_image.py,sha256=dmsMoxFwQ7JDPHK2vRFzUfXeEovT86cmUnmEm-qrfwE,12665
|
|
8
8
|
pwact/active_learning/init_bulk/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
9
|
pwact/active_learning/init_bulk/aimd.py,sha256=XzDlX2vylaljQKoUnv6nrI2NfiOdHZpq8qr3DenA1F4,10465
|
|
10
10
|
pwact/active_learning/init_bulk/duplicate_scale.py,sha256=Z9mYBASy9gNLTV_db8lqfXpcahHbrQq6emae169P2wM,9468
|
|
@@ -22,7 +22,7 @@ pwact/active_learning/train/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMp
|
|
|
22
22
|
pwact/active_learning/train/dp_kpu.py,sha256=GkGKEGhLmOvPERqgTkf_0_vD9zOEPlBX2N7vuSQG_-c,9317
|
|
23
23
|
pwact/active_learning/train/train_model.py,sha256=NXqTKrl7Lb_UGt2-Lq_gLM9iZDAZFY09nxP_aGtBdrE,10525
|
|
24
24
|
pwact/active_learning/user_input/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
25
|
-
pwact/active_learning/user_input/cmd_infos.py,sha256=
|
|
25
|
+
pwact/active_learning/user_input/cmd_infos.py,sha256=qyUYxsxnqZS2hxDbbeyyqsB2kps5z5nBCBa2iK0QE68,3372
|
|
26
26
|
pwact/active_learning/user_input/init_bulk_input.py,sha256=NCUAB1xpa67MXHRRWAkmWZjkQbTibE4EjmtlOmr0HdQ,6762
|
|
27
27
|
pwact/active_learning/user_input/iter_input.py,sha256=zwFCI6dDFIEVWitEqBBZP7TtrSvxMbmSUEPiM0Z8ZOk,10732
|
|
28
28
|
pwact/active_learning/user_input/resource.py,sha256=bg1rkdjIzuj7pDUvp6h1yMWFT0PqYzH4BfX5tJ7MZzc,6812
|
|
@@ -52,9 +52,9 @@ pwact/utils/app_lib/cp2k.py,sha256=ljhCCHmZ2kfoXEXn5O7-D56EgTLn2a7H3y_TIkHiasY,1
|
|
|
52
52
|
pwact/utils/app_lib/cp2k_dp.py,sha256=VP4gyPGhLcMAqAjrqCQSUiiGlESNlyYz7Gs3Q4QoUHo,6912
|
|
53
53
|
pwact/utils/app_lib/lammps.py,sha256=2oxHJHdDxfDDWWmnjo0gMNwgGvxABwuDgDlb8kbhgfk,8037
|
|
54
54
|
pwact/utils/app_lib/pwmat.py,sha256=PTRPkG_d00ibGhpCe2-4M7MW3dx2ZuAyb9hT2jl_LAs,18047
|
|
55
|
-
pwact-0.1.
|
|
56
|
-
pwact-0.1.
|
|
57
|
-
pwact-0.1.
|
|
58
|
-
pwact-0.1.
|
|
59
|
-
pwact-0.1.
|
|
60
|
-
pwact-0.1.
|
|
55
|
+
pwact-0.1.22.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
56
|
+
pwact-0.1.22.dist-info/METADATA,sha256=Gd3Y11fJq-d9Pt-06_RMwmZTQj_dSbN1c9VyeZ5YdD4,3659
|
|
57
|
+
pwact-0.1.22.dist-info/WHEEL,sha256=2wepM1nk4DS4eFpYrW1TTqPcoGNfHhhO_i5m4cOimbo,92
|
|
58
|
+
pwact-0.1.22.dist-info/entry_points.txt,sha256=p61auAnpbn8E2WjvHNBA7rb9_NRAOCew4zdcCj33cGc,42
|
|
59
|
+
pwact-0.1.22.dist-info/top_level.txt,sha256=fY1_7sH5Lke4dC9L8MbYM4fT5aat5eCkAmpkIzY1SlM,6
|
|
60
|
+
pwact-0.1.22.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|