oracle-ads 2.13.11__py3-none-any.whl → 2.13.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/app.py +73 -15
- ads/aqua/cli.py +17 -0
- ads/aqua/client/client.py +38 -21
- ads/aqua/client/openai_client.py +20 -10
- ads/aqua/common/entities.py +78 -12
- ads/aqua/common/utils.py +35 -0
- ads/aqua/constants.py +2 -0
- ads/aqua/evaluation/evaluation.py +5 -4
- ads/aqua/extension/common_handler.py +47 -2
- ads/aqua/extension/model_handler.py +51 -9
- ads/aqua/model/constants.py +1 -0
- ads/aqua/model/enums.py +19 -1
- ads/aqua/model/model.py +119 -51
- ads/aqua/model/utils.py +1 -2
- ads/aqua/modeldeployment/config_loader.py +815 -0
- ads/aqua/modeldeployment/constants.py +4 -1
- ads/aqua/modeldeployment/deployment.py +178 -129
- ads/aqua/modeldeployment/entities.py +150 -178
- ads/aqua/modeldeployment/model_group_config.py +233 -0
- ads/aqua/modeldeployment/utils.py +0 -539
- ads/aqua/verify_policies/__init__.py +8 -0
- ads/aqua/verify_policies/constants.py +13 -0
- ads/aqua/verify_policies/entities.py +29 -0
- ads/aqua/verify_policies/messages.py +101 -0
- ads/aqua/verify_policies/utils.py +432 -0
- ads/aqua/verify_policies/verify.py +345 -0
- ads/aqua/version.json +3 -0
- ads/common/oci_logging.py +4 -7
- ads/common/work_request.py +39 -38
- ads/jobs/builders/infrastructure/dsc_job.py +121 -24
- ads/jobs/builders/infrastructure/dsc_job_runtime.py +71 -24
- ads/jobs/builders/runtimes/base.py +7 -5
- ads/jobs/builders/runtimes/pytorch_runtime.py +6 -8
- ads/jobs/templates/driver_pytorch.py +486 -172
- ads/jobs/templates/driver_utils.py +27 -11
- ads/model/deployment/model_deployment.py +51 -38
- ads/model/service/oci_datascience_model_deployment.py +6 -11
- ads/telemetry/client.py +4 -4
- {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/METADATA +2 -1
- {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/RECORD +43 -34
- {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/WHEEL +0 -0
- {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/entry_points.txt +0 -0
- {oracle_ads-2.13.11.dist-info → oracle_ads-2.13.13.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,7 +1,6 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8; -*-
|
3
2
|
|
4
|
-
# Copyright (c) 2021,
|
3
|
+
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
from __future__ import annotations
|
7
6
|
|
@@ -21,30 +20,33 @@ import fsspec
|
|
21
20
|
import oci
|
22
21
|
import oci.data_science
|
23
22
|
import oci.util as oci_util
|
23
|
+
import yaml
|
24
|
+
from oci.data_science import models
|
24
25
|
from oci.data_science.models import JobInfrastructureConfigurationDetails
|
25
26
|
from oci.exceptions import ServiceError
|
26
|
-
|
27
|
+
|
27
28
|
from ads.common import utils
|
29
|
+
from ads.common.decorator.utils import class_or_instance_method
|
30
|
+
from ads.common.dsc_file_system import (
|
31
|
+
DSCFileSystemManager,
|
32
|
+
OCIFileStorage,
|
33
|
+
OCIObjectStorage,
|
34
|
+
)
|
28
35
|
from ads.common.oci_datascience import DSCNotebookSession, OCIDataScienceMixin
|
29
36
|
from ads.common.oci_logging import OCILog
|
30
37
|
from ads.common.oci_resource import ResourceNotFoundError
|
31
38
|
from ads.jobs.builders.infrastructure.base import Infrastructure, RunInstance
|
32
39
|
from ads.jobs.builders.infrastructure.dsc_job_runtime import (
|
40
|
+
MULTI_NODE_JOB_SUPPORT,
|
33
41
|
ContainerRuntimeHandler,
|
34
42
|
DataScienceJobRuntimeManager,
|
35
43
|
)
|
36
44
|
from ads.jobs.builders.infrastructure.utils import get_value
|
37
45
|
from ads.jobs.builders.runtimes.artifact import Artifact
|
46
|
+
from ads.jobs.builders.runtimes.base import MultiNodeRuntime
|
38
47
|
from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
|
39
48
|
from ads.jobs.builders.runtimes.python_runtime import GitPythonRuntime
|
40
49
|
|
41
|
-
from ads.common.dsc_file_system import (
|
42
|
-
OCIFileStorage,
|
43
|
-
DSCFileSystemManager,
|
44
|
-
OCIObjectStorage,
|
45
|
-
)
|
46
|
-
from ads.common.decorator.utils import class_or_instance_method
|
47
|
-
|
48
50
|
logger = logging.getLogger(__name__)
|
49
51
|
|
50
52
|
SLEEP_INTERVAL = 3
|
@@ -52,6 +54,7 @@ WAIT_SECONDS_AFTER_FINISHED = 90
|
|
52
54
|
MAXIMUM_MOUNT_COUNT = 5
|
53
55
|
FILE_STORAGE_TYPE = "FILE_STORAGE"
|
54
56
|
OBJECT_STORAGE_TYPE = "OBJECT_STORAGE"
|
57
|
+
DEFAULT_NODE_GROUP_NAME = "node-group"
|
55
58
|
|
56
59
|
|
57
60
|
class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
|
@@ -284,11 +287,15 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
|
|
284
287
|
|
285
288
|
def load_defaults(self) -> DSCJob:
|
286
289
|
self.load_properties_from_env()
|
290
|
+
if getattr(self, "job_node_configuration_details", None):
|
291
|
+
return self
|
292
|
+
# Following are for single node job run only
|
287
293
|
if not self.job_infrastructure_configuration_details:
|
288
294
|
self.job_infrastructure_configuration_details = {}
|
295
|
+
|
289
296
|
# Convert the dict to JobInfrastructureConfigurationDetails object
|
290
297
|
if isinstance(self.job_infrastructure_configuration_details, dict):
|
291
|
-
|
298
|
+
|
292
299
|
if not self.job_infrastructure_configuration_details.get(
|
293
300
|
"jobInfrastructureType"
|
294
301
|
):
|
@@ -352,6 +359,7 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
|
|
352
359
|
raise ValueError("Specify compartment ID for data science job.")
|
353
360
|
if not self.project_id:
|
354
361
|
raise ValueError("Specify project ID for data science job.")
|
362
|
+
|
355
363
|
self._create_with_oci_api()
|
356
364
|
return self
|
357
365
|
|
@@ -498,7 +506,9 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
|
|
498
506
|
keys = list(kwargs.keys())
|
499
507
|
for key in keys:
|
500
508
|
if key in config_swagger_types:
|
501
|
-
|
509
|
+
val = kwargs.pop(key)
|
510
|
+
if val is not None:
|
511
|
+
config_kwargs[key] = val
|
502
512
|
elif key in env_config_swagger_types:
|
503
513
|
value = kwargs.pop(key)
|
504
514
|
if key in [
|
@@ -545,6 +555,25 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
|
|
545
555
|
env_config_override
|
546
556
|
)
|
547
557
|
|
558
|
+
if getattr(self, "job_node_configuration_details", None):
|
559
|
+
job_config_override = kwargs.pop("job_configuration_override_details", None)
|
560
|
+
env_config_override = kwargs.pop(
|
561
|
+
"job_environment_configuration_override_details", None
|
562
|
+
)
|
563
|
+
if job_config_override or env_config_override:
|
564
|
+
node_config = {
|
565
|
+
"jobNodeType": "MULTI_NODE",
|
566
|
+
"jobNodeGroupConfigurationDetailsList": [
|
567
|
+
{
|
568
|
+
# Node group name must match the node group name in the job.
|
569
|
+
"name": DEFAULT_NODE_GROUP_NAME,
|
570
|
+
"JobConfigurationDetails": job_config_override,
|
571
|
+
"JobEnvironmentConfigurationDetails": env_config_override,
|
572
|
+
}
|
573
|
+
],
|
574
|
+
}
|
575
|
+
kwargs["job_node_configuration_override_details"] = node_config
|
576
|
+
|
548
577
|
wait = kwargs.pop("wait", False)
|
549
578
|
run = DataScienceJobRun(**kwargs, **self.auth).create()
|
550
579
|
if wait:
|
@@ -756,13 +785,11 @@ class DataScienceJobRun(
|
|
756
785
|
return True
|
757
786
|
# Stop only if time_finished is over 2 minute ago.
|
758
787
|
# This is for the time delay between job run stopped and the logs appear in oci logging.
|
759
|
-
|
788
|
+
return (
|
760
789
|
datetime.datetime.now(self.time_finished.tzinfo)
|
761
790
|
- datetime.timedelta(seconds=wait)
|
762
791
|
> self.time_finished
|
763
|
-
)
|
764
|
-
return True
|
765
|
-
return False
|
792
|
+
)
|
766
793
|
|
767
794
|
if not self.log_id and not self.log_group_id:
|
768
795
|
print(
|
@@ -1471,6 +1498,23 @@ class DataScienceJob(Infrastructure):
|
|
1471
1498
|
}
|
1472
1499
|
self.dsc_job = dsc_job
|
1473
1500
|
|
1501
|
+
# Process multi-node infrastructure config
|
1502
|
+
node_groups = get_value(
|
1503
|
+
dsc_job,
|
1504
|
+
"job_node_configuration_details.job_node_group_configuration_details_list",
|
1505
|
+
)
|
1506
|
+
if node_groups and len(node_groups) == 1:
|
1507
|
+
node_group = node_groups[0]
|
1508
|
+
dsc_job.job_infrastructure_configuration_details = (
|
1509
|
+
node_group.job_infrastructure_configuration_details
|
1510
|
+
)
|
1511
|
+
subnet_id = get_value(
|
1512
|
+
dsc_job,
|
1513
|
+
"job_node_configuration_details.job_network_configuration.subnet_id",
|
1514
|
+
)
|
1515
|
+
if subnet_id:
|
1516
|
+
self.set_spec(self.CONST_SUBNET_ID, subnet_id)
|
1517
|
+
|
1474
1518
|
for infra_attr, dsc_attr in self.payload_attribute_map.items():
|
1475
1519
|
value = get_value(dsc_job, dsc_attr)
|
1476
1520
|
if not value:
|
@@ -1557,10 +1601,13 @@ class DataScienceJob(Infrastructure):
|
|
1557
1601
|
if value:
|
1558
1602
|
dsc_job.job_infrastructure_configuration_details[camel_attr] = value
|
1559
1603
|
|
1560
|
-
|
1561
|
-
|
1562
|
-
|
1563
|
-
"
|
1604
|
+
shape = dsc_job.job_infrastructure_configuration_details.get("shapeName", "")
|
1605
|
+
if (
|
1606
|
+
shape
|
1607
|
+
and not str(shape).endswith("Flex")
|
1608
|
+
and dsc_job.job_infrastructure_configuration_details.get(
|
1609
|
+
"jobShapeConfigDetails"
|
1610
|
+
)
|
1564
1611
|
):
|
1565
1612
|
raise ValueError(
|
1566
1613
|
"Shape config is not required for non flex shape from user end."
|
@@ -1583,7 +1630,6 @@ class DataScienceJob(Infrastructure):
|
|
1583
1630
|
return self
|
1584
1631
|
|
1585
1632
|
def build(self) -> DataScienceJob:
|
1586
|
-
self.dsc_job.load_defaults()
|
1587
1633
|
|
1588
1634
|
try:
|
1589
1635
|
self.dsc_job.load_defaults()
|
@@ -1611,6 +1657,48 @@ class DataScienceJob(Infrastructure):
|
|
1611
1657
|
)
|
1612
1658
|
)
|
1613
1659
|
|
1660
|
+
def _config_multi_node(self, runtime: MultiNodeRuntime):
|
1661
|
+
"""Configure the payload for multi-node job run."""
|
1662
|
+
infra_config: dict = self.dsc_job.job_infrastructure_configuration_details
|
1663
|
+
job_config: models.DefaultJobConfigurationDetails = (
|
1664
|
+
self.dsc_job.job_configuration_details
|
1665
|
+
)
|
1666
|
+
env_config = self.dsc_job.job_environment_configuration_details
|
1667
|
+
# For multi-node jobs,
|
1668
|
+
# the job_infrastructure_configuration_details and job_configuration_details
|
1669
|
+
# should be the special EMPTY class.
|
1670
|
+
# The job_environment_configuration_details should be None.
|
1671
|
+
# The configs will be specified in each node group.
|
1672
|
+
self.dsc_job.job_infrastructure_configuration_details = None
|
1673
|
+
self.dsc_job.job_configuration_details = None
|
1674
|
+
self.dsc_job.job_environment_configuration_details = None
|
1675
|
+
|
1676
|
+
subnet_id = infra_config.pop("subnetId", None)
|
1677
|
+
infra_config["jobInfrastructureType"] = (
|
1678
|
+
models.MultiNodeJobInfrastructureConfigurationDetails.JOB_INFRASTRUCTURE_TYPE_MULTI_NODE
|
1679
|
+
)
|
1680
|
+
|
1681
|
+
if subnet_id:
|
1682
|
+
network_config = models.JobCustomNetworkConfiguration(subnet_id=subnet_id)
|
1683
|
+
else:
|
1684
|
+
network_config = models.JobDefaultNetworkConfiguration()
|
1685
|
+
|
1686
|
+
node_group_config: dict = {
|
1687
|
+
"name": DEFAULT_NODE_GROUP_NAME,
|
1688
|
+
"replicas": runtime.replica,
|
1689
|
+
"minimumSuccessReplicas": runtime.replica,
|
1690
|
+
"jobInfrastructureConfigurationDetails": infra_config,
|
1691
|
+
"jobConfigurationDetails": job_config,
|
1692
|
+
"jobEnvironmentConfigurationDetails": env_config,
|
1693
|
+
}
|
1694
|
+
|
1695
|
+
self.dsc_job.job_node_configuration_details = {
|
1696
|
+
"jobNodeType": "MULTI_NODE",
|
1697
|
+
"startupOrder": "IN_PARALLEL",
|
1698
|
+
"jobNetworkConfiguration": network_config,
|
1699
|
+
"jobNodeGroupConfigurationDetailsList": [node_group_config],
|
1700
|
+
}
|
1701
|
+
|
1614
1702
|
def create(self, runtime, **kwargs) -> DataScienceJob:
|
1615
1703
|
"""Creates a job with runtime.
|
1616
1704
|
|
@@ -1635,9 +1723,7 @@ class DataScienceJob(Infrastructure):
|
|
1635
1723
|
|
1636
1724
|
if self.name:
|
1637
1725
|
display_name = Template(self.name).safe_substitute(runtime.envs)
|
1638
|
-
elif isinstance(runtime, GitPythonRuntime)
|
1639
|
-
runtime, ContainerRuntime
|
1640
|
-
):
|
1726
|
+
elif isinstance(runtime, (GitPythonRuntime, ContainerRuntime)):
|
1641
1727
|
display_name = utils.get_random_name_for_resource()
|
1642
1728
|
else:
|
1643
1729
|
display_name = None
|
@@ -1652,11 +1738,22 @@ class DataScienceJob(Infrastructure):
|
|
1652
1738
|
self.dsc_job = DSCJob(**payload, **self.auth)
|
1653
1739
|
# Set Job infra to user values after DSCJob initialized the defaults
|
1654
1740
|
self._update_job_infra(self.dsc_job)
|
1741
|
+
if self.is_multi_node_job(runtime):
|
1742
|
+
self._config_multi_node(runtime=runtime)
|
1655
1743
|
self.dsc_job.create()
|
1656
1744
|
# Update the model from infra after job creation.
|
1657
1745
|
self._update_from_dsc_model(self.dsc_job)
|
1658
1746
|
return self
|
1659
1747
|
|
1748
|
+
@staticmethod
|
1749
|
+
def is_multi_node_job(runtime):
|
1750
|
+
"""Check if the job is multi-node job."""
|
1751
|
+
return (
|
1752
|
+
MULTI_NODE_JOB_SUPPORT
|
1753
|
+
and isinstance(runtime, MultiNodeRuntime)
|
1754
|
+
and runtime.replica > 1
|
1755
|
+
)
|
1756
|
+
|
1660
1757
|
def run(
|
1661
1758
|
self,
|
1662
1759
|
name=None,
|
@@ -1,7 +1,6 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8; -*-
|
3
2
|
|
4
|
-
# Copyright (c) 2021,
|
3
|
+
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
"""Contains classes for conversion between ADS runtime and OCI Data Science Job implementation.
|
7
6
|
This module is for ADS developers only.
|
@@ -19,29 +18,37 @@ import os
|
|
19
18
|
import shlex
|
20
19
|
from typing import Optional
|
21
20
|
from urllib import parse
|
21
|
+
|
22
|
+
import oci
|
23
|
+
|
22
24
|
from ads.common.utils import extract_region
|
25
|
+
from ads.jobs.builders.infrastructure.utils import get_value
|
26
|
+
from ads.jobs.builders.runtimes.artifact import (
|
27
|
+
GitPythonArtifact,
|
28
|
+
NotebookArtifact,
|
29
|
+
PythonArtifact,
|
30
|
+
ScriptArtifact,
|
31
|
+
)
|
23
32
|
from ads.jobs.builders.runtimes.base import Runtime
|
33
|
+
from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
|
24
34
|
from ads.jobs.builders.runtimes.python_runtime import (
|
25
35
|
CondaRuntime,
|
26
|
-
ScriptRuntime,
|
27
|
-
PythonRuntime,
|
28
|
-
NotebookRuntime,
|
29
36
|
GitPythonRuntime,
|
37
|
+
NotebookRuntime,
|
38
|
+
PythonRuntime,
|
39
|
+
ScriptRuntime,
|
30
40
|
)
|
31
|
-
from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
|
32
41
|
from ads.jobs.builders.runtimes.pytorch_runtime import (
|
33
|
-
PyTorchDistributedRuntime,
|
34
42
|
PyTorchDistributedArtifact,
|
43
|
+
PyTorchDistributedRuntime,
|
35
44
|
)
|
36
|
-
from ads.jobs.builders.runtimes.artifact import (
|
37
|
-
ScriptArtifact,
|
38
|
-
NotebookArtifact,
|
39
|
-
PythonArtifact,
|
40
|
-
GitPythonArtifact,
|
41
|
-
)
|
42
|
-
from ads.opctl.distributed.common import cluster_config_helper
|
43
|
-
from ads.jobs.builders.infrastructure.utils import get_value
|
44
45
|
from ads.jobs.templates import driver_utils
|
46
|
+
from ads.opctl.distributed.common import cluster_config_helper
|
47
|
+
|
48
|
+
if hasattr(oci.data_science.models, "MultiNodeJobInfrastructureConfigurationDetails"):
|
49
|
+
MULTI_NODE_JOB_SUPPORT = True
|
50
|
+
else:
|
51
|
+
MULTI_NODE_JOB_SUPPORT = False
|
45
52
|
|
46
53
|
|
47
54
|
class IncompatibleRuntime(Exception): # pragma: no cover
|
@@ -77,6 +84,9 @@ class RuntimeHandler:
|
|
77
84
|
# Defines the class of the runtime to be handled.
|
78
85
|
RUNTIME_CLASS = Runtime
|
79
86
|
|
87
|
+
CONST_WORKER_COUNT = "OCI__WORKER_COUNT"
|
88
|
+
CONST_NODE_COUNT = "NODE_COUNT"
|
89
|
+
|
80
90
|
def __init__(self, data_science_job) -> None:
|
81
91
|
"""Initialize the runtime handler.
|
82
92
|
|
@@ -285,7 +295,7 @@ class RuntimeHandler:
|
|
285
295
|
* _extract_artifact()
|
286
296
|
* _extract_runtime_minutes()
|
287
297
|
Each of these method returns a dict for specifying the runtime.
|
288
|
-
The dictionaries are combined before
|
298
|
+
The dictionaries are combined before initializing the runtime.
|
289
299
|
A sub-class can modify one of more of these methods.
|
290
300
|
|
291
301
|
Parameters
|
@@ -349,6 +359,30 @@ class RuntimeHandler:
|
|
349
359
|
return {Runtime.CONST_ARGS: shlex.split(args_string)}
|
350
360
|
return {}
|
351
361
|
|
362
|
+
def _get_node_group(self, dsc_job):
|
363
|
+
"""Gets the node group for multi-node job with single node group."""
|
364
|
+
node_groups = get_value(
|
365
|
+
dsc_job,
|
366
|
+
"job_node_configuration_details.job_node_group_configuration_details_list",
|
367
|
+
)
|
368
|
+
if node_groups and len(node_groups) == 1:
|
369
|
+
return node_groups[0]
|
370
|
+
return None
|
371
|
+
|
372
|
+
def _get_replica(self, dsc_job, envs):
|
373
|
+
node_group = self._get_node_group(dsc_job)
|
374
|
+
if node_group:
|
375
|
+
replica = get_value(node_group, "replicas")
|
376
|
+
elif not envs:
|
377
|
+
replica = None
|
378
|
+
elif self.CONST_WORKER_COUNT in envs:
|
379
|
+
replica = int(envs.pop(self.CONST_WORKER_COUNT)) + 1
|
380
|
+
elif self.CONST_NODE_COUNT in envs:
|
381
|
+
replica = int(envs.pop(self.CONST_NODE_COUNT))
|
382
|
+
else:
|
383
|
+
replica = None
|
384
|
+
return replica
|
385
|
+
|
352
386
|
def _extract_envs(self, dsc_job):
|
353
387
|
"""Extract the environment variables from data science job.
|
354
388
|
|
@@ -362,7 +396,12 @@ class RuntimeHandler:
|
|
362
396
|
dict
|
363
397
|
A runtime specification dictionary for initializing a runtime.
|
364
398
|
"""
|
365
|
-
|
399
|
+
env_attr = "job_configuration_details.environment_variables"
|
400
|
+
node_group = self._get_node_group(dsc_job)
|
401
|
+
if node_group:
|
402
|
+
envs = get_value(node_group, env_attr)
|
403
|
+
else:
|
404
|
+
envs = get_value(dsc_job, env_attr)
|
366
405
|
if envs:
|
367
406
|
return {Runtime.CONST_ENV_VAR: envs}
|
368
407
|
return {}
|
@@ -968,6 +1007,12 @@ class ContainerRuntimeHandler(RuntimeHandler):
|
|
968
1007
|
payload["job_environment_configuration_details"] = job_env_config
|
969
1008
|
return payload
|
970
1009
|
|
1010
|
+
def _translate_env(self, runtime):
|
1011
|
+
envs = super()._translate_env(runtime)
|
1012
|
+
if runtime.replica:
|
1013
|
+
envs[self.CONST_NODE_COUNT] = str(runtime.replica)
|
1014
|
+
return envs
|
1015
|
+
|
971
1016
|
def _translate_artifact(self, runtime: ContainerRuntime):
|
972
1017
|
"""Additional artifact for the container"""
|
973
1018
|
if runtime.artifact_uri:
|
@@ -1049,6 +1094,10 @@ class ContainerRuntimeHandler(RuntimeHandler):
|
|
1049
1094
|
if envs:
|
1050
1095
|
spec[ContainerRuntime.CONST_ENV_VAR] = envs
|
1051
1096
|
|
1097
|
+
replica = self._get_replica(dsc_job=dsc_job, envs=envs)
|
1098
|
+
if replica:
|
1099
|
+
spec[ContainerRuntime.CONST_REPLICA] = replica
|
1100
|
+
|
1052
1101
|
return spec
|
1053
1102
|
|
1054
1103
|
def _extract_properties(self, dsc_job) -> dict:
|
@@ -1081,7 +1130,6 @@ class ContainerRuntimeHandler(RuntimeHandler):
|
|
1081
1130
|
|
1082
1131
|
class PyTorchDistributedRuntimeHandler(PythonRuntimeHandler):
|
1083
1132
|
RUNTIME_CLASS = PyTorchDistributedRuntime
|
1084
|
-
CONST_WORKER_COUNT = "OCI__WORKER_COUNT"
|
1085
1133
|
CONST_COMMAND = "OCI__LAUNCH_CMD"
|
1086
1134
|
CONST_DEEPSPEED = "OCI__DEEPSPEED"
|
1087
1135
|
|
@@ -1105,8 +1153,7 @@ class PyTorchDistributedRuntimeHandler(PythonRuntimeHandler):
|
|
1105
1153
|
def _translate_env(self, runtime: PyTorchDistributedRuntime) -> dict:
|
1106
1154
|
envs = super()._translate_env(runtime)
|
1107
1155
|
replica = runtime.replica if runtime.replica else 1
|
1108
|
-
|
1109
|
-
envs[self.CONST_WORKER_COUNT] = str(replica - 1)
|
1156
|
+
envs[self.CONST_NODE_COUNT] = str(replica)
|
1110
1157
|
envs[self.CONST_JOB_ENTRYPOINT] = PyTorchDistributedArtifact.CONST_DRIVER_SCRIPT
|
1111
1158
|
if runtime.inputs:
|
1112
1159
|
envs[driver_utils.CONST_ENV_INPUT_MAPPINGS] = json.dumps(runtime.inputs)
|
@@ -1131,12 +1178,12 @@ class PyTorchDistributedRuntimeHandler(PythonRuntimeHandler):
|
|
1131
1178
|
def _extract_envs(self, dsc_job) -> dict:
|
1132
1179
|
spec = super()._extract_envs(dsc_job)
|
1133
1180
|
envs = spec.pop(PythonRuntime.CONST_ENV_VAR, {})
|
1134
|
-
|
1181
|
+
replica = self._get_replica(dsc_job, envs=envs)
|
1182
|
+
|
1183
|
+
if not replica:
|
1135
1184
|
raise IncompatibleRuntime()
|
1136
1185
|
# Replicas
|
1137
|
-
spec[PyTorchDistributedRuntime.CONST_REPLICA] =
|
1138
|
-
int(envs.pop(self.CONST_WORKER_COUNT)) + 1
|
1139
|
-
)
|
1186
|
+
spec[PyTorchDistributedRuntime.CONST_REPLICA] = replica
|
1140
1187
|
# Git
|
1141
1188
|
if cluster_config_helper.OCI__RUNTIME_URI in envs:
|
1142
1189
|
git_spec = {}
|
@@ -1,17 +1,16 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8; -*-
|
3
2
|
|
4
|
-
# Copyright (c) 2022,
|
3
|
+
# Copyright (c) 2022, 2025 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
from __future__ import annotations
|
6
|
+
|
7
7
|
import re
|
8
8
|
import time
|
9
9
|
import traceback
|
10
|
-
|
11
10
|
from typing import Dict, TypeVar
|
12
|
-
from ads.jobs.builders.base import Builder
|
13
|
-
from ads.jobs import env_var_parser
|
14
11
|
|
12
|
+
from ads.jobs import env_var_parser
|
13
|
+
from ads.jobs.builders.base import Builder
|
15
14
|
|
16
15
|
Self = TypeVar("Self", bound="Runtime")
|
17
16
|
|
@@ -285,6 +284,9 @@ class MultiNodeRuntime(Runtime):
|
|
285
284
|
|
286
285
|
def run(self, dsc_job, **kwargs):
|
287
286
|
"""Starts the job runs"""
|
287
|
+
# For multi-node job, there is no need to create multiple job run.
|
288
|
+
if getattr(dsc_job, "job_node_configuration_details", None):
|
289
|
+
return dsc_job.run(**kwargs)
|
288
290
|
replicas = self.replica if self.replica else 1
|
289
291
|
main_run = None
|
290
292
|
job_runs = []
|
@@ -1,19 +1,19 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8; -*-
|
3
2
|
|
4
|
-
# Copyright (c) 2023,
|
3
|
+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
|
-
from ads.jobs.builders.runtimes.artifact import
|
6
|
+
from ads.jobs.builders.runtimes.artifact import GitPythonArtifact, PythonArtifact
|
8
7
|
from ads.jobs.builders.runtimes.base import MultiNodeRuntime
|
9
8
|
from ads.jobs.builders.runtimes.python_runtime import (
|
10
|
-
PythonRuntime,
|
11
9
|
GitPythonRuntime,
|
10
|
+
PythonRuntime,
|
12
11
|
)
|
13
12
|
|
14
13
|
|
15
14
|
class PyTorchDistributedRuntime(PythonRuntime, MultiNodeRuntime):
|
16
15
|
"""Represents runtime supporting PyTorch Distributed training."""
|
16
|
+
|
17
17
|
CONST_GIT = "git"
|
18
18
|
CONST_INPUT = "inputs"
|
19
19
|
CONST_DEP = "dependencies"
|
@@ -169,13 +169,11 @@ class PyTorchDistributedRuntime(PythonRuntime, MultiNodeRuntime):
|
|
169
169
|
def command(self):
|
170
170
|
"""The command for launching the workload."""
|
171
171
|
return self.get_spec(self.CONST_COMMAND)
|
172
|
-
|
172
|
+
|
173
173
|
@property
|
174
174
|
def use_deepspeed(self):
|
175
175
|
"""Indicate whether whether to configure deepspeed for multi-node workload"""
|
176
|
-
|
177
|
-
return True
|
178
|
-
return False
|
176
|
+
return bool(self.get_spec(self.CONST_DEEPSPEED))
|
179
177
|
|
180
178
|
|
181
179
|
class PyTorchDistributedArtifact(PythonArtifact):
|