dpdispatcher 0.6.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. dpdispatcher/_version.py +22 -4
  2. dpdispatcher/base_context.py +60 -1
  3. dpdispatcher/contexts/__init__.py +1 -0
  4. dpdispatcher/contexts/dp_cloud_server_context.py +8 -1
  5. dpdispatcher/contexts/hdfs_context.py +16 -11
  6. dpdispatcher/contexts/lazy_local_context.py +2 -19
  7. dpdispatcher/contexts/local_context.py +77 -43
  8. dpdispatcher/contexts/openapi_context.py +78 -14
  9. dpdispatcher/contexts/ssh_context.py +117 -98
  10. dpdispatcher/dlog.py +9 -5
  11. dpdispatcher/dpcloudserver/__init__.py +0 -0
  12. dpdispatcher/dpcloudserver/client.py +7 -0
  13. dpdispatcher/dpdisp.py +21 -0
  14. dpdispatcher/entrypoints/run.py +9 -0
  15. dpdispatcher/entrypoints/submission.py +21 -1
  16. dpdispatcher/machine.py +15 -4
  17. dpdispatcher/machines/JH_UniScheduler.py +171 -0
  18. dpdispatcher/machines/__init__.py +1 -0
  19. dpdispatcher/machines/distributed_shell.py +6 -10
  20. dpdispatcher/machines/fugaku.py +9 -12
  21. dpdispatcher/machines/lsf.py +3 -9
  22. dpdispatcher/machines/openapi.py +48 -15
  23. dpdispatcher/machines/pbs.py +183 -20
  24. dpdispatcher/machines/shell.py +7 -16
  25. dpdispatcher/machines/slurm.py +30 -42
  26. dpdispatcher/run.py +172 -0
  27. dpdispatcher/submission.py +5 -14
  28. dpdispatcher/utils/dpcloudserver/client.py +10 -6
  29. dpdispatcher/utils/hdfs_cli.py +10 -19
  30. dpdispatcher/utils/utils.py +21 -7
  31. {dpdispatcher-0.6.1.dist-info → dpdispatcher-1.0.0.dist-info}/METADATA +35 -29
  32. dpdispatcher-1.0.0.dist-info/RECORD +49 -0
  33. {dpdispatcher-0.6.1.dist-info → dpdispatcher-1.0.0.dist-info}/WHEEL +1 -1
  34. dpdispatcher-0.6.1.dist-info/RECORD +0 -44
  35. {dpdispatcher-0.6.1.dist-info → dpdispatcher-1.0.0.dist-info}/entry_points.txt +0 -0
  36. {dpdispatcher-0.6.1.dist-info → dpdispatcher-1.0.0.dist-info/licenses}/LICENSE +0 -0
  37. {dpdispatcher-0.6.1.dist-info → dpdispatcher-1.0.0.dist-info}/top_level.txt +0 -0
@@ -44,6 +44,8 @@ class SSHSession:
44
44
  totp_secret=None,
45
45
  tar_compress=True,
46
46
  look_for_keys=True,
47
+ execute_command=None,
48
+ proxy_command=None,
47
49
  ):
48
50
  self.hostname = hostname
49
51
  self.username = username
@@ -56,6 +58,8 @@ class SSHSession:
56
58
  self.ssh = None
57
59
  self.tar_compress = tar_compress
58
60
  self.look_for_keys = look_for_keys
61
+ self.execute_command = execute_command
62
+ self.proxy_command = proxy_command
59
63
  self._keyboard_interactive_auth = False
60
64
  self._setup_ssh()
61
65
 
@@ -88,8 +92,7 @@ class SSHSession:
88
92
  while not self._check_alive():
89
93
  if count == max_check:
90
94
  raise RuntimeError(
91
- "cannot connect ssh after %d failures at interval %d s"
92
- % (max_check, sleep_time)
95
+ f"cannot connect ssh after {max_check} failures at interval {sleep_time} s"
93
96
  )
