oracle-ads 2.11.16__py3-none-any.whl → 2.11.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/app.py +5 -6
- ads/aqua/common/enums.py +9 -0
- ads/aqua/common/utils.py +128 -1
- ads/aqua/constants.py +1 -0
- ads/aqua/evaluation/evaluation.py +1 -1
- ads/aqua/extension/common_handler.py +75 -5
- ads/aqua/extension/deployment_handler.py +2 -0
- ads/aqua/extension/model_handler.py +113 -12
- ads/aqua/model/entities.py +20 -2
- ads/aqua/model/model.py +417 -172
- ads/aqua/modeldeployment/deployment.py +69 -55
- ads/common/auth.py +4 -4
- ads/jobs/builders/infrastructure/dsc_job.py +23 -14
- ads/jobs/builders/infrastructure/dsc_job_runtime.py +12 -25
- ads/jobs/builders/runtimes/artifact.py +0 -5
- ads/jobs/builders/runtimes/container_runtime.py +26 -3
- ads/opctl/conda/cmds.py +100 -42
- ads/opctl/conda/pack.py +3 -2
- ads/opctl/operator/lowcode/anomaly/const.py +1 -0
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +58 -37
- ads/opctl/operator/lowcode/anomaly/model/factory.py +2 -0
- ads/opctl/operator/lowcode/anomaly/model/randomcutforest.py +116 -0
- ads/opctl/operator/lowcode/anomaly/schema.yaml +1 -0
- ads/opctl/operator/lowcode/forecast/const.py +1 -1
- ads/opctl/operator/lowcode/forecast/model/arima.py +9 -3
- ads/opctl/operator/lowcode/forecast/model/automlx.py +6 -1
- ads/opctl/operator/lowcode/forecast/model/autots.py +3 -1
- ads/opctl/operator/lowcode/forecast/model/factory.py +1 -1
- ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +24 -15
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +6 -1
- ads/opctl/operator/lowcode/forecast/model/prophet.py +3 -1
- ads/opctl/operator/lowcode/forecast/schema.yaml +1 -1
- {oracle_ads-2.11.16.dist-info → oracle_ads-2.11.18.dist-info}/METADATA +5 -1
- {oracle_ads-2.11.16.dist-info → oracle_ads-2.11.18.dist-info}/RECORD +37 -36
- {oracle_ads-2.11.16.dist-info → oracle_ads-2.11.18.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.11.16.dist-info → oracle_ads-2.11.18.dist-info}/WHEEL +0 -0
- {oracle_ads-2.11.16.dist-info → oracle_ads-2.11.18.dist-info}/entry_points.txt +0 -0
@@ -87,25 +87,26 @@ class AquaDeploymentApp(AquaApp):
|
|
87
87
|
|
88
88
|
@telemetry(entry_point="plugin=deployment&action=create", name="aqua")
|
89
89
|
def create(
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
90
|
+
self,
|
91
|
+
model_id: str,
|
92
|
+
instance_shape: str,
|
93
|
+
display_name: str,
|
94
|
+
instance_count: int = None,
|
95
|
+
log_group_id: str = None,
|
96
|
+
access_log_id: str = None,
|
97
|
+
predict_log_id: str = None,
|
98
|
+
compartment_id: str = None,
|
99
|
+
project_id: str = None,
|
100
|
+
description: str = None,
|
101
|
+
bandwidth_mbps: int = None,
|
102
|
+
web_concurrency: int = None,
|
103
|
+
server_port: int = None,
|
104
|
+
health_check_port: int = None,
|
105
|
+
env_var: Dict = None,
|
106
|
+
container_family: str = None,
|
107
|
+
memory_in_gbs: Optional[float] = None,
|
108
|
+
ocpus: Optional[float] = None,
|
109
|
+
model_file: Optional[str] = None,
|
109
110
|
) -> "AquaDeployment":
|
110
111
|
"""
|
111
112
|
Creates a new Aqua deployment
|
@@ -150,6 +151,8 @@ class AquaDeploymentApp(AquaApp):
|
|
150
151
|
The memory in gbs for the shape selected.
|
151
152
|
ocpus: float
|
152
153
|
The ocpu count for the shape selected.
|
154
|
+
model_file: str
|
155
|
+
The file used for model deployment.
|
153
156
|
Returns
|
154
157
|
-------
|
155
158
|
AquaDeployment
|
@@ -212,21 +215,6 @@ class AquaDeploymentApp(AquaApp):
|
|
212
215
|
|
213
216
|
env_var.update({"BASE_MODEL": f"{model_path_prefix}"})
|
214
217
|
|
215
|
-
model_format = aqua_model.freeform_tags.get(
|
216
|
-
Tags.MODEL_FORMAT, ModelFormat.SAFETENSORS.value
|
217
|
-
).upper()
|
218
|
-
if model_format == ModelFormat.GGUF.value:
|
219
|
-
try:
|
220
|
-
model_file = aqua_model.custom_metadata_list.get(
|
221
|
-
AQUA_MODEL_ARTIFACT_FILE
|
222
|
-
).value
|
223
|
-
except ValueError as err:
|
224
|
-
raise AquaValueError(
|
225
|
-
f"{AQUA_MODEL_ARTIFACT_FILE} key is not available in the custom metadata field "
|
226
|
-
f"for model {aqua_model.id}."
|
227
|
-
) from err
|
228
|
-
env_var.update({"BASE_MODEL_FILE": f"{model_file}"})
|
229
|
-
|
230
218
|
if is_fine_tuned_model:
|
231
219
|
_, fine_tune_output_path = get_model_by_reference_paths(
|
232
220
|
aqua_model.model_file_description
|
@@ -262,7 +250,7 @@ class AquaDeploymentApp(AquaApp):
|
|
262
250
|
try:
|
263
251
|
# Check if the container override flag is set. If set, then the user has chosen custom image
|
264
252
|
if aqua_model.custom_metadata_list.get(
|
265
|
-
|
253
|
+
AQUA_DEPLOYMENT_CONTAINER_OVERRIDE_FLAG_METADATA_NAME
|
266
254
|
).value:
|
267
255
|
is_custom_container = True
|
268
256
|
except Exception:
|
@@ -281,6 +269,32 @@ class AquaDeploymentApp(AquaApp):
|
|
281
269
|
f"Aqua Image used for deploying {aqua_model.id} : {container_image}"
|
282
270
|
)
|
283
271
|
|
272
|
+
model_formats_str = aqua_model.freeform_tags.get(
|
273
|
+
Tags.MODEL_FORMAT, ModelFormat.SAFETENSORS.value
|
274
|
+
).upper()
|
275
|
+
model_format = model_formats_str.split(",")
|
276
|
+
|
277
|
+
# Figure out a better way to handle this in future release
|
278
|
+
if ModelFormat.GGUF.value in model_format and container_type_key.lower() == InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY:
|
279
|
+
if model_file is not None:
|
280
|
+
logger.info(
|
281
|
+
f"Overriding {model_file} as model_file for model {aqua_model.id}."
|
282
|
+
)
|
283
|
+
else:
|
284
|
+
try:
|
285
|
+
model_file = aqua_model.custom_metadata_list.get(
|
286
|
+
AQUA_MODEL_ARTIFACT_FILE
|
287
|
+
).value
|
288
|
+
except ValueError as err:
|
289
|
+
raise AquaValueError(
|
290
|
+
f"{AQUA_MODEL_ARTIFACT_FILE} key is not available in the custom metadata field "
|
291
|
+
f"for model {aqua_model.id}. Either register the model with a default model_file or pass "
|
292
|
+
f"as a parameter when creating a deployment."
|
293
|
+
) from err
|
294
|
+
|
295
|
+
env_var.update({"BASE_MODEL_FILE": f"{model_file}"})
|
296
|
+
tags.update({Tags.MODEL_ARTIFACT_FILE: model_file})
|
297
|
+
|
284
298
|
# todo: use AquaContainerConfig.from_container_index_json instead.
|
285
299
|
# Fetch the startup cli command for the container
|
286
300
|
# container_index.json will have "containerSpec" section which will provide the cli params for
|
@@ -312,8 +326,8 @@ class AquaDeploymentApp(AquaApp):
|
|
312
326
|
if user_params:
|
313
327
|
# todo: remove this check in the future version, logic to be moved to container_index
|
314
328
|
if (
|
315
|
-
|
316
|
-
|
329
|
+
container_type_key.lower()
|
330
|
+
== InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
|
317
331
|
):
|
318
332
|
# AQUA_LLAMA_CPP_CONTAINER_FAMILY container uses uvicorn that required model/server params
|
319
333
|
# to be set as env vars
|
@@ -446,8 +460,8 @@ class AquaDeploymentApp(AquaApp):
|
|
446
460
|
for model_deployment in model_deployments:
|
447
461
|
oci_aqua = (
|
448
462
|
(
|
449
|
-
|
450
|
-
|
463
|
+
Tags.AQUA_TAG in model_deployment.freeform_tags
|
464
|
+
or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
|
451
465
|
)
|
452
466
|
if model_deployment.freeform_tags
|
453
467
|
else False
|
@@ -501,8 +515,8 @@ class AquaDeploymentApp(AquaApp):
|
|
501
515
|
|
502
516
|
oci_aqua = (
|
503
517
|
(
|
504
|
-
|
505
|
-
|
518
|
+
Tags.AQUA_TAG in model_deployment.freeform_tags
|
519
|
+
or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
|
506
520
|
)
|
507
521
|
if model_deployment.freeform_tags
|
508
522
|
else False
|
@@ -519,8 +533,8 @@ class AquaDeploymentApp(AquaApp):
|
|
519
533
|
log_group_name = ""
|
520
534
|
|
521
535
|
logs = (
|
522
|
-
|
523
|
-
|
536
|
+
model_deployment.category_log_details.access
|
537
|
+
or model_deployment.category_log_details.predict
|
524
538
|
)
|
525
539
|
if logs:
|
526
540
|
log_id = logs.log_id
|
@@ -575,9 +589,9 @@ class AquaDeploymentApp(AquaApp):
|
|
575
589
|
return config
|
576
590
|
|
577
591
|
def get_deployment_default_params(
|
578
|
-
|
579
|
-
|
580
|
-
|
592
|
+
self,
|
593
|
+
model_id: str,
|
594
|
+
instance_shape: str,
|
581
595
|
) -> List[str]:
|
582
596
|
"""Gets the default params set in the deployment configs for the given model and instance shape.
|
583
597
|
|
@@ -609,8 +623,8 @@ class AquaDeploymentApp(AquaApp):
|
|
609
623
|
)
|
610
624
|
|
611
625
|
if (
|
612
|
-
|
613
|
-
|
626
|
+
container_type_key
|
627
|
+
and container_type_key in InferenceContainerTypeFamily.values()
|
614
628
|
):
|
615
629
|
deployment_config = self.get_deployment_config(model_id)
|
616
630
|
config_params = (
|
@@ -633,10 +647,10 @@ class AquaDeploymentApp(AquaApp):
|
|
633
647
|
return default_params
|
634
648
|
|
635
649
|
def validate_deployment_params(
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
650
|
+
self,
|
651
|
+
model_id: str,
|
652
|
+
params: List[str] = None,
|
653
|
+
container_family: str = None,
|
640
654
|
) -> Dict:
|
641
655
|
"""Validate if the deployment parameters passed by the user can be overridden. Parameter values are not
|
642
656
|
validated, only param keys are validated.
|
@@ -695,9 +709,9 @@ class AquaDeploymentApp(AquaApp):
|
|
695
709
|
|
696
710
|
@staticmethod
|
697
711
|
def _find_restricted_params(
|
698
|
-
|
699
|
-
|
700
|
-
|
712
|
+
default_params: Union[str, List[str]],
|
713
|
+
user_params: Union[str, List[str]],
|
714
|
+
container_family: str,
|
701
715
|
) -> List[str]:
|
702
716
|
"""Returns a list of restricted params that user chooses to override when creating an Aqua deployment.
|
703
717
|
The default parameters coming from the container index json file cannot be overridden.
|
ads/common/auth.py
CHANGED
@@ -73,7 +73,7 @@ class AuthState(metaclass=SingletonMeta):
|
|
73
73
|
self.oci_key_profile = self.oci_key_profile or os.environ.get(
|
74
74
|
"OCI_CONFIG_PROFILE", DEFAULT_PROFILE
|
75
75
|
)
|
76
|
-
self.oci_config = self.oci_config or {}
|
76
|
+
self.oci_config = self.oci_config or {"region": os.environ["OCI_RESOURCE_REGION"]} if os.environ.get("OCI_RESOURCE_REGION") else {}
|
77
77
|
self.oci_signer_kwargs = self.oci_signer_kwargs or {}
|
78
78
|
self.oci_client_kwargs = self.oci_client_kwargs or {}
|
79
79
|
|
@@ -82,7 +82,7 @@ def set_auth(
|
|
82
82
|
auth: Optional[str] = AuthType.API_KEY,
|
83
83
|
oci_config_location: Optional[str] = DEFAULT_LOCATION,
|
84
84
|
profile: Optional[str] = DEFAULT_PROFILE,
|
85
|
-
config: Optional[Dict] = {},
|
85
|
+
config: Optional[Dict] = {"region": os.environ["OCI_RESOURCE_REGION"]} if os.environ.get("OCI_RESOURCE_REGION") else {},
|
86
86
|
signer: Optional[Any] = None,
|
87
87
|
signer_callable: Optional[Callable] = None,
|
88
88
|
signer_kwargs: Optional[Dict] = {},
|
@@ -678,7 +678,7 @@ class ResourcePrincipal(AuthSignerGenerator):
|
|
678
678
|
>>> signer_generator = AuthFactory().signerGenerator(AuthType.RESOURCE_PRINCIPAL)
|
679
679
|
>>> signer_generator(signer_args).create_signer()
|
680
680
|
"""
|
681
|
-
configuration = ads.telemetry.update_oci_client_config()
|
681
|
+
configuration = ads.telemetry.update_oci_client_config(AuthState().oci_config)
|
682
682
|
signer_dict = {
|
683
683
|
"config": configuration,
|
684
684
|
"signer": oci.auth.signers.get_resource_principals_signer(),
|
@@ -739,7 +739,7 @@ class InstancePrincipal(AuthSignerGenerator):
|
|
739
739
|
>>> signer_generator = AuthFactory().signerGenerator(AuthType.INSTANCE_PRINCIPAL)
|
740
740
|
>>> signer_generator(signer_args).create_signer()
|
741
741
|
"""
|
742
|
-
configuration = ads.telemetry.update_oci_client_config()
|
742
|
+
configuration = ads.telemetry.update_oci_client_config(AuthState().oci_config)
|
743
743
|
signer_dict = {
|
744
744
|
"config": configuration,
|
745
745
|
"signer": oci.auth.signers.InstancePrincipalsSecurityTokenSigner(
|
@@ -312,7 +312,7 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
|
|
312
312
|
logger.debug(oci_model)
|
313
313
|
res = self.client.create_job(oci_model)
|
314
314
|
self.update_from_oci_model(res.data)
|
315
|
-
if self.
|
315
|
+
if not self.artifact:
|
316
316
|
return
|
317
317
|
try:
|
318
318
|
if issubclass(self.artifact.__class__, Artifact):
|
@@ -377,13 +377,12 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
|
|
377
377
|
"""
|
378
378
|
runs = self.run_list()
|
379
379
|
for run in runs:
|
380
|
-
if force_delete
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
run.cancel(wait_for_completion=True)
|
380
|
+
if force_delete and run.lifecycle_state in [
|
381
|
+
DataScienceJobRun.LIFECYCLE_STATE_ACCEPTED,
|
382
|
+
DataScienceJobRun.LIFECYCLE_STATE_IN_PROGRESS,
|
383
|
+
DataScienceJobRun.LIFECYCLE_STATE_NEEDS_ATTENTION,
|
384
|
+
]:
|
385
|
+
run.cancel(wait_for_completion=True)
|
387
386
|
run.delete()
|
388
387
|
self.client.delete_job(self.id)
|
389
388
|
return self
|
@@ -488,7 +487,9 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
|
|
488
487
|
oci.data_science.models.DefaultJobConfigurationDetails().swagger_types.keys()
|
489
488
|
)
|
490
489
|
env_config_swagger_types = {}
|
491
|
-
if hasattr(
|
490
|
+
if hasattr(
|
491
|
+
oci.data_science.models, "OcirContainerJobEnvironmentConfigurationDetails"
|
492
|
+
):
|
492
493
|
env_config_swagger_types = (
|
493
494
|
oci.data_science.models.OcirContainerJobEnvironmentConfigurationDetails().swagger_types.keys()
|
494
495
|
)
|
@@ -502,7 +503,7 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
|
|
502
503
|
value = kwargs.pop(key)
|
503
504
|
if key in [
|
504
505
|
ContainerRuntime.CONST_CMD,
|
505
|
-
ContainerRuntime.CONST_ENTRYPOINT
|
506
|
+
ContainerRuntime.CONST_ENTRYPOINT,
|
506
507
|
] and isinstance(value, str):
|
507
508
|
value = ContainerRuntimeHandler.split_args(value)
|
508
509
|
env_config_kwargs[key] = value
|
@@ -536,9 +537,13 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
|
|
536
537
|
|
537
538
|
if env_config_kwargs:
|
538
539
|
env_config_kwargs["jobEnvironmentType"] = "OCIR_CONTAINER"
|
539
|
-
env_config_override = kwargs.get(
|
540
|
+
env_config_override = kwargs.get(
|
541
|
+
"job_environment_configuration_override_details", {}
|
542
|
+
)
|
540
543
|
env_config_override.update(env_config_kwargs)
|
541
|
-
kwargs["job_environment_configuration_override_details"] =
|
544
|
+
kwargs["job_environment_configuration_override_details"] = (
|
545
|
+
env_config_override
|
546
|
+
)
|
542
547
|
|
543
548
|
wait = kwargs.pop("wait", False)
|
544
549
|
run = DataScienceJobRun(**kwargs, **self.auth).create()
|
@@ -894,10 +899,14 @@ class DataScienceJobRun(
|
|
894
899
|
return self
|
895
900
|
|
896
901
|
def delete(self, force_delete: bool = False):
|
897
|
-
if force_delete
|
902
|
+
if force_delete and self.status in [
|
903
|
+
DataScienceJobRun.LIFECYCLE_STATE_ACCEPTED,
|
904
|
+
DataScienceJobRun.LIFECYCLE_STATE_IN_PROGRESS,
|
905
|
+
DataScienceJobRun.LIFECYCLE_STATE_NEEDS_ATTENTION,
|
906
|
+
]:
|
898
907
|
self.cancel(wait_for_completion=True)
|
899
908
|
super().delete()
|
900
|
-
return
|
909
|
+
return self
|
901
910
|
|
902
911
|
|
903
912
|
# This is for backward compatibility
|
@@ -181,9 +181,9 @@ class RuntimeHandler:
|
|
181
181
|
"jobType": self.data_science_job.job_type,
|
182
182
|
}
|
183
183
|
if runtime.maximum_runtime_in_minutes:
|
184
|
-
job_configuration_details[
|
185
|
-
|
186
|
-
|
184
|
+
job_configuration_details["maximum_runtime_in_minutes"] = (
|
185
|
+
runtime.maximum_runtime_in_minutes
|
186
|
+
)
|
187
187
|
job_configuration_details["environment_variables"] = self._translate_env(
|
188
188
|
runtime
|
189
189
|
)
|
@@ -310,7 +310,7 @@ class RuntimeHandler:
|
|
310
310
|
for extraction in extractions:
|
311
311
|
runtime_spec.update(extraction(dsc_job))
|
312
312
|
return self.RUNTIME_CLASS(self._format_env_var(runtime_spec))
|
313
|
-
|
313
|
+
|
314
314
|
def _extract_properties(self, dsc_job) -> dict:
|
315
315
|
"""Extract the job runtime properties from data science job.
|
316
316
|
|
@@ -968,23 +968,10 @@ class ContainerRuntimeHandler(RuntimeHandler):
|
|
968
968
|
payload["job_environment_configuration_details"] = job_env_config
|
969
969
|
return payload
|
970
970
|
|
971
|
-
def _translate_artifact(self, runtime:
|
972
|
-
"""
|
973
|
-
runtime
|
974
|
-
|
975
|
-
Parameters
|
976
|
-
----------
|
977
|
-
runtime : Runtime
|
978
|
-
This is not used.
|
979
|
-
|
980
|
-
Returns
|
981
|
-
-------
|
982
|
-
str
|
983
|
-
Path to the dummy script.
|
984
|
-
"""
|
985
|
-
return os.path.join(
|
986
|
-
os.path.dirname(__file__), "../../templates", "container.py"
|
987
|
-
)
|
971
|
+
def _translate_artifact(self, runtime: ContainerRuntime):
|
972
|
+
"""Additional artifact for the container"""
|
973
|
+
if runtime.artifact_uri:
|
974
|
+
return ScriptArtifact(runtime.artifact_uri, runtime)
|
988
975
|
|
989
976
|
def _translate_env_config(self, runtime: Runtime) -> dict:
|
990
977
|
"""Converts runtime properties to ``OcirContainerJobEnvironmentConfigurationDetails`` payload required by OCI Data Science job.
|
@@ -1007,7 +994,7 @@ class ContainerRuntimeHandler(RuntimeHandler):
|
|
1007
994
|
property = runtime.get_spec(key, None)
|
1008
995
|
if key in [
|
1009
996
|
ContainerRuntime.CONST_CMD,
|
1010
|
-
ContainerRuntime.CONST_ENTRYPOINT
|
997
|
+
ContainerRuntime.CONST_ENTRYPOINT,
|
1011
998
|
] and isinstance(property, str):
|
1012
999
|
property = self.split_args(property)
|
1013
1000
|
if property is not None:
|
@@ -1063,7 +1050,7 @@ class ContainerRuntimeHandler(RuntimeHandler):
|
|
1063
1050
|
spec[ContainerRuntime.CONST_ENV_VAR] = envs
|
1064
1051
|
|
1065
1052
|
return spec
|
1066
|
-
|
1053
|
+
|
1067
1054
|
def _extract_properties(self, dsc_job) -> dict:
|
1068
1055
|
"""Extract the runtime properties from data science job.
|
1069
1056
|
|
@@ -1078,10 +1065,10 @@ class ContainerRuntimeHandler(RuntimeHandler):
|
|
1078
1065
|
A runtime specification dictionary for initializing a runtime.
|
1079
1066
|
"""
|
1080
1067
|
spec = super()._extract_envs(dsc_job)
|
1081
|
-
|
1068
|
+
|
1082
1069
|
job_env_config = getattr(dsc_job, "job_environment_configuration_details", None)
|
1083
1070
|
job_env_type = getattr(job_env_config, "job_environment_type", None)
|
1084
|
-
|
1071
|
+
|
1085
1072
|
if not (job_env_config and job_env_type == "OCIR_CONTAINER"):
|
1086
1073
|
raise IncompatibleRuntime()
|
1087
1074
|
|
@@ -183,11 +183,6 @@ class ScriptArtifact(Artifact):
|
|
183
183
|
if os.path.isdir(source):
|
184
184
|
basename = os.path.basename(str(source).rstrip("/"))
|
185
185
|
source = str(source).rstrip("/")
|
186
|
-
# Runtime must have entrypoint if the source is a directory
|
187
|
-
if self.runtime and not self.runtime.entrypoint:
|
188
|
-
raise ValueError(
|
189
|
-
"Please specify entrypoint when script source is a directory."
|
190
|
-
)
|
191
186
|
output = os.path.join(self.temp_dir.name, basename)
|
192
187
|
shutil.make_archive(
|
193
188
|
output, "zip", os.path.dirname(source), base_dir=basename
|
@@ -56,6 +56,7 @@ class ContainerRuntime(MultiNodeRuntime):
|
|
56
56
|
CONST_CMD = "cmd"
|
57
57
|
CONST_IMAGE_DIGEST = "imageDigest"
|
58
58
|
CONST_IMAGE_SIGNATURE_ID = "imageSignatureId"
|
59
|
+
CONST_SCRIPT_PATH = "scriptPathURI"
|
59
60
|
attribute_map = {
|
60
61
|
CONST_IMAGE: CONST_IMAGE,
|
61
62
|
CONST_ENTRYPOINT: CONST_ENTRYPOINT,
|
@@ -121,7 +122,7 @@ class ContainerRuntime(MultiNodeRuntime):
|
|
121
122
|
def image_digest(self) -> str:
|
122
123
|
"""The container image digest."""
|
123
124
|
return self.get_spec(self.CONST_IMAGE_DIGEST)
|
124
|
-
|
125
|
+
|
125
126
|
def with_image_digest(self, image_digest: str) -> "ContainerRuntime":
|
126
127
|
"""Sets the digest of custom image.
|
127
128
|
|
@@ -136,12 +137,12 @@ class ContainerRuntime(MultiNodeRuntime):
|
|
136
137
|
The runtime instance.
|
137
138
|
"""
|
138
139
|
return self.set_spec(self.CONST_IMAGE_DIGEST, image_digest)
|
139
|
-
|
140
|
+
|
140
141
|
@property
|
141
142
|
def image_signature_id(self) -> str:
|
142
143
|
"""The container image signature id."""
|
143
144
|
return self.get_spec(self.CONST_IMAGE_SIGNATURE_ID)
|
144
|
-
|
145
|
+
|
145
146
|
def with_image_signature_id(self, image_signature_id: str) -> "ContainerRuntime":
|
146
147
|
"""Sets the signature id of custom image.
|
147
148
|
|
@@ -217,3 +218,25 @@ class ContainerRuntime(MultiNodeRuntime):
|
|
217
218
|
entrypoint=["bash", "--login", "-c"],
|
218
219
|
cmd="{Container CMD. For MLflow and Operator will be auto generated}",
|
219
220
|
)
|
221
|
+
|
222
|
+
@property
|
223
|
+
def artifact_uri(self) -> str:
|
224
|
+
"""The URI of the source code"""
|
225
|
+
return self.get_spec(self.CONST_SCRIPT_PATH)
|
226
|
+
|
227
|
+
def with_artifact(self, uri: str):
|
228
|
+
"""Specifies the artifact to be added to the container.
|
229
|
+
|
230
|
+
Parameters
|
231
|
+
----------
|
232
|
+
uri : str
|
233
|
+
URI to the source code script, which can be any URI supported by fsspec,
|
234
|
+
including http://, https:// and OCI object storage.
|
235
|
+
For example: oci://your_bucket@your_namespace/path/to/script.py
|
236
|
+
|
237
|
+
Returns
|
238
|
+
-------
|
239
|
+
self
|
240
|
+
The runtime instance.
|
241
|
+
"""
|
242
|
+
return self.set_spec(self.CONST_SCRIPT_PATH, uri)
|