oracle-ads 2.10.1__py3-none-any.whl → 2.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. ads/aqua/__init__.py +12 -0
  2. ads/aqua/base.py +324 -0
  3. ads/aqua/cli.py +19 -0
  4. ads/aqua/config/deployment_config_defaults.json +9 -0
  5. ads/aqua/config/resource_limit_names.json +7 -0
  6. ads/aqua/constants.py +45 -0
  7. ads/aqua/data.py +40 -0
  8. ads/aqua/decorator.py +101 -0
  9. ads/aqua/deployment.py +643 -0
  10. ads/aqua/dummy_data/icon.txt +1 -0
  11. ads/aqua/dummy_data/oci_model_deployments.json +56 -0
  12. ads/aqua/dummy_data/oci_models.json +1 -0
  13. ads/aqua/dummy_data/readme.md +26 -0
  14. ads/aqua/evaluation.py +1751 -0
  15. ads/aqua/exception.py +82 -0
  16. ads/aqua/extension/__init__.py +40 -0
  17. ads/aqua/extension/base_handler.py +138 -0
  18. ads/aqua/extension/common_handler.py +21 -0
  19. ads/aqua/extension/deployment_handler.py +202 -0
  20. ads/aqua/extension/evaluation_handler.py +135 -0
  21. ads/aqua/extension/finetune_handler.py +66 -0
  22. ads/aqua/extension/model_handler.py +59 -0
  23. ads/aqua/extension/ui_handler.py +201 -0
  24. ads/aqua/extension/utils.py +23 -0
  25. ads/aqua/finetune.py +579 -0
  26. ads/aqua/job.py +29 -0
  27. ads/aqua/model.py +819 -0
  28. ads/aqua/training/__init__.py +4 -0
  29. ads/aqua/training/exceptions.py +459 -0
  30. ads/aqua/ui.py +453 -0
  31. ads/aqua/utils.py +715 -0
  32. ads/cli.py +37 -6
  33. ads/common/decorator/__init__.py +7 -3
  34. ads/common/decorator/require_nonempty_arg.py +65 -0
  35. ads/common/object_storage_details.py +166 -7
  36. ads/common/oci_client.py +18 -1
  37. ads/common/oci_logging.py +2 -2
  38. ads/common/oci_mixin.py +4 -5
  39. ads/common/serializer.py +34 -5
  40. ads/common/utils.py +75 -10
  41. ads/config.py +40 -1
  42. ads/jobs/ads_job.py +43 -25
  43. ads/jobs/builders/infrastructure/base.py +4 -2
  44. ads/jobs/builders/infrastructure/dsc_job.py +49 -39
  45. ads/jobs/builders/runtimes/base.py +71 -1
  46. ads/jobs/builders/runtimes/container_runtime.py +4 -4
  47. ads/jobs/builders/runtimes/pytorch_runtime.py +10 -63
  48. ads/jobs/templates/driver_pytorch.py +27 -10
  49. ads/model/artifact_downloader.py +84 -14
  50. ads/model/artifact_uploader.py +25 -23
  51. ads/model/datascience_model.py +388 -38
  52. ads/model/deployment/model_deployment.py +10 -2
  53. ads/model/generic_model.py +8 -0
  54. ads/model/model_file_description_schema.json +68 -0
  55. ads/model/model_metadata.py +1 -1
  56. ads/model/service/oci_datascience_model.py +34 -5
  57. ads/opctl/operator/lowcode/anomaly/README.md +2 -1
  58. ads/opctl/operator/lowcode/anomaly/__main__.py +10 -4
  59. ads/opctl/operator/lowcode/anomaly/environment.yaml +2 -1
  60. ads/opctl/operator/lowcode/anomaly/model/automlx.py +12 -6
  61. ads/opctl/operator/lowcode/forecast/README.md +3 -2
  62. ads/opctl/operator/lowcode/forecast/environment.yaml +3 -2
  63. ads/opctl/operator/lowcode/forecast/model/automlx.py +12 -23
  64. ads/telemetry/base.py +62 -0
  65. ads/telemetry/client.py +105 -0
  66. ads/telemetry/telemetry.py +6 -3
  67. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/METADATA +37 -7
  68. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/RECORD +71 -36
  69. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/LICENSE.txt +0 -0
  70. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/WHEEL +0 -0
  71. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/entry_points.txt +0 -0
