dpdispatcher 0.6.6__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dpdispatcher/_version.py CHANGED
@@ -1,16 +1,34 @@
1
- # file generated by setuptools_scm
1
+ # file generated by setuptools-scm
2
2
  # don't change, don't track in version control
3
+
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
12
+
3
13
  TYPE_CHECKING = False
4
14
  if TYPE_CHECKING:
5
- from typing import Tuple, Union
15
+ from typing import Tuple
16
+ from typing import Union
17
+
6
18
  VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
7
20
  else:
8
21
  VERSION_TUPLE = object
22
+ COMMIT_ID = object
9
23
 
10
24
  version: str
11
25
  __version__: str
12
26
  __version_tuple__: VERSION_TUPLE
13
27
  version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
30
+
31
+ __version__ = version = '1.0.0'
32
+ __version_tuple__ = version_tuple = (1, 0, 0)
14
33
 
15
- __version__ = version = '0.6.6'
16
- __version_tuple__ = version_tuple = (0, 6, 6)
34
+ __commit_id__ = commit_id = None
@@ -1,5 +1,5 @@
1
1
  from abc import ABCMeta, abstractmethod
2
- from typing import List, Tuple
2
+ from typing import Any, List, Tuple
3
3
 
4
4
  from dargs import Argument
5
5
 
@@ -73,6 +73,65 @@ class BaseContext(metaclass=ABCMeta):
73
73
  def check_finish(self, proc):
74
74
  raise NotImplementedError("abstract method")
75
75
 
76
+ def block_checkcall(self, cmd, asynchronously=False) -> Tuple[Any, Any, Any]:
77
+ """Run command with arguments. Wait for command to complete.
78
+
79
+ Parameters
80
+ ----------
81
+ cmd : str
82
+ The command to run.
83
+ asynchronously : bool, optional, default=False
84
+ Run command asynchronously. If True, `nohup` will be used to run the command.
85
+
86
+ Returns
87
+ -------
88
+ stdin
89
+ standard inout
90
+ stdout
91
+ standard output
92
+ stderr
93
+ standard error
94
+
95
+ Raises
96
+ ------
97
+ RuntimeError
98
+ when the return code is not zero
99
+ """
100
+ if asynchronously:
101
+ cmd = f"nohup {cmd} >/dev/null &"
102
+ exit_status, stdin, stdout, stderr = self.block_call(cmd)
103
+ if exit_status != 0:
104
+ raise RuntimeError(
105
+ "Get error code {} in calling {} with job: {} . message: {}".format(
106
+ exit_status,
107
+ cmd,
108
+ self.submission.submission_hash,
109
+ stderr.read().decode("utf-8"),
110
+ )
111
+ )
112
+ return stdin, stdout, stderr
113
+
114
+ @abstractmethod
115
+ def block_call(self, cmd) -> Tuple[int, Any, Any, Any]:
116
+ """Run command with arguments. Wait for command to complete.
117
+
118
+ Parameters
119
+ ----------
120
+ cmd : str
121
+ The command to run.
122
+
123
+ Returns
124
+ -------
125
+ exit_status
126
+ exit code
127
+ stdin
128
+ standard inout
129
+ stdout
130
+ standard output
131
+ stderr
132
+ standard error
133
+ """
134
+
76
135
  @classmethod
77
136
  def machine_arginfo(cls) -> Argument:
78
137
  """Generate the machine arginfo.
@@ -161,7 +161,9 @@ class BohriumContext(BaseContext):
161
161
  # return oss_task_zip
162
162
  # api.upload(self.oss_task_dir, zip_task_file)
163
163
 
164
- def download(self, submission):
164
+ def download(
165
+ self, submission, check_exists=False, mark_failure=True, back_error=False
166
+ ):
165
167
  jobs = submission.belonging_jobs
166
168
  job_hashs = {}
167
169
  job_infos = {}
@@ -335,6 +337,11 @@ class BohriumContext(BaseContext):
335
337
  )
336
338
  ]
337
339
 
340
+ def block_call(self, cmd):
341
+ raise RuntimeError(
342
+ "Unsupported method. You may use an unsupported combination of the machine and the context."
343
+ )
344
+
338
345
 
339
346
  DpCloudServerContext = BohriumContext
340
347
  LebesgueContext = BohriumContext
@@ -244,3 +244,8 @@ class HDFSContext(BaseContext):
244
244
 
245
245
  def read_file(self, fname):
246
246
  return HDFS.read_hdfs_file(os.path.join(self.remote_root, fname))
247
+
248
+ def block_call(self, cmd):
249
+ raise RuntimeError(
250
+ "Unsupported method. You may use an unsupported combination of the machine and the context."
251
+ )
@@ -83,7 +83,7 @@ class LazyLocalContext(BaseContext):
83
83
 
84
84
  def upload(
85
85
  self,
86
- jobs,
86
+ submission,
87
87
  # local_up_files,
88
88
  dereference=True,
89
89
  ):
@@ -91,7 +91,7 @@ class LazyLocalContext(BaseContext):
91
91
 
92
92
  def download(
93
93
  self,
94
- jobs,
94
+ submission,
95
95
  # remote_down_files,
96
96
  check_exists=False,
97
97
  mark_failure=True,
@@ -112,23 +112,6 @@ class LazyLocalContext(BaseContext):
112
112
  # else:
113
113
  # raise RuntimeError('do not find download file ' + fname)
114
114
 
115
- def block_checkcall(self, cmd):
116
- # script_dir = os.path.join(self.local_root, self.submission.work_base)
117
- # os.chdir(script_dir)
118
- proc = sp.Popen(
119
- cmd, cwd=self.local_root, shell=True, stdout=sp.PIPE, stderr=sp.PIPE
120
- )
121
- o, e = proc.communicate()
122
- stdout = SPRetObj(o)
123
- stderr = SPRetObj(e)
124
- code = proc.returncode
125
- if code != 0:
126
- raise RuntimeError(
127
- "Get error code %d in locally calling %s with job: %s ",
128
- (code, cmd, self.submission.submission_hash),
129
- )
130
- return None, stdout, stderr
131
-
132
115
  def block_call(self, cmd):
133
116
  proc = sp.Popen(
134
117
  cmd, cwd=self.local_root, shell=True, stdout=sp.PIPE, stderr=sp.PIPE
@@ -3,6 +3,9 @@ import shutil
3
3
  import subprocess as sp
4
4
  from glob import glob
5
5
  from subprocess import TimeoutExpired
6
+ from typing import List
7
+
8
+ from dargs import Argument
6
9
 
7
10
  from dpdispatcher.base_context import BaseContext
8
11
  from dpdispatcher.dlog import dlog
@@ -60,6 +63,7 @@ class LocalContext(BaseContext):
60
63
  self.temp_local_root = os.path.abspath(local_root)
61
64
  self.temp_remote_root = os.path.abspath(remote_root)
62
65
  self.remote_profile = remote_profile
66
+ self.symlink = remote_profile.get("symlink", True)
63
67
 
64
68
  @classmethod
65
69
  def load_from_dict(cls, context_dict):
@@ -83,6 +87,25 @@ class LocalContext(BaseContext):
83
87
  self.temp_remote_root, submission.submission_hash
84
88
  )
85
89
 
90
+ def _copy_from_local_to_remote(self, local_path, remote_path):
91
+ if not os.path.exists(local_path):
92
+ raise FileNotFoundError(
93
+ f"cannot find uploaded file {os.path.join(local_path)}"
94
+ )
95
+ if os.path.exists(remote_path):
96
+ os.remove(remote_path)
97
+ _check_file_path(remote_path)
98
+
99
+ if self.symlink:
100
+ # ensure the file exist
101
+ os.symlink(local_path, remote_path)
102
+ elif os.path.isfile(local_path):
103
+ shutil.copyfile(local_path, remote_path)
104
+ elif os.path.isdir(local_path):
105
+ shutil.copytree(local_path, remote_path)
106
+ else:
107
+ raise ValueError(f"Unknown file type: {local_path}")
108
+
86
109
  def upload(self, submission):
87
110
  os.makedirs(self.remote_root, exist_ok=True)
88
111
  for ii in submission.belonging_tasks:
@@ -103,14 +126,9 @@ class LocalContext(BaseContext):
103
126
  file_list.extend(rel_file_list)
104
127
 
105
128
  for jj in file_list:
106
- if not os.path.exists(os.path.join(local_job, jj)):
107
- raise FileNotFoundError(
108
- "cannot find upload file " + os.path.join(local_job, jj)
109
- )
110
- if os.path.exists(os.path.join(remote_job, jj)):
111
- os.remove(os.path.join(remote_job, jj))
112
- _check_file_path(os.path.join(remote_job, jj))
113
- os.symlink(os.path.join(local_job, jj), os.path.join(remote_job, jj))
129
+ self._copy_from_local_to_remote(
130
+ os.path.join(local_job, jj), os.path.join(remote_job, jj)
131
+ )
114
132
 
115
133
  local_job = self.local_root
116
134
  remote_job = self.remote_root
@@ -128,14 +146,9 @@ class LocalContext(BaseContext):
128
146
  file_list.extend(rel_file_list)
129
147
 
130
148
  for jj in file_list:
131
- if not os.path.exists(os.path.join(local_job, jj)):
132
- raise FileNotFoundError(
133
- "cannot find upload file " + os.path.join(local_job, jj)
134
- )
135
- if os.path.exists(os.path.join(remote_job, jj)):
136
- os.remove(os.path.join(remote_job, jj))
137
- _check_file_path(os.path.join(remote_job, jj))
138
- os.symlink(os.path.join(local_job, jj), os.path.join(remote_job, jj))
149
+ self._copy_from_local_to_remote(
150
+ os.path.join(local_job, jj), os.path.join(remote_job, jj)
151
+ )
139
152
 
140
153
  def download(
141
154
  self, submission, check_exists=False, mark_failure=True, back_error=False
@@ -288,21 +301,6 @@ class LocalContext(BaseContext):
288
301
  # no nothing in the case of linked files
289
302
  pass
290
303
 
291
- def block_checkcall(self, cmd):
292
- proc = sp.Popen(
293
- cmd, cwd=self.remote_root, shell=True, stdout=sp.PIPE, stderr=sp.PIPE
294
- )
295
- o, e = proc.communicate()
296
- stdout = SPRetObj(o)
297
- stderr = SPRetObj(e)
298
- code = proc.returncode
299
- if code != 0:
300
- raise RuntimeError(
301
- f"Get error code {code} in locally calling {cmd} with job: {self.submission.submission_hash}"
302
- f"\nStandard error: {stderr}"
303
- )
304
- return None, stdout, stderr
305
-
306
304
  def block_call(self, cmd):
307
305
  proc = sp.Popen(
308
306
  cmd, cwd=self.remote_root, shell=True, stdout=sp.PIPE, stderr=sp.PIPE
@@ -351,3 +349,31 @@ class LocalContext(BaseContext):
351
349
  stdout = None
352
350
  stderr = None
353
351
  return ret, stdout, stderr
352
+
353
+ @classmethod
354
+ def machine_subfields(cls) -> List[Argument]:
355
+ """Generate the machine subfields.
356
+
357
+ Returns
358
+ -------
359
+ list[Argument]
360
+ machine subfields
361
+ """
362
+ doc_remote_profile = "The information used to maintain the local machine."
363
+ return [
364
+ Argument(
365
+ "remote_profile",
366
+ dict,
367
+ optional=True,
368
+ doc=doc_remote_profile,
369
+ sub_fields=[
370
+ Argument(
371
+ "symlink",
372
+ bool,
373
+ optional=True,
374
+ default=True,
375
+ doc="Whether to use symbolic links to replace copy. This option should be turned off if the local directory is not accessible on the Batch system.",
376
+ ),
377
+ ],
378
+ )
379
+ ]
@@ -1,18 +1,20 @@
1
+ import glob
1
2
  import os
2
3
  import shutil
3
4
  import uuid
5
+ from zipfile import ZipFile
4
6
 
5
7
  import tqdm
6
8
 
7
9
  try:
8
- from bohriumsdk.client import Client
9
- from bohriumsdk.job import Job
10
- from bohriumsdk.storage import Storage
11
- from bohriumsdk.util import Util
12
- except ModuleNotFoundError:
10
+ from bohrium import Bohrium
11
+ from bohrium.resources import Job, Tiefblue
12
+ except ModuleNotFoundError as e:
13
13
  found_bohriumsdk = False
14
+ import_bohrium_error = e
14
15
  else:
15
16
  found_bohriumsdk = True
17
+ import_bohrium_error = None
16
18
 
17
19
  from dpdispatcher.base_context import BaseContext
18
20
  from dpdispatcher.dlog import dlog
@@ -23,6 +25,36 @@ DP_CLOUD_SERVER_HOME_DIR = os.path.join(
23
25
  )
24
26
 
25
27
 
28
+ def unzip_file(zip_file, out_dir="./"):
29
+ obj = ZipFile(zip_file, "r")
30
+ for item in obj.namelist():
31
+ obj.extract(item, out_dir)
32
+
33
+
34
+ def zip_file_list(root_path, zip_filename, file_list=[]):
35
+ out_zip_file = os.path.join(root_path, zip_filename)
36
+ # print('debug: file_list', file_list)
37
+ zip_obj = ZipFile(out_zip_file, "w")
38
+ for f in file_list:
39
+ matched_files = os.path.join(root_path, f)
40
+ for ii in glob.glob(matched_files):
41
+ # print('debug: matched_files:ii', ii)
42
+ if os.path.isdir(ii):
43
+ arcname = os.path.relpath(ii, start=root_path)
44
+ zip_obj.write(ii, arcname)
45
+ for root, dirs, files in os.walk(ii):
46
+ for file in files:
47
+ filename = os.path.join(root, file)
48
+ arcname = os.path.relpath(filename, start=root_path)
49
+ # print('debug: filename:arcname:root_path', filename, arcname, root_path)
50
+ zip_obj.write(filename, arcname)
51
+ else:
52
+ arcname = os.path.relpath(ii, start=root_path)
53
+ zip_obj.write(ii, arcname)
54
+ zip_obj.close()
55
+ return out_zip_file
56
+
57
+
26
58
  class OpenAPIContext(BaseContext):
27
59
  def __init__(
28
60
  self,
@@ -35,16 +67,41 @@ class OpenAPIContext(BaseContext):
35
67
  if not found_bohriumsdk:
36
68
  raise ModuleNotFoundError(
37
69
  "bohriumsdk not installed. Install dpdispatcher with `pip install dpdispatcher[bohrium]`"
38
- )
70
+ ) from import_bohrium_error
39
71
  self.init_local_root = local_root
40
72
  self.init_remote_root = remote_root
41
73
  self.temp_local_root = os.path.abspath(local_root)
42
74
  self.remote_profile = remote_profile
43
- self.client = Client()
44
- self.storage = Storage(client=self.client)
45
- self.job = Job(client=self.client)
46
- self.util = Util()
75
+ access_key = (
76
+ remote_profile.get("access_key", None)
77
+ or os.getenv("BOHRIUM_ACCESS_KEY", None)
78
+ or os.getenv("ACCESS_KEY", None)
79
+ )
80
+ project_id = (
81
+ remote_profile.get("project_id", None)
82
+ or os.getenv("BOHRIUM_PROJECT_ID", None)
83
+ or os.getenv("PROJECT_ID", None)
84
+ )
85
+ app_key = (
86
+ remote_profile.get("app_key", None)
87
+ or os.getenv("BOHRIUM_APP_KEY", None)
88
+ or os.getenv("APP_KEY", None)
89
+ )
90
+ if access_key is None:
91
+ raise ValueError(
92
+ "remote_profile must contain 'access_key' or set environment variable 'BOHRIUM_ACCESS_KEY'"
93
+ )
94
+ if project_id is None:
95
+ raise ValueError(
96
+ "remote_profile must contain 'project_id' or set environment variable 'BOHRIUM_PROJECT_ID'"
97
+ )
98
+ self.client = Bohrium( # type: ignore[reportPossiblyUnboundVariable]
99
+ access_key=access_key, project_id=project_id, app_key=app_key
100
+ )
101
+ self.storage = Tiefblue() # type: ignore[reportPossiblyUnboundVariable]
102
+ self.job = Job(client=self.client) # type: ignore[reportPossiblyUnboundVariable]
47
103
  self.jgid = None
104
+ os.makedirs(DP_CLOUD_SERVER_HOME_DIR, exist_ok=True)
48
105
 
49
106
  @classmethod
50
107
  def load_from_dict(cls, context_dict):
@@ -97,7 +154,7 @@ class OpenAPIContext(BaseContext):
97
154
  for file in task.forward_files:
98
155
  upload_file_list.append(os.path.join(task.task_work_path, file))
99
156
 
100
- upload_zip = Util.zip_file_list(
157
+ upload_zip = zip_file_list(
101
158
  self.local_root, zip_task_file, file_list=upload_file_list
102
159
  )
103
160
  project_id = self.remote_profile.get("project_id", 0)
@@ -113,7 +170,7 @@ class OpenAPIContext(BaseContext):
113
170
  object_key = os.path.join(data["storePath"], zip_filename) # type: ignore
114
171
  job.upload_path = object_key
115
172
  job.job_id = data["jobId"] # type: ignore
116
- job.jgid = data["jobGroupId"] # type: ignore
173
+ job.jgid = data.get("jobGroupId", "") # type: ignore
117
174
  self.storage.upload_From_file_multi_part(
118
175
  object_key=object_key, file_path=upload_zip, token=token
119
176
  )
@@ -149,7 +206,9 @@ class OpenAPIContext(BaseContext):
149
206
  # return oss_task_zip
150
207
  # api.upload(self.oss_task_dir, zip_task_file)
151
208
 
152
- def download(self, submission):
209
+ def download(
210
+ self, submission, check_exists=False, mark_failure=True, back_error=False
211
+ ):
153
212
  jobs = submission.belonging_jobs
154
213
  job_hashs = {}
155
214
  job_infos = {}
@@ -189,7 +248,7 @@ class OpenAPIContext(BaseContext):
189
248
  ):
190
249
  continue
191
250
  self.storage.download_from_url(info["resultUrl"], target_result_zip)
192
- Util.unzip_file(target_result_zip, out_dir=self.local_root)
251
+ unzip_file(target_result_zip, out_dir=self.local_root)
193
252
  self._backup(self.local_root, target_result_zip)
194
253
  self._clean_backup(
195
254
  self.local_root, keep_backup=self.remote_profile.get("keep_backup", True)
@@ -258,3 +317,8 @@ class OpenAPIContext(BaseContext):
258
317
  dir_to_be_removed = os.path.join(local_root, "backup")
259
318
  if os.path.exists(dir_to_be_removed):
260
319
  shutil.rmtree(dir_to_be_removed)
320
+
321
+ def block_call(self, cmd):
322
+ raise RuntimeError(
323
+ "Unsupported method. You may use an unsupported combination of the machine and the context."
324
+ )
@@ -44,6 +44,8 @@ class SSHSession:
44
44
  totp_secret=None,
45
45
  tar_compress=True,
46
46
  look_for_keys=True,
47
+ execute_command=None,
48
+ proxy_command=None,
47
49
  ):
48
50
  self.hostname = hostname
49
51
  self.username = username
@@ -56,6 +58,8 @@ class SSHSession:
56
58
  self.ssh = None
57
59
  self.tar_compress = tar_compress
58
60
  self.look_for_keys = look_for_keys
61
+ self.execute_command = execute_command
62
+ self.proxy_command = proxy_command
59
63
  self._keyboard_interactive_auth = False
60
64
  self._setup_ssh()
61
65
 
@@ -88,8 +92,7 @@ class SSHSession:
88
92
  while not self._check_alive():
89
93
  if count == max_check:
90
94
  raise RuntimeError(
91
- "cannot connect ssh after %d failures at interval %d s"
92
- % (max_check, sleep_time)
95
+ f"cannot connect ssh after {max_check} failures at interval {sleep_time} s"
93
96
  )
94
97
  dlog.info("connection check failed, try to reconnect to " + self.hostname)
95
98
  self._setup_ssh()
@@ -141,7 +144,12 @@ class SSHSession:
141
144
  # transport = self.ssh.get_transport()
142
145
  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
143
146
  sock.settimeout(self.timeout)
144
- sock.connect((self.hostname, self.port))
147
+
148
+ # Use ProxyCommand if configured (either directly or via jump host parameters)
149
+ if self.proxy_command is not None:
150
+ sock = paramiko.ProxyCommand(self.proxy_command)
151
+ else:
152
+ sock.connect((self.hostname, self.port))
145
153
 
146
154
  # Make a Paramiko Transport object using the socket
147
155
  ts = paramiko.Transport(sock)
@@ -162,7 +170,6 @@ class SSHSession:
162
170
  if os.path.exists(key_path):
163
171
  for pkey_class in (
164
172
  paramiko.RSAKey,
165
- paramiko.DSSKey,
166
173
  paramiko.ECDSAKey,
167
174
  paramiko.Ed25519Key,
168
175
  ):
@@ -180,7 +187,6 @@ class SSHSession:
180
187
  elif self.look_for_keys:
181
188
  for keytype, name in [
182
189
  (paramiko.RSAKey, "rsa"),
183
- (paramiko.DSSKey, "dsa"),
184
190
  (paramiko.ECDSAKey, "ecdsa"),
185
191
  (paramiko.Ed25519Key, "ed25519"),
186
192
  ]:
@@ -237,6 +243,8 @@ class SSHSession:
237
243
  self.ssh._transport = ts # type: ignore
238
244
  # reset sftp
239
245
  self._sftp = None
246
+ if self.execute_command is not None:
247
+ self.exec_command(self.execute_command)
240
248
 
241
249
  def inter_handler(self, title, instructions, prompt_list):
242
250
  """inter_handler: the callback for paramiko.transport.auth_interactive.
