oracle-ads 2.13.10rc0__py3-none-any.whl → 2.13.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. ads/aqua/app.py +13 -7
  2. ads/aqua/cli.py +15 -0
  3. ads/aqua/common/entities.py +31 -5
  4. ads/aqua/common/utils.py +35 -0
  5. ads/aqua/config/container_config.py +0 -1
  6. ads/aqua/evaluation/evaluation.py +5 -4
  7. ads/aqua/extension/deployment_handler.py +4 -1
  8. ads/aqua/extension/model_handler.py +1 -1
  9. ads/aqua/model/enums.py +19 -1
  10. ads/aqua/model/model.py +45 -36
  11. ads/aqua/model/utils.py +1 -2
  12. ads/aqua/modeldeployment/config_loader.py +815 -0
  13. ads/aqua/modeldeployment/constants.py +4 -1
  14. ads/aqua/modeldeployment/deployment.py +100 -124
  15. ads/aqua/modeldeployment/entities.py +4 -178
  16. ads/aqua/modeldeployment/model_group_config.py +240 -0
  17. ads/aqua/modeldeployment/utils.py +0 -539
  18. ads/common/work_request.py +39 -38
  19. ads/jobs/builders/infrastructure/dsc_job.py +121 -24
  20. ads/jobs/builders/infrastructure/dsc_job_runtime.py +71 -24
  21. ads/jobs/builders/runtimes/base.py +7 -5
  22. ads/jobs/builders/runtimes/pytorch_runtime.py +6 -8
  23. ads/jobs/templates/driver_pytorch.py +486 -172
  24. ads/jobs/templates/driver_utils.py +27 -11
  25. ads/model/service/oci_datascience_model_deployment.py +6 -11
  26. ads/telemetry/client.py +4 -4
  27. {oracle_ads-2.13.10rc0.dist-info → oracle_ads-2.13.12.dist-info}/METADATA +2 -2
  28. {oracle_ads-2.13.10rc0.dist-info → oracle_ads-2.13.12.dist-info}/RECORD +31 -29
  29. {oracle_ads-2.13.10rc0.dist-info → oracle_ads-2.13.12.dist-info}/WHEEL +0 -0
  30. {oracle_ads-2.13.10rc0.dist-info → oracle_ads-2.13.12.dist-info}/entry_points.txt +0 -0
  31. {oracle_ads-2.13.10rc0.dist-info → oracle_ads-2.13.12.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,7 +1,5 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8; -*-
3
-
4
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
5
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
4
 
7
5
  import logging
@@ -12,6 +10,7 @@ from typing import Callable
12
10
  import oci
13
11
  from oci import Signer
14
12
  from tqdm.auto import tqdm
13
+
15
14
  from ads.common.oci_datascience import OCIDataScienceMixin
16
15
 
17
16
  logger = logging.getLogger(__name__)
@@ -20,10 +19,10 @@ WORK_REQUEST_STOP_STATE = ("SUCCEEDED", "FAILED", "CANCELED")
20
19
  DEFAULT_WAIT_TIME = 1200
21
20
  DEFAULT_POLL_INTERVAL = 10
22
21
  WORK_REQUEST_PERCENTAGE = 100
23
- # default tqdm progress bar format:
22
+ # default tqdm progress bar format:
24
23
  # {l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]
25
24
  # customize the bar format to remove the {n_fmt}/{total_fmt} from the right side
26
- DEFAULT_BAR_FORMAT = '{l_bar}{bar}| [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]'
25
+ DEFAULT_BAR_FORMAT = "{l_bar}{bar}| [{elapsed}<{remaining}, " "{rate_fmt}{postfix}]"
27
26
 
28
27
 
29
28
  class DataScienceWorkRequest(OCIDataScienceMixin):
@@ -32,13 +31,13 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
32
31
  """
33
32
 
34
33
  def __init__(
35
- self,
36
- id: str,
34
+ self,
35
+ id: str,
37
36
  description: str = "Processing",
38
- config: dict = None,
39
- signer: Signer = None,
40
- client_kwargs: dict = None,
41
- **kwargs
37
+ config: dict = None,
38
+ signer: Signer = None,
39
+ client_kwargs: dict = None,
40
+ **kwargs,
42
41
  ) -> None:
43
42
  """Initializes ADSWorkRequest object.
44
43
 
@@ -49,41 +48,43 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
49
48
  description: str
50
49
  Progress bar initial step description (Defaults to `Processing`).
51
50
  config : dict, optional
52
- OCI API key config dictionary to initialize
51
+ OCI API key config dictionary to initialize
53
52
  oci.data_science.DataScienceClient (Defaults to None).
54
53
  signer : oci.signer.Signer, optional
55
- OCI authentication signer to initialize
54
+ OCI authentication signer to initialize
56
55
  oci.data_science.DataScienceClient (Defaults to None).
57
56
  client_kwargs : dict, optional
58
- Additional client keyword arguments to initialize
57
+ Additional client keyword arguments to initialize
59
58
  oci.data_science.DataScienceClient (Defaults to None).
60
59
  kwargs:
61
- Additional keyword arguments to initialize
60
+ Additional keyword arguments to initialize
62
61
  oci.data_science.DataScienceClient.
63
62
  """
64
63
  self.id = id
65
64
  self._description = description
66
65
  self._percentage = 0
67
66
  self._status = None
67
+ self._error_message = ""
68
68
  super().__init__(config, signer, client_kwargs, **kwargs)
69
-
70
69
 
71
70
  def _sync(self):
72
71
  """Fetches the latest work request information to ADSWorkRequest object."""
73
72
  work_request = self.client.get_work_request(self.id).data
74
- work_request_logs = self.client.list_work_request_logs(
75
- self.id
76
- ).data
73
+ work_request_logs = self.client.list_work_request_logs(self.id).data
77
74
 
78
- self._percentage= work_request.percent_complete
75
+ self._percentage = work_request.percent_complete
79
76
  self._status = work_request.status
80
- self._description = work_request_logs[-1].message if work_request_logs else "Processing"
77
+ self._description = (
78
+ work_request_logs[-1].message if work_request_logs else "Processing"
79
+ )
80
+ if work_request.status == "FAILED":
81
+ self._error_message = self.client.list_work_request_errors(self.id).data
81
82
 
82
83
  def watch(
83
- self,
84
+ self,
84
85
  progress_callback: Callable,
85
- max_wait_time: int=DEFAULT_WAIT_TIME,
86
- poll_interval: int=DEFAULT_POLL_INTERVAL,
86
+ max_wait_time: int = DEFAULT_WAIT_TIME,
87
+ poll_interval: int = DEFAULT_POLL_INTERVAL,
87
88
  ):
88
89
  """Updates the progress bar with realtime message and percentage until the process is completed.
89
90
 
@@ -92,10 +93,10 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
92
93
  progress_callback: Callable
93
94
  Progress bar callback function.
94
95
  It must accept `(percent_change, description)` where `percent_change` is the
95
- work request percent complete and `description` is the latest work request log message.
96
+ work request percent complete and `description` is the latest work request log message.
96
97
  max_wait_time: int
97
98
  Maximum amount of time to wait in seconds (Defaults to 1200).
98
- Negative implies infinite wait time.
99
+ Negative implies infinite wait time.
99
100
  poll_interval: int
100
101
  Poll interval in seconds (Defaults to 10).
101
102
 
@@ -107,7 +108,6 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
107
108
 
108
109
  start_time = time.time()
109
110
  while self._percentage < 100:
110
-
111
111
  seconds_since = time.time() - start_time
112
112
  if max_wait_time > 0 and seconds_since >= max_wait_time:
113
113
  logger.error(f"Exceeded max wait time of {max_wait_time} seconds.")
@@ -124,12 +124,14 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
124
124
  percent_change = self._percentage - previous_percent_complete
125
125
  previous_percent_complete = self._percentage
126
126
  progress_callback(
127
- percent_change=percent_change,
128
- description=self._description
127
+ percent_change=percent_change, description=self._description
129
128
  )
130
129
 
131
130
  if self._status in WORK_REQUEST_STOP_STATE:
132
- if self._status != oci.work_requests.models.WorkRequest.STATUS_SUCCEEDED:
131
+ if (
132
+ self._status
133
+ != oci.work_requests.models.WorkRequest.STATUS_SUCCEEDED
134
+ ):
133
135
  if self._description:
134
136
  raise Exception(self._description)
135
137
  else:
@@ -145,12 +147,12 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
145
147
 
146
148
  def wait_work_request(
147
149
  self,
148
- progress_bar_description: str="Processing",
149
- max_wait_time: int=DEFAULT_WAIT_TIME,
150
- poll_interval: int=DEFAULT_POLL_INTERVAL
150
+ progress_bar_description: str = "Processing",
151
+ max_wait_time: int = DEFAULT_WAIT_TIME,
152
+ poll_interval: int = DEFAULT_POLL_INTERVAL,
151
153
  ):
152
154
  """Waits for the work request progress bar to be completed.
153
-
155
+
154
156
  Parameters
155
157
  ----------
156
158
  progress_bar_description: str
@@ -160,7 +162,7 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
160
162
  Negative implies infinite wait time.
161
163
  poll_interval: int
162
164
  Poll interval in seconds (Defaults to 10).
163
-
165
+
164
166
  Returns
165
167
  -------
166
168
  None
@@ -172,7 +174,7 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
172
174
  mininterval=0,
173
175
  file=sys.stdout,
174
176
  desc=progress_bar_description,
175
- bar_format=DEFAULT_BAR_FORMAT
177
+ bar_format=DEFAULT_BAR_FORMAT,
176
178
  ) as pbar:
177
179
 
178
180
  def progress_callback(percent_change, description):
@@ -184,6 +186,5 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
184
186
  self.watch(
185
187
  progress_callback=progress_callback,
186
188
  max_wait_time=max_wait_time,
187
- poll_interval=poll_interval
189
+ poll_interval=poll_interval,
188
190
  )
189
-
@@ -1,7 +1,6 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8; -*-
3
2
 
4
- # Copyright (c) 2021, 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2021, 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
  from __future__ import annotations
7
6
 
@@ -21,30 +20,33 @@ import fsspec
21
20
  import oci
22
21
  import oci.data_science
23
22
  import oci.util as oci_util
23
+ import yaml
24
+ from oci.data_science import models
24
25
  from oci.data_science.models import JobInfrastructureConfigurationDetails
25
26
  from oci.exceptions import ServiceError
26
- import yaml
27
+
27
28
  from ads.common import utils
29
+ from ads.common.decorator.utils import class_or_instance_method
30
+ from ads.common.dsc_file_system import (
31
+ DSCFileSystemManager,
32
+ OCIFileStorage,
33
+ OCIObjectStorage,
34
+ )
28
35
  from ads.common.oci_datascience import DSCNotebookSession, OCIDataScienceMixin
29
36
  from ads.common.oci_logging import OCILog
30
37
  from ads.common.oci_resource import ResourceNotFoundError
31
38
  from ads.jobs.builders.infrastructure.base import Infrastructure, RunInstance
32
39
  from ads.jobs.builders.infrastructure.dsc_job_runtime import (
40
+ MULTI_NODE_JOB_SUPPORT,
33
41
  ContainerRuntimeHandler,
34
42
  DataScienceJobRuntimeManager,
35
43
  )
36
44
  from ads.jobs.builders.infrastructure.utils import get_value
37
45
  from ads.jobs.builders.runtimes.artifact import Artifact
46
+ from ads.jobs.builders.runtimes.base import MultiNodeRuntime
38
47
  from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
39
48
  from ads.jobs.builders.runtimes.python_runtime import GitPythonRuntime
40
49
 
41
- from ads.common.dsc_file_system import (
42
- OCIFileStorage,
43
- DSCFileSystemManager,
44
- OCIObjectStorage,
45
- )
46
- from ads.common.decorator.utils import class_or_instance_method
47
-
48
50
  logger = logging.getLogger(__name__)
49
51
 
50
52
  SLEEP_INTERVAL = 3
@@ -52,6 +54,7 @@ WAIT_SECONDS_AFTER_FINISHED = 90
52
54
  MAXIMUM_MOUNT_COUNT = 5
53
55
  FILE_STORAGE_TYPE = "FILE_STORAGE"
54
56
  OBJECT_STORAGE_TYPE = "OBJECT_STORAGE"
57
+ DEFAULT_NODE_GROUP_NAME = "node-group"
55
58
 
56
59
 
57
60
  class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
@@ -284,11 +287,15 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
284
287
 
285
288
  def load_defaults(self) -> DSCJob:
286
289
  self.load_properties_from_env()
290
+ if getattr(self, "job_node_configuration_details", None):
291
+ return self
292
+ # Following are for single node job run only
287
293
  if not self.job_infrastructure_configuration_details:
288
294
  self.job_infrastructure_configuration_details = {}
295
+
289
296
  # Convert the dict to JobInfrastructureConfigurationDetails object
290
297
  if isinstance(self.job_infrastructure_configuration_details, dict):
291
- # Default networking
298
+
292
299
  if not self.job_infrastructure_configuration_details.get(
293
300
  "jobInfrastructureType"
294
301
  ):
@@ -352,6 +359,7 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
352
359
  raise ValueError("Specify compartment ID for data science job.")
353
360
  if not self.project_id:
354
361
  raise ValueError("Specify project ID for data science job.")
362
+
355
363
  self._create_with_oci_api()
356
364
  return self
357
365
 
@@ -498,7 +506,9 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
498
506
  keys = list(kwargs.keys())
499
507
  for key in keys:
500
508
  if key in config_swagger_types:
501
- config_kwargs[key] = kwargs.pop(key)
509
+ val = kwargs.pop(key)
510
+ if val is not None:
511
+ config_kwargs[key] = val
502
512
  elif key in env_config_swagger_types:
503
513
  value = kwargs.pop(key)
504
514
  if key in [
@@ -545,6 +555,25 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
545
555
  env_config_override
546
556
  )
547
557
 
558
+ if getattr(self, "job_node_configuration_details", None):
559
+ job_config_override = kwargs.pop("job_configuration_override_details", None)
560
+ env_config_override = kwargs.pop(
561
+ "job_environment_configuration_override_details", None
562
+ )
563
+ if job_config_override or env_config_override:
564
+ node_config = {
565
+ "jobNodeType": "MULTI_NODE",
566
+ "jobNodeGroupConfigurationDetailsList": [
567
+ {
568
+ # Node group name must match the node group name in the job.
569
+ "name": DEFAULT_NODE_GROUP_NAME,
570
+ "JobConfigurationDetails": job_config_override,
571
+ "JobEnvironmentConfigurationDetails": env_config_override,
572
+ }
573
+ ],
574
+ }
575
+ kwargs["job_node_configuration_override_details"] = node_config
576
+
548
577
  wait = kwargs.pop("wait", False)
549
578
  run = DataScienceJobRun(**kwargs, **self.auth).create()
550
579
  if wait:
@@ -756,13 +785,11 @@ class DataScienceJobRun(
756
785
  return True
757
786
  # Stop only if time_finished is over 2 minute ago.
758
787
  # This is for the time delay between job run stopped and the logs appear in oci logging.
759
- if (
788
+ return (
760
789
  datetime.datetime.now(self.time_finished.tzinfo)
761
790
  - datetime.timedelta(seconds=wait)
762
791
  > self.time_finished
763
- ):
764
- return True
765
- return False
792
+ )
766
793
 
767
794
  if not self.log_id and not self.log_group_id:
768
795
  print(
@@ -1471,6 +1498,23 @@ class DataScienceJob(Infrastructure):
1471
1498
  }
1472
1499
  self.dsc_job = dsc_job
1473
1500
 
1501
+ # Process multi-node infrastructure config
1502
+ node_groups = get_value(
1503
+ dsc_job,
1504
+ "job_node_configuration_details.job_node_group_configuration_details_list",
1505
+ )
1506
+ if node_groups and len(node_groups) == 1:
1507
+ node_group = node_groups[0]
1508
+ dsc_job.job_infrastructure_configuration_details = (
1509
+ node_group.job_infrastructure_configuration_details
1510
+ )
1511
+ subnet_id = get_value(
1512
+ dsc_job,
1513
+ "job_node_configuration_details.job_network_configuration.subnet_id",
1514
+ )
1515
+ if subnet_id:
1516
+ self.set_spec(self.CONST_SUBNET_ID, subnet_id)
1517
+
1474
1518
  for infra_attr, dsc_attr in self.payload_attribute_map.items():
1475
1519
  value = get_value(dsc_job, dsc_attr)
1476
1520
  if not value:
@@ -1557,10 +1601,13 @@ class DataScienceJob(Infrastructure):
1557
1601
  if value:
1558
1602
  dsc_job.job_infrastructure_configuration_details[camel_attr] = value
1559
1603
 
1560
- if not dsc_job.job_infrastructure_configuration_details.get(
1561
- "shapeName", ""
1562
- ).endswith("Flex") and dsc_job.job_infrastructure_configuration_details.get(
1563
- "jobShapeConfigDetails"
1604
+ shape = dsc_job.job_infrastructure_configuration_details.get("shapeName", "")
1605
+ if (
1606
+ shape
1607
+ and not str(shape).endswith("Flex")
1608
+ and dsc_job.job_infrastructure_configuration_details.get(
1609
+ "jobShapeConfigDetails"
1610
+ )
1564
1611
  ):
1565
1612
  raise ValueError(
1566
1613
  "Shape config is not required for non flex shape from user end."
@@ -1583,7 +1630,6 @@ class DataScienceJob(Infrastructure):
1583
1630
  return self
1584
1631
 
1585
1632
  def build(self) -> DataScienceJob:
1586
- self.dsc_job.load_defaults()
1587
1633
 
1588
1634
  try:
1589
1635
  self.dsc_job.load_defaults()
@@ -1611,6 +1657,48 @@ class DataScienceJob(Infrastructure):
1611
1657
  )
1612
1658
  )
1613
1659
 
1660
+ def _config_multi_node(self, runtime: MultiNodeRuntime):
1661
+ """Configure the payload for multi-node job run."""
1662
+ infra_config: dict = self.dsc_job.job_infrastructure_configuration_details
1663
+ job_config: models.DefaultJobConfigurationDetails = (
1664
+ self.dsc_job.job_configuration_details
1665
+ )
1666
+ env_config = self.dsc_job.job_environment_configuration_details
1667
+ # For multi-node jobs,
1668
+ # the job_infrastructure_configuration_details and job_configuration_details
1669
+ # should be the special EMPTY class.
1670
+ # The job_environment_configuration_details should be None.
1671
+ # The configs will be specified in each node group.
1672
+ self.dsc_job.job_infrastructure_configuration_details = None
1673
+ self.dsc_job.job_configuration_details = None
1674
+ self.dsc_job.job_environment_configuration_details = None
1675
+
1676
+ subnet_id = infra_config.pop("subnetId", None)
1677
+ infra_config["jobInfrastructureType"] = (
1678
+ models.MultiNodeJobInfrastructureConfigurationDetails.JOB_INFRASTRUCTURE_TYPE_MULTI_NODE
1679
+ )
1680
+
1681
+ if subnet_id:
1682
+ network_config = models.JobCustomNetworkConfiguration(subnet_id=subnet_id)
1683
+ else:
1684
+ network_config = models.JobDefaultNetworkConfiguration()
1685
+
1686
+ node_group_config: dict = {
1687
+ "name": DEFAULT_NODE_GROUP_NAME,
1688
+ "replicas": runtime.replica,
1689
+ "minimumSuccessReplicas": runtime.replica,
1690
+ "jobInfrastructureConfigurationDetails": infra_config,
1691
+ "jobConfigurationDetails": job_config,
1692
+ "jobEnvironmentConfigurationDetails": env_config,
1693
+ }
1694
+
1695
+ self.dsc_job.job_node_configuration_details = {
1696
+ "jobNodeType": "MULTI_NODE",
1697
+ "startupOrder": "IN_PARALLEL",
1698
+ "jobNetworkConfiguration": network_config,
1699
+ "jobNodeGroupConfigurationDetailsList": [node_group_config],
1700
+ }
1701
+
1614
1702
  def create(self, runtime, **kwargs) -> DataScienceJob:
1615
1703
  """Creates a job with runtime.
1616
1704
 
@@ -1635,9 +1723,7 @@ class DataScienceJob(Infrastructure):
1635
1723
 
1636
1724
  if self.name:
1637
1725
  display_name = Template(self.name).safe_substitute(runtime.envs)
1638
- elif isinstance(runtime, GitPythonRuntime) or isinstance(
1639
- runtime, ContainerRuntime
1640
- ):
1726
+ elif isinstance(runtime, (GitPythonRuntime, ContainerRuntime)):
1641
1727
  display_name = utils.get_random_name_for_resource()
1642
1728
  else:
1643
1729
  display_name = None
@@ -1652,11 +1738,22 @@ class DataScienceJob(Infrastructure):
1652
1738
  self.dsc_job = DSCJob(**payload, **self.auth)
1653
1739
  # Set Job infra to user values after DSCJob initialized the defaults
1654
1740
  self._update_job_infra(self.dsc_job)
1741
+ if self.is_multi_node_job(runtime):
1742
+ self._config_multi_node(runtime=runtime)
1655
1743
  self.dsc_job.create()
1656
1744
  # Update the model from infra after job creation.
1657
1745
  self._update_from_dsc_model(self.dsc_job)
1658
1746
  return self
1659
1747
 
1748
+ @staticmethod
1749
+ def is_multi_node_job(runtime):
1750
+ """Check if the job is multi-node job."""
1751
+ return (
1752
+ MULTI_NODE_JOB_SUPPORT
1753
+ and isinstance(runtime, MultiNodeRuntime)
1754
+ and runtime.replica > 1
1755
+ )
1756
+
1660
1757
  def run(
1661
1758
  self,
1662
1759
  name=None,
@@ -1,7 +1,6 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8; -*-
3
2
 
4
- # Copyright (c) 2021, 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2021, 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
  """Contains classes for conversion between ADS runtime and OCI Data Science Job implementation.
7
6
  This module is for ADS developers only.
@@ -19,29 +18,37 @@ import os
19
18
  import shlex
20
19
  from typing import Optional
21
20
  from urllib import parse
21
+
22
+ import oci
23
+
22
24
  from ads.common.utils import extract_region
25
+ from ads.jobs.builders.infrastructure.utils import get_value
26
+ from ads.jobs.builders.runtimes.artifact import (
27
+ GitPythonArtifact,
28
+ NotebookArtifact,
29
+ PythonArtifact,
30
+ ScriptArtifact,
31
+ )
23
32
  from ads.jobs.builders.runtimes.base import Runtime
33
+ from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
24
34
  from ads.jobs.builders.runtimes.python_runtime import (
25
35
  CondaRuntime,
26
- ScriptRuntime,
27
- PythonRuntime,
28
- NotebookRuntime,
29
36
  GitPythonRuntime,
37
+ NotebookRuntime,
38
+ PythonRuntime,
39
+ ScriptRuntime,
30
40
  )
31
- from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
32
41
  from ads.jobs.builders.runtimes.pytorch_runtime import (
33
- PyTorchDistributedRuntime,
34
42
  PyTorchDistributedArtifact,
43
+ PyTorchDistributedRuntime,
35
44
  )
36
- from ads.jobs.builders.runtimes.artifact import (
37
- ScriptArtifact,
38
- NotebookArtifact,
39
- PythonArtifact,
40
- GitPythonArtifact,
41
- )
42
- from ads.opctl.distributed.common import cluster_config_helper
43
- from ads.jobs.builders.infrastructure.utils import get_value
44
45
  from ads.jobs.templates import driver_utils
46
+ from ads.opctl.distributed.common import cluster_config_helper
47
+
48
+ if hasattr(oci.data_science.models, "MultiNodeJobInfrastructureConfigurationDetails"):
49
+ MULTI_NODE_JOB_SUPPORT = True
50
+ else:
51
+ MULTI_NODE_JOB_SUPPORT = False
45
52
 
46
53
 
47
54
  class IncompatibleRuntime(Exception): # pragma: no cover
@@ -77,6 +84,9 @@ class RuntimeHandler:
77
84
  # Defines the class of the runtime to be handled.
78
85
  RUNTIME_CLASS = Runtime
79
86
 
87
+ CONST_WORKER_COUNT = "OCI__WORKER_COUNT"
88
+ CONST_NODE_COUNT = "NODE_COUNT"
89
+
80
90
  def __init__(self, data_science_job) -> None:
81
91
  """Initialize the runtime handler.
82
92
 
@@ -285,7 +295,7 @@ class RuntimeHandler:
285
295
  * _extract_artifact()
286
296
  * _extract_runtime_minutes()
287
297
  Each of these method returns a dict for specifying the runtime.
288
- The dictionaries are combined before initalizing the runtime.
298
+ The dictionaries are combined before initializing the runtime.
289
299
  A sub-class can modify one of more of these methods.
290
300
 
291
301
  Parameters
@@ -349,6 +359,30 @@ class RuntimeHandler:
349
359
  return {Runtime.CONST_ARGS: shlex.split(args_string)}
350
360
  return {}
351
361
 
362
+ def _get_node_group(self, dsc_job):
363
+ """Gets the node group for multi-node job with single node group."""
364
+ node_groups = get_value(
365
+ dsc_job,
366
+ "job_node_configuration_details.job_node_group_configuration_details_list",
367
+ )
368
+ if node_groups and len(node_groups) == 1:
369
+ return node_groups[0]
370
+ return None
371
+
372
+ def _get_replica(self, dsc_job, envs):
373
+ node_group = self._get_node_group(dsc_job)
374
+ if node_group:
375
+ replica = get_value(node_group, "replicas")
376
+ elif not envs:
377
+ replica = None
378
+ elif self.CONST_WORKER_COUNT in envs:
379
+ replica = int(envs.pop(self.CONST_WORKER_COUNT)) + 1
380
+ elif self.CONST_NODE_COUNT in envs:
381
+ replica = int(envs.pop(self.CONST_NODE_COUNT))
382
+ else:
383
+ replica = None
384
+ return replica
385
+
352
386
  def _extract_envs(self, dsc_job):
353
387
  """Extract the environment variables from data science job.
354
388
 
@@ -362,7 +396,12 @@ class RuntimeHandler:
362
396
  dict
363
397
  A runtime specification dictionary for initializing a runtime.
364
398
  """
365
- envs = get_value(dsc_job, "job_configuration_details.environment_variables")
399
+ env_attr = "job_configuration_details.environment_variables"
400
+ node_group = self._get_node_group(dsc_job)
401
+ if node_group:
402
+ envs = get_value(node_group, env_attr)
403
+ else:
404
+ envs = get_value(dsc_job, env_attr)
366
405
  if envs:
367
406
  return {Runtime.CONST_ENV_VAR: envs}
368
407
  return {}
@@ -968,6 +1007,12 @@ class ContainerRuntimeHandler(RuntimeHandler):
968
1007
  payload["job_environment_configuration_details"] = job_env_config
969
1008
  return payload
970
1009
 
1010
+ def _translate_env(self, runtime):
1011
+ envs = super()._translate_env(runtime)
1012
+ if runtime.replica:
1013
+ envs[self.CONST_NODE_COUNT] = str(runtime.replica)
1014
+ return envs
1015
+
971
1016
  def _translate_artifact(self, runtime: ContainerRuntime):
972
1017
  """Additional artifact for the container"""
973
1018
  if runtime.artifact_uri:
@@ -1049,6 +1094,10 @@ class ContainerRuntimeHandler(RuntimeHandler):
1049
1094
  if envs:
1050
1095
  spec[ContainerRuntime.CONST_ENV_VAR] = envs
1051
1096
 
1097
+ replica = self._get_replica(dsc_job=dsc_job, envs=envs)
1098
+ if replica:
1099
+ spec[ContainerRuntime.CONST_REPLICA] = replica
1100
+
1052
1101
  return spec
1053
1102
 
1054
1103
  def _extract_properties(self, dsc_job) -> dict:
@@ -1081,7 +1130,6 @@ class ContainerRuntimeHandler(RuntimeHandler):
1081
1130
 
1082
1131
  class PyTorchDistributedRuntimeHandler(PythonRuntimeHandler):
1083
1132
  RUNTIME_CLASS = PyTorchDistributedRuntime
1084
- CONST_WORKER_COUNT = "OCI__WORKER_COUNT"
1085
1133
  CONST_COMMAND = "OCI__LAUNCH_CMD"
1086
1134
  CONST_DEEPSPEED = "OCI__DEEPSPEED"
1087
1135
 
@@ -1105,8 +1153,7 @@ class PyTorchDistributedRuntimeHandler(PythonRuntimeHandler):
1105
1153
  def _translate_env(self, runtime: PyTorchDistributedRuntime) -> dict:
1106
1154
  envs = super()._translate_env(runtime)
1107
1155
  replica = runtime.replica if runtime.replica else 1
1108
- # WORKER_COUNT = REPLICA - 1 so that it will be same as distributed training
1109
- envs[self.CONST_WORKER_COUNT] = str(replica - 1)
1156
+ envs[self.CONST_NODE_COUNT] = str(replica)
1110
1157
  envs[self.CONST_JOB_ENTRYPOINT] = PyTorchDistributedArtifact.CONST_DRIVER_SCRIPT
1111
1158
  if runtime.inputs:
1112
1159
  envs[driver_utils.CONST_ENV_INPUT_MAPPINGS] = json.dumps(runtime.inputs)
@@ -1131,12 +1178,12 @@ class PyTorchDistributedRuntimeHandler(PythonRuntimeHandler):
1131
1178
  def _extract_envs(self, dsc_job) -> dict:
1132
1179
  spec = super()._extract_envs(dsc_job)
1133
1180
  envs = spec.pop(PythonRuntime.CONST_ENV_VAR, {})
1134
- if self.CONST_WORKER_COUNT not in envs:
1181
+ replica = self._get_replica(dsc_job, envs=envs)
1182
+
1183
+ if not replica:
1135
1184
  raise IncompatibleRuntime()
1136
1185
  # Replicas
1137
- spec[PyTorchDistributedRuntime.CONST_REPLICA] = (
1138
- int(envs.pop(self.CONST_WORKER_COUNT)) + 1
1139
- )
1186
+ spec[PyTorchDistributedRuntime.CONST_REPLICA] = replica
1140
1187
  # Git
1141
1188
  if cluster_config_helper.OCI__RUNTIME_URI in envs:
1142
1189
  git_spec = {}
@@ -1,17 +1,16 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8; -*-
3
2
 
4
- # Copyright (c) 2022, 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2022, 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
  from __future__ import annotations
6
+
7
7
  import re
8
8
  import time
9
9
  import traceback
10
-
11
10
  from typing import Dict, TypeVar
12
- from ads.jobs.builders.base import Builder
13
- from ads.jobs import env_var_parser
14
11
 
12
+ from ads.jobs import env_var_parser
13
+ from ads.jobs.builders.base import Builder
15
14
 
16
15
  Self = TypeVar("Self", bound="Runtime")
17
16
 
@@ -285,6 +284,9 @@ class MultiNodeRuntime(Runtime):
285
284
 
286
285
  def run(self, dsc_job, **kwargs):
287
286
  """Starts the job runs"""
287
+ # For multi-node job, there is no need to create multiple job run.
288
+ if getattr(dsc_job, "job_node_configuration_details", None):
289
+ return dsc_job.run(**kwargs)
288
290
  replicas = self.replica if self.replica else 1
289
291
  main_run = None
290
292
  job_runs = []