ads/config.py CHANGED
@@ -1,7 +1,7 @@
1
1
  #!/usr/bin/env python
2
2
  # -*- coding: utf-8; -*-
3
3
 
4
- # Copyright (c) 2020, 2023 Oracle and/or its affiliates.
4
+ # Copyright (c) 2020, 2024 Oracle and/or its affiliates.
5
5
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
6
 
7
7
  import contextlib
@@ -33,8 +33,10 @@ CONDA_BUCKET_NS = os.environ.get("CONDA_BUCKET_NS", "id19sfcrra6z")
33
33
  OCI_RESOURCE_PRINCIPAL_RPT_ENDPOINT = os.environ.get(
34
34
  "OCI_RESOURCE_PRINCIPAL_RPT_ENDPOINT"
35
35
  )
36
+ PROJECT_COMPARTMENT_OCID = os.environ.get("PROJECT_COMPARTMENT_OCID")
36
37
  COMPARTMENT_OCID = (
37
38
  NB_SESSION_COMPARTMENT_OCID
39
+ or PROJECT_COMPARTMENT_OCID
38
40
  or JOB_RUN_COMPARTMENT_OCID
39
41
  or PIPELINE_RUN_COMPARTMENT_OCID
40
42
  )
@@ -47,6 +49,43 @@ RESOURCE_OCID = (
47
49
  NO_CONTAINER = os.environ.get("NO_CONTAINER")
48
50
  TMPDIR = os.environ.get("TMPDIR")
49
51
 
52
+ ODSC_MODEL_COMPARTMENT_OCID = os.environ.get("ODSC_MODEL_COMPARTMENT_OCID")
53
+ AQUA_MODEL_DEPLOYMENT_IMAGE = os.environ.get("AQUA_MODEL_DEPLOYMENT_IMAGE")
54
+ AQUA_MODEL_DEPLOYMENT_CONFIG = os.environ.get(
55
+ "AQUA_DEPLOYMENT_CONFIG", "deployment_config.json"
56
+ )
57
+ AQUA_MODEL_FINETUNING_CONFIG = os.environ.get(
58
+ "AQUA_MODEL_FINETUNING_CONFIG", "ft_config.json"
59
+ )
60
+ AQUA_CONTAINER_INDEX_CONFIG = os.environ.get(
61
+ "AQUA_CONTAINER_INDEX_CONFIG", "container_index.json"
62
+ )
63
+ AQUA_MODEL_DEPLOYMENT_CONFIG_DEFAULTS = os.environ.get(
64
+ "AQUA_MODEL_DEPLOYMENT_CONFIG_DEFAULTS", "deployment_config_defaults.json"
65
+ )
66
+ AQUA_RESOURCE_LIMIT_NAMES_CONFIG = os.environ.get(
67
+ "AQUA_RESOURCE_LIMIT_NAMES_CONFIG", "resource_limit_names.json"
68
+ )
69
+ AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME = "deployment-container"
70
+ AQUA_FINETUNING_CONTAINER_METADATA_NAME = "finetune-container"
71
+ AQUA_EVALUATION_CONTAINER_METADATA_NAME = "evaluation-container"
72
+ AQUA_MODEL_DEPLOYMENT_FOLDER = "/opt/ds/model/deployed_model/"
73
+ AQUA_SERVED_MODEL_NAME = "odsc-llm"
74
+ AQUA_CONFIG_FOLDER = os.path.join(
75
+ os.path.dirname(os.path.realpath(__file__)), "aqua/config/"
76
+ )
77
+ AQUA_JOB_SUBNET_ID = os.environ.get("AQUA_JOB_SUBNET_ID", None)
78
+ AQUA_SERVICE_MODELS_BUCKET = os.environ.get(
79
+ "AQUA_SERVICE_MODELS_BUCKET", "service-managed-models"
80
+ )
81
+ AQUA_TELEMETRY_BUCKET = os.environ.get(
82
+ "AQUA_TELEMETRY_BUCKET", "service-managed-models"
83
+ )
84
+ AQUA_TELEMETRY_BUCKET_NS = os.environ.get("AQUA_TELEMETRY_BUCKET_NS", CONDA_BUCKET_NS)
85
+ DEBUG_TELEMETRY = os.environ.get("DEBUG_TELEMETRY", None)
86
+ AQUA_SERVICE_NAME = "aqua"
87
+ DATA_SCIENCE_SERVICE_NAME = "data-science"
88
+
50
89
 
51
90
  def export(
52
91
  uri: Optional[str] = DEFAULT_CONFIG_PATH,
ads/jobs/ads_job.py CHANGED
@@ -1,8 +1,9 @@
1
1
  #!/usr/bin/env python
2
2
  # -*- coding: utf-8; -*-
3
3
 
4
- # Copyright (c) 2021, 2023 Oracle and/or its affiliates.
4
+ # Copyright (c) 2021, 2024 Oracle and/or its affiliates.
5
5
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
+ import inspect
6
7
  import time
7
8
  from typing import List, Union, Dict
8
9
  from urllib.parse import urlparse
@@ -10,12 +11,13 @@ from urllib.parse import urlparse
10
11
  import fsspec
11
12
  import oci
12
13
  from ads.common.auth import default_signer
14
+ from ads.common.decorator.utils import class_or_instance_method
13
15
  from ads.jobs.builders.base import Builder
14
16
  from ads.jobs.builders.infrastructure.dataflow import DataFlow, DataFlowRun
15
17
  from ads.jobs.builders.infrastructure.dsc_job import (
16
- DataScienceJob,
17
- DataScienceJobRun,
18
- SLEEP_INTERVAL
18
+ DataScienceJob,
19
+ DataScienceJobRun,
20
+ SLEEP_INTERVAL,
19
21
  )
20
22
  from ads.jobs.builders.runtimes.pytorch_runtime import PyTorchDistributedRuntime
21
23
  from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
@@ -143,8 +145,10 @@ class Job(Builder):
143
145
  ]
144
146
  }
