pwact 0.1.28__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. pwact/active_learning/environment.py +13 -11
  2. pwact/active_learning/explore/run_model_md.py +110 -0
  3. pwact/active_learning/explore/select_image.py +10 -5
  4. pwact/active_learning/init_bulk/direct.py +182 -0
  5. pwact/active_learning/init_bulk/duplicate_scale.py +1 -1
  6. pwact/active_learning/init_bulk/explore.py +300 -0
  7. pwact/active_learning/init_bulk/init_bulk_run.py +87 -47
  8. pwact/active_learning/init_bulk/relabel.py +149 -116
  9. pwact/active_learning/label/labeling.py +132 -18
  10. pwact/active_learning/train/train_model.py +13 -3
  11. pwact/active_learning/user_input/init_bulk_input.py +55 -6
  12. pwact/active_learning/user_input/iter_input.py +12 -0
  13. pwact/active_learning/user_input/resource.py +19 -7
  14. pwact/active_learning/user_input/scf_param.py +24 -6
  15. pwact/active_learning/user_input/train_param/nep_param.py +2 -2
  16. pwact/active_learning/user_input/train_param/optimizer_param.py +1 -1
  17. pwact/active_learning/user_input/train_param/work_file_param.py +1 -1
  18. pwact/main.py +18 -9
  19. pwact/utils/app_lib/do_direct_sample.py +145 -0
  20. pwact/utils/app_lib/do_eqv2model.py +41 -0
  21. pwact/utils/app_lib/lammps.py +1 -1
  22. pwact/utils/constant.py +32 -12
  23. pwact/utils/file_operation.py +12 -5
  24. pwact-0.2.1.dist-info/METADATA +17 -0
  25. {pwact-0.1.28.dist-info → pwact-0.2.1.dist-info}/RECORD +29 -25
  26. {pwact-0.1.28.dist-info → pwact-0.2.1.dist-info}/WHEEL +1 -1
  27. pwact-0.1.28.dist-info/METADATA +0 -107
  28. {pwact-0.1.28.dist-info → pwact-0.2.1.dist-info}/LICENSE +0 -0
  29. {pwact-0.1.28.dist-info → pwact-0.2.1.dist-info}/entry_points.txt +0 -0
  30. {pwact-0.1.28.dist-info → pwact-0.2.1.dist-info}/top_level.txt +0 -0
@@ -26,13 +26,15 @@ class InitBulkParam(object):
26
26
  sys_configs = [sys_configs]
27
27
 
28
28
  # set sys_config detail
29
- self.dft_style = get_required_parameter("dft_style", json_dict).lower()
29
+ self.dft_style = get_parameter("dft_style", json_dict, "PWMAT").lower()
30
30
  self.scf_style = get_parameter("scf_style", json_dict, None)
31
31
 
32
32
  self.sys_config:list[Stage] = []
33
33
  self.is_relax = False
34
34
  self.is_aimd = False
35
35
  self.is_scf = False
36
+ self.is_bigmodel=False
37
+ self.is_direct = False
36
38
  for index, config in enumerate(sys_configs):
37
39
  stage = Stage(config, index, sys_config_prefix, self.dft_style)
38
40
  self.sys_config.append(stage)
@@ -42,22 +44,46 @@ class InitBulkParam(object):
42
44
  self.is_aimd = True
43
45
  if stage.scf:
44
46
  self.is_scf = True
47
+ if stage.bigmodel:
48
+ self.is_bigmodel = True
49
+ if stage.direct:
50
+ self.is_direct = True
45
51
 
46
52
  # for PWmat: set etot.input files and persudo files
47
53
  # for Vasp: set INCAR files and persudo files
48
- self.dft_input = SCFParam(json_dict=json_dict, is_scf=self.is_scf, is_relax=self.is_relax, is_aimd=self.is_aimd, root_dir=self.root_dir, dft_style=self.dft_style, scf_style=self.scf_style)
54
+ self.dft_input = SCFParam(json_dict=json_dict,
55
+ is_scf=self.is_scf,
56
+ is_relax=self.is_relax,
57
+ is_aimd=self.is_aimd,
58
+ root_dir=self.root_dir,
59
+ dft_style=self.dft_style,
60
+ scf_style=self.scf_style,
61
+ is_bigmodel=self.is_bigmodel,
62
+ is_direct=self.is_direct)
63
+
49
64
  # check and set relax etot.input file
50
65
  for config in self.sys_config:
51
66
  if self.is_relax:
52
67
  if config.relax_input_idx >= len(self.dft_input.relax_input_list):
53
68
  raise Exception("Error! for config '{}' 'relax_input_idx' {} not in 'relax_input'!".format(os.path.basename(config.config_file), config.relax_input_idx))
