wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (156) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/analytics/sentry.py +1 -0
  3. wandb/apis/importers/base.py +20 -5
  4. wandb/apis/importers/mlflow.py +7 -1
  5. wandb/apis/internal.py +12 -0
  6. wandb/apis/public.py +247 -1387
  7. wandb/apis/reports/_panels.py +58 -35
  8. wandb/beta/workflows.py +6 -7
  9. wandb/cli/cli.py +130 -60
  10. wandb/data_types.py +3 -1
  11. wandb/filesync/dir_watcher.py +21 -27
  12. wandb/filesync/step_checksum.py +8 -8
  13. wandb/filesync/step_prepare.py +23 -10
  14. wandb/filesync/step_upload.py +13 -13
  15. wandb/filesync/upload_job.py +4 -8
  16. wandb/integration/cohere/__init__.py +3 -0
  17. wandb/integration/cohere/cohere.py +21 -0
  18. wandb/integration/cohere/resolver.py +347 -0
  19. wandb/integration/gym/__init__.py +4 -6
  20. wandb/integration/huggingface/__init__.py +3 -0
  21. wandb/integration/huggingface/huggingface.py +18 -0
  22. wandb/integration/huggingface/resolver.py +213 -0
  23. wandb/integration/langchain/wandb_tracer.py +16 -179
  24. wandb/integration/openai/__init__.py +1 -3
  25. wandb/integration/openai/openai.py +11 -143
  26. wandb/integration/openai/resolver.py +111 -38
  27. wandb/integration/sagemaker/config.py +2 -2
  28. wandb/integration/tensorboard/log.py +4 -4
  29. wandb/old/settings.py +24 -7
  30. wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
  31. wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
  32. wandb/proto/wandb_deprecated.py +3 -1
  33. wandb/sdk/__init__.py +1 -1
  34. wandb/sdk/artifacts/__init__.py +0 -0
  35. wandb/sdk/artifacts/artifact.py +2101 -0
  36. wandb/sdk/artifacts/artifact_download_logger.py +42 -0
  37. wandb/sdk/artifacts/artifact_manifest.py +67 -0
  38. wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
  39. wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
  40. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
  41. wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
  42. wandb/sdk/artifacts/artifact_state.py +10 -0
  43. wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
  44. wandb/sdk/artifacts/exceptions.py +55 -0
  45. wandb/sdk/artifacts/storage_handler.py +59 -0
  46. wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
  47. wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
  48. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
  49. wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
  50. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
  51. wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
  52. wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
  53. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
  54. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
  55. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
  56. wandb/sdk/artifacts/storage_layout.py +6 -0
  57. wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
  58. wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
  59. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
  60. wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
  61. wandb/sdk/data_types/_dtypes.py +7 -12
  62. wandb/sdk/data_types/base_types/json_metadata.py +3 -2
  63. wandb/sdk/data_types/base_types/media.py +8 -8
  64. wandb/sdk/data_types/base_types/wb_value.py +12 -13
  65. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
  66. wandb/sdk/data_types/helper_types/classes.py +6 -8
  67. wandb/sdk/data_types/helper_types/image_mask.py +5 -6
  68. wandb/sdk/data_types/histogram.py +4 -3
  69. wandb/sdk/data_types/html.py +3 -4
  70. wandb/sdk/data_types/image.py +11 -9
  71. wandb/sdk/data_types/molecule.py +5 -3
  72. wandb/sdk/data_types/object_3d.py +7 -5
  73. wandb/sdk/data_types/plotly.py +3 -2
  74. wandb/sdk/data_types/saved_model.py +11 -11
  75. wandb/sdk/data_types/trace_tree.py +5 -4
  76. wandb/sdk/data_types/utils.py +3 -5
  77. wandb/sdk/data_types/video.py +5 -4
  78. wandb/sdk/integration_utils/auto_logging.py +215 -0
  79. wandb/sdk/interface/interface.py +15 -15
  80. wandb/sdk/internal/file_pusher.py +8 -16
  81. wandb/sdk/internal/file_stream.py +5 -11
  82. wandb/sdk/internal/handler.py +13 -1
  83. wandb/sdk/internal/internal_api.py +287 -13
  84. wandb/sdk/internal/job_builder.py +119 -30
  85. wandb/sdk/internal/sender.py +6 -26
  86. wandb/sdk/internal/settings_static.py +2 -0
  87. wandb/sdk/internal/system/assets/__init__.py +2 -0
  88. wandb/sdk/internal/system/assets/gpu.py +42 -0
  89. wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
  90. wandb/sdk/internal/system/env_probe_helpers.py +13 -0
  91. wandb/sdk/internal/system/system_info.py +3 -3
  92. wandb/sdk/internal/tb_watcher.py +32 -22
  93. wandb/sdk/internal/thread_local_settings.py +18 -0
  94. wandb/sdk/launch/_project_spec.py +57 -11
  95. wandb/sdk/launch/agent/agent.py +147 -65
  96. wandb/sdk/launch/agent/job_status_tracker.py +34 -0
  97. wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
  98. wandb/sdk/launch/builder/abstract.py +5 -1
  99. wandb/sdk/launch/builder/build.py +21 -18
  100. wandb/sdk/launch/builder/docker_builder.py +10 -4
  101. wandb/sdk/launch/builder/kaniko_builder.py +113 -23
  102. wandb/sdk/launch/builder/noop.py +6 -3
  103. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
  104. wandb/sdk/launch/environment/aws_environment.py +3 -2
  105. wandb/sdk/launch/environment/azure_environment.py +124 -0
  106. wandb/sdk/launch/environment/gcp_environment.py +2 -4
  107. wandb/sdk/launch/environment/local_environment.py +1 -1
  108. wandb/sdk/launch/errors.py +19 -0
  109. wandb/sdk/launch/github_reference.py +32 -19
  110. wandb/sdk/launch/launch.py +3 -8
  111. wandb/sdk/launch/launch_add.py +6 -2
  112. wandb/sdk/launch/loader.py +21 -2
  113. wandb/sdk/launch/registry/azure_container_registry.py +132 -0
  114. wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
  115. wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
  116. wandb/sdk/launch/registry/local_registry.py +2 -1
  117. wandb/sdk/launch/runner/abstract.py +24 -3
  118. wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
  119. wandb/sdk/launch/runner/local_container.py +103 -51
  120. wandb/sdk/launch/runner/local_process.py +1 -1
  121. wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
  122. wandb/sdk/launch/runner/vertex_runner.py +10 -5
  123. wandb/sdk/launch/sweeps/__init__.py +7 -9
  124. wandb/sdk/launch/sweeps/scheduler.py +307 -77
  125. wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
  126. wandb/sdk/launch/sweeps/utils.py +82 -35
  127. wandb/sdk/launch/utils.py +89 -75
  128. wandb/sdk/lib/_settings_toposort_generated.py +7 -0
  129. wandb/sdk/lib/capped_dict.py +26 -0
  130. wandb/sdk/lib/{git.py → gitlib.py} +76 -59
  131. wandb/sdk/lib/hashutil.py +12 -4
  132. wandb/sdk/lib/paths.py +96 -8
  133. wandb/sdk/lib/sock_client.py +2 -2
  134. wandb/sdk/lib/timer.py +1 -0
  135. wandb/sdk/service/server.py +22 -9
  136. wandb/sdk/service/server_sock.py +1 -1
  137. wandb/sdk/service/service.py +27 -8
  138. wandb/sdk/verify/verify.py +4 -7
  139. wandb/sdk/wandb_config.py +2 -6
  140. wandb/sdk/wandb_init.py +57 -53
  141. wandb/sdk/wandb_require.py +7 -0
  142. wandb/sdk/wandb_run.py +61 -223
  143. wandb/sdk/wandb_settings.py +28 -4
  144. wandb/testing/relay.py +15 -2
  145. wandb/util.py +74 -36
  146. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
  147. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
  148. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
  149. wandb/integration/langchain/util.py +0 -191
  150. wandb/sdk/interface/artifacts/__init__.py +0 -33
  151. wandb/sdk/interface/artifacts/artifact.py +0 -615
  152. wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
  153. wandb/sdk/wandb_artifacts.py +0 -2226
  154. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
  155. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
  156. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -39,12 +39,13 @@ from wandb.apis.normalize import normalize_exceptions, parse_backend_error_messa
