pwact 0.2.0__py3-none-any.whl → 0.2.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -35,6 +35,7 @@ from pwact.utils.file_operation import write_to_file, copy_file, copy_dir, searc
35
35
  from pwact.utils.app_lib.common import link_pseudo_by_atom, set_input_script
36
36
 
37
37
  from pwact.data_format.configop import extract_pwdata, save_config, get_atom_type
38
+ from pwdata import Config
38
39
  class Labeling(object):
39
40
  @staticmethod
40
41
  def kill_job(root_dir:str, itername:str):
@@ -59,9 +60,10 @@ class Labeling(object):
59
60
  self.real_explore_dir = os.path.join(self.input_param.root_dir, itername, AL_STRUCTURE.explore)
60
61
  self.md_dir = os.path.join(self.explore_dir, EXPLORE_FILE_STRUCTURE.md)
61
62
  self.select_dir = os.path.join(self.explore_dir, EXPLORE_FILE_STRUCTURE.select)
63
+ self.direct_dir = os.path.join(self.explore_dir, EXPLORE_FILE_STRUCTURE.direct)
62
64
  self.real_md_dir = os.path.join(self.real_explore_dir, EXPLORE_FILE_STRUCTURE.md)
63
65
  self.real_select_dir = os.path.join(self.real_explore_dir, EXPLORE_FILE_STRUCTURE.select)
64
-
66
+ self.real_direct_dir = os.path.join(self.real_explore_dir, EXPLORE_FILE_STRUCTURE.direct)
65
67
  # labed work dir
66
68
  self.label_dir = os.path.join(self.input_param.root_dir, itername, TEMP_STRUCTURE.tmp_run_iter_dir, AL_STRUCTURE.labeling)
67
69
  self.scf_dir = os.path.join(self.label_dir, LABEL_FILE_STRUCTURE.scf)
@@ -71,6 +73,9 @@ class Labeling(object):
71
73
  self.real_scf_dir = os.path.join(self.real_label_dir, LABEL_FILE_STRUCTURE.scf)
72
74
  self.real_result_dir = os.path.join(self.real_label_dir, LABEL_FILE_STRUCTURE.result)
73
75
 
76
+ self.bigmodel_dir = os.path.join(self.label_dir, LABEL_FILE_STRUCTURE.bigmodel)
77
+ self.real_bigmodel_dir = os.path.join(self.real_label_dir, LABEL_FILE_STRUCTURE.bigmodel)
78
+
74
79
  '''
75
80
  description:
76
81
  the scf work dir file structure is as follow.
@@ -86,9 +91,8 @@ class Labeling(object):
86
91
  return {*}
87
92
  author: wuxingxing
88
93
  '''
94
+
89
95
  def make_scf_work(self):
90
- # read select info, and make scf
91
- # ["devi_force", "file_path", "config_index"]
92
96
  candidate = pd.read_csv(os.path.join(self.select_dir, EXPLORE_FILE_STRUCTURE.candidate))
93
97
  # make scf work dir
94
98
  scf_dir_list = []
@@ -108,14 +112,51 @@ class Labeling(object):
108
112
  atom_names = line.split()
109
113
  self.make_scf_file(scf_sub_md_sys_path, tarj_lmp, atom_names)
110
114
  scf_dir_list.append(scf_sub_md_sys_path)
111
-
115
+
112
116
  self.make_scf_slurm_job_files(scf_dir_list)
113
117
 
118
+ def make_bigmodel_work(self):
119
+ # copy from realdir/direct/select.xyz
120
+ if self.input_param.strategy.direct:
121
+ copy_file(os.path.join(self.real_direct_dir, EXPLORE_FILE_STRUCTURE.select_xyz),
122
+ os.path.join(self.bigmodel_dir, EXPLORE_FILE_STRUCTURE.select_xyz))
123
+ else:
124
+ # copy trajs to bigmodel_dir and cvt to xyz
125
+ candidate = pd.read_csv(os.path.join(self.select_dir, EXPLORE_FILE_STRUCTURE.candidate))
126
+ # make scf work dir
127
+ image_list = None
128
+ for index, row in candidate.iterrows():
129
+ config_index = int(row["config_index"])
130
+ sub_md_sys_path = row["file_path"]
131
+ atom_names = None
132
+ with open(os.path.join(sub_md_sys_path, LAMMPS.atom_type_file), 'r') as rf:
133
+ line = rf.readline()
134
+ atom_names = line.split()
135
+ if image_list is None:
136
+ image_list = Config(data_path=os.path.join(sub_md_sys_path, EXPLORE_FILE_STRUCTURE.traj, "{}{}".format(config_index, LAMMPS.traj_postfix)),
137
+ format=PWDATA.lammps_dump, atom_names=atom_names)
138
+ else:
139
+ image_list.append(Config(data_path=os.path.join(sub_md_sys_path, EXPLORE_FILE_STRUCTURE.traj, "{}{}".format(config_index, LAMMPS.traj_postfix)),
140
+ format=PWDATA.lammps_dump, atom_names=atom_names))
141
+ # cvt_lammps.dump to extxyz
142
+ image_list.to(data_path=self.bigmodel_dir, format=PWDATA.extxyz, data_name="{}".format(EXPLORE_FILE_STRUCTURE.select_xyz))
143
+ # copy bigmodelscript
144
+ copy_file(self.input_param.scf.bigmodel_script, os.path.join(self.bigmodel_dir, os.path.basename(self.input_param.scf.bigmodel_script)))
145
+ # make slrum file
146
+ self.make_bigmodel_slurm_job_files([self.bigmodel_dir])
147
+
114
148
  def back_label(self):