54
69
  config.set_relax_input_file(self.dft_input.relax_input_list[config.relax_input_idx])
70
+
55
71
  if self.is_scf:
56
72
  if not os.path.exists(self.dft_input.scf_input_list[0].input_file):
57
73
  raise Exception("Error! relabel dft input file {} not exisit!".format(self.dft_input.scf_input_list[0].input_file))
58
74
  config.set_scf_input_file(self.dft_input.scf_input_list[0])
59
- # check and set aimd etot.input file
60
- for config in self.sys_config:
75
+
76
+ if self.is_bigmodel:
77
+ if config.bigmodel_input_idx >= len(self.dft_input.bigmodel_input_list):
78
+ raise Exception("Error! for script '{}' 'bigmodel_input_idx' {} not in 'bigmodel_input'!".format(os.path.basename(config.config_file), config.bigmodel_input_idx))
79
+ config.set_bigmodel_input_file(self.dft_input.bigmodel_input_list[config.bigmodel_input_idx])
80
+
81
+ if self.is_direct:
82
+ if config.direct_input_idx >= len(self.dft_input.direct_input_list):
83
+ raise Exception("Error! for script '{}' 'direct_input_idx' {} not in 'direct_input'!".format(os.path.basename(config.config_file), config.direct_input_idx))
84
+ config.set_direct_input_file(self.dft_input.direct_input_list[config.direct_input_idx])
85
+
86
+ # check and set aimd etot.input file
61
87
  if self.is_aimd:
62
88
  if config.aimd_input_idx >= len(self.dft_input.aimd_input_list):
63
89
  raise Exception("Error! for config '{}' 'aimd_input_idx' {} not in 'aimd_input'!".format(os.path.basename(config.config_file), config.aimd_input_idx))
@@ -77,16 +103,29 @@ class Stage(object):
77
103
  self.format = get_parameter("format", json_dict, PWDATA.pwmat_config).lower()
78
104
  self.pbc = get_parameter("pbc", json_dict, [1,1,1])
79
105
  # extract config file to Config object, then use it
80
- self.relax = get_parameter("relax", json_dict, True)
106
+ self.relax = get_parameter("relax", json_dict, False)
81
107
  self.relax_input_idx = get_parameter("relax_input_idx", json_dict, 0)
82
108
  self.relax_input_file = None
83
109
 
84
- self.aimd = get_parameter("aimd", json_dict, True)
110
+ self.aimd = get_parameter("aimd", json_dict, False)
85
111
  self.aimd_input_idx = get_parameter("aimd_input_idx", json_dict, 0)
86
112
  self.aimd_input_file = None
87
113
 
88
114
  self.scf = get_parameter("scf", json_dict, False)
115
+ self.scf_input_idx = get_parameter("scf_input_idx", json_dict, 0)
116
+ self.scf_input_file = None
89
117
 
118
+ self.bigmodel = get_parameter("bigmodel", json_dict, False)
119
+ self.bigmodel_input_idx = get_parameter("bigmodel_input_idx", json_dict, 0)
120
+ self.bigmodel_script = None
121
+
122
+ self.direct = get_parameter("direct", json_dict, False)
123
+ self.direct_input_idx = get_parameter("direct_input_idx", json_dict, 0)
124
+ self.direct_script = None
125
+
126
+ if self.bigmodel and self.aimd:
127
+ raise Exception("ERROR! The 'aimd' and 'bigmodel' cannot be set simultaneously!")
128
+
90
129
  super_cell = get_parameter("super_cell", json_dict, [])
91
130
  super_cell = str_list_format(super_cell)
92
131
  if len(super_cell) > 0:
@@ -131,3 +170,13 @@ class Stage(object):
131
170
  self.aimd_flag_symm = input_file.flag_symm
132
171
  self.use_dftb = input_file.use_dftb
133
172
  self.use_skf = input_file.use_skf
173
+
174
+ def set_bigmodel_input_file(self, input_file:DFTInput):
175
+ self.bigmodel_input_file = input_file.input_file
176
+ self.bigmodel_kspacing = input_file.kspacing
177
+ self.bigmodel_flag_symm = input_file.flag_symm
178
+
179
+ def set_direct_input_file(self, input_file:DFTInput):
180
+ self.direct_input_file = input_file.input_file
181
+ self.direct_kspacing = input_file.kspacing
182
+ self.direct_flag_symm = input_file.flag_symm
@@ -105,6 +105,18 @@ class StrategyParam(object):
105
105
  if self.compress:
106
106
  error_log = "Error! the kpu uncertainty does not support compress, please set the 'compress' in strategy dict to be false!"
107
107
  raise Exception(error_log)