145
147
 
146
- @staticmethod
147
- def from_datascience_job(job_id) -> "Job":
148
+ auth = {}
149
+
150
+ @class_or_instance_method
151
+ def from_datascience_job(cls, job_id) -> "Job":
148
152
  """Loads a data science job from OCI.
149
153
 
150
154
  Parameters
@@ -158,16 +162,18 @@ class Job(Builder):
158
162
  A job instance.
159
163
 
160
164
  """
161
- dsc_infra = DataScienceJob.from_id(job_id)
165
+ dsc_infra = DataScienceJob(**cls.auth).from_id(job_id)
162
166
  job = (
163
- Job(name=dsc_infra.name)
167
+ Job(name=dsc_infra.name, **cls.auth)
164
168
  .with_infrastructure(dsc_infra)
165
169
  .with_runtime(dsc_infra.runtime)
166
170
  )
167
171
  return job
168
172
 
169
- @staticmethod
170
- def datascience_job(compartment_id: str = None, **kwargs) -> List["DataScienceJob"]:
173
+ @class_or_instance_method
174
+ def datascience_job(
175
+ cls, compartment_id: str = None, **kwargs
176
+ ) -> List["DataScienceJob"]:
171
177
  """Lists the existing data science jobs in the compartment.
172
178
 
173
179
  Parameters
@@ -183,10 +189,12 @@ class Job(Builder):
183
189
  A list of Job objects.
184
190
  """
185
191
  return [
186
- Job(name=dsc_job.name)
192
+ Job(name=dsc_job.name, **cls.auth)
187
193
  .with_infrastructure(dsc_job)
188
194
  .with_runtime(dsc_job.runtime)
189
- for dsc_job in DataScienceJob.list_jobs(compartment_id, **kwargs)
195
+ for dsc_job in DataScienceJob(**cls.auth).list_jobs(
196
+ compartment_id, **kwargs
197
+ )
190
198
  ]