115
- slurm_remain, slurm_success = get_slurm_job_run_info(self.real_scf_dir, \
116
- job_patten="*-{}".format(LABEL_FILE_STRUCTURE.scf_job), \
117
- tag_patten="*-{}".format(LABEL_FILE_STRUCTURE.scf_tag))
118
- slurm_done = True if len(slurm_remain) == 0 and len(slurm_success) > 0 else False
149
+ if self.input_param.scf.dft_style == DFT_STYLE.bigmodel:
150
+ slurm_remain, slurm_success = get_slurm_job_run_info(self.real_bigmodel_dir, \
151
+ job_patten="*-{}".format(LABEL_FILE_STRUCTURE.bigmodel_job), \
152
+ tag_patten="*-{}".format(LABEL_FILE_STRUCTURE.bigmodel_tag))
153
+ slurm_done = True if len(slurm_remain) == 0 and len(slurm_success) > 0 else False
154
+ else:
155
+ slurm_remain, slurm_success = get_slurm_job_run_info(self.real_scf_dir, \
156
+ job_patten="*-{}".format(LABEL_FILE_STRUCTURE.scf_job), \
157
+ tag_patten="*-{}".format(LABEL_FILE_STRUCTURE.scf_tag))
158
+ slurm_done = True if len(slurm_remain) == 0 and len(slurm_success) > 0 else False
159
+
119
160
  if slurm_done:
120
161
  # bk and do new job
121
162
  target_bk_file = add_postfix_dir(self.real_label_dir, postfix_str="bk")
@@ -147,7 +188,31 @@ class Labeling(object):
147
188
  mission.commit_jobs()
148
189
  mission.check_running_job()
149
190
  mission.all_job_finished(error_type=SLURM_OUT.dft_out)
150
-
191
+
192
+ def do_bigmodel_jobs(self):
193
+ mission = Mission()
194
+ slurm_remain, slurm_success = get_slurm_job_run_info(self.bigmodel_dir, \
195
+ job_patten="*-{}".format(LABEL_FILE_STRUCTURE.bigmodel_job), \
196
+ tag_patten="*-{}".format(LABEL_FILE_STRUCTURE.bigmodel_tag))
197
+ slurm_done = True if len(slurm_remain) == 0 and len(slurm_success) > 0 else False
198
+ if slurm_done is False:
199
+ #recover slurm jobs
200
+ if len(slurm_remain) > 0:
201
+ print("Run bigModel Job:\n")
202
+ print(slurm_remain)
203
+ for i, script_path in enumerate(slurm_remain):
204
+ slurm_job = SlurmJob()
205
+ tag_name = "{}-{}".format(os.path.basename(script_path).split('-')[0].strip(), LABEL_FILE_STRUCTURE.bigmodel_tag)
206
+ tag = os.path.join(os.path.dirname(script_path),tag_name)
207
+ slurm_job.set_tag(tag)
208
+ slurm_job.set_cmd(script_path)
209
+ mission.add_job(slurm_job)
210
+
211
+ if len(mission.job_list) > 0:
212
+ mission.commit_jobs()
213
+ mission.check_running_job()
214
+ mission.all_job_finished()
215
+
151
216
  def make_scf_file(self, scf_dir:str, tarj_lmp:str, atom_names:list[str]=None):
152
217
  config_index = os.path.basename(tarj_lmp).split('.')[0]
153
218
  if DFT_STYLE.vasp == self.resource.dft_style: # when do scf, the vasp input file name is 'POSCAR'
@@ -230,6 +295,42 @@ class Labeling(object):
230
295
  slurm_job_file = os.path.join(self.scf_dir, slurm_script_name)
231
296
  write_to_file(slurm_job_file, group_slurm_script, "w")
232
297
 