108
+
109
+ self.direct = get_parameter("direct", json_dict, False)
110
+ if self.direct:
111
+ self.direct_script = get_parameter("direct_script", json_dict, None)
112
+ if self.direct_script is not None:
113
+ self.direct_script = os.path.abspath(self.direct_script)
114
+ if not os.path.exists(self.direct_script):
115
+ raise Exception("ERROR! The direct script {} does not exist!".format(self.direct_script))
116
+ else:
117
+ raise Exception("ERROR! The direct script does not exist!")
118
+ else:
119
+ self.direct_script = None
108
120
 
109
121
  def to_dict(self):
110
122
  res = {}
@@ -20,10 +20,22 @@ class Resource(object):
20
20
  if "-in" in self.explore_resource.command:
21
21
  self.explore_resource.command = self.explore_resource.command.split('-in')[0].strip()
22
22
  self.explore_resource.command = "{} -in {} > {}".format(self.explore_resource.command, LAMMPS.input_lammps, SLURM_OUT.md_out)
23
-
23
+ else:
24
+ if "explore" in json_dict.keys():
25
+ self.explore_resource = self.get_resource(get_required_parameter("explore", json_dict))
26
+ else:
27
+ self.explore_resource = None
24
28
  # check dft resource
25
- self.dft_resource = self.get_resource(get_required_parameter("dft", json_dict))
26
-
29
+ if "dft" in json_dict.keys():
30
+ self.dft_resource = self.get_resource(get_required_parameter("dft", json_dict))
31
+ else:
32
+ self.dft_resource = ResourceDetail("mpirun -np 1 PWmat", 1, 1, 1, 1, 1, None, None, None)
33
+
34
+ if "direct" in json_dict.keys():
35
+ self.direct_resource = self.get_resource(get_required_parameter("direct", json_dict))
36
+ else:
37
+ self.direct_resource = None
38
+
27
39
  if "scf" in json_dict.keys():
28
40
  self.scf_resource = self.get_resource(get_parameter("scf", json_dict, None))
29
41
  else:
@@ -33,11 +45,11 @@ class Resource(object):
33
45
  # self.dft_resource.dftb_command = "{} > {}".format(dftb_command, SLURM_OUT.dft_out)
34
46
  self.dft_style = dft_style
35
47
  self.scf_style = scf_style
36
- if DFT_STYLE.vasp.lower() == dft_style.lower():
48
+ if DFT_STYLE.vasp.lower() == dft_style:
37
49
  self.dft_resource.command = "{} > {}".format(self.dft_resource.command, SLURM_OUT.dft_out)
38
- elif DFT_STYLE.pwmat.lower() == dft_style.lower():
50
+ elif DFT_STYLE.pwmat.lower() == dft_style:
39
51
  self.dft_resource.command = "{} > {}".format(self.dft_resource.command, SLURM_OUT.dft_out)
40
- elif DFT_STYLE.cp2k.lower() == dft_style.lower():
52
+ elif DFT_STYLE.cp2k.lower() == dft_style:
41
53
  self.dft_resource.command = "{} {} > {}".format(self.dft_resource.command, CP2K.cp2k_inp, SLURM_OUT.dft_out)
42
54
 
43
55
  if self.scf_resource is not None and scf_style is not None:
@@ -114,7 +126,7 @@ class ResourceDetail(object):
114
126
  if self.gpu_per_node is None and self.cpu_per_node is None:
115
127
  raise Exception("ERROR! Both CPU and GPU resources are not specified!")
116
128
  # check param
117
- if "$SLURM_NTASKS".lower() in command.lower():
129
+ if "$SLURM".lower() in command.lower():
118
130
  pass
119
131
  else:
120
132
  if "mpirun -np" in command:
@@ -10,7 +10,9 @@ class SCFParam(object):
10
10
  is_scf:bool=False,
11
11
  root_dir:str=None,
12
12
  dft_style:str=None,
13
- scf_style:str=None) -> None:# for scf relabel in init_bulk
13
+ scf_style:str=None,
14
+ is_bigmodel:bool=False,
15
+ is_direct:bool=False) -> None:# for scf relabel in init_bulk
14
16
 
15
17
  self.dft_style = dft_style
16
18
  self.root_dir = root_dir
@@ -24,12 +26,18 @@ class SCFParam(object):
24
26
 
25
27
  if is_scf:
26
28
  if "scf_input" in json_dict.keys(): # for init_bulk relabel
27
- json_scf = get_required_parameter("scf_input", json_dict)
28
- self.scf_input_list = self.set_input(json_scf, flag_symm=0)
29
+ if dft_style == DFT_STYLE.bigmodel:
30
+ self.bigmodel_script = get_required_parameter("bigmodel_script", json_dict)
31
+ else:
32
+ json_scf = get_required_parameter("scf_input", json_dict)
33
+ self.scf_input_list = self.set_input(json_scf, flag_symm=0)
29
34
  else: # for run_iter
