oracle-ads 2.13.11__py3-none-any.whl → 2.13.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,26 +1,27 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8; -*-
3
2
 
4
- # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
- """This module requires oracle-ads>=2.6.8
7
- """
5
+ """This module requires oracle-ads>=2.6.8 and python>=3.8"""
8
6
  import getpass
9
7
  import ipaddress
8
+ import json
10
9
  import logging
11
10
  import multiprocessing
12
11
  import os
13
- import time
14
12
  import shlex
15
13
  import socket
16
14
  import sys
15
+ import time
17
16
  import traceback
18
17
 
18
+ import fsspec
19
19
  import oci
20
20
  import psutil
21
21
  import torch
22
+
22
23
  from ads import set_auth
23
- from ads.jobs import DataScienceJobRun
24
+ from ads.jobs import DataScienceJob, DataScienceJobRun
24
25
  from ads.jobs.builders.infrastructure.dsc_job_runtime import (
25
26
  PythonRuntimeHandler,
26
27
  )
@@ -29,13 +30,13 @@ from ads.opctl.distributed.common import cluster_config_helper
29
30
  try:
30
31
  # This is used by ADS and testing
31
32
  from . import driver_utils
32
- from .driver_oci import GitSSHKey, GitManager
33
- from .oci_metrics import collect_metrics, METRIC_NAMESPACE
33
+ from .driver_oci import GitManager, GitSSHKey
34
+ from .oci_metrics import METRIC_NAMESPACE, collect_metrics
34
35
  except ImportError:
35
36
  # This is used when the script is in a job run.
36
37
  import driver_utils
37
- from driver_oci import GitSSHKey, GitManager
38
- from oci_metrics import collect_metrics, METRIC_NAMESPACE
38
+ from driver_oci import GitManager, GitSSHKey
39
+ from oci_metrics import METRIC_NAMESPACE, collect_metrics
39
40
 
40
41
  logger = logging.getLogger(__name__)
41
42
  logger = driver_utils.set_log_level(logger)
@@ -50,21 +51,36 @@ CONST_ENV_NODE_RANK = "NODE_RANK"
50
51
  CONST_ENV_NODE_COUNT = "NODE_COUNT"
51
52
  CONST_ENV_LAUNCH_CMD = "OCI__LAUNCH_CMD"
52
53
  CONST_ENV_DEEPSPEED = "OCI__DEEPSPEED"
54
+ CONST_ENV_LOG_OUTPUT = "OCI__LOG_OUTPUT"
53
55
  # Envs set by this module
54
56
  CONST_ENV_WORLD_SIZE = "WORLD_SIZE"
55
57
  CONST_ENV_LD_PRELOAD = "LD_PRELOAD"
56
58
  # Envs for debugging only
59
+ CONST_ENV_SET_SOCKET_IFNAME = "SET_SOCKET_IFNAME"
57
60
  # OCI_ODSC_SERVICE_ENDPOINT is used for all processes in the job run
58
61
  CONST_ENV_ODSC_SERVICE_ENDPOINT = "OCI_ODSC_SERVICE_ENDPOINT"
59
62
  # OCI_DS_SERVICE_ENDPOINT is used only by the training process
60
63
  CONST_ENV_DS_SERVICE_ENDPOINT = "OCI_DS_SERVICE_ENDPOINT"
61
64
 
65
+ # DTv2 environment variables
66
+ CONST_ENV_INITIAL_CLUSTER_SIZE = "INITIAL_CLUSTER_SIZE"
67
+ CONST_ENV_META_FILE = "CLUSTER_NODES_METADATA_FILE"
68
+ # DTv2 metadata variables
69
+ CONST_IP_ADDRESS = "IPAddress"
70
+ CONST_RANK = "Rank"
71
+
72
+
73
+ CONST_ENCODING = "utf-8"
74
+
62
75
  # Constants used in logs
63
76
  LOG_PREFIX_HOST_IP = "Distributed Training HOST IP: "
64
77
  LOG_PREFIX_NODE_IP = "Node IP: "
65
78
  LOG_PREFIX_PUBLIC_KEY = "HOST PUBLIC KEY: "
79
+ LOG_PREFIX_HOST_KEY_RSA = "NODE HOST KEY RSA: "
80
+ LOG_PREFIX_HOST_KEY_ECDSA = "NODE HOST KEY ECDSA: "
66
81
  # Other constants used within this script
67
- # Other constants used within this script
82
+ HOST_KEY_PATH_RSA = "/etc/ssh/ssh_host_rsa_key.pub"
83
+ HOST_KEY_PATH_ECDSA = "/etc/ssh/ssh_host_ecdsa_key.pub"
68
84
  USER_HOME = os.environ.get("HOME", f"/home/{getpass.getuser()}")
69
85
  SSH_DIR = os.environ.get("OCI__SSH_DIR", os.path.join(USER_HOME, ".ssh"))
70
86
  DEFAULT_LAUNCHER = "torchrun"
@@ -122,42 +138,78 @@ class Runner(driver_utils.JobRunner):
122
138
  super().__init__(code_dir)
123
139
  self.launch_cmd = os.environ.get(CONST_ENV_LAUNCH_CMD, "")
124
140
 
125
- self.ds_client = driver_utils.OCIHelper.init_oci_client(
126
- oci.data_science.DataScienceClient
127
- )
128
- self.ip = self.find_self_ip()
129
- # IP address of other nodes as a list
130
- self.node_ip_list = []
131
- # DataScienceJobRun objects of other nodes as a list
132
- self.node_runs = []
133
-
134
- if CONST_ENV_HOST_JOB_RUN_OCID in os.environ:
135
- # Print the node IP address to logs so that it can be obtained by the host.
136
- print(f"{LOG_PREFIX_NODE_IP}{self.ip}")
137
- self.host_ocid = os.environ[CONST_ENV_HOST_JOB_RUN_OCID]
138
- logger.debug("Host job run OCID: %s", self.host_ocid)
139
- self.host_ip = None
140
- self.is_host = False
141
- else:
142
- # Print the host IP address to logs so that it can be obtained by the nodes.
143
- print(f"{LOG_PREFIX_HOST_IP}{self.ip}")
144
- self.host_ocid = os.environ.get(CONST_ENV_JOB_RUN_OCID)
145
- self.host_ip = self.ip
146
- self.is_host = True
141
+ logger.debug(os.environ)
147
142
 
