wandb 0.15.4__py3-none-any.whl → 0.15.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -1
- wandb/analytics/sentry.py +1 -0
- wandb/apis/internal.py +3 -0
- wandb/apis/public.py +18 -20
- wandb/beta/workflows.py +5 -6
- wandb/cli/cli.py +27 -27
- wandb/data_types.py +2 -0
- wandb/integration/langchain/wandb_tracer.py +16 -179
- wandb/integration/sagemaker/config.py +2 -2
- wandb/integration/tensorboard/log.py +4 -4
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/wandb_deprecated.py +3 -1
- wandb/sdk/__init__.py +1 -4
- wandb/sdk/artifacts/__init__.py +0 -14
- wandb/sdk/artifacts/artifact.py +1757 -277
- wandb/sdk/artifacts/artifact_manifest_entry.py +26 -6
- wandb/sdk/artifacts/artifact_state.py +10 -0
- wandb/sdk/artifacts/artifacts_cache.py +7 -8
- wandb/sdk/artifacts/exceptions.py +4 -4
- wandb/sdk/artifacts/storage_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +16 -6
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/http_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +35 -32
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +5 -9
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +2 -2
- wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +2 -2
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +24 -16
- wandb/sdk/artifacts/storage_policy.py +3 -3
- wandb/sdk/data_types/_dtypes.py +7 -12
- wandb/sdk/data_types/base_types/json_metadata.py +2 -2
- wandb/sdk/data_types/base_types/media.py +5 -6
- wandb/sdk/data_types/base_types/wb_value.py +12 -13
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +4 -5
- wandb/sdk/data_types/helper_types/classes.py +5 -8
- wandb/sdk/data_types/helper_types/image_mask.py +4 -5
- wandb/sdk/data_types/histogram.py +3 -3
- wandb/sdk/data_types/html.py +3 -4
- wandb/sdk/data_types/image.py +4 -5
- wandb/sdk/data_types/molecule.py +2 -2
- wandb/sdk/data_types/object_3d.py +3 -3
- wandb/sdk/data_types/plotly.py +2 -2
- wandb/sdk/data_types/saved_model.py +7 -8
- wandb/sdk/data_types/trace_tree.py +4 -4
- wandb/sdk/data_types/video.py +4 -4
- wandb/sdk/interface/interface.py +8 -10
- wandb/sdk/internal/file_stream.py +2 -3
- wandb/sdk/internal/internal_api.py +99 -4
- wandb/sdk/internal/job_builder.py +15 -7
- wandb/sdk/internal/sender.py +4 -0
- wandb/sdk/internal/settings_static.py +1 -0
- wandb/sdk/launch/_project_spec.py +9 -7
- wandb/sdk/launch/agent/agent.py +115 -58
- wandb/sdk/launch/agent/job_status_tracker.py +34 -0
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
- wandb/sdk/launch/builder/abstract.py +5 -1
- wandb/sdk/launch/builder/build.py +16 -10
- wandb/sdk/launch/builder/docker_builder.py +9 -2
- wandb/sdk/launch/builder/kaniko_builder.py +108 -22
- wandb/sdk/launch/builder/noop.py +3 -1
- wandb/sdk/launch/environment/aws_environment.py +2 -1
- wandb/sdk/launch/environment/azure_environment.py +124 -0
- wandb/sdk/launch/github_reference.py +30 -18
- wandb/sdk/launch/launch.py +1 -1
- wandb/sdk/launch/loader.py +15 -0
- wandb/sdk/launch/registry/azure_container_registry.py +132 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +38 -4
- wandb/sdk/launch/registry/google_artifact_registry.py +46 -7
- wandb/sdk/launch/runner/abstract.py +19 -3
- wandb/sdk/launch/runner/kubernetes_runner.py +111 -47
- wandb/sdk/launch/runner/local_container.py +101 -48
- wandb/sdk/launch/runner/sagemaker_runner.py +59 -9
- wandb/sdk/launch/runner/vertex_runner.py +8 -4
- wandb/sdk/launch/sweeps/scheduler.py +102 -27
- wandb/sdk/launch/sweeps/utils.py +21 -0
- wandb/sdk/launch/utils.py +19 -7
- wandb/sdk/lib/_settings_toposort_generated.py +3 -0
- wandb/sdk/service/server.py +22 -9
- wandb/sdk/service/service.py +27 -8
- wandb/sdk/verify/verify.py +6 -9
- wandb/sdk/wandb_config.py +2 -4
- wandb/sdk/wandb_init.py +2 -0
- wandb/sdk/wandb_require.py +7 -0
- wandb/sdk/wandb_run.py +32 -35
- wandb/sdk/wandb_settings.py +10 -3
- wandb/testing/relay.py +15 -2
- wandb/util.py +55 -23
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/METADATA +11 -8
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/RECORD +97 -97
- wandb/integration/langchain/util.py +0 -191
- wandb/sdk/artifacts/invalid_artifact.py +0 -23
- wandb/sdk/artifacts/lazy_artifact.py +0 -162
- wandb/sdk/artifacts/local_artifact.py +0 -719
- wandb/sdk/artifacts/public_artifact.py +0 -1188
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +0 -0
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
wandb/sdk/interface/interface.py
CHANGED
@@ -20,8 +20,6 @@ from typing import TYPE_CHECKING, Any, Iterable, NewType, Optional, Tuple, Union
|
|
20
20
|
from wandb.proto import wandb_internal_pb2 as pb
|
21
21
|
from wandb.proto import wandb_telemetry_pb2 as tpb
|
22
22
|
from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
|
23
|
-
from wandb.sdk.artifacts.local_artifact import Artifact as LocalArtifact
|
24
|
-
from wandb.sdk.artifacts.public_artifact import Artifact as PublicArtifact
|
25
23
|
from wandb.util import (
|
26
24
|
WandBJSONEncoderOld,
|
27
25
|
get_h5_typename,
|
@@ -40,6 +38,8 @@ from .message_future import MessageFuture
|
|
40
38
|
GlobStr = NewType("GlobStr", str)
|
41
39
|
|
42
40
|
if TYPE_CHECKING:
|
41
|
+
from wandb.sdk.artifacts.artifact import Artifact
|
42
|
+
|
43
43
|
from ..wandb_run import Run
|
44
44
|
|
45
45
|
if sys.version_info >= (3, 8):
|
@@ -373,7 +373,7 @@ class InterfaceBase:
|
|
373
373
|
def _publish_files(self, files: pb.FilesRecord) -> None:
|
374
374
|
raise NotImplementedError
|
375
375
|
|
376
|
-
def _make_artifact(self, artifact:
|
376
|
+
def _make_artifact(self, artifact: "Artifact") -> pb.ArtifactRecord:
|
377
377
|
proto_artifact = pb.ArtifactRecord()
|
378
378
|
proto_artifact.type = artifact.type
|
379
379
|
proto_artifact.name = artifact.name
|
@@ -425,14 +425,14 @@ class InterfaceBase:
|
|
425
425
|
def publish_link_artifact(
|
426
426
|
self,
|
427
427
|
run: "Run",
|
428
|
-
artifact:
|
428
|
+
artifact: "Artifact",
|
429
429
|
portfolio_name: str,
|
430
430
|
aliases: Iterable[str],
|
431
431
|
entity: Optional[str] = None,
|
432
432
|
project: Optional[str] = None,
|
433
433
|
) -> None:
|
434
434
|
link_artifact = pb.LinkArtifactRecord()
|
435
|
-
if
|
435
|
+
if artifact.is_draft():
|
436
436
|
link_artifact.client_id = artifact._client_id
|
437
437
|
else:
|
438
438
|
link_artifact.server_id = artifact.id if artifact.id else ""
|
@@ -449,10 +449,8 @@ class InterfaceBase:
|
|
449
449
|
|
450
450
|
def publish_use_artifact(
|
451
451
|
self,
|
452
|
-
artifact:
|
452
|
+
artifact: "Artifact",
|
453
453
|
) -> None:
|
454
|
-
# use_artifact is either a public.Artifact or a wandb.Artifact that has been
|
455
|
-
# waited on and has an id
|
456
454
|
assert artifact.id is not None, "Artifact must have an id"
|
457
455
|
use_artifact = pb.UseArtifactRecord(
|
458
456
|
id=artifact.id, type=artifact.type, name=artifact.name
|
@@ -467,7 +465,7 @@ class InterfaceBase:
|
|
467
465
|
def communicate_artifact(
|
468
466
|
self,
|
469
467
|
run: "Run",
|
470
|
-
artifact:
|
468
|
+
artifact: "Artifact",
|
471
469
|
aliases: Iterable[str],
|
472
470
|
history_step: Optional[int] = None,
|
473
471
|
is_user_created: bool = False,
|
@@ -517,7 +515,7 @@ class InterfaceBase:
|
|
517
515
|
def publish_artifact(
|
518
516
|
self,
|
519
517
|
run: "Run",
|
520
|
-
artifact:
|
518
|
+
artifact: "Artifact",
|
521
519
|
aliases: Iterable[str],
|
522
520
|
is_user_created: bool = False,
|
523
521
|
use_after_commit: bool = False,
|
@@ -653,9 +653,8 @@ def request_with_retry(
|
|
653
653
|
e.response is not None and e.response.status_code == 429
|
654
654
|
):
|
655
655
|
err_str = (
|
656
|
-
"Filestream rate limit exceeded,
|
657
|
-
|
658
|
-
)
|
656
|
+
"Filestream rate limit exceeded, "
|
657
|
+
f"retrying in {delay:.1f} seconds. "
|
659
658
|
)
|
660
659
|
if retry_callback:
|
661
660
|
retry_callback(e.response.status_code, err_str)
|
@@ -287,6 +287,7 @@ class Api:
|
|
287
287
|
self.server_use_artifact_input_info: Optional[List[str]] = None
|
288
288
|
self._max_cli_version: Optional[str] = None
|
289
289
|
self._server_settings_type: Optional[List[str]] = None
|
290
|
+
self.fail_run_queue_item_input_info: Optional[List[str]] = None
|
290
291
|
|
291
292
|
def gql(self, *args: Any, **kwargs: Any) -> Any:
|
292
293
|
ret = self._retry_gql(
|
@@ -586,13 +587,103 @@ class Api:
|
|
586
587
|
return "failRunQueueItem" in mutations
|
587
588
|
|
588
589
|
@normalize_exceptions
|
589
|
-
def
|
590
|
+
def fail_run_queue_item_fields_introspection(self) -> List:
|
591
|
+
if self.fail_run_queue_item_input_info:
|
592
|
+
return self.fail_run_queue_item_input_info
|
593
|
+
query_string = """
|
594
|
+
query ProbeServerFailRunQueueItemInput {
|
595
|
+
FailRunQueueItemInputInfoType: __type(name:"FailRunQueueItemInput") {
|
596
|
+
inputFields{
|
597
|
+
name
|
598
|
+
}
|
599
|
+
}
|
600
|
+
}
|
601
|
+
"""
|
602
|
+
|
603
|
+
query = gql(query_string)
|
604
|
+
res = self.gql(query)
|
605
|
+
|
606
|
+
self.fail_run_queue_item_input_info = [
|
607
|
+
field.get("name", "")
|
608
|
+
for field in res.get("FailRunQueueItemInputInfoType", {}).get(
|
609
|
+
"inputFields", [{}]
|
610
|
+
)
|
611
|
+
]
|
612
|
+
return self.fail_run_queue_item_input_info
|
613
|
+
|
614
|
+
@normalize_exceptions
|
615
|
+
def fail_run_queue_item(
|
616
|
+
self,
|
617
|
+
run_queue_item_id: str,
|
618
|
+
message: str,
|
619
|
+
stage: str,
|
620
|
+
file_paths: Optional[List[str]] = None,
|
621
|
+
) -> bool:
|
622
|
+
if not self.fail_run_queue_item_introspection():
|
623
|
+
return False
|
624
|
+
variable_values: Dict[str, Union[str, Optional[List[str]]]] = {
|
625
|
+
"runQueueItemId": run_queue_item_id,
|
626
|
+
}
|
627
|
+
if "message" in self.fail_run_queue_item_fields_introspection():
|
628
|
+
variable_values.update({"message": message, "stage": stage})
|
629
|
+
if file_paths is not None:
|
630
|
+
variable_values["filePaths"] = file_paths
|
631
|
+
mutation_string = """
|
632
|
+
mutation failRunQueueItem($runQueueItemId: ID!, $message: String!, $stage: String!, $filePaths: [String!]) {
|
633
|
+
failRunQueueItem(
|
634
|
+
input: {
|
635
|
+
runQueueItemId: $runQueueItemId
|
636
|
+
message: $message
|
637
|
+
stage: $stage
|
638
|
+
filePaths: $filePaths
|
639
|
+
}
|
640
|
+
) {
|
641
|
+
success
|
642
|
+
}
|
643
|
+
}
|
644
|
+
"""
|
645
|
+
else:
|
646
|
+
mutation_string = """
|
647
|
+
mutation failRunQueueItem($runQueueItemId: ID!) {
|
648
|
+
failRunQueueItem(
|
649
|
+
input: {
|
650
|
+
runQueueItemId: $runQueueItemId
|
651
|
+
}
|
652
|
+
) {
|
653
|
+
success
|
654
|
+
}
|
655
|
+
}
|
656
|
+
"""
|
657
|
+
|
658
|
+
mutation = gql(mutation_string)
|
659
|
+
response = self.gql(mutation, variable_values=variable_values)
|
660
|
+
result: bool = response["failRunQueueItem"]["success"]
|
661
|
+
return result
|
662
|
+
|
663
|
+
@normalize_exceptions
|
664
|
+
def update_run_queue_item_warning_introspection(self) -> bool:
|
665
|
+
_, _, mutations = self.server_info_introspection()
|
666
|
+
return "updateRunQueueItemWarning" in mutations
|
667
|
+
|
668
|
+
@normalize_exceptions
|
669
|
+
def update_run_queue_item_warning(
|
670
|
+
self,
|
671
|
+
run_queue_item_id: str,
|
672
|
+
message: str,
|
673
|
+
stage: str,
|
674
|
+
file_paths: Optional[List[str]] = None,
|
675
|
+
) -> bool:
|
676
|
+
if not self.update_run_queue_item_warning_introspection():
|
677
|
+
return False
|
590
678
|
mutation = gql(
|
591
679
|
"""
|
592
|
-
mutation
|
593
|
-
|
680
|
+
mutation updateRunQueueItemWarning($runQueueItemId: ID!, $message: String!, $stage: String!, $filePaths: [String!]) {
|
681
|
+
updateRunQueueItemWarning(
|
594
682
|
input: {
|
595
683
|
runQueueItemId: $runQueueItemId
|
684
|
+
message: $message
|
685
|
+
stage: $stage
|
686
|
+
filePaths: $filePaths
|
596
687
|
}
|
597
688
|
) {
|
598
689
|
success
|
@@ -604,9 +695,12 @@ class Api:
|
|
604
695
|
mutation,
|
605
696
|
variable_values={
|
606
697
|
"runQueueItemId": run_queue_item_id,
|
698
|
+
"message": message,
|
699
|
+
"stage": stage,
|
700
|
+
"filePaths": file_paths,
|
607
701
|
},
|
608
702
|
)
|
609
|
-
result: bool = response["
|
703
|
+
result: bool = response["updateRunQueueItemWarning"]["success"]
|
610
704
|
return result
|
611
705
|
|
612
706
|
@normalize_exceptions
|
@@ -617,6 +711,7 @@ class Api:
|
|
617
711
|
viewer {
|
618
712
|
id
|
619
713
|
entity
|
714
|
+
username
|
620
715
|
flags
|
621
716
|
teams {
|
622
717
|
edges {
|
@@ -5,7 +5,7 @@ import os
|
|
5
5
|
import sys
|
6
6
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
7
7
|
|
8
|
-
from wandb.sdk.artifacts.
|
8
|
+
from wandb.sdk.artifacts.artifact import Artifact
|
9
9
|
from wandb.sdk.data_types._dtypes import TypeRegistry
|
10
10
|
from wandb.sdk.lib.filenames import DIFF_FNAME, METADATA_FNAME, REQUIREMENTS_FNAME
|
11
11
|
from wandb.util import make_artifact_name_safe
|
@@ -62,7 +62,7 @@ class ArtifactInfoForJob(TypedDict):
|
|
62
62
|
name: str
|
63
63
|
|
64
64
|
|
65
|
-
class JobArtifact(
|
65
|
+
class JobArtifact(Artifact):
|
66
66
|
def __init__(self, name: str, *args: Any, **kwargs: Any):
|
67
67
|
super().__init__(name, "placeholder", *args, **kwargs)
|
68
68
|
self._type = JOB_ARTIFACT_TYPE # Get around type restriction.
|
@@ -76,6 +76,7 @@ class JobBuilder:
|
|
76
76
|
_summary: Optional[Dict[str, Any]]
|
77
77
|
_logged_code_artifact: Optional[ArtifactInfoForJob]
|
78
78
|
_disable: bool
|
79
|
+
_aliases: List[str]
|
79
80
|
|
80
81
|
def __init__(self, settings: SettingsStatic):
|
81
82
|
self._settings = settings
|
@@ -85,6 +86,7 @@ class JobBuilder:
|
|
85
86
|
self._summary = None
|
86
87
|
self._logged_code_artifact = None
|
87
88
|
self._disable = settings.disable_job_creation
|
89
|
+
self._aliases = []
|
88
90
|
self._source_type: Optional[
|
89
91
|
Literal["repo", "artifact", "image"]
|
90
92
|
] = settings.get("job_source")
|
@@ -116,7 +118,7 @@ class JobBuilder:
|
|
116
118
|
|
117
119
|
def _build_repo_job(
|
118
120
|
self, metadata: Dict[str, Any], program_relpath: str, root: Optional[str]
|
119
|
-
) -> Tuple[Optional[
|
121
|
+
) -> Tuple[Optional[Artifact], Optional[GitSourceDict]]:
|
120
122
|
git_info: Dict[str, str] = metadata.get("git", {})
|
121
123
|
remote = git_info.get("remote")
|
122
124
|
commit = git_info.get("commit")
|
@@ -177,7 +179,7 @@ class JobBuilder:
|
|
177
179
|
|
178
180
|
def _build_artifact_job(
|
179
181
|
self, metadata: Dict[str, Any], program_relpath: str
|
180
|
-
) -> Tuple[Optional[
|
182
|
+
) -> Tuple[Optional[Artifact], Optional[ArtifactSourceDict]]:
|
181
183
|
assert isinstance(self._logged_code_artifact, dict)
|
182
184
|
# TODO: should we just always exit early if the path doesn't exist?
|
183
185
|
if self._is_notebook_run() and not self._is_colab_run():
|
@@ -212,10 +214,16 @@ class JobBuilder:
|
|
212
214
|
|
213
215
|
def _build_image_job(
|
214
216
|
self, metadata: Dict[str, Any]
|
215
|
-
) -> Tuple[
|
217
|
+
) -> Tuple[Artifact, ImageSourceDict]:
|
216
218
|
image_name = metadata.get("docker")
|
217
219
|
assert isinstance(image_name, str)
|
218
|
-
|
220
|
+
|
221
|
+
raw_image_name = image_name
|
222
|
+
if ":" in image_name:
|
223
|
+
raw_image_name, tag = image_name.split(":")
|
224
|
+
self._aliases += [tag]
|
225
|
+
|
226
|
+
name = make_artifact_name_safe(f"job-{raw_image_name}")
|
219
227
|
artifact = JobArtifact(name)
|
220
228
|
source: ImageSourceDict = {
|
221
229
|
"image": image_name,
|
@@ -228,7 +236,7 @@ class JobBuilder:
|
|
228
236
|
def _is_colab_run(self) -> bool:
|
229
237
|
return hasattr(self._settings, "_colab") and bool(self._settings._colab)
|
230
238
|
|
231
|
-
def build(self) -> Optional[
|
239
|
+
def build(self) -> Optional[Artifact]:
|
232
240
|
_logger.info("Attempting to build job artifact")
|
233
241
|
if not os.path.exists(
|
234
242
|
os.path.join(self._settings.files_dir, REQUIREMENTS_FNAME)
|
wandb/sdk/internal/sender.py
CHANGED
@@ -1608,6 +1608,10 @@ class SendManager:
|
|
1608
1608
|
# TODO: this should be removed when the latest tag is handled
|
1609
1609
|
# by the backend (WB-12116)
|
1610
1610
|
proto_artifact.aliases.append("latest")
|
1611
|
+
# add docker image tag
|
1612
|
+
for alias in self._job_builder._aliases:
|
1613
|
+
proto_artifact.aliases.append(alias)
|
1614
|
+
|
1611
1615
|
proto_artifact.user_created = True
|
1612
1616
|
proto_artifact.use_after_commit = True
|
1613
1617
|
proto_artifact.finalize = True
|
@@ -7,19 +7,21 @@ import json
|
|
7
7
|
import logging
|
8
8
|
import os
|
9
9
|
import tempfile
|
10
|
-
from typing import Any, Dict, List, Optional
|
10
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
11
11
|
|
12
12
|
import wandb
|
13
13
|
import wandb.docker as docker
|
14
14
|
from wandb.apis.internal import Api
|
15
15
|
from wandb.errors import CommError
|
16
|
-
from wandb.sdk.artifacts.public_artifact import Artifact as PublicArtifact
|
17
16
|
from wandb.sdk.launch import utils
|
18
17
|
from wandb.sdk.lib.runid import generate_id
|
19
18
|
|
20
19
|
from .errors import LaunchError
|
21
20
|
from .utils import LOG_PREFIX, recursive_macro_sub
|
22
21
|
|
22
|
+
if TYPE_CHECKING:
|
23
|
+
from wandb.sdk.artifacts.artifact import Artifact
|
24
|
+
|
23
25
|
_logger = logging.getLogger(__name__)
|
24
26
|
|
25
27
|
DEFAULT_LAUNCH_METADATA_PATH = "launch_metadata.json"
|
@@ -69,7 +71,7 @@ class LaunchProject:
|
|
69
71
|
self.job = job
|
70
72
|
if job is not None:
|
71
73
|
wandb.termlog(f"{LOG_PREFIX}Launching job: {job}")
|
72
|
-
self._job_artifact: Optional[
|
74
|
+
self._job_artifact: Optional["Artifact"] = None
|
73
75
|
self.api = api
|
74
76
|
self.launch_spec = launch_spec
|
75
77
|
self.target_entity = target_entity
|
@@ -80,12 +82,12 @@ class LaunchProject:
|
|
80
82
|
# runner, so we need to pop the builder key out
|
81
83
|
resource_args_build = resource_args.get(resource, {}).pop("builder", {})
|
82
84
|
self.resource = resource
|
83
|
-
self.resource_args = resource_args
|
85
|
+
self.resource_args = resource_args.copy()
|
84
86
|
self.sweep_id = sweep_id
|
85
87
|
self.python_version: Optional[str] = launch_spec.get("python_version")
|
86
|
-
self.
|
87
|
-
"
|
88
|
-
)
|
88
|
+
self.accelerator_base_image: Optional[str] = resource_args_build.get(
|
89
|
+
"accelerator", {}
|
90
|
+
).get("base_image") or resource_args_build.get("cuda", {}).get("base_image")
|
89
91
|
self._base_image: Optional[str] = launch_spec.get("base_image")
|
90
92
|
self.docker_image: Optional[str] = docker_config.get(
|
91
93
|
"docker_image"
|