30
- self.scf_input_list = self.set_input(json_dict, flag_symm=0)
31
- if self.scf_input_list[0].use_dftb:
32
- self.use_dftb = True
35
+ if dft_style == DFT_STYLE.bigmodel:
36
+ self.bigmodel_script = get_required_parameter("bigmodel_script", json_dict)
37
+ else:
38
+ self.scf_input_list = self.set_input(json_dict, flag_symm=0)
39
+ if self.scf_input_list[0].use_dftb:
40
+ self.use_dftb = True
33
41
  if is_aimd:
34
42
  json_aimd = get_required_parameter("aimd_input", json_dict)
35
43
  self.aimd_input_list = self.set_input(json_aimd, flag_symm=0)
@@ -40,6 +48,16 @@ class SCFParam(object):
40
48
  self.relax_input_list = self.set_input(json_relax, flag_symm=3)
41
49
  if self.relax_input_list[0].use_dftb:
42
50
  self.use_dftb = True
51
+
52
+ if is_bigmodel: # init_bulk
53
+ json_bigmodel = get_required_parameter("bigmodel_input", json_dict)
54
+ self.bigmodel_input_list = self.set_input(json_bigmodel, flag_symm=3)
55
+
56
+ if is_direct: # init_bulk
57
+ json_direct = get_required_parameter("direct_input", json_dict)
58
+ self.direct_input_list = self.set_input(json_direct, flag_symm=3)
59
+
60
+ self.scf_max_num = get_parameter("scf_max_num", json_dict, None)
43
61
  # for pwmat, use 'pseudo' key
44
62
  # for vasp is INCAR file, use 'pseudo' key
45
63
  pseudo = get_parameter("pseudo", json_dict, [])
@@ -250,11 +250,11 @@ class NepParam(object):
250
250
  error_log = "the input 'l_max' should has 3 values. The values should be [4, 0, 0] (only use three body features), [4, 2, 0] (use 3 and 4 body features) or [4, 2, 1] (use 3,4,5 body features).\n"
251
251
  raise Exception(error_log)
252
252
  if "fitting_net" in model_dict.keys():
253
- self.neuron = self.get_parameter("network_size", model_dict["fitting_net"], [100]) # number of neurons in the hidden layer
253
+ self.neuron = self.get_parameter("network_size", model_dict["fitting_net"], [40]) # number of neurons in the hidden layer
254
254
  if not isinstance(self.neuron, list):
255
255
  self.neuron = [self.neuron]
256
256
  else:
257
- self.neuron = [100]
257
+ self.neuron = [40]
258
258
  if self.neuron[-1] != 1:
259
259
  self.neuron.append(1) # output layer of fitting net
260
260
  self.set_feature_params()
@@ -6,7 +6,7 @@ class OptimizerParam(object):
6
6
 
7
7
  def set_optimizer(self, json_source:dict, nep_param:NepParam=None):
8
8
  optimizer_dict = get_parameter("optimizer", json_source, {})
9
- self.opt_name = get_parameter("optimizer", optimizer_dict, "LKF")
9
+ self.opt_name = get_parameter("optimizer", optimizer_dict, "ADAM")
10
10
  self.batch_size = get_parameter("batch_size", optimizer_dict, 1)
11
11
  self.epochs = get_parameter("epochs", optimizer_dict, 30)
12
12
  self.print_freq = get_parameter("print_freq", optimizer_dict, 10)
@@ -231,7 +231,7 @@ class WorkFileStructure(object):
231
231
  # self._set_data_file_paths(trainSetDir, dRFeatureInputDir, dRFeatureOutputDir, trainDataPath, validDataPath)
232
232
 
233
233
  def set_nep_native_file_paths(self):
234
- self.nep_model_file = "nep_to_lmps.txt"
234
+ self.nep_model_file = "nep5.txt"
235
235
 
236
236
  def get_data_file_structure(self):
237
237
  file_dict = {}
pwact/main.py CHANGED
@@ -5,7 +5,7 @@ import glob
5
5
  import sys
6
6
  import json
7
7
  import argparse
8
- from pwact.utils.constant import TEMP_STRUCTURE, UNCERTAINTY, AL_WORK, AL_STRUCTURE, LABEL_FILE_STRUCTURE, EXPLORE_FILE_STRUCTURE
8
+ from pwact.utils.constant import TEMP_STRUCTURE, UNCERTAINTY, AL_WORK, AL_STRUCTURE, LABEL_FILE_STRUCTURE, EXPLORE_FILE_STRUCTURE, DFT_STYLE
9
9
  from pwact.utils.format_input_output import make_iter_name
