pwact 0.1.9__py3-none-any.whl → 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pwact/active_learning/environment.py +13 -1
- pwact/active_learning/explore/run_model_md.py +11 -8
- pwact/active_learning/slurm.py +27 -25
- pwact/active_learning/train/train_model.py +14 -12
- pwact/active_learning/user_input/iter_input.py +3 -2
- pwact/active_learning/user_input/resource.py +4 -1
- pwact/active_learning/user_input/scf_param.py +1 -1
- pwact/active_learning/user_input/train_param/model_param.py +6 -3
- pwact/active_learning/user_input/train_param/nep_param.py +319 -118
- pwact/active_learning/user_input/train_param/nn_feature_type.py +28 -5
- pwact/active_learning/user_input/train_param/optimizer_param.py +167 -67
- pwact/active_learning/user_input/train_param/train_param.py +84 -35
- pwact/active_learning/user_input/train_param/work_file_param.py +54 -48
- pwact/utils/constant.py +14 -3
- {pwact-0.1.9.dist-info → pwact-0.1.10.dist-info}/METADATA +1 -1
- {pwact-0.1.9.dist-info → pwact-0.1.10.dist-info}/RECORD +20 -20
- {pwact-0.1.9.dist-info → pwact-0.1.10.dist-info}/LICENSE +0 -0
- {pwact-0.1.9.dist-info → pwact-0.1.10.dist-info}/WHEEL +0 -0
- {pwact-0.1.9.dist-info → pwact-0.1.10.dist-info}/entry_points.txt +0 -0
- {pwact-0.1.9.dist-info → pwact-0.1.10.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,16 @@
|
|
|
1
1
|
import subprocess
|
|
2
|
+
import pkg_resources
|
|
2
3
|
def check_envs():
|
|
3
4
|
# for pwmat
|
|
4
|
-
|
|
5
|
+
pass
|
|
6
|
+
# check pwdata
|
|
7
|
+
# try:
|
|
8
|
+
# package_version = pkg_resources.get_distribution('pwdata').version
|
|
9
|
+
# if pkg_resources.parse_version(min_version) <= pkg_resources.parse_version(package_version) <= pkg_resources.parse_version(max_version):
|
|
10
|
+
# print(f"{package_name} version {package_version} is within the required range [{min_version}, {max_version}].")
|
|
11
|
+
# return True
|
|
12
|
+
# else:
|
|
13
|
+
# print(f"{package_name} version {package_version} is NOT within the required range [{min_version}, {max_version}].")
|
|
14
|
+
# return False
|
|
15
|
+
|
|
16
|
+
# check PWMLFF???
|
|
@@ -21,7 +21,7 @@ from pwact.active_learning.explore.select_image import select_image
|
|
|
21
21
|
from pwact.active_learning.user_input.resource import Resource
|
|
22
22
|
from pwact.active_learning.user_input.iter_input import InputParam, MdDetail
|
|
23
23
|
from pwact.utils.constant import AL_STRUCTURE, TEMP_STRUCTURE, EXPLORE_FILE_STRUCTURE, TRAIN_FILE_STRUCTUR, \
|
|
24
|
-
FORCEFILED, ENSEMBLE, LAMMPS, LAMMPS_CMD, UNCERTAINTY, DFT_STYLE, SLURM_OUT, SLURM_JOB_TYPE, PWDATA
|
|
24
|
+
FORCEFILED, ENSEMBLE, LAMMPS, LAMMPS_CMD, UNCERTAINTY, DFT_STYLE, SLURM_OUT, SLURM_JOB_TYPE, PWDATA, MODEL_TYPE
|
|
25
25
|
|
|
26
26
|
from pwact.utils.format_input_output import get_iter_from_iter_name, get_sub_md_sys_template_name,\
|
|
27
27
|
make_md_sys_name, get_md_sys_template_name, make_temp_press_name, make_temp_name, make_train_name
|
|
@@ -247,13 +247,16 @@ class Explore(object):
|
|
|
247
247
|
def set_forcefiled_file(self, md_dir:str):
|
|
248
248
|
model_name = ""
|
|
249
249
|
md_model_paths = []
|
|
250
|
-
if self.input_param.
|
|
251
|
-
model_name += TRAIN_FILE_STRUCTUR.
|
|
252
|
-
elif self.input_param.
|
|
253
|
-
if self.input_param.strategy.
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
250
|
+
if self.input_param.train.model_type == MODEL_TYPE.nep:
|
|
251
|
+
model_name += "{}/{}".format(TRAIN_FILE_STRUCTUR.model_record, TRAIN_FILE_STRUCTUR.nep_model_lmps)
|
|
252
|
+
elif self.input_param.train.model_type == MODEL_TYPE.dp:
|
|
253
|
+
if self.input_param.strategy.md_type == FORCEFILED.libtorch_lmps:
|
|
254
|
+
model_name += "{}/{}".format(TRAIN_FILE_STRUCTUR.model_record, TRAIN_FILE_STRUCTUR.script_dp_name)
|
|
255
|
+
elif self.input_param.strategy.md_type == FORCEFILED.fortran_lmps:
|
|
256
|
+
if self.input_param.strategy.compress:
|
|
257
|
+
raise Exception("ERROR! The compress model does not support fortran lammps md! Please change the 'md_type' to 2!")
|
|
258
|
+
else:
|
|
259
|
+
model_name += "{}/{}".format(TRAIN_FILE_STRUCTUR.fortran_dp, TRAIN_FILE_STRUCTUR.fortran_dp_name)
|
|
257
260
|
|
|
258
261
|
for model_index in range(self.input_param.strategy.model_num):
|
|
259
262
|
model_name_i = "{}/{}".format(make_train_name(model_index), model_name)
|
pwact/active_learning/slurm.py
CHANGED
|
@@ -154,31 +154,33 @@ class SlurmJob(object):
|
|
|
154
154
|
def check_lammps_out_file(self):
|
|
155
155
|
# read last line of md.log file
|
|
156
156
|
md_dirs = self.get_slurm_works_dir()
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
with open(md_log, "rb") as file:
|
|
166
|
-
file.seek(-2, 2) # 定位到文件末尾前两个字节
|
|
167
|
-
while file.read(1) != b'\n': # 逐字节向前查找换行符
|
|
168
|
-
file.seek(-2, 1) # 向前移动两个字节
|
|
169
|
-
last_line = file.readline().decode().strip() # 读取最后一行并去除换行符和空白字符
|
|
170
|
-
if "ERROR: there are two atoms" in last_line:
|
|
171
|
-
with open(tag_md_file, 'w') as wf:
|
|
172
|
-
wf.writelines("ERROR: there are two atoms too close")
|
|
173
|
-
return True
|
|
174
|
-
elif "Total wall time" in last_line:
|
|
175
|
-
with open(tag_md_file, 'w') as wf:
|
|
176
|
-
wf.writelines("Job Done!")
|
|
177
|
-
return True
|
|
178
|
-
else:
|
|
179
|
-
return False
|
|
157
|
+
try:
|
|
158
|
+
for md_dir in md_dirs:
|
|
159
|
+
tag_md_file = os.path.join(md_dir, "tag.md.success")
|
|
160
|
+
md_log = os.path.join(md_dir, "md.log")
|
|
161
|
+
if os.path.exists(tag_md_file):
|
|
162
|
+
continue
|
|
163
|
+
if not os.path.exists(md_log):
|
|
164
|
+
return False
|
|
180
165
|
|
|
181
|
-
|
|
166
|
+
with open(md_log, "rb") as file:
|
|
167
|
+
file.seek(-2, 2) # 定位到文件末尾前两个字节
|
|
168
|
+
while file.read(1) != b'\n': # 逐字节向前查找换行符
|
|
169
|
+
file.seek(-2, 1) # 向前移动两个字节
|
|
170
|
+
last_line = file.readline().decode().strip() # 读取最后一行并去除换行符和空白字符
|
|
171
|
+
if "ERROR: there are two atoms" in last_line:
|
|
172
|
+
with open(tag_md_file, 'w') as wf:
|
|
173
|
+
wf.writelines("ERROR: there are two atoms too close")
|
|
174
|
+
return True
|
|
175
|
+
elif "Total wall time" in last_line:
|
|
176
|
+
with open(tag_md_file, 'w') as wf:
|
|
177
|
+
wf.writelines("Job Done!")
|
|
178
|
+
return True
|
|
179
|
+
else:
|
|
180
|
+
return False
|
|
181
|
+
return True
|
|
182
|
+
except Exception as e:
|
|
183
|
+
return False
|
|
182
184
|
|
|
183
185
|
|
|
184
186
|
class Mission(object):
|
|
@@ -302,7 +304,7 @@ class Mission(object):
|
|
|
302
304
|
for job in self.job_list:
|
|
303
305
|
if job.status == JobStatus.terminated:
|
|
304
306
|
if job.submit_num <= JobStatus.submit_limit.value:
|
|
305
|
-
print("resubmit job: {}, the time is {}\n".format(job.submit_cmd, job.submit_num))
|
|
307
|
+
print("resubmit job {}: {}, the time is {}\n".format(job.jobid, job.submit_cmd, job.submit_num))
|
|
306
308
|
job.submit()
|
|
307
309
|
else:
|
|
308
310
|
job.status = JobStatus.resubmit_failed
|
|
@@ -9,7 +9,7 @@ from pwact.active_learning.user_input.iter_input import InputParam
|
|
|
9
9
|
|
|
10
10
|
from pwact.utils.format_input_output import make_train_name, get_seed_by_time, get_iter_from_iter_name, make_iter_name
|
|
11
11
|
from pwact.utils.constant import AL_STRUCTURE, UNCERTAINTY, TEMP_STRUCTURE, MODEL_CMD, \
|
|
12
|
-
TRAIN_INPUT_PARAM, TRAIN_FILE_STRUCTUR, FORCEFILED, LABEL_FILE_STRUCTURE, SLURM_OUT
|
|
12
|
+
TRAIN_INPUT_PARAM, TRAIN_FILE_STRUCTUR, FORCEFILED, LABEL_FILE_STRUCTURE, SLURM_OUT, MODEL_TYPE
|
|
13
13
|
|
|
14
14
|
from pwact.utils.file_operation import save_json_file, write_to_file, del_dir, search_files, add_postfix_dir, mv_file, copy_dir, del_file_list
|
|
15
15
|
'''
|
|
@@ -106,18 +106,20 @@ class ModelTrian(object):
|
|
|
106
106
|
script = ""
|
|
107
107
|
pwmlff = self.resource.train_resource.command
|
|
108
108
|
script += "{} {} {} >> {}\n\n".format(pwmlff, MODEL_CMD.train, TRAIN_FILE_STRUCTUR.train_json, SLURM_OUT.train_out)
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
109
|
+
|
|
110
|
+
# do nothing for nep model
|
|
111
|
+
if self.input_param.train.model_type == MODEL_TYPE.dp:
|
|
112
|
+
if self.input_param.strategy.compress:
|
|
113
|
+
script += " {} {} {} -d {} -o {} -s {}/{} >> {}\n\n".format(pwmlff, MODEL_CMD.compress, model_path, \
|
|
114
|
+
self.input_param.strategy.compress_dx, self.input_param.strategy.compress_order, TRAIN_FILE_STRUCTUR.model_record, TRAIN_FILE_STRUCTUR.compree_dp_name, SLURM_OUT.train_out)
|
|
115
|
+
cmp_model_path = "{}/{}".format(TRAIN_FILE_STRUCTUR.model_record, TRAIN_FILE_STRUCTUR.compree_dp_name)
|
|
113
116
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
script += " {} {} {} >> {}\n\n".format(pwmlff, MODEL_CMD.script, cmp_model_path, SLURM_OUT.train_out)
|
|
117
|
+
if self.input_param.strategy.md_type == FORCEFILED.libtorch_lmps:
|
|
118
|
+
if cmp_model_path is None:
|
|
119
|
+
# script model_record/dp_model.ckpt the torch_script_module.pt will in model_record dir
|
|
120
|
+
script += " {} {} {} {}/{} >> {}\n".format(pwmlff, MODEL_CMD.script, model_path, TRAIN_FILE_STRUCTUR.model_record, TRAIN_FILE_STRUCTUR.script_dp_name, SLURM_OUT.train_out)
|
|
121
|
+
else:
|
|
122
|
+
script += " {} {} {} {}/{} >> {}\n\n".format(pwmlff, MODEL_CMD.script, cmp_model_path, TRAIN_FILE_STRUCTUR.model_record, TRAIN_FILE_STRUCTUR.script_dp_name, SLURM_OUT.train_out)
|
|
121
123
|
return script
|
|
122
124
|
|
|
123
125
|
'''
|
|
@@ -3,7 +3,7 @@ import glob
|
|
|
3
3
|
|
|
4
4
|
from pwact.utils.json_operation import get_parameter, get_required_parameter
|
|
5
5
|
from pwact.utils.constant import MODEL_CMD, FORCEFILED, UNCERTAINTY, PWDATA
|
|
6
|
-
from pwact.active_learning.user_input.train_param.train_param import TrainParam
|
|
6
|
+
from pwact.active_learning.user_input.train_param.train_param import InputParam as TrainParam
|
|
7
7
|
from pwact.active_learning.user_input.scf_param import SCFParam
|
|
8
8
|
class InputParam(object):
|
|
9
9
|
# _instance = None
|
|
@@ -12,7 +12,8 @@ class InputParam(object):
|
|
|
12
12
|
if not os.path.isabs(self.root_dir):
|
|
13
13
|
self.root_dir = os.path.realpath(self.root_dir)
|
|
14
14
|
self.record_file = get_parameter("record_file", json_dict, "al.record")
|
|
15
|
-
|
|
15
|
+
if "record_file" not in json_dict.keys():
|
|
16
|
+
print("Warning! record_file not provided, automatically set to {}! ".format(self.record_file))
|
|
16
17
|
|
|
17
18
|
self.reserve_work = get_parameter("reserve_work", json_dict, False)
|
|
18
19
|
# self.reserve_feature = get_parameter("reserve_feature", json_dict, False)
|
|
@@ -69,7 +69,10 @@ class Resource(object):
|
|
|
69
69
|
env_script = ""
|
|
70
70
|
if len(source_list) > 0:
|
|
71
71
|
for source in source_list:
|
|
72
|
-
if "source" != source.split()[0].lower()
|
|
72
|
+
if "source" != source.split()[0].lower() and \
|
|
73
|
+
"export" != source.split()[0].lower() and \
|
|
74
|
+
"module" != source.split()[0].lower() and \
|
|
75
|
+
"conda" != source.split()[0].lower():
|
|
73
76
|
tmp_source = "source {}\n".format(source)
|
|
74
77
|
else:
|
|
75
78
|
tmp_source = "{}\n".format(source)
|
|
@@ -165,7 +165,7 @@ class DFTInput(object):
|
|
|
165
165
|
|
|
166
166
|
if "MP_N123" in key_values and self.kspacing is not None:
|
|
167
167
|
error_info = "ERROR! The 'kspacing' in DFT/input/{} dict and 'MP_N123' in {} file cannot coexist.\n".format(os.path.basename(self.input_file), os.path.basename(self.input_file))
|
|
168
|
-
error_info += "If 'MP_N123' is not indicated in DFT/input/{},
|
|
168
|
+
error_info += "If 'MP_N123' is not indicated in DFT/input/{}, the 'kspacing' param will be used to generate the 'MP_N123' parameter\n".format(os.path.basename(self.input_file))
|
|
169
169
|
raise Exception(error_info)
|
|
170
170
|
elif "MP_N123" not in key_values and self.kspacing is None:
|
|
171
171
|
self.kspacing = 0.5
|
|
@@ -31,7 +31,8 @@ class NetParam(object):
|
|
|
31
31
|
if "type_" in self.net_type:
|
|
32
32
|
dicts["physical_property"] = self.physical_property
|
|
33
33
|
#dicts["bias"] = self.bias,
|
|
34
|
-
#
|
|
34
|
+
# if self.resnet_dt is False:
|
|
35
|
+
# dicts["resnet_dt"] = self.resnet_dt
|
|
35
36
|
#dicts["activation"] = self.activation
|
|
36
37
|
return dicts
|
|
37
38
|
|
|
@@ -40,6 +41,7 @@ class ModelParam(object):
|
|
|
40
41
|
self.type_embedding_net = None
|
|
41
42
|
self.embedding_net = None
|
|
42
43
|
self.fitting_net = None
|
|
44
|
+
self.nep_param:NetParam = None
|
|
43
45
|
|
|
44
46
|
'''
|
|
45
47
|
description:
|
|
@@ -75,7 +77,9 @@ class ModelParam(object):
|
|
|
75
77
|
def set_nn_fitting_net(self, fitting_net_dict:dict):
|
|
76
78
|
# fitting_net_dict = get_parameter("fitting_net",json_input, {})
|
|
77
79
|
network_size = get_parameter("network_size", fitting_net_dict,[15,15,1])
|
|
78
|
-
if network_size
|
|
80
|
+
if not isinstance(network_size, list):
|
|
81
|
+
network_size = [network_size]
|
|
82
|
+
if len(network_size) > 1 and network_size[-1] != 1:
|
|
79
83
|
raise Exception("Error: The last layer of the fitting network should have a size of 1 for etot energy, but the input size is {}!".format(network_size[-1]))
|
|
80
84
|
bias = True # get_parameter("bias", fitting_net_dict, True)
|
|
81
85
|
resnet_dt = False # get_parameter("resnet_dt", fitting_net_dict, False)
|
|
@@ -91,4 +95,3 @@ class ModelParam(object):
|
|
|
91
95
|
# # dicts[self.fitting_net.net_type] = self.fitting_net.to_dict()
|
|
92
96
|
# return self.fitting_net.to_dict_std()
|
|
93
97
|
|
|
94
|
-
|