@@ -338,6 +346,10 @@ class SSHSession:
338
346
  doc_look_for_keys = (
339
347
  "enable searching for discoverable private key files in ~/.ssh/"
340
348
  )
349
+ doc_execute_command = "execute command after ssh connection is established."
350
+ doc_proxy_command = (
351
+ "ProxyCommand to use for SSH connection through intermediate servers."
352
+ )
341
353
  ssh_remote_profile_args = [
342
354
  Argument("hostname", str, optional=False, doc=doc_hostname),
343
355
  Argument("username", str, optional=False, doc=doc_username),
@@ -379,6 +391,20 @@ class SSHSession:
379
391
  default=True,
380
392
  doc=doc_look_for_keys,
381
393
  ),
394
+ Argument(
395
+ "execute_command",
396
+ str,
397
+ optional=True,
398
+ default=None,
399
+ doc=doc_execute_command,
400
+ ),
401
+ Argument(
402
+ "proxy_command",
403
+ [str, type(None)],
404
+ optional=True,
405
+ default=None,
406
+ doc=doc_proxy_command,
407
+ ),
382
408
  ]
383
409
  ssh_remote_profile_format = Argument(
384
410
  "ssh_session", dict, ssh_remote_profile_args
@@ -387,23 +413,37 @@ class SSHSession:
387
413
 
388
414
  def put(self, from_f, to_f):
389
415
  if self.rsync_available:
416
+ # For rsync, we need to use %h:%p placeholders for target host/port
417
+ proxy_cmd_rsync = None
418
+ if self.proxy_command is not None:
419
+ proxy_cmd_rsync = self.proxy_command.replace(
420
+ f"{self.hostname}:{self.port}", "%h:%p"
421
+ )
390
422
  return rsync(
391
423
  from_f,
392
424
  self.remote + ":" + to_f,
393
425
  port=self.port,
394
426
  key_filename=self.key_filename,
395
427
  timeout=self.timeout,
428
+ proxy_command=proxy_cmd_rsync,
396
429
  )
397
430
  return self.sftp.put(from_f, to_f)
398
431
 
399
432
  def get(self, from_f, to_f):
400
433
  if self.rsync_available:
434
+ # For rsync, we need to use %h:%p placeholders for target host/port
435
+ proxy_cmd_rsync = None
436
+ if self.proxy_command is not None:
437
+ proxy_cmd_rsync = self.proxy_command.replace(
438
+ f"{self.hostname}:{self.port}", "%h:%p"
439
+ )
401
440
  return rsync(
402
441
  self.remote + ":" + from_f,
403
442
  to_f,
404
443
  port=self.port,
405
444
  key_filename=self.key_filename,
406
445
  timeout=self.timeout,
446
+ proxy_command=proxy_cmd_rsync,
407
447
  )
408
448
  return self.sftp.get(from_f, to_f)
409
449
 
@@ -438,7 +478,9 @@ class SSHContext(BaseContext):
438
478
  self.init_local_root = local_root
439
479
  self.init_remote_root = remote_root
440
480
  self.temp_local_root = os.path.abspath(local_root)
441
- assert os.path.isabs(remote_root), "remote_root must be a abspath"
481
+ assert os.path.isabs(os.path.realpath(remote_root)), (
482
+ "remote_root must be a abspath"
483
+ )
442
484
  self.temp_remote_root = remote_root
443
485
  self.remote_profile = remote_profile
444
486
  self.remote_root = None
@@ -755,41 +797,6 @@ class SSHContext(BaseContext):
755
797
  tar_compress=self.remote_profile.get("tar_compress", None),
756
798
  )
