ion-CSP 2.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,487 @@
1
+ import os
2
+ import re
3
+ import json
4
+ import time
5
+ import logging
6
+ import getpass
7
+ import paramiko
8
+ from stat import S_ISDIR
9
+ from collections import deque
10
+ from typing import List, Dict
11
+
12
+
13
+ class SSHBatchJob:
14
+
15
+ def __init__(
16
+ self, work_dir: str, machine_json: str, machine_type: str = "ssh_direct"
17
+ ):
18
+ self.base_dir = work_dir
19
+ os.chdir(self.base_dir)
20
+ self.folder_name = os.path.normpath(os.path.abspath(work_dir)).split(os.sep)[-1]
21
+ self.upload_folder = f"{self.folder_name}/3_for_vasp_opt"
22
+ self.download_folder = "4_vasp_optimized"
23
+ # 本地的目标文件夹路径
24
+ self.local_folder_dir = f"{os.path.dirname(self.base_dir)}/{self.upload_folder}"
25
+ # 加载配置文件
26
+ with open(machine_json, "r") as mf:
27
+ self.machine_config = json.load(mf)
28
+ self.remote_dir = self.machine_config["remote_root"]
29
+ self.remote_task_dir = f"{self.remote_dir}/{self.upload_folder}"
30
+ remote_profile = self.machine_config["remote_profile"]
31
+ if machine_type == "ssh_direct":
32
+ try:
33
+ # 创建 SSH 客户端并连接到服务器,支持超时设置
34
+ self.client = paramiko.SSHClient()
35
+ self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
36
+ self.client.connect(
37
+ hostname=remote_profile["hostname"],
38
+ username=remote_profile["username"],
39
+ password=remote_profile["password"],
40
+ port=remote_profile["port"],
41
+ look_for_keys=remote_profile["look_for_keys"],
42
+ timeout=10,
43
+ )
44
+ self.sftp = self.client.open_sftp()
45
+ print(
46
+ f"Direct SSH connection with {machine_json.split('_machine.json')[0]} established successfully."
47
+ )
48
+ logging.info(
49
+ f"Direct SSH connection with {machine_json.split('_machine.json')[0]} established successfully."
50
+ )
51
+ except Exception as e:
52
+ logging.error(
53
+ f"Failed to establish direct SSH connection with {machine_json.split('_machine.json')[0]}: {e}"
54
+ )
55
+ raise
56
+ if machine_type == "jumper":
57
+ # 创建跳板机 SSH 客户端
58
+ jumper_client = paramiko.SSHClient()
59
+ jumper_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
60
+ try:
61
+ # 连接到跳板机
62
+ jumper_profile = self.machine_config["jumper_profile"]
63
+ jumper_client.connect(
64
+ hostname=jumper_profile["hostname"],
65
+ username=jumper_profile["username"],
66
+ port=jumper_profile["port"],
67
+ key_filename=jumper_profile["key_filename"],
68
+ timeout=10,
69
+ )
70
+ # 创建一个通道,并建立代理通道
71
+ jumper_transport = jumper_client.get_transport()
72
+ src_addr = (jumper_profile["hostname"], jumper_profile["port"])
73
+ dest_addr = (remote_profile["hostname"], remote_profile["port"])
74
+ jumper_channel = jumper_transport.open_channel(
75
+ kind="direct-tcpip", dest_addr=dest_addr, src_addr=src_addr
76
+ )
77
+ print("Jumper connection established successfully")
78
+ logging.info("Jumper connection established successfully")
79
+ # 创建 SSH 客户端并连接到服务器,支持超时设置
80
+ self.client = paramiko.SSHClient()
81
+ self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
82
+ self.client.connect(
83
+ hostname=remote_profile["hostname"],
84
+ username=remote_profile["username"],
85
+ password=remote_profile["password"],
86
+ port=remote_profile["port"],
87
+ sock=jumper_channel,
88
+ look_for_keys=remote_profile["look_for_keys"],
89
+ timeout=10,
90
+ )
91
+ self.sftp = self.client.open_sftp()
92
+ print(
93
+ f"SSH jumper connection with {machine_json.split('_machine.json')[0]} established successfully."
94
+ )
95
+ logging.info(
96
+ f"SSH jumper connection with {machine_json.split('_machine.json')[0]} established successfully."
97
+ )
98
+ except Exception as e:
99
+ logging.error(f"Failed to establish SSH connection: {e}")
100
+ raise
101
+ if machine_type == "2FA":
102
+ try:
103
+ # 获取 machine.json 中的固定部分密码
104
+ fixed_password = remote_profile["password"]
105
+ # 获取动态验证码
106
+ dynamic_code = getpass.getpass(
107
+ prompt="请输入Authentifactor中的动态验证码: "
108
+ )
109
+ # 组合完整的密码
110
+ full_password = f"{fixed_password}{dynamic_code}"
111
+ self.client = paramiko.SSHClient()
112
+ self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
113
+ self.client.connect(
114
+ hostname=remote_profile["hostname"],
115
+ username=remote_profile["username"],
116
+ password=full_password,
117
+ port=remote_profile["port"],
118
+ look_for_keys=remote_profile["look_for_keys"],
119
+ timeout=10,
120
+ )
121
+ self.sftp = self.client.open_sftp()
122
+ print(
123
+ f"Direct SSH connection with {machine_json.split('_machine.json')[0]} established successfully."
124
+ )
125
+ logging.info(
126
+ f"Direct SSH connection with {machine_json.split('_machine.json')[0]} established successfully."
127
+ )
128
+ except Exception as e:
129
+ logging.error(
130
+ f"Failed to establish direct SSH connection with {machine_json.split('_machine.json')[0]}: {e}"
131
+ )
132
+ raise
133
+
134
+ def _execute_command(self, command: str):
135
+ """执行命令,支持重试机制"""
136
+ output, error = None, None
137
+ for attempt in range(3): # 重试 3 次
138
+ try:
139
+ _, stdout, stderr = self.client.exec_command(command)
140
+ output, error = stdout.read().decode(), stderr.read().decode()
141
+ logging.info(output)
142
+ print(output)
143
+ if error:
144
+ logging.error(error)
145
+ raise Exception(f"Error executing command: {error}")
146
+ break # 成功后跳出重试循环
147
+ except Exception as e:
148
+ print(f"Error executing command: {e}. Retrying...")
149
+ time.sleep(5) # 等待 5 秒后重试
150
+ return output, error
151
+
152
+ def _upload_files(self, local_dir: str, local_files: List[str], remote_dir: str):
153
+ """上传文件到远程服务器,支持重试机制"""
154
+ for local_file in local_files:
155
+ local_path = os.path.join(local_dir, local_file)
156
+ remote_path = os.path.join(remote_dir, local_file)
157
+ try:
158
+ self.sftp.stat(remote_dir)
159
+ except FileNotFoundError:
160
+ self.sftp.mkdir(remote_dir)
161
+ for attempt in range(3): # 重试 3 次
162
+ try:
163
+ self.sftp.put(local_path, remote_path)
164
+ print(f"Uploaded successful: from {local_path} to {remote_path}")
165
+ logging.info(
166
+ f"Uploaded successful: from {local_path} to {remote_path}"
167
+ )
168
+ break # 成功后跳出重试循环
169
+ except Exception as e:
170
+ print(f"Error uploading {local_path}: {e}. Retrying...")
171
+ logging.error(f"Error uploading {local_path}: {e}. Retrying...")
172
+ time.sleep(2) # 等待 2 秒后重试
173
+
174
+ def _batch_prepare(self, file_config: Dict[str, list[str]]):
175
+ """
176
+ Prepare files for upload and download based on file configuration.
177
+
178
+ Example Parameter:
179
+ file_config = {
180
+ 'upload_prefixes': ['POSCAR_'],
181
+ 'upload_suffixes': ['.gjf'],
182
+ 'download_prefixes': ['CONTCAR_'],
183
+ 'download_suffixes': ['.log', 'fchk']
184
+ }
185
+ """
186
+ upload_prefixes = file_config.get("upload_prefixes", [])
187
+ upload_suffixes = file_config.get("upload_suffixes", [])
188
+ download_prefixes = file_config.get("download_prefixes", [])
189
+ download_suffixes = file_config.get("download_suffixes", [])
190
+ self.batch_forward_json = []
191
+ # 根据给定的“前缀”选择要上传的文件
192
+ if upload_prefixes:
193
+ for upload_prefix in upload_prefixes:
194
+ upload_prefix_files = [
195
+ f
196
+ for f in os.listdir(self.local_folder_dir)
197
+ if f.startswith(upload_prefix)
198
+ ]
199
+ self.forward_files.extend(upload_prefix_files)
200
+ self.batch_forward_json.extend(
201
+ [f[len(upload_prefix) :], upload_prefix] for f in self.forward_files
202
+ )
203
+ # 可以根据上传文件的名字以及给定的“前缀”设定作业后所要下载的文件名
204
+ if download_prefixes:
205
+ for download_prefix in download_prefixes:
206
+ self.backward_files.extend(
207
+ f"{download_prefix}{f[len(upload_prefix) :]}"
208
+ for f in upload_prefix_files
209
+ )
210
+ # 根据给定的“后缀”选择要上传的文件
211
+ if upload_suffixes:
212
+ for upload_suffix in upload_suffixes:
213
+ upload_suffix_files = [
214
+ f
215
+ for f in os.listdir(self.local_folder_dir)
216
+ if f.endswith(upload_suffix)
217
+ ]
218
+ self.forward_files.extend(upload_suffix_files)
219
+ self.batch_forward_json.extend(
220
+ [f[: -len(upload_suffix)], upload_suffix]
221
+ for f in self.forward_files
222
+ )
223
+ # 可以根据上传文件的名字以及给定的“后缀”设定作业后所要下载的文件名
224
+ if download_suffixes:
225
+ for download_suffix in download_suffixes:
226
+ self.backward_files.extend(
227
+ f"{f[: -len(upload_suffix)]}{download_suffix}"
228
+ for f in upload_suffix_files
229
+ )
230
+
231
+ def prepare_and_submit(
232
+ self,
233
+ command: str,
234
+ forward_common_files: List[str] = [],
235
+ upload_files: List[str] = [],
236
+ download_files: List[str] = [],
237
+ batch_config: Dict[str, list[str]] = None,
238
+ ):
239
+ # 确保参数为文件名的字符串列表,否则抛出类型异常
240
+ if not isinstance(forward_common_files, list):
241
+ raise TypeError(
242
+ f"Expected a list of strings, but received: {type(forward_common_files).__name__}"
243
+ )
244
+ # 在远程服务器上创建任务目录
245
+ self._execute_command(f"mkdir -p {self.remote_task_dir}")
246
+ if forward_common_files:
247
+ self._upload_files(
248
+ os.path.dirname(__file__),
249
+ [file for file in forward_common_files],
250
+ self.remote_dir,
251
+ )
252
+
253
+ # 针对专门的少数任务,可手动设定上传与下载的文件
254
+ self.forward_files = upload_files
255
+ self.backward_files = download_files
256
+ if batch_config:
257
+ self._batch_prepare(batch_config)
258
+ self.backward_files = list(set(self.backward_files))
259
+ # 输出所有的上传文件列表和下载文件列表并在日志中记录
260
+ print(f"Forward_files: {self.forward_files}")
261
+ print(f"Backward_files: {self.backward_files}")
262
+ logging.info(f"Forward_files: {self.forward_files}")
263
+ logging.info(f"Backward_files: {self.backward_files}")
264
+ # 记录在json文件中,方便在ssh连接中断后下载文件
265
+ with open(
266
+ f"{self.local_folder_dir}/forward_batch_files.json", "w"
267
+ ) as json_file:
268
+ # 注意:forward_files.json中存放的是文件名与前后缀分开的键值对
269
+ json.dump(self.batch_forward_json, json_file, indent=4)
270
+ if self.backward_files:
271
+ with open(
272
+ f"{self.local_folder_dir}/backward_batch_files.json", "w"
273
+ ) as json_file:
274
+ # 注意:backward_files.json中存放的是完整的文件名列表
275
+ json.dump(self.backward_files, json_file, indent=4)
276
+
277
+ # 上传文件到远程服务器
278
+ self._upload_files(
279
+ self.local_folder_dir, [f for f in self.forward_files], self.remote_task_dir
280
+ )
281
+ try:
282
+ # 执行提交命令
283
+ output, _ = self._execute_command(f"cd {self.remote_dir}; {command}")
284
+ # 正则表达式匹配 Job ID
285
+ pattern_slurm = r"Submitted batch job (\d+)"
286
+ pattern_lsf = r"Job <(\d+)> is submitted to queue <normal>"
287
+ # 使用 re.findall 查找匹配所有输出内容
288
+ matches_slurm = re.findall(pattern_slurm, output)
289
+ matches_lsf = re.findall(pattern_lsf, output)
290
+ # 合并所有匹配的 Job ID
291
+ job_ids = matches_slurm + matches_lsf
292
+ if job_ids:
293
+ print(f"Captured Job IDs: {job_ids}")
294
+ logging.info(f"Captured Job IDs: {job_ids}")
295
+ with open(
296
+ f"{self.local_folder_dir}/submitted_job_ids.json", "w"
297
+ ) as json_file:
298
+ json.dump(job_ids, json_file, indent=4)
299
+ else:
300
+ print("No Job IDs found in command output.")
301
+ except Exception as e:
302
+ print(f"Error executing command: {e}")
303
+
304
+ def upload_entire_folder(self, local_folder: str, remote_folder: str):
305
+ """Upload entire local folder to remote folder"""
306
+ local_dir = os.path.join(self.base_dir, local_folder)
307
+ remote_dir = os.path.join(self.remote_dir, remote_folder)
308
+
309
+ # 创建远程目录,如果不存在的话
310
+ try:
311
+ self.sftp.mkdir(remote_dir)
312
+ except IOError: # 目录可能已经存在
313
+ pass
314
+
315
+ # 使用队列来管理待处理的文件夹
316
+ folders = deque([local_dir])
317
+ while folders:
318
+ current_folder = folders.popleft() # 获取当前处理的文件夹
319
+ # 列出当前文件夹中的所有文件和子文件夹
320
+ for item in os.listdir(current_folder):
321
+ local_path = os.path.join(current_folder, item)
322
+ remote_path = os.path.join(remote_dir, item)
323
+
324
+ if os.path.isdir(local_path): # 如果是目录,加入队列
325
+ # 创建远程对应的文件夹
326
+ try:
327
+ self.sftp.mkdir(remote_path)
328
+ except IOError: # 目录可能已经存在
329
+ pass
330
+ folders.append(local_path)
331
+ else: # 如果是文件,上传文件
332
+ print(f"Uploading from {local_path} to {remote_path}")
333
+ self.sftp.put(local_path, remote_path)
334
+
335
+ def download_entire_folder(
336
+ self, remote_folder: str = None, local_folder: str = None
337
+ ):
338
+ """Download entire remote folder to local folder"""
339
+ # if check_job_ids:
340
+ # with open(f'{self.folder}/submitted_job_ids.json', 'w') as json_file:
341
+ # job_ids = json.load(json_file)
342
+ if not remote_folder:
343
+ remote_folder = self.upload_folder
344
+ if not local_folder:
345
+ local_folder = self.download_folder
346
+ local_dir = os.path.join(self.base_dir, local_folder)
347
+ os.makedirs(local_dir, exist_ok=True)
348
+ remote_dir = os.path.join(self.remote_dir, remote_folder)
349
+ # 使用队列来管理待处理的文件夹
350
+ folders = deque([remote_dir])
351
+ while folders:
352
+ current_folder = folders.popleft() # 获取当前处理的文件夹
353
+ # 列出当前文件夹中的所有文件和子文件夹
354
+ for item in self.sftp.listdir_attr(current_folder):
355
+ remote_path = os.path.join(current_folder, item.filename)
356
+ relative_path = os.path.relpath(remote_path, start=remote_dir)
357
+ local_path = os.path.join(local_dir, relative_path)
358
+ if S_ISDIR(item.st_mode): # 如果是目录,加入队列
359
+ # 创建本地对应的文件夹
360
+ if not os.path.exists(local_path):
361
+ os.makedirs(local_path)
362
+ folders.append(remote_path)
363
+ else: # 如果是文件,下载文件
364
+ os.makedirs(os.path.dirname(local_path), exist_ok=True)
365
+ print(f"Downloading {remote_path} to {local_path}")
366
+ self.sftp.get(remote_path, local_path)
367
+
368
+ def download_from_json(
369
+ self,
370
+ download_files: List[str] = [],
371
+ download_prefixes: List[str] = [],
372
+ download_suffixes: List[str] = [],
373
+ ):
374
+ """
375
+ Due to the construction of a JSON storage file with file name information and prefixes and suffixes when uploading files, selective batch downloads can be performed based on the file name information in the JSON file and the given prefixes and suffixes of the files to be downloaded
376
+ """
377
+ results_dir = f"{self.local_folder_dir}/results"
378
+ os.makedirs(results_dir, exist_ok=True)
379
+ backward_files = download_files
380
+ try:
381
+ with open(f"{self.local_folder_dir}/backward_files.json", "r") as json_file:
382
+ backward_files.extend(json.load(json_file))
383
+ if not backward_files:
384
+ raise FileNotFoundError
385
+ with open(f"{self.local_folder_dir}/forward_files.json", "r") as json_file:
386
+ forward_json = json.load(json_file)
387
+ if download_prefixes:
388
+ for download_prefix in download_prefixes:
389
+ backward_files.extend(
390
+ [f"{download_prefix}{f}" for f in forward_json.keys()]
391
+ )
392
+ if download_suffixes:
393
+ for download_suffix in download_suffixes:
394
+ backward_files.extend(
395
+ [f"{f}{download_suffix}" for f in forward_json.keys()]
396
+ )
397
+ except FileNotFoundError as e:
398
+ logging.error(e)
399
+ for remote_file in backward_files:
400
+ local_file = os.path.join(results_dir, os.path.basename(remote_file))
401
+ for attempt in range(3): # 重试 3 次
402
+ try:
403
+ remote_file_path = os.path.join(self.remote_task_dir, remote_file)
404
+ self.sftp.stat(remote_file_path)
405
+ self.sftp.get(remote_file_path, local_file)
406
+ print(
407
+ f"Downloaded {remote_file} from {self.remote_task_dir} to {local_file}"
408
+ )
409
+ logging.info(
410
+ f"Downloaded {remote_file} from {self.remote_task_dir} to {local_file}"
411
+ )
412
+ break # 成功后跳出重试循环
413
+ except FileNotFoundError:
414
+ print(
415
+ f"File {remote_file} not found in {self.remote_task_dir} on remote server."
416
+ )
417
+ logging.error(
418
+ f"File {remote_file} not found in {self.remote_task_dir} on remote server."
419
+ )
420
+ break # 文件未找到,跳出重试循环
421
+ except Exception as e:
422
+ print(f"Error downloading {remote_file}: {e}. Retrying...")
423
+ logging.error(f"Error downloading {remote_file}: {e}. Retrying...")
424
+ time.sleep(2) # 等待 2 秒后重试
425
+
426
+ def download_from_condition(
427
+ self, prefixes: List[str] = [], suffixes: List[str] = []
428
+ ):
429
+ """Download all files with specified prefix or suffix conditions from the specified remote server directory"""
430
+ # 如果没有提供前缀和后缀
431
+ if not prefixes and not suffixes:
432
+ logging.error("No prefixes or suffixes provided.")
433
+ raise Exception("No prefixes or suffixes provided.")
434
+ # 确保本地目录存在
435
+ os.makedirs(self.local_folder_dir, exist_ok=True)
436
+ # 列出远程目录中的文件
437
+ remote_files = self.sftp.listdir(self.remote_task_dir)
438
+ # 用于跟踪是否有文件被匹配和下载
439
+ matched_files = False
440
+ # 跟踪每个前缀和后缀的匹配情况
441
+ unmatched_prefixes, unmatched_suffixes = set(prefixes), set(suffixes)
442
+ for file_name in remote_files:
443
+ # 如果提供了前缀
444
+ if prefixes:
445
+ for prefix in prefixes:
446
+ if file_name.startswith(prefix):
447
+ remote_file_path = os.path.join(self.remote_task_dir, file_name)
448
+ local_file_path = os.path.join(self.local_folder_dir, file_name)
449
+ # 下载文件
450
+ self.sftp.get(remote_file_path, local_file_path)
451
+ print(f"Downloaded: {remote_file_path} to {local_file_path}")
452
+ logging.info(
453
+ f"Downloaded: {remote_file_path} to {local_file_path}"
454
+ )
455
+ matched_files = True
456
+ unmatched_prefixes.discard(prefix) # 移除已匹配的前缀
457
+ # 如果提供了后缀
458
+ if suffixes:
459
+ for suffix in suffixes:
460
+ if file_name.endswith(suffix):
461
+ remote_file_path = os.path.join(self.remote_task_dir, file_name)
462
+ local_file_path = os.path.join(self.local_folder_dir, file_name)
463
+ # 下载文件
464
+ self.sftp.get(remote_file_path, local_file_path)
465
+ print(f"Downloaded: {remote_file_path} to {local_file_path}")
466
+ logging.info(
467
+ f"Downloaded: {remote_file_path} to {local_file_path}"
468
+ )
469
+ matched_files = True
470
+ unmatched_suffixes.discard(suffix) # 移除已匹配的后缀
471
+ # 输出未匹配到的前缀
472
+ for prefix in unmatched_prefixes:
473
+ print(f"Error: No files matched the given prefix: {prefix}")
474
+ logging.error(f"Error: No files matched the given prefix: {prefix}")
475
+ # 输出未匹配到的后缀
476
+ for suffix in unmatched_suffixes:
477
+ print(f"Error: No files matched the given suffix: {suffix}")
478
+ logging.error(f"Error: No files matched the given suffix: {suffix}")
479
+ # 如果没有匹配到任何文件,输出错误信息
480
+ if not matched_files:
481
+ print("Error: No files matched the given prefixes or suffixes.")
482
+ logging.error("Error: No files matched the given prefixes or suffixes.")
483
+
484
+ def close_connection(self):
485
+ self.sftp.close()
486
+ self.client.close()
487
+