191
199
 
192
200
  @staticmethod
@@ -229,7 +237,9 @@ class Job(Builder):
229
237
  for df in DataFlow.list_jobs(compartment_id, **kwargs)
230
238
  ]
231
239
 
232
- def __init__(self, name: str = None, infrastructure=None, runtime=None) -> None:
240
+ def __init__(
241
+ self, name: str = None, infrastructure=None, runtime=None, **kwargs
242
+ ) -> None:
233
243
  """Initializes a job.
234
244
 
235
245
  The infrastructure and runtime can be configured when initializing the job,
@@ -253,6 +263,9 @@ class Job(Builder):
253
263
  Job runtime, by default None.
254
264
 
255
265
  """
266
+ for key in ["config", "signer", "client_kwargs"]:
267
+ if kwargs.get(key):
268
+ self.auth[key] = kwargs.pop(key)
256
269
  super().__init__()
257
270
  if name:
258
271
  self.set_spec("name", name)
@@ -398,7 +411,7 @@ class Job(Builder):
398
411
  freeform_tags=None,
399
412
  defined_tags=None,
400
413
  wait=False,
401
- **kwargs
414
+ **kwargs,
402
415
  ) -> Union[DataScienceJobRun, DataFlowRun]:
403
416
  """Runs the job.
404
417
 
@@ -454,7 +467,7 @@ class Job(Builder):
454
467
  freeform_tags=freeform_tags,
455
468
  defined_tags=defined_tags,
456
469
  wait=wait,
457
- **kwargs
470
+ **kwargs,
458
471
  )
459
472
 
460
473
  def run_list(self, **kwargs) -> list:
@@ -466,7 +479,7 @@ class Job(Builder):
466
479
  A list of job run instances, the actual object type depends on the infrastructure.
467
480
  """
468
481
  return self.infrastructure.run_list(**kwargs)
469
-
482
+
470
483
  def cancel(self, wait_for_completion: bool = True) -> None:
471
484
  """Cancels the runs of the job.
472
485
 
@@ -479,16 +492,16 @@ class Job(Builder):
479
492
  runs = self.run_list()
480
493
  for run in runs:
481
494
  run.cancel(wait_for_completion=False)
482
-
495
+
483
496
  if wait_for_completion:
484
497
  for run in runs:
485
498
  while (
486
- run.lifecycle_state !=
487
- oci.data_science.models.JobRun.LIFECYCLE_STATE_CANCELED
499
+ run.lifecycle_state
500
+ != oci.data_science.models.JobRun.LIFECYCLE_STATE_CANCELED
488
501
  ):
489
502
  run.sync()
490
503
  time.sleep(SLEEP_INTERVAL)
491
-
504
+
492
505
  def delete(self) -> None:
493
506
  """Deletes the job from the infrastructure."""
494
507
  self.infrastructure.delete()
@@ -532,7 +545,7 @@ class Job(Builder):
532
545
  "spec": spec,
533
546
  }
534
547
 
535
- @classmethod
548
+ @class_or_instance_method
536
549
  def from_dict(cls, config: dict) -> "Job":
537
550
  """Initializes a job from a dictionary containing the configurations.
538
551
 
@@ -559,7 +572,10 @@ class Job(Builder):
559
572
  "infrastructure": cls._INFRASTRUCTURE_MAPPING,
560
573
  "runtime": cls._RUNTIME_MAPPING,
561
574
  }
562
- job = cls()
575
+ if inspect.isclass(cls):
576
+ job = cls()
577
+ else:
578
+ job = cls.__class__()
563
579
 
564
580
  for key, value in spec.items():
565
581
  if key in mappings:
@@ -569,9 +585,11 @@ class Job(Builder):
569
585
  raise NotImplementedError(
570
586
  f"{key.title()} type: {child_config.get('type')} is not supported."
571
587
  )
