pwact 0.2.0__py3-none-any.whl → 0.2.2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pwact/active_learning/explore/run_model_md.py +110 -0
- pwact/active_learning/explore/select_image.py +1 -1
- pwact/active_learning/init_bulk/direct.py +182 -0
- pwact/active_learning/init_bulk/duplicate_scale.py +1 -1
- pwact/active_learning/init_bulk/explore.py +301 -0
- pwact/active_learning/init_bulk/init_bulk_run.py +78 -48
- pwact/active_learning/init_bulk/relabel.py +149 -120
- pwact/active_learning/label/labeling.py +125 -11
- pwact/active_learning/user_input/init_bulk_input.py +55 -6
- pwact/active_learning/user_input/iter_input.py +12 -0
- pwact/active_learning/user_input/resource.py +18 -6
- pwact/active_learning/user_input/scf_param.py +24 -6
- pwact/active_learning/user_input/train_param/optimizer_param.py +1 -1
- pwact/main.py +17 -7
- pwact/utils/app_lib/do_direct_sample.py +145 -0
- pwact/utils/app_lib/do_eqv2model.py +41 -0
- pwact/utils/constant.py +31 -11
- pwact/utils/file_operation.py +12 -5
- {pwact-0.2.0.dist-info → pwact-0.2.2.dev0.dist-info}/METADATA +1 -1
- {pwact-0.2.0.dist-info → pwact-0.2.2.dev0.dist-info}/RECORD +24 -20
- {pwact-0.2.0.dist-info → pwact-0.2.2.dev0.dist-info}/LICENSE +0 -0
- {pwact-0.2.0.dist-info → pwact-0.2.2.dev0.dist-info}/WHEEL +0 -0
- {pwact-0.2.0.dist-info → pwact-0.2.2.dev0.dist-info}/entry_points.txt +0 -0
- {pwact-0.2.0.dist-info → pwact-0.2.2.dev0.dist-info}/top_level.txt +0 -0
|
@@ -35,6 +35,7 @@ from pwact.utils.file_operation import write_to_file, copy_file, copy_dir, searc
|
|
|
35
35
|
from pwact.utils.app_lib.common import link_pseudo_by_atom, set_input_script
|
|
36
36
|
|
|
37
37
|
from pwact.data_format.configop import extract_pwdata, save_config, get_atom_type
|
|
38
|
+
from pwdata import Config
|
|
38
39
|
class Labeling(object):
|
|
39
40
|
@staticmethod
|
|
40
41
|
def kill_job(root_dir:str, itername:str):
|
|
@@ -59,9 +60,10 @@ class Labeling(object):
|
|
|
59
60
|
self.real_explore_dir = os.path.join(self.input_param.root_dir, itername, AL_STRUCTURE.explore)
|
|
60
61
|
self.md_dir = os.path.join(self.explore_dir, EXPLORE_FILE_STRUCTURE.md)
|
|
61
62
|
self.select_dir = os.path.join(self.explore_dir, EXPLORE_FILE_STRUCTURE.select)
|
|
63
|
+
self.direct_dir = os.path.join(self.explore_dir, EXPLORE_FILE_STRUCTURE.direct)
|
|
62
64
|
self.real_md_dir = os.path.join(self.real_explore_dir, EXPLORE_FILE_STRUCTURE.md)
|
|
63
65
|
self.real_select_dir = os.path.join(self.real_explore_dir, EXPLORE_FILE_STRUCTURE.select)
|
|
64
|
-
|
|
66
|
+
self.real_direct_dir = os.path.join(self.real_explore_dir, EXPLORE_FILE_STRUCTURE.direct)
|
|
65
67
|
# labed work dir
|
|
66
68
|
self.label_dir = os.path.join(self.input_param.root_dir, itername, TEMP_STRUCTURE.tmp_run_iter_dir, AL_STRUCTURE.labeling)
|
|
67
69
|
self.scf_dir = os.path.join(self.label_dir, LABEL_FILE_STRUCTURE.scf)
|
|
@@ -71,6 +73,9 @@ class Labeling(object):
|
|
|
71
73
|
self.real_scf_dir = os.path.join(self.real_label_dir, LABEL_FILE_STRUCTURE.scf)
|
|
72
74
|
self.real_result_dir = os.path.join(self.real_label_dir, LABEL_FILE_STRUCTURE.result)
|
|
73
75
|
|
|
76
|
+
self.bigmodel_dir = os.path.join(self.label_dir, LABEL_FILE_STRUCTURE.bigmodel)
|
|
77
|
+
self.real_bigmodel_dir = os.path.join(self.real_label_dir, LABEL_FILE_STRUCTURE.bigmodel)
|
|
78
|
+
|
|
74
79
|
'''
|
|
75
80
|
description:
|
|
76
81
|
the scf work dir file structure is as follow.
|
|
@@ -86,9 +91,8 @@ class Labeling(object):
|
|
|
86
91
|
return {*}
|
|
87
92
|
author: wuxingxing
|
|
88
93
|
'''
|
|
94
|
+
|
|
89
95
|
def make_scf_work(self):
|
|
90
|
-
# read select info, and make scf
|
|
91
|
-
# ["devi_force", "file_path", "config_index"]
|
|
92
96
|
candidate = pd.read_csv(os.path.join(self.select_dir, EXPLORE_FILE_STRUCTURE.candidate))
|
|
93
97
|
# make scf work dir
|
|
94
98
|
scf_dir_list = []
|
|
@@ -108,14 +112,51 @@ class Labeling(object):
|
|
|
108
112
|
atom_names = line.split()
|
|
109
113
|
self.make_scf_file(scf_sub_md_sys_path, tarj_lmp, atom_names)
|
|
110
114
|
scf_dir_list.append(scf_sub_md_sys_path)
|
|
111
|
-
|
|
115
|
+
|
|
112
116
|
self.make_scf_slurm_job_files(scf_dir_list)
|
|
113
117
|
|
|
118
|
+
def make_bigmodel_work(self):
|
|
119
|
+
# copy from realdir/direct/select.xyz
|
|
120
|
+
if self.input_param.strategy.direct:
|
|
121
|
+
copy_file(os.path.join(self.real_direct_dir, EXPLORE_FILE_STRUCTURE.select_xyz),
|
|
122
|
+
os.path.join(self.bigmodel_dir, EXPLORE_FILE_STRUCTURE.select_xyz))
|
|
123
|
+
else:
|
|
124
|
+
# copy trajs to bigmodel_dir and cvt to xyz
|
|
125
|
+
candidate = pd.read_csv(os.path.join(self.select_dir, EXPLORE_FILE_STRUCTURE.candidate))
|
|
126
|
+
# make scf work dir
|
|
127
|
+
image_list = None
|
|
128
|
+
for index, row in candidate.iterrows():
|
|
129
|
+
config_index = int(row["config_index"])
|
|
130
|
+
sub_md_sys_path = row["file_path"]
|
|
131
|
+
atom_names = None
|
|
132
|
+
with open(os.path.join(sub_md_sys_path, LAMMPS.atom_type_file), 'r') as rf:
|
|
133
|
+
line = rf.readline()
|
|
134
|
+
atom_names = line.split()
|
|
135
|
+
if image_list is None:
|
|
136
|
+
image_list = Config(data_path=os.path.join(sub_md_sys_path, EXPLORE_FILE_STRUCTURE.traj, "{}{}".format(config_index, LAMMPS.traj_postfix)),
|
|
137
|
+
format=PWDATA.lammps_dump, atom_names=atom_names)
|
|
138
|
+
else:
|
|
139
|
+
image_list.append(Config(data_path=os.path.join(sub_md_sys_path, EXPLORE_FILE_STRUCTURE.traj, "{}{}".format(config_index, LAMMPS.traj_postfix)),
|
|
140
|
+
format=PWDATA.lammps_dump, atom_names=atom_names))
|
|
141
|
+
# cvt_lammps.dump to extxyz
|
|
142
|
+
image_list.to(data_path=self.bigmodel_dir, format=PWDATA.extxyz, data_name="{}".format(EXPLORE_FILE_STRUCTURE.select_xyz))
|
|
143
|
+
# copy bigmodelscript
|
|
144
|
+
copy_file(self.input_param.scf.bigmodel_script, os.path.join(self.bigmodel_dir, os.path.basename(self.input_param.scf.bigmodel_script)))
|
|
145
|
+
# make slrum file
|
|
146
|
+
self.make_bigmodel_slurm_job_files([self.bigmodel_dir])
|
|
147
|
+
|
|
114
148
|
def back_label(self):
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
149
|
+
if self.input_param.scf.dft_style == DFT_STYLE.bigmodel:
|
|
150
|
+
slurm_remain, slurm_success = get_slurm_job_run_info(self.real_bigmodel_dir, \
|
|
151
|
+
job_patten="*-{}".format(LABEL_FILE_STRUCTURE.bigmodel_job), \
|
|
152
|
+
tag_patten="*-{}".format(LABEL_FILE_STRUCTURE.bigmodel_tag))
|
|
153
|
+
slurm_done = True if len(slurm_remain) == 0 and len(slurm_success) > 0 else False
|
|
154
|
+
else:
|
|
155
|
+
slurm_remain, slurm_success = get_slurm_job_run_info(self.real_scf_dir, \
|
|
156
|
+
job_patten="*-{}".format(LABEL_FILE_STRUCTURE.scf_job), \
|
|
157
|
+
tag_patten="*-{}".format(LABEL_FILE_STRUCTURE.scf_tag))
|
|
158
|
+
slurm_done = True if len(slurm_remain) == 0 and len(slurm_success) > 0 else False
|
|
159
|
+
|
|
119
160
|
if slurm_done:
|
|
120
161
|
# bk and do new job
|
|
121
162
|
target_bk_file = add_postfix_dir(self.real_label_dir, postfix_str="bk")
|
|
@@ -147,7 +188,31 @@ class Labeling(object):
|
|
|
147
188
|
mission.commit_jobs()
|
|
148
189
|
mission.check_running_job()
|
|
149
190
|
mission.all_job_finished(error_type=SLURM_OUT.dft_out)
|
|
150
|
-
|
|
191
|
+
|
|
192
|
+
def do_bigmodel_jobs(self):
|
|
193
|
+
mission = Mission()
|
|
194
|
+
slurm_remain, slurm_success = get_slurm_job_run_info(self.bigmodel_dir, \
|
|
195
|
+
job_patten="*-{}".format(LABEL_FILE_STRUCTURE.bigmodel_job), \
|
|
196
|
+
tag_patten="*-{}".format(LABEL_FILE_STRUCTURE.bigmodel_tag))
|
|
197
|
+
slurm_done = True if len(slurm_remain) == 0 and len(slurm_success) > 0 else False
|
|
198
|
+
if slurm_done is False:
|
|
199
|
+
#recover slurm jobs
|
|
200
|
+
if len(slurm_remain) > 0:
|
|
201
|
+
print("Run bigModel Job:\n")
|
|
202
|
+
print(slurm_remain)
|
|
203
|
+
for i, script_path in enumerate(slurm_remain):
|
|
204
|
+
slurm_job = SlurmJob()
|
|
205
|
+
tag_name = "{}-{}".format(os.path.basename(script_path).split('-')[0].strip(), LABEL_FILE_STRUCTURE.bigmodel_tag)
|
|
206
|
+
tag = os.path.join(os.path.dirname(script_path),tag_name)
|
|
207
|
+
slurm_job.set_tag(tag)
|
|
208
|
+
slurm_job.set_cmd(script_path)
|
|
209
|
+
mission.add_job(slurm_job)
|
|
210
|
+
|
|
211
|
+
if len(mission.job_list) > 0:
|
|
212
|
+
mission.commit_jobs()
|
|
213
|
+
mission.check_running_job()
|
|
214
|
+
mission.all_job_finished()
|
|
215
|
+
|
|
151
216
|
def make_scf_file(self, scf_dir:str, tarj_lmp:str, atom_names:list[str]=None):
|
|
152
217
|
config_index = os.path.basename(tarj_lmp).split('.')[0]
|
|
153
218
|
if DFT_STYLE.vasp == self.resource.dft_style: # when do scf, the vasp input file name is 'POSCAR'
|
|
@@ -230,6 +295,42 @@ class Labeling(object):
|
|
|
230
295
|
slurm_job_file = os.path.join(self.scf_dir, slurm_script_name)
|
|
231
296
|
write_to_file(slurm_job_file, group_slurm_script, "w")
|
|
232
297
|
|
|
298
|
+
|
|
299
|
+
def make_bigmodel_slurm_job_files(self, scf_sub_list:list[str]):
|
|
300
|
+
del_file_list_by_patten(self.bigmodel_dir, "*{}".format(LABEL_FILE_STRUCTURE.scf_job))
|
|
301
|
+
group_list = split_job_for_group(1, scf_sub_list, 1)
|
|
302
|
+
|
|
303
|
+
for group_index, group in enumerate(group_list):
|
|
304
|
+
if group[0] == "NONE":
|
|
305
|
+
continue
|
|
306
|
+
|
|
307
|
+
jobname = "bigmodel{}".format(group_index)
|
|
308
|
+
tag_name = "{}-{}".format(group_index, LABEL_FILE_STRUCTURE.bigmodel_tag)
|
|
309
|
+
tag = os.path.join(self.bigmodel_dir, tag_name)
|
|
310
|
+
run_cmd = self.resource.dft_resource.command
|
|
311
|
+
# if self.resource.dft_resource.gpu_per_node > 0:
|
|
312
|
+
# run_cmd = "mpirun -np {} PWmat > {}".format(self.resource.dft_resource.gpu_per_node, SLURM_OUT.md_out)
|
|
313
|
+
# else:
|
|
314
|
+
# raise Exception("ERROR! the cpu version of pwmat not support yet!")
|
|
315
|
+
group_slurm_script = set_slurm_script_content(gpu_per_node=self.resource.dft_resource.gpu_per_node,
|
|
316
|
+
number_node = self.resource.dft_resource.number_node,
|
|
317
|
+
cpu_per_node = self.resource.dft_resource.cpu_per_node,
|
|
318
|
+
queue_name = self.resource.dft_resource.queue_name,
|
|
319
|
+
custom_flags = self.resource.dft_resource.custom_flags,
|
|
320
|
+
env_script = self.resource.dft_resource.env_script,
|
|
321
|
+
job_name = jobname,
|
|
322
|
+
run_cmd_template = run_cmd,
|
|
323
|
+
group = group,
|
|
324
|
+
job_tag = tag,
|
|
325
|
+
task_tag = LABEL_FILE_STRUCTURE.bigmodel_tag,
|
|
326
|
+
task_tag_faild = LABEL_FILE_STRUCTURE.bigmodel_tag_failed,
|
|
327
|
+
parallel_num=self.resource.dft_resource.parallel_num,
|
|
328
|
+
check_type=self.resource.dft_style
|
|
329
|
+
)
|
|
330
|
+
slurm_script_name = "{}-{}".format(group_index, LABEL_FILE_STRUCTURE.bigmodel_job)
|
|
331
|
+
slurm_job_file = os.path.join(self.bigmodel_dir, slurm_script_name)
|
|
332
|
+
write_to_file(slurm_job_file, group_slurm_script, "w")
|
|
333
|
+
|
|
233
334
|
'''
|
|
234
335
|
description:
|
|
235
336
|
collecte OUT.MLMD to mvm-
|
|
@@ -274,12 +375,12 @@ class Labeling(object):
|
|
|
274
375
|
for scf_file in scf_files:
|
|
275
376
|
scf_file_path = os.path.join(scf_dir, scf_file)
|
|
276
377
|
if scf_file.lower() in DFT_STYLE.get_scf_reserve_list(self.resource.dft_style) \
|
|
277
|
-
|
|
378
|
+
or "atom.config" in scf_file.lower() :# for the input natom.config
|
|
278
379
|
copy_file(scf_file_path, scf_file_path.replace(TEMP_STRUCTURE.tmp_run_iter_dir, ""))
|
|
279
380
|
|
|
280
381
|
# scf files to pwdata format
|
|
281
382
|
scf_configs = self.collect_scf_configs()
|
|
282
|
-
|
|
383
|
+
|
|
283
384
|
extract_pwdata(input_data_list=scf_configs,
|
|
284
385
|
intput_data_format =DFT_STYLE.get_format_by_postfix(os.path.basename(scf_configs[0])),
|
|
285
386
|
save_data_path =self.result_dir,
|
|
@@ -289,3 +390,16 @@ class Labeling(object):
|
|
|
289
390
|
)
|
|
290
391
|
# copy to main dir
|
|
291
392
|
copy_dir(self.result_dir, self.real_result_dir)
|
|
393
|
+
|
|
394
|
+
def do_post_bigmodel(self):
|
|
395
|
+
# copy the bigmodel labeled.xyz to result
|
|
396
|
+
if self.input_param.data_format == PWDATA.extxyz:
|
|
397
|
+
copy_file(os.path.join(self.bigmodel_dir, LABEL_FILE_STRUCTURE.train_xyz), os.path.join(self.result_dir, LABEL_FILE_STRUCTURE.train_xyz))
|
|
398
|
+
else:
|
|
399
|
+
images = Config(data_path=os.path.join(self.bigmodel_dir, LABEL_FILE_STRUCTURE.train_xyz), format=PWDATA.extxyz)
|
|
400
|
+
images.to(data_path=self.result_dir, format=PWDATA.pwmlff_npy)
|
|
401
|
+
# copy bigmodel dir to real dir
|
|
402
|
+
copy_dir(self.bigmodel_dir, self.real_bigmodel_dir)
|
|
403
|
+
copy_dir(self.result_dir, self.real_result_dir)
|
|
404
|
+
# del slurm logs and tags
|
|
405
|
+
del_file_list_by_patten(self.real_bigmodel_dir, "slurm-*")
|
|
@@ -26,13 +26,15 @@ class InitBulkParam(object):
|
|
|
26
26
|
sys_configs = [sys_configs]
|
|
27
27
|
|
|
28
28
|
# set sys_config detail
|
|
29
|
-
self.dft_style =
|
|
29
|
+
self.dft_style = get_parameter("dft_style", json_dict, "PWMAT").lower()
|
|
30
30
|
self.scf_style = get_parameter("scf_style", json_dict, None)
|
|
31
31
|
|
|
32
32
|
self.sys_config:list[Stage] = []
|
|
33
33
|
self.is_relax = False
|
|
34
34
|
self.is_aimd = False
|
|
35
35
|
self.is_scf = False
|
|
36
|
+
self.is_bigmodel=False
|
|
37
|
+
self.is_direct = False
|
|
36
38
|
for index, config in enumerate(sys_configs):
|
|
37
39
|
stage = Stage(config, index, sys_config_prefix, self.dft_style)
|
|
38
40
|
self.sys_config.append(stage)
|
|
@@ -42,22 +44,46 @@ class InitBulkParam(object):
|
|
|
42
44
|
self.is_aimd = True
|
|
43
45
|
if stage.scf:
|
|
44
46
|
self.is_scf = True
|
|
47
|
+
if stage.bigmodel:
|
|
48
|
+
self.is_bigmodel = True
|
|
49
|
+
if stage.direct:
|
|
50
|
+
self.is_direct = True
|
|
45
51
|
|
|
46
52
|
# for PWmat: set etot.input files and persudo files
|
|
47
53
|
# for Vasp: set INCAR files and persudo files
|
|
48
|
-
self.dft_input = SCFParam(json_dict=json_dict,
|
|
54
|
+
self.dft_input = SCFParam(json_dict=json_dict,
|
|
55
|
+
is_scf=self.is_scf,
|
|
56
|
+
is_relax=self.is_relax,
|
|
57
|
+
is_aimd=self.is_aimd,
|
|
58
|
+
root_dir=self.root_dir,
|
|
59
|
+
dft_style=self.dft_style,
|
|
60
|
+
scf_style=self.scf_style,
|
|
61
|
+
is_bigmodel=self.is_bigmodel,
|
|
62
|
+
is_direct=self.is_direct)
|
|
63
|
+
|
|
49
64
|
# check and set relax etot.input file
|
|
50
65
|
for config in self.sys_config:
|
|
51
66
|
if self.is_relax:
|
|
52
67
|
if config.relax_input_idx >= len(self.dft_input.relax_input_list):
|
|
53
68
|
raise Exception("Error! for config '{}' 'relax_input_idx' {} not in 'relax_input'!".format(os.path.basename(config.config_file), config.relax_input_idx))
|
|
54
69
|
config.set_relax_input_file(self.dft_input.relax_input_list[config.relax_input_idx])
|
|
70
|
+
|
|
55
71
|
if self.is_scf:
|
|
56
72
|
if not os.path.exists(self.dft_input.scf_input_list[0].input_file):
|
|
57
73
|
raise Exception("Error! relabel dft input file {} not exisit!".format(self.dft_input.scf_input_list[0].input_file))
|
|
58
74
|
config.set_scf_input_file(self.dft_input.scf_input_list[0])
|
|
59
|
-
|
|
60
|
-
|
|
75
|
+
|
|
76
|
+
if self.is_bigmodel:
|
|
77
|
+
if config.bigmodel_input_idx >= len(self.dft_input.bigmodel_input_list):
|
|
78
|
+
raise Exception("Error! for script '{}' 'bigmodel_input_idx' {} not in 'bigmodel_input'!".format(os.path.basename(config.config_file), config.bigmodel_input_idx))
|
|
79
|
+
config.set_bigmodel_input_file(self.dft_input.bigmodel_input_list[config.bigmodel_input_idx])
|
|
80
|
+
|
|
81
|
+
if self.is_direct:
|
|
82
|
+
if config.direct_input_idx >= len(self.dft_input.direct_input_list):
|
|
83
|
+
raise Exception("Error! for script '{}' 'direct_input_idx' {} not in 'direct_input'!".format(os.path.basename(config.config_file), config.direct_input_idx))
|
|
84
|
+
config.set_direct_input_file(self.dft_input.direct_input_list[config.direct_input_idx])
|
|
85
|
+
|
|
86
|
+
# check and set aimd etot.input file
|
|
61
87
|
if self.is_aimd:
|
|
62
88
|
if config.aimd_input_idx >= len(self.dft_input.aimd_input_list):
|
|
63
89
|
raise Exception("Error! for config '{}' 'aimd_input_idx' {} not in 'aimd_input'!".format(os.path.basename(config.config_file), config.aimd_input_idx))
|
|
@@ -77,16 +103,29 @@ class Stage(object):
|
|
|
77
103
|
self.format = get_parameter("format", json_dict, PWDATA.pwmat_config).lower()
|
|
78
104
|
self.pbc = get_parameter("pbc", json_dict, [1,1,1])
|
|
79
105
|
# extract config file to Config object, then use it
|
|
80
|
-
self.relax = get_parameter("relax", json_dict,
|
|
106
|
+
self.relax = get_parameter("relax", json_dict, False)
|
|
81
107
|
self.relax_input_idx = get_parameter("relax_input_idx", json_dict, 0)
|
|
82
108
|
self.relax_input_file = None
|
|
83
109
|
|
|
84
|
-
self.aimd = get_parameter("aimd", json_dict,
|
|
110
|
+
self.aimd = get_parameter("aimd", json_dict, False)
|
|
85
111
|
self.aimd_input_idx = get_parameter("aimd_input_idx", json_dict, 0)
|
|
86
112
|
self.aimd_input_file = None
|
|
87
113
|
|
|
88
114
|
self.scf = get_parameter("scf", json_dict, False)
|
|
115
|
+
self.scf_input_idx = get_parameter("scf_input_idx", json_dict, 0)
|
|
116
|
+
self.scf_input_file = None
|
|
89
117
|
|
|
118
|
+
self.bigmodel = get_parameter("bigmodel", json_dict, False)
|
|
119
|
+
self.bigmodel_input_idx = get_parameter("bigmodel_input_idx", json_dict, 0)
|
|
120
|
+
self.bigmodel_script = None
|
|
121
|
+
|
|
122
|
+
self.direct = get_parameter("direct", json_dict, False)
|
|
123
|
+
self.direct_input_idx = get_parameter("direct_input_idx", json_dict, 0)
|
|
124
|
+
self.direct_script = None
|
|
125
|
+
|
|
126
|
+
if self.bigmodel and self.aimd:
|
|
127
|
+
raise Exception("ERROR! The 'aimd' and 'bigmodel' cannot be set simultaneously!")
|
|
128
|
+
|
|
90
129
|
super_cell = get_parameter("super_cell", json_dict, [])
|
|
91
130
|
super_cell = str_list_format(super_cell)
|
|
92
131
|
if len(super_cell) > 0:
|
|
@@ -131,3 +170,13 @@ class Stage(object):
|
|
|
131
170
|
self.aimd_flag_symm = input_file.flag_symm
|
|
132
171
|
self.use_dftb = input_file.use_dftb
|
|
133
172
|
self.use_skf = input_file.use_skf
|
|
173
|
+
|
|
174
|
+
def set_bigmodel_input_file(self, input_file:DFTInput):
|
|
175
|
+
self.bigmodel_input_file = input_file.input_file
|
|
176
|
+
self.bigmodel_kspacing = input_file.kspacing
|
|
177
|
+
self.bigmodel_flag_symm = input_file.flag_symm
|
|
178
|
+
|
|
179
|
+
def set_direct_input_file(self, input_file:DFTInput):
|
|
180
|
+
self.direct_input_file = input_file.input_file
|
|
181
|
+
self.direct_kspacing = input_file.kspacing
|
|
182
|
+
self.direct_flag_symm = input_file.flag_symm
|
|
@@ -105,6 +105,18 @@ class StrategyParam(object):
|
|
|
105
105
|
if self.compress:
|
|
106
106
|
error_log = "Error! the kpu uncertainty does not support compress, please set the 'compress' in strategy dict to be false!"
|
|
107
107
|
raise Exception(error_log)
|
|
108
|
+
|
|
109
|
+
self.direct = get_parameter("direct", json_dict, False)
|
|
110
|
+
if self.direct:
|
|
111
|
+
self.direct_script = get_parameter("direct_script", json_dict, None)
|
|
112
|
+
if self.direct_script is not None:
|
|
113
|
+
self.direct_script = os.path.abspath(self.direct_script)
|
|
114
|
+
if not os.path.exists(self.direct_script):
|
|
115
|
+
raise Exception("ERROR! The direct script {} does not exist!".format(self.direct_script))
|
|
116
|
+
else:
|
|
117
|
+
raise Exception("ERROR! The direct script does not exist!")
|
|
118
|
+
else:
|
|
119
|
+
self.direct_script = None
|
|
108
120
|
|
|
109
121
|
def to_dict(self):
|
|
110
122
|
res = {}
|
|
@@ -20,10 +20,22 @@ class Resource(object):
|
|
|
20
20
|
if "-in" in self.explore_resource.command:
|
|
21
21
|
self.explore_resource.command = self.explore_resource.command.split('-in')[0].strip()
|
|
22
22
|
self.explore_resource.command = "{} -in {} > {}".format(self.explore_resource.command, LAMMPS.input_lammps, SLURM_OUT.md_out)
|
|
23
|
-
|
|
23
|
+
else:
|
|
24
|
+
if "explore" in json_dict.keys():
|
|
25
|
+
self.explore_resource = self.get_resource(get_required_parameter("explore", json_dict))
|
|
26
|
+
else:
|
|
27
|
+
self.explore_resource = None
|
|
24
28
|
# check dft resource
|
|
25
|
-
|
|
26
|
-
|
|
29
|
+
if "dft" in json_dict.keys():
|
|
30
|
+
self.dft_resource = self.get_resource(get_required_parameter("dft", json_dict))
|
|
31
|
+
else:
|
|
32
|
+
self.dft_resource = ResourceDetail("mpirun -np 1 PWmat", 1, 1, 1, 1, 1, None, None, None)
|
|
33
|
+
|
|
34
|
+
if "direct" in json_dict.keys():
|
|
35
|
+
self.direct_resource = self.get_resource(get_required_parameter("direct", json_dict))
|
|
36
|
+
else:
|
|
37
|
+
self.direct_resource = None
|
|
38
|
+
|
|
27
39
|
if "scf" in json_dict.keys():
|
|
28
40
|
self.scf_resource = self.get_resource(get_parameter("scf", json_dict, None))
|
|
29
41
|
else:
|
|
@@ -33,11 +45,11 @@ class Resource(object):
|
|
|
33
45
|
# self.dft_resource.dftb_command = "{} > {}".format(dftb_command, SLURM_OUT.dft_out)
|
|
34
46
|
self.dft_style = dft_style
|
|
35
47
|
self.scf_style = scf_style
|
|
36
|
-
if DFT_STYLE.vasp.lower() == dft_style
|
|
48
|
+
if DFT_STYLE.vasp.lower() == dft_style:
|
|
37
49
|
self.dft_resource.command = "{} > {}".format(self.dft_resource.command, SLURM_OUT.dft_out)
|
|
38
|
-
elif DFT_STYLE.pwmat.lower() == dft_style
|
|
50
|
+
elif DFT_STYLE.pwmat.lower() == dft_style:
|
|
39
51
|
self.dft_resource.command = "{} > {}".format(self.dft_resource.command, SLURM_OUT.dft_out)
|
|
40
|
-
elif DFT_STYLE.cp2k.lower() == dft_style
|
|
52
|
+
elif DFT_STYLE.cp2k.lower() == dft_style:
|
|
41
53
|
self.dft_resource.command = "{} {} > {}".format(self.dft_resource.command, CP2K.cp2k_inp, SLURM_OUT.dft_out)
|
|
42
54
|
|
|
43
55
|
if self.scf_resource is not None and scf_style is not None:
|
|
@@ -10,7 +10,9 @@ class SCFParam(object):
|
|
|
10
10
|
is_scf:bool=False,
|
|
11
11
|
root_dir:str=None,
|
|
12
12
|
dft_style:str=None,
|
|
13
|
-
scf_style:str=None
|
|
13
|
+
scf_style:str=None,
|
|
14
|
+
is_bigmodel:bool=False,
|
|
15
|
+
is_direct:bool=False) -> None:# for scf relabel in init_bulk
|
|
14
16
|
|
|
15
17
|
self.dft_style = dft_style
|
|
16
18
|
self.root_dir = root_dir
|
|
@@ -24,12 +26,18 @@ class SCFParam(object):
|
|
|
24
26
|
|
|
25
27
|
if is_scf:
|
|
26
28
|
if "scf_input" in json_dict.keys(): # for init_bulk relabel
|
|
27
|
-
|
|
28
|
-
|
|
29
|
+
if dft_style == DFT_STYLE.bigmodel:
|
|
30
|
+
self.bigmodel_script = get_required_parameter("bigmodel_script", json_dict)
|
|
31
|
+
else:
|
|
32
|
+
json_scf = get_required_parameter("scf_input", json_dict)
|
|
33
|
+
self.scf_input_list = self.set_input(json_scf, flag_symm=0)
|
|
29
34
|
else: # for run_iter
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
35
|
+
if dft_style == DFT_STYLE.bigmodel:
|
|
36
|
+
self.bigmodel_script = get_required_parameter("bigmodel_script", json_dict)
|
|
37
|
+
else:
|
|
38
|
+
self.scf_input_list = self.set_input(json_dict, flag_symm=0)
|
|
39
|
+
if self.scf_input_list[0].use_dftb:
|
|
40
|
+
self.use_dftb = True
|
|
33
41
|
if is_aimd:
|
|
34
42
|
json_aimd = get_required_parameter("aimd_input", json_dict)
|
|
35
43
|
self.aimd_input_list = self.set_input(json_aimd, flag_symm=0)
|
|
@@ -40,6 +48,16 @@ class SCFParam(object):
|
|
|
40
48
|
self.relax_input_list = self.set_input(json_relax, flag_symm=3)
|
|
41
49
|
if self.relax_input_list[0].use_dftb:
|
|
42
50
|
self.use_dftb = True
|
|
51
|
+
|
|
52
|
+
if is_bigmodel: # init_bulk
|
|
53
|
+
json_bigmodel = get_required_parameter("bigmodel_input", json_dict)
|
|
54
|
+
self.bigmodel_input_list = self.set_input(json_bigmodel, flag_symm=3)
|
|
55
|
+
|
|
56
|
+
if is_direct: # init_bulk
|
|
57
|
+
json_direct = get_required_parameter("direct_input", json_dict)
|
|
58
|
+
self.direct_input_list = self.set_input(json_direct, flag_symm=3)
|
|
59
|
+
|
|
60
|
+
self.scf_max_num = get_parameter("scf_max_num", json_dict, None)
|
|
43
61
|
# for pwmat, use 'pseudo' key
|
|
44
62
|
# for vasp is INCAR file, use 'pseudo' key
|
|
45
63
|
pseudo = get_parameter("pseudo", json_dict, [])
|
|
@@ -6,7 +6,7 @@ class OptimizerParam(object):
|
|
|
6
6
|
|
|
7
7
|
def set_optimizer(self, json_source:dict, nep_param:NepParam=None):
|
|
8
8
|
optimizer_dict = get_parameter("optimizer", json_source, {})
|
|
9
|
-
self.opt_name = get_parameter("optimizer", optimizer_dict, "
|
|
9
|
+
self.opt_name = get_parameter("optimizer", optimizer_dict, "ADAM")
|
|
10
10
|
self.batch_size = get_parameter("batch_size", optimizer_dict, 1)
|
|
11
11
|
self.epochs = get_parameter("epochs", optimizer_dict, 30)
|
|
12
12
|
self.print_freq = get_parameter("print_freq", optimizer_dict, 10)
|
pwact/main.py
CHANGED
|
@@ -5,7 +5,7 @@ import glob
|
|
|
5
5
|
import sys
|
|
6
6
|
import json
|
|
7
7
|
import argparse
|
|
8
|
-
from pwact.utils.constant import TEMP_STRUCTURE, UNCERTAINTY, AL_WORK, AL_STRUCTURE, LABEL_FILE_STRUCTURE, EXPLORE_FILE_STRUCTURE
|
|
8
|
+
from pwact.utils.constant import TEMP_STRUCTURE, UNCERTAINTY, AL_WORK, AL_STRUCTURE, LABEL_FILE_STRUCTURE, EXPLORE_FILE_STRUCTURE, DFT_STYLE
|
|
9
9
|
from pwact.utils.format_input_output import make_iter_name
|
|
10
10
|
from pwact.utils.file_operation import write_to_file, del_file_list, search_files, del_dir, copy_dir
|
|
11
11
|
from pwact.utils.json_operation import convert_keys_to_lowercase
|
|
@@ -86,11 +86,16 @@ def run_fp(itername:str, resource : Resource, input_param: InputParam):
|
|
|
86
86
|
#1. if the label work done before, back up and do new work
|
|
87
87
|
lab.back_label()
|
|
88
88
|
#2. make scf work
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
89
|
+
if input_param.dft_style == DFT_STYLE.bigmodel:
|
|
90
|
+
lab.make_bigmodel_work()
|
|
91
|
+
lab.do_bigmodel_jobs()
|
|
92
|
+
lab.do_post_bigmodel()
|
|
93
|
+
else:
|
|
94
|
+
lab.make_scf_work()
|
|
95
|
+
#3. do scf work
|
|
96
|
+
lab.do_scf_jobs()
|
|
97
|
+
#4. collect scf configs outcar or movement, then to pwdata format
|
|
98
|
+
lab.do_post_labeling()
|
|
94
99
|
|
|
95
100
|
def do_training_work(itername:str, resource : Resource, input_param: InputParam):
|
|
96
101
|
mtrain = ModelTrian(itername, resource, input_param)
|
|
@@ -129,8 +134,13 @@ def do_exploring_work(itername:str, resource : Resource, input_param: InputParam
|
|
|
129
134
|
summary = "{} {}\n".format(itername, summary)
|
|
130
135
|
write_to_file(os.path.join(input_param.root_dir, EXPLORE_FILE_STRUCTURE.iter_select_file), summary, mode='a')
|
|
131
136
|
|
|
137
|
+
if input_param.strategy.direct:
|
|
138
|
+
md.make_drct_work()
|
|
139
|
+
md.do_drct_jobs()
|
|
140
|
+
md.post_drct()
|
|
132
141
|
print("config selection done!")
|
|
133
|
-
|
|
142
|
+
|
|
143
|
+
# 5. do post process
|
|
134
144
|
md.post_process_md()
|
|
135
145
|
print("exploring done!")
|
|
136
146
|
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
from maml.sampling.direct import DIRECTSampler, BirchClustering, SelectKFromClusters
|
|
2
|
+
import numpy as np
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
import matplotlib.ticker as mtick
|
|
5
|
+
from ase.io import read
|
|
6
|
+
import subprocess, os, sys
|
|
7
|
+
|
|
8
|
+
write_file = "select.xyz"
|
|
9
|
+
if os.path.exists(write_file):
|
|
10
|
+
os.remove(write_file)
|
|
11
|
+
filenames = ["candidate.xyz"]
|
|
12
|
+
k = 1
|
|
13
|
+
threshold = .04
|
|
14
|
+
def load_ase_MD_traj(filenames: list):
|
|
15
|
+
"""
|
|
16
|
+
Load .traj to pymatgen structures
|
|
17
|
+
"""
|
|
18
|
+
structs = []
|
|
19
|
+
trajs = []
|
|
20
|
+
lens = []
|
|
21
|
+
for filename in filenames:
|
|
22
|
+
traj = read(filename,index=":")
|
|
23
|
+
structs += [i for i in traj]
|
|
24
|
+
trajs.append(traj)
|
|
25
|
+
lens.append(len(traj))
|
|
26
|
+
return structs, trajs, lens
|
|
27
|
+
|
|
28
|
+
structures, trajs, lens = load_ase_MD_traj(filenames)
|
|
29
|
+
n_image = len(structures)
|
|
30
|
+
|
|
31
|
+
DIRECT_sampler = DIRECTSampler(
|
|
32
|
+
clustering=BirchClustering(n=None, threshold_init=threshold), select_k_from_clusters=SelectKFromClusters(k=k)
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
DIRECT_selection = DIRECT_sampler.fit_transform(structures)
|
|
36
|
+
n, m = DIRECT_selection["PCAfeatures"].shape
|
|
37
|
+
|
|
38
|
+
explained_variance = DIRECT_sampler.pca.pca.explained_variance_ratio_
|
|
39
|
+
DIRECT_selection["PCAfeatures_unweighted"] = DIRECT_selection["PCAfeatures"] / explained_variance[:m]
|
|
40
|
+
|
|
41
|
+
plt.plot(
|
|
42
|
+
range(1, explained_variance.shape[0]+1),
|
|
43
|
+
explained_variance * 100,
|
|
44
|
+
"o-",
|
|
45
|
+
)
|
|
46
|
+
plt.xlabel("i$^{\mathrm{th}}$ PC", size=20)
|
|
47
|
+
plt.ylabel("Explained variance", size=20)
|
|
48
|
+
ax = plt.gca()
|
|
49
|
+
ax.yaxis.set_major_formatter(mtick.PercentFormatter())
|
|
50
|
+
plt.tight_layout()
|
|
51
|
+
plt.savefig("PCA_variance.png",dpi=360)
|
|
52
|
+
plt.close()
|
|
53
|
+
|
|
54
|
+
def plot_PCAfeature_coverage(all_features, selected_indexes, method="DIRECT"):
|
|
55
|
+
fig, ax = plt.subplots(figsize=(5, 5))
|
|
56
|
+
selected_features = all_features[selected_indexes]
|
|
57
|
+
plt.plot(all_features[:, 0], all_features[:, 1], "*", alpha=0.5, label=f"All {len(all_features):,} structures")
|
|
58
|
+
plt.plot(
|
|
59
|
+
selected_features[:, 0],
|
|
60
|
+
selected_features[:, 1],
|
|
61
|
+
"*",
|
|
62
|
+
alpha=0.5,
|
|
63
|
+
label=f"{method} sampled {len(selected_features):,}",
|
|
64
|
+
)
|
|
65
|
+
legend = plt.legend(frameon=False, fontsize=14, loc="upper left", bbox_to_anchor=(-0.02, 1.02), reverse=True)
|
|
66
|
+
#for lh in legend.legendHandles:
|
|
67
|
+
# lh.set_alpha(1)
|
|
68
|
+
plt.ylabel("PC 2", size=20)
|
|
69
|
+
plt.xlabel("PC 1", size=20)
|
|
70
|
+
|
|
71
|
+
all_features = DIRECT_selection["PCAfeatures_unweighted"]
|
|
72
|
+
selected_indexes = DIRECT_selection["selected_indexes"]
|
|
73
|
+
plot_PCAfeature_coverage(all_features, selected_indexes)
|
|
74
|
+
plt.tight_layout()
|
|
75
|
+
plt.savefig("PCA_direct.png",dpi=360)
|
|
76
|
+
plt.close()
|
|
77
|
+
|
|
78
|
+
#manual_selection_index = np.arange(0, n_image, int(n_image/n))
|
|
79
|
+
#plot_PCAfeature_coverage(all_features, manual_selection_index, "Manually")
|
|
80
|
+
#plt.tight_layout()
|
|
81
|
+
#plt.savefig("PCA_manually.png",dpi=360)
|
|
82
|
+
#plt.close()
|
|
83
|
+
|
|
84
|
+
def calculate_feature_coverage_score(all_features, selected_indexes, n_bins=100):
|
|
85
|
+
selected_features = all_features[selected_indexes]
|
|
86
|
+
n_all = np.count_nonzero(
|
|
87
|
+
np.histogram(all_features, bins=np.linspace(min(all_features), max(all_features), n_bins))[0]
|
|
88
|
+
)
|
|
89
|
+
n_select = np.count_nonzero(
|
|
90
|
+
np.histogram(selected_features, bins=np.linspace(min(all_features), max(all_features), n_bins))[0]
|
|
91
|
+
)
|
|
92
|
+
return n_select / n_all
|
|
93
|
+
|
|
94
|
+
def calculate_all_FCS(all_features, selected_indexes, b_bins=100):
|
|
95
|
+
select_scores = [
|
|
96
|
+
calculate_feature_coverage_score(all_features[:, i], selected_indexes, n_bins=b_bins)
|
|
97
|
+
for i in range(all_features.shape[1])
|
|
98
|
+
]
|
|
99
|
+
return select_scores
|
|
100
|
+
|
|
101
|
+
all_features = DIRECT_selection["PCAfeatures_unweighted"]
|
|
102
|
+
scores_DIRECT = calculate_all_FCS(all_features, DIRECT_selection["selected_indexes"], b_bins=100)
|
|
103
|
+
#scores_MS = calculate_all_FCS(all_features, manual_selection_index, b_bins=100)
|
|
104
|
+
x = np.arange(len(scores_DIRECT))
|
|
105
|
+
x_ticks = [f"PC {n+1}" for n in range(len(x))]
|
|
106
|
+
|
|
107
|
+
plt.figure(figsize=(15, 4))
|
|
108
|
+
plt.bar(
|
|
109
|
+
x,
|
|
110
|
+
scores_DIRECT,
|
|
111
|
+
width=0.3,
|
|
112
|
+
label=f"DIRECT, $\overline{{\mathrm{{Coverage\ score}}}}$ = {np.mean(scores_DIRECT):.3f}",
|
|
113
|
+
)
|
|
114
|
+
#plt.bar(
|
|
115
|
+
# x + 0.3, scores_MS, width=0.3, label=f"Manual, $\overline{{\mathrm{{Coverage\ score}}}}$ = {np.mean(scores_MS):.3f}"
|
|
116
|
+
#)
|
|
117
|
+
plt.xticks(x, x_ticks, size=16)
|
|
118
|
+
plt.yticks(np.linspace(0, 1.0, 6), size=16)
|
|
119
|
+
plt.ylabel("Coverage score", size=20)
|
|
120
|
+
plt.legend(shadow=True, loc="lower right", fontsize=16)
|
|
121
|
+
plt.tight_layout()
|
|
122
|
+
plt.savefig("Cov_score.png",dpi=360)
|
|
123
|
+
plt.close()
|
|
124
|
+
|
|
125
|
+
def get2index(num: int, list_lens: list):
|
|
126
|
+
for idx, i in enumerate(list_lens):
|
|
127
|
+
if num >= i:
|
|
128
|
+
num -= i
|
|
129
|
+
else:
|
|
130
|
+
break
|
|
131
|
+
return idx, num
|
|
132
|
+
|
|
133
|
+
indices = DIRECT_selection["selected_indexes"]
|
|
134
|
+
select_idx = []
|
|
135
|
+
for ii,index in enumerate(indices):
|
|
136
|
+
idx, num = get2index(index, lens)
|
|
137
|
+
atoms = trajs[idx][num]
|
|
138
|
+
angles = atoms.cell.cellpar()[-3:]
|
|
139
|
+
if angles.max() > 140 or angles.min() < 40:
|
|
140
|
+
continue
|
|
141
|
+
else:
|
|
142
|
+
atoms.set_scaled_positions(atoms.get_scaled_positions())
|
|
143
|
+
atoms.write(write_file,format="extxyz",append=True)
|
|
144
|
+
select_idx.append(idx)
|
|
145
|
+
np.savetxt("select_idx.dat",np.array(indices),fmt="%8d")
|