wandb 0.16.6__py3-none-any.whl → 0.17.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (193) hide show
  1. package_readme.md +95 -0
  2. wandb/__init__.py +2 -3
  3. wandb/agents/pyagent.py +0 -1
  4. wandb/analytics/sentry.py +2 -1
  5. wandb/apis/importers/internals/internal.py +0 -1
  6. wandb/apis/importers/internals/protocols.py +30 -56
  7. wandb/apis/importers/mlflow.py +13 -26
  8. wandb/apis/importers/wandb.py +8 -14
  9. wandb/apis/internal.py +0 -3
  10. wandb/apis/public/api.py +55 -3
  11. wandb/apis/public/artifacts.py +1 -0
  12. wandb/apis/public/files.py +1 -0
  13. wandb/apis/public/history.py +1 -0
  14. wandb/apis/public/jobs.py +17 -4
  15. wandb/apis/public/projects.py +1 -0
  16. wandb/apis/public/reports.py +1 -0
  17. wandb/apis/public/runs.py +15 -17
  18. wandb/apis/public/sweeps.py +1 -0
  19. wandb/apis/public/teams.py +1 -0
  20. wandb/apis/public/users.py +1 -0
  21. wandb/apis/reports/v1/_blocks.py +3 -7
  22. wandb/apis/reports/v2/gql.py +1 -0
  23. wandb/apis/reports/v2/interface.py +3 -4
  24. wandb/apis/reports/v2/internal.py +5 -8
  25. wandb/cli/cli.py +92 -22
  26. wandb/data_types.py +9 -6
  27. wandb/docker/__init__.py +1 -1
  28. wandb/env.py +38 -8
  29. wandb/errors/__init__.py +5 -0
  30. wandb/errors/term.py +10 -2
  31. wandb/filesync/step_checksum.py +1 -4
  32. wandb/filesync/step_prepare.py +4 -24
  33. wandb/filesync/step_upload.py +4 -106
  34. wandb/filesync/upload_job.py +0 -76
  35. wandb/integration/catboost/catboost.py +1 -1
  36. wandb/integration/fastai/__init__.py +1 -0
  37. wandb/integration/huggingface/resolver.py +2 -2
  38. wandb/integration/keras/__init__.py +1 -0
  39. wandb/integration/keras/callbacks/metrics_logger.py +1 -1
  40. wandb/integration/keras/keras.py +7 -7
  41. wandb/integration/langchain/wandb_tracer.py +1 -0
  42. wandb/integration/lightning/fabric/logger.py +1 -3
  43. wandb/integration/metaflow/metaflow.py +41 -6
  44. wandb/integration/openai/fine_tuning.py +3 -3
  45. wandb/integration/prodigy/prodigy.py +1 -1
  46. wandb/old/summary.py +1 -1
  47. wandb/plot/confusion_matrix.py +1 -1
  48. wandb/plot/pr_curve.py +2 -1
  49. wandb/plot/roc_curve.py +2 -1
  50. wandb/{plots → plot}/utils.py +13 -25
  51. wandb/proto/v3/wandb_internal_pb2.py +364 -332
  52. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  53. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  54. wandb/proto/v4/wandb_internal_pb2.py +322 -316
  55. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  56. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  57. wandb/proto/wandb_deprecated.py +7 -1
  58. wandb/proto/wandb_internal_codegen.py +3 -29
  59. wandb/sdk/artifacts/artifact.py +26 -11
  60. wandb/sdk/artifacts/artifact_download_logger.py +1 -0
  61. wandb/sdk/artifacts/artifact_file_cache.py +18 -4
  62. wandb/sdk/artifacts/artifact_instance_cache.py +1 -0
  63. wandb/sdk/artifacts/artifact_manifest.py +1 -0
  64. wandb/sdk/artifacts/artifact_manifest_entry.py +7 -3
  65. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
  66. wandb/sdk/artifacts/artifact_saver.py +2 -8
  67. wandb/sdk/artifacts/artifact_state.py +1 -0
  68. wandb/sdk/artifacts/artifact_ttl.py +1 -0
  69. wandb/sdk/artifacts/exceptions.py +1 -0
  70. wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
  71. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +13 -18
  72. wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -0
  73. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -0
  74. wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -0
  75. wandb/sdk/artifacts/storage_handlers/s3_handler.py +5 -3
  76. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +1 -0
  77. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -0
  78. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -0
  79. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +3 -42
  80. wandb/sdk/artifacts/storage_policy.py +2 -12
  81. wandb/sdk/data_types/_dtypes.py +8 -8
  82. wandb/sdk/data_types/base_types/media.py +3 -6
  83. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +3 -1
  84. wandb/sdk/data_types/image.py +1 -1
  85. wandb/sdk/data_types/video.py +1 -1
  86. wandb/sdk/integration_utils/auto_logging.py +5 -6
  87. wandb/sdk/integration_utils/data_logging.py +10 -6
  88. wandb/sdk/interface/interface.py +68 -32
  89. wandb/sdk/interface/interface_shared.py +7 -13
  90. wandb/sdk/internal/datastore.py +1 -1
  91. wandb/sdk/internal/file_pusher.py +2 -5
  92. wandb/sdk/internal/file_stream.py +5 -18
  93. wandb/sdk/internal/handler.py +18 -2
  94. wandb/sdk/internal/internal.py +0 -1
  95. wandb/sdk/internal/internal_api.py +1 -129
  96. wandb/sdk/internal/internal_util.py +0 -1
  97. wandb/sdk/internal/job_builder.py +159 -45
  98. wandb/sdk/internal/profiler.py +1 -0
  99. wandb/sdk/internal/progress.py +0 -28
  100. wandb/sdk/internal/run.py +1 -0
  101. wandb/sdk/internal/sender.py +1 -2
  102. wandb/sdk/internal/system/assets/gpu_amd.py +44 -44
  103. wandb/sdk/internal/system/assets/gpu_apple.py +56 -11
  104. wandb/sdk/internal/system/assets/interfaces.py +6 -8
  105. wandb/sdk/internal/system/assets/open_metrics.py +2 -2
  106. wandb/sdk/internal/system/assets/trainium.py +1 -3
  107. wandb/sdk/launch/__init__.py +9 -1
  108. wandb/sdk/launch/_launch.py +4 -24
  109. wandb/sdk/launch/_launch_add.py +1 -3
  110. wandb/sdk/launch/_project_spec.py +186 -224
  111. wandb/sdk/launch/agent/agent.py +37 -13
  112. wandb/sdk/launch/agent/config.py +72 -14
  113. wandb/sdk/launch/builder/abstract.py +69 -1
  114. wandb/sdk/launch/builder/build.py +156 -555
  115. wandb/sdk/launch/builder/context_manager.py +235 -0
  116. wandb/sdk/launch/builder/docker_builder.py +8 -23
  117. wandb/sdk/launch/builder/kaniko_builder.py +12 -25
  118. wandb/sdk/launch/builder/noop.py +1 -0
  119. wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
  120. wandb/sdk/launch/create_job.py +47 -37
  121. wandb/sdk/launch/environment/abstract.py +1 -0
  122. wandb/sdk/launch/environment/gcp_environment.py +1 -0
  123. wandb/sdk/launch/environment/local_environment.py +1 -0
  124. wandb/sdk/launch/inputs/files.py +148 -0
  125. wandb/sdk/launch/inputs/internal.py +217 -0
  126. wandb/sdk/launch/inputs/manage.py +95 -0
  127. wandb/sdk/launch/loader.py +1 -0
  128. wandb/sdk/launch/registry/abstract.py +1 -0
  129. wandb/sdk/launch/registry/azure_container_registry.py +1 -0
  130. wandb/sdk/launch/registry/elastic_container_registry.py +1 -0
  131. wandb/sdk/launch/registry/google_artifact_registry.py +2 -1
  132. wandb/sdk/launch/registry/local_registry.py +1 -0
  133. wandb/sdk/launch/runner/abstract.py +1 -0
  134. wandb/sdk/launch/runner/kubernetes_monitor.py +1 -0
  135. wandb/sdk/launch/runner/kubernetes_runner.py +9 -10
  136. wandb/sdk/launch/runner/local_container.py +2 -3
  137. wandb/sdk/launch/runner/local_process.py +8 -29
  138. wandb/sdk/launch/runner/sagemaker_runner.py +21 -20
  139. wandb/sdk/launch/runner/vertex_runner.py +8 -7
  140. wandb/sdk/launch/sweeps/scheduler.py +4 -3
  141. wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
  142. wandb/sdk/launch/sweeps/utils.py +3 -3
  143. wandb/sdk/launch/utils.py +15 -140
  144. wandb/sdk/lib/_settings_toposort_generated.py +0 -5
  145. wandb/sdk/lib/fsm.py +8 -12
  146. wandb/sdk/lib/gitlib.py +4 -4
  147. wandb/sdk/lib/import_hooks.py +1 -1
  148. wandb/sdk/lib/lazyloader.py +0 -1
  149. wandb/sdk/lib/proto_util.py +23 -2
  150. wandb/sdk/lib/redirect.py +19 -14
  151. wandb/sdk/lib/retry.py +3 -2
  152. wandb/sdk/lib/tracelog.py +1 -1
  153. wandb/sdk/service/service.py +19 -16
  154. wandb/sdk/verify/verify.py +2 -1
  155. wandb/sdk/wandb_init.py +14 -55
  156. wandb/sdk/wandb_manager.py +2 -2
  157. wandb/sdk/wandb_require.py +5 -0
  158. wandb/sdk/wandb_run.py +114 -56
  159. wandb/sdk/wandb_settings.py +0 -48
  160. wandb/sdk/wandb_setup.py +1 -1
  161. wandb/sklearn/__init__.py +1 -0
  162. wandb/sklearn/plot/__init__.py +1 -0
  163. wandb/sklearn/plot/classifier.py +11 -12
  164. wandb/sklearn/plot/clusterer.py +2 -1
  165. wandb/sklearn/plot/regressor.py +1 -0
  166. wandb/sklearn/plot/shared.py +1 -0
  167. wandb/sklearn/utils.py +1 -0
  168. wandb/testing/relay.py +4 -4
  169. wandb/trigger.py +1 -0
  170. wandb/util.py +67 -54
  171. wandb/wandb_controller.py +2 -3
  172. wandb/wandb_torch.py +1 -2
  173. {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/METADATA +67 -70
  174. {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/RECORD +177 -187
  175. {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/WHEEL +1 -2
  176. wandb/bin/apple_gpu_stats +0 -0
  177. wandb/catboost/__init__.py +0 -9
  178. wandb/fastai/__init__.py +0 -9
  179. wandb/keras/__init__.py +0 -18
  180. wandb/lightgbm/__init__.py +0 -9
  181. wandb/plots/__init__.py +0 -6
  182. wandb/plots/explain_text.py +0 -36
  183. wandb/plots/heatmap.py +0 -81
  184. wandb/plots/named_entity.py +0 -43
  185. wandb/plots/part_of_speech.py +0 -50
  186. wandb/plots/plot_definitions.py +0 -768
  187. wandb/plots/precision_recall.py +0 -121
  188. wandb/plots/roc.py +0 -103
  189. wandb/sacred/__init__.py +0 -3
  190. wandb/xgboost/__init__.py +0 -9
  191. wandb-0.16.6.dist-info/top_level.txt +0 -1
  192. {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/entry_points.txt +0 -0
  193. {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info/licenses}/LICENSE +0 -0
@@ -1,4 +1,5 @@
1
1
  """Implementation of the SageMakerRunner class."""
2
+
2
3
  import asyncio
3
4
  import logging
4
5
  from typing import Any, Dict, List, Optional, cast
@@ -11,8 +12,7 @@ from wandb.apis.internal import Api
11
12
  from wandb.sdk.launch.environment.aws_environment import AwsEnvironment
12
13
  from wandb.sdk.launch.errors import LaunchError
13
14
 
14
- from .._project_spec import EntryPoint, LaunchProject, get_entry_point_command
15
- from ..builder.build import get_env_vars_dict
15
+ from .._project_spec import EntryPoint, LaunchProject
16
16
  from ..registry.abstract import AbstractRegistry
17
17
  from ..utils import (
18
18
  LOG_PREFIX,
@@ -67,6 +67,7 @@ class SagemakerSubmittedRun(AbstractRun):
67
67
  logGroupName="/aws/sagemaker/TrainingJobs",
68
68
  logStreamName=log_name,
69
69
  )
70
+ assert "events" in res
70
71
  return "\n".join(
71
72
  [f'{event["timestamp"]}:{event["message"]}' for event in res["events"]]
72
73
  )
@@ -220,12 +221,12 @@ class SageMakerRunner(AbstractRunner):
220
221
  launch_project.fill_macros(image_uri)
221
222
  _logger.info("Connecting to sagemaker client")
222
223
  entry_point = (
223
- launch_project.override_entrypoint
224
- or launch_project.get_single_entry_point()
225
- )
226
- command_args = get_entry_point_command(
227
- entry_point, launch_project.override_args
224
+ launch_project.override_entrypoint or launch_project.get_job_entry_point()
228
225
  )
226
+ command_args = []
227
+ if entry_point is not None:
228
+ command_args += entry_point.command
229
+ command_args += launch_project.override_args
229
230
  if command_args:
230
231
  command_str = " ".join(command_args)
231
232
  wandb.termlog(
@@ -324,16 +325,16 @@ def build_sagemaker_args(
324
325
  sagemaker_args["TrainingJobName"] = training_job_name
325
326
  entry_cmd = entry_point.command if entry_point else []
326
327
 
327
- sagemaker_args[
328
- "AlgorithmSpecification"
329
- ] = merge_image_uri_with_algorithm_specification(
330
- given_sagemaker_args.get(
331
- "AlgorithmSpecification",
332
- given_sagemaker_args.get("algorithm_specification"),
333
- ),
334
- image_uri,
335
- entry_cmd,
336
- args,
328
+ sagemaker_args["AlgorithmSpecification"] = (
329
+ merge_image_uri_with_algorithm_specification(
330
+ given_sagemaker_args.get(
331
+ "AlgorithmSpecification",
332
+ given_sagemaker_args.get("algorithm_specification"),
333
+ ),
334
+ image_uri,
335
+ entry_cmd,
336
+ args,
337
+ )
337
338
  )
338
339
 
339
340
  sagemaker_args["RoleArn"] = role_arn
@@ -348,18 +349,18 @@ def build_sagemaker_args(
348
349
 
349
350
  if sagemaker_args.get("ResourceConfig") is None:
350
351
  raise LaunchError(
351
- "Sagemaker launcher requires a ResourceConfig Sagemaker resource argument"
352
+ "Sagemaker launcher requires a ResourceConfig resource argument"
352
353
  )
353
354
 
354
355
  if sagemaker_args.get("StoppingCondition") is None:
355
356
  raise LaunchError(
356
- "Sagemaker launcher requires a StoppingCondition Sagemaker resource argument"
357
+ "Sagemaker launcher requires a StoppingCondition resource argument"
357
358
  )
358
359
 
359
360
  given_env = given_sagemaker_args.get(
360
361
  "Environment", sagemaker_args.get("environment", {})
361
362
  )
362
- calced_env = get_env_vars_dict(launch_project, api, max_env_length)
363
+ calced_env = launch_project.get_env_vars_dict(api, max_env_length)
363
364
  total_env = {**calced_env, **given_env}
364
365
  sagemaker_args["Environment"] = total_env
365
366
 
@@ -8,8 +8,7 @@ if False:
8
8
  from wandb.apis.internal import Api
9
9
  from wandb.util import get_module
10
10
 
11
- from .._project_spec import LaunchProject, get_entry_point_command
12
- from ..builder.build import get_env_vars_dict
11
+ from .._project_spec import LaunchProject
13
12
  from ..environment.gcp_environment import GcpEnvironment
14
13
  from ..errors import LaunchError
15
14
  from ..registry.abstract import AbstractRegistry
@@ -113,14 +112,16 @@ class VertexRunner(AbstractRunner):
113
112
  synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
114
113
 
115
114
  entry_point = (
116
- launch_project.override_entrypoint
117
- or launch_project.get_single_entry_point()
115
+ launch_project.override_entrypoint or launch_project.get_job_entry_point()
118
116
  )
119
117
 
120
118
  # TODO: Set entrypoint in each container
121
- entry_cmd = get_entry_point_command(entry_point, launch_project.override_args)
122
- env_vars = get_env_vars_dict(
123
- launch_project=launch_project,
119
+ entry_cmd = []
120
+ if entry_point is not None:
121
+ entry_cmd += entry_point.command
122
+ entry_cmd += launch_project.override_args
123
+
124
+ env_vars = launch_project.get_env_vars_dict(
124
125
  api=self._api,
125
126
  max_env_length=MAX_ENV_LENGTHS[self.__class__.__name__],
126
127
  )
@@ -1,4 +1,5 @@
1
1
  """Abstract Scheduler class."""
2
+
2
3
  import asyncio
3
4
  import base64
4
5
  import copy
@@ -407,7 +408,7 @@ class Scheduler(ABC):
407
408
  return count
408
409
 
409
410
  def _try_load_executable(self) -> bool:
410
- """Check existance of valid executable for a run.
411
+ """Check existence of valid executable for a run.
411
412
 
412
413
  logs and returns False when job is unreachable
413
414
  """
@@ -422,7 +423,7 @@ class Scheduler(ABC):
422
423
  return False
423
424
  return True
424
425
  elif self._kwargs.get("image_uri"):
425
- # TODO(gst): check docker existance? Use registry in launch config?
426
+ # TODO(gst): check docker existence? Use registry in launch config?
426
427
  return True
427
428
  else:
428
429
  return False
@@ -610,7 +611,7 @@ class Scheduler(ABC):
610
611
  f"Failed to get runstate for run ({run_id}). Error: {traceback.format_exc()}"
611
612
  )
612
613
  run_state = RunState.FAILED
613
- else: # first time we get unknwon state
614
+ else: # first time we get unknown state
614
615
  run_state = RunState.UNKNOWN
615
616
  except (AttributeError, ValueError):
616
617
  wandb.termwarn(
@@ -1,4 +1,5 @@
1
1
  """Scheduler for classic wandb Sweeps."""
2
+
2
3
  import logging
3
4
  from pprint import pformat as pf
4
5
  from typing import Any, Dict, List, Optional
@@ -58,7 +59,7 @@ class SweepScheduler(Scheduler):
58
59
  return None
59
60
 
60
61
  def _get_sweep_commands(self, worker_id: int) -> List[Dict[str, Any]]:
61
- """Helper to recieve sweep command from backend."""
62
+ """Helper to receive sweep command from backend."""
62
63
  # AgentHeartbeat wants a Dict of runs which are running or queued
63
64
  _run_states: Dict[str, bool] = {}
64
65
  for run_id, run in self._yield_runs():
@@ -217,7 +217,7 @@ def create_sweep_command_args(command: Dict) -> Dict[str, Any]:
217
217
  flags: List[str] = []
218
218
  # (2) flags without hyphens (e.g. foo=bar)
219
219
  flags_no_hyphens: List[str] = []
220
- # (3) flags with false booleans ommited (e.g. --foo)
220
+ # (3) flags with false booleans omitted (e.g. --foo)
221
221
  flags_no_booleans: List[str] = []
222
222
  # (4) flags as a dictionary (used for constructing a json)
223
223
  flags_dict: Dict[str, Any] = {}
@@ -257,7 +257,7 @@ def make_launch_sweep_entrypoint(
257
257
  """Use args dict from create_sweep_command_args to construct entrypoint.
258
258
 
259
259
  If replace is True, remove macros from entrypoint, fill them in with args
260
- and then return the args in seperate return value.
260
+ and then return the args in separate return value.
261
261
  """
262
262
  if not command:
263
263
  return None, None
@@ -296,7 +296,7 @@ def check_job_exists(public_api: "PublicApi", job: Optional[str]) -> bool:
296
296
 
297
297
 
298
298
  def get_previous_args(
299
- run_spec: Dict[str, Any]
299
+ run_spec: Dict[str, Any],
300
300
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
301
301
  """Parse through previous scheduler run_spec.
302
302
 
wandb/sdk/launch/utils.py CHANGED
@@ -1,4 +1,3 @@
1
- # heavily inspired by https://github.com/mlflow/mlflow/blob/master/mlflow/projects/utils.py
2
1
  import asyncio
3
2
  import json
4
3
  import logging
@@ -16,7 +15,6 @@ import wandb
16
15
  import wandb.docker as docker
17
16
  from wandb import util
18
17
  from wandb.apis.internal import Api
19
- from wandb.errors import CommError
20
18
  from wandb.sdk.launch.errors import LaunchError
21
19
  from wandb.sdk.launch.git_reference import GitReference
22
20
  from wandb.sdk.launch.wandb_reference import WandbReference
@@ -32,7 +30,6 @@ FAILED_PACKAGES_REGEX = re.compile(
32
30
  )
33
31
 
34
32
  if TYPE_CHECKING: # pragma: no cover
35
- from wandb.sdk.artifacts.artifact import Artifact
36
33
  from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker
37
34
 
38
35
 
@@ -57,15 +54,15 @@ API_KEY_REGEX = r"WANDB_API_KEY=\w+(-\w+)?"
57
54
  MACRO_REGEX = re.compile(r"\$\{(\w+)\}")
58
55
 
59
56
  AZURE_CONTAINER_REGISTRY_URI_REGEX = re.compile(
60
- r"(?:https://)?([\w]+)\.azurecr\.io/([\w\-]+):?(.*)"
57
+ r"^(?:https://)?([\w]+)\.azurecr\.io/(?P<repository>[\w\-]+):?(?P<tag>.*)"
61
58
  )
62
59
 
63
60
  ELASTIC_CONTAINER_REGISTRY_URI_REGEX = re.compile(
64
- r"^(?P<account>.*)\.dkr\.ecr\.(?P<region>.*)\.amazonaws\.com/(?P<repository>.*)/?$"
61
+ r"^(?:https://)?(?P<account>[\w-]+)\.dkr\.ecr\.(?P<region>[\w-]+)\.amazonaws\.com/(?P<repository>[\w-]+):?(?P<tag>.*)$"
65
62
  )
66
63
 
67
64
  GCP_ARTIFACT_REGISTRY_URI_REGEX = re.compile(
68
- r"^(?P<region>[\w-]+)-docker\.pkg\.dev/(?P<project>[\w-]+)/(?P<repository>[\w-]+)/(?P<image_name>[\w-]+)$",
65
+ r"^(?:https://)?(?P<region>[\w-]+)-docker\.pkg\.dev/(?P<project>[\w-]+)/(?P<repository>[\w-]+)/?(?P<image_name>[\w-]+)?(?P<tag>:.*)?$",
69
66
  re.IGNORECASE,
70
67
  )
71
68
 
@@ -316,16 +313,13 @@ def construct_launch_spec(
316
313
 
317
314
 
318
315
  def validate_launch_spec_source(launch_spec: Dict[str, Any]) -> None:
319
- uri = launch_spec.get("uri")
320
316
  job = launch_spec.get("job")
321
317
  docker_image = launch_spec.get("docker", {}).get("docker_image")
322
-
323
- if not bool(uri) and not bool(job) and not bool(docker_image):
324
- raise LaunchError("Must specify a uri, job or docker image")
325
- elif bool(uri) and bool(docker_image):
326
- raise LaunchError("Found both uri and docker-image, only one can be set")
327
- elif sum(map(bool, [uri, job, docker_image])) > 1:
328
- raise LaunchError("Must specify exactly one of uri, job or image")
318
+ if bool(job) == bool(docker_image):
319
+ raise LaunchError(
320
+ "Exactly one of job or docker_image must be specified in the launch "
321
+ "spec."
322
+ )
329
323
 
330
324
 
331
325
  def parse_wandb_uri(uri: str) -> Tuple[str, str, str]:
@@ -336,77 +330,6 @@ def parse_wandb_uri(uri: str) -> Tuple[str, str, str]:
336
330
  return (ref.entity, ref.project, ref.run_id)
337
331
 
338
332
 
339
- def is_bare_wandb_uri(uri: str) -> bool:
340
- """Check that a wandb uri is valid.
341
-
342
- URI must be in the format
343
- `/<entity>/<project>/runs/<run_name>[other stuff]`
344
- or
345
- `/<entity>/<project>/artifacts/job/<job_name>[other stuff]`.
346
- """
347
- _logger.info(f"Checking if uri {uri} is bare...")
348
- return uri.startswith("/") and WandbReference.is_uri_job_or_run(uri)
349
-
350
-
351
- def fetch_wandb_project_run_info(
352
- entity: str, project: str, run_name: str, api: Api
353
- ) -> Any:
354
- _logger.info("Fetching run info...")
355
- try:
356
- result = api.get_run_info(entity, project, run_name)
357
- except CommError:
358
- result = None
359
- if result is None:
360
- raise LaunchError(
361
- f"Run info is invalid or doesn't exist for {api.settings('base_url')}/{entity}/{project}/runs/{run_name}"
362
- )
363
- if result.get("codePath") is None:
364
- # TODO: we don't currently expose codePath in the runInfo endpoint, this downloads
365
- # it from wandb-metadata.json if we can.
366
- metadata = api.download_url(
367
- project, "wandb-metadata.json", run=run_name, entity=entity
368
- )
369
- if metadata is not None:
370
- _, response = api.download_file(metadata["url"])
371
- data = response.json()
372
- result["codePath"] = data.get("codePath")
373
- result["cudaVersion"] = data.get("cuda", None)
374
-
375
- return result
376
-
377
-
378
- def download_entry_point(
379
- entity: str, project: str, run_name: str, api: Api, entry_point: str, dir: str
380
- ) -> bool:
381
- metadata = api.download_url(
382
- project, f"code/{entry_point}", run=run_name, entity=entity
383
- )
384
- if metadata is not None:
385
- _, response = api.download_file(metadata["url"])
386
- with util.fsync_open(os.path.join(dir, entry_point), "wb") as file:
387
- for data in response.iter_content(chunk_size=1024):
388
- file.write(data)
389
- return True
390
- return False
391
-
392
-
393
- def download_wandb_python_deps(
394
- entity: str, project: str, run_name: str, api: Api, dir: str
395
- ) -> Optional[str]:
396
- reqs = api.download_url(project, "requirements.txt", run=run_name, entity=entity)
397
- if reqs is not None:
398
- _logger.info("Downloading python dependencies")
399
- _, response = api.download_file(reqs["url"])
400
-
401
- with util.fsync_open(
402
- os.path.join(dir, "requirements.frozen.txt"), "wb"
403
- ) as file:
404
- for data in response.iter_content(chunk_size=1024):
405
- file.write(data)
406
- return "requirements.frozen.txt"
407
- return None
408
-
409
-
410
333
  def get_local_python_deps(
411
334
  dir: str, filename: str = "requirements.local.txt"
412
335
  ) -> Optional[str]:
@@ -498,19 +421,6 @@ def validate_wandb_python_deps(
498
421
  _logger.warning("Unable to validate local python dependencies")
499
422
 
500
423
 
501
- def fetch_project_diff(
502
- entity: str, project: str, run_name: str, api: Api
503
- ) -> Optional[str]:
504
- """Fetches project diff from wandb servers."""
505
- _logger.info("Searching for diff.patch")
506
- patch = None
507
- try:
508
- (_, _, patch, _) = api.run_config(project, run_name, entity)
509
- except CommError:
510
- pass
511
- return patch
512
-
513
-
514
424
  def apply_patch(patch_string: str, dst_dir: str) -> None:
515
425
  """Applies a patch file to a directory."""
516
426
  _logger.info("Applying diff.patch")
@@ -531,17 +441,6 @@ def apply_patch(patch_string: str, dst_dir: str) -> None:
531
441
  raise wandb.Error("Failed to apply diff.patch associated with run.")
532
442
 
533
443
 
534
- def _make_refspec_from_version(version: Optional[str]) -> List[str]:
535
- """Create a refspec that checks for the existence of origin/main and the version."""
536
- if version:
537
- return [f"+{version}"]
538
-
539
- return [
540
- "+refs/heads/main*:refs/remotes/origin/main*",
541
- "+refs/heads/master*:refs/remotes/origin/master*",
542
- ]
543
-
544
-
545
444
  def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> Optional[str]:
546
445
  """Clones the git repo at ``uri`` into ``dst_dir``.
547
446
 
@@ -561,13 +460,6 @@ def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> Optional[
561
460
  return version
562
461
 
563
462
 
564
- def merge_parameters(
565
- higher_priority_params: Dict[str, Any], lower_priority_params: Dict[str, Any]
566
- ) -> Dict[str, Any]:
567
- """Merge the contents of two dicts, keeping values from higher_priority_params if there are conflicts."""
568
- return {**lower_priority_params, **higher_priority_params}
569
-
570
-
571
463
  def convert_jupyter_notebook_to_script(fname: str, project_dir: str) -> str:
572
464
  nbconvert = wandb.util.get_module(
573
465
  "nbconvert", "nbformat and nbconvert are required to use launch with notebooks"
@@ -597,25 +489,6 @@ def convert_jupyter_notebook_to_script(fname: str, project_dir: str) -> str:
597
489
  return new_name
598
490
 
599
491
 
600
- def check_and_download_code_artifacts(
601
- entity: str, project: str, run_name: str, internal_api: Api, project_dir: str
602
- ) -> Optional["Artifact"]:
603
- _logger.info("Checking for code artifacts")
604
- public_api = wandb.PublicApi(
605
- overrides={"base_url": internal_api.settings("base_url")}
606
- )
607
-
608
- run = public_api.run(f"{entity}/{project}/{run_name}")
609
- run_artifacts = run.logged_artifacts()
610
-
611
- for artifact in run_artifacts:
612
- if hasattr(artifact, "type") and artifact.type == "code":
613
- artifact.download(project_dir)
614
- return artifact # type: ignore
615
-
616
- return None
617
-
618
-
619
492
  def to_camel_case(maybe_snake_str: str) -> str:
620
493
  if "_" not in maybe_snake_str:
621
494
  return maybe_snake_str
@@ -623,11 +496,6 @@ def to_camel_case(maybe_snake_str: str) -> str:
623
496
  return "".join(x.title() if x else "_" for x in components)
624
497
 
625
498
 
626
- def run_shell(args: List[str]) -> Tuple[str, str]:
627
- out = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
628
- return out.stdout.decode("utf-8").strip(), out.stderr.decode("utf-8").strip()
629
-
630
-
631
499
  def validate_build_and_registry_configs(
632
500
  build_config: Dict[str, Any], registry_config: Dict[str, Any]
633
501
  ) -> None:
@@ -864,3 +732,10 @@ def get_entrypoint_file(entrypoint: List[str]) -> Optional[str]:
864
732
  if len(entrypoint) < 2:
865
733
  return None
866
734
  return entrypoint[1]
735
+
736
+
737
+ def get_current_python_version() -> Tuple[str, str]:
738
+ full_version = sys.version.split()[0].split(".")
739
+ major = full_version[0]
740
+ version = ".".join(full_version[:2]) if len(full_version) >= 2 else major + ".0"
741
+ return version, major
@@ -13,7 +13,6 @@ else:
13
13
  _Setting = Literal[
14
14
  "_args",
15
15
  "_aws_lambda",
16
- "_async_upload_concurrency_limit",
17
16
  "_cli_only_mode",
18
17
  "_code_path_local",
19
18
  "_colab",
@@ -25,7 +24,6 @@ _Setting = Literal[
25
24
  "_disable_update_check",
26
25
  "_disable_viewer",
27
26
  "_disable_machine_info",
28
- "_except_exit",
29
27
  "_executable",
30
28
  "_extra_http_headers",
31
29
  "_file_stream_retry_max",
@@ -126,7 +124,6 @@ _Setting = Literal[
126
124
  "login_timeout",
127
125
  "mode",
128
126
  "notebook_name",
129
- "problem",
130
127
  "program",
131
128
  "program_abspath",
132
129
  "program_relpath",
@@ -179,7 +176,6 @@ _Setting = Literal[
179
176
  ]
180
177
 
181
178
  SETTINGS_TOPOLOGICALLY_SORTED: Final[Tuple[_Setting, ...]] = (
182
- "_async_upload_concurrency_limit",
183
179
  "_service_wait",
184
180
  "_stats_sample_rate_seconds",
185
181
  "_stats_samples_to_average",
@@ -189,7 +185,6 @@ SETTINGS_TOPOLOGICALLY_SORTED: Final[Tuple[_Setting, ...]] = (
189
185
  "console",
190
186
  "job_source",
191
187
  "mode",
192
- "problem",
193
188
  "project",
194
189
  "run_id",
195
190
  "start_method",
wandb/sdk/lib/fsm.py CHANGED
@@ -52,43 +52,39 @@ T_FsmContext_contra = TypeVar("T_FsmContext_contra", contravariant=True)
52
52
  @runtime_checkable
53
53
  class FsmStateCheck(Protocol[T_FsmInputs]):
54
54
  @abstractmethod
55
- def on_check(self, inputs: T_FsmInputs) -> None:
56
- ... # pragma: no cover
55
+ def on_check(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
57
56
 
58
57
 
59
58
  @runtime_checkable
60
59
  class FsmStateOutput(Protocol[T_FsmInputs]):
61
60
  @abstractmethod
62
- def on_state(self, inputs: T_FsmInputs) -> None:
63
- ... # pragma: no cover
61
+ def on_state(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
64
62
 
65
63
 
66
64
  @runtime_checkable
67
65
  class FsmStateEnter(Protocol[T_FsmInputs]):
68
66
  @abstractmethod
69
- def on_enter(self, inputs: T_FsmInputs) -> None:
70
- ... # pragma: no cover
67
+ def on_enter(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
71
68
 
72
69
 
73
70
  @runtime_checkable
74
71
  class FsmStateEnterWithContext(Protocol[T_FsmInputs, T_FsmContext_contra]):
75
72
  @abstractmethod
76
- def on_enter(self, inputs: T_FsmInputs, context: T_FsmContext_contra) -> None:
77
- ... # pragma: no cover
73
+ def on_enter(
74
+ self, inputs: T_FsmInputs, context: T_FsmContext_contra
75
+ ) -> None: ... # pragma: no cover
78
76
 
79
77
 
80
78
  @runtime_checkable
81
79
  class FsmStateStay(Protocol[T_FsmInputs]):
82
80
  @abstractmethod
83
- def on_stay(self, inputs: T_FsmInputs) -> None:
84
- ... # pragma: no cover
81
+ def on_stay(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
85
82
 
86
83
 
87
84
  @runtime_checkable
88
85
  class FsmStateExit(Protocol[T_FsmInputs, T_FsmContext_cov]):
89
86
  @abstractmethod
90
- def on_exit(self, inputs: T_FsmInputs) -> T_FsmContext_cov:
91
- ... # pragma: no cover
87
+ def on_exit(self, inputs: T_FsmInputs) -> T_FsmContext_cov: ... # pragma: no cover
92
88
 
93
89
 
94
90
  # It would be nice if python provided optional protocol members, but it doesnt as described here:
wandb/sdk/lib/gitlib.py CHANGED
@@ -14,7 +14,7 @@ try:
14
14
  Repo,
15
15
  )
16
16
  except ImportError:
17
- Repo = None
17
+ Repo = None # type: ignore
18
18
 
19
19
  if TYPE_CHECKING:
20
20
  from git import Repo
@@ -121,7 +121,7 @@ class GitRepo:
121
121
  # TODO: Saw a user getting a Unicode decode error when parsing refs,
122
122
  # more details on implementing a real fix in [WB-4064]
123
123
  try:
124
- if len(self.repo.refs) > 0:
124
+ if len(self.repo.refs) > 0: # type: ignore[arg-type]
125
125
  return self.repo.head.commit.hexsha
126
126
  else:
127
127
  return self.repo.git.show_ref("--head").split(" ")[0]
@@ -140,7 +140,7 @@ class GitRepo:
140
140
  if not self.repo:
141
141
  return None
142
142
  try:
143
- return self.repo.remotes[self.remote_name]
143
+ return self.repo.remotes[self.remote_name] # type: ignore[index]
144
144
  except IndexError:
145
145
  return None
146
146
 
@@ -200,7 +200,7 @@ class GitRepo:
200
200
  possible_relatives.append(tracking_branch.commit)
201
201
 
202
202
  if not possible_relatives:
203
- for branch in self.repo.branches:
203
+ for branch in self.repo.branches: # type: ignore[attr-defined]
204
204
  tracking_branch = branch.tracking_branch()
205
205
  if tracking_branch is not None:
206
206
  possible_relatives.append(tracking_branch.commit)
@@ -143,7 +143,7 @@ class _ImportHookChainedLoader:
143
143
  # None, so handle None as well. The module may not support attribute
144
144
  # assignment, in which case we simply skip it. Note that we also deal
145
145
  # with __loader__ not existing at all. This is to future proof things
146
- # due to proposal to remove the attribue as described in the GitHub
146
+ # due to proposal to remove the attribute as described in the GitHub
147
147
  # issue at https://github.com/python/cpython/issues/77458. Also prior
148
148
  # to Python 3.3, the __loader__ attribute was only set if a custom
149
149
  # module loader was used. It isn't clear whether the attribute still
@@ -1,6 +1,5 @@
1
1
  """module lazyloader."""
2
2
 
3
-
4
3
  import importlib
5
4
  import sys
6
5
  import types
@@ -12,7 +12,28 @@ if TYPE_CHECKING: # pragma: no cover
12
12
 
13
13
 
14
14
  def dict_from_proto_list(obj_list: "RepeatedCompositeFieldContainer") -> Dict[str, Any]:
15
- return {item.key: json.loads(item.value_json) for item in obj_list}
15
+ result: Dict[str, Any] = {}
16
+
17
+ for item in obj_list:
18
+ # Start from the root of the result dict
19
+ current_level = result
20
+
21
+ if len(item.nested_key) > 0:
22
+ keys = list(item.nested_key)
23
+ else:
24
+ keys = [item.key]
25
+
26
+ for key in keys[:-1]:
27
+ if key not in current_level:
28
+ current_level[key] = {}
29
+ # Move the reference deeper into the nested dictionary
30
+ current_level = current_level[key]
31
+
32
+ # Set the value at the final key location, parsing JSON from the value_json field
33
+ final_key = keys[-1]
34
+ current_level[final_key] = json.loads(item.value_json)
35
+
36
+ return result
16
37
 
17
38
 
18
39
  def _result_from_record(record: "pb.Record") -> "pb.Result":
@@ -29,7 +50,7 @@ def _assign_end_offset(record: "pb.Record", end_offset: int) -> None:
29
50
 
30
51
 
31
52
  def proto_encode_to_dict(
32
- pb_obj: Union["tpb.TelemetryRecord", "pb.MetricRecord"]
53
+ pb_obj: Union["tpb.TelemetryRecord", "pb.MetricRecord"],
33
54
  ) -> Dict[int, Any]:
34
55
  data: Dict[int, Any] = dict()
35
56
  fields = pb_obj.ListFields()