572
- job.set_spec(
573
- key, mapping[child_config.get("type")].from_dict(child_config)
574
- )
588
+ spec_class = mapping[child_config.get("type")]
589
+ if spec_class == DataScienceJob:
590
+ spec_class = DataScienceJob(**job.auth)
591
+
592
+ job.set_spec(key, spec_class.from_dict(child_config))
575
593
  else:
576
594
  job.set_spec(key, value)
577
595
 
@@ -1,10 +1,11 @@
1
1
  #!/usr/bin/env python
2
2
  # -*- coding: utf-8; -*-
3
3
 
4
- # Copyright (c) 2021, 2023 Oracle and/or its affiliates.
4
+ # Copyright (c) 2021, 2024 Oracle and/or its affiliates.
5
5
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
6
 
7
7
  from ads.common import utils as common_utils
8
+ from ads.common.decorator.utils import class_or_instance_method
8
9
  from ads.jobs.builders.base import Builder
9
10
  from ads.jobs.builders.runtimes.base import Runtime
10
11
  import logging
@@ -97,7 +98,8 @@ class Infrastructure(Builder):
97
98
  """
98
99
  raise NotImplementedError()
99
100
 
100
- def list_jobs(self, **kwargs) -> list:
101
+ @class_or_instance_method
102
+ def list_jobs(cls, **kwargs) -> list:
101
103
  """
102
104
  List jobs from the infrastructure.
103
105
 
@@ -1,7 +1,7 @@
1
1
  #!/usr/bin/env python
2
2
  # -*- coding: utf-8; -*-
3
3
 
4
- # Copyright (c) 2021, 2023 Oracle and/or its affiliates.
4
+ # Copyright (c) 2021, 2024 Oracle and/or its affiliates.
5
5
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
6
  from __future__ import annotations
7
7
 
@@ -40,6 +40,7 @@ from ads.common.dsc_file_system import (
40
40
  DSCFileSystemManager,
41
41
  OCIObjectStorage,
42
42
  )
43
+ from ads.common.decorator.utils import class_or_instance_method
43
44
 
44
45
  logger = logging.getLogger(__name__)
45
46
 
@@ -171,10 +172,12 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
171
172
  )
172
173
  subnet_id = infra.get(
173
174
  "subnetId",
174
- nb_config.subnet_id
175
- if infra_type
176
- != JobInfrastructureConfigurationDetails.JOB_INFRASTRUCTURE_TYPE_ME_STANDALONE
177
- else None,
175
+ (
176
+ nb_config.subnet_id
177
+ if infra_type
178
+ != JobInfrastructureConfigurationDetails.JOB_INFRASTRUCTURE_TYPE_ME_STANDALONE
179
+ else None
180
+ ),
178
181
  )
179
182
  job_shape_config_details = infra.get("jobShapeConfigDetails", {})