757
799
 
758
- def block_checkcall(self, cmd, asynchronously=False, stderr_whitelist=None):
759
- """Run command with arguments. Wait for command to complete. If the return code
760
- was zero then return, otherwise raise RuntimeError.
761
-
762
- Parameters
763
- ----------
764
- cmd : str
765
- The command to run.
766
- asynchronously : bool, optional, default=False
767
- Run command asynchronously. If True, `nohup` will be used to run the command.
768
- stderr_whitelist : list of str, optional, default=None
769
- If not None, the stderr will be checked against the whitelist. If the stderr
770
- contains any of the strings in the whitelist, the command will be considered
771
- successful.
772
- """
773
- assert self.remote_root is not None
774
- self.ssh_session.ensure_alive()
775
- if asynchronously:
776
- cmd = f"nohup {cmd} >/dev/null &"
777
- stdin, stdout, stderr = self.ssh_session.exec_command(
778
- (f"cd {shlex.quote(self.remote_root)} ;") + cmd
779
- )
780
- exit_status = stdout.channel.recv_exit_status()
781
- if exit_status != 0:
782
- raise RuntimeError(
783
- "Get error code %d in calling %s through ssh with job: %s . message: %s"
784
- % (
785
- exit_status,
786
- cmd,
787
- self.submission.submission_hash,
788
- stderr.read().decode("utf-8"),
789
- )
790
- )
791
- return stdin, stdout, stderr
792
-
793
800
  def block_call(self, cmd):