10
10
  from pwact.utils.file_operation import write_to_file, del_file_list, search_files, del_dir, copy_dir
11
11
  from pwact.utils.json_operation import convert_keys_to_lowercase
@@ -79,18 +79,23 @@ def run_iter():
79
79
  if jj == 2 and not input_param.reserve_work: # delete temp_work_dir under current iteration after the labeling done
80
80
  del_file_list([os.path.join(input_param.root_dir, iter_name, TEMP_STRUCTURE.tmp_run_iter_dir)])
81
81
 
82
- print("Active learning done! \nYou could use cmd 'al_pwmlff gather_pwdata' to collect all datas sampled from iterations.")
82
+ print("Active learning done! \nYou could use cmd 'pwact gather_pwdata' to collect all datas sampled from iterations.")
83
83
 
84
84
  def run_fp(itername:str, resource : Resource, input_param: InputParam):
85
85
  lab = Labeling(itername, resource, input_param)
86
86
  #1. if the label work done before, back up and do new work
87
87
  lab.back_label()
88
88
  #2. make scf work
89
- lab.make_scf_work()
90
- #3. do scf work
91
- lab.do_scf_jobs()
92
- #4. collect scf configs outcar or movement, then to pwdata format
93
- lab.do_post_labeling()
89
+ if input_param.dft_style == DFT_STYLE.bigmodel:
90
+ lab.make_bigmodel_work()
91
+ lab.do_bigmodel_jobs()
92
+ lab.do_post_bigmodel()
93
+ else:
94
+ lab.make_scf_work()
95
+ #3. do scf work
96
+ lab.do_scf_jobs()
97
+ #4. collect scf configs outcar or movement, then to pwdata format
98
+ lab.do_post_labeling()
94
99
 
95
100
  def do_training_work(itername:str, resource : Resource, input_param: InputParam):
96
101
  mtrain = ModelTrian(itername, resource, input_param)
