wandb 0.17.0rc1__py3-none-any.whl → 0.17.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (173) hide show
  1. wandb/__init__.py +1 -2
  2. wandb/apis/importers/internals/internal.py +0 -1
  3. wandb/apis/importers/wandb.py +12 -7
  4. wandb/apis/internal.py +0 -3
  5. wandb/apis/public/api.py +213 -79
  6. wandb/apis/public/artifacts.py +335 -100
  7. wandb/apis/public/files.py +9 -9
  8. wandb/apis/public/jobs.py +16 -4
  9. wandb/apis/public/projects.py +26 -28
  10. wandb/apis/public/query_generator.py +1 -1
  11. wandb/apis/public/runs.py +163 -65
  12. wandb/apis/public/sweeps.py +2 -2
  13. wandb/apis/reports/__init__.py +1 -7
  14. wandb/apis/reports/v1/__init__.py +5 -27
  15. wandb/apis/reports/v2/__init__.py +7 -19
  16. wandb/apis/workspaces/__init__.py +8 -0
  17. wandb/beta/workflows.py +8 -3
  18. wandb/cli/cli.py +131 -59
  19. wandb/data_types.py +6 -3
  20. wandb/docker/__init__.py +2 -2
  21. wandb/env.py +3 -3
  22. wandb/errors/term.py +10 -2
  23. wandb/filesync/step_checksum.py +1 -4
  24. wandb/filesync/step_prepare.py +4 -24
  25. wandb/filesync/step_upload.py +5 -107
  26. wandb/filesync/upload_job.py +0 -76
  27. wandb/integration/gym/__init__.py +35 -15
  28. wandb/integration/huggingface/resolver.py +2 -2
  29. wandb/integration/keras/callbacks/metrics_logger.py +1 -1
  30. wandb/integration/keras/keras.py +1 -1
  31. wandb/integration/openai/fine_tuning.py +21 -3
  32. wandb/integration/prodigy/prodigy.py +1 -1
  33. wandb/jupyter.py +16 -17
  34. wandb/old/summary.py +1 -1
  35. wandb/plot/confusion_matrix.py +1 -1
  36. wandb/plot/pr_curve.py +2 -1
  37. wandb/plot/roc_curve.py +2 -1
  38. wandb/{plots → plot}/utils.py +13 -25
  39. wandb/proto/v3/wandb_internal_pb2.py +54 -54
  40. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  41. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  42. wandb/proto/v4/wandb_internal_pb2.py +54 -54
  43. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  44. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  45. wandb/proto/v5/wandb_base_pb2.py +30 -0
  46. wandb/proto/v5/wandb_internal_pb2.py +355 -0
  47. wandb/proto/v5/wandb_server_pb2.py +63 -0
  48. wandb/proto/v5/wandb_settings_pb2.py +45 -0
  49. wandb/proto/v5/wandb_telemetry_pb2.py +41 -0
  50. wandb/proto/wandb_base_pb2.py +2 -0
  51. wandb/proto/wandb_deprecated.py +9 -1
  52. wandb/proto/wandb_generate_deprecated.py +34 -0
  53. wandb/proto/{wandb_internal_codegen.py → wandb_generate_proto.py} +1 -35
  54. wandb/proto/wandb_internal_pb2.py +2 -0
  55. wandb/proto/wandb_server_pb2.py +2 -0
  56. wandb/proto/wandb_settings_pb2.py +2 -0
  57. wandb/proto/wandb_telemetry_pb2.py +2 -0
  58. wandb/sdk/artifacts/artifact.py +68 -22
  59. wandb/sdk/artifacts/artifact_manifest.py +1 -1
  60. wandb/sdk/artifacts/artifact_manifest_entry.py +6 -3
  61. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -1
  62. wandb/sdk/artifacts/artifact_saver.py +1 -10
  63. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +6 -2
  64. wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -1
  65. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +6 -4
  66. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +2 -42
  67. wandb/sdk/artifacts/storage_policy.py +1 -12
  68. wandb/sdk/data_types/_dtypes.py +8 -8
  69. wandb/sdk/data_types/image.py +2 -2
  70. wandb/sdk/data_types/video.py +5 -3
  71. wandb/sdk/integration_utils/data_logging.py +5 -5
  72. wandb/sdk/interface/interface.py +14 -1
  73. wandb/sdk/interface/interface_shared.py +1 -1
  74. wandb/sdk/internal/file_pusher.py +2 -5
  75. wandb/sdk/internal/file_stream.py +6 -19
  76. wandb/sdk/internal/internal_api.py +148 -136
  77. wandb/sdk/internal/job_builder.py +208 -136
  78. wandb/sdk/internal/progress.py +0 -28
  79. wandb/sdk/internal/sender.py +102 -39
  80. wandb/sdk/internal/settings_static.py +8 -1
  81. wandb/sdk/internal/system/assets/trainium.py +3 -3
  82. wandb/sdk/internal/system/system_info.py +4 -2
  83. wandb/sdk/internal/update.py +1 -1
  84. wandb/sdk/launch/__init__.py +9 -1
  85. wandb/sdk/launch/_launch.py +4 -24
  86. wandb/sdk/launch/_launch_add.py +1 -3
  87. wandb/sdk/launch/_project_spec.py +187 -225
  88. wandb/sdk/launch/agent/agent.py +59 -19
  89. wandb/sdk/launch/agent/config.py +0 -3
  90. wandb/sdk/launch/builder/abstract.py +68 -1
  91. wandb/sdk/launch/builder/build.py +165 -576
  92. wandb/sdk/launch/builder/context_manager.py +235 -0
  93. wandb/sdk/launch/builder/docker_builder.py +7 -23
  94. wandb/sdk/launch/builder/kaniko_builder.py +12 -25
  95. wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
  96. wandb/sdk/launch/create_job.py +51 -45
  97. wandb/sdk/launch/environment/aws_environment.py +26 -1
  98. wandb/sdk/launch/inputs/files.py +148 -0
  99. wandb/sdk/launch/inputs/internal.py +224 -0
  100. wandb/sdk/launch/inputs/manage.py +95 -0
  101. wandb/sdk/launch/registry/google_artifact_registry.py +1 -1
  102. wandb/sdk/launch/runner/abstract.py +2 -2
  103. wandb/sdk/launch/runner/kubernetes_monitor.py +45 -12
  104. wandb/sdk/launch/runner/kubernetes_runner.py +6 -8
  105. wandb/sdk/launch/runner/local_container.py +2 -3
  106. wandb/sdk/launch/runner/local_process.py +8 -29
  107. wandb/sdk/launch/runner/sagemaker_runner.py +20 -14
  108. wandb/sdk/launch/runner/vertex_runner.py +8 -7
  109. wandb/sdk/launch/sweeps/scheduler.py +5 -3
  110. wandb/sdk/launch/sweeps/scheduler_sweep.py +1 -1
  111. wandb/sdk/launch/sweeps/utils.py +4 -4
  112. wandb/sdk/launch/utils.py +16 -138
  113. wandb/sdk/lib/_settings_toposort_generated.py +2 -5
  114. wandb/sdk/lib/apikey.py +4 -2
  115. wandb/sdk/lib/config_util.py +3 -3
  116. wandb/sdk/lib/import_hooks.py +1 -1
  117. wandb/sdk/lib/proto_util.py +22 -1
  118. wandb/sdk/lib/redirect.py +20 -15
  119. wandb/sdk/lib/tracelog.py +1 -1
  120. wandb/sdk/service/service.py +2 -1
  121. wandb/sdk/service/streams.py +5 -5
  122. wandb/sdk/wandb_init.py +25 -59
  123. wandb/sdk/wandb_login.py +28 -25
  124. wandb/sdk/wandb_run.py +123 -53
  125. wandb/sdk/wandb_settings.py +33 -64
  126. wandb/sdk/wandb_setup.py +1 -1
  127. wandb/sdk/wandb_watch.py +1 -1
  128. wandb/sklearn/plot/classifier.py +10 -12
  129. wandb/sklearn/plot/clusterer.py +1 -1
  130. wandb/sync/sync.py +2 -2
  131. wandb/testing/relay.py +32 -17
  132. wandb/util.py +36 -37
  133. wandb/wandb_agent.py +3 -3
  134. wandb/wandb_controller.py +5 -4
  135. {wandb-0.17.0rc1.dist-info → wandb-0.17.1.dist-info}/METADATA +8 -10
  136. {wandb-0.17.0rc1.dist-info → wandb-0.17.1.dist-info}/RECORD +139 -161
  137. {wandb-0.17.0rc1.dist-info → wandb-0.17.1.dist-info}/WHEEL +1 -1
  138. wandb/apis/reports/v1/_blocks.py +0 -1406
  139. wandb/apis/reports/v1/_helpers.py +0 -70
  140. wandb/apis/reports/v1/_panels.py +0 -1282
  141. wandb/apis/reports/v1/_templates.py +0 -478
  142. wandb/apis/reports/v1/blocks.py +0 -27
  143. wandb/apis/reports/v1/helpers.py +0 -2
  144. wandb/apis/reports/v1/mutations.py +0 -66
  145. wandb/apis/reports/v1/panels.py +0 -17
  146. wandb/apis/reports/v1/report.py +0 -268
  147. wandb/apis/reports/v1/runset.py +0 -144
  148. wandb/apis/reports/v1/templates.py +0 -7
  149. wandb/apis/reports/v1/util.py +0 -406
  150. wandb/apis/reports/v1/validators.py +0 -131
  151. wandb/apis/reports/v2/blocks.py +0 -25
  152. wandb/apis/reports/v2/expr_parsing.py +0 -257
  153. wandb/apis/reports/v2/gql.py +0 -68
  154. wandb/apis/reports/v2/interface.py +0 -1911
  155. wandb/apis/reports/v2/internal.py +0 -867
  156. wandb/apis/reports/v2/metrics.py +0 -6
  157. wandb/apis/reports/v2/panels.py +0 -15
  158. wandb/catboost/__init__.py +0 -9
  159. wandb/fastai/__init__.py +0 -9
  160. wandb/keras/__init__.py +0 -19
  161. wandb/lightgbm/__init__.py +0 -9
  162. wandb/plots/__init__.py +0 -6
  163. wandb/plots/explain_text.py +0 -36
  164. wandb/plots/heatmap.py +0 -81
  165. wandb/plots/named_entity.py +0 -43
  166. wandb/plots/part_of_speech.py +0 -50
  167. wandb/plots/plot_definitions.py +0 -768
  168. wandb/plots/precision_recall.py +0 -121
  169. wandb/plots/roc.py +0 -103
  170. wandb/sacred/__init__.py +0 -3
  171. wandb/xgboost/__init__.py +0 -9
  172. {wandb-0.17.0rc1.dist-info → wandb-0.17.1.dist-info}/entry_points.txt +0 -0
  173. {wandb-0.17.0rc1.dist-info → wandb-0.17.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,95 @@
1
+ """Functions for declaring overridable configuration for launch jobs."""
2
+
3
+ from typing import List, Optional
4
+
5
+
6
+ def manage_config_file(
7
+ path: str,
8
+ include: Optional[List[str]] = None,
9
+ exclude: Optional[List[str]] = None,
10
+ ):
11
+ r"""Declare an overridable configuration file for a launch job.
12
+
13
+ If a new job version is created from the active run, the configuration file
14
+ will be added to the job's inputs. If the job is launched and overrides
15
+ have been provided for the configuration file, this function will detect
16
+ the overrides from the environment and update the configuration file on disk.
17
+ Note that these overrides will only be applied in ephemeral containers.
18
+ `include` and `exclude` are lists of dot separated paths with the config.
19
+ The paths are used to filter subtrees of the configuration file out of the
20
+ job's inputs.
21
+
22
+ For example, given the following configuration file:
23
+ ```yaml
24
+ model:
25
+ name: resnet
26
+ layers: 18
27
+ training:
28
+ epochs: 10
29
+ batch_size: 32
30
+ ```
31
+
32
+ Passing `include=['model']` will only include the `model` subtree in the
33
+ job's inputs. Passing `exclude=['model.layers']` will exclude the `layers`
34
+ key from the `model` subtree. Note that `exclude` takes precedence over
35
+ `include`.
36
+
37
+ `.` is used as a separator for nested keys. If a key contains a `.`, it
38
+ should be escaped with a backslash, e.g. `include=[r'model\.layers']`. Note
39
+ the use of `r` to denote a raw string when using escape chars.
40
+
41
+ Args:
42
+ path (str): The path to the configuration file. This path must be
43
+ relative and must not contain backwards traversal, i.e. `..`.
44
+ include (List[str]): A list of keys to include in the configuration file.
45
+ exclude (List[str]): A list of keys to exclude from the configuration file.
46
+
47
+ Raises:
48
+ LaunchError: If the path is not valid, or if there is no active run.
49
+ """
50
+ from .internal import handle_config_file_input
51
+
52
+ return handle_config_file_input(path, include, exclude)
53
+
54
+
55
+ def manage_wandb_config(
56
+ include: Optional[List[str]] = None,
57
+ exclude: Optional[List[str]] = None,
58
+ ):
59
+ r"""Declare wandb.config as an overridable configuration for a launch job.
60
+
61
+ If a new job version is created from the active run, the run config
62
+ (wandb.config) will become an overridable input of the job. If the job is
63
+ launched and overrides have been provided for the run config, the overrides
64
+ will be applied to the run config when `wandb.init` is called.
65
+ `include` and `exclude` are lists of dot separated paths with the config.
66
+ The paths are used to filter subtrees of the configuration file out of the
67
+ job's inputs.
68
+
69
+ For example, given the following run config contents:
70
+ ```yaml
71
+ model:
72
+ name: resnet
73
+ layers: 18
74
+ training:
75
+ epochs: 10
76
+ batch_size: 32
77
+ ```
78
+ Passing `include=['model']` will only include the `model` subtree in the
79
+ job's inputs. Passing `exclude=['model.layers']` will exclude the `layers`
80
+ key from the `model` subtree. Note that `exclude` takes precedence over
81
+ `include`.
82
+ `.` is used as a separator for nested keys. If a key contains a `.`, it
83
+ should be escaped with a backslash, e.g. `include=[r'model\.layers']`. Note
84
+ the use of `r` to denote a raw string when using escape chars.
85
+
86
+ Args:
87
+ include (List[str]): A list of subtrees to include in the configuration.
88
+ exclude (List[str]): A list of subtrees to exclude from the configuration.
89
+
90
+ Raises:
91
+ LaunchError: If there is no active run.
92
+ """
93
+ from .internal import handle_run_config_input
94
+
95
+ handle_run_config_input(include, exclude)
@@ -211,7 +211,7 @@ class GoogleArtifactRegistry(AbstractRegistry):
211
211
  for image in await list_images(request={"parent": parent}):
212
212
  if tag in image.tags:
213
213
  return True
214
- except google.api_core.exceptions.NotFound as e:
214
+ except google.api_core.exceptions.NotFound as e: # type: ignore[attr-defined]
215
215
  raise LaunchError(
216
216
  f"The Google Artifact Registry repository {self.repository} "
217
217
  f"does not exist. Please create it or modify your registry configuration."
@@ -40,9 +40,9 @@ State = Literal[
40
40
 
41
41
 
42
42
  class Status:
43
- def __init__(self, state: "State" = "unknown", data=None): # type: ignore
43
+ def __init__(self, state: "State" = "unknown", messages: List[str] = None): # type: ignore
44
44
  self.state = state
45
- self.data = data or {}
45
+ self.messages = messages or []
46
46
 
47
47
  def __repr__(self) -> "State":
48
48
  return self.state
@@ -14,6 +14,7 @@ from kubernetes_asyncio.client import ( # type: ignore # noqa: F401
14
14
  BatchV1Api,
15
15
  CoreV1Api,
16
16
  CustomObjectsApi,
17
+ V1Pod,
17
18
  V1PodStatus,
18
19
  )
19
20
 
@@ -118,6 +119,27 @@ def _is_container_creating(status: "V1PodStatus") -> bool:
118
119
  return False
119
120
 
120
121
 
122
+ def _is_pod_unschedulable(status: "V1PodStatus") -> Tuple[bool, str]:
123
+ """Return whether the pod is unschedulable along with the reason message."""
124
+ if not status.conditions:
125
+ return False, ""
126
+ for condition in status.conditions:
127
+ if (
128
+ condition.type == "PodScheduled"
129
+ and condition.status == "False"
130
+ and condition.reason == "Unschedulable"
131
+ ):
132
+ return True, condition.message
133
+ return False, ""
134
+
135
+
136
+ def _get_crd_job_name(object: "V1Pod") -> Optional[str]:
137
+ refs = object.metadata.owner_references
138
+ if refs:
139
+ return refs[0].name
140
+ return None
141
+
142
+
121
143
  def _state_from_conditions(conditions: List[Dict[str, Any]]) -> Optional[State]:
122
144
  """Get the status from the pod conditions."""
123
145
  true_conditions = [
@@ -298,10 +320,18 @@ class LaunchKubernetesMonitor:
298
320
  counts[state] += 1
299
321
  return counts
300
322
 
301
- def _set_status(self, job_name: str, status: Status) -> None:
323
+ def _set_status_state(self, job_name: str, state: State) -> None:
302
324
  """Set the status of the run."""
303
- if self._job_states.get(job_name) != status:
304
- self._job_states[job_name] = status
325
+ if job_name not in self._job_states:
326
+ self._job_states[job_name] = Status(state)
327
+ elif self._job_states[job_name].state != state:
328
+ self._job_states[job_name].state = state
329
+
330
+ def _add_status_message(self, job_name: str, message: str) -> None:
331
+ if job_name not in self._job_states:
332
+ self._job_states[job_name] = Status("unknown")
333
+ wandb.termwarn(f"Warning from Kubernetes for job {job_name}: {message}")
334
+ self._job_states[job_name].messages.append(message)
305
335
 
306
336
  async def _monitor_pods(self, namespace: str) -> None:
307
337
  """Monitor a namespace for changes."""
@@ -312,15 +342,19 @@ class LaunchKubernetesMonitor:
312
342
  label_selector=self._label_selector,
313
343
  ):
314
344
  obj = event.get("object")
315
- job_name = obj.metadata.labels.get("job-name")
345
+ job_name = obj.metadata.labels.get("job-name") or _get_crd_job_name(obj)
316
346
  if job_name is None or not hasattr(obj, "status"):
317
347
  continue
318
348
  if self.__get_status(job_name) in ["finished", "failed"]:
319
349
  continue
350
+
351
+ is_unschedulable, reason = _is_pod_unschedulable(obj.status)
352
+ if is_unschedulable:
353
+ self._add_status_message(job_name, reason)
320
354
  if obj.status.phase == "Running" or _is_container_creating(obj.status):
321
- self._set_status(job_name, Status("running"))
355
+ self._set_status_state(job_name, "running")
322
356
  elif _is_preempted(obj.status):
323
- self._set_status(job_name, Status("preempted"))
357
+ self._set_status_state(job_name, "preempted")
324
358
 
325
359
  async def _monitor_jobs(self, namespace: str) -> None:
326
360
  """Monitor a namespace for changes."""
@@ -334,15 +368,15 @@ class LaunchKubernetesMonitor:
334
368
  job_name = obj.metadata.name
335
369
 
336
370
  if obj.status.succeeded == 1:
337
- self._set_status(job_name, Status("finished"))
371
+ self._set_status_state(job_name, "finished")
338
372
  elif obj.status.failed is not None and obj.status.failed >= 1:
339
- self._set_status(job_name, Status("failed"))
373
+ self._set_status_state(job_name, "failed")
340
374
 
341
375
  # If the job is deleted and we haven't seen a terminal state
342
376
  # then we will consider the job failed.
343
377
  if event.get("type") == "DELETED":
344
378
  if self._job_states.get(job_name) != Status("finished"):
345
- self._set_status(job_name, Status("failed"))
379
+ self._set_status_state(job_name, "failed")
346
380
 
347
381
  async def _monitor_crd(
348
382
  self, namespace: str, custom_resource: CustomResource
@@ -355,7 +389,7 @@ class LaunchKubernetesMonitor:
355
389
  plural=custom_resource.plural,
356
390
  group=custom_resource.group,
357
391
  version=custom_resource.version,
358
- label_selector=self._label_selector, # TODO: Label selector doesn't work for CRDs.
392
+ label_selector=self._label_selector,
359
393
  ):
360
394
  object = event.get("object")
361
395
  name = object.get("metadata", dict()).get("name")
@@ -383,8 +417,7 @@ class LaunchKubernetesMonitor:
383
417
  )
384
418
  if state is None:
385
419
  continue
386
- status = Status(state)
387
- self._set_status(name, status)
420
+ self._set_status_state(name, state)
388
421
 
389
422
 
390
423
  class SafeWatch:
@@ -29,7 +29,6 @@ from wandb.sdk.lib.retry import ExponentialBackoff, retry_async
29
29
  from wandb.util import get_module
30
30
 
31
31
  from .._project_spec import EntryPoint, LaunchProject
32
- from ..builder.build import get_env_vars_dict
33
32
  from ..errors import LaunchError
34
33
  from ..utils import (
35
34
  LOG_PREFIX,
@@ -374,8 +373,7 @@ class KubernetesRunner(AbstractRunner):
374
373
  }
375
374
 
376
375
  entry_point = (
377
- launch_project.override_entrypoint
378
- or launch_project.get_single_entry_point()
376
+ launch_project.override_entrypoint or launch_project.get_job_entry_point()
379
377
  )
380
378
  if launch_project.docker_image:
381
379
  # dont specify run id if user provided image, could have multiple runs
@@ -401,8 +399,8 @@ class KubernetesRunner(AbstractRunner):
401
399
  launch_project.override_entrypoint is not None,
402
400
  )
403
401
 
404
- env_vars = get_env_vars_dict(
405
- launch_project, self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
402
+ env_vars = launch_project.get_env_vars_dict(
403
+ self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
406
404
  )
407
405
  api_key_secret = None
408
406
  for cont in containers:
@@ -511,8 +509,8 @@ class KubernetesRunner(AbstractRunner):
511
509
  api_version = resource_args.get("apiVersion", "batch/v1")
512
510
 
513
511
  if api_version not in ["batch/v1", "batch/v1beta1"]:
514
- env_vars = get_env_vars_dict(
515
- launch_project, self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
512
+ env_vars = launch_project.get_env_vars_dict(
513
+ self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
516
514
  )
517
515
  # Crawl the resource args and add our env vars to the containers.
518
516
  add_wandb_env(resource_args, env_vars)
@@ -537,7 +535,7 @@ class KubernetesRunner(AbstractRunner):
537
535
  if LaunchAgent.initialized():
538
536
  add_label_to_pods(
539
537
  resource_args,
540
- WANDB_K8S_LABEL_MONITOR,
538
+ WANDB_K8S_LABEL_AGENT,
541
539
  LaunchAgent.name(),
542
540
  )
543
541
  resource_args["metadata"]["labels"][WANDB_K8S_LABEL_AGENT] = (
@@ -12,7 +12,6 @@ from wandb.sdk.launch.environment.abstract import AbstractEnvironment
12
12
  from wandb.sdk.launch.registry.abstract import AbstractRegistry
13
13
 
14
14
  from .._project_spec import LaunchProject
15
- from ..builder.build import get_env_vars_dict
16
15
  from ..errors import LaunchError
17
16
  from ..utils import (
18
17
  LOG_PREFIX,
@@ -133,8 +132,8 @@ class LocalContainerRunner(AbstractRunner):
133
132
  docker_args = self._populate_docker_args(launch_project, image_uri)
134
133
  synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
135
134
 
136
- env_vars = get_env_vars_dict(
137
- launch_project, self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
135
+ env_vars = launch_project.get_env_vars_dict(
136
+ self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
138
137
  )
139
138
 
140
139
  # When running against local port, need to swap to local docker host
@@ -4,16 +4,12 @@ from typing import Any, List, Optional
4
4
 
5
5
  import wandb
6
6
 
7
- from .._project_spec import LaunchProject, get_entry_point_command
8
- from ..builder.build import get_env_vars_dict
7
+ from .._project_spec import LaunchProject
9
8
  from ..errors import LaunchError
10
9
  from ..utils import (
11
10
  LOG_PREFIX,
12
11
  MAX_ENV_LENGTHS,
13
12
  PROJECT_SYNCHRONOUS,
14
- _is_wandb_uri,
15
- download_wandb_python_deps,
16
- parse_wandb_uri,
17
13
  sanitize_wandb_api_key,
18
14
  validate_wandb_python_deps,
19
15
  )
@@ -47,8 +43,7 @@ class LocalProcessRunner(AbstractRunner):
47
43
 
48
44
  synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
49
45
  entry_point = (
50
- launch_project.override_entrypoint
51
- or launch_project.get_single_entry_point()
46
+ launch_project.override_entrypoint or launch_project.get_job_entry_point()
52
47
  )
53
48
 
54
49
  cmd: List[Any] = []
@@ -56,23 +51,7 @@ class LocalProcessRunner(AbstractRunner):
56
51
  if launch_project.project_dir is None:
57
52
  raise LaunchError("Launch LocalProcessRunner received empty project dir")
58
53
 
59
- # Check to make sure local python dependencies match run's requirement.txt
60
- if launch_project.uri and _is_wandb_uri(launch_project.uri):
61
- source_entity, source_project, run_name = parse_wandb_uri(
62
- launch_project.uri
63
- )
64
- run_requirements_file = download_wandb_python_deps(
65
- source_entity,
66
- source_project,
67
- run_name,
68
- self._api,
69
- launch_project.project_dir,
70
- )
71
- validate_wandb_python_deps(
72
- run_requirements_file,
73
- launch_project.project_dir,
74
- )
75
- elif launch_project.job:
54
+ if launch_project.job:
76
55
  assert launch_project._job_artifact is not None
77
56
  try:
78
57
  validate_wandb_python_deps(
@@ -81,14 +60,14 @@ class LocalProcessRunner(AbstractRunner):
81
60
  )
82
61
  except Exception:
83
62
  wandb.termwarn("Unable to validate python dependencies")
84
- env_vars = get_env_vars_dict(
85
- launch_project, self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
63
+ env_vars = launch_project.get_env_vars_dict(
64
+ self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
86
65
  )
87
66
  for env_key, env_value in env_vars.items():
88
67
  cmd += [f"{shlex.quote(env_key)}={shlex.quote(env_value)}"]
89
-
90
- entry_cmd = get_entry_point_command(entry_point, launch_project.override_args)
91
- cmd += entry_cmd
68
+ if entry_point is not None:
69
+ cmd += entry_point.command
70
+ cmd += launch_project.override_args
92
71
 
93
72
  command_str = " ".join(cmd).strip()
94
73
  _msg = f"{LOG_PREFIX}Launching run as a local-process with command {sanitize_wandb_api_key(command_str)}"
@@ -12,8 +12,7 @@ from wandb.apis.internal import Api
12
12
  from wandb.sdk.launch.environment.aws_environment import AwsEnvironment
13
13
  from wandb.sdk.launch.errors import LaunchError
14
14
 
15
- from .._project_spec import EntryPoint, LaunchProject, get_entry_point_command
16
- from ..builder.build import get_env_vars_dict
15
+ from .._project_spec import EntryPoint, LaunchProject
17
16
  from ..registry.abstract import AbstractRegistry
18
17
  from ..utils import (
19
18
  LOG_PREFIX,
@@ -68,6 +67,7 @@ class SagemakerSubmittedRun(AbstractRun):
68
67
  logGroupName="/aws/sagemaker/TrainingJobs",
69
68
  logStreamName=log_name,
70
69
  )
70
+ assert "events" in res
71
71
  return "\n".join(
72
72
  [f'{event["timestamp"]}:{event["message"]}' for event in res["events"]]
73
73
  )
@@ -179,7 +179,10 @@ class SageMakerRunner(AbstractRunner):
179
179
  caller_id = client.get_caller_identity()
180
180
  account_id = caller_id["Account"]
181
181
  _logger.info(f"Using account ID {account_id}")
182
- role_arn = get_role_arn(given_sagemaker_args, self.backend_config, account_id)
182
+ partition = await self.environment.get_partition()
183
+ role_arn = get_role_arn(
184
+ given_sagemaker_args, self.backend_config, account_id, partition
185
+ )
183
186
 
184
187
  # Create a sagemaker client to launch the job.
185
188
  sagemaker_client = session.client("sagemaker")
@@ -221,12 +224,12 @@ class SageMakerRunner(AbstractRunner):
221
224
  launch_project.fill_macros(image_uri)
222
225
  _logger.info("Connecting to sagemaker client")
223
226
  entry_point = (
224
- launch_project.override_entrypoint
225
- or launch_project.get_single_entry_point()
226
- )
227
- command_args = get_entry_point_command(
228
- entry_point, launch_project.override_args
227
+ launch_project.override_entrypoint or launch_project.get_job_entry_point()
229
228
  )
229
+ command_args = []
230
+ if entry_point is not None:
231
+ command_args += entry_point.command
232
+ command_args += launch_project.override_args
230
233
  if command_args:
231
234
  command_str = " ".join(command_args)
232
235
  wandb.termlog(
@@ -349,18 +352,18 @@ def build_sagemaker_args(
349
352
 
350
353
  if sagemaker_args.get("ResourceConfig") is None:
351
354
  raise LaunchError(
352
- "Sagemaker launcher requires a ResourceConfig Sagemaker resource argument"
355
+ "Sagemaker launcher requires a ResourceConfig resource argument"
353
356
  )
354
357
 
355
358
  if sagemaker_args.get("StoppingCondition") is None:
356
359
  raise LaunchError(
357
- "Sagemaker launcher requires a StoppingCondition Sagemaker resource argument"
360
+ "Sagemaker launcher requires a StoppingCondition resource argument"
358
361
  )
359
362
 
360
363
  given_env = given_sagemaker_args.get(
361
364
  "Environment", sagemaker_args.get("environment", {})
362
365
  )
363
- calced_env = get_env_vars_dict(launch_project, api, max_env_length)
366
+ calced_env = launch_project.get_env_vars_dict(api, max_env_length)
364
367
  total_env = {**calced_env, **given_env}
365
368
  sagemaker_args["Environment"] = total_env
366
369
 
@@ -405,7 +408,10 @@ async def launch_sagemaker_job(
405
408
 
406
409
 
407
410
  def get_role_arn(
408
- sagemaker_args: Dict[str, Any], backend_config: Dict[str, Any], account_id: str
411
+ sagemaker_args: Dict[str, Any],
412
+ backend_config: Dict[str, Any],
413
+ account_id: str,
414
+ partition: str,
409
415
  ) -> str:
410
416
  """Get the role arn from the sagemaker args or the backend config."""
411
417
  role_arn = sagemaker_args.get("RoleArn") or sagemaker_args.get("role_arn")
@@ -416,7 +422,7 @@ def get_role_arn(
416
422
  "AWS sagemaker require a string RoleArn set this by adding a `RoleArn` key to the sagemaker"
417
423
  "field of resource_args"
418
424
  )
419
- if role_arn.startswith("arn:aws:iam::"):
425
+ if role_arn.startswith(f"arn:{partition}:iam::"):
420
426
  return role_arn # type: ignore
421
427
 
422
- return f"arn:aws:iam::{account_id}:role/{role_arn}"
428
+ return f"arn:{partition}:iam::{account_id}:role/{role_arn}"
@@ -8,8 +8,7 @@ if False:
8
8
  from wandb.apis.internal import Api
9
9
  from wandb.util import get_module
10
10
 
11
- from .._project_spec import LaunchProject, get_entry_point_command
12
- from ..builder.build import get_env_vars_dict
11
+ from .._project_spec import LaunchProject
13
12
  from ..environment.gcp_environment import GcpEnvironment
14
13
  from ..errors import LaunchError
15
14
  from ..registry.abstract import AbstractRegistry
@@ -113,14 +112,16 @@ class VertexRunner(AbstractRunner):
113
112
  synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
114
113
 
115
114
  entry_point = (
116
- launch_project.override_entrypoint
117
- or launch_project.get_single_entry_point()
115
+ launch_project.override_entrypoint or launch_project.get_job_entry_point()
118
116
  )
119
117
 
120
118
  # TODO: Set entrypoint in each container
121
- entry_cmd = get_entry_point_command(entry_point, launch_project.override_args)
122
- env_vars = get_env_vars_dict(
123
- launch_project=launch_project,
119
+ entry_cmd = []
120
+ if entry_point is not None:
121
+ entry_cmd += entry_point.command
122
+ entry_cmd += launch_project.override_args
123
+
124
+ env_vars = launch_project.get_env_vars_dict(
124
125
  api=self._api,
125
126
  max_env_length=MAX_ENV_LENGTHS[self.__class__.__name__],
126
127
  )
@@ -408,7 +408,7 @@ class Scheduler(ABC):
408
408
  return count
409
409
 
410
410
  def _try_load_executable(self) -> bool:
411
- """Check existance of valid executable for a run.
411
+ """Check existence of valid executable for a run.
412
412
 
413
413
  logs and returns False when job is unreachable
414
414
  """
@@ -423,7 +423,7 @@ class Scheduler(ABC):
423
423
  return False
424
424
  return True
425
425
  elif self._kwargs.get("image_uri"):
426
- # TODO(gst): check docker existance? Use registry in launch config?
426
+ # TODO(gst): check docker existence? Use registry in launch config?
427
427
  return True
428
428
  else:
429
429
  return False
@@ -611,7 +611,7 @@ class Scheduler(ABC):
611
611
  f"Failed to get runstate for run ({run_id}). Error: {traceback.format_exc()}"
612
612
  )
613
613
  run_state = RunState.FAILED
614
- else: # first time we get unknwon state
614
+ else: # first time we get unknown state
615
615
  run_state = RunState.UNKNOWN
616
616
  except (AttributeError, ValueError):
617
617
  wandb.termwarn(
@@ -668,6 +668,8 @@ class Scheduler(ABC):
668
668
  launch_config = copy.deepcopy(self._wandb_run.config.get("launch", {}))
669
669
  if "overrides" not in launch_config:
670
670
  launch_config["overrides"] = {"run_config": {}}
671
+ if "run_config" not in launch_config["overrides"]:
672
+ launch_config["overrides"]["run_config"] = {}
671
673
  launch_config["overrides"]["run_config"].update(args["args_dict"])
672
674
 
673
675
  if macro_args: # pipe in hyperparam args as params to launch
@@ -59,7 +59,7 @@ class SweepScheduler(Scheduler):
59
59
  return None
60
60
 
61
61
  def _get_sweep_commands(self, worker_id: int) -> List[Dict[str, Any]]:
62
- """Helper to recieve sweep command from backend."""
62
+ """Helper to receive sweep command from backend."""
63
63
  # AgentHeartbeat wants a Dict of runs which are running or queued
64
64
  _run_states: Dict[str, bool] = {}
65
65
  for run_id, run in self._yield_runs():
@@ -211,13 +211,13 @@ def create_sweep_command_args(command: Dict) -> Dict[str, Any]:
211
211
 
212
212
  """
213
213
  if "args" not in command:
214
- raise ValueError('No "args" found in command: %s' % command)
214
+ raise ValueError('No "args" found in command: {}'.format(command))
215
215
  # four different formats of command args
216
216
  # (1) standard command line flags (e.g. --foo=bar)
217
217
  flags: List[str] = []
218
218
  # (2) flags without hyphens (e.g. foo=bar)
219
219
  flags_no_hyphens: List[str] = []
220
- # (3) flags with false booleans ommited (e.g. --foo)
220
+ # (3) flags with false booleans omitted (e.g. --foo)
221
221
  flags_no_booleans: List[str] = []
222
222
  # (4) flags as a dictionary (used for constructing a json)
223
223
  flags_dict: Dict[str, Any] = {}
@@ -228,7 +228,7 @@ def create_sweep_command_args(command: Dict) -> Dict[str, Any]:
228
228
  try:
229
229
  _value: Any = config["value"]
230
230
  except KeyError:
231
- raise ValueError('No "value" found for command["args"]["%s"]' % param)
231
+ raise ValueError('No "value" found for command["args"]["{}"]'.format(param))
232
232
 
233
233
  _flag: str = f"{param}={_value}"
234
234
  flags.append("--" + _flag)
@@ -257,7 +257,7 @@ def make_launch_sweep_entrypoint(
257
257
  """Use args dict from create_sweep_command_args to construct entrypoint.
258
258
 
259
259
  If replace is True, remove macros from entrypoint, fill them in with args
260
- and then return the args in seperate return value.
260
+ and then return the args in separate return value.
261
261
  """
262
262
  if not command:
263
263
  return None, None