wandb 0.17.0rc1__py3-none-win32.whl → 0.17.1__py3-none-win32.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -2
- wandb/apis/importers/internals/internal.py +0 -1
- wandb/apis/importers/wandb.py +12 -7
- wandb/apis/internal.py +0 -3
- wandb/apis/public/api.py +213 -79
- wandb/apis/public/artifacts.py +335 -100
- wandb/apis/public/files.py +9 -9
- wandb/apis/public/jobs.py +16 -4
- wandb/apis/public/projects.py +26 -28
- wandb/apis/public/query_generator.py +1 -1
- wandb/apis/public/runs.py +163 -65
- wandb/apis/public/sweeps.py +2 -2
- wandb/apis/reports/__init__.py +1 -7
- wandb/apis/reports/v1/__init__.py +5 -27
- wandb/apis/reports/v2/__init__.py +7 -19
- wandb/apis/workspaces/__init__.py +8 -0
- wandb/beta/workflows.py +8 -3
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +131 -59
- wandb/data_types.py +6 -3
- wandb/docker/__init__.py +2 -2
- wandb/env.py +3 -3
- wandb/errors/term.py +10 -2
- wandb/filesync/step_checksum.py +1 -4
- wandb/filesync/step_prepare.py +4 -24
- wandb/filesync/step_upload.py +5 -107
- wandb/filesync/upload_job.py +0 -76
- wandb/integration/gym/__init__.py +35 -15
- wandb/integration/huggingface/resolver.py +2 -2
- wandb/integration/keras/callbacks/metrics_logger.py +1 -1
- wandb/integration/keras/keras.py +1 -1
- wandb/integration/openai/fine_tuning.py +21 -3
- wandb/integration/prodigy/prodigy.py +1 -1
- wandb/jupyter.py +16 -17
- wandb/old/summary.py +1 -1
- wandb/plot/confusion_matrix.py +1 -1
- wandb/plot/pr_curve.py +2 -1
- wandb/plot/roc_curve.py +2 -1
- wandb/{plots → plot}/utils.py +13 -25
- wandb/proto/v3/wandb_internal_pb2.py +54 -54
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +54 -54
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v5/wandb_base_pb2.py +30 -0
- wandb/proto/v5/wandb_internal_pb2.py +355 -0
- wandb/proto/v5/wandb_server_pb2.py +63 -0
- wandb/proto/v5/wandb_settings_pb2.py +45 -0
- wandb/proto/v5/wandb_telemetry_pb2.py +41 -0
- wandb/proto/wandb_base_pb2.py +2 -0
- wandb/proto/wandb_deprecated.py +9 -1
- wandb/proto/wandb_generate_deprecated.py +34 -0
- wandb/proto/{wandb_internal_codegen.py → wandb_generate_proto.py} +1 -35
- wandb/proto/wandb_internal_pb2.py +2 -0
- wandb/proto/wandb_server_pb2.py +2 -0
- wandb/proto/wandb_settings_pb2.py +2 -0
- wandb/proto/wandb_telemetry_pb2.py +2 -0
- wandb/sdk/artifacts/artifact.py +68 -22
- wandb/sdk/artifacts/artifact_manifest.py +1 -1
- wandb/sdk/artifacts/artifact_manifest_entry.py +6 -3
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -1
- wandb/sdk/artifacts/artifact_saver.py +1 -10
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +6 -2
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +6 -4
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +2 -42
- wandb/sdk/artifacts/storage_policy.py +1 -12
- wandb/sdk/data_types/_dtypes.py +8 -8
- wandb/sdk/data_types/image.py +2 -2
- wandb/sdk/data_types/video.py +5 -3
- wandb/sdk/integration_utils/data_logging.py +5 -5
- wandb/sdk/interface/interface.py +14 -1
- wandb/sdk/interface/interface_shared.py +1 -1
- wandb/sdk/internal/file_pusher.py +2 -5
- wandb/sdk/internal/file_stream.py +6 -19
- wandb/sdk/internal/internal_api.py +148 -136
- wandb/sdk/internal/job_builder.py +208 -136
- wandb/sdk/internal/progress.py +0 -28
- wandb/sdk/internal/sender.py +102 -39
- wandb/sdk/internal/settings_static.py +8 -1
- wandb/sdk/internal/system/assets/trainium.py +3 -3
- wandb/sdk/internal/system/system_info.py +4 -2
- wandb/sdk/internal/update.py +1 -1
- wandb/sdk/launch/__init__.py +9 -1
- wandb/sdk/launch/_launch.py +4 -24
- wandb/sdk/launch/_launch_add.py +1 -3
- wandb/sdk/launch/_project_spec.py +187 -225
- wandb/sdk/launch/agent/agent.py +59 -19
- wandb/sdk/launch/agent/config.py +0 -3
- wandb/sdk/launch/builder/abstract.py +68 -1
- wandb/sdk/launch/builder/build.py +165 -576
- wandb/sdk/launch/builder/context_manager.py +235 -0
- wandb/sdk/launch/builder/docker_builder.py +7 -23
- wandb/sdk/launch/builder/kaniko_builder.py +12 -25
- wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
- wandb/sdk/launch/create_job.py +51 -45
- wandb/sdk/launch/environment/aws_environment.py +26 -1
- wandb/sdk/launch/inputs/files.py +148 -0
- wandb/sdk/launch/inputs/internal.py +224 -0
- wandb/sdk/launch/inputs/manage.py +95 -0
- wandb/sdk/launch/registry/google_artifact_registry.py +1 -1
- wandb/sdk/launch/runner/abstract.py +2 -2
- wandb/sdk/launch/runner/kubernetes_monitor.py +45 -12
- wandb/sdk/launch/runner/kubernetes_runner.py +6 -8
- wandb/sdk/launch/runner/local_container.py +2 -3
- wandb/sdk/launch/runner/local_process.py +8 -29
- wandb/sdk/launch/runner/sagemaker_runner.py +20 -14
- wandb/sdk/launch/runner/vertex_runner.py +8 -7
- wandb/sdk/launch/sweeps/scheduler.py +5 -3
- wandb/sdk/launch/sweeps/scheduler_sweep.py +1 -1
- wandb/sdk/launch/sweeps/utils.py +4 -4
- wandb/sdk/launch/utils.py +16 -138
- wandb/sdk/lib/_settings_toposort_generated.py +2 -5
- wandb/sdk/lib/apikey.py +4 -2
- wandb/sdk/lib/config_util.py +3 -3
- wandb/sdk/lib/import_hooks.py +1 -1
- wandb/sdk/lib/proto_util.py +22 -1
- wandb/sdk/lib/redirect.py +20 -15
- wandb/sdk/lib/tracelog.py +1 -1
- wandb/sdk/service/service.py +2 -1
- wandb/sdk/service/streams.py +5 -5
- wandb/sdk/wandb_init.py +25 -59
- wandb/sdk/wandb_login.py +28 -25
- wandb/sdk/wandb_run.py +123 -53
- wandb/sdk/wandb_settings.py +33 -64
- wandb/sdk/wandb_setup.py +1 -1
- wandb/sdk/wandb_watch.py +1 -1
- wandb/sklearn/plot/classifier.py +10 -12
- wandb/sklearn/plot/clusterer.py +1 -1
- wandb/sync/sync.py +2 -2
- wandb/testing/relay.py +32 -17
- wandb/util.py +36 -37
- wandb/wandb_agent.py +3 -3
- wandb/wandb_controller.py +5 -4
- {wandb-0.17.0rc1.dist-info → wandb-0.17.1.dist-info}/METADATA +8 -10
- {wandb-0.17.0rc1.dist-info → wandb-0.17.1.dist-info}/RECORD +140 -162
- wandb/apis/reports/v1/_blocks.py +0 -1406
- wandb/apis/reports/v1/_helpers.py +0 -70
- wandb/apis/reports/v1/_panels.py +0 -1282
- wandb/apis/reports/v1/_templates.py +0 -478
- wandb/apis/reports/v1/blocks.py +0 -27
- wandb/apis/reports/v1/helpers.py +0 -2
- wandb/apis/reports/v1/mutations.py +0 -66
- wandb/apis/reports/v1/panels.py +0 -17
- wandb/apis/reports/v1/report.py +0 -268
- wandb/apis/reports/v1/runset.py +0 -144
- wandb/apis/reports/v1/templates.py +0 -7
- wandb/apis/reports/v1/util.py +0 -406
- wandb/apis/reports/v1/validators.py +0 -131
- wandb/apis/reports/v2/blocks.py +0 -25
- wandb/apis/reports/v2/expr_parsing.py +0 -257
- wandb/apis/reports/v2/gql.py +0 -68
- wandb/apis/reports/v2/interface.py +0 -1911
- wandb/apis/reports/v2/internal.py +0 -867
- wandb/apis/reports/v2/metrics.py +0 -6
- wandb/apis/reports/v2/panels.py +0 -15
- wandb/catboost/__init__.py +0 -9
- wandb/fastai/__init__.py +0 -9
- wandb/keras/__init__.py +0 -19
- wandb/lightgbm/__init__.py +0 -9
- wandb/plots/__init__.py +0 -6
- wandb/plots/explain_text.py +0 -36
- wandb/plots/heatmap.py +0 -81
- wandb/plots/named_entity.py +0 -43
- wandb/plots/part_of_speech.py +0 -50
- wandb/plots/plot_definitions.py +0 -768
- wandb/plots/precision_recall.py +0 -121
- wandb/plots/roc.py +0 -103
- wandb/sacred/__init__.py +0 -3
- wandb/xgboost/__init__.py +0 -9
- {wandb-0.17.0rc1.dist-info → wandb-0.17.1.dist-info}/WHEEL +0 -0
- {wandb-0.17.0rc1.dist-info → wandb-0.17.1.dist-info}/entry_points.txt +0 -0
- {wandb-0.17.0rc1.dist-info → wandb-0.17.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,95 @@
|
|
1
|
+
"""Functions for declaring overridable configuration for launch jobs."""
|
2
|
+
|
3
|
+
from typing import List, Optional
|
4
|
+
|
5
|
+
|
6
|
+
def manage_config_file(
|
7
|
+
path: str,
|
8
|
+
include: Optional[List[str]] = None,
|
9
|
+
exclude: Optional[List[str]] = None,
|
10
|
+
):
|
11
|
+
r"""Declare an overridable configuration file for a launch job.
|
12
|
+
|
13
|
+
If a new job version is created from the active run, the configuration file
|
14
|
+
will be added to the job's inputs. If the job is launched and overrides
|
15
|
+
have been provided for the configuration file, this function will detect
|
16
|
+
the overrides from the environment and update the configuration file on disk.
|
17
|
+
Note that these overrides will only be applied in ephemeral containers.
|
18
|
+
`include` and `exclude` are lists of dot separated paths with the config.
|
19
|
+
The paths are used to filter subtrees of the configuration file out of the
|
20
|
+
job's inputs.
|
21
|
+
|
22
|
+
For example, given the following configuration file:
|
23
|
+
```yaml
|
24
|
+
model:
|
25
|
+
name: resnet
|
26
|
+
layers: 18
|
27
|
+
training:
|
28
|
+
epochs: 10
|
29
|
+
batch_size: 32
|
30
|
+
```
|
31
|
+
|
32
|
+
Passing `include=['model']` will only include the `model` subtree in the
|
33
|
+
job's inputs. Passing `exclude=['model.layers']` will exclude the `layers`
|
34
|
+
key from the `model` subtree. Note that `exclude` takes precedence over
|
35
|
+
`include`.
|
36
|
+
|
37
|
+
`.` is used as a separator for nested keys. If a key contains a `.`, it
|
38
|
+
should be escaped with a backslash, e.g. `include=[r'model\.layers']`. Note
|
39
|
+
the use of `r` to denote a raw string when using escape chars.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
path (str): The path to the configuration file. This path must be
|
43
|
+
relative and must not contain backwards traversal, i.e. `..`.
|
44
|
+
include (List[str]): A list of keys to include in the configuration file.
|
45
|
+
exclude (List[str]): A list of keys to exclude from the configuration file.
|
46
|
+
|
47
|
+
Raises:
|
48
|
+
LaunchError: If the path is not valid, or if there is no active run.
|
49
|
+
"""
|
50
|
+
from .internal import handle_config_file_input
|
51
|
+
|
52
|
+
return handle_config_file_input(path, include, exclude)
|
53
|
+
|
54
|
+
|
55
|
+
def manage_wandb_config(
|
56
|
+
include: Optional[List[str]] = None,
|
57
|
+
exclude: Optional[List[str]] = None,
|
58
|
+
):
|
59
|
+
r"""Declare wandb.config as an overridable configuration for a launch job.
|
60
|
+
|
61
|
+
If a new job version is created from the active run, the run config
|
62
|
+
(wandb.config) will become an overridable input of the job. If the job is
|
63
|
+
launched and overrides have been provided for the run config, the overrides
|
64
|
+
will be applied to the run config when `wandb.init` is called.
|
65
|
+
`include` and `exclude` are lists of dot separated paths with the config.
|
66
|
+
The paths are used to filter subtrees of the configuration file out of the
|
67
|
+
job's inputs.
|
68
|
+
|
69
|
+
For example, given the following run config contents:
|
70
|
+
```yaml
|
71
|
+
model:
|
72
|
+
name: resnet
|
73
|
+
layers: 18
|
74
|
+
training:
|
75
|
+
epochs: 10
|
76
|
+
batch_size: 32
|
77
|
+
```
|
78
|
+
Passing `include=['model']` will only include the `model` subtree in the
|
79
|
+
job's inputs. Passing `exclude=['model.layers']` will exclude the `layers`
|
80
|
+
key from the `model` subtree. Note that `exclude` takes precedence over
|
81
|
+
`include`.
|
82
|
+
`.` is used as a separator for nested keys. If a key contains a `.`, it
|
83
|
+
should be escaped with a backslash, e.g. `include=[r'model\.layers']`. Note
|
84
|
+
the use of `r` to denote a raw string when using escape chars.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
include (List[str]): A list of subtrees to include in the configuration.
|
88
|
+
exclude (List[str]): A list of subtrees to exclude from the configuration.
|
89
|
+
|
90
|
+
Raises:
|
91
|
+
LaunchError: If there is no active run.
|
92
|
+
"""
|
93
|
+
from .internal import handle_run_config_input
|
94
|
+
|
95
|
+
handle_run_config_input(include, exclude)
|
@@ -211,7 +211,7 @@ class GoogleArtifactRegistry(AbstractRegistry):
|
|
211
211
|
for image in await list_images(request={"parent": parent}):
|
212
212
|
if tag in image.tags:
|
213
213
|
return True
|
214
|
-
except google.api_core.exceptions.NotFound as e:
|
214
|
+
except google.api_core.exceptions.NotFound as e: # type: ignore[attr-defined]
|
215
215
|
raise LaunchError(
|
216
216
|
f"The Google Artifact Registry repository {self.repository} "
|
217
217
|
f"does not exist. Please create it or modify your registry configuration."
|
@@ -40,9 +40,9 @@ State = Literal[
|
|
40
40
|
|
41
41
|
|
42
42
|
class Status:
|
43
|
-
def __init__(self, state: "State" = "unknown",
|
43
|
+
def __init__(self, state: "State" = "unknown", messages: List[str] = None): # type: ignore
|
44
44
|
self.state = state
|
45
|
-
self.
|
45
|
+
self.messages = messages or []
|
46
46
|
|
47
47
|
def __repr__(self) -> "State":
|
48
48
|
return self.state
|
@@ -14,6 +14,7 @@ from kubernetes_asyncio.client import ( # type: ignore # noqa: F401
|
|
14
14
|
BatchV1Api,
|
15
15
|
CoreV1Api,
|
16
16
|
CustomObjectsApi,
|
17
|
+
V1Pod,
|
17
18
|
V1PodStatus,
|
18
19
|
)
|
19
20
|
|
@@ -118,6 +119,27 @@ def _is_container_creating(status: "V1PodStatus") -> bool:
|
|
118
119
|
return False
|
119
120
|
|
120
121
|
|
122
|
+
def _is_pod_unschedulable(status: "V1PodStatus") -> Tuple[bool, str]:
|
123
|
+
"""Return whether the pod is unschedulable along with the reason message."""
|
124
|
+
if not status.conditions:
|
125
|
+
return False, ""
|
126
|
+
for condition in status.conditions:
|
127
|
+
if (
|
128
|
+
condition.type == "PodScheduled"
|
129
|
+
and condition.status == "False"
|
130
|
+
and condition.reason == "Unschedulable"
|
131
|
+
):
|
132
|
+
return True, condition.message
|
133
|
+
return False, ""
|
134
|
+
|
135
|
+
|
136
|
+
def _get_crd_job_name(object: "V1Pod") -> Optional[str]:
|
137
|
+
refs = object.metadata.owner_references
|
138
|
+
if refs:
|
139
|
+
return refs[0].name
|
140
|
+
return None
|
141
|
+
|
142
|
+
|
121
143
|
def _state_from_conditions(conditions: List[Dict[str, Any]]) -> Optional[State]:
|
122
144
|
"""Get the status from the pod conditions."""
|
123
145
|
true_conditions = [
|
@@ -298,10 +320,18 @@ class LaunchKubernetesMonitor:
|
|
298
320
|
counts[state] += 1
|
299
321
|
return counts
|
300
322
|
|
301
|
-
def
|
323
|
+
def _set_status_state(self, job_name: str, state: State) -> None:
|
302
324
|
"""Set the status of the run."""
|
303
|
-
if self._job_states
|
304
|
-
self._job_states[job_name] =
|
325
|
+
if job_name not in self._job_states:
|
326
|
+
self._job_states[job_name] = Status(state)
|
327
|
+
elif self._job_states[job_name].state != state:
|
328
|
+
self._job_states[job_name].state = state
|
329
|
+
|
330
|
+
def _add_status_message(self, job_name: str, message: str) -> None:
|
331
|
+
if job_name not in self._job_states:
|
332
|
+
self._job_states[job_name] = Status("unknown")
|
333
|
+
wandb.termwarn(f"Warning from Kubernetes for job {job_name}: {message}")
|
334
|
+
self._job_states[job_name].messages.append(message)
|
305
335
|
|
306
336
|
async def _monitor_pods(self, namespace: str) -> None:
|
307
337
|
"""Monitor a namespace for changes."""
|
@@ -312,15 +342,19 @@ class LaunchKubernetesMonitor:
|
|
312
342
|
label_selector=self._label_selector,
|
313
343
|
):
|
314
344
|
obj = event.get("object")
|
315
|
-
job_name = obj.metadata.labels.get("job-name")
|
345
|
+
job_name = obj.metadata.labels.get("job-name") or _get_crd_job_name(obj)
|
316
346
|
if job_name is None or not hasattr(obj, "status"):
|
317
347
|
continue
|
318
348
|
if self.__get_status(job_name) in ["finished", "failed"]:
|
319
349
|
continue
|
350
|
+
|
351
|
+
is_unschedulable, reason = _is_pod_unschedulable(obj.status)
|
352
|
+
if is_unschedulable:
|
353
|
+
self._add_status_message(job_name, reason)
|
320
354
|
if obj.status.phase == "Running" or _is_container_creating(obj.status):
|
321
|
-
self.
|
355
|
+
self._set_status_state(job_name, "running")
|
322
356
|
elif _is_preempted(obj.status):
|
323
|
-
self.
|
357
|
+
self._set_status_state(job_name, "preempted")
|
324
358
|
|
325
359
|
async def _monitor_jobs(self, namespace: str) -> None:
|
326
360
|
"""Monitor a namespace for changes."""
|
@@ -334,15 +368,15 @@ class LaunchKubernetesMonitor:
|
|
334
368
|
job_name = obj.metadata.name
|
335
369
|
|
336
370
|
if obj.status.succeeded == 1:
|
337
|
-
self.
|
371
|
+
self._set_status_state(job_name, "finished")
|
338
372
|
elif obj.status.failed is not None and obj.status.failed >= 1:
|
339
|
-
self.
|
373
|
+
self._set_status_state(job_name, "failed")
|
340
374
|
|
341
375
|
# If the job is deleted and we haven't seen a terminal state
|
342
376
|
# then we will consider the job failed.
|
343
377
|
if event.get("type") == "DELETED":
|
344
378
|
if self._job_states.get(job_name) != Status("finished"):
|
345
|
-
self.
|
379
|
+
self._set_status_state(job_name, "failed")
|
346
380
|
|
347
381
|
async def _monitor_crd(
|
348
382
|
self, namespace: str, custom_resource: CustomResource
|
@@ -355,7 +389,7 @@ class LaunchKubernetesMonitor:
|
|
355
389
|
plural=custom_resource.plural,
|
356
390
|
group=custom_resource.group,
|
357
391
|
version=custom_resource.version,
|
358
|
-
label_selector=self._label_selector,
|
392
|
+
label_selector=self._label_selector,
|
359
393
|
):
|
360
394
|
object = event.get("object")
|
361
395
|
name = object.get("metadata", dict()).get("name")
|
@@ -383,8 +417,7 @@ class LaunchKubernetesMonitor:
|
|
383
417
|
)
|
384
418
|
if state is None:
|
385
419
|
continue
|
386
|
-
|
387
|
-
self._set_status(name, status)
|
420
|
+
self._set_status_state(name, state)
|
388
421
|
|
389
422
|
|
390
423
|
class SafeWatch:
|
@@ -29,7 +29,6 @@ from wandb.sdk.lib.retry import ExponentialBackoff, retry_async
|
|
29
29
|
from wandb.util import get_module
|
30
30
|
|
31
31
|
from .._project_spec import EntryPoint, LaunchProject
|
32
|
-
from ..builder.build import get_env_vars_dict
|
33
32
|
from ..errors import LaunchError
|
34
33
|
from ..utils import (
|
35
34
|
LOG_PREFIX,
|
@@ -374,8 +373,7 @@ class KubernetesRunner(AbstractRunner):
|
|
374
373
|
}
|
375
374
|
|
376
375
|
entry_point = (
|
377
|
-
launch_project.override_entrypoint
|
378
|
-
or launch_project.get_single_entry_point()
|
376
|
+
launch_project.override_entrypoint or launch_project.get_job_entry_point()
|
379
377
|
)
|
380
378
|
if launch_project.docker_image:
|
381
379
|
# dont specify run id if user provided image, could have multiple runs
|
@@ -401,8 +399,8 @@ class KubernetesRunner(AbstractRunner):
|
|
401
399
|
launch_project.override_entrypoint is not None,
|
402
400
|
)
|
403
401
|
|
404
|
-
env_vars = get_env_vars_dict(
|
405
|
-
|
402
|
+
env_vars = launch_project.get_env_vars_dict(
|
403
|
+
self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
|
406
404
|
)
|
407
405
|
api_key_secret = None
|
408
406
|
for cont in containers:
|
@@ -511,8 +509,8 @@ class KubernetesRunner(AbstractRunner):
|
|
511
509
|
api_version = resource_args.get("apiVersion", "batch/v1")
|
512
510
|
|
513
511
|
if api_version not in ["batch/v1", "batch/v1beta1"]:
|
514
|
-
env_vars = get_env_vars_dict(
|
515
|
-
|
512
|
+
env_vars = launch_project.get_env_vars_dict(
|
513
|
+
self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
|
516
514
|
)
|
517
515
|
# Crawl the resource args and add our env vars to the containers.
|
518
516
|
add_wandb_env(resource_args, env_vars)
|
@@ -537,7 +535,7 @@ class KubernetesRunner(AbstractRunner):
|
|
537
535
|
if LaunchAgent.initialized():
|
538
536
|
add_label_to_pods(
|
539
537
|
resource_args,
|
540
|
-
|
538
|
+
WANDB_K8S_LABEL_AGENT,
|
541
539
|
LaunchAgent.name(),
|
542
540
|
)
|
543
541
|
resource_args["metadata"]["labels"][WANDB_K8S_LABEL_AGENT] = (
|
@@ -12,7 +12,6 @@ from wandb.sdk.launch.environment.abstract import AbstractEnvironment
|
|
12
12
|
from wandb.sdk.launch.registry.abstract import AbstractRegistry
|
13
13
|
|
14
14
|
from .._project_spec import LaunchProject
|
15
|
-
from ..builder.build import get_env_vars_dict
|
16
15
|
from ..errors import LaunchError
|
17
16
|
from ..utils import (
|
18
17
|
LOG_PREFIX,
|
@@ -133,8 +132,8 @@ class LocalContainerRunner(AbstractRunner):
|
|
133
132
|
docker_args = self._populate_docker_args(launch_project, image_uri)
|
134
133
|
synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
|
135
134
|
|
136
|
-
env_vars = get_env_vars_dict(
|
137
|
-
|
135
|
+
env_vars = launch_project.get_env_vars_dict(
|
136
|
+
self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
|
138
137
|
)
|
139
138
|
|
140
139
|
# When running against local port, need to swap to local docker host
|
@@ -4,16 +4,12 @@ from typing import Any, List, Optional
|
|
4
4
|
|
5
5
|
import wandb
|
6
6
|
|
7
|
-
from .._project_spec import LaunchProject
|
8
|
-
from ..builder.build import get_env_vars_dict
|
7
|
+
from .._project_spec import LaunchProject
|
9
8
|
from ..errors import LaunchError
|
10
9
|
from ..utils import (
|
11
10
|
LOG_PREFIX,
|
12
11
|
MAX_ENV_LENGTHS,
|
13
12
|
PROJECT_SYNCHRONOUS,
|
14
|
-
_is_wandb_uri,
|
15
|
-
download_wandb_python_deps,
|
16
|
-
parse_wandb_uri,
|
17
13
|
sanitize_wandb_api_key,
|
18
14
|
validate_wandb_python_deps,
|
19
15
|
)
|
@@ -47,8 +43,7 @@ class LocalProcessRunner(AbstractRunner):
|
|
47
43
|
|
48
44
|
synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
|
49
45
|
entry_point = (
|
50
|
-
launch_project.override_entrypoint
|
51
|
-
or launch_project.get_single_entry_point()
|
46
|
+
launch_project.override_entrypoint or launch_project.get_job_entry_point()
|
52
47
|
)
|
53
48
|
|
54
49
|
cmd: List[Any] = []
|
@@ -56,23 +51,7 @@ class LocalProcessRunner(AbstractRunner):
|
|
56
51
|
if launch_project.project_dir is None:
|
57
52
|
raise LaunchError("Launch LocalProcessRunner received empty project dir")
|
58
53
|
|
59
|
-
|
60
|
-
if launch_project.uri and _is_wandb_uri(launch_project.uri):
|
61
|
-
source_entity, source_project, run_name = parse_wandb_uri(
|
62
|
-
launch_project.uri
|
63
|
-
)
|
64
|
-
run_requirements_file = download_wandb_python_deps(
|
65
|
-
source_entity,
|
66
|
-
source_project,
|
67
|
-
run_name,
|
68
|
-
self._api,
|
69
|
-
launch_project.project_dir,
|
70
|
-
)
|
71
|
-
validate_wandb_python_deps(
|
72
|
-
run_requirements_file,
|
73
|
-
launch_project.project_dir,
|
74
|
-
)
|
75
|
-
elif launch_project.job:
|
54
|
+
if launch_project.job:
|
76
55
|
assert launch_project._job_artifact is not None
|
77
56
|
try:
|
78
57
|
validate_wandb_python_deps(
|
@@ -81,14 +60,14 @@ class LocalProcessRunner(AbstractRunner):
|
|
81
60
|
)
|
82
61
|
except Exception:
|
83
62
|
wandb.termwarn("Unable to validate python dependencies")
|
84
|
-
env_vars = get_env_vars_dict(
|
85
|
-
|
63
|
+
env_vars = launch_project.get_env_vars_dict(
|
64
|
+
self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
|
86
65
|
)
|
87
66
|
for env_key, env_value in env_vars.items():
|
88
67
|
cmd += [f"{shlex.quote(env_key)}={shlex.quote(env_value)}"]
|
89
|
-
|
90
|
-
|
91
|
-
cmd +=
|
68
|
+
if entry_point is not None:
|
69
|
+
cmd += entry_point.command
|
70
|
+
cmd += launch_project.override_args
|
92
71
|
|
93
72
|
command_str = " ".join(cmd).strip()
|
94
73
|
_msg = f"{LOG_PREFIX}Launching run as a local-process with command {sanitize_wandb_api_key(command_str)}"
|
@@ -12,8 +12,7 @@ from wandb.apis.internal import Api
|
|
12
12
|
from wandb.sdk.launch.environment.aws_environment import AwsEnvironment
|
13
13
|
from wandb.sdk.launch.errors import LaunchError
|
14
14
|
|
15
|
-
from .._project_spec import EntryPoint, LaunchProject
|
16
|
-
from ..builder.build import get_env_vars_dict
|
15
|
+
from .._project_spec import EntryPoint, LaunchProject
|
17
16
|
from ..registry.abstract import AbstractRegistry
|
18
17
|
from ..utils import (
|
19
18
|
LOG_PREFIX,
|
@@ -68,6 +67,7 @@ class SagemakerSubmittedRun(AbstractRun):
|
|
68
67
|
logGroupName="/aws/sagemaker/TrainingJobs",
|
69
68
|
logStreamName=log_name,
|
70
69
|
)
|
70
|
+
assert "events" in res
|
71
71
|
return "\n".join(
|
72
72
|
[f'{event["timestamp"]}:{event["message"]}' for event in res["events"]]
|
73
73
|
)
|
@@ -179,7 +179,10 @@ class SageMakerRunner(AbstractRunner):
|
|
179
179
|
caller_id = client.get_caller_identity()
|
180
180
|
account_id = caller_id["Account"]
|
181
181
|
_logger.info(f"Using account ID {account_id}")
|
182
|
-
|
182
|
+
partition = await self.environment.get_partition()
|
183
|
+
role_arn = get_role_arn(
|
184
|
+
given_sagemaker_args, self.backend_config, account_id, partition
|
185
|
+
)
|
183
186
|
|
184
187
|
# Create a sagemaker client to launch the job.
|
185
188
|
sagemaker_client = session.client("sagemaker")
|
@@ -221,12 +224,12 @@ class SageMakerRunner(AbstractRunner):
|
|
221
224
|
launch_project.fill_macros(image_uri)
|
222
225
|
_logger.info("Connecting to sagemaker client")
|
223
226
|
entry_point = (
|
224
|
-
launch_project.override_entrypoint
|
225
|
-
or launch_project.get_single_entry_point()
|
226
|
-
)
|
227
|
-
command_args = get_entry_point_command(
|
228
|
-
entry_point, launch_project.override_args
|
227
|
+
launch_project.override_entrypoint or launch_project.get_job_entry_point()
|
229
228
|
)
|
229
|
+
command_args = []
|
230
|
+
if entry_point is not None:
|
231
|
+
command_args += entry_point.command
|
232
|
+
command_args += launch_project.override_args
|
230
233
|
if command_args:
|
231
234
|
command_str = " ".join(command_args)
|
232
235
|
wandb.termlog(
|
@@ -349,18 +352,18 @@ def build_sagemaker_args(
|
|
349
352
|
|
350
353
|
if sagemaker_args.get("ResourceConfig") is None:
|
351
354
|
raise LaunchError(
|
352
|
-
"Sagemaker launcher requires a ResourceConfig
|
355
|
+
"Sagemaker launcher requires a ResourceConfig resource argument"
|
353
356
|
)
|
354
357
|
|
355
358
|
if sagemaker_args.get("StoppingCondition") is None:
|
356
359
|
raise LaunchError(
|
357
|
-
"Sagemaker launcher requires a StoppingCondition
|
360
|
+
"Sagemaker launcher requires a StoppingCondition resource argument"
|
358
361
|
)
|
359
362
|
|
360
363
|
given_env = given_sagemaker_args.get(
|
361
364
|
"Environment", sagemaker_args.get("environment", {})
|
362
365
|
)
|
363
|
-
calced_env = get_env_vars_dict(
|
366
|
+
calced_env = launch_project.get_env_vars_dict(api, max_env_length)
|
364
367
|
total_env = {**calced_env, **given_env}
|
365
368
|
sagemaker_args["Environment"] = total_env
|
366
369
|
|
@@ -405,7 +408,10 @@ async def launch_sagemaker_job(
|
|
405
408
|
|
406
409
|
|
407
410
|
def get_role_arn(
|
408
|
-
sagemaker_args: Dict[str, Any],
|
411
|
+
sagemaker_args: Dict[str, Any],
|
412
|
+
backend_config: Dict[str, Any],
|
413
|
+
account_id: str,
|
414
|
+
partition: str,
|
409
415
|
) -> str:
|
410
416
|
"""Get the role arn from the sagemaker args or the backend config."""
|
411
417
|
role_arn = sagemaker_args.get("RoleArn") or sagemaker_args.get("role_arn")
|
@@ -416,7 +422,7 @@ def get_role_arn(
|
|
416
422
|
"AWS sagemaker require a string RoleArn set this by adding a `RoleArn` key to the sagemaker"
|
417
423
|
"field of resource_args"
|
418
424
|
)
|
419
|
-
if role_arn.startswith("arn:
|
425
|
+
if role_arn.startswith(f"arn:{partition}:iam::"):
|
420
426
|
return role_arn # type: ignore
|
421
427
|
|
422
|
-
return f"arn:
|
428
|
+
return f"arn:{partition}:iam::{account_id}:role/{role_arn}"
|
@@ -8,8 +8,7 @@ if False:
|
|
8
8
|
from wandb.apis.internal import Api
|
9
9
|
from wandb.util import get_module
|
10
10
|
|
11
|
-
from .._project_spec import LaunchProject
|
12
|
-
from ..builder.build import get_env_vars_dict
|
11
|
+
from .._project_spec import LaunchProject
|
13
12
|
from ..environment.gcp_environment import GcpEnvironment
|
14
13
|
from ..errors import LaunchError
|
15
14
|
from ..registry.abstract import AbstractRegistry
|
@@ -113,14 +112,16 @@ class VertexRunner(AbstractRunner):
|
|
113
112
|
synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
|
114
113
|
|
115
114
|
entry_point = (
|
116
|
-
launch_project.override_entrypoint
|
117
|
-
or launch_project.get_single_entry_point()
|
115
|
+
launch_project.override_entrypoint or launch_project.get_job_entry_point()
|
118
116
|
)
|
119
117
|
|
120
118
|
# TODO: Set entrypoint in each container
|
121
|
-
entry_cmd =
|
122
|
-
|
123
|
-
|
119
|
+
entry_cmd = []
|
120
|
+
if entry_point is not None:
|
121
|
+
entry_cmd += entry_point.command
|
122
|
+
entry_cmd += launch_project.override_args
|
123
|
+
|
124
|
+
env_vars = launch_project.get_env_vars_dict(
|
124
125
|
api=self._api,
|
125
126
|
max_env_length=MAX_ENV_LENGTHS[self.__class__.__name__],
|
126
127
|
)
|
@@ -408,7 +408,7 @@ class Scheduler(ABC):
|
|
408
408
|
return count
|
409
409
|
|
410
410
|
def _try_load_executable(self) -> bool:
|
411
|
-
"""Check
|
411
|
+
"""Check existence of valid executable for a run.
|
412
412
|
|
413
413
|
logs and returns False when job is unreachable
|
414
414
|
"""
|
@@ -423,7 +423,7 @@ class Scheduler(ABC):
|
|
423
423
|
return False
|
424
424
|
return True
|
425
425
|
elif self._kwargs.get("image_uri"):
|
426
|
-
# TODO(gst): check docker
|
426
|
+
# TODO(gst): check docker existence? Use registry in launch config?
|
427
427
|
return True
|
428
428
|
else:
|
429
429
|
return False
|
@@ -611,7 +611,7 @@ class Scheduler(ABC):
|
|
611
611
|
f"Failed to get runstate for run ({run_id}). Error: {traceback.format_exc()}"
|
612
612
|
)
|
613
613
|
run_state = RunState.FAILED
|
614
|
-
else: # first time we get
|
614
|
+
else: # first time we get unknown state
|
615
615
|
run_state = RunState.UNKNOWN
|
616
616
|
except (AttributeError, ValueError):
|
617
617
|
wandb.termwarn(
|
@@ -668,6 +668,8 @@ class Scheduler(ABC):
|
|
668
668
|
launch_config = copy.deepcopy(self._wandb_run.config.get("launch", {}))
|
669
669
|
if "overrides" not in launch_config:
|
670
670
|
launch_config["overrides"] = {"run_config": {}}
|
671
|
+
if "run_config" not in launch_config["overrides"]:
|
672
|
+
launch_config["overrides"]["run_config"] = {}
|
671
673
|
launch_config["overrides"]["run_config"].update(args["args_dict"])
|
672
674
|
|
673
675
|
if macro_args: # pipe in hyperparam args as params to launch
|
@@ -59,7 +59,7 @@ class SweepScheduler(Scheduler):
|
|
59
59
|
return None
|
60
60
|
|
61
61
|
def _get_sweep_commands(self, worker_id: int) -> List[Dict[str, Any]]:
|
62
|
-
"""Helper to
|
62
|
+
"""Helper to receive sweep command from backend."""
|
63
63
|
# AgentHeartbeat wants a Dict of runs which are running or queued
|
64
64
|
_run_states: Dict[str, bool] = {}
|
65
65
|
for run_id, run in self._yield_runs():
|
wandb/sdk/launch/sweeps/utils.py
CHANGED
@@ -211,13 +211,13 @@ def create_sweep_command_args(command: Dict) -> Dict[str, Any]:
|
|
211
211
|
|
212
212
|
"""
|
213
213
|
if "args" not in command:
|
214
|
-
raise ValueError('No "args" found in command:
|
214
|
+
raise ValueError('No "args" found in command: {}'.format(command))
|
215
215
|
# four different formats of command args
|
216
216
|
# (1) standard command line flags (e.g. --foo=bar)
|
217
217
|
flags: List[str] = []
|
218
218
|
# (2) flags without hyphens (e.g. foo=bar)
|
219
219
|
flags_no_hyphens: List[str] = []
|
220
|
-
# (3) flags with false booleans
|
220
|
+
# (3) flags with false booleans omitted (e.g. --foo)
|
221
221
|
flags_no_booleans: List[str] = []
|
222
222
|
# (4) flags as a dictionary (used for constructing a json)
|
223
223
|
flags_dict: Dict[str, Any] = {}
|
@@ -228,7 +228,7 @@ def create_sweep_command_args(command: Dict) -> Dict[str, Any]:
|
|
228
228
|
try:
|
229
229
|
_value: Any = config["value"]
|
230
230
|
except KeyError:
|
231
|
-
raise ValueError('No "value" found for command["args"]["
|
231
|
+
raise ValueError('No "value" found for command["args"]["{}"]'.format(param))
|
232
232
|
|
233
233
|
_flag: str = f"{param}={_value}"
|
234
234
|
flags.append("--" + _flag)
|
@@ -257,7 +257,7 @@ def make_launch_sweep_entrypoint(
|
|
257
257
|
"""Use args dict from create_sweep_command_args to construct entrypoint.
|
258
258
|
|
259
259
|
If replace is True, remove macros from entrypoint, fill them in with args
|
260
|
-
and then return the args in
|
260
|
+
and then return the args in separate return value.
|
261
261
|
"""
|
262
262
|
if not command:
|
263
263
|
return None, None
|