94
97
  dlog.info("connection check failed, try to reconnect to " + self.hostname)
95
98
  self._setup_ssh()
@@ -141,7 +144,12 @@ class SSHSession:
141
144
  # transport = self.ssh.get_transport()
142
145
  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
143
146
  sock.settimeout(self.timeout)
144
- sock.connect((self.hostname, self.port))
147
+
148
+ # Use ProxyCommand if configured (either directly or via jump host parameters)
149
+ if self.proxy_command is not None:
150
+ sock = paramiko.ProxyCommand(self.proxy_command)
151
+ else:
152
+ sock.connect((self.hostname, self.port))
145
153
 
146
154
  # Make a Paramiko Transport object using the socket
147
155
  ts = paramiko.Transport(sock)
@@ -162,7 +170,6 @@ class SSHSession:
162
170
  if os.path.exists(key_path):
163
171
  for pkey_class in (
164
172
  paramiko.RSAKey,
165
- paramiko.DSSKey,
166
173
  paramiko.ECDSAKey,
167
174
  paramiko.Ed25519Key,
168
175
  ):
@@ -180,7 +187,6 @@ class SSHSession:
180
187
  elif self.look_for_keys:
181
188
  for keytype, name in [
182
189
  (paramiko.RSAKey, "rsa"),
183
- (paramiko.DSSKey, "dsa"),
184
190
  (paramiko.ECDSAKey, "ecdsa"),
185
191
  (paramiko.Ed25519Key, "ed25519"),
186
192
  ]:
@@ -237,6 +243,8 @@ class SSHSession:
237
243
  self.ssh._transport = ts # type: ignore
238
244
  # reset sftp
239
245
  self._sftp = None
246
+ if self.execute_command is not None:
247
+ self.exec_command(self.execute_command)
240
248
 
241
249
  def inter_handler(self, title, instructions, prompt_list):