180
183
  memory_in_gbs = job_shape_config_details.get(
@@ -302,9 +305,9 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
302
305
  return self
303
306
 
304
307
  def _create_with_oci_api(self) -> None:
305
- res = self.client.create_job(
306
- self.to_oci_model(oci.data_science.models.CreateJobDetails)
307
- )
308
+ oci_model = self.to_oci_model(oci.data_science.models.CreateJobDetails)
309
+ logger.debug(oci_model)
310
+ res = self.client.create_job(oci_model)
308
311
  self.update_from_oci_model(res.data)
309
312
  if self.lifecycle_state == "ACTIVE":
310
313
  return
@@ -353,9 +356,16 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
353
356
  """Updates the Data Science Job."""
354
357
  raise NotImplementedError("Updating Job is not supported at the moment.")
355
358
 
356
- def delete(self) -> DSCJob:
359
+ def delete(self, force_delete: bool = False) -> DSCJob:
357
360
  """Deletes the job and the corresponding job runs.
358
361
 
362
+ Parameters
363
+ ----------
364
+ force_delete : bool, optional, defaults to False
365
+ the deletion fails when associated job runs are in progress, but if force_delete to true, then
366
+ the job run will be canceled, then it will be deleted. In this case, delete job has to wait till
367
+ job has been canceled.
368
+
359
369
  Returns
360
370
  -------
361
371
  DSCJob
@@ -364,6 +374,12 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
364
374
  """
365
375
  runs = self.run_list()
366
376
  for run in runs:
377
+ if run.lifecycle_state in [
378
+ DataScienceJobRun.LIFECYCLE_STATE_ACCEPTED,
379
+ DataScienceJobRun.LIFECYCLE_STATE_IN_PROGRESS,
380
+ DataScienceJobRun.LIFECYCLE_STATE_NEEDS_ATTENTION,
381
+ ]:
382
+ run.cancel(wait_for_completion=True)
367
383
  run.delete()
368
384
  self.client.delete_job(self.id)
369
385
  return self
@@ -430,7 +446,7 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
430
446
  items = oci.pagination.list_call_get_all_results(
431
447
  self.client.list_job_runs, self.compartment_id, job_id=self.id, **kwargs
432
448
  ).data
433
- return [DataScienceJobRun.from_oci_model(item) for item in items]
449
+ return [DataScienceJobRun(**self.auth).from_oci_model(item) for item in items]
434
450
 
435
451
  def run(self, **kwargs) -> DataScienceJobRun:
436
452
  """Runs the job
@@ -496,28 +512,11 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
496
512
  kwargs["job_configuration_override_details"] = config_override
497
513
 
498
514
  wait = kwargs.pop("wait", False)
499
- run = DataScienceJobRun(**kwargs).create()
515
+ run = DataScienceJobRun(**kwargs, **self.auth).create()
500
516
  if wait:
501
517
  return run.watch()
502
518
  return run
503
519
 
504
- @classmethod
505
- def from_ocid(cls, ocid) -> DSCJob:
506
- """Gets a job by OCID
507
-
508
- Parameters
509
- ----------
510
- ocid : str
511
- The OCID of the job.
512
-
513
- Returns
514
- -------
515
- DSCJob
516
- An instance of DSCJob.
517
-
518
- """
519
- return super().from_ocid(ocid)
520
-
521
520
 
522
521
  class DataScienceJobRun(
523
522
  OCIDataScienceMixin, oci.data_science.models.JobRun, RunInstance
@@ -575,7 +574,12 @@ class DataScienceJobRun(
575
574
  if not self.log_id:
576
575
  raise ValueError("Log OCID is not specified for this job run.")
577
576
  # Specifying log group ID when initializing OCILog can reduce the number of API calls.
578
- return OCILog(id=self.log_id, log_group_id=self.log_details.log_group_id)
577
+ auth = self.auth
578
+ if "client_kwargs" in auth and isinstance(auth["client_kwargs"], dict):
579
+ auth["client_kwargs"].pop("service_endpoint", None)
580
+ return OCILog(
581
+ id=self.log_id, log_group_id=self.log_details.log_group_id, **auth
582
+ )
579
583
 
580
584
  @staticmethod
581
585
  def _format_log(message: str, date_time: datetime.datetime) -> dict:
@@ -928,6 +932,8 @@ class DataScienceJob(Infrastructure):
928
932
  OBJECT_STORAGE_TYPE: OCIObjectStorage,
929
933
  }
930
934
 
935
+ auth = {}
936
+
931
937
  @staticmethod
932
938
  def standardize_spec(spec):
933
939
  if not spec:
@@ -965,12 +971,16 @@ class DataScienceJob(Infrastructure):
965
971
  Specification as keyword arguments.
966
972
  If spec contains the same key as the one in kwargs, the value from kwargs will be used.
967
973
  """
974
+ for key in ["config", "signer", "client_kwargs"]:
975
+ if kwargs.get(key):
976
+ self.auth[key] = kwargs.pop(key)
977
+
968
978
  self.standardize_spec(spec)
969
979
  self.standardize_spec(kwargs)
970
980
  super().__init__(spec=spec, **kwargs)
971
981
  if not self.job_type:
972
982
  self.with_job_type("DEFAULT")
973
- self.dsc_job = DSCJob()
983
+ self.dsc_job = DSCJob(**self.auth)
974
984
  self.runtime = None
975
985
  self._name = None
976
986
 
@@ -1557,7 +1567,7 @@ class DataScienceJob(Infrastructure):
1557
1567
  if not payload.get("defined_tags"):
1558
1568
  payload["defined_tags"] = self.defined_tags
1559
1569
 
1560
- self.dsc_job = DSCJob(**payload)
1570
+ self.dsc_job = DSCJob(**payload, **self.auth)
1561
1571
  # Set Job infra to user values after DSCJob initialized the defaults
1562
1572
  self._update_job_infra(self.dsc_job)
1563
1573
  self.dsc_job.create()
@@ -1682,7 +1692,7 @@ class DataScienceJob(Infrastructure):
1682
1692
  instance.runtime = DataScienceJobRuntimeManager(instance).extract(dsc_job)
1683
1693
  return instance
1684
1694
 
1685
- @classmethod
1695
+ @class_or_instance_method
1686
1696
  def from_id(cls, job_id: str) -> DataScienceJob:
1687
1697
  """Gets an existing job using Job OCID
1688
1698
 
@@ -1698,9 +1708,9 @@ class DataScienceJob(Infrastructure):
1698
1708
  An instance of DataScienceJob
1699
1709
 
1700
1710
  """
1701
- return cls.from_dsc_job(DSCJob.from_ocid(job_id))
1711
+ return cls.from_dsc_job(DSCJob(**cls.auth).from_ocid(job_id))
1702
1712
 
1703
- @classmethod
1713
+ @class_or_instance_method
1704
1714
  def list_jobs(cls, compartment_id: str = None, **kwargs) -> List[DataScienceJob]:
1705
1715
  """Lists all jobs in a compartment.
1706
1716
 
@@ -1721,10 +1731,10 @@ class DataScienceJob(Infrastructure):
1721
1731
  """
1722
1732
  return [
1723
1733
  cls.from_dsc_job(job)
1724
- for job in DSCJob.list_resource(compartment_id, **kwargs)
1734
+ for job in DSCJob(**cls.auth).list_resource(compartment_id, **kwargs)
1725
1735
  ]
1726
1736
 
1727
- @classmethod
1737
+ @class_or_instance_method
1728
1738
  def instance_shapes(cls, compartment_id: str = None, **kwargs) -> list:
1729
1739
  """Lists the supported shapes for running jobs in a compartment.
1730
1740
 
@@ -1752,13 +1762,13 @@ class DataScienceJob(Infrastructure):
1752
1762
 
1753
1763
  """
1754
1764
  shapes = oci.pagination.list_call_get_all_results(
1755
- DSCJob.init_client().list_job_shapes,
1765
+ DSCJob(**cls.auth).init_client().list_job_shapes,
1756
1766
  DSCJob.check_compartment_id(compartment_id),
1757
1767
  **kwargs,
1758
1768
  ).data
1759
1769
  return shapes
1760
1770
 
1761
- @classmethod
1771
+ @class_or_instance_method
1762
1772
  def fast_launch_shapes(cls, compartment_id: str = None, **kwargs) -> list:
1763
1773
  """Lists the supported fast launch shapes for running jobs in a compartment.
1764
1774
 
@@ -1786,7 +1796,7 @@ class DataScienceJob(Infrastructure):
1786
1796
 
1787
1797
  """
1788
1798
  shapes = oci.pagination.list_call_get_all_results(
1789
- DSCJob.init_client().list_fast_launch_job_configs,
1799
+ DSCJob(**cls.auth).init_client().list_fast_launch_job_configs,
1790
1800
  DSCJob.check_compartment_id(compartment_id),
1791
1801
  **kwargs,
1792
1802
  ).data
@@ -1,10 +1,13 @@
1
1
  #!/usr/bin/env python
2
2
  # -*- coding: utf-8; -*-
3
3
 
4
- # Copyright (c) 2022, 2023 Oracle and/or its affiliates.
4
+ # Copyright (c) 2022, 2024 Oracle and/or its affiliates.
5
5
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
6
  from __future__ import annotations
7
7
  import re
8
+ import time
9
+ import traceback
10
+
8
11
  from typing import Dict, TypeVar
9
12
  from ads.jobs.builders.base import Builder
10
13
  from ads.jobs import env_var_parser
@@ -253,3 +256,70 @@ class Runtime(Builder):
253
256
  )