39
39
  from wandb.errors import CommError, UsageError
40
40
  from wandb.integration.sagemaker import parse_sm_secrets
41
41
  from wandb.old.settings import Settings
42
+ from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
42
43
  from wandb.sdk.lib.gql_request import GraphQLSession
43
44
  from wandb.sdk.lib.hashutil import B64MD5, md5_file_b64
44
45
 
45
46
  from ..lib import retry
46
47
  from ..lib.filenames import DIFF_FNAME, METADATA_FNAME
47
- from ..lib.git import GitRepo
48
+ from ..lib.gitlib import GitRepo
48
49
  from . import context
49
50
  from .progress import AsyncProgress, Progress
50
51
 
@@ -66,6 +67,7 @@ if TYPE_CHECKING:
66
67
  md5: str
67
68
  mimetype: Optional[str]
68
69
  artifactManifestID: Optional[str] # noqa: N815
70
+ uploadPartsInput: Optional[List[Dict[str, object]]] # noqa: N815
69
71
 
70
72
  class CreateArtifactFilesResponseFile(TypedDict):
71
73
  id: str
@@ -73,11 +75,34 @@ if TYPE_CHECKING:
73
75
  displayName: str # noqa: N815
74
76
  uploadUrl: Optional[str] # noqa: N815
75
77
  uploadHeaders: Sequence[str] # noqa: N815