242
250
  """inter_handler: the callback for paramiko.transport.auth_interactive.
@@ -295,12 +303,16 @@ class SSHSession:
295
303
  assert self.ssh is not None
296
304
  try:
297
305
  return self.ssh.exec_command(cmd)
298
- except (paramiko.ssh_exception.SSHException, socket.timeout) as e:
306
+ except (
307
+ paramiko.ssh_exception.SSHException,
308
+ socket.timeout,
309
+ EOFError,
310
+ ) as e:
299
311
  # SSH session not active
300
312
  # retry for up to 3 times
301
313
  # ensure alive
302
314
  self.ensure_alive()
303
- raise RetrySignal("SSH session not active in calling %s" % cmd) from e
315
+ raise RetrySignal(f"SSH session not active in calling {cmd}") from e
304
316
 
305
317
  @property
306
318
  def sftp(self):
@@ -334,6 +346,10 @@ class SSHSession:
334
346
  doc_look_for_keys = (
335
347
  "enable searching for discoverable private key files in ~/.ssh/"
336
348
  )
349
+ doc_execute_command = "execute command after ssh connection is established."
350
+ doc_proxy_command = (
351
+ "ProxyCommand to use for SSH connection through intermediate servers."
352
+ )
337
353
  ssh_remote_profile_args = [
338
354
  Argument("hostname", str, optional=False, doc=doc_hostname),
339
355
  Argument("username", str, optional=False, doc=doc_username),
@@ -355,10 +371,18 @@ class SSHSession:
355
371
  ),
356
372
  Argument("timeout", int, optional=True, default=10, doc=doc_timeout),
357
373
  Argument(
358
- "totp_secret", str, optional=True, default=None, doc=doc_totp_secret
374
+ "totp_secret",
375
+ str,
376
+ optional=True,
377
+ default=None,
378
+ doc=doc_totp_secret,
359
379
  ),
360
380
  Argument(
361
- "tar_compress", bool, optional=True, default=True, doc=doc_tar_compress
381
+ "tar_compress",
382
+ bool,
383
+ optional=True,
384
+ default=True,
385
+ doc=doc_tar_compress,
362
386
  ),
363
387
  Argument(
364
388
  "look_for_keys",
@@ -367,6 +391,20 @@ class SSHSession:
367
391
  default=True,
368
392
  doc=doc_look_for_keys,
369
393
  ),
394
+ Argument(
395
+ "execute_command",
396
+ str,
397
+ optional=True,
398
+ default=None,
399
+ doc=doc_execute_command,
400
+ ),
401
+ Argument(
402
+ "proxy_command",
403
+ [str, type(None)],
404
+ optional=True,
405
+ default=None,
406
+ doc=doc_proxy_command,
407
+ ),
370
408
  ]
371
409
  ssh_remote_profile_format = Argument(
372
410
  "ssh_session", dict, ssh_remote_profile_args
@@ -375,23 +413,37 @@ class SSHSession:
375
413
 
376
414
  def put(self, from_f, to_f):
377
415
  if self.rsync_available:
416
+ # For rsync, we need to use %h:%p placeholders for target host/port
417
+ proxy_cmd_rsync = None
418
+ if self.proxy_command is not None:
419
+ proxy_cmd_rsync = self.proxy_command.replace(
420
+ f"{self.hostname}:{self.port}", "%h:%p"
421
+ )
378
422
  return rsync(
379
423
  from_f,
380
424
  self.remote + ":" + to_f,
381
425
  port=self.port,
382
426
  key_filename=self.key_filename,
383
427
  timeout=self.timeout,
428
+ proxy_command=proxy_cmd_rsync,
384
429
  )
385
430
  return self.sftp.put(from_f, to_f)
386
431
 
387
432
  def get(self, from_f, to_f):
388
433
  if self.rsync_available:
434
+ # For rsync, we need to use %h:%p placeholders for target host/port
435
+ proxy_cmd_rsync = None
436
+ if self.proxy_command is not None:
437
+ proxy_cmd_rsync = self.proxy_command.replace(
438
+ f"{self.hostname}:{self.port}", "%h:%p"
439
+ )
389
440
  return rsync(
390
441
  self.remote + ":" + from_f,
391
442
  to_f,
392
443
  port=self.port,
393
444
  key_filename=self.key_filename,
394
445
  timeout=self.timeout,
446
+ proxy_command=proxy_cmd_rsync,
395
447
  )
396
448
  return self.sftp.get(from_f, to_f)
397
449
 
@@ -426,7 +478,9 @@ class SSHContext(BaseContext):
426
478
  self.init_local_root = local_root
427
479
  self.init_remote_root = remote_root
428
480
  self.temp_local_root = os.path.abspath(local_root)
429
- assert os.path.isabs(remote_root), "remote_root must be a abspath"
481
+ assert os.path.isabs(os.path.realpath(remote_root)), (
482
+ "remote_root must be a abspath"
483
+ )
430
484
  self.temp_remote_root = remote_root
431
485
  self.remote_profile = remote_profile
432
486
  self.remote_root = None
@@ -569,7 +623,7 @@ class SSHContext(BaseContext):
569
623
  rel_file_list, work_path, file_list, directory_list
570
624
  )
571
625
  else:
572
- raise RuntimeError(f"cannot find upload file {work_path} {jj}")
626
+ raise FileNotFoundError(f"cannot find upload file {work_path} {jj}")
573
627
 
574
628
  def upload(
575
629
  self,
@@ -603,7 +657,10 @@ class SSHContext(BaseContext):
603
657
  directory_list,
604
658
  )
605
659
  self._walk_directory(
606
- submission.forward_common_files, self.local_root, file_list, directory_list
660
+ submission.forward_common_files,
661
+ self.local_root,
662
+ file_list,
663
+ directory_list,
607
664
  )
608
665
 
609
666
  # convert to relative path to local_root
@@ -621,15 +678,14 @@ class SSHContext(BaseContext):
621
678
  ).as_posix()
622
679
  sha256_list.append(f"{sha256} {jj_rel}")
623
680
  # write to remote
624
- sha256_file = os.path.join(
625
- self.remote_root, ".tmp.sha256." + str(uuid.uuid4())
626
- )
681
+ sha256_file = pathlib.PurePath(
682
+ os.path.join(self.remote_root, ".tmp.sha256." + str(uuid.uuid4()))
683
+ ).as_posix()
627
684
  self.write_file(sha256_file, "\n".join(sha256_list))
628
685
  # check sha256
629
686
  # `:` means pass: https://stackoverflow.com/a/2421592/9567349
630
687
  _, stdout, _ = self.block_checkcall(
631
- "sha256sum -c %s --quiet >.sha256sum_stdout 2>/dev/null || :"
632
- % shlex.quote(sha256_file)
688
+ f"sha256sum -c {shlex.quote(sha256_file)} --quiet >.sha256sum_stdout 2>/dev/null || :"
633
689
  )
634
690
  self.sftp.remove(sha256_file)
635
691
  # regenerate file list
@@ -708,7 +764,7 @@ class SSHContext(BaseContext):
708
764
  os.path.join(
709
765
  self.local_root,
710
766
  ii.task_work_path,
711
- "tag_failure_download_%s" % jj,
767
+ f"tag_failure_download_{jj}",
712
768
  ),
713
769
  "w",
714
770
  ) as fp:
@@ -737,49 +793,15 @@ class SSHContext(BaseContext):
737
793
  file_list.extend(submission.backward_common_files)
738
794
  if len(file_list) > 0:
739
795
  self._get_files(
740
- file_list, tar_compress=self.remote_profile.get("tar_compress", None)
741
- )
742
-
743
- def block_checkcall(self, cmd, asynchronously=False, stderr_whitelist=None):
744
- """Run command with arguments. Wait for command to complete. If the return code
745
- was zero then return, otherwise raise RuntimeError.
746
-
747
- Parameters
748
- ----------
749
- cmd : str
750
- The command to run.
751
- asynchronously : bool, optional, default=False
752
- Run command asynchronously. If True, `nohup` will be used to run the command.
753
- stderr_whitelist : list of str, optional, default=None
754
- If not None, the stderr will be checked against the whitelist. If the stderr
755
- contains any of the strings in the whitelist, the command will be considered
756
- successful.
757
- """
758
- assert self.remote_root is not None
759
- self.ssh_session.ensure_alive()
760
- if asynchronously:
761
- cmd = "nohup %s >/dev/null &" % cmd
762
- stdin, stdout, stderr = self.ssh_session.exec_command(
763
- ("cd %s ;" % shlex.quote(self.remote_root)) + cmd
764
- )
765
- exit_status = stdout.channel.recv_exit_status()
766
- if exit_status != 0:
767
- raise RuntimeError(
768
- "Get error code %d in calling %s through ssh with job: %s . message: %s"
769
- % (
770
- exit_status,
771
- cmd,
772
- self.submission.submission_hash,
773
- stderr.read().decode("utf-8"),
774
- )
796
+ file_list,
797
+ tar_compress=self.remote_profile.get("tar_compress", None),
775
798
  )
776
- return stdin, stdout, stderr
777
799
 
778
800
  def block_call(self, cmd):
779
801
  assert self.remote_root is not None
780
802
  self.ssh_session.ensure_alive()
781
803
  stdin, stdout, stderr = self.ssh_session.exec_command(
782
- ("cd %s ;" % shlex.quote(self.remote_root)) + cmd
804
+ (f"cd {shlex.quote(self.remote_root)} ;") + cmd
783
805
  )
784
806
  exit_status = stdout.channel.recv_exit_status()
785
807
  return exit_status, stdin, stdout, stderr
@@ -794,18 +816,23 @@ class SSHContext(BaseContext):
794
816
  fname = pathlib.PurePath(os.path.join(self.remote_root, fname)).as_posix()
795
817
  # to prevent old file from being overwritten but cancelled, create a temporary file first
796
818
  # when it is fully written, rename it to the original file name
797
- with self.sftp.open(fname + "~", "w") as fp:
798
- fp.write(write_str)
819
+ temp_fname = fname + "_tmp"
820
+ try:
821
+ with self.sftp.open(temp_fname, "w") as fp:
822
+ fp.write(write_str)
823
+ # Rename the temporary file
824
+ self.block_checkcall(f"mv {shlex.quote(temp_fname)} {shlex.quote(fname)}")
799
825
  # sftp.rename may throw OSError
800
- self.block_checkcall(
801
- "mv {} {}".format(shlex.quote(fname + "~"), shlex.quote(fname))
802
- )
826
+ except OSError as e:
827
+ dlog.exception(f"Error writing to file {fname}")
828
+ raise e
803
829
 
804
830
  def read_file(self, fname):
805
831
  assert self.remote_root is not None
806
832
  self.ssh_session.ensure_alive()
807
833
  with self.sftp.open(
808
- pathlib.PurePath(os.path.join(self.remote_root, fname)).as_posix(), "r"
834
+ pathlib.PurePath(os.path.join(self.remote_root, fname)).as_posix(),
835
+ "r",
809
836
  ) as fp:
810
837
  ret = fp.read().decode("utf-8")
811
838
  return ret
@@ -829,8 +856,8 @@ class SSHContext(BaseContext):
829
856
  # print(pid)
830
857
  return {"stdin": stdin, "stdout": stdout, "stderr": stderr}
831
858
 
832
- def check_finish(self, cmd_pipes):
833
- return cmd_pipes["stdout"].channel.exit_status_ready()
859
+ def check_finish(self, proc):
860
+ return proc["stdout"].channel.exit_status_ready()
834
861
 
835
862
  def get_return(self, cmd_pipes):
836
863
  if not self.check_finish(cmd_pipes):
@@ -846,12 +873,12 @@ class SSHContext(BaseContext):
846
873
  # Thus, it's better to use system's `rm` to remove a directory, which may
847
874
  # save a lot of time.
848
875
  if verbose:
849
- dlog.info("removing %s" % remotepath)
876
+ dlog.info(f"removing {remotepath}")
850
877
  # In some supercomputers, it's very slow to remove large numbers of files
851
878
  # (e.g. directory containing trajectory) due to bad I/O performance.
852
879
  # So an asynchronously option is provided.
853
880
  self.block_checkcall(
854
- "rm -rf %s" % shlex.quote(remotepath),
881
+ f"rm -rf {shlex.quote(remotepath)}",
855
882
  asynchronously=self.clean_asynchronously,
856
883
  )
857
884
 
@@ -892,11 +919,11 @@ class SSHContext(BaseContext):
892
919
  # local tar
893
920
  if os.path.isfile(os.path.join(self.local_root, of)):
894
921
  os.remove(os.path.join(self.local_root, of))
895
- with tarfile.open(
922
+ with tarfile.open( # type: ignore[reportCallIssue, reportArgumentType]
896
923
  os.path.join(self.local_root, of),
897
- tarfile_mode,
898
- dereference=dereference,
899
- **kwargs,
924
+ mode=tarfile_mode, # type: ignore[reportArgumentType]
925
+ dereference=dereference, # type: ignore[reportArgumentType]
926
+ **kwargs, # type: ignore[reportArgumentType]
900
927
  ) as tar:
901
928
  # avoid compressing duplicated files or directories
902
929
  for ii in set(files):
@@ -921,7 +948,7 @@ class SSHContext(BaseContext):
921
948
  f"from {from_f} to {self.ssh_session.username} @ {self.ssh_session.hostname} : {to_f} Error!"
922
949
  )
923
950
  # remote extract
924
- self.block_checkcall("tar xf %s" % of)
951
+ self.block_checkcall(f"tar xf {of}")
925
952
  # clean up
926
953
  os.remove(from_f)
927
954
  self.sftp.remove(to_f)
@@ -946,36 +973,28 @@ class SSHContext(BaseContext):
946
973
  per_nfile = 100
947
974
  ntar = len(files) // per_nfile + 1
948
975
  if ntar <= 1:
949
- try:
950
- self.block_checkcall(
951
- "tar {} {} {}".format(
952
- tar_command,
953
- shlex.quote(of),
954
- " ".join([shlex.quote(file) for file in files]),
955
- )
956
- )
957
- except RuntimeError as e:
958
- if "No such file or directory" in str(e):
959
- raise FileNotFoundError(
960
- "Any of the backward files does not exist in the remote directory."
961
- ) from e
962
- raise e
976
+ file_list = " ".join([shlex.quote(file) for file in files])
977
+ tar_cmd = f"tar {tar_command} {shlex.quote(of)} {file_list}"
963
978
  else:
964
- file_list_file = os.path.join(
965
- self.remote_root, ".tmp.tar." + str(uuid.uuid4())
966
- )
979
+ file_list_file = pathlib.PurePath(
980
+ os.path.join(self.remote_root, f".tmp_tar_{uuid.uuid4()}")
981
+ ).as_posix()
967
982
  self.write_file(file_list_file, "\n".join(files))
968
- try:
969
- self.block_checkcall(
970
- f"tar {tar_command} {shlex.quote(of)} -T {shlex.quote(file_list_file)}"
971
- )
972
- except RuntimeError as e:
973
- if "No such file or directory" in str(e):
974
- raise FileNotFoundError(
975
- "Any of the backward files does not exist in the remote directory."
976
- ) from e
977
- raise e
978
- # trans
983
+ tar_cmd = (
984
+ f"tar {tar_command} {shlex.quote(of)} -T {shlex.quote(file_list_file)}"
985
+ )
986
+
987
+ # Execute the tar command remotely
988
+ try:
989
+ self.block_checkcall(tar_cmd)
990
+ except RuntimeError as e:
991
+ if "No such file or directory" in str(e):
992
+ raise FileNotFoundError(
993
+ "Backward files do not exist in the remote directory."
994
+ ) from e
995
+ raise e
996
+
997
+ # Transfer the archive from remote to local
979
998
  from_f = pathlib.PurePath(os.path.join(self.remote_root, of)).as_posix()
980
999
  to_f = pathlib.PurePath(os.path.join(self.local_root, of)).as_posix()
981
1000
  if os.path.isfile(to_f):
dpdispatcher/dlog.py CHANGED
@@ -6,21 +6,25 @@ import warnings
6
6
  dlog = logging.getLogger("dpdispatcher")
7
7
  dlog.propagate = False
8
8
  dlog.setLevel(logging.INFO)
9
+ cwd_logfile_path = os.path.join(os.getcwd(), "dpdispatcher.log")
10
+ dlogf = logging.FileHandler(cwd_logfile_path, delay=True)
9
11
  try:
10
- dlogf = logging.FileHandler(
11
- os.getcwd() + os.sep + "dpdispatcher" + ".log", delay=True
12
- )
12
+ dlog.addHandler(dlogf)
13
+ dlog.info(f"LOG INIT:dpdispatcher log direct to {cwd_logfile_path}")
13
14
  except PermissionError:
15
+ dlog.removeHandler(dlogf)
14
16
  warnings.warn(
15
- "dpdispatcher.log meet permission error. redirect the log to ~/dpdispatcher.log"
17
+ f"dump logfile dpdispatcher.log to {cwd_logfile_path} meet permission error. redirect the log to ~/dpdispatcher.log"
16
18
  )
17
19
  dlogf = logging.FileHandler(
18
20
  os.path.join(os.path.expanduser("~"), "dpdispatcher.log"), delay=True
19
21
  )
22
+ dlog.addHandler(dlogf)
23
+ dlog.info("LOG INIT:dpdispatcher log init at ~/dpdispatcher.log")
20
24
 
21
25
  dlogf_formatter = logging.Formatter("%(asctime)s - %(levelname)s : %(message)s")
22
26
  dlogf.setFormatter(dlogf_formatter)
23
- dlog.addHandler(dlogf)
27
+ # dlog.addHandler(dlogf)
24
28
 
25
29
  dlog_stdout = logging.StreamHandler(sys.stdout)
26
30
  dlog_stdout.setFormatter(dlogf_formatter)
File without changes
@@ -0,0 +1,7 @@
1
+ """Provide backward compatbility with dflow."""
2
+
3
+ from dpdispatcher.utils.dpcloudserver.client import RequestInfoException
4
+
5
+ __all__ = [
6
+ "RequestInfoException",
7
+ ]
dpdispatcher/dpdisp.py CHANGED
@@ -3,6 +3,7 @@ import argparse
3
3
  from typing import List, Optional
4
4
 
5
5
  from dpdispatcher.entrypoints.gui import start_dpgui
6
+ from dpdispatcher.entrypoints.run import run
6
7
  from dpdispatcher.entrypoints.submission import handle_submission
7
8
 
8
9
 
@@ -54,6 +55,11 @@ def main_parser() -> argparse.ArgumentParser:
54
55
  action="store_true",
55
56
  help="Clean submission.",
56
57
  )
58
+ parser_submission_action.add_argument(
59
+ "--reset-fail-count",
60
+ action="store_true",
61
+ help="Reset fail count of all jobs to zero.",
62
+ )
57
63
  ##########################################
58
64
  # gui
59
65
  parser_gui = subparsers.add_parser(
@@ -76,6 +82,18 @@ def main_parser() -> argparse.ArgumentParser:
76
82
  "to the network on both IPv4 and IPv6 (where available)."
77
83
  ),
78
84
  )
85
+ ##########################################
86
+ # run
87
+ parser_run = subparsers.add_parser(
88
+ "run",
89
+ help="Run a Python script.",
90
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
91
+ )
92
+ parser_run.add_argument(
93
+ "filename",
94
+ type=str,
95
+ help="Python script to run. PEP 723 metadata should be contained in this file.",
96
+ )
79
97
  return parser
80
98
 
81
99
 
@@ -105,12 +123,15 @@ def main():
105
123
  download_terminated_log=args.download_terminated_log,
106
124
  download_finished_task=args.download_finished_task,
107
125
  clean=args.clean,
126
+ reset_fail_count=args.reset_fail_count,
108
127
  )
109
128
  elif args.command == "gui":
110
129
  start_dpgui(
111
130
  port=args.port,
112
131
  bind_all=args.bind_all,
113
132
  )
133
+ elif args.command == "run":
134
+ run(filename=args.filename)
114
135
  elif args.command is None:
115
136
  pass
116
137
  else:
@@ -0,0 +1,9 @@
1
+ """Run PEP 723 script."""
2
+
3
+ from dpdispatcher.run import run_pep723
4
+
5
+
6
+ def run(*, filename: str):
7
+ with open(filename) as f:
8
+ script = f.read()
9
+ run_pep723(script)
@@ -12,6 +12,7 @@ def handle_submission(
12
12
  download_terminated_log: bool = False,
13
13
  download_finished_task: bool = False,
14
14
  clean: bool = False,
15
+ reset_fail_count: bool = False,
15
16
  ):
16
17
  """Handle terminated submission.
