wandb 0.19.12rc1__py3-none-win32.whl → 0.20.1__py3-none-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. wandb/__init__.py +1 -2
  2. wandb/__init__.pyi +3 -6
  3. wandb/_iterutils.py +26 -7
  4. wandb/_pydantic/__init__.py +2 -1
  5. wandb/_pydantic/utils.py +7 -0
  6. wandb/agents/pyagent.py +9 -15
  7. wandb/analytics/sentry.py +1 -2
  8. wandb/apis/attrs.py +3 -4
  9. wandb/apis/importers/internals/util.py +1 -1
  10. wandb/apis/importers/validation.py +2 -2
  11. wandb/apis/importers/wandb.py +30 -25
  12. wandb/apis/normalize.py +2 -2
  13. wandb/apis/public/__init__.py +1 -0
  14. wandb/apis/public/api.py +37 -33
  15. wandb/apis/public/artifacts.py +103 -72
  16. wandb/apis/public/jobs.py +3 -2
  17. wandb/apis/public/registries/registries_search.py +4 -2
  18. wandb/apis/public/registries/registry.py +1 -1
  19. wandb/apis/public/registries/utils.py +9 -9
  20. wandb/apis/public/runs.py +18 -6
  21. wandb/automations/_filters/expressions.py +1 -1
  22. wandb/automations/_filters/operators.py +1 -1
  23. wandb/automations/_filters/run_metrics.py +1 -1
  24. wandb/beta/workflows.py +6 -5
  25. wandb/bin/gpu_stats.exe +0 -0
  26. wandb/bin/wandb-core +0 -0
  27. wandb/cli/cli.py +54 -73
  28. wandb/docker/__init__.py +21 -74
  29. wandb/docker/names.py +40 -0
  30. wandb/env.py +0 -1
  31. wandb/errors/util.py +1 -1
  32. wandb/filesync/step_checksum.py +1 -1
  33. wandb/filesync/step_upload.py +1 -1
  34. wandb/integration/diffusers/resolvers/multimodal.py +1 -2
  35. wandb/integration/gym/__init__.py +5 -6
  36. wandb/integration/keras/callbacks/model_checkpoint.py +2 -2
  37. wandb/integration/keras/keras.py +13 -19
  38. wandb/integration/kfp/kfp_patch.py +2 -3
  39. wandb/integration/langchain/wandb_tracer.py +1 -1
  40. wandb/integration/metaflow/metaflow.py +13 -13
  41. wandb/integration/openai/fine_tuning.py +3 -2
  42. wandb/integration/sagemaker/auth.py +2 -1
  43. wandb/integration/sklearn/utils.py +2 -1
  44. wandb/integration/tensorboard/__init__.py +1 -1
  45. wandb/integration/tensorboard/log.py +2 -5
  46. wandb/integration/tensorflow/__init__.py +2 -2
  47. wandb/jupyter.py +20 -17
  48. wandb/plot/confusion_matrix.py +1 -1
  49. wandb/plot/utils.py +8 -7
  50. wandb/proto/v3/wandb_internal_pb2.py +355 -335
  51. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  52. wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
  53. wandb/proto/v4/wandb_internal_pb2.py +339 -335
  54. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  55. wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
  56. wandb/proto/v5/wandb_internal_pb2.py +339 -335
  57. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  58. wandb/proto/v5/wandb_telemetry_pb2.py +12 -12
  59. wandb/proto/v6/wandb_internal_pb2.py +339 -335
  60. wandb/proto/v6/wandb_settings_pb2.py +2 -2
  61. wandb/proto/v6/wandb_telemetry_pb2.py +12 -12
  62. wandb/proto/wandb_deprecated.py +6 -8
  63. wandb/sdk/artifacts/_internal_artifact.py +43 -0
  64. wandb/sdk/artifacts/_validators.py +55 -35
  65. wandb/sdk/artifacts/artifact.py +117 -115
  66. wandb/sdk/artifacts/artifact_download_logger.py +2 -0
  67. wandb/sdk/artifacts/artifact_saver.py +1 -3
  68. wandb/sdk/artifacts/artifact_state.py +2 -0
  69. wandb/sdk/artifacts/artifact_ttl.py +2 -0
  70. wandb/sdk/artifacts/exceptions.py +14 -0
  71. wandb/sdk/artifacts/staging.py +2 -0
  72. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +2 -6
  73. wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -1
  74. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +2 -6
  75. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -5
  76. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -1
  77. wandb/sdk/artifacts/storage_layout.py +2 -0
  78. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +3 -3
  79. wandb/sdk/backend/backend.py +11 -182
  80. wandb/sdk/data_types/_dtypes.py +2 -6
  81. wandb/sdk/data_types/audio.py +20 -3
  82. wandb/sdk/data_types/base_types/media.py +12 -7
  83. wandb/sdk/data_types/base_types/wb_value.py +8 -18
  84. wandb/sdk/data_types/bokeh.py +19 -2
  85. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +17 -1
  86. wandb/sdk/data_types/helper_types/image_mask.py +7 -1
  87. wandb/sdk/data_types/html.py +4 -4
  88. wandb/sdk/data_types/image.py +178 -103
  89. wandb/sdk/data_types/molecule.py +6 -6
  90. wandb/sdk/data_types/object_3d.py +10 -5
  91. wandb/sdk/data_types/saved_model.py +11 -6
  92. wandb/sdk/data_types/table.py +313 -83
  93. wandb/sdk/data_types/table_decorators.py +108 -0
  94. wandb/sdk/data_types/utils.py +43 -7
  95. wandb/sdk/data_types/video.py +21 -3
  96. wandb/sdk/interface/interface.py +10 -0
  97. wandb/sdk/internal/datastore.py +2 -6
  98. wandb/sdk/internal/file_pusher.py +1 -5
  99. wandb/sdk/internal/file_stream.py +8 -17
  100. wandb/sdk/internal/handler.py +2 -2
  101. wandb/sdk/internal/incremental_table_util.py +53 -0
  102. wandb/sdk/internal/internal.py +3 -5
  103. wandb/sdk/internal/internal_api.py +66 -89
  104. wandb/sdk/internal/job_builder.py +2 -7
  105. wandb/sdk/internal/profiler.py +2 -2
  106. wandb/sdk/internal/progress.py +1 -3
  107. wandb/sdk/internal/run.py +1 -6
  108. wandb/sdk/internal/sender.py +24 -36
  109. wandb/sdk/internal/system/assets/aggregators.py +1 -7
  110. wandb/sdk/internal/system/assets/disk.py +3 -3
  111. wandb/sdk/internal/system/assets/gpu.py +4 -4
  112. wandb/sdk/internal/system/assets/gpu_amd.py +4 -4
  113. wandb/sdk/internal/system/assets/interfaces.py +6 -6
  114. wandb/sdk/internal/system/assets/tpu.py +1 -1
  115. wandb/sdk/internal/system/assets/trainium.py +6 -6
  116. wandb/sdk/internal/system/system_info.py +5 -7
  117. wandb/sdk/internal/system/system_monitor.py +4 -4
  118. wandb/sdk/internal/tb_watcher.py +5 -7
  119. wandb/sdk/launch/_launch.py +1 -1
  120. wandb/sdk/launch/_project_spec.py +19 -20
  121. wandb/sdk/launch/agent/agent.py +3 -3
  122. wandb/sdk/launch/agent/config.py +1 -1
  123. wandb/sdk/launch/agent/job_status_tracker.py +2 -2
  124. wandb/sdk/launch/builder/build.py +2 -3
  125. wandb/sdk/launch/builder/kaniko_builder.py +5 -4
  126. wandb/sdk/launch/environment/gcp_environment.py +1 -2
  127. wandb/sdk/launch/registry/azure_container_registry.py +2 -2
  128. wandb/sdk/launch/registry/elastic_container_registry.py +2 -2
  129. wandb/sdk/launch/registry/google_artifact_registry.py +3 -3
  130. wandb/sdk/launch/runner/abstract.py +5 -5
  131. wandb/sdk/launch/runner/kubernetes_monitor.py +2 -2
  132. wandb/sdk/launch/runner/kubernetes_runner.py +1 -1
  133. wandb/sdk/launch/runner/sagemaker_runner.py +2 -4
  134. wandb/sdk/launch/runner/vertex_runner.py +2 -7
  135. wandb/sdk/launch/sweeps/__init__.py +1 -1
  136. wandb/sdk/launch/sweeps/scheduler.py +2 -2
  137. wandb/sdk/launch/sweeps/utils.py +3 -3
  138. wandb/sdk/launch/utils.py +3 -4
  139. wandb/sdk/lib/apikey.py +5 -8
  140. wandb/sdk/lib/config_util.py +3 -3
  141. wandb/sdk/lib/fsm.py +3 -18
  142. wandb/sdk/lib/gitlib.py +6 -5
  143. wandb/sdk/lib/ipython.py +2 -2
  144. wandb/sdk/lib/json_util.py +9 -14
  145. wandb/sdk/lib/printer.py +3 -8
  146. wandb/sdk/lib/redirect.py +1 -1
  147. wandb/sdk/lib/retry.py +3 -7
  148. wandb/sdk/lib/run_moment.py +2 -2
  149. wandb/sdk/lib/service_connection.py +3 -1
  150. wandb/sdk/lib/service_token.py +1 -2
  151. wandb/sdk/mailbox/mailbox_handle.py +3 -7
  152. wandb/sdk/mailbox/response_handle.py +2 -6
  153. wandb/sdk/service/streams.py +3 -7
  154. wandb/sdk/verify/verify.py +5 -6
  155. wandb/sdk/wandb_config.py +1 -1
  156. wandb/sdk/wandb_init.py +38 -106
  157. wandb/sdk/wandb_login.py +7 -6
  158. wandb/sdk/wandb_run.py +52 -240
  159. wandb/sdk/wandb_settings.py +71 -60
  160. wandb/sdk/wandb_setup.py +40 -14
  161. wandb/sdk/wandb_watch.py +5 -7
  162. wandb/sync/__init__.py +1 -1
  163. wandb/sync/sync.py +13 -13
  164. wandb/util.py +17 -35
  165. wandb/wandb_agent.py +8 -11
  166. {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/METADATA +5 -5
  167. {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/RECORD +170 -168
  168. wandb/docker/auth.py +0 -435
  169. wandb/docker/www_authenticate.py +0 -94
  170. {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/WHEEL +0 -0
  171. {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/entry_points.txt +0 -0
  172. {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,4 +1,3 @@
1
- import ast
2
1
  import base64
3
2
  import datetime
4
3
  import functools
@@ -12,7 +11,6 @@ import sys
12
11
  import threading
13
12
  from copy import deepcopy
14
13
  from pathlib import Path
15
- from types import MappingProxyType
16
14
  from typing import (
17
15
  IO,
18
16
  TYPE_CHECKING,
@@ -70,42 +68,42 @@ if TYPE_CHECKING:
70
68
  class CreateArtifactFileSpecInput(TypedDict, total=False):
71
69
  """Corresponds to `type CreateArtifactFileSpecInput` in schema.graphql."""
72
70
 
73
- artifactID: str # noqa: N815
71
+ artifactID: str
74
72
  name: str
75
73
  md5: str
76
74
  mimetype: Optional[str]
77
- artifactManifestID: Optional[str] # noqa: N815
78
- uploadPartsInput: Optional[List[Dict[str, object]]] # noqa: N815
75
+ artifactManifestID: Optional[str]
76
+ uploadPartsInput: Optional[List[Dict[str, object]]]
79
77
 
80
78
  class CreateArtifactFilesResponseFile(TypedDict):
81
79
  id: str
82
80
  name: str
83
- displayName: str # noqa: N815
84
- uploadUrl: Optional[str] # noqa: N815
85
- uploadHeaders: Sequence[str] # noqa: N815
86
- uploadMultipartUrls: "UploadPartsResponse" # noqa: N815
87
- storagePath: str # noqa: N815
81
+ displayName: str
82
+ uploadUrl: Optional[str]
83
+ uploadHeaders: Sequence[str]
84
+ uploadMultipartUrls: "UploadPartsResponse"
85
+ storagePath: str
88
86
  artifact: "CreateArtifactFilesResponseFileNode"
89
87
 
90
88
  class CreateArtifactFilesResponseFileNode(TypedDict):
91
89
  id: str
92
90
 
93
91
  class UploadPartsResponse(TypedDict):
94
- uploadUrlParts: List["UploadUrlParts"] # noqa: N815
95
- uploadID: str # noqa: N815
92
+ uploadUrlParts: List["UploadUrlParts"]
93
+ uploadID: str
96
94
 
97
95
  class UploadUrlParts(TypedDict):
98
- partNumber: int # noqa: N815
99
- uploadUrl: str # noqa: N815
96
+ partNumber: int
97
+ uploadUrl: str
100
98
 
101
99
  class CompleteMultipartUploadArtifactInput(TypedDict):
102
100
  """Corresponds to `type CompleteMultipartUploadArtifactInput` in schema.graphql."""
103
101
 
104
- completeMultipartAction: str # noqa: N815
105
- completedParts: Dict[int, str] # noqa: N815
106
- artifactID: str # noqa: N815
107
- storagePath: str # noqa: N815
108
- uploadID: str # noqa: N815
102
+ completeMultipartAction: str
103
+ completedParts: Dict[int, str]
104
+ artifactID: str
105
+ storagePath: str
106
+ uploadID: str
109
107
  md5: str
110
108
 
111
109
  class CompleteMultipartUploadArtifactResponse(TypedDict):
@@ -238,7 +236,7 @@ class Api:
238
236
  ]
239
237
  ] = None,
240
238
  load_settings: bool = True,
241
- retry_timedelta: datetime.timedelta = datetime.timedelta( # noqa: B008 # okay because it's immutable
239
+ retry_timedelta: datetime.timedelta = datetime.timedelta( # okay because it's immutable
242
240
  days=7
243
241
  ),
244
242
  environ: MutableMapping = os.environ,
@@ -364,7 +362,7 @@ class Api:
364
362
  self.server_create_run_queue_supports_priority: Optional[bool] = None
365
363
  self.server_supports_template_variables: Optional[bool] = None
366
364
  self.server_push_to_run_queue_supports_priority: Optional[bool] = None
367
- self._server_features_cache: Optional[dict[str, bool]] = None
365
+ self._server_features_cache: Optional[Dict[str, bool]] = None
368
366
 
369
367
  def gql(self, *args: Any, **kwargs: Any) -> Any:
370
368
  ret = self._retry_gql(
@@ -399,8 +397,7 @@ class Api:
399
397
  except requests.exceptions.HTTPError as err:
400
398
  response = err.response
401
399
  assert response is not None
402
- logger.error(f"{response.status_code} response executing GraphQL.")
403
- logger.error(response.text)
400
+ logger.exception("Error executing GraphQL.")
404
401
  for error in parse_backend_error_messages(response):
405
402
  wandb.termerror(f"Error while calling W&B API: {error} ({response})")
406
403
  raise
@@ -869,45 +866,43 @@ class Api:
869
866
  _, _, mutations = self.server_info_introspection()
870
867
  return "updateRunQueueItemWarning" in mutations
871
868
 
872
- def _server_features(self) -> Mapping[str, bool]:
873
- """Returns a cached, read-only lookup of current server feature flags."""
874
- if self._server_features_cache is None:
875
- query = gql(SERVER_FEATURES_QUERY_GQL)
876
-
877
- try:
878
- response = self.gql(query)
879
- except Exception as e:
880
- # Unfortunately we currently have to match on the text of the error message
881
- if 'Cannot query field "features" on type "ServerInfo".' in str(e):
882
- self._server_features_cache = {}
883
- else:
884
- raise
869
+ def _server_features(self) -> Dict[str, bool]:
870
+ # NOTE: Avoid caching via `@cached_property`, due to undocumented
871
+ # locking behavior before Python 3.12.
872
+ # See: https://github.com/python/cpython/issues/87634
873
+ query = gql(SERVER_FEATURES_QUERY_GQL)
874
+ try:
875
+ response = self.gql(query)
876
+ except Exception as e:
877
+ # Unfortunately we currently have to match on the text of the error message,
878
+ # as the `gql` client raises `Exception` rather than a more specific error.
879
+ if 'Cannot query field "features" on type "ServerInfo".' in str(e):
880
+ self._server_features_cache = {}
885
881
  else:
886
- info = ServerFeaturesQuery.model_validate(response).server_info
887
- if info and (feats := info.features):
888
- self._server_features_cache = {
889
- f.name: f.is_enabled for f in feats if f
890
- }
891
- else:
892
- self._server_features_cache = {}
882
+ raise
883
+ else:
884
+ info = ServerFeaturesQuery.model_validate(response).server_info
885
+ if info and (feats := info.features):
886
+ self._server_features_cache = {f.name: f.is_enabled for f in feats if f}
887
+ else:
888
+ self._server_features_cache = {}
889
+ return self._server_features_cache
893
890
 
894
- return MappingProxyType(self._server_features_cache)
891
+ def _server_supports(self, feature: Union[int, str]) -> bool:
892
+ """Return whether the current server supports the given feature.
895
893
 
896
- def _check_server_feature_with_fallback(self, feature_value: ServerFeature) -> bool:
897
- """Wrapper around check_server_feature that warns and returns False for older unsupported servers.
894
+ This also caches the underlying lookup of server feature flags,
895
+ and it maps {feature_name (str) -> is_enabled (bool)}.
898
896
 
899
897
  Good to use for features that have a fallback mechanism for older servers.
900
-
901
- Args:
902
- feature_value (ServerFeature): The enum value of the feature to check.
903
-
904
- Returns:
905
- bool: True if the feature is enabled, False otherwise.
906
-
907
- Exceptions:
908
- Exception: If an error other than the server not supporting feature queries occurs.
909
898
  """
910
- return self._server_features().get(ServerFeature.Name(feature_value), False)
899
+ # If we're given the protobuf enum value, convert to a string name.
900
+ # NOTE: We deliberately use names (str) instead of enum values (int)
901
+ # as the keys here, since:
902
+ # - the server identifies features by their name, rather than (client-side) enum value
903
+ # - the defined list of client-side flags may be behind the server-side list of flags
904
+ key = ServerFeature.Name(feature) if isinstance(feature, int) else feature
905
+ return self._server_features().get(key) or False
911
906
 
912
907
  @normalize_exceptions
913
908
  def update_run_queue_item_warning(
@@ -2092,9 +2087,7 @@ class Api:
2092
2087
  )
2093
2088
  if default is None or default.get("queueID") is None:
2094
2089
  raise CommError(
2095
- "Unable to create default queue for {}/{}. No queues for agent to poll".format(
2096
- entity, project
2097
- )
2090
+ f"Unable to create default queue for {entity}/{project}. No queues for agent to poll"
2098
2091
  )
2099
2092
  project_queues = [{"id": default["queueID"], "name": "default"}]
2100
2093
  polling_queue_ids = [
@@ -2571,15 +2564,11 @@ class Api:
2571
2564
  res = self.gql(query, variable_values)
2572
2565
  if res.get("project") is None:
2573
2566
  raise CommError(
2574
- "Error fetching run info for {}/{}/{}. Check that this project exists and you have access to this entity and project".format(
2575
- entity, project, name
2576
- )
2567
+ f"Error fetching run info for {entity}/{project}/{name}. Check that this project exists and you have access to this entity and project"
2577
2568
  )
2578
2569
  elif res["project"].get("run") is None:
2579
2570
  raise CommError(
2580
- "Error fetching run info for {}/{}/{}. Check that this run id exists".format(
2581
- entity, project, name
2582
- )
2571
+ f"Error fetching run info for {entity}/{project}/{name}. Check that this run id exists"
2583
2572
  )
2584
2573
  run_info: dict = res["project"]["run"]["runInfo"]
2585
2574
  return run_info
@@ -2993,11 +2982,8 @@ class Api:
2993
2982
  logger.debug("upload_file: %s complete", url)
2994
2983
  response.raise_for_status()
2995
2984
  except requests.exceptions.RequestException as e:
2996
- logger.error(f"upload_file exception {url}: {e}")
2997
- request_headers = e.request.headers if e.request is not None else ""
2998
- logger.error(f"upload_file request headers: {request_headers!r}")
2985
+ logger.exception(f"upload_file exception for {url=}")
2999
2986
  response_content = e.response.content if e.response is not None else ""
3000
- logger.error(f"upload_file response body: {response_content!r}")
3001
2987
  status_code = e.response.status_code if e.response is not None else 0
3002
2988
  # S3 reports retryable request timeouts out-of-band
3003
2989
  is_aws_retryable = status_code == 400 and "RequestTimeout" in str(
@@ -3059,11 +3045,8 @@ class Api:
3059
3045
  logger.debug("upload_file: %s complete", url)
3060
3046
  response.raise_for_status()
3061
3047
  except requests.exceptions.RequestException as e:
3062
- logger.error(f"upload_file exception {url}: {e}")
3063
- request_headers = e.request.headers if e.request is not None else ""
3064
- logger.error(f"upload_file request headers: {request_headers}")
3048
+ logger.exception(f"upload_file exception for {url=}")
3065
3049
  response_content = e.response.content if e.response is not None else ""
3066
- logger.error(f"upload_file response body: {response_content!r}")
3067
3050
  status_code = e.response.status_code if e.response is not None else 0
3068
3051
  # S3 reports retryable request timeouts out-of-band
3069
3052
  is_aws_retryable = (
@@ -3190,10 +3173,8 @@ class Api:
3190
3173
  },
3191
3174
  timeout=60,
3192
3175
  )
3193
- except Exception as e:
3194
- # GQL raises exceptions with stringified python dictionaries :/
3195
- message = ast.literal_eval(e.args[0])["message"]
3196
- logger.error("Error communicating with W&B: %s", message)
3176
+ except Exception:
3177
+ logger.exception("Error communicating with W&B.")
3197
3178
  return []
3198
3179
  else:
3199
3180
  result: List[Dict[str, Any]] = json.loads(
@@ -3235,10 +3216,8 @@ class Api:
3235
3216
  parameter["distribution"] = "uniform"
3236
3217
  else:
3237
3218
  raise ValueError(
3238
- "Parameter {} is ambiguous, please specify bounds as both floats (for a float_"
3239
- "uniform distribution) or ints (for an int_uniform distribution).".format(
3240
- parameter_name
3241
- )
3219
+ f"Parameter {parameter_name} is ambiguous, please specify bounds as both floats (for a float_"
3220
+ "uniform distribution) or ints (for an int_uniform distribution)."
3242
3221
  )
3243
3222
  return config
3244
3223
 
@@ -3387,8 +3366,8 @@ class Api:
3387
3366
  variable_values=variables,
3388
3367
  check_retry_fn=util.no_retry_4xx,
3389
3368
  )
3390
- except UsageError as e:
3391
- raise e
3369
+ except UsageError:
3370
+ raise
3392
3371
  except Exception as e:
3393
3372
  # graphql schema exception is generic
3394
3373
  err = e
@@ -3783,10 +3762,8 @@ class Api:
3783
3762
  "usedAs": use_as,
3784
3763
  }
3785
3764
 
3786
- server_allows_entity_project_information = (
3787
- self._check_server_feature_with_fallback(
3788
- ServerFeature.USE_ARTIFACT_WITH_ENTITY_AND_PROJECT_INFORMATION # type: ignore
3789
- )
3765
+ server_allows_entity_project_information = self._server_supports(
3766
+ ServerFeature.USE_ARTIFACT_WITH_ENTITY_AND_PROJECT_INFORMATION
3790
3767
  )
3791
3768
  if server_allows_entity_project_information:
3792
3769
  query_vars.extend(
@@ -4565,9 +4542,9 @@ class Api:
4565
4542
  s = self.sweep(sweep=sweep, entity=entity, project=project, specs="{}")
4566
4543
  curr_state = s["state"].upper()
4567
4544
  if state == "PAUSED" and curr_state not in ("PAUSED", "RUNNING"):
4568
- raise Exception("Cannot pause {} sweep.".format(curr_state.lower()))
4545
+ raise Exception(f"Cannot pause {curr_state.lower()} sweep.")
4569
4546
  elif state != "RUNNING" and curr_state not in ("RUNNING", "PAUSED", "PENDING"):
4570
- raise Exception("Sweep already {}.".format(curr_state.lower()))
4547
+ raise Exception(f"Sweep already {curr_state.lower()}.")
4571
4548
  sweep_id = s["id"]
4572
4549
  mutation = gql(
4573
4550
  """
@@ -19,6 +19,7 @@ from typing import (
19
19
  )
20
20
 
21
21
  import wandb
22
+ from wandb.sdk.artifacts._internal_artifact import InternalArtifact
22
23
  from wandb.sdk.artifacts.artifact import Artifact
23
24
  from wandb.sdk.data_types._dtypes import TypeRegistry
24
25
  from wandb.sdk.internal.internal_api import Api
@@ -128,12 +129,6 @@ def get_min_supported_for_source_dict(
128
129
  return min_seen
129
130
 
130
131
 
131
- class JobArtifact(Artifact):
132
- def __init__(self, name: str, *args: Any, **kwargs: Any):
133
- super().__init__(name, "placeholder", *args, **kwargs)
134
- self._type = JOB_ARTIFACT_TYPE # Get around type restriction.
135
-
136
-
137
132
  class JobBuilder:
138
133
  _settings: SettingsStatic
139
134
  _metadatafile_path: Optional[str]
@@ -552,7 +547,7 @@ class JobBuilder:
552
547
  assert source_info is not None
553
548
  assert name is not None
554
549
 
555
- artifact = JobArtifact(name)
550
+ artifact = InternalArtifact(name, JOB_ARTIFACT_TYPE)
556
551
 
557
552
  _logger.info("adding wandb-job metadata file")
558
553
  with artifact.new_file("wandb-job.json") as f:
@@ -54,12 +54,12 @@ def torch_trace_handler():
54
54
  prof.step()
55
55
  ```
56
56
  """
57
- from wandb.util import parse_version
57
+ from packaging.version import parse
58
58
 
59
59
  torch = wandb.util.get_module(PYTORCH_MODULE, required=True)
60
60
  torch_profiler = wandb.util.get_module(PYTORCH_PROFILER_MODULE, required=True)
61
61
 
62
- if parse_version(torch.__version__) < parse_version("1.9.0"):
62
+ if parse(torch.__version__) < parse("1.9.0"):
63
63
  raise Error(
64
64
  f"torch version must be at least 1.9 in order to use the PyTorch Profiler API.\
65
65
  \nVersion of torch currently installed: {torch.__version__}"
@@ -43,9 +43,7 @@ class Progress:
43
43
  # files getting truncated while uploading seems like something
44
44
  # that shouldn't really be happening anyway.
45
45
  raise CommError(
46
- "File {} size shrank from {} to {} while it was being uploaded.".format(
47
- self.file.name, self.len, self.bytes_read
48
- )
46
+ f"File {self.file.name} size shrank from {self.len} to {self.bytes_read} while it was being uploaded."
49
47
  )
50
48
  # Growing files are also likely to be bad, but our code didn't break
51
49
  # on those in the past, so it's riskier to make that an error now.
wandb/sdk/internal/run.py CHANGED
@@ -5,12 +5,7 @@ Semi-stubbed run for internal process use.
5
5
 
6
6
  """
7
7
 
8
- import sys
9
-
10
- if sys.version_info >= (3, 12):
11
- from typing import override
12
- else:
13
- from typing_extensions import override
8
+ from typing_extensions import override
14
9
 
15
10
  from wandb.sdk import wandb_run
16
11
 
@@ -749,14 +749,12 @@ class SendManager:
749
749
  self._resume_state.wandb_runtime = new_runtime
750
750
  tags = resume_status.get("tags") or []
751
751
 
752
- except (IndexError, ValueError) as e:
753
- logger.error("unable to load resume tails", exc_info=e)
752
+ except (IndexError, ValueError):
753
+ logger.exception("unable to load resume tails")
754
754
  if self._settings.resume == "must":
755
755
  error = wandb_internal_pb2.ErrorInfo()
756
756
  error.code = wandb_internal_pb2.ErrorInfo.ErrorCode.USAGE
757
- error.message = "resume='must' but could not resume ({}) ".format(
758
- run.run_id
759
- )
757
+ error.message = f"resume='must' but could not resume ({run.run_id}) "
760
758
  return error
761
759
 
762
760
  # TODO: Do we need to restore config / summary?
@@ -772,7 +770,7 @@ class SendManager:
772
770
  self._resume_state.summary = summary
773
771
  self._resume_state.tags = tags
774
772
  self._resume_state.resumed = True
775
- logger.info("configured resuming with: {}".format(self._resume_state))
773
+ logger.info(f"configured resuming with: {self._resume_state}")
776
774
  return None
777
775
 
778
776
  def _telemetry_get_framework(self) -> str:
@@ -816,9 +814,7 @@ class SendManager:
816
814
  self._interface.publish_config(
817
815
  key=("_wandb", "spell_url"), val=env.get("SPELL_RUN_URL")
818
816
  )
819
- url = "{}/{}/{}/runs/{}".format(
820
- self._api.app_url, self._run.entity, self._run.project, self._run.run_id
821
- )
817
+ url = f"{self._api.app_url}/{self._run.entity}/{self._run.project}/runs/{self._run.run_id}"
822
818
  requests.put(
823
819
  env.get("SPELL_API_URL", "https://api.spell.run") + "/wandb_url",
824
820
  json={"access_token": env.get("WANDB_ACCESS_TOKEN"), "url": url},
@@ -829,23 +825,22 @@ class SendManager:
829
825
  # TODO: do something if sync spell is not successful?
830
826
 
831
827
  def _setup_fork(self, server_run: dict):
832
- assert self._settings.fork_from
833
- assert self._settings.fork_from.metric == "_step"
834
828
  assert self._run
835
- first_step = int(self._settings.fork_from.value) + 1
829
+ assert self._run.branch_point
830
+ first_step = int(self._run.branch_point.value) + 1
836
831
  self._resume_state.step = first_step
837
832
  self._resume_state.history = server_run.get("historyLineCount", 0)
838
833
  self._run.forked = True
839
834
  self._run.starting_step = first_step
840
835
 
841
836
  def _load_rewind_state(self, run: "RunRecord"):
842
- assert self._settings.resume_from
837
+ assert run.branch_point
843
838
  self._rewind_response = self._api.rewind_run(
844
839
  run_name=run.run_id,
845
840
  entity=run.entity or None,
846
841
  project=run.project or None,
847
- metric_name=self._settings.resume_from.metric,
848
- metric_value=self._settings.resume_from.value,
842
+ metric_name=run.branch_point.metric,
843
+ metric_value=run.branch_point.value,
849
844
  program_path=self._settings.program or None,
850
845
  )
851
846
  self._resume_state.history = self._rewind_response.get("historyLineCount", 0)
@@ -854,12 +849,11 @@ class SendManager:
854
849
  )
855
850
 
856
851
  def _install_rewind_state(self):
857
- assert self._settings.resume_from
858
- assert self._settings.resume_from.metric == "_step"
859
852
  assert self._run
853
+ assert self._run.branch_point
860
854
  assert self._rewind_response
861
855
 
862
- first_step = int(self._settings.resume_from.value) + 1
856
+ first_step = int(self._run.branch_point.value) + 1
863
857
  self._resume_state.step = first_step
864
858
 
865
859
  # We set the fork flag here because rewind uses the forking
@@ -903,8 +897,8 @@ class SendManager:
903
897
  config_value_dict = self._config_backend_dict()
904
898
  self._config_save(config_value_dict)
905
899
 
906
- do_fork = self._settings.fork_from is not None and is_wandb_init
907
- do_rewind = self._settings.resume_from is not None and is_wandb_init
900
+ do_rewind = run.branch_point.run == run.run_id
901
+ do_fork = not do_rewind and run.branch_point.run != ""
908
902
  do_resume = bool(self._settings.resume)
909
903
 
910
904
  num_resume_options_set = sum([do_fork, do_rewind, do_resume])
@@ -1188,7 +1182,7 @@ class SendManager:
1188
1182
  try:
1189
1183
  d[item.key] = json.loads(item.value_json)
1190
1184
  except json.JSONDecodeError:
1191
- logger.error("error decoding stats json: %s", item.value_json)
1185
+ logger.exception("error decoding stats json: %s", item.value_json)
1192
1186
  row: Dict[str, Any] = dict(system=d)
1193
1187
  self._flatten(row)
1194
1188
  row["_wandb"] = True
@@ -1500,17 +1494,15 @@ class SendManager:
1500
1494
  try:
1501
1495
  res = self._send_artifact(artifact)
1502
1496
  logger.info(f"sent artifact {artifact.name} - {res}")
1503
- except Exception as e:
1504
- logger.error(
1505
- 'send_artifact: failed for artifact "{}/{}": {}'.format(
1506
- artifact.type, artifact.name, e
1507
- )
1497
+ except Exception:
1498
+ logger.exception(
1499
+ f'send_artifact: failed for artifact "{artifact.type}/{artifact.name}"'
1508
1500
  )
1509
1501
 
1510
1502
  def _send_artifact(
1511
1503
  self, artifact: "ArtifactRecord", history_step: Optional[int] = None
1512
1504
  ) -> Optional[Dict]:
1513
- from wandb.util import parse_version
1505
+ from packaging.version import parse
1514
1506
 
1515
1507
  assert self._pusher
1516
1508
  saver = ArtifactSaver(
@@ -1523,9 +1515,7 @@ class SendManager:
1523
1515
 
1524
1516
  if artifact.distributed_id:
1525
1517
  max_cli_version = self._max_cli_version()
1526
- if max_cli_version is None or parse_version(
1527
- max_cli_version
1528
- ) < parse_version("0.10.16"):
1518
+ if max_cli_version is None or parse(max_cli_version) < parse("0.10.16"):
1529
1519
  logger.warning(
1530
1520
  "This W&B Server doesn't support distributed artifacts, "
1531
1521
  "have your administrator install wandb/local >= 0.9.37"
@@ -1561,13 +1551,11 @@ class SendManager:
1561
1551
  return res
1562
1552
 
1563
1553
  def send_alert(self, record: "Record") -> None:
1564
- from wandb.util import parse_version
1554
+ from packaging.version import parse
1565
1555
 
1566
1556
  alert = record.alert
1567
1557
  max_cli_version = self._max_cli_version()
1568
- if max_cli_version is None or parse_version(max_cli_version) < parse_version(
1569
- "0.10.9"
1570
- ):
1558
+ if max_cli_version is None or parse(max_cli_version) < parse("0.10.9"):
1571
1559
  logger.warning(
1572
1560
  "This W&B server doesn't support alerts, "
1573
1561
  "have your administrator install wandb/local >= 0.9.31"
@@ -1580,8 +1568,8 @@ class SendManager:
1580
1568
  level=alert.level,
1581
1569
  wait_duration=alert.wait_duration,
1582
1570
  )
1583
- except Exception as e:
1584
- logger.error(f"send_alert: failed for alert {alert.title!r}: {e}")
1571
+ except Exception:
1572
+ logger.exception(f"send_alert: failed for alert {alert.title!r}")
1585
1573
 
1586
1574
  def finish(self) -> None:
1587
1575
  logger.info("shutting down sender")
@@ -1,10 +1,4 @@
1
- import sys
2
- from typing import Union
3
-
4
- if sys.version_info >= (3, 9):
5
- from collections.abc import Sequence
6
- else:
7
- from typing import Sequence
1
+ from typing import Sequence, Union
8
2
 
9
3
  Number = Union[int, float]
10
4
 
@@ -33,7 +33,7 @@ class DiskUsagePercent:
33
33
  try:
34
34
  psutil.disk_usage(path)
35
35
  self.paths.append(path)
36
- except Exception as e: # noqa
36
+ except Exception as e:
37
37
  termwarn(f"Could not access disk path {path}: {e}", repeat=False)
38
38
 
39
39
  def sample(self) -> None:
@@ -74,7 +74,7 @@ class DiskUsage:
74
74
  try:
75
75
  psutil.disk_usage(path)
76
76
  self.paths.append(path)
77
- except Exception as e: # noqa
77
+ except Exception as e:
78
78
  termwarn(f"Could not access disk path {path}: {e}", repeat=False)
79
79
 
80
80
  def sample(self) -> None:
@@ -198,7 +198,7 @@ class Disk:
198
198
  "total": total,
199
199
  "used": used,
200
200
  }
201
- except Exception as e: # noqa
201
+ except Exception as e:
202
202
  termwarn(f"Could not access disk path {disk_path}: {e}", repeat=False)
203
203
 
204
204
  return {self.name: disk_metrics}
@@ -377,8 +377,8 @@ class GPU:
377
377
  return True
378
378
  except pynvml.NVMLError_LibraryNotFound: # type: ignore
379
379
  return False
380
- except Exception as e:
381
- logger.error(f"Error initializing NVML: {e}")
380
+ except Exception:
381
+ logger.exception("Error initializing NVML.")
382
382
  return False
383
383
 
384
384
  def start(self) -> None:
@@ -410,7 +410,7 @@ class GPU:
410
410
 
411
411
  except pynvml.NVMLError:
412
412
  pass
413
- except Exception as e:
414
- logger.error(f"Error Probing GPU: {e}")
413
+ except Exception:
414
+ logger.exception("Error Probing GPU.")
415
415
 
416
416
  return info
@@ -104,8 +104,8 @@ class GPUAMDStats:
104
104
  if cards:
105
105
  self.samples.append(cards)
106
106
 
107
- except (OSError, ValueError, TypeError, subprocess.CalledProcessError) as e:
108
- logger.exception(f"GPU stats error: {e}")
107
+ except (OSError, ValueError, TypeError, subprocess.CalledProcessError):
108
+ logger.exception("GPU stats error")
109
109
 
110
110
  def clear(self) -> None:
111
111
  self.samples.clear()
@@ -228,6 +228,6 @@ class GPUAMD:
228
228
  for key in stats.keys()
229
229
  if key.startswith("card")
230
230
  ]
231
- except Exception as e:
232
- logger.exception(f"GPUAMD probe error: {e}")
231
+ except Exception:
232
+ logger.exception("GPUAMD probe error")
233
233
  return info
@@ -136,8 +136,8 @@ class MetricsMonitor:
136
136
  logger.info(f"Process {metric.name} has exited.")
137
137
  self._shutdown_event.set()
138
138
  break
139
- except Exception as e:
140
- logger.error(f"Failed to sample metric: {e}")
139
+ except Exception:
140
+ logger.exception("Failed to sample metric.")
141
141
  self._shutdown_event.wait(self.sampling_interval)
142
142
  if self._shutdown_event.is_set():
143
143
  break
@@ -153,8 +153,8 @@ class MetricsMonitor:
153
153
  # aggregated_metrics = wandb.util.merge_dicts(
154
154
  # aggregated_metrics, metric.serialize()
155
155
  # )
156
- except Exception as e:
157
- logger.error(f"Failed to serialize metric: {e}")
156
+ except Exception:
157
+ logger.exception("Failed to serialize metric.")
158
158
  return aggregated_metrics
159
159
 
160
160
  def publish(self) -> None:
@@ -165,8 +165,8 @@ class MetricsMonitor:
165
165
  self._interface.publish_stats(aggregated_metrics)
166
166
  for metric in self.metrics:
167
167
  metric.clear()
168
- except Exception as e:
169
- logger.error(f"Failed to publish metrics: {e}")
168
+ except Exception:
169
+ logger.exception("Failed to publish metrics.")
170
170
 
171
171
  def start(self) -> None:
172
172
  if (self._process is not None) or self._shutdown_event.is_set():
@@ -91,7 +91,7 @@ class TPU:
91
91
  ) -> str:
92
92
  if service_addr is not None:
93
93
  if tpu_name is not None:
94
- logger.warn(
94
+ logger.warning(
95
95
  "Both service_addr and tpu_name arguments provided. "
96
96
  "Ignoring tpu_name and using service_addr."
97
97
  )