148
- self.host_job_run = DataScienceJobRun.from_ocid(self.host_ocid)
149
- self.entrypoint_env = PythonRuntimeHandler.CONST_CODE_ENTRYPOINT
150
- # The total number of nodes is OCI__WORKER_COUNT + 1
151
- if CONST_ENV_NODE_COUNT in os.environ:
143
+ # Node count
144
+ if CONST_ENV_INITIAL_CLUSTER_SIZE in os.environ:
145
+ self.node_count = int(os.environ[CONST_ENV_INITIAL_CLUSTER_SIZE])
146
+ elif CONST_ENV_NODE_COUNT in os.environ:
152
147
  self.node_count = int(os.environ[CONST_ENV_NODE_COUNT])
153
148
  else:
149
+ # The total number of nodes is OCI__WORKER_COUNT + 1
154
150
  self.node_count = int(os.environ.get(OCI__WORKER_COUNT, 0)) + 1
155
151
  logger.debug("Node count: %s", self.node_count)
152
+
156
153
  self.gpu_count = torch.cuda.device_count()
157
154
  logger.debug("GPU count on this node: %s", self.gpu_count)
155
+ if self.gpu_count > 0:
156
+ logger.debug("GPU name: %s", torch.cuda.get_device_name())
157
+
158
+ # IP address of other nodes as a list
159
+ self.node_ip_list = []
160
+ # For DTv2, node_runs should not be used.
161
+ self.node_runs = None
162
+ self.host_ocid = None
163
+ self.host_job_run = None
164
+
165
+ self.node_rank = int(os.environ.get(CONST_ENV_NODE_RANK, 0))
166
+
167
+ hostname = socket.gethostname()
168
+ logger.debug("Hostname: %s", hostname)
169
+ logger.debug(
170
+ "Get Host by Addr: %s", LazyEvaluate(socket.gethostbyaddr, hostname)
171
+ )
172
+ logger.debug("FQDN: %s", LazyEvaluate(socket.getfqdn, hostname))
173
+
174
+ # Read metadata file for DTv2
175
+ self.rank_to_ip = self.read_metadata()
176
+ if self.rank_to_ip:
177
+ logger.debug(self.rank_to_ip)
178
+ # DTv2
179
+ self.ip = self.rank_to_ip[self.node_rank]
180
+ self.host_ip = self.rank_to_ip[0]
181
+ self.is_host = self.node_rank == 0
182
+ self.node_ip_list = list(self.rank_to_ip.values())
183
+ self._set_socket_interface(self._get_interface_name())
184
+ # DeepSpeed worker will check job logs to determine the public SSH key.
185
+ self.host_ocid = os.environ.get(CONST_ENV_JOB_RUN_OCID)
186
+ else:
187
+ # DTv1
188
+ self.ip = self.find_self_ip()
189
+ if CONST_ENV_HOST_JOB_RUN_OCID in os.environ:
190
+ # Print the node IP address to logs so that it can be obtained by the host.
191
+ print(f"{LOG_PREFIX_NODE_IP}{self.ip}", flush=True)
192
+ self.host_ocid = os.environ[CONST_ENV_HOST_JOB_RUN_OCID]
193
+ logger.debug("Host job run OCID: %s", self.host_ocid)
194
+ self.host_ip = None
195
+ self.is_host = False
196
+ else:
197
+ # Print the host IP address to logs so that it can be obtained by the nodes.
198
+ print(f"{LOG_PREFIX_HOST_IP}{self.ip}", flush=True)
199
+ self.host_ocid = os.environ.get(CONST_ENV_JOB_RUN_OCID)
200
+ self.host_ip = self.ip
201
+ self.is_host = True
202
+
203
+ # host_job_run is needed for DTv1 to fetch the IP addresses from logs.
204
+ if self.host_ocid and self.node_count > 1:
205
+ self.host_job_run = DataScienceJobRun.from_ocid(self.host_ocid)
206
+ self.entrypoint_env = PythonRuntimeHandler.CONST_CODE_ENTRYPOINT
158
207
 
159
208
  logger.debug("Runner initialized.")
160
209
 
210
+ def is_dtv2(self):
211
+ return CONST_ENV_META_FILE in os.environ
212
+
161
213
  def launch_cmd_contains(self, arg) -> bool:
162
214
  """Checks if the cmd for launching the training contains specific keyword argument."""
163
215
  return f"--{arg}" in self.launch_cmd
@@ -204,7 +256,7 @@ class Runner(driver_utils.JobRunner):
204
256
  logger.info("IP of %s: %s", job_run.id[-6:], ip_address)
205
257
  return ip_address
206
258
 
207
- def wait_for_log(self, job_run, log_prefix, timeout=15 * 60) -> str:
259
+ def wait_for_log(self, job_run, log_prefix, timeout=15 * 60, limit=1) -> str:
208
260
  """Waits until a log message with specific prefix is found in the logs of a job run.
209
261
 
210
262
  Parameters
@@ -223,27 +275,33 @@ class Runner(driver_utils.JobRunner):
223
275
 
224
276
  Raises
225
277
  ------
226
- TimeoutError
278
+ LoggingError
227
279
  Failed to obtain the log message within the specific timeout.
228
280
  """
229
281
  logger.debug(
230
282
  "Waiting for logs with prefix '%s' from %s.", log_prefix, job_run.id
231
283
  )
232
284
  second_started = time.time()
233
- log = None
234
- while not log:
235
- log = self.check_job_run_logs(job_run=job_run, log_prefix=log_prefix)
236
- if log:
285
+ logs = None
286
+ while True:
287
+ logs = self.check_job_run_logs(job_run=job_run, log_prefix=log_prefix)
288
+ if logs and len(logs) >= limit:
289
+ logs = logs[:limit]
237
290
  break
238
291
  if time.time() - second_started > timeout:
239
- raise TimeoutError(
240
- f"Failed to obtain log with prefix {log_prefix} for {job_run.id} in {timeout} seconds."
292
+ logs = job_run.logs()
293
+ last_log = logs[-1]["message"] if len(logs) > 0 else ""
294
+ raise Exception(
295
+ f"Failed to obtain log with prefix {log_prefix} for {job_run.id} in {timeout} seconds.\n"
296
+ f"Last log obtained: {last_log}"
241
297
  )
242
298
  time.sleep(60)
243
- return log
299
+ if limit == 1:
300
+ return logs[0]
301
+ return logs
244
302
 
245
303
  @staticmethod
246
- def check_job_run_logs(job_run, log_prefix: str) -> str:
304
+ def check_job_run_logs(job_run, log_prefix: str) -> list:
247
305
  """Checks the logs of a specific job run and find the log message with specific prefix.
248
306
 
249
307
  Parameters
@@ -260,45 +318,111 @@ class Runner(driver_utils.JobRunner):
260
318
  """
261
319
  logger.debug("Checking logs for job run %s", job_run.id)
262
320
  logs = job_run.logs()
263
- for log in logs:
264
- if log["message"].startswith(log_prefix):
265
- return log["message"][len(log_prefix) :]
266
- return None
321
+ logs = [
322
+ log["message"][len(log_prefix) :]
323
+ for log in logs
324
+ if log["message"].startswith(log_prefix)
325
+ ]
326
+ return logs
267
327
 
268
328
  def find_self_ip(self):
269
329
  """
270
330
  Identify IP address by finding which of the host IP intersects with the CIDR block of the subnet
271
331
  associated with the JOB_OCID
272
332
  """
273
- hostname = socket.gethostname()
274
- logger.debug("Hostname: %s", hostname)
275
- logger.debug(
276
- "Get Host by Addr: %s", LazyEvaluate(socket.gethostbyaddr, hostname)
277
- )
278
- logger.debug("FQDN: %s", LazyEvaluate(socket.getfqdn, hostname))
279
- if os.environ.get("JOB_OCID"):
280
- subnet_id = self.ds_client.get_job(
281
- os.environ["JOB_OCID"]
282
- ).data.job_infrastructure_configuration_details.subnet_id
333
+ if os.environ.get("JOB_OCID") and self.node_count > 1:
334
+ subnet_id = DataScienceJob.from_id(os.environ["JOB_OCID"]).subnet_id
283
335
  core_client = driver_utils.OCIHelper.init_oci_client(
284
336
  oci.core.VirtualNetworkClient
285
337
  )
286
338
  cidr = core_client.get_subnet(subnet_id).data.cidr_block
287
339
 
340
+ self_ip = None
288
341
  for interface, snics in psutil.net_if_addrs().items():
289
342
  ip = snics[0].address
343
+ logger.debug("IFNAME: %s, IP: %s", interface, ip)
290
344
  if ipaddress.ip_address(ip) in ipaddress.ip_network(cidr):
345
+ self_ip = ip
291
346
  logger.info("Node IP address: %s", ip)
292
- # Specify the network interface for NCCL/GLOO
293
- os.environ["GLOO_SOCKET_IFNAME"] = interface
294
- os.environ["NCCL_SOCKET_IFNAME"] = interface
295
- return ip
296
- raise EnvironmentError("Unable to determine node IP address.")
347
+
348
+ self._set_socket_interface(interface)
349
+ if self_ip:
350
+ return self_ip
351
+ raise OSError("Unable to determine node IP address.")
297
352
  else:
298
- ip = socket.gethostbyname(hostname)
353
+ ip = socket.gethostbyname(socket.gethostname())
299
354
  logger.info("Node IP address: %s", ip)
300
355
  return ip
301
356
 
357
+ def _set_socket_interface(self, interface: str):
358
+ """Sets the socket interface environment variables,
359
+ NCCL_SOCKET_IFNAME and GLOO_SOCKET_IFNAME.
360
+
361
+ When `SET_SOCKET_IFNAME` is found in env var and the value is not empty,
362
+ the value will be used and the `interface` argument will be ignored.
363
+
364
+ NCCL/GLOO will match the interface using prefix.
365
+ https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-socket-ifname
366
+
367
+ """
368
+ # Specify the network interface for NCCL/GLOO
369
+ if os.environ.get(CONST_ENV_SET_SOCKET_IFNAME):
370
+ interface = os.environ[CONST_ENV_SET_SOCKET_IFNAME]
371
+
372
+ # Set the env vars only if user has not set it already
373
+ if not os.environ.get("GLOO_SOCKET_IFNAME"):
374
+ logger.debug("Setting GLOO_SOCKET_IFNAME to %s", interface)
375
+ os.environ["GLOO_SOCKET_IFNAME"] = interface
376
+ if not os.environ.get("NCCL_SOCKET_IFNAME"):
377
+ logger.debug("Setting NCCL_SOCKET_IFNAME to %s", interface)
378
+ os.environ["NCCL_SOCKET_IFNAME"] = interface
379
+
380
+ def _get_interface_name(self):
381
+ node_interface = None
382
+ for interface, snics in psutil.net_if_addrs().items():
383
+ ip = snics[0].address
384
+ logger.debug("IFNAME: %s, IP: %s", interface, ip)
385
+ if ip == self.ip:
386
+ node_interface = interface
387
+ return node_interface
388
+
389
+ def read_metadata(self):
390
+ """Reads the metadata for DTv2 to get the rank and IP address mapping."""
391
+ if CONST_ENV_META_FILE not in os.environ:
392
+ return None
393
+ metadata_file = os.environ.get(CONST_ENV_META_FILE)
394
+ error_count = 0
395
+ while True:
396
+ if not os.path.exists(metadata_file):
397
+ logger.debug("Waiting for file %s to be available...", metadata_file)
398
+ time.sleep(20)
399
+ continue
400
+ logger.debug("Reading %s...", metadata_file)
401
+ with open(metadata_file, encoding=CONST_ENCODING) as f:
402
+ try:
403
+ node_list = json.load(f)
404
+ except Exception as ex:
405
+ # log the content of the file for debugging purpose.
406
+ logger.debug("Error occurred when reading metadata file:")
407
+ f.seek(0)
408
+ logger.debug(f.read())
409
+ error_count += 1
410
+ node_list = []
411
+ if error_count > 3:
412
+ raise ex
413
+
414
+ if len(node_list) < self.node_count:
415
+ logger.debug(
416
+ "Waiting for nodes... found %s of %s",
417
+ len(node_list),
418
+ self.node_count,
419
+ )
420
+ time.sleep(20)
421
+ continue
422
+ logger.debug("All nodes are found in metadata file.")
423
+ logger.debug(node_list)
424
+ return {int(meta[CONST_RANK]): meta[CONST_IP_ADDRESS] for meta in node_list}
425
+
302
426
  def fetch_code(self):
303
427
  """Fetches source code from Git if repo uri is specified."""
304
428
  if cluster_config_helper.OCI__RUNTIME_URI in os.environ:
@@ -370,10 +494,7 @@ class Runner(driver_utils.JobRunner):
370
494
  else:
371
495
  launch_args.append(self.get_cmd_with_entrypoint_and_args())
372
496
 
373
- if prefix:
374
- launcher = f"{prefix} {self.LAUNCHER}"
375
- else:
376
- launcher = self.LAUNCHER
497
+ launcher = f"{prefix} {self.LAUNCHER}" if prefix else self.LAUNCHER
377
498
 
378
499
  return f"{launcher} {' '.join(launch_args)}"
379
500
 
@@ -383,8 +504,16 @@ class Runner(driver_utils.JobRunner):
383
504
  self.run_command("pwd", level=logging.DEBUG)
384
505
  # Show all environment variables
385
506
  self.run_command("printenv", level=logging.DEBUG)
507
+ if CONST_ENV_DS_SERVICE_ENDPOINT in os.environ:
508
+ envs = {
509
+ CONST_ENV_ODSC_SERVICE_ENDPOINT: os.environ[
510
+ CONST_ENV_DS_SERVICE_ENDPOINT
511
+ ]
512
+ }
513
+ else:
514
+ envs = None
386
515
  training_start_time = time.time()
387
- self.run_command(cmd, conda_prefix=self.conda_prefix, check=True)
516
+ self.run_command(cmd, conda_prefix=self.conda_prefix, check=True, envs=envs)
388
517
  logger.info("Time: %s seconds.", time.time() - training_start_time)
389
518
 
390
519
  def run(self):
@@ -397,6 +526,7 @@ class TorchRunner(Runner):
397
526
 
398
527
  def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
399
528
  super().__init__(code_dir)
529
+ logger.debug("Initializing Torch Runner...")
400
530
  self.build_c_library()
401
531
 
402
532
  def build_c_library(self):
@@ -442,10 +572,7 @@ class TorchRunner(Runner):
442
572
  return rdzv_conf
443
573
 
444
574
  def run(self):
445
- if self.gpu_count > 0:
446
- nproc_per_node = self.gpu_count
447
- else:
448
- nproc_per_node = 1
575
+ nproc_per_node = self.gpu_count if self.gpu_count > 0 else 1
449
576
 
450
577
  launch_args = []
451
578
  # Add nnode, nproc_per_node and rdzv args only if they are not specified by the user.
@@ -471,24 +598,119 @@ class DeepSpeedRunner(Runner):
471
598
  HOST_FILE = "/home/datascience/hostfile"
472
599
  ENV_FILE = os.path.expanduser("~/.deepspeed_env")
473
600
  LAUNCHER = "deepspeed"
601
+ TMPDIR = os.environ.get("TMPDIR")
474
602
 
475
603
  def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
476
604
  super().__init__(code_dir)
477
- self.update_os()
478
-
479
- def update_os(self):
480
- # Generate SSH host keys for SSH server
481
- self.run_command("sudo ssh-keygen -A", level=logging.DEBUG, check=True)
482
- # Install SSH server to accept SSH connections
483
- # DeepSpeed uses "hostname -I" to determine the IP address
484
- # pdsh is required for default multi node training
485
- # torch cpp extension uses which command to find compiler
486
- # DeepSpeed async_io requires libaio-devel
605
+ logger.debug("Initializing DeepSpeed Runner...")
606
+ # Setup DeepSpeed if it used.
607
+ if self.use_deepspeed():
608
+ self.host_key = None
609
+ self.deepspeed_setup()
610
+
611
+ def use_deepspeed(self):
612
+ """Indicate if DeepSpeed is used."""
613
+ # Accelerate support using DeepSpeed by adding the "--use_deepspeed" argument.
614
+ return bool(
615
+ os.environ.get(CONST_ENV_DEEPSPEED)
616
+ or self.launch_cmd_contains("use_deepspeed")
617
+ or self.launch_cmd_contains("deepspeed")
618
+ )
619
+
620
+ def deepspeed_setup(self):
621
+ """Setup for DeepSpeed."""
622
+ self.host_key = HOST_KEY_PATH_RSA if os.path.exists(HOST_KEY_PATH_RSA) else None
623
+ # Create the temp dir if one does not exist.
624
+ # This is needed for JIT
625
+ if self.TMPDIR and not os.path.isdir(self.TMPDIR):
626
+ logger.info("Creating temp directory: %s", self.TMPDIR)
627
+ os.makedirs(self.TMPDIR, exist_ok=True)
628
+ self.install_deepspeed_dependencies()
629
+ # host_job_run is needed for DeepSpeed to fetch the public SSH key from the logs.
630
+ if self.host_ocid and self.node_count > 1:
631
+ self.host_job_run = DataScienceJobRun.from_ocid(self.host_ocid)
632
+
633
+ def install_epel(self):
634
+ """Installs oracle-epel-release."""
635
+ for ol_version in ["8", "9"]:
636
+ if (
637
+ self.run_command(
638
+ f'cat /etc/oracle-release | grep "release {ol_version}"',
639
+ level=logging.DEBUG,
640
+ )
641
+ == 0
642
+ ):
643
+ self.run_command(
644
+ f"sudo --preserve-env microdnf install -y oracle-epel-release-el{ol_version}"
645
+ )
646
+ break
647
+
648
+ def _print_host_key(self, host_key_path, prefix):
649
+ with open(host_key_path, encoding=CONST_ENCODING) as f:
650
+ public_key = f.read()
651
+ print(f"{prefix}{self.ip}-{public_key}")
652
+
653
+ def _add_known_hosts_from_file(self, ip_addr, key_file):
654
+ if not os.path.exists(key_file):
655
+ logger.warning(
656
+ "Unable to add host key %s to known_hosts: key file not found.",
657
+ key_file,
658
+ )
659
+ return
487
660
  self.run_command(
488
- "sudo --preserve-env yum install -y openssh-server hostname pdsh which libaio-devel",
661
+ f"echo -n '{ip_addr} ' | " f"cat - {key_file} >> {SSH_DIR}/known_hosts",
489
662
  level=logging.DEBUG,
490
663
  check=True,
491
664
  )
665
+
666
+ def _add_known_hosts_from_log(self, job_run, prefix, ip_address=None):
667
+ ip_key = self.wait_for_log(job_run, f"{prefix}")
668
+ ip_addr, public_key = ip_key.split("-", 1)
669
+ if ip_address:
670
+ ip_addr = ip_address
671
+ with open(f"{SSH_DIR}/known_hosts", "a+", encoding=CONST_ENCODING) as f:
672
+ line = f"{ip_addr} {public_key}"
673
+ f.write(f"{line}\n")
674
+ logger.debug("Added host key: %s", line)
675
+
676
+ def install_deepspeed_dependencies(self):
677
+ """Installs extra dependencies and start SSH service."""
678
+ if self.node_count == 1:
679
+ logger.debug(
680
+ "Skipped installing extra dependencies for single node training."
681
+ )
682
+ return
683
+
684
+ # Check if host keys exist
685
+ if self.host_key:
686
+ logger.debug(
687
+ "Skipped SSH host key generation.\nHost keys found: %s", self.host_key
688
+ )
689
+ else:
690
+ # Generate SSH host keys for SSH server
691
+ self.run_command("sudo ssh-keygen -A", level=logging.DEBUG, check=True)
692
+ self._print_host_key(HOST_KEY_PATH_RSA, LOG_PREFIX_HOST_KEY_RSA)
693
+ self._print_host_key(HOST_KEY_PATH_ECDSA, LOG_PREFIX_HOST_KEY_ECDSA)
694
+
695
+ if self.run_command("which pdsh", level=logging.DEBUG) != 0:
696
+ # Install "openssh-server" to accept SSH connections
697
+ # DeepSpeed uses "hostname -I" to determine the IP address
698
+ # "pdsh" is required for default multi node training
699
+ # torch cpp extension uses "which" command to find compiler
700
+ # DeepSpeed async_io requires "libaio-devel"
701
+ if self.run_command("which microdnf", level=logging.DEBUG) == 0:
702
+ self.install_epel()
703
+ self.run_command(
704
+ "sudo --preserve-env microdnf install -y openssh-server hostname pdsh pdsh-rcmd-ssh libaio-devel",
705
+ level=logging.DEBUG,
706
+ check=True,
707
+ )
708
+ elif self.run_command("which yum", level=logging.DEBUG) == 0:
709
+ self.run_command(
710
+ "sudo --preserve-env yum install -y openssh-server hostname pdsh which libaio-devel",
711
+ level=logging.DEBUG,
712
+ check=True,
713
+ )
492
714
  # Start SSH service
493
715
  self.run_command("sudo /usr/sbin/sshd", level=logging.DEBUG, check=True)
494
716
 
@@ -496,15 +718,13 @@ class DeepSpeedRunner(Runner):
496
718
  self.run_command(
497
719
  "ssh-keygen -q -t rsa -N '' <<< $'\ny'", level=logging.DEBUG, check=True
498
720
  )
499
- with open(os.path.join(SSH_DIR, "id_rsa.pub"), "r", encoding="utf-8") as f:
721
+ with open(os.path.join(SSH_DIR, "id_rsa.pub"), encoding=CONST_ENCODING) as f:
500
722
  public_key = f.read()
501
- print(f"{LOG_PREFIX_PUBLIC_KEY}{public_key}")
502
- self.add_authoried_key(public_key)
503
- self.run_command(
504
- f"ssh-keyscan -H {self.host_ip} >> {SSH_DIR}/known_hosts",
505
- level=logging.DEBUG,
506
- check=True,
507
- )
723
+ print(f"{LOG_PREFIX_PUBLIC_KEY}{public_key}", flush=True)
724
+ self._add_authoried_key(public_key)
725
+ # Add host key to known hosts
726
+ self._add_known_hosts_from_file(self.host_ip, HOST_KEY_PATH_RSA)
727
+ self._add_known_hosts_from_file(self.host_ip, HOST_KEY_PATH_ECDSA)
508
728
  self.test_ssh_connection(self.host_ip)
509
729
  # Check DeepSpeed compatibility
510
730
  self.run_command(
@@ -512,64 +732,70 @@ class DeepSpeedRunner(Runner):
512
732
  )
513
733
  return self
514
734
 
515
- @staticmethod
516
- def add_authoried_key(public_key):
735
+ def _add_authoried_key(self, public_key):
517
736
  auth_keys_file = os.path.join(SSH_DIR, "authorized_keys")
518
737
  os.makedirs(SSH_DIR, exist_ok=True)
519
- with open(auth_keys_file, "a+", encoding="utf-8") as f:
738
+ with open(auth_keys_file, "a+", encoding=CONST_ENCODING) as f:
520
739
  f.write(public_key)
521
740
  f.write("\n")
522
- logger.debug("Public key saved to %s", auth_keys_file)
741
+ logger.debug("Public key saved to %s:%s", self.ip, auth_keys_file)
523
742
 
524
743
  def fetch_host_public_key(self):
525
744
  public_key = self.wait_for_log(self.host_job_run, LOG_PREFIX_PUBLIC_KEY)
526
- print(f"{LOG_PREFIX_PUBLIC_KEY}{public_key}")
527
- # logger.debug("%s", LOG_PREFIX_PUBLIC_KEY + public_key)
528
- self.add_authoried_key(public_key)
745
+ print(f"{LOG_PREFIX_PUBLIC_KEY}{public_key}", flush=True)
746
+ self._add_authoried_key(public_key)
529
747
 
530
748
  def generate_hostfile(self):
531
- runs = self.host_job_run.job.run_list()
532
- self.node_runs = [
533
- run
534
- for run in runs
535
- if run.status in ["ACCEPTED", "IN_PROGRESS"] and run.id != self.host_ocid
536
- ]
537
- self.node_ip_list = [self.wait_for_ip_address(run) for run in self.node_runs]
749
+ if not self.node_ip_list:
750
+ runs = self.host_job_run.job.run_list()
751
+ self.node_runs = [
752
+ run
753
+ for run in runs
754
+ if run.status in ["ACCEPTED", "IN_PROGRESS"]
755
+ and run.id != self.host_ocid
756
+ ]
757
+ self.node_ip_list = [
758
+ self.wait_for_ip_address(run) for run in self.node_runs
759
+ ]
538
760
  logger.info("Node IPs: %s", self.node_ip_list)
539
761
  # Hostfile
540
762
  logger.debug("Writing hostfile to %s", self.HOST_FILE)
541
763
  os.makedirs(os.path.dirname(self.HOST_FILE), exist_ok=True)
542
- host_file_content = [f"{ip} slots={self.gpu_count}" for ip in self.node_ip_list]
543
- with open(self.HOST_FILE, "w", encoding="utf-8") as f:
544
- f.write(f"{self.host_ip} slots={self.gpu_count}\n")
764
+ host_file_content = [
765
+ f"{ip} slots={self.gpu_count}\n" for ip in self.node_ip_list
766
+ ]
767
+ with open(self.HOST_FILE, "w", encoding=CONST_ENCODING) as f:
768
+ if self.host_ip not in self.node_ip_list:
769
+ f.write(f"{self.host_ip} slots={self.gpu_count}\n")
545
770
  f.writelines(host_file_content)
546
771
  self.run_command(f"cat {self.HOST_FILE}", level=logging.DEBUG)
547
772
  # SSH config
548
773
  ssh_config_path = os.path.join(SSH_DIR, "config")
549
774
  logger.debug("Writing SSH config to %s", ssh_config_path)
550
- with open(ssh_config_path, "w", encoding="utf-8") as f:
775
+ with open(ssh_config_path, "w", encoding=CONST_ENCODING) as f:
551
776
  f.writelines(
552
777
  [
553
- "",
554
- f"Host {self.host_ip}",
555
- "IdentityFile /home/datascience/.ssh/id_rsa",
556
- "User datascience",
778
+ "\n",
779
+ f"Host {self.host_ip}\n",
780
+ "KexAlgorithms diffie-hellman-group-exchange-sha256\n",
557
781
  ]
558
782
  )
559
783
  for node_ip in self.node_ip_list:
784
+ if node_ip == self.host_ip:
785
+ continue
560
786
  f.writelines(
561
787
  [
562
- "",
563
- f"Host {node_ip}",
564
- "IdentityFile /home/datascience/.ssh/id_rsa",
565
- "User datascience",
788
+ "\n",
789
+ f"Host {node_ip}\n",
790
+ "KexAlgorithms diffie-hellman-group-exchange-sha256\n",
566
791
  ]
567
792
  )
793
+ self.run_command(f"cat {ssh_config_path}", level=logging.DEBUG)
568
794
  return self
569
795
 
570
796
  def test_ssh_connection(self, host):
571
797
  ret = self.run_command(
572
- f"ssh -v -o PasswordAuthentication=no {host} hostname -I",
798
+ f"ssh -vvv -o PasswordAuthentication=no {host} hostname -I",
573
799
  level=logging.DEBUG,
574
800
  )
575
801
  if ret == 0:
@@ -582,9 +808,8 @@ class DeepSpeedRunner(Runner):
582
808
  for node_ip in self.node_ip_list:
583
809
  logger.debug("Sending stop file to %s", node_ip)
584
810
  self.run_command(
585
- f"ssh -v {node_ip} 'touch {filename}'",
811
+ f"ssh -v -o PasswordAuthentication=no {node_ip} 'touch {filename}'",
586
812
  level=logging.DEBUG,
587
- check=True,
588
813
  )
589
814
 
590
815
  def save_deepspeed_env(self):
@@ -593,31 +818,55 @@ class DeepSpeedRunner(Runner):
593
818
  the environment variables configured by the job runs are not propagated to the SSH session.
594
819
  DeepSpeed will load the environment variables from file for the SSH sessions.
595
820
  """
596
- with open(self.ENV_FILE, mode="w", encoding="utf-8") as f:
821
+ import deepspeed
822
+
823
+ try:
824
+ version = deepspeed.__version__
825
+ minor_version = int(version.split(".")[1])
826
+ except Exception:
827
+ version = 0
828
+ minor_version = 0
829
+
830
+ with open(self.ENV_FILE, mode="w", encoding=CONST_ENCODING) as f:
597
831
  for k, v in os.environ.items():
598
- # As of deepspeed==0.9.2, empty value or line break will cause parsing error,
832
+ # Empty value or line break may cause parsing error,
599
833
  # as the .deepspeed_env file is parsed line by line.
600
834
  if not v or "\n" in v:
835
+ logger.debug("Skipped saving %s as deepspeed env.", k)
601
836
  continue
602
837
  # Ignore variables that are node specific
603
- # The network interface name for each job run is a unique string, e.g. ens300f0v1604
604
- if k in ["NCCL_SOCKET_IFNAME", "GLOO_SOCKET_IFNAME", "JOB_RUN_OCID"]:
838
+ # The network interface name for each job run could be a unique string, e.g. ens300f0v1604
839
+ # Deepspeed will copy the SOCKET_IFNAME values to all nodes if they are set.
840
+ if k in [
841
+ "NCCL_SOCKET_IFNAME",
842
+ "GLOO_SOCKET_IFNAME",
843
+ "JOB_RUN_OCID",
844
+ "NODE_RANK",
845
+ ]:
846
+ logger.debug("Skipped saving %s as deepspeed env.", k)
605
847
  continue
606
- # Quote the value if it contains space
607
- # Environment variable containing space may not be exported correctly when using pdsh
608
- # https://github.com/microsoft/DeepSpeed/blob/v0.9.2/deepspeed/launcher/multinode_runner.py#L79
609
- if " " in v:
848
+ # For DeepSpeed < 0.15.2, no extra quotes are added by DeepSpeed
849
+ # shelex.quote() will make sure the variable is exported correctly.
850
+ if minor_version < 15 or version in ["0.15.1", "0.15.0"]:
610
851
  v = shlex.quote(v)
611
852
 
853
+ # As v0.16.4, DeepSpeed is wrapping the value with double quotes.
854
+ # Escape the double quotes so that they can be exported correctly.
855
+ # This logic may need to be updated with the future version of DeepSpeed.
856
+ # https://github.com/deepspeedai/DeepSpeed/blob/v0.16.4/deepspeed/launcher/multinode_runner.py#L37
857
+ # https://github.com/deepspeedai/DeepSpeed/blob/v0.16.4/deepspeed/launcher/multinode_runner.py#L90
858
+ # https://github.com/deepspeedai/DeepSpeed/pull/5878
859
+ # https://github.com/deepspeedai/DeepSpeed/pull/7071
860
+ elif '"' in v:
861
+ v = v.replace('"', '\\"')
862
+
612
863
  f.write(f"{k}={v}\n")
613
- # The following are required for specifying the network interface to be used by NCCL/GLOO
614
- # The value should be the prefix of the expected network interface name
615
- # https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-socket-ifname
616
- f.write("NCCL_SOCKET_IFNAME=ens\n")
617
- f.write("GLOO_SOCKET_IFNAME=ens\n")
618
864
  logger.debug("Environment variables saved to %s", self.ENV_FILE)
619
865
  self.run_command(f"cat {self.ENV_FILE}")
620
866
 
867
+ def wait_for_nodes(self):
868
+ pass
869
+
621
870
  def run_deepspeed_host(self, launch_args=None):
622
871
  """Prepares the host and launch the deepspeed training.
623
872
 
@@ -633,15 +882,41 @@ class DeepSpeedRunner(Runner):
633
882
  self.generate_key_pair().generate_hostfile()
634
883
  self.save_deepspeed_env()
635
884
  # Wait for nodes to be ready
636
- for run in self.node_runs:
637
- self.wait_for_log(run, LOG_PREFIX_PUBLIC_KEY)
638
-
639
- for node_ip in self.node_ip_list:
640
- self.run_command(
641
- f"ssh-keyscan -H {node_ip} >> {SSH_DIR}/known_hosts",
642
- level=logging.DEBUG,
643
- check=True,
885
+ # For DTv2, self.node_runs will be None
886
+ if self.is_dtv2():
887
+ self.wait_for_log(
888
+ self.host_job_run, LOG_PREFIX_PUBLIC_KEY, limit=self.node_count
644
889
  )
890
+ else:
891
+ for run in self.node_runs:
892
+ self.wait_for_log(run, LOG_PREFIX_PUBLIC_KEY)
893
+
894
+ if self.host_key:
895
+ # If host key exists, it should be the same for all nodes.
896
+ for node_ip in self.node_ip_list:
897
+ self._add_known_hosts_from_file(node_ip, HOST_KEY_PATH_RSA)
898
+ self._add_known_hosts_from_file(node_ip, HOST_KEY_PATH_ECDSA)
899
+ elif self.is_dtv2():
900
+ # If host key did not exist, it it generated on the fly,
901
+ # Each node will have a different key.
902
+ # We will need to check the logs for the public key.
903
+ logger.debug("Adding node host keys to known_hosts...")
904
+ for node_ip in self.node_ip_list:
905
+ self._add_known_hosts_from_log(
906
+ self.host_job_run,
907
+ LOG_PREFIX_HOST_KEY_RSA + node_ip,
908
+ ip_address=node_ip,
909
+ )
910
+ self._add_known_hosts_from_log(
911
+ self.host_job_run,
912
+ LOG_PREFIX_HOST_KEY_ECDSA + node_ip,
913
+ ip_address=node_ip,
914
+ )
915
+ else:
916
+ logger.debug("Adding job run host keys to known_hosts...")
917
+ for run in self.node_runs:
918
+ self._add_known_hosts_from_log(run, LOG_PREFIX_HOST_KEY_RSA)
919
+ self._add_known_hosts_from_log(run, LOG_PREFIX_HOST_KEY_ECDSA)
645
920
 
646
921
  cmd = self.prepare_cmd(launch_args)
647
922
  # For DeepSpeed, we only need to run the cmd on the host
@@ -663,6 +938,9 @@ class DeepSpeedRunner(Runner):
663
938
  if os.path.exists(self.ERROR_FILE):
664
939
  logger.error("There is an error in the host job run.")
665
940
  sys.exit(1)
941
+ # Check host job run only if it is not None
942
+ if self.host_job_run is None:
943
+ continue
666
944
  # Stop the node if the host job run is CANCELLED or in unexpected state.
667
945
  try:
668
946
  self.host_job_run.sync()
@@ -693,23 +971,23 @@ class DeepSpeedRunner(Runner):
693
971
 
694
972
 
695
973
  class GenericRunner(TorchRunner, DeepSpeedRunner):
696
- """Runner for running command other than ``torchrun``, ``deepspeed`` or ``accelerate``."""
974
+ """Runner for running command that may use ``torchrun`` or ``deepspeed``."""
697
975
 
698
976
  LAUNCHER = ""
699
977
 
700
- def use_deepspeed(self) -> bool:
701
- """Indicate if DeepSpeed is used."""
702
- if os.environ.get(CONST_ENV_DEEPSPEED):
703
- return True
704
- return False
978
+ def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
979
+ super().__init__(code_dir)
980
+ logger.debug("Initializing Generic Runner...")
705
981
 
706
982
  def set_env_var(self):
707
983
  """Set default environment variables."""
708
984
  defaults = {
709
- "WORLD_SIZE": self.node_count * self.gpu_count,
985
+ CONST_ENV_WORLD_SIZE: self.node_count * self.gpu_count,
710
986
  "MASTER_ADDR": self.host_ip,
711
987
  "MASTER_PORT": self.RDZV_PORT,
712
988
  }
989
+ if self.node_count == 1:
990
+ defaults["RANK"] = 0
713
991
  for k, v in defaults.items():
714
992
  if k not in os.environ:
715
993
  os.environ[k] = str(v)
@@ -734,7 +1012,7 @@ class GenericRunner(TorchRunner, DeepSpeedRunner):
734
1012
  self.time_cmd(cmd=self.prepare_cmd(prefix=self.env_ld_preload()))
735
1013
 
736
1014
 
737
- class AccelerateRunner(TorchRunner, DeepSpeedRunner):
1015
+ class AccelerateRunner(GenericRunner):
738
1016
  """Runner for HuggingFace Accelerate."""
739
1017
 
740
1018
  # accelerate launch will add main_process_port for deepspeed cmd even if it is not needed.
@@ -750,14 +1028,18 @@ class AccelerateRunner(TorchRunner, DeepSpeedRunner):
750
1028
  LAUNCHER = "accelerate launch"
751
1029
 
752
1030
  def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
1031
+ # Here we need to call GenericRunner.__init__() explicitly
1032
+ # to avoid calling the DeepSpeedRunner.__init__().
753
1033
  super().__init__(code_dir)
1034
+ logger.debug("Initializing Accelerate Runner...")
754
1035
  # For "accelerate launch", only one of the following options can be used at one time
755
1036
  # `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp`.
756
1037
  # When a config file is not provided,
757
1038
  # --multi_gpu will be set automatically if there is more than 1 GPU
758
1039
  # self.multi_gpu = bool(self.node_count > 1 or self.gpu_count > 1)
759
1040
  self.num_machines = self.node_count
760
- self.machine_rank = os.environ["NODE_RANK"]
1041
+ # Machine rank is needed for accelerate launch to work correctly
1042
+ self.machine_rank = self.node_rank
761
1043
  # Total number of processes across all nodes
762
1044
  # Here we assume all nodes are having the same shape
763
1045
  self.num_processes = (self.gpu_count if self.gpu_count else 1) * self.node_count
@@ -766,15 +1048,6 @@ class AccelerateRunner(TorchRunner, DeepSpeedRunner):
766
1048
  # Host IP is not ready at initialization
767
1049
  self.main_process_ip = None
768
1050
 
769
- def use_deepspeed(self):
770
- """Indicate if DeepSpeed is used."""
771
- # Accelerate support using DeepSpeed by adding the "--use_deepspeed" argument.
772
- if os.environ.get(CONST_ENV_DEEPSPEED) or self.launch_cmd_contains(
773
- "use_deepspeed"
774
- ):
775
- return True
776
- return False
777
-
778
1051
  def accelerate_args(self):
779
1052
  """Gets the default arguments for the accelerate command.
780
1053
  The value of the default arguments are assigned in ``__init__()``.
@@ -785,8 +1058,11 @@ class AccelerateRunner(TorchRunner, DeepSpeedRunner):
785
1058
  logger.debug("%s=%s", arg, arg_val)
786
1059
  if arg_val is True:
787
1060
  args.append(f"--{arg}")
788
- elif arg_val:
1061
+ elif arg_val is not None:
789
1062
  args.extend([f"--{arg}", str(arg_val)])
1063
+ # --use_deepspeed is needed for deepspeed to work on single GPU
1064
+ if self.use_deepspeed() and not self.launch_cmd_contains("use_deepspeed"):
1065
+ args.append("--use_deepspeed")
790
1066
  return args
791
1067
 
792
1068
  def run_with_torchrun(self):
@@ -822,6 +1098,23 @@ class AccelerateRunner(TorchRunner, DeepSpeedRunner):
822
1098
 
823
1099
 
824
1100
  def main():
1101
+ # Collect GPU metrics only if GPU is available and user defined METRIC_NAMESPACE
1102
+ if METRIC_NAMESPACE and torch.cuda.device_count():
1103
+ p = multiprocessing.Process(target=collect_metrics)
1104
+ p.daemon = True
1105
+ p.start()
1106
+
1107
+ # Merge the CLI Arguments with CMD specified in env var
1108
+ if len(sys.argv) > 1:
1109
+ # Expand the environment variables before shlex.join
1110
+ # as it will quote the arg with single quotes.
1111
+ argv = [os.path.expandvars(arg) for arg in sys.argv[1:]]
1112
+ if os.environ.get(CONST_ENV_LAUNCH_CMD):
1113
+ os.environ[CONST_ENV_LAUNCH_CMD] = (
1114
+ shlex.join(argv) + " " + os.environ.get(CONST_ENV_LAUNCH_CMD)
1115
+ )
1116
+ else:
1117
+ os.environ[CONST_ENV_LAUNCH_CMD] = shlex.join(argv)
825
1118
  launch_cmd = os.environ.get(CONST_ENV_LAUNCH_CMD)
826
1119
  if not launch_cmd or launch_cmd.startswith("torchrun "):
827
1120
  # Use torchrun as default if launch cmd is not provided
@@ -832,21 +1125,42 @@ def main():
832
1125
  runner_class = AccelerateRunner
833
1126
  else:
834
1127
  runner_class = GenericRunner
835
-
1128
+ logger.debug("Using %s", str(runner_class))
836
1129
  runner = runner_class()
1130
+
837
1131
  runner: Runner
838
1132
  runner.fetch_code().set_working_dir().setup_python_path().install_dependencies()
839
1133
 
840
1134
  driver_utils.OCIHelper.copy_inputs()
841
-
842
- runner.wait_for_host_ip_address().run()
1135
+ if not runner.host_ip:
1136
+ runner.wait_for_host_ip_address()
1137
+ runner.run()
843
1138
  driver_utils.OCIHelper.copy_outputs()
1139
+ logger.info("Job finished with exit code 0")
1140
+ sys.exit(0)
1141
+
1142
+
1143
+ def save_job_run_logs(output_uri=os.environ.get(CONST_ENV_LOG_OUTPUT)):
1144
+ """Saves the job run logs to a file in output_uri."""
1145
+ if not output_uri:
1146
+ return
1147
+ if CONST_ENV_HOST_JOB_RUN_OCID not in os.environ:
1148
+ return
1149
+
1150
+ job_run_ocid = os.environ[CONST_ENV_HOST_JOB_RUN_OCID]
1151
+ log_uri = os.path.join(output_uri, job_run_ocid + ".log")
1152
+ # Wait for the job logs to be available in logging service
1153
+ logger.debug("Saving job run logs to %s", log_uri)
1154
+ time.sleep(60)
1155
+ try:
1156
+ job_run = DataScienceJobRun.from_ocid(job_run_ocid)
1157
+ with fsspec.open(log_uri, "w") as f:
1158
+ for log in job_run.logs():
1159
+ f.write(f"{log.get('message', '')}\n")
1160
+ except Exception:
1161
+ logger.error("Failed to save the job run logs to %s", log_uri)
1162
+ logger.debug(traceback.format_exc())
844
1163
 
845
1164
 
846
1165
  if __name__ == "__main__":
847
- # Collect GPU metrics only if GPU is available and user defined METRIC_NAMESPACE
848
- if METRIC_NAMESPACE and torch.cuda.device_count():
849
- p = multiprocessing.Process(target=collect_metrics)
850
- p.daemon = True
851
- p.start()
852
1166
  main()