78
+ uploadMultipartUrls: "UploadPartsResponse" # noqa: N815
79
+ storagePath: str # noqa: N815
76
80
  artifact: "CreateArtifactFilesResponseFileNode"
77
81
 
78
82
  class CreateArtifactFilesResponseFileNode(TypedDict):
79
83
  id: str
80
84
 
85
+ class UploadPartsResponse(TypedDict):
86
+ uploadUrlParts: List["UploadUrlParts"] # noqa: N815
87
+ uploadID: str # noqa: N815
88
+
89
+ class UploadUrlParts(TypedDict):
90
+ partNumber: int # noqa: N815
91
+ uploadUrl: str # noqa: N815
92
+
93
+ class CompleteMultipartUploadArtifactInput(TypedDict):
94
+ """Corresponds to `type CompleteMultipartUploadArtifactInput` in schema.graphql."""
95
+
96
+ completeMultipartAction: str # noqa: N815
97
+ completedParts: Dict[int, str] # noqa: N815
98
+ artifactID: str # noqa: N815
99
+ storagePath: str # noqa: N815
100
+ uploadID: str # noqa: N815
101
+ md5: str
102
+
103
+ class CompleteMultipartUploadArtifactResponse(TypedDict):
104
+ digest: str
105
+
81
106
  class DefaultSettings(TypedDict):
82
107
  section: str
83
108
  git_remote: str
@@ -205,6 +230,10 @@ class Api:
205
230
  self._environ.get("WANDB__EXTRA_HTTP_HEADERS", {})
206
231
  )
207
232
 