298
+
299
+ def make_bigmodel_slurm_job_files(self, scf_sub_list:list[str]):
300
+ del_file_list_by_patten(self.bigmodel_dir, "*{}".format(LABEL_FILE_STRUCTURE.scf_job))
301
+ group_list = split_job_for_group(1, scf_sub_list, 1)
302
+
303
+ for group_index, group in enumerate(group_list):
304
+ if group[0] == "NONE":
305
+ continue
306
+
307
+ jobname = "bigmodel{}".format(group_index)
308
+ tag_name = "{}-{}".format(group_index, LABEL_FILE_STRUCTURE.bigmodel_tag)
309
+ tag = os.path.join(self.bigmodel_dir, tag_name)
310
+ run_cmd = self.resource.dft_resource.command
311
+ # if self.resource.dft_resource.gpu_per_node > 0:
312
+ # run_cmd = "mpirun -np {} PWmat > {}".format(self.resource.dft_resource.gpu_per_node, SLURM_OUT.md_out)
313
+ # else:
314
+ # raise Exception("ERROR! the cpu version of pwmat not support yet!")
315
+ group_slurm_script = set_slurm_script_content(gpu_per_node=self.resource.dft_resource.gpu_per_node,
316
+ number_node = self.resource.dft_resource.number_node,
317
+ cpu_per_node = self.resource.dft_resource.cpu_per_node,
318
+ queue_name = self.resource.dft_resource.queue_name,
319
+ custom_flags = self.resource.dft_resource.custom_flags,
320
+ env_script = self.resource.dft_resource.env_script,
321
+ job_name = jobname,
322
+ run_cmd_template = run_cmd,
323
+ group = group,
324
+ job_tag = tag,
325
+ task_tag = LABEL_FILE_STRUCTURE.bigmodel_tag,
326
+ task_tag_faild = LABEL_FILE_STRUCTURE.bigmodel_tag_failed,
327
+ parallel_num=self.resource.dft_resource.parallel_num,
328
+ check_type=self.resource.dft_style
329
+ )
330
+ slurm_script_name = "{}-{}".format(group_index, LABEL_FILE_STRUCTURE.bigmodel_job)
331
+ slurm_job_file = os.path.join(self.bigmodel_dir, slurm_script_name)
332
+ write_to_file(slurm_job_file, group_slurm_script, "w")
333
+
233
334
  '''
234
335
  description:
235
336
  collecte OUT.MLMD to mvm-
@@ -274,12 +375,12 @@ class Labeling(object):
274
375
  for scf_file in scf_files:
275
376
  scf_file_path = os.path.join(scf_dir, scf_file)
276
377
  if scf_file.lower() in DFT_STYLE.get_scf_reserve_list(self.resource.dft_style) \
277
- and scf_file.lower() not in DFT_STYLE.get_scf_del_list():# for pwmat final.config
378
+ or "atom.config" in scf_file.lower() :# for the input natom.config
278
379
  copy_file(scf_file_path, scf_file_path.replace(TEMP_STRUCTURE.tmp_run_iter_dir, ""))
279
380
 
280
381
  # scf files to pwdata format
281
382
  scf_configs = self.collect_scf_configs()
282
-
383
+
283
384
  extract_pwdata(input_data_list=scf_configs,
284
385
  intput_data_format =DFT_STYLE.get_format_by_postfix(os.path.basename(scf_configs[0])),
285
386
  save_data_path =self.result_dir,
@@ -289,3 +390,16 @@ class Labeling(object):
289
390
  )
290
391
  # copy to main dir
291
392
  copy_dir(self.result_dir, self.real_result_dir)
393
+
394
+ def do_post_bigmodel(self):
395
+ # copy the bigmodel labeled.xyz to result
396
+ if self.input_param.data_format == PWDATA.extxyz:
397
+ copy_file(os.path.join(self.bigmodel_dir, LABEL_FILE_STRUCTURE.train_xyz), os.path.join(self.result_dir, LABEL_FILE_STRUCTURE.train_xyz))
398
+ else:
399
+ images = Config(data_path=os.path.join(self.bigmodel_dir, LABEL_FILE_STRUCTURE.train_xyz), format=PWDATA.extxyz)
400
+ images.to(data_path=self.result_dir, format=PWDATA.pwmlff_npy)
401
+ # copy bigmodel dir to real dir
402
+ copy_dir(self.bigmodel_dir, self.real_bigmodel_dir)
403
+ copy_dir(self.result_dir, self.real_result_dir)
404
+ # del slurm logs and tags
405
+ del_file_list_by_patten(self.real_bigmodel_dir, "slurm-*")
@@ -26,13 +26,15 @@ class InitBulkParam(object):
26
26
  sys_configs = [sys_configs]
27
27
 
28
28
  # set sys_config detail
29
- self.dft_style = get_required_parameter("dft_style", json_dict).lower()
29
+ self.dft_style = get_parameter("dft_style", json_dict, "PWMAT").lower()
30
30
  self.scf_style = get_parameter("scf_style", json_dict, None)
31
31
 
32
32
  self.sys_config:list[Stage] = []
33
33
  self.is_relax = False
34
34
  self.is_aimd = False
35
35
  self.is_scf = False
36
+ self.is_bigmodel=False
37
+ self.is_direct = False
36
38
  for index, config in enumerate(sys_configs):
37
39
  stage = Stage(config, index, sys_config_prefix, self.dft_style)
38
40
  self.sys_config.append(stage)
@@ -42,22 +44,46 @@ class InitBulkParam(object):
42
44
  self.is_aimd = True
43
45
  if stage.scf:
44
46
  self.is_scf = True
47
+ if stage.bigmodel:
48
+ self.is_bigmodel = True
49
+ if stage.direct:
50
+ self.is_direct = True
45
51
 
46
52
  # for PWmat: set etot.input files and persudo files
47
53
  # for Vasp: set INCAR files and persudo files
48
- self.dft_input = SCFParam(json_dict=json_dict, is_scf=self.is_scf, is_relax=self.is_relax, is_aimd=self.is_aimd, root_dir=self.root_dir, dft_style=self.dft_style, scf_style=self.scf_style)
54
+ self.dft_input = SCFParam(json_dict=json_dict,
55
+ is_scf=self.is_scf,
56
+ is_relax=self.is_relax,
57
+ is_aimd=self.is_aimd,
58
+ root_dir=self.root_dir,
59
+ dft_style=self.dft_style,
60
+ scf_style=self.scf_style,
61
+ is_bigmodel=self.is_bigmodel,
62
+ is_direct=self.is_direct)
63
+
49
64
  # check and set relax etot.input file
50
65
  for config in self.sys_config:
51
66
  if self.is_relax:
52
67
  if config.relax_input_idx >= len(self.dft_input.relax_input_list):
53
68
  raise Exception("Error! for config '{}' 'relax_input_idx' {} not in 'relax_input'!".format(os.path.basename(config.config_file), config.relax_input_idx))
54
69
  config.set_relax_input_file(self.dft_input.relax_input_list[config.relax_input_idx])
70
+
55
71
  if self.is_scf:
56
72
  if not os.path.exists(self.dft_input.scf_input_list[0].input_file):
57
73
  raise Exception("Error! relabel dft input file {} not exisit!".format(self.dft_input.scf_input_list[0].input_file))
58
74
  config.set_scf_input_file(self.dft_input.scf_input_list[0])
59
- # check and set aimd etot.input file
60
- for config in self.sys_config:
75
+
76
+ if self.is_bigmodel:
77
+ if config.bigmodel_input_idx >= len(self.dft_input.bigmodel_input_list):
78
+ raise Exception("Error! for script '{}' 'bigmodel_input_idx' {} not in 'bigmodel_input'!".format(os.path.basename(config.config_file), config.bigmodel_input_idx))
79
+ config.set_bigmodel_input_file(self.dft_input.bigmodel_input_list[config.bigmodel_input_idx])
80
+
81
+ if self.is_direct:
82
+ if config.direct_input_idx >= len(self.dft_input.direct_input_list):
83
+ raise Exception("Error! for script '{}' 'direct_input_idx' {} not in 'direct_input'!".format(os.path.basename(config.config_file), config.direct_input_idx))
84
+ config.set_direct_input_file(self.dft_input.direct_input_list[config.direct_input_idx])
85
+
86
+ # check and set aimd etot.input file
61
87
  if self.is_aimd:
62
88
  if config.aimd_input_idx >= len(self.dft_input.aimd_input_list):
63
89
  raise Exception("Error! for config '{}' 'aimd_input_idx' {} not in 'aimd_input'!".format(os.path.basename(config.config_file), config.aimd_input_idx))
@@ -77,16 +103,29 @@ class Stage(object):
77
103
  self.format = get_parameter("format", json_dict, PWDATA.pwmat_config).lower()
78
104
  self.pbc = get_parameter("pbc", json_dict, [1,1,1])
79
105
  # extract config file to Config object, then use it
80
- self.relax = get_parameter("relax", json_dict, True)
106
+ self.relax = get_parameter("relax", json_dict, False)
81
107
  self.relax_input_idx = get_parameter("relax_input_idx", json_dict, 0)
82
108
  self.relax_input_file = None
83
109
 
84
- self.aimd = get_parameter("aimd", json_dict, True)
110
+ self.aimd = get_parameter("aimd", json_dict, False)
85
111
  self.aimd_input_idx = get_parameter("aimd_input_idx", json_dict, 0)
86
112
  self.aimd_input_file = None
87
113
 
88
114
  self.scf = get_parameter("scf", json_dict, False)
115
+ self.scf_input_idx = get_parameter("scf_input_idx", json_dict, 0)
116
+ self.scf_input_file = None
89
117
 
118
+ self.bigmodel = get_parameter("bigmodel", json_dict, False)
119
+ self.bigmodel_input_idx = get_parameter("bigmodel_input_idx", json_dict, 0)
120
+ self.bigmodel_script = None
121
+
122
+ self.direct = get_parameter("direct", json_dict, False)
123
+ self.direct_input_idx = get_parameter("direct_input_idx", json_dict, 0)
124
+ self.direct_script = None
125
+
126
+ if self.bigmodel and self.aimd:
127
+ raise Exception("ERROR! The 'aimd' and 'bigmodel' cannot be set simultaneously!")
128
+
90
129
  super_cell = get_parameter("super_cell", json_dict, [])
91
130
  super_cell = str_list_format(super_cell)
92
131
  if len(super_cell) > 0:
@@ -131,3 +170,13 @@ class Stage(object):
131
170
  self.aimd_flag_symm = input_file.flag_symm
132
171
  self.use_dftb = input_file.use_dftb
133
172
  self.use_skf = input_file.use_skf
173
+
174
+ def set_bigmodel_input_file(self, input_file:DFTInput):
175
+ self.bigmodel_input_file = input_file.input_file
176
+ self.bigmodel_kspacing = input_file.kspacing
177
+ self.bigmodel_flag_symm = input_file.flag_symm
178
+
179
+ def set_direct_input_file(self, input_file:DFTInput):
180
+ self.direct_input_file = input_file.input_file
181
+ self.direct_kspacing = input_file.kspacing
182
+ self.direct_flag_symm = input_file.flag_symm
@@ -105,6 +105,18 @@ class StrategyParam(object):
105
105
  if self.compress:
106
106
  error_log = "Error! the kpu uncertainty does not support compress, please set the 'compress' in strategy dict to be false!"
107
107
  raise Exception(error_log)
108
+
109
+ self.direct = get_parameter("direct", json_dict, False)
110
+ if self.direct:
111
+ self.direct_script = get_parameter("direct_script", json_dict, None)
112
+ if self.direct_script is not None:
113
+ self.direct_script = os.path.abspath(self.direct_script)
114
+ if not os.path.exists(self.direct_script):
115
+ raise Exception("ERROR! The direct script {} does not exist!".format(self.direct_script))
116
+ else:
117
+ raise Exception("ERROR! The direct script does not exist!")
118
+ else:
119
+ self.direct_script = None
108
120
 
109
121
  def to_dict(self):
110
122
  res = {}
@@ -20,10 +20,22 @@ class Resource(object):
20
20
  if "-in" in self.explore_resource.command:
21
21
  self.explore_resource.command = self.explore_resource.command.split('-in')[0].strip()
22
22
  self.explore_resource.command = "{} -in {} > {}".format(self.explore_resource.command, LAMMPS.input_lammps, SLURM_OUT.md_out)
23
-
23
+ else:
24
+ if "explore" in json_dict.keys():
25
+ self.explore_resource = self.get_resource(get_required_parameter("explore", json_dict))
26
+ else:
27
+ self.explore_resource = None
24
28
  # check dft resource
25
- self.dft_resource = self.get_resource(get_required_parameter("dft", json_dict))
26
-
29
+ if "dft" in json_dict.keys():
30
+ self.dft_resource = self.get_resource(get_required_parameter("dft", json_dict))
31
+ else:
32
+ self.dft_resource = ResourceDetail("mpirun -np 1 PWmat", 1, 1, 1, 1, 1, None, None, None)
33
+
34
+ if "direct" in json_dict.keys():
35
+ self.direct_resource = self.get_resource(get_required_parameter("direct", json_dict))
36
+ else:
37
+ self.direct_resource = None
38
+
27
39
  if "scf" in json_dict.keys():
28
40
  self.scf_resource = self.get_resource(get_parameter("scf", json_dict, None))
29
41
  else:
@@ -33,11 +45,11 @@ class Resource(object):
33
45
  # self.dft_resource.dftb_command = "{} > {}".format(dftb_command, SLURM_OUT.dft_out)
34
46
  self.dft_style = dft_style
35
47
  self.scf_style = scf_style
36
- if DFT_STYLE.vasp.lower() == dft_style.lower():
48
+ if DFT_STYLE.vasp.lower() == dft_style:
37
49
  self.dft_resource.command = "{} > {}".format(self.dft_resource.command, SLURM_OUT.dft_out)
38
- elif DFT_STYLE.pwmat.lower() == dft_style.lower():
50
+ elif DFT_STYLE.pwmat.lower() == dft_style:
39
51
  self.dft_resource.command = "{} > {}".format(self.dft_resource.command, SLURM_OUT.dft_out)
40
- elif DFT_STYLE.cp2k.lower() == dft_style.lower():
52
+ elif DFT_STYLE.cp2k.lower() == dft_style:
41
53
  self.dft_resource.command = "{} {} > {}".format(self.dft_resource.command, CP2K.cp2k_inp, SLURM_OUT.dft_out)
42
54
 
43
55
  if self.scf_resource is not None and scf_style is not None:
@@ -10,7 +10,9 @@ class SCFParam(object):
10
10
  is_scf:bool=False,
11
11
  root_dir:str=None,
12
12
  dft_style:str=None,
13
- scf_style:str=None) -> None:# for scf relabel in init_bulk
13
+ scf_style:str=None,
14
+ is_bigmodel:bool=False,
15
+ is_direct:bool=False) -> None:# for scf relabel in init_bulk
14
16
 
15
17
  self.dft_style = dft_style
16
18
  self.root_dir = root_dir
@@ -24,12 +26,18 @@ class SCFParam(object):
24
26
 
25
27
  if is_scf:
26
28
  if "scf_input" in json_dict.keys(): # for init_bulk relabel
27
- json_scf = get_required_parameter("scf_input", json_dict)
28
- self.scf_input_list = self.set_input(json_scf, flag_symm=0)
29
+ if dft_style == DFT_STYLE.bigmodel:
30
+ self.bigmodel_script = get_required_parameter("bigmodel_script", json_dict)
31
+ else:
32
+ json_scf = get_required_parameter("scf_input", json_dict)
33
+ self.scf_input_list = self.set_input(json_scf, flag_symm=0)
29
34
  else: # for run_iter
30
- self.scf_input_list = self.set_input(json_dict, flag_symm=0)
31
- if self.scf_input_list[0].use_dftb:
32
- self.use_dftb = True
35
+ if dft_style == DFT_STYLE.bigmodel:
36
+ self.bigmodel_script = get_required_parameter("bigmodel_script", json_dict)
37
+ else:
38
+ self.scf_input_list = self.set_input(json_dict, flag_symm=0)
39
+ if self.scf_input_list[0].use_dftb:
40
+ self.use_dftb = True
33
41
  if is_aimd:
34
42
  json_aimd = get_required_parameter("aimd_input", json_dict)
35
43
  self.aimd_input_list = self.set_input(json_aimd, flag_symm=0)
@@ -40,6 +48,16 @@ class SCFParam(object):
40
48
  self.relax_input_list = self.set_input(json_relax, flag_symm=3)
41
49
  if self.relax_input_list[0].use_dftb:
42
50
  self.use_dftb = True
51
+
52
+ if is_bigmodel: # init_bulk
53
+ json_bigmodel = get_required_parameter("bigmodel_input", json_dict)
54
+ self.bigmodel_input_list = self.set_input(json_bigmodel, flag_symm=3)
55
+
56
+ if is_direct: # init_bulk
57
+ json_direct = get_required_parameter("direct_input", json_dict)
58
+ self.direct_input_list = self.set_input(json_direct, flag_symm=3)
59
+
60
+ self.scf_max_num = get_parameter("scf_max_num", json_dict, None)
43
61
  # for pwmat, use 'pseudo' key
44
62
  # for vasp is INCAR file, use 'pseudo' key
45
63
  pseudo = get_parameter("pseudo", json_dict, [])
@@ -6,7 +6,7 @@ class OptimizerParam(object):
6
6
 
7
7
  def set_optimizer(self, json_source:dict, nep_param:NepParam=None):
8
8
  optimizer_dict = get_parameter("optimizer", json_source, {})
9
- self.opt_name = get_parameter("optimizer", optimizer_dict, "LKF")
9
+ self.opt_name = get_parameter("optimizer", optimizer_dict, "ADAM")
10
10
  self.batch_size = get_parameter("batch_size", optimizer_dict, 1)
11
11
  self.epochs = get_parameter("epochs", optimizer_dict, 30)
12
12
  self.print_freq = get_parameter("print_freq", optimizer_dict, 10)
pwact/main.py CHANGED
@@ -5,7 +5,7 @@ import glob
5
5
  import sys
6
6
  import json
7
7
  import argparse
8
- from pwact.utils.constant import TEMP_STRUCTURE, UNCERTAINTY, AL_WORK, AL_STRUCTURE, LABEL_FILE_STRUCTURE, EXPLORE_FILE_STRUCTURE
8
+ from pwact.utils.constant import TEMP_STRUCTURE, UNCERTAINTY, AL_WORK, AL_STRUCTURE, LABEL_FILE_STRUCTURE, EXPLORE_FILE_STRUCTURE, DFT_STYLE
9
9
  from pwact.utils.format_input_output import make_iter_name
10
10
  from pwact.utils.file_operation import write_to_file, del_file_list, search_files, del_dir, copy_dir
11
11
  from pwact.utils.json_operation import convert_keys_to_lowercase
@@ -86,11 +86,16 @@ def run_fp(itername:str, resource : Resource, input_param: InputParam):
86
86
  #1. if the label work done before, back up and do new work
87
87
  lab.back_label()
88
88
  #2. make scf work
89
- lab.make_scf_work()
90
- #3. do scf work
91
- lab.do_scf_jobs()
92
- #4. collect scf configs outcar or movement, then to pwdata format
93
- lab.do_post_labeling()
89
+ if input_param.dft_style == DFT_STYLE.bigmodel:
90
+ lab.make_bigmodel_work()
91
+ lab.do_bigmodel_jobs()
92
+ lab.do_post_bigmodel()
93
+ else:
94
+ lab.make_scf_work()
95
+ #3. do scf work
96
+ lab.do_scf_jobs()
97
+ #4. collect scf configs outcar or movement, then to pwdata format
98
+ lab.do_post_labeling()
94
99
 
95
100
  def do_training_work(itername:str, resource : Resource, input_param: InputParam):
96
101
  mtrain = ModelTrian(itername, resource, input_param)
@@ -129,8 +134,13 @@ def do_exploring_work(itername:str, resource : Resource, input_param: InputParam
129
134
  summary = "{} {}\n".format(itername, summary)
130
135
  write_to_file(os.path.join(input_param.root_dir, EXPLORE_FILE_STRUCTURE.iter_select_file), summary, mode='a')
131
136
 
137
+ if input_param.strategy.direct:
138
+ md.make_drct_work()
139
+ md.do_drct_jobs()
140
+ md.post_drct()
132
141
  print("config selection done!")
133
- # 5. do post process after lammps md running
142
+
143
+ # 5. do post process
134
144
  md.post_process_md()
135
145
  print("exploring done!")
136
146
 
@@ -0,0 +1,145 @@
1
+ from maml.sampling.direct import DIRECTSampler, BirchClustering, SelectKFromClusters
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib.ticker as mtick
5
+ from ase.io import read
6
+ import subprocess, os, sys
7
+
8
+ write_file = "select.xyz"
9
+ if os.path.exists(write_file):
10
+ os.remove(write_file)
11
+ filenames = ["candidate.xyz"]
12
+ k = 1
13
+ threshold = .04
14
+ def load_ase_MD_traj(filenames: list):
15
+ """
16
+ Load .traj to pymatgen structures
17
+ """
18
+ structs = []
19
+ trajs = []
20
+ lens = []
21
+ for filename in filenames:
22
+ traj = read(filename,index=":")
23
+ structs += [i for i in traj]
24
+ trajs.append(traj)
25
+ lens.append(len(traj))
26
+ return structs, trajs, lens
27
+
28
+ structures, trajs, lens = load_ase_MD_traj(filenames)
29
+ n_image = len(structures)
30
+
31
+ DIRECT_sampler = DIRECTSampler(
32
+ clustering=BirchClustering(n=None, threshold_init=threshold), select_k_from_clusters=SelectKFromClusters(k=k)
33
+ )
34
+
35
+ DIRECT_selection = DIRECT_sampler.fit_transform(structures)
36
+ n, m = DIRECT_selection["PCAfeatures"].shape
37
+
38
+ explained_variance = DIRECT_sampler.pca.pca.explained_variance_ratio_
39
+ DIRECT_selection["PCAfeatures_unweighted"] = DIRECT_selection["PCAfeatures"] / explained_variance[:m]
40
+
41
+ plt.plot(
42
+ range(1, explained_variance.shape[0]+1),
43
+ explained_variance * 100,
44
+ "o-",
45
+ )
46
+ plt.xlabel("i$^{\mathrm{th}}$ PC", size=20)
47
+ plt.ylabel("Explained variance", size=20)
48
+ ax = plt.gca()
49
+ ax.yaxis.set_major_formatter(mtick.PercentFormatter())
50
+ plt.tight_layout()
51
+ plt.savefig("PCA_variance.png",dpi=360)
52
+ plt.close()
53
+
54
+ def plot_PCAfeature_coverage(all_features, selected_indexes, method="DIRECT"):
55
+ fig, ax = plt.subplots(figsize=(5, 5))
56
+ selected_features = all_features[selected_indexes]
57
+ plt.plot(all_features[:, 0], all_features[:, 1], "*", alpha=0.5, label=f"All {len(all_features):,} structures")
58
+ plt.plot(
59
+ selected_features[:, 0],
60
+ selected_features[:, 1],
61
+ "*",
62
+ alpha=0.5,
63
+ label=f"{method} sampled {len(selected_features):,}",
64
+ )
65
+ legend = plt.legend(frameon=False, fontsize=14, loc="upper left", bbox_to_anchor=(-0.02, 1.02), reverse=True)
66
+ #for lh in legend.legendHandles:
67
+ # lh.set_alpha(1)
68
+ plt.ylabel("PC 2", size=20)
69
+ plt.xlabel("PC 1", size=20)
70
+
71
+ all_features = DIRECT_selection["PCAfeatures_unweighted"]
72
+ selected_indexes = DIRECT_selection["selected_indexes"]
73
+ plot_PCAfeature_coverage(all_features, selected_indexes)
74
+ plt.tight_layout()
75
+ plt.savefig("PCA_direct.png",dpi=360)
76
+ plt.close()
77
+
78
+ #manual_selection_index = np.arange(0, n_image, int(n_image/n))
79
+ #plot_PCAfeature_coverage(all_features, manual_selection_index, "Manually")
80
+ #plt.tight_layout()
81
+ #plt.savefig("PCA_manually.png",dpi=360)
82
+ #plt.close()
83
+
84
+ def calculate_feature_coverage_score(all_features, selected_indexes, n_bins=100):
85
+ selected_features = all_features[selected_indexes]
86
+ n_all = np.count_nonzero(
87
+ np.histogram(all_features, bins=np.linspace(min(all_features), max(all_features), n_bins))[0]
88
+ )
89
+ n_select = np.count_nonzero(
90
+ np.histogram(selected_features, bins=np.linspace(min(all_features), max(all_features), n_bins))[0]
91
+ )
92
+ return n_select / n_all
93
+
94
+ def calculate_all_FCS(all_features, selected_indexes, b_bins=100):
95
+ select_scores = [
96
+ calculate_feature_coverage_score(all_features[:, i], selected_indexes, n_bins=b_bins)
97
+ for i in range(all_features.shape[1])
98
+ ]
99
+ return select_scores
100
+
101
+ all_features = DIRECT_selection["PCAfeatures_unweighted"]
102
+ scores_DIRECT = calculate_all_FCS(all_features, DIRECT_selection["selected_indexes"], b_bins=100)
103
+ #scores_MS = calculate_all_FCS(all_features, manual_selection_index, b_bins=100)
104
+ x = np.arange(len(scores_DIRECT))
105
+ x_ticks = [f"PC {n+1}" for n in range(len(x))]
106
+
107
+ plt.figure(figsize=(15, 4))
108
+ plt.bar(
109
+ x,
110
+ scores_DIRECT,
111
+ width=0.3,
112
+ label=f"DIRECT, $\overline{{\mathrm{{Coverage\ score}}}}$ = {np.mean(scores_DIRECT):.3f}",
113
+ )
114
+ #plt.bar(
115
+ # x + 0.3, scores_MS, width=0.3, label=f"Manual, $\overline{{\mathrm{{Coverage\ score}}}}$ = {np.mean(scores_MS):.3f}"
116
+ #)
117
+ plt.xticks(x, x_ticks, size=16)
118
+ plt.yticks(np.linspace(0, 1.0, 6), size=16)
119
+ plt.ylabel("Coverage score", size=20)
120
+ plt.legend(shadow=True, loc="lower right", fontsize=16)
121
+ plt.tight_layout()
122
+ plt.savefig("Cov_score.png",dpi=360)
123
+ plt.close()
124
+
125
+ def get2index(num: int, list_lens: list):
126
+ for idx, i in enumerate(list_lens):
127
+ if num >= i:
128
+ num -= i
129
+ else:
130
+ break
131
+ return idx, num
132
+
133
+ indices = DIRECT_selection["selected_indexes"]
134
+ select_idx = []
135
+ for ii,index in enumerate(indices):
136
+ idx, num = get2index(index, lens)
137
+ atoms = trajs[idx][num]
138
+ angles = atoms.cell.cellpar()[-3:]
139
+ if angles.max() > 140 or angles.min() < 40:
140
+ continue
141
+ else:
142
+ atoms.set_scaled_positions(atoms.get_scaled_positions())
143
+ atoms.write(write_file,format="extxyz",append=True)
144
+ select_idx.append(idx)
145
+ np.savetxt("select_idx.dat",np.array(indices),fmt="%8d")