794
801
  assert self.remote_root is not None
795
802
  self.ssh_session.ensure_alive()
@@ -849,8 +856,8 @@ class SSHContext(BaseContext):
849
856
  # print(pid)
850
857
  return {"stdin": stdin, "stdout": stdout, "stderr": stderr}
851
858
 
852
- def check_finish(self, cmd_pipes):
853
- return cmd_pipes["stdout"].channel.exit_status_ready()
859
+ def check_finish(self, proc):
860
+ return proc["stdout"].channel.exit_status_ready()
854
861
 
855
862
  def get_return(self, cmd_pipes):
856
863
  if not self.check_finish(cmd_pipes):
@@ -912,11 +919,11 @@ class SSHContext(BaseContext):
912
919
  # local tar
913
920
  if os.path.isfile(os.path.join(self.local_root, of)):
914
921
  os.remove(os.path.join(self.local_root, of))
915
- with tarfile.open(
922
+ with tarfile.open( # type: ignore[reportCallIssue, reportArgumentType]
916
923
  os.path.join(self.local_root, of),
917
- tarfile_mode,
918
- dereference=dereference,
919
- **kwargs,
924
+ mode=tarfile_mode, # type: ignore[reportArgumentType]
925
+ dereference=dereference, # type: ignore[reportArgumentType]
926
+ **kwargs, # type: ignore[reportArgumentType]
920
927
  ) as tar:
921
928
  # avoid compressing duplicated files or directories
922
929
  for ii in set(files):