233
+ auth = None
234
+ if _thread_local_api_settings.cookies is None:
235
+ auth = ("api", self.api_key or "")
236
+ extra_http_headers.update(_thread_local_api_settings.headers or {})
208
237
  self.client = Client(
209
238
  transport=GraphQLSession(
210
239
  headers={
@@ -217,8 +246,9 @@ class Api:
217
246
  # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
218
247
  # https://bugs.python.org/issue22889
219
248
  timeout=self.HTTP_TIMEOUT,
220
- auth=("api", self.api_key or ""),
249
+ auth=auth,
221
250
  url=f"{self.settings('base_url')}/graphql",
251
+ cookies=_thread_local_api_settings.cookies,
222
252
  )
223
253
  )
224
254
 
@@ -242,6 +272,11 @@ class Api:
242
272
  self.upload_file_retry = normalize_exceptions(
243
273
  retry.retriable(retry_timedelta=retry_timedelta)(self.upload_file)
244
274
  )
275
+ self.upload_multipart_file_chunk_retry = normalize_exceptions(
276
+ retry.retriable(retry_timedelta=retry_timedelta)(
277
+ self.upload_multipart_file_chunk
278
+ )
279
+ )
245
280
  self._client_id_mapping: Dict[str, str] = {}
246
281
  # Large file uploads to azure can optionally use their SDK
247
282
  self._azure_blob_module = util.get_module("azure.storage.blob")
@@ -252,6 +287,7 @@ class Api:
252
287
  self.server_use_artifact_input_info: Optional[List[str]] = None
253
288
  self._max_cli_version: Optional[str] = None
254
289
  self._server_settings_type: Optional[List[str]] = None
290
+ self.fail_run_queue_item_input_info: Optional[List[str]] = None
255
291
 
256
292
  def gql(self, *args: Any, **kwargs: Any) -> Any:
257
293
  ret = self._retry_gql(
@@ -273,7 +309,7 @@ class Api:
273
309
 
274
310
  def reauth(self) -> None:
275
311
  """Ensure the current api key is set in the transport."""
276
- self.client.transport.auth = ("api", self.api_key or "")
312
+ self.client.transport.session.auth = ("api", self.api_key or "")
277
313
 
278
314
  def relocate(self) -> None:
279
315
  """Ensure the current api points to the right server."""
@@ -307,6 +343,8 @@ class Api:
307
343
 
308
344
  @property
309
345
  def api_key(self) -> Optional[str]:
346
+ if _thread_local_api_settings.api_key:
347
+ return _thread_local_api_settings.api_key
310
348
  auth = requests.utils.get_netrc_auth(self.api_url)
311
349
  key = None
312
350
  if auth:
@@ -549,13 +587,103 @@ class Api:
549
587
  return "failRunQueueItem" in mutations
550
588
 
551
589
  @normalize_exceptions
552
- def fail_run_queue_item(self, run_queue_item_id: str) -> bool:
590
+ def fail_run_queue_item_fields_introspection(self) -> List:
591
+ if self.fail_run_queue_item_input_info:
592
+ return self.fail_run_queue_item_input_info
593
+ query_string = """
594
+ query ProbeServerFailRunQueueItemInput {
595
+ FailRunQueueItemInputInfoType: __type(name:"FailRunQueueItemInput") {
596
+ inputFields{
597
+ name
598
+ }
599
+ }
600
+ }
601
+ """
602
+
603
+ query = gql(query_string)
604
+ res = self.gql(query)
605
+
606
+ self.fail_run_queue_item_input_info = [
607
+ field.get("name", "")
608
+ for field in res.get("FailRunQueueItemInputInfoType", {}).get(
609
+ "inputFields", [{}]
610
+ )
611
+ ]
612
+ return self.fail_run_queue_item_input_info
613
+
614
+ @normalize_exceptions
615
+ def fail_run_queue_item(
616
+ self,
617
+ run_queue_item_id: str,
618
+ message: str,
619
+ stage: str,
620
+ file_paths: Optional[List[str]] = None,
621
+ ) -> bool:
622
+ if not self.fail_run_queue_item_introspection():
623
+ return False
624
+ variable_values: Dict[str, Union[str, Optional[List[str]]]] = {
625
+ "runQueueItemId": run_queue_item_id,
626
+ }
627
+ if "message" in self.fail_run_queue_item_fields_introspection():
628
+ variable_values.update({"message": message, "stage": stage})
629
+ if file_paths is not None:
630
+ variable_values["filePaths"] = file_paths
631
+ mutation_string = """
632
+ mutation failRunQueueItem($runQueueItemId: ID!, $message: String!, $stage: String!, $filePaths: [String!]) {
633
+ failRunQueueItem(
634
+ input: {
635
+ runQueueItemId: $runQueueItemId
636
+ message: $message
637
+ stage: $stage
638
+ filePaths: $filePaths
639
+ }
640
+ ) {
641
+ success
642
+ }
643
+ }
644
+ """
645
+ else:
646
+ mutation_string = """
647
+ mutation failRunQueueItem($runQueueItemId: ID!) {
648
+ failRunQueueItem(
649
+ input: {
650
+ runQueueItemId: $runQueueItemId
651
+ }
652
+ ) {
653
+ success
654
+ }
655
+ }
656
+ """
657
+
658
+ mutation = gql(mutation_string)
659
+ response = self.gql(mutation, variable_values=variable_values)
660
+ result: bool = response["failRunQueueItem"]["success"]
661
+ return result
662
+
663
+ @normalize_exceptions
664
+ def update_run_queue_item_warning_introspection(self) -> bool:
665
+ _, _, mutations = self.server_info_introspection()
666
+ return "updateRunQueueItemWarning" in mutations
667
+
668
+ @normalize_exceptions
669
+ def update_run_queue_item_warning(
670
+ self,
671
+ run_queue_item_id: str,
672
+ message: str,
673
+ stage: str,
674
+ file_paths: Optional[List[str]] = None,
675
+ ) -> bool:
676
+ if not self.update_run_queue_item_warning_introspection():
677
+ return False
553
678
  mutation = gql(
554
679
  """
555
- mutation failRunQueueItem($runQueueItemId: ID!) {
556
- failRunQueueItem(
680
+ mutation updateRunQueueItemWarning($runQueueItemId: ID!, $message: String!, $stage: String!, $filePaths: [String!]) {
681
+ updateRunQueueItemWarning(
557
682
  input: {
558
683
  runQueueItemId: $runQueueItemId
684
+ message: $message
685
+ stage: $stage
686
+ filePaths: $filePaths
559
687
  }
560
688
  ) {
561
689
  success
@@ -567,9 +695,12 @@ class Api:
567
695
  mutation,
568
696
  variable_values={
569
697
  "runQueueItemId": run_queue_item_id,
698
+ "message": message,
699
+ "stage": stage,
700
+ "filePaths": file_paths,
570
701
  },
571
702
  )
572
- result: bool = response["failRunQueueItem"]["success"]
703
+ result: bool = response["updateRunQueueItemWarning"]["success"]
573
704
  return result
574
705
 
575
706
  @normalize_exceptions
@@ -580,6 +711,7 @@ class Api:
580
711
  viewer {
581
712
  id
582
713
  entity
714
+ username
583
715
  flags
584
716
  teams {
585
717
  edges {
@@ -1992,7 +2124,16 @@ class Api:
1992
2124
  Returns:
1993
2125
  A tuple of the content length and the streaming response
1994
2126
  """
1995
- response = requests.get(url, auth=("user", self.api_key), stream=True) # type: ignore
2127
+ auth = None
2128
+ if _thread_local_api_settings.cookies is None:
2129
+ auth = ("user", self.api_key or "")
2130
+ response = requests.get(
2131
+ url,
2132
+ auth=auth,
2133
+ cookies=_thread_local_api_settings.cookies or {},
2134
+ headers=_thread_local_api_settings.headers or {},
2135
+ stream=True,
2136
+ )
1996
2137
  response.raise_for_status()
1997
2138
  return int(response.headers.get("content-length", 0)), response
1998
2139
 
@@ -2060,6 +2201,53 @@ class Api:
2060
2201
  else:
2061
2202
  raise requests.exceptions.ConnectionError(e.message)
2062
2203
 
2204
+ def upload_multipart_file_chunk(
2205
+ self,
2206
+ url: str,
2207
+ upload_chunk: bytes,
2208
+ extra_headers: Optional[Dict[str, str]] = None,
2209
+ ) -> Optional[requests.Response]:
2210
+ """Upload a file chunk to S3 with failure resumption.
2211
+
2212
+ Arguments:
2213
+ url: The url to download
2214
+ upload_chunk: The path to the file you want to upload
2215
+ extra_headers: A dictionary of extra headers to send with the request
2216
+
2217
+ Returns:
2218
+ The `requests` library response object
2219
+ """
2220
+ try:
2221
+ response = self._upload_file_session.put(
2222
+ url, data=upload_chunk, headers=extra_headers
2223
+ )
2224
+ response.raise_for_status()
2225
+ except requests.exceptions.RequestException as e:
2226
+ logger.error(f"upload_file exception {url}: {e}")
2227
+ request_headers = e.request.headers if e.request is not None else ""
2228
+ logger.error(f"upload_file request headers: {request_headers}")
2229
+ response_content = e.response.content if e.response is not None else ""
2230
+ logger.error(f"upload_file response body: {response_content}")
2231
+ status_code = e.response.status_code if e.response is not None else 0
2232
+ # S3 reports retryable request timeouts out-of-band
2233
+ is_aws_retryable = status_code == 400 and "RequestTimeout" in str(
2234
+ response_content
2235
+ )
2236
+ # Retry errors from cloud storage or local network issues
2237
+ if (
2238
+ status_code in (308, 408, 409, 429, 500, 502, 503, 504)
2239
+ or isinstance(
2240
+ e,
2241
+ (requests.exceptions.Timeout, requests.exceptions.ConnectionError),
2242
+ )
2243
+ or is_aws_retryable
2244
+ ):
2245
+ _e = retry.TransientError(exc=e)
2246
+ raise _e.with_traceback(sys.exc_info()[2])
2247
+ else:
2248
+ wandb._sentry.reraise(e)
2249
+ return response
2250
+
2063
2251
  def upload_file(
2064
2252
  self,
2065
2253
  url: str,
@@ -2353,7 +2541,8 @@ class Api:
2353
2541
  config = dict(config)
2354
2542
 
2355
2543
  if "parameters" not in config:
2356
- raise ValueError("sweep config must have a parameters section")
2544
+ # still shows an anaconda warning, but doesn't error
2545
+ return config
2357
2546
 
2358
2547
  for parameter_name in config["parameters"]:
2359
2548
  parameter = config["parameters"][parameter_name]
@@ -3014,6 +3203,50 @@ class Api:
3014
3203
  )
3015
3204
  return response
3016
3205
 
3206
+ def complete_multipart_upload_artifact(
3207
+ self,
3208
+ artifact_id: str,
3209
+ storage_path: str,
3210
+ completed_parts: List[Dict[str, Any]],
3211
+ upload_id: str,
3212
+ complete_multipart_action: str = "Complete",
3213
+ ) -> Optional[str]:
3214
+ mutation = gql(
3215
+ """
3216
+ mutation CompleteMultipartUploadArtifact(
3217
+ $completeMultipartAction: CompleteMultipartAction!,
3218
+ $completedParts: [UploadPartsInput!]!,
3219
+ $artifactID: ID!
3220
+ $storagePath: String!
3221
+ $uploadID: String!
3222
+ ) {
3223
+ completeMultipartUploadArtifact(
3224
+ input: {
3225
+ completeMultipartAction: $completeMultipartAction,
3226
+ completedParts: $completedParts,
3227
+ artifactID: $artifactID,
3228
+ storagePath: $storagePath
3229
+ uploadID: $uploadID
3230
+ }
3231
+ ) {
3232
+ digest
3233
+ }
3234
+ }
3235
+ """
3236
+ )
3237
+ response = self.gql(
3238
+ mutation,
3239
+ variable_values={
3240
+ "completeMultipartAction": complete_multipart_action,
3241
+ "artifactID": artifact_id,
3242
+ "storagePath": storage_path,
3243
+ "completedParts": completed_parts,
3244
+ "uploadID": upload_id,
3245
+ },
3246
+ )
3247
+ digest: Optional[str] = response["completeMultipartUploadArtifact"]["digest"]
3248
+ return digest
3249
+
3017
3250
  def create_artifact_manifest(
3018
3251
  self,
3019
3252
  name: str,
@@ -3171,19 +3404,39 @@ class Api:
3171
3404
  self._client_id_mapping[client_id] = server_id
3172
3405
  return server_id
3173
3406
 
3407
+ def server_create_artifact_file_spec_input_introspection(self) -> List:
3408
+ query_string = """
3409
+ query ProbeServerCreateArtifactFileSpecInput {
3410
+ CreateArtifactFileSpecInputInfoType: __type(name:"CreateArtifactFileSpecInput") {
3411
+ inputFields{
3412
+ name
3413
+ }
3414
+ }
3415
+ }
3416
+ """
3417
+
3418
+ query = gql(query_string)
3419
+ res = self.gql(query)
3420
+ create_artifact_file_spec_input_info = [
3421
+ field.get("name", "")
3422
+ for field in res.get("CreateArtifactFileSpecInputInfoType", {}).get(
3423
+ "inputFields", [{}]
3424
+ )
3425
+ ]
3426
+ return create_artifact_file_spec_input_info
3427
+
3174
3428
  @normalize_exceptions
3175
3429
  def create_artifact_files(
3176
3430
  self, artifact_files: Iterable["CreateArtifactFileSpecInput"]
3177
3431
  ) -> Mapping[str, "CreateArtifactFilesResponseFile"]:
3178
- mutation = gql(
3179
- """
3432
+ query_template = """
3180
3433
  mutation CreateArtifactFiles(
3181
3434
  $storageLayout: ArtifactStorageLayout!
3182
3435
  $artifactFiles: [CreateArtifactFileSpecInput!]!
3183
3436
  ) {
3184
3437
  createArtifactFiles(input: {
3185
3438
  artifactFiles: $artifactFiles,
3186
- storageLayout: $storageLayout
3439
+ storageLayout: $storageLayout,
3187
3440
  }) {
3188
3441
  files {
3189
3442
  edges {
@@ -3193,6 +3446,7 @@ class Api:
3193
3446
  displayName
3194
3447
  uploadUrl
3195
3448
  uploadHeaders
3449
+ _MULTIPART_UPLOAD_FIELDS_
3196
3450
  artifact {
3197
3451
  id
3198
3452
  }
@@ -3202,7 +3456,16 @@ class Api:
3202
3456
  }
3203
3457
  }
3204
3458
  """
3205
- )
3459
+ multipart_upload_url_query = """
3460
+ storagePath
3461
+ uploadMultipartUrls {
3462
+ uploadID
3463
+ uploadUrlParts {
3464
+ partNumber
3465
+ uploadUrl
3466
+ }
3467
+ }
3468
+ """
3206
3469
 
3207
3470
  # TODO: we should use constants here from interface/artifacts.py
3208
3471
  # but probably don't want the dependency. We're going to remove
@@ -3211,6 +3474,17 @@ class Api:
3211
3474
  if env.get_use_v1_artifacts():
3212
3475
  storage_layout = "V1"
3213
3476
 
3477
+ create_artifact_file_spec_input_fields = (
3478
+ self.server_create_artifact_file_spec_input_introspection()
3479
+ )
3480
+ if "uploadPartsInput" in create_artifact_file_spec_input_fields:
3481
+ query_template = query_template.replace(
3482
+ "_MULTIPART_UPLOAD_FIELDS_", multipart_upload_url_query
3483
+ )
3484
+ else:
3485
+ query_template = query_template.replace("_MULTIPART_UPLOAD_FIELDS_", "")
3486
+
3487
+ mutation = gql(query_template)
3214
3488
  response = self.gql(
3215
3489
  mutation,
3216
3490
  variable_values={