17
18
 
@@ -25,13 +26,21 @@ def handle_submission(
25
26
  Download finished tasks.
26
27
  clean : bool, optional
27
28
  Clean submission.
29
+ reset_fail_count : bool, optional
30
+ Reset fail count of all jobs to zero.
28
31
 
29
32
  Raises
30
33
  ------
31
34
  ValueError
32
35
  At least one action should be specified.
33
36
  """
34
- if int(download_terminated_log) + int(download_finished_task) + int(clean) == 0:
37
+ if (
38
+ int(download_terminated_log)
39
+ + int(download_finished_task)
40
+ + int(clean)
41
+ + int(reset_fail_count)
42
+ == 0
43
+ ):
35
44
  raise ValueError("At least one action should be specified.")
36
45
 
37
46
  submission_file = record.get_submission(submission_hash)
@@ -42,7 +51,18 @@ def handle_submission(
42
51
  # TODO: for unclear reason, the submission_hash may be changed
43
52
  submission.submission_hash = submission_hash
44
53
  submission.machine.context.bind_submission(submission)
54
+ if reset_fail_count:
55
+ for job in submission.belonging_jobs:
56
+ job.fail_count = 0
57
+ # save to remote and local
58
+ submission.submission_to_json()
59
+ record.write(submission)
60
+ if int(download_terminated_log) + int(download_finished_task) + int(clean) == 0:
61
+ # if only reset_fail_count, no need to update submission state (expensive)
62
+ return
45
63
  submission.update_submission_state()
64
+ submission.submission_to_json()
65
+ record.write(submission)
46
66
 
47
67
  terminated_tasks = []
48
68
  finished_tasks = []
dpdispatcher/machine.py CHANGED
@@ -161,6 +161,9 @@ class Machine(metaclass=ABCMeta):
161
161
  machine_dict["remote_profile"] = self.context.remote_profile
162
162
  else:
163
163
  machine_dict["remote_profile"] = {}
164
+ # normalize the dict
165
+ base = self.arginfo()
166
+ machine_dict = base.normalize_value(machine_dict, trim_pattern="_*")
164
167
  return machine_dict
165
168
 
166
169
  def __eq__(self, other):
@@ -224,7 +227,7 @@ class Machine(metaclass=ABCMeta):
224
227
  return if_recover
225
228
 
226
229
  @abstractmethod
227
- def check_finish_tag(self, **kwargs):
230
+ def check_finish_tag(self, job):
228
231
  raise NotImplementedError(
229
232
  "abstract method check_finish_tag should be implemented by derived class"
230
233
  )
@@ -261,11 +264,19 @@ class Machine(metaclass=ABCMeta):
261
264
 
262
265
  source_list = job.resources.source_list
263
266
  for ii in source_list:
264
- line = "{ source %s; } \n" % ii
265
- source_files_part += line
267
+ source_files_part += f"source {ii}\n"
266
268
 
267
269
  export_envs_part = ""
268
270
  envs = job.resources.envs
271
+ envs = {
272
+ # export resources information to the environment variables
273
+ "DPDISPATCHER_NUMBER_NODE": job.resources.number_node,
274
+ "DPDISPATCHER_CPU_PER_NODE": job.resources.cpu_per_node,
275
+ "DPDISPATCHER_GPU_PER_NODE": job.resources.gpu_per_node,
276
+ "DPDISPATCHER_QUEUE_NAME": job.resources.queue_name,
277
+ "DPDISPATCHER_GROUP_SIZE": job.resources.group_size,
278
+ **envs,
279
+ }
269
280
  for k, v in envs.items():
270
281
  if isinstance(v, list):
271
282
  for each_value in v:
@@ -466,7 +477,7 @@ class Machine(metaclass=ABCMeta):
466
477
  job : Job
467
478
  job
468
479
  """
469
- dlog.warning("Job %s should be manually killed" % job.job_id)
480
+ dlog.warning(f"Job {job.job_id} should be manually killed")
470
481
 
471
482
  def get_exit_code(self, job):
472
483
  """Get exit code of the job.