wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -1
- wandb/analytics/sentry.py +1 -0
- wandb/apis/importers/base.py +20 -5
- wandb/apis/importers/mlflow.py +7 -1
- wandb/apis/internal.py +12 -0
- wandb/apis/public.py +247 -1387
- wandb/apis/reports/_panels.py +58 -35
- wandb/beta/workflows.py +6 -7
- wandb/cli/cli.py +130 -60
- wandb/data_types.py +3 -1
- wandb/filesync/dir_watcher.py +21 -27
- wandb/filesync/step_checksum.py +8 -8
- wandb/filesync/step_prepare.py +23 -10
- wandb/filesync/step_upload.py +13 -13
- wandb/filesync/upload_job.py +4 -8
- wandb/integration/cohere/__init__.py +3 -0
- wandb/integration/cohere/cohere.py +21 -0
- wandb/integration/cohere/resolver.py +347 -0
- wandb/integration/gym/__init__.py +4 -6
- wandb/integration/huggingface/__init__.py +3 -0
- wandb/integration/huggingface/huggingface.py +18 -0
- wandb/integration/huggingface/resolver.py +213 -0
- wandb/integration/langchain/wandb_tracer.py +16 -179
- wandb/integration/openai/__init__.py +1 -3
- wandb/integration/openai/openai.py +11 -143
- wandb/integration/openai/resolver.py +111 -38
- wandb/integration/sagemaker/config.py +2 -2
- wandb/integration/tensorboard/log.py +4 -4
- wandb/old/settings.py +24 -7
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/wandb_deprecated.py +3 -1
- wandb/sdk/__init__.py +1 -1
- wandb/sdk/artifacts/__init__.py +0 -0
- wandb/sdk/artifacts/artifact.py +2101 -0
- wandb/sdk/artifacts/artifact_download_logger.py +42 -0
- wandb/sdk/artifacts/artifact_manifest.py +67 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
- wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
- wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
- wandb/sdk/artifacts/artifact_state.py +10 -0
- wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
- wandb/sdk/artifacts/exceptions.py +55 -0
- wandb/sdk/artifacts/storage_handler.py +59 -0
- wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
- wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
- wandb/sdk/artifacts/storage_layout.py +6 -0
- wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
- wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
- wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
- wandb/sdk/data_types/_dtypes.py +7 -12
- wandb/sdk/data_types/base_types/json_metadata.py +3 -2
- wandb/sdk/data_types/base_types/media.py +8 -8
- wandb/sdk/data_types/base_types/wb_value.py +12 -13
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
- wandb/sdk/data_types/helper_types/classes.py +6 -8
- wandb/sdk/data_types/helper_types/image_mask.py +5 -6
- wandb/sdk/data_types/histogram.py +4 -3
- wandb/sdk/data_types/html.py +3 -4
- wandb/sdk/data_types/image.py +11 -9
- wandb/sdk/data_types/molecule.py +5 -3
- wandb/sdk/data_types/object_3d.py +7 -5
- wandb/sdk/data_types/plotly.py +3 -2
- wandb/sdk/data_types/saved_model.py +11 -11
- wandb/sdk/data_types/trace_tree.py +5 -4
- wandb/sdk/data_types/utils.py +3 -5
- wandb/sdk/data_types/video.py +5 -4
- wandb/sdk/integration_utils/auto_logging.py +215 -0
- wandb/sdk/interface/interface.py +15 -15
- wandb/sdk/internal/file_pusher.py +8 -16
- wandb/sdk/internal/file_stream.py +5 -11
- wandb/sdk/internal/handler.py +13 -1
- wandb/sdk/internal/internal_api.py +287 -13
- wandb/sdk/internal/job_builder.py +119 -30
- wandb/sdk/internal/sender.py +6 -26
- wandb/sdk/internal/settings_static.py +2 -0
- wandb/sdk/internal/system/assets/__init__.py +2 -0
- wandb/sdk/internal/system/assets/gpu.py +42 -0
- wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
- wandb/sdk/internal/system/env_probe_helpers.py +13 -0
- wandb/sdk/internal/system/system_info.py +3 -3
- wandb/sdk/internal/tb_watcher.py +32 -22
- wandb/sdk/internal/thread_local_settings.py +18 -0
- wandb/sdk/launch/_project_spec.py +57 -11
- wandb/sdk/launch/agent/agent.py +147 -65
- wandb/sdk/launch/agent/job_status_tracker.py +34 -0
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
- wandb/sdk/launch/builder/abstract.py +5 -1
- wandb/sdk/launch/builder/build.py +21 -18
- wandb/sdk/launch/builder/docker_builder.py +10 -4
- wandb/sdk/launch/builder/kaniko_builder.py +113 -23
- wandb/sdk/launch/builder/noop.py +6 -3
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
- wandb/sdk/launch/environment/aws_environment.py +3 -2
- wandb/sdk/launch/environment/azure_environment.py +124 -0
- wandb/sdk/launch/environment/gcp_environment.py +2 -4
- wandb/sdk/launch/environment/local_environment.py +1 -1
- wandb/sdk/launch/errors.py +19 -0
- wandb/sdk/launch/github_reference.py +32 -19
- wandb/sdk/launch/launch.py +3 -8
- wandb/sdk/launch/launch_add.py +6 -2
- wandb/sdk/launch/loader.py +21 -2
- wandb/sdk/launch/registry/azure_container_registry.py +132 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
- wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
- wandb/sdk/launch/registry/local_registry.py +2 -1
- wandb/sdk/launch/runner/abstract.py +24 -3
- wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
- wandb/sdk/launch/runner/local_container.py +103 -51
- wandb/sdk/launch/runner/local_process.py +1 -1
- wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
- wandb/sdk/launch/runner/vertex_runner.py +10 -5
- wandb/sdk/launch/sweeps/__init__.py +7 -9
- wandb/sdk/launch/sweeps/scheduler.py +307 -77
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +82 -35
- wandb/sdk/launch/utils.py +89 -75
- wandb/sdk/lib/_settings_toposort_generated.py +7 -0
- wandb/sdk/lib/capped_dict.py +26 -0
- wandb/sdk/lib/{git.py → gitlib.py} +76 -59
- wandb/sdk/lib/hashutil.py +12 -4
- wandb/sdk/lib/paths.py +96 -8
- wandb/sdk/lib/sock_client.py +2 -2
- wandb/sdk/lib/timer.py +1 -0
- wandb/sdk/service/server.py +22 -9
- wandb/sdk/service/server_sock.py +1 -1
- wandb/sdk/service/service.py +27 -8
- wandb/sdk/verify/verify.py +4 -7
- wandb/sdk/wandb_config.py +2 -6
- wandb/sdk/wandb_init.py +57 -53
- wandb/sdk/wandb_require.py +7 -0
- wandb/sdk/wandb_run.py +61 -223
- wandb/sdk/wandb_settings.py +28 -4
- wandb/testing/relay.py +15 -2
- wandb/util.py +74 -36
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
- wandb/integration/langchain/util.py +0 -191
- wandb/sdk/interface/artifacts/__init__.py +0 -33
- wandb/sdk/interface/artifacts/artifact.py +0 -615
- wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
- wandb/sdk/wandb_artifacts.py +0 -2226
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -39,12 +39,13 @@ from wandb.apis.normalize import normalize_exceptions, parse_backend_error_messa
|
|
39
39
|
from wandb.errors import CommError, UsageError
|
40
40
|
from wandb.integration.sagemaker import parse_sm_secrets
|
41
41
|
from wandb.old.settings import Settings
|
42
|
+
from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
|
42
43
|
from wandb.sdk.lib.gql_request import GraphQLSession
|
43
44
|
from wandb.sdk.lib.hashutil import B64MD5, md5_file_b64
|
44
45
|
|
45
46
|
from ..lib import retry
|
46
47
|
from ..lib.filenames import DIFF_FNAME, METADATA_FNAME
|
47
|
-
from ..lib.
|
48
|
+
from ..lib.gitlib import GitRepo
|
48
49
|
from . import context
|
49
50
|
from .progress import AsyncProgress, Progress
|
50
51
|
|
@@ -66,6 +67,7 @@ if TYPE_CHECKING:
|
|
66
67
|
md5: str
|
67
68
|
mimetype: Optional[str]
|
68
69
|
artifactManifestID: Optional[str] # noqa: N815
|
70
|
+
uploadPartsInput: Optional[List[Dict[str, object]]] # noqa: N815
|
69
71
|
|
70
72
|
class CreateArtifactFilesResponseFile(TypedDict):
|
71
73
|
id: str
|
@@ -73,11 +75,34 @@ if TYPE_CHECKING:
|
|
73
75
|
displayName: str # noqa: N815
|
74
76
|
uploadUrl: Optional[str] # noqa: N815
|
75
77
|
uploadHeaders: Sequence[str] # noqa: N815
|
78
|
+
uploadMultipartUrls: "UploadPartsResponse" # noqa: N815
|
79
|
+
storagePath: str # noqa: N815
|
76
80
|
artifact: "CreateArtifactFilesResponseFileNode"
|
77
81
|
|
78
82
|
class CreateArtifactFilesResponseFileNode(TypedDict):
|
79
83
|
id: str
|
80
84
|
|
85
|
+
class UploadPartsResponse(TypedDict):
|
86
|
+
uploadUrlParts: List["UploadUrlParts"] # noqa: N815
|
87
|
+
uploadID: str # noqa: N815
|
88
|
+
|
89
|
+
class UploadUrlParts(TypedDict):
|
90
|
+
partNumber: int # noqa: N815
|
91
|
+
uploadUrl: str # noqa: N815
|
92
|
+
|
93
|
+
class CompleteMultipartUploadArtifactInput(TypedDict):
|
94
|
+
"""Corresponds to `type CompleteMultipartUploadArtifactInput` in schema.graphql."""
|
95
|
+
|
96
|
+
completeMultipartAction: str # noqa: N815
|
97
|
+
completedParts: Dict[int, str] # noqa: N815
|
98
|
+
artifactID: str # noqa: N815
|
99
|
+
storagePath: str # noqa: N815
|
100
|
+
uploadID: str # noqa: N815
|
101
|
+
md5: str
|
102
|
+
|
103
|
+
class CompleteMultipartUploadArtifactResponse(TypedDict):
|
104
|
+
digest: str
|
105
|
+
|
81
106
|
class DefaultSettings(TypedDict):
|
82
107
|
section: str
|
83
108
|
git_remote: str
|
@@ -205,6 +230,10 @@ class Api:
|
|
205
230
|
self._environ.get("WANDB__EXTRA_HTTP_HEADERS", {})
|
206
231
|
)
|
207
232
|
|
233
|
+
auth = None
|
234
|
+
if _thread_local_api_settings.cookies is None:
|
235
|
+
auth = ("api", self.api_key or "")
|
236
|
+
extra_http_headers.update(_thread_local_api_settings.headers or {})
|
208
237
|
self.client = Client(
|
209
238
|
transport=GraphQLSession(
|
210
239
|
headers={
|
@@ -217,8 +246,9 @@ class Api:
|
|
217
246
|
# this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
|
218
247
|
# https://bugs.python.org/issue22889
|
219
248
|
timeout=self.HTTP_TIMEOUT,
|
220
|
-
auth=
|
249
|
+
auth=auth,
|
221
250
|
url=f"{self.settings('base_url')}/graphql",
|
251
|
+
cookies=_thread_local_api_settings.cookies,
|
222
252
|
)
|
223
253
|
)
|
224
254
|
|
@@ -242,6 +272,11 @@ class Api:
|
|
242
272
|
self.upload_file_retry = normalize_exceptions(
|
243
273
|
retry.retriable(retry_timedelta=retry_timedelta)(self.upload_file)
|
244
274
|
)
|
275
|
+
self.upload_multipart_file_chunk_retry = normalize_exceptions(
|
276
|
+
retry.retriable(retry_timedelta=retry_timedelta)(
|
277
|
+
self.upload_multipart_file_chunk
|
278
|
+
)
|
279
|
+
)
|
245
280
|
self._client_id_mapping: Dict[str, str] = {}
|
246
281
|
# Large file uploads to azure can optionally use their SDK
|
247
282
|
self._azure_blob_module = util.get_module("azure.storage.blob")
|
@@ -252,6 +287,7 @@ class Api:
|
|
252
287
|
self.server_use_artifact_input_info: Optional[List[str]] = None
|
253
288
|
self._max_cli_version: Optional[str] = None
|
254
289
|
self._server_settings_type: Optional[List[str]] = None
|
290
|
+
self.fail_run_queue_item_input_info: Optional[List[str]] = None
|
255
291
|
|
256
292
|
def gql(self, *args: Any, **kwargs: Any) -> Any:
|
257
293
|
ret = self._retry_gql(
|
@@ -273,7 +309,7 @@ class Api:
|
|
273
309
|
|
274
310
|
def reauth(self) -> None:
|
275
311
|
"""Ensure the current api key is set in the transport."""
|
276
|
-
self.client.transport.auth = ("api", self.api_key or "")
|
312
|
+
self.client.transport.session.auth = ("api", self.api_key or "")
|
277
313
|
|
278
314
|
def relocate(self) -> None:
|
279
315
|
"""Ensure the current api points to the right server."""
|
@@ -307,6 +343,8 @@ class Api:
|
|
307
343
|
|
308
344
|
@property
|
309
345
|
def api_key(self) -> Optional[str]:
|
346
|
+
if _thread_local_api_settings.api_key:
|
347
|
+
return _thread_local_api_settings.api_key
|
310
348
|
auth = requests.utils.get_netrc_auth(self.api_url)
|
311
349
|
key = None
|
312
350
|
if auth:
|
@@ -549,13 +587,103 @@ class Api:
|
|
549
587
|
return "failRunQueueItem" in mutations
|
550
588
|
|
551
589
|
@normalize_exceptions
|
552
|
-
def
|
590
|
+
def fail_run_queue_item_fields_introspection(self) -> List:
|
591
|
+
if self.fail_run_queue_item_input_info:
|
592
|
+
return self.fail_run_queue_item_input_info
|
593
|
+
query_string = """
|
594
|
+
query ProbeServerFailRunQueueItemInput {
|
595
|
+
FailRunQueueItemInputInfoType: __type(name:"FailRunQueueItemInput") {
|
596
|
+
inputFields{
|
597
|
+
name
|
598
|
+
}
|
599
|
+
}
|
600
|
+
}
|
601
|
+
"""
|
602
|
+
|
603
|
+
query = gql(query_string)
|
604
|
+
res = self.gql(query)
|
605
|
+
|
606
|
+
self.fail_run_queue_item_input_info = [
|
607
|
+
field.get("name", "")
|
608
|
+
for field in res.get("FailRunQueueItemInputInfoType", {}).get(
|
609
|
+
"inputFields", [{}]
|
610
|
+
)
|
611
|
+
]
|
612
|
+
return self.fail_run_queue_item_input_info
|
613
|
+
|
614
|
+
@normalize_exceptions
|
615
|
+
def fail_run_queue_item(
|
616
|
+
self,
|
617
|
+
run_queue_item_id: str,
|
618
|
+
message: str,
|
619
|
+
stage: str,
|
620
|
+
file_paths: Optional[List[str]] = None,
|
621
|
+
) -> bool:
|
622
|
+
if not self.fail_run_queue_item_introspection():
|
623
|
+
return False
|
624
|
+
variable_values: Dict[str, Union[str, Optional[List[str]]]] = {
|
625
|
+
"runQueueItemId": run_queue_item_id,
|
626
|
+
}
|
627
|
+
if "message" in self.fail_run_queue_item_fields_introspection():
|
628
|
+
variable_values.update({"message": message, "stage": stage})
|
629
|
+
if file_paths is not None:
|
630
|
+
variable_values["filePaths"] = file_paths
|
631
|
+
mutation_string = """
|
632
|
+
mutation failRunQueueItem($runQueueItemId: ID!, $message: String!, $stage: String!, $filePaths: [String!]) {
|
633
|
+
failRunQueueItem(
|
634
|
+
input: {
|
635
|
+
runQueueItemId: $runQueueItemId
|
636
|
+
message: $message
|
637
|
+
stage: $stage
|
638
|
+
filePaths: $filePaths
|
639
|
+
}
|
640
|
+
) {
|
641
|
+
success
|
642
|
+
}
|
643
|
+
}
|
644
|
+
"""
|
645
|
+
else:
|
646
|
+
mutation_string = """
|
647
|
+
mutation failRunQueueItem($runQueueItemId: ID!) {
|
648
|
+
failRunQueueItem(
|
649
|
+
input: {
|
650
|
+
runQueueItemId: $runQueueItemId
|
651
|
+
}
|
652
|
+
) {
|
653
|
+
success
|
654
|
+
}
|
655
|
+
}
|
656
|
+
"""
|
657
|
+
|
658
|
+
mutation = gql(mutation_string)
|
659
|
+
response = self.gql(mutation, variable_values=variable_values)
|
660
|
+
result: bool = response["failRunQueueItem"]["success"]
|
661
|
+
return result
|
662
|
+
|
663
|
+
@normalize_exceptions
|
664
|
+
def update_run_queue_item_warning_introspection(self) -> bool:
|
665
|
+
_, _, mutations = self.server_info_introspection()
|
666
|
+
return "updateRunQueueItemWarning" in mutations
|
667
|
+
|
668
|
+
@normalize_exceptions
|
669
|
+
def update_run_queue_item_warning(
|
670
|
+
self,
|
671
|
+
run_queue_item_id: str,
|
672
|
+
message: str,
|
673
|
+
stage: str,
|
674
|
+
file_paths: Optional[List[str]] = None,
|
675
|
+
) -> bool:
|
676
|
+
if not self.update_run_queue_item_warning_introspection():
|
677
|
+
return False
|
553
678
|
mutation = gql(
|
554
679
|
"""
|
555
|
-
mutation
|
556
|
-
|
680
|
+
mutation updateRunQueueItemWarning($runQueueItemId: ID!, $message: String!, $stage: String!, $filePaths: [String!]) {
|
681
|
+
updateRunQueueItemWarning(
|
557
682
|
input: {
|
558
683
|
runQueueItemId: $runQueueItemId
|
684
|
+
message: $message
|
685
|
+
stage: $stage
|
686
|
+
filePaths: $filePaths
|
559
687
|
}
|
560
688
|
) {
|
561
689
|
success
|
@@ -567,9 +695,12 @@ class Api:
|
|
567
695
|
mutation,
|
568
696
|
variable_values={
|
569
697
|
"runQueueItemId": run_queue_item_id,
|
698
|
+
"message": message,
|
699
|
+
"stage": stage,
|
700
|
+
"filePaths": file_paths,
|
570
701
|
},
|
571
702
|
)
|
572
|
-
result: bool = response["
|
703
|
+
result: bool = response["updateRunQueueItemWarning"]["success"]
|
573
704
|
return result
|
574
705
|
|
575
706
|
@normalize_exceptions
|
@@ -580,6 +711,7 @@ class Api:
|
|
580
711
|
viewer {
|
581
712
|
id
|
582
713
|
entity
|
714
|
+
username
|
583
715
|
flags
|
584
716
|
teams {
|
585
717
|
edges {
|
@@ -1992,7 +2124,16 @@ class Api:
|
|
1992
2124
|
Returns:
|
1993
2125
|
A tuple of the content length and the streaming response
|
1994
2126
|
"""
|
1995
|
-
|
2127
|
+
auth = None
|
2128
|
+
if _thread_local_api_settings.cookies is None:
|
2129
|
+
auth = ("user", self.api_key or "")
|
2130
|
+
response = requests.get(
|
2131
|
+
url,
|
2132
|
+
auth=auth,
|
2133
|
+
cookies=_thread_local_api_settings.cookies or {},
|
2134
|
+
headers=_thread_local_api_settings.headers or {},
|
2135
|
+
stream=True,
|
2136
|
+
)
|
1996
2137
|
response.raise_for_status()
|
1997
2138
|
return int(response.headers.get("content-length", 0)), response
|
1998
2139
|
|
@@ -2060,6 +2201,53 @@ class Api:
|
|
2060
2201
|
else:
|
2061
2202
|
raise requests.exceptions.ConnectionError(e.message)
|
2062
2203
|
|
2204
|
+
def upload_multipart_file_chunk(
|
2205
|
+
self,
|
2206
|
+
url: str,
|
2207
|
+
upload_chunk: bytes,
|
2208
|
+
extra_headers: Optional[Dict[str, str]] = None,
|
2209
|
+
) -> Optional[requests.Response]:
|
2210
|
+
"""Upload a file chunk to S3 with failure resumption.
|
2211
|
+
|
2212
|
+
Arguments:
|
2213
|
+
url: The url to download
|
2214
|
+
upload_chunk: The path to the file you want to upload
|
2215
|
+
extra_headers: A dictionary of extra headers to send with the request
|
2216
|
+
|
2217
|
+
Returns:
|
2218
|
+
The `requests` library response object
|
2219
|
+
"""
|
2220
|
+
try:
|
2221
|
+
response = self._upload_file_session.put(
|
2222
|
+
url, data=upload_chunk, headers=extra_headers
|
2223
|
+
)
|
2224
|
+
response.raise_for_status()
|
2225
|
+
except requests.exceptions.RequestException as e:
|
2226
|
+
logger.error(f"upload_file exception {url}: {e}")
|
2227
|
+
request_headers = e.request.headers if e.request is not None else ""
|
2228
|
+
logger.error(f"upload_file request headers: {request_headers}")
|
2229
|
+
response_content = e.response.content if e.response is not None else ""
|
2230
|
+
logger.error(f"upload_file response body: {response_content}")
|
2231
|
+
status_code = e.response.status_code if e.response is not None else 0
|
2232
|
+
# S3 reports retryable request timeouts out-of-band
|
2233
|
+
is_aws_retryable = status_code == 400 and "RequestTimeout" in str(
|
2234
|
+
response_content
|
2235
|
+
)
|
2236
|
+
# Retry errors from cloud storage or local network issues
|
2237
|
+
if (
|
2238
|
+
status_code in (308, 408, 409, 429, 500, 502, 503, 504)
|
2239
|
+
or isinstance(
|
2240
|
+
e,
|
2241
|
+
(requests.exceptions.Timeout, requests.exceptions.ConnectionError),
|
2242
|
+
)
|
2243
|
+
or is_aws_retryable
|
2244
|
+
):
|
2245
|
+
_e = retry.TransientError(exc=e)
|
2246
|
+
raise _e.with_traceback(sys.exc_info()[2])
|
2247
|
+
else:
|
2248
|
+
wandb._sentry.reraise(e)
|
2249
|
+
return response
|
2250
|
+
|
2063
2251
|
def upload_file(
|
2064
2252
|
self,
|
2065
2253
|
url: str,
|
@@ -2353,7 +2541,8 @@ class Api:
|
|
2353
2541
|
config = dict(config)
|
2354
2542
|
|
2355
2543
|
if "parameters" not in config:
|
2356
|
-
|
2544
|
+
# still shows an anaconda warning, but doesn't error
|
2545
|
+
return config
|
2357
2546
|
|
2358
2547
|
for parameter_name in config["parameters"]:
|
2359
2548
|
parameter = config["parameters"][parameter_name]
|
@@ -3014,6 +3203,50 @@ class Api:
|
|
3014
3203
|
)
|
3015
3204
|
return response
|
3016
3205
|
|
3206
|
+
def complete_multipart_upload_artifact(
|
3207
|
+
self,
|
3208
|
+
artifact_id: str,
|
3209
|
+
storage_path: str,
|
3210
|
+
completed_parts: List[Dict[str, Any]],
|
3211
|
+
upload_id: str,
|
3212
|
+
complete_multipart_action: str = "Complete",
|
3213
|
+
) -> Optional[str]:
|
3214
|
+
mutation = gql(
|
3215
|
+
"""
|
3216
|
+
mutation CompleteMultipartUploadArtifact(
|
3217
|
+
$completeMultipartAction: CompleteMultipartAction!,
|
3218
|
+
$completedParts: [UploadPartsInput!]!,
|
3219
|
+
$artifactID: ID!
|
3220
|
+
$storagePath: String!
|
3221
|
+
$uploadID: String!
|
3222
|
+
) {
|
3223
|
+
completeMultipartUploadArtifact(
|
3224
|
+
input: {
|
3225
|
+
completeMultipartAction: $completeMultipartAction,
|
3226
|
+
completedParts: $completedParts,
|
3227
|
+
artifactID: $artifactID,
|
3228
|
+
storagePath: $storagePath
|
3229
|
+
uploadID: $uploadID
|
3230
|
+
}
|
3231
|
+
) {
|
3232
|
+
digest
|
3233
|
+
}
|
3234
|
+
}
|
3235
|
+
"""
|
3236
|
+
)
|
3237
|
+
response = self.gql(
|
3238
|
+
mutation,
|
3239
|
+
variable_values={
|
3240
|
+
"completeMultipartAction": complete_multipart_action,
|
3241
|
+
"artifactID": artifact_id,
|
3242
|
+
"storagePath": storage_path,
|
3243
|
+
"completedParts": completed_parts,
|
3244
|
+
"uploadID": upload_id,
|
3245
|
+
},
|
3246
|
+
)
|
3247
|
+
digest: Optional[str] = response["completeMultipartUploadArtifact"]["digest"]
|
3248
|
+
return digest
|
3249
|
+
|
3017
3250
|
def create_artifact_manifest(
|
3018
3251
|
self,
|
3019
3252
|
name: str,
|
@@ -3171,19 +3404,39 @@ class Api:
|
|
3171
3404
|
self._client_id_mapping[client_id] = server_id
|
3172
3405
|
return server_id
|
3173
3406
|
|
3407
|
+
def server_create_artifact_file_spec_input_introspection(self) -> List:
|
3408
|
+
query_string = """
|
3409
|
+
query ProbeServerCreateArtifactFileSpecInput {
|
3410
|
+
CreateArtifactFileSpecInputInfoType: __type(name:"CreateArtifactFileSpecInput") {
|
3411
|
+
inputFields{
|
3412
|
+
name
|
3413
|
+
}
|
3414
|
+
}
|
3415
|
+
}
|
3416
|
+
"""
|
3417
|
+
|
3418
|
+
query = gql(query_string)
|
3419
|
+
res = self.gql(query)
|
3420
|
+
create_artifact_file_spec_input_info = [
|
3421
|
+
field.get("name", "")
|
3422
|
+
for field in res.get("CreateArtifactFileSpecInputInfoType", {}).get(
|
3423
|
+
"inputFields", [{}]
|
3424
|
+
)
|
3425
|
+
]
|
3426
|
+
return create_artifact_file_spec_input_info
|
3427
|
+
|
3174
3428
|
@normalize_exceptions
|
3175
3429
|
def create_artifact_files(
|
3176
3430
|
self, artifact_files: Iterable["CreateArtifactFileSpecInput"]
|
3177
3431
|
) -> Mapping[str, "CreateArtifactFilesResponseFile"]:
|
3178
|
-
|
3179
|
-
"""
|
3432
|
+
query_template = """
|
3180
3433
|
mutation CreateArtifactFiles(
|
3181
3434
|
$storageLayout: ArtifactStorageLayout!
|
3182
3435
|
$artifactFiles: [CreateArtifactFileSpecInput!]!
|
3183
3436
|
) {
|
3184
3437
|
createArtifactFiles(input: {
|
3185
3438
|
artifactFiles: $artifactFiles,
|
3186
|
-
storageLayout: $storageLayout
|
3439
|
+
storageLayout: $storageLayout,
|
3187
3440
|
}) {
|
3188
3441
|
files {
|
3189
3442
|
edges {
|
@@ -3193,6 +3446,7 @@ class Api:
|
|
3193
3446
|
displayName
|
3194
3447
|
uploadUrl
|
3195
3448
|
uploadHeaders
|
3449
|
+
_MULTIPART_UPLOAD_FIELDS_
|
3196
3450
|
artifact {
|
3197
3451
|
id
|
3198
3452
|
}
|
@@ -3202,7 +3456,16 @@ class Api:
|
|
3202
3456
|
}
|
3203
3457
|
}
|
3204
3458
|
"""
|
3205
|
-
|
3459
|
+
multipart_upload_url_query = """
|
3460
|
+
storagePath
|
3461
|
+
uploadMultipartUrls {
|
3462
|
+
uploadID
|
3463
|
+
uploadUrlParts {
|
3464
|
+
partNumber
|
3465
|
+
uploadUrl
|
3466
|
+
}
|
3467
|
+
}
|
3468
|
+
"""
|
3206
3469
|
|
3207
3470
|
# TODO: we should use constants here from interface/artifacts.py
|
3208
3471
|
# but probably don't want the dependency. We're going to remove
|
@@ -3211,6 +3474,17 @@ class Api:
|
|
3211
3474
|
if env.get_use_v1_artifacts():
|
3212
3475
|
storage_layout = "V1"
|
3213
3476
|
|
3477
|
+
create_artifact_file_spec_input_fields = (
|
3478
|
+
self.server_create_artifact_file_spec_input_introspection()
|
3479
|
+
)
|
3480
|
+
if "uploadPartsInput" in create_artifact_file_spec_input_fields:
|
3481
|
+
query_template = query_template.replace(
|
3482
|
+
"_MULTIPART_UPLOAD_FIELDS_", multipart_upload_url_query
|
3483
|
+
)
|
3484
|
+
else:
|
3485
|
+
query_template = query_template.replace("_MULTIPART_UPLOAD_FIELDS_", "")
|
3486
|
+
|
3487
|
+
mutation = gql(query_template)
|
3214
3488
|
response = self.gql(
|
3215
3489
|
mutation,
|
3216
3490
|
variable_values={
|