@@ -129,8 +134,13 @@ def do_exploring_work(itername:str, resource : Resource, input_param: InputParam
129
134
  summary = "{} {}\n".format(itername, summary)
130
135
  write_to_file(os.path.join(input_param.root_dir, EXPLORE_FILE_STRUCTURE.iter_select_file), summary, mode='a')
131
136
 
137
+ if input_param.strategy.direct:
138
+ md.make_drct_work()
139
+ md.do_drct_jobs()
140
+ md.post_drct()
132
141
  print("config selection done!")
133
- # 5. do post process after lammps md running
142
+
143
+ # 5. do post process
134
144
  md.post_process_md()
135
145
  print("exploring done!")
136
146
 
@@ -236,7 +246,6 @@ def kill_job():
236
246
  # system_json = json.load(open(sys.argv[3]))
237
247
  # if "work_dir" in system_json.keys():
238
248
  # os.chdir(system_json["work_dir"])
239
- os.chdir("/data/home/wuxingxing/codespace/dev_pwact/al_dir/si_5_pwmat/init_bulk")
240
249
  try:
241
250
  with open("./PID", 'r') as rf:
242
251
  pid_str_info = rf.readline().split()
@@ -0,0 +1,145 @@
1
+ from maml.sampling.direct import DIRECTSampler, BirchClustering, SelectKFromClusters
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib.ticker as mtick
5
+ from ase.io import read
6
+ import subprocess, os, sys
7
+
8
+ write_file = "select.xyz"
9
+ if os.path.exists(write_file):
10
+ os.remove(write_file)
11
+ filenames = ["candidate.xyz"]
12
+ k = 1
13
+ threshold = .04
14
+ def load_ase_MD_traj(filenames: list):
15
+ """
16
+ Load .traj to pymatgen structures
17
+ """
18
+ structs = []
19
+ trajs = []
20
+ lens = []
21
+ for filename in filenames:
22
+ traj = read(filename,index=":")
23
+ structs += [i for i in traj]
24
+ trajs.append(traj)
25
+ lens.append(len(traj))
26
+ return structs, trajs, lens
27
+
28
+ structures, trajs, lens = load_ase_MD_traj(filenames)
29
+ n_image = len(structures)
30
+
31
+ DIRECT_sampler = DIRECTSampler(
32
+ clustering=BirchClustering(n=None, threshold_init=threshold), select_k_from_clusters=SelectKFromClusters(k=k)
33
+ )
34
+
35
+ DIRECT_selection = DIRECT_sampler.fit_transform(structures)
36
+ n, m = DIRECT_selection["PCAfeatures"].shape
37
+
38
+ explained_variance = DIRECT_sampler.pca.pca.explained_variance_ratio_
39
+ DIRECT_selection["PCAfeatures_unweighted"] = DIRECT_selection["PCAfeatures"] / explained_variance[:m]
40
+
41
+ plt.plot(
42
+ range(1, explained_variance.shape[0]+1),
43
+ explained_variance * 100,
44
+ "o-",
45
+ )
46
+ plt.xlabel("i$^{\mathrm{th}}$ PC", size=20)
47
+ plt.ylabel("Explained variance", size=20)
48
+ ax = plt.gca()
49
+ ax.yaxis.set_major_formatter(mtick.PercentFormatter())
50
+ plt.tight_layout()
51
+ plt.savefig("PCA_variance.png",dpi=360)
52
+ plt.close()
53
+
54
+ def plot_PCAfeature_coverage(all_features, selected_indexes, method="DIRECT"):
55
+ fig, ax = plt.subplots(figsize=(5, 5))
56
+ selected_features = all_features[selected_indexes]
57
+ plt.plot(all_features[:, 0], all_features[:, 1], "*", alpha=0.5, label=f"All {len(all_features):,} structures")
58
+ plt.plot(
59
+ selected_features[:, 0],
60
+ selected_features[:, 1],
61
+ "*",
62
+ alpha=0.5,
63
+ label=f"{method} sampled {len(selected_features):,}",
64
+ )
65
+ legend = plt.legend(frameon=False, fontsize=14, loc="upper left", bbox_to_anchor=(-0.02, 1.02), reverse=True)
66
+ #for lh in legend.legendHandles:
67
+ # lh.set_alpha(1)
68
+ plt.ylabel("PC 2", size=20)
69
+ plt.xlabel("PC 1", size=20)
70
+
71
+ all_features = DIRECT_selection["PCAfeatures_unweighted"]
72
+ selected_indexes = DIRECT_selection["selected_indexes"]
73
+ plot_PCAfeature_coverage(all_features, selected_indexes)
74
+ plt.tight_layout()
75
+ plt.savefig("PCA_direct.png",dpi=360)
76
+ plt.close()
77
+
78
+ #manual_selection_index = np.arange(0, n_image, int(n_image/n))
79
+ #plot_PCAfeature_coverage(all_features, manual_selection_index, "Manually")
80
+ #plt.tight_layout()
81
+ #plt.savefig("PCA_manually.png",dpi=360)
82
+ #plt.close()
83
+
84
+ def calculate_feature_coverage_score(all_features, selected_indexes, n_bins=100):
85
+ selected_features = all_features[selected_indexes]
86
+ n_all = np.count_nonzero(
87
+ np.histogram(all_features, bins=np.linspace(min(all_features), max(all_features), n_bins))[0]
88
+ )
89
+ n_select = np.count_nonzero(
90
+ np.histogram(selected_features, bins=np.linspace(min(all_features), max(all_features), n_bins))[0]
91
+ )
92
+ return n_select / n_all
93
+
94
+ def calculate_all_FCS(all_features, selected_indexes, b_bins=100):
95
+ select_scores = [
96
+ calculate_feature_coverage_score(all_features[:, i], selected_indexes, n_bins=b_bins)
97
+ for i in range(all_features.shape[1])
98
+ ]
99
+ return select_scores
100
+
101
+ all_features = DIRECT_selection["PCAfeatures_unweighted"]
102
+ scores_DIRECT = calculate_all_FCS(all_features, DIRECT_selection["selected_indexes"], b_bins=100)
103
+ #scores_MS = calculate_all_FCS(all_features, manual_selection_index, b_bins=100)
104
+ x = np.arange(len(scores_DIRECT))
105
+ x_ticks = [f"PC {n+1}" for n in range(len(x))]
106
+
107
+ plt.figure(figsize=(15, 4))
108
+ plt.bar(
109
+ x,
110
+ scores_DIRECT,
111
+ width=0.3,
112
+ label=f"DIRECT, $\overline{{\mathrm{{Coverage\ score}}}}$ = {np.mean(scores_DIRECT):.3f}",
113
+ )
114
+ #plt.bar(
115
+ # x + 0.3, scores_MS, width=0.3, label=f"Manual, $\overline{{\mathrm{{Coverage\ score}}}}$ = {np.mean(scores_MS):.3f}"
116
+ #)
117
+ plt.xticks(x, x_ticks, size=16)
118
+ plt.yticks(np.linspace(0, 1.0, 6), size=16)
119
+ plt.ylabel("Coverage score", size=20)
120
+ plt.legend(shadow=True, loc="lower right", fontsize=16)
121
+ plt.tight_layout()
122
+ plt.savefig("Cov_score.png",dpi=360)
123
+ plt.close()
124
+
125
+ def get2index(num: int, list_lens: list):
126
+ for idx, i in enumerate(list_lens):
127
+ if num >= i:
128
+ num -= i
129
+ else:
130
+ break
131
+ return idx, num
132
+
133
+ indices = DIRECT_selection["selected_indexes"]
134
+ select_idx = []
135
+ for ii,index in enumerate(indices):
136
+ idx, num = get2index(index, lens)
137
+ atoms = trajs[idx][num]
138
+ angles = atoms.cell.cellpar()[-3:]
139
+ if angles.max() > 140 or angles.min() < 40:
140
+ continue
141
+ else:
142
+ atoms.set_scaled_positions(atoms.get_scaled_positions())
143
+ atoms.write(write_file,format="extxyz",append=True)
144
+ select_idx.append(idx)
145
+ np.savetxt("select_idx.dat",np.array(indices),fmt="%8d")
@@ -0,0 +1,41 @@
1
+ from ase.io import read
2
+ from fairchem.core import OCPCalculator
3
+ import os
4
+ output_file = 'train.xyz'
5
+ traj = read("select.xyz", index=":")
6
+ calc = OCPCalculator(
7
+ checkpoint_path="/share/public/PWMLFF_test_data/eqv2-models/eqV2_31M_omat.pt",
8
+ cpu=False,
9
+ )
10
+
11
+ def atoms2xyzstr(atoms):
12
+ num_atom = atoms.get_global_number_of_atoms()
13
+ vol = atoms.get_volume()
14
+ pos = atoms.positions
15
+ forces = atoms.get_forces()
16
+ energy = atoms.get_potential_energy()
17
+ cell = atoms.cell
18
+ virial = -atoms.get_stress(voigt=False) * vol
19
+ xyzstr = "%d\n" % num_atom
20
+ xyz_head = 'Lattice="%.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f" Properties=species:S:1:pos:R:3:forces:R:3 energy=%.8f'
21
+ xyz_format = (cell[0,0],cell[0,1],cell[0,2],cell[1,0],cell[1,1],cell[1,2],cell[2,0],cell[2,1],cell[2,2],energy)
22
+ if virial is not None:
23
+ xyz_head += ' virial="%.8f %.8f %.8f %.8f %.8f %.8f %.8f %.8f %.8f"'
24
+ xyz_format += (
25
+ virial[0,0], virial[0,1], virial[0,2],
26
+ virial[1,0], virial[1,1], virial[1,2],
27
+ virial[2,0], virial[2,1], virial[2,2]
28
+ )
29
+ xyz_head += '\n'
30
+ xyzstr += xyz_head % xyz_format
31
+ for i in range(num_atom):
32
+ xyzstr += "%2s %14.8f %14.8f %14.8f %14.8f %14.8f %14.8f\n" %\
33
+ (atoms[i].symbol,pos[i,0],pos[i,1],pos[i,2],forces[i,0],forces[i,1],forces[i,2])
34
+ return xyzstr
35
+
36
+ f = open(output_file, "w")
37
+ for i in range(len(traj)):
38
+ atoms = traj[i]
39
+ atoms.calc = calc
40
+ f.write(atoms2xyzstr(atoms))
41
+ f.close()
@@ -24,7 +24,7 @@ def make_pair_style(md_type, forcefiled, atom_type:list[int], dump_info:str):
24
24
  pair_names = ""
25
25
  for fi in forcefiled:
26
26
  pair_names += "{} ".format(os.path.basename(fi))
27
- pair_style = "pair_style pwmlff {} {} {}\n".format(len(forcefiled), pair_names, dump_info)
27
+ pair_style = "pair_style matpl {} {}\n".format(pair_names, dump_info)
28
28
  atom_names = " ".join(map(str, atom_type))
29
29
  pair_style += "pair_coeff * * {}\n".format(atom_names)
30
30
  return pair_style
pwact/utils/constant.py CHANGED
@@ -41,6 +41,7 @@ class SLURM_JOB_TYPE:
41
41
  vasp_scf = "vasp/scf"
42
42
  vasp_aimd = "vasp/aimd"
43
43
  lammps = "lammps"
44
+ direct = "direct"
44
45
 
45
46
  '''
46
47
  description:
@@ -138,6 +139,7 @@ class DFT_STYLE:
138
139
  pwmat = "pwmat"
139
140
  cp2k = "cp2k"
140
141
  lammps = "lammps"
142
+ bigmodel="bigmodel"
141
143
 
142
144
  '''
143
145
  description:
@@ -156,6 +158,8 @@ class DFT_STYLE:
156
158
  return PWDATA.cp2k_scf
157
159
  else:
158
160
  return PWDATA.vasp_poscar
161
+ if dft_style.lower() == DFT_STYLE.bigmodel.lower():
162
+ return PWDATA.extxyz
159
163
 
160
164
  @staticmethod
161
165
  def get_normal_config(dft_style:str): # the input config file name of pwmat vasp and cp2k
@@ -249,17 +253,6 @@ class DFT_STYLE:
249
253
  scf_list = [_.lower() for _ in scf_list]
250
254
  return scf_list
251
255
 
252
- '''
253
- description:
254
- the files in scf does not need reserve
255
- return {*}
256
- author: wuxingxing
257
- '''
258
- @staticmethod
259
- def get_scf_del_list():
260
- del_list = ["final.config"]
261
- return del_list
262
-
263
256
  @staticmethod
264
257
  def get_aimd_config(dft_style:str):
265
258
  if dft_style == DFT_STYLE.pwmat:
@@ -357,6 +350,20 @@ class INIT_BULK:
357
350
  scf_tag = "tag.scf.success"
358
351
  scf_tag_failed ="tag.scf.failed"
359
352
 
353
+ bigmodel="bigmodel"
354
+ bigmodel_job = "bigmodel.job"
355
+ bigmodel_tag = "tag.bigmodel.success"
356
+ bigmodel_tag_failed ="tag.bigmodel.failed"
357
+ bigmodel_traj = "traj.xyz"
358
+
359
+ direct="direct"
360
+ direct_job = "direct.job"
361
+ direct_tag = "tag.direct.success"
362
+ direct_tag_failed ="tag.direct.failed"
363
+ candidate_xyz="candidate.xyz"
364
+ candidate_idx="candidate.json"
365
+ direct_traj = "select.xyz"
366
+
360
367
  collection = "collection"
361
368
  npy_format_save_dir = "PWdata"
362
369
  npy_format_name = "datapath.txt"
@@ -407,17 +414,24 @@ class TRAIN_FILE_STRUCTUR:
407
414
 
408
415
  # nep model
409
416
  nep_model_name ="nep_model.ckpt"
410
- nep_model_lmps = "nep_to_lmps.txt"
417
+ nep_model_lmps = "nep5.txt"
411
418
 
412
419
  class EXPLORE_FILE_STRUCTURE:
413
420
  kpu= "kpu"
414
421
  md = "md"
415
422
  select = "select"
423
+ direct = "direct"
416
424
  md_tag = "tag.md.success"
417
425
  md_tag_faild = "tag.md.error"
418
426
  md_job = "md.job"
427
+ direct_tag = "tag.direct.success"
428
+ direct_tag_faild = "tag.direct.error"
429
+ direct_job = "direct.job"
419
430
  # selected image info file names
420
431
  candidate = "candidate.csv"
432
+ candidate_xyz="candidate.xyz"
433
+ select_idx = "select_idx.dat"
434
+ select_xyz = "select.xyz"
421
435
  # candidate_random = "candidate_random.csv"
422
436
  candidate_delete = "candidate_delete.csv"
423
437
  failed = "fail.csv"
@@ -446,11 +460,17 @@ class EXPLORE_FILE_STRUCTURE:
446
460
 
447
461
 
448
462
  class LABEL_FILE_STRUCTURE:
463
+ bigmodel="bigmodel"
449
464
  scf = "scf"
450
465
  result = "result"
451
466
  scf_tag = "tag.scf.success"
452
467
  scf_tag_failed = "tag.scf.failed"
453
468
  scf_job = "scf.job"
469
+ bigmodel_job="bigmodel.job"
470
+ bigmodel_tag = "tag.bigmodel.success"
471
+ bigmodel_tag_failed = "tag.bigmodel.failed"
472
+ train_xyz = "train.xyz"
473
+
454
474
 
455
475
  class LAMMPS:
456
476
  input_lammps="in.lammps"
@@ -318,8 +318,15 @@ def get_file_extension(file_name:str, split_char = "."):
318
318
  @Author :wuxingxing
319
319
  """
320
320
 
321
- def get_random_nums(start, end, n):
322
- numsArray = set()
323
- while len(numsArray) < n:
324
- numsArray.add(random.randint(start, end-1))
325
- return list(numsArray)
321
+ def get_random_nums(start, end, n, seed=None):
322
+ if seed is not None:
323
+ local_random = random.Random(seed) # 独立的随机实例
324
+ numsArray = set()
325
+ while len(numsArray) < n:
326
+ numsArray.add(local_random.randint(start, end-1))
327
+ return list(numsArray)
328
+ else:
329
+ numsArray = set()
330
+ while len(numsArray) < n:
331
+ numsArray.add(random.randint(start, end-1))
332
+ return list(numsArray)