254
257
  .with_argument(**kwargs.get(self.attribute_map[self.CONST_ARGS], {}))
255
258
  )
259
+
260
+
261
+ class MultiNodeRuntime(Runtime):
262
+ """Represents runtime supporting multi-node jobs."""
263
+
264
+ CONST_REPLICA = "replicas"
265
+
266
+ def with_replica(self, count: int):
267
+ """Specifies the number of nodes (job runs) for the job.
268
+
269
+ Parameters
270
+ ----------
271
+ count : int
272
+ Number of nodes (job runs)
273
+
274
+ Returns
275
+ -------
276
+ self
277
+ The runtime instance.
278
+ """
279
+ return self.set_spec(self.CONST_REPLICA, count)
280
+
281
+ @property
282
+ def replica(self) -> int:
283
+ """The number of nodes (job runs)."""
284
+ return self.get_spec(self.CONST_REPLICA)
285
+
286
+ def run(self, dsc_job, **kwargs):
287
+ """Starts the job runs"""
288
+ replicas = self.replica if self.replica else 1
289
+ main_run = None
290
+ job_runs = []
291
+ try:
292
+ for i in range(replicas):
293
+ replica_kwargs = kwargs.copy()
294
+
295
+ # Only update display name and env vars if replica is specified (not None).
296
+ if self.replica is not None:
297
+ envs = replica_kwargs.get("environment_variables")
298
+ if not envs:
299
+ envs = {}
300
+ # HuggingFace accelerate requires machine rank
301
+ # Here we use NODE_RANK to store the machine rank
302
+ envs["NODE_RANK"] = str(i)
303
+ envs["NODE_COUNT"] = str(replicas)
304
+ if main_run:
305
+ envs["MAIN_JOB_RUN_OCID"] = main_run.id
306
+ name = replica_kwargs.get("display_name")
307
+ if not name:
308
+ name = dsc_job.display_name
309
+
310
+ replica_kwargs["display_name"] = f"{name}-{str(i)}"
311
+ replica_kwargs["environment_variables"] = envs
312
+ run = dsc_job.run(**replica_kwargs)
313
+ job_runs.append(run)
314
+ if i == 0:
315
+ main_run = run
316
+ except Exception as ex:
317
+ traceback.print_exc()
318
+ # Wait a few second to avoid the job run being in a transient state.
319
+ time.sleep(2)
320
+ # If there is any error when creating the job runs
321
+ # cancel all the job runs.
322
+ for run in job_runs:
323
+ run.cancel()
324
+ raise ex
325
+ return main_run
@@ -1,13 +1,13 @@
1
1
  #!/usr/bin/env python
2
2
  # -*- coding: utf-8; -*-
3
3
 
4
- # Copyright (c) 2021, 2023 Oracle and/or its affiliates.
4
+ # Copyright (c) 2021, 2024 Oracle and/or its affiliates.
5
5
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
6
  from typing import Union
7
- from ads.jobs.builders.runtimes.base import Runtime
7
+ from ads.jobs.builders.runtimes.base import MultiNodeRuntime
8
8
 
9
9
 
10
- class ContainerRuntime(Runtime):
10
+ class ContainerRuntime(MultiNodeRuntime):
11
11
  """Represents a container job runtime
12
12
 
13
13
  To define container runtime:
@@ -51,7 +51,7 @@ class ContainerRuntime(Runtime):
51
51
  CONST_ENTRYPOINT: CONST_ENTRYPOINT,
52
52
  CONST_CMD: CONST_CMD,
53
53
  }
54
- attribute_map.update(Runtime.attribute_map)
54
+ attribute_map.update(MultiNodeRuntime.attribute_map)
55
55
 
56
56
  @property
57
57
  def image(self) -> str: