wandb 0.16.5__py3-none-any.whl → 0.17.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- package_readme.md +95 -0
- wandb/__init__.py +2 -3
- wandb/agents/pyagent.py +0 -1
- wandb/analytics/sentry.py +2 -1
- wandb/apis/importers/internals/internal.py +0 -1
- wandb/apis/importers/internals/protocols.py +30 -56
- wandb/apis/importers/mlflow.py +13 -26
- wandb/apis/importers/wandb.py +8 -14
- wandb/apis/internal.py +0 -3
- wandb/apis/public/api.py +55 -3
- wandb/apis/public/artifacts.py +1 -0
- wandb/apis/public/files.py +1 -0
- wandb/apis/public/history.py +1 -0
- wandb/apis/public/jobs.py +17 -4
- wandb/apis/public/projects.py +1 -0
- wandb/apis/public/reports.py +1 -0
- wandb/apis/public/runs.py +15 -17
- wandb/apis/public/sweeps.py +1 -0
- wandb/apis/public/teams.py +1 -0
- wandb/apis/public/users.py +1 -0
- wandb/apis/reports/v1/_blocks.py +3 -7
- wandb/apis/reports/v2/gql.py +1 -0
- wandb/apis/reports/v2/interface.py +3 -4
- wandb/apis/reports/v2/internal.py +5 -8
- wandb/cli/cli.py +95 -22
- wandb/data_types.py +9 -6
- wandb/docker/__init__.py +1 -1
- wandb/env.py +38 -8
- wandb/errors/__init__.py +5 -0
- wandb/errors/term.py +10 -2
- wandb/filesync/step_checksum.py +1 -4
- wandb/filesync/step_prepare.py +4 -24
- wandb/filesync/step_upload.py +4 -106
- wandb/filesync/upload_job.py +0 -76
- wandb/integration/catboost/catboost.py +1 -1
- wandb/integration/fastai/__init__.py +1 -0
- wandb/integration/huggingface/resolver.py +2 -2
- wandb/integration/keras/__init__.py +1 -0
- wandb/integration/keras/callbacks/metrics_logger.py +1 -1
- wandb/integration/keras/keras.py +7 -7
- wandb/integration/langchain/wandb_tracer.py +1 -0
- wandb/integration/lightning/fabric/logger.py +1 -3
- wandb/integration/metaflow/metaflow.py +41 -6
- wandb/integration/openai/fine_tuning.py +77 -40
- wandb/integration/prodigy/prodigy.py +1 -1
- wandb/old/summary.py +1 -1
- wandb/plot/confusion_matrix.py +1 -1
- wandb/plot/pr_curve.py +2 -1
- wandb/plot/roc_curve.py +2 -1
- wandb/{plots → plot}/utils.py +13 -25
- wandb/proto/v3/wandb_internal_pb2.py +364 -332
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +322 -316
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/wandb_deprecated.py +7 -1
- wandb/proto/wandb_internal_codegen.py +3 -29
- wandb/sdk/artifacts/artifact.py +51 -20
- wandb/sdk/artifacts/artifact_download_logger.py +1 -0
- wandb/sdk/artifacts/artifact_file_cache.py +18 -4
- wandb/sdk/artifacts/artifact_instance_cache.py +1 -0
- wandb/sdk/artifacts/artifact_manifest.py +1 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +7 -3
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
- wandb/sdk/artifacts/artifact_saver.py +18 -27
- wandb/sdk/artifacts/artifact_state.py +1 -0
- wandb/sdk/artifacts/artifact_ttl.py +1 -0
- wandb/sdk/artifacts/exceptions.py +1 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +13 -18
- wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +5 -3
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +3 -42
- wandb/sdk/artifacts/storage_policy.py +2 -12
- wandb/sdk/data_types/_dtypes.py +8 -8
- wandb/sdk/data_types/base_types/media.py +3 -6
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +3 -1
- wandb/sdk/data_types/image.py +1 -1
- wandb/sdk/data_types/video.py +1 -1
- wandb/sdk/integration_utils/auto_logging.py +5 -6
- wandb/sdk/integration_utils/data_logging.py +10 -6
- wandb/sdk/interface/interface.py +86 -38
- wandb/sdk/interface/interface_shared.py +7 -13
- wandb/sdk/internal/datastore.py +1 -1
- wandb/sdk/internal/file_pusher.py +2 -5
- wandb/sdk/internal/file_stream.py +5 -18
- wandb/sdk/internal/handler.py +18 -2
- wandb/sdk/internal/internal.py +0 -1
- wandb/sdk/internal/internal_api.py +1 -129
- wandb/sdk/internal/internal_util.py +0 -1
- wandb/sdk/internal/job_builder.py +159 -45
- wandb/sdk/internal/profiler.py +1 -0
- wandb/sdk/internal/progress.py +0 -28
- wandb/sdk/internal/run.py +1 -0
- wandb/sdk/internal/sender.py +1 -2
- wandb/sdk/internal/system/assets/gpu_amd.py +44 -44
- wandb/sdk/internal/system/assets/gpu_apple.py +56 -11
- wandb/sdk/internal/system/assets/interfaces.py +6 -8
- wandb/sdk/internal/system/assets/open_metrics.py +2 -2
- wandb/sdk/internal/system/assets/trainium.py +1 -3
- wandb/sdk/launch/__init__.py +9 -1
- wandb/sdk/launch/_launch.py +9 -24
- wandb/sdk/launch/_launch_add.py +1 -3
- wandb/sdk/launch/_project_spec.py +188 -241
- wandb/sdk/launch/agent/agent.py +115 -48
- wandb/sdk/launch/agent/config.py +80 -14
- wandb/sdk/launch/builder/abstract.py +69 -1
- wandb/sdk/launch/builder/build.py +156 -555
- wandb/sdk/launch/builder/context_manager.py +235 -0
- wandb/sdk/launch/builder/docker_builder.py +8 -23
- wandb/sdk/launch/builder/kaniko_builder.py +161 -159
- wandb/sdk/launch/builder/noop.py +1 -0
- wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
- wandb/sdk/launch/create_job.py +68 -63
- wandb/sdk/launch/environment/abstract.py +1 -0
- wandb/sdk/launch/environment/gcp_environment.py +1 -0
- wandb/sdk/launch/environment/local_environment.py +1 -0
- wandb/sdk/launch/inputs/files.py +148 -0
- wandb/sdk/launch/inputs/internal.py +217 -0
- wandb/sdk/launch/inputs/manage.py +95 -0
- wandb/sdk/launch/loader.py +1 -0
- wandb/sdk/launch/registry/abstract.py +1 -0
- wandb/sdk/launch/registry/azure_container_registry.py +1 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +1 -0
- wandb/sdk/launch/registry/google_artifact_registry.py +2 -1
- wandb/sdk/launch/registry/local_registry.py +1 -0
- wandb/sdk/launch/runner/abstract.py +1 -0
- wandb/sdk/launch/runner/kubernetes_monitor.py +4 -1
- wandb/sdk/launch/runner/kubernetes_runner.py +9 -10
- wandb/sdk/launch/runner/local_container.py +2 -3
- wandb/sdk/launch/runner/local_process.py +8 -29
- wandb/sdk/launch/runner/sagemaker_runner.py +21 -20
- wandb/sdk/launch/runner/vertex_runner.py +8 -7
- wandb/sdk/launch/sweeps/scheduler.py +7 -4
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +3 -3
- wandb/sdk/launch/utils.py +33 -140
- wandb/sdk/lib/_settings_toposort_generated.py +1 -5
- wandb/sdk/lib/fsm.py +8 -12
- wandb/sdk/lib/gitlib.py +4 -4
- wandb/sdk/lib/import_hooks.py +1 -1
- wandb/sdk/lib/lazyloader.py +0 -1
- wandb/sdk/lib/proto_util.py +23 -2
- wandb/sdk/lib/redirect.py +19 -14
- wandb/sdk/lib/retry.py +3 -2
- wandb/sdk/lib/run_moment.py +7 -1
- wandb/sdk/lib/tracelog.py +1 -1
- wandb/sdk/service/service.py +19 -16
- wandb/sdk/verify/verify.py +2 -1
- wandb/sdk/wandb_init.py +16 -63
- wandb/sdk/wandb_manager.py +2 -2
- wandb/sdk/wandb_require.py +5 -0
- wandb/sdk/wandb_run.py +164 -90
- wandb/sdk/wandb_settings.py +2 -48
- wandb/sdk/wandb_setup.py +1 -1
- wandb/sklearn/__init__.py +1 -0
- wandb/sklearn/plot/__init__.py +1 -0
- wandb/sklearn/plot/classifier.py +11 -12
- wandb/sklearn/plot/clusterer.py +2 -1
- wandb/sklearn/plot/regressor.py +1 -0
- wandb/sklearn/plot/shared.py +1 -0
- wandb/sklearn/utils.py +1 -0
- wandb/testing/relay.py +4 -4
- wandb/trigger.py +1 -0
- wandb/util.py +67 -54
- wandb/wandb_controller.py +2 -3
- wandb/wandb_torch.py +1 -2
- {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/METADATA +67 -70
- {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/RECORD +178 -188
- {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/WHEEL +1 -2
- wandb/bin/apple_gpu_stats +0 -0
- wandb/catboost/__init__.py +0 -9
- wandb/fastai/__init__.py +0 -9
- wandb/keras/__init__.py +0 -18
- wandb/lightgbm/__init__.py +0 -9
- wandb/plots/__init__.py +0 -6
- wandb/plots/explain_text.py +0 -36
- wandb/plots/heatmap.py +0 -81
- wandb/plots/named_entity.py +0 -43
- wandb/plots/part_of_speech.py +0 -50
- wandb/plots/plot_definitions.py +0 -768
- wandb/plots/precision_recall.py +0 -121
- wandb/plots/roc.py +0 -103
- wandb/sacred/__init__.py +0 -3
- wandb/xgboost/__init__.py +0 -9
- wandb-0.16.5.dist-info/top_level.txt +0 -1
- {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info/licenses}/LICENSE +0 -0
@@ -0,0 +1,95 @@
|
|
1
|
+
"""Functions for declaring overridable configuration for launch jobs."""
|
2
|
+
|
3
|
+
from typing import List, Optional
|
4
|
+
|
5
|
+
|
6
|
+
def manage_config_file(
|
7
|
+
path: str,
|
8
|
+
include: Optional[List[str]] = None,
|
9
|
+
exclude: Optional[List[str]] = None,
|
10
|
+
):
|
11
|
+
r"""Declare an overridable configuration file for a launch job.
|
12
|
+
|
13
|
+
If a new job version is created from the active run, the configuration file
|
14
|
+
will be added to the job's inputs. If the job is launched and overrides
|
15
|
+
have been provided for the configuration file, this function will detect
|
16
|
+
the overrides from the environment and update the configuration file on disk.
|
17
|
+
Note that these overrides will only be applied in ephemeral containers.
|
18
|
+
`include` and `exclude` are lists of dot separated paths with the config.
|
19
|
+
The paths are used to filter subtrees of the configuration file out of the
|
20
|
+
job's inputs.
|
21
|
+
|
22
|
+
For example, given the following configuration file:
|
23
|
+
```yaml
|
24
|
+
model:
|
25
|
+
name: resnet
|
26
|
+
layers: 18
|
27
|
+
training:
|
28
|
+
epochs: 10
|
29
|
+
batch_size: 32
|
30
|
+
```
|
31
|
+
|
32
|
+
Passing `include=['model']` will only include the `model` subtree in the
|
33
|
+
job's inputs. Passing `exclude=['model.layers']` will exclude the `layers`
|
34
|
+
key from the `model` subtree. Note that `exclude` takes precedence over
|
35
|
+
`include`.
|
36
|
+
|
37
|
+
`.` is used as a separator for nested keys. If a key contains a `.`, it
|
38
|
+
should be escaped with a backslash, e.g. `include=[r'model\.layers']`. Note
|
39
|
+
the use of `r` to denote a raw string when using escape chars.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
path (str): The path to the configuration file. This path must be
|
43
|
+
relative and must not contain backwards traversal, i.e. `..`.
|
44
|
+
include (List[str]): A list of keys to include in the configuration file.
|
45
|
+
exclude (List[str]): A list of keys to exclude from the configuration file.
|
46
|
+
|
47
|
+
Raises:
|
48
|
+
LaunchError: If the path is not valid, or if there is no active run.
|
49
|
+
"""
|
50
|
+
from .internal import handle_config_file_input
|
51
|
+
|
52
|
+
return handle_config_file_input(path, include, exclude)
|
53
|
+
|
54
|
+
|
55
|
+
def manage_wandb_config(
|
56
|
+
include: Optional[List[str]] = None,
|
57
|
+
exclude: Optional[List[str]] = None,
|
58
|
+
):
|
59
|
+
r"""Declare wandb.config as an overridable configuration for a launch job.
|
60
|
+
|
61
|
+
If a new job version is created from the active run, the run config
|
62
|
+
(wandb.config) will become an overridable input of the job. If the job is
|
63
|
+
launched and overrides have been provided for the run config, the overrides
|
64
|
+
will be applied to the run config when `wandb.init` is called.
|
65
|
+
`include` and `exclude` are lists of dot separated paths with the config.
|
66
|
+
The paths are used to filter subtrees of the configuration file out of the
|
67
|
+
job's inputs.
|
68
|
+
|
69
|
+
For example, given the following run config contents:
|
70
|
+
```yaml
|
71
|
+
model:
|
72
|
+
name: resnet
|
73
|
+
layers: 18
|
74
|
+
training:
|
75
|
+
epochs: 10
|
76
|
+
batch_size: 32
|
77
|
+
```
|
78
|
+
Passing `include=['model']` will only include the `model` subtree in the
|
79
|
+
job's inputs. Passing `exclude=['model.layers']` will exclude the `layers`
|
80
|
+
key from the `model` subtree. Note that `exclude` takes precedence over
|
81
|
+
`include`.
|
82
|
+
`.` is used as a separator for nested keys. If a key contains a `.`, it
|
83
|
+
should be escaped with a backslash, e.g. `include=[r'model\.layers']`. Note
|
84
|
+
the use of `r` to denote a raw string when using escape chars.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
include (List[str]): A list of subtrees to include in the configuration.
|
88
|
+
exclude (List[str]): A list of subtrees to exclude from the configuration.
|
89
|
+
|
90
|
+
Raises:
|
91
|
+
LaunchError: If there is no active run.
|
92
|
+
"""
|
93
|
+
from .internal import handle_run_config_input
|
94
|
+
|
95
|
+
handle_run_config_input(include, exclude)
|
wandb/sdk/launch/loader.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
"""Implementation of Google Artifact Registry for wandb launch."""
|
2
|
+
|
2
3
|
import logging
|
3
4
|
from typing import Optional, Tuple
|
4
5
|
|
@@ -210,7 +211,7 @@ class GoogleArtifactRegistry(AbstractRegistry):
|
|
210
211
|
for image in await list_images(request={"parent": parent}):
|
211
212
|
if tag in image.tags:
|
212
213
|
return True
|
213
|
-
except google.api_core.exceptions.NotFound as e:
|
214
|
+
except google.api_core.exceptions.NotFound as e: # type: ignore[attr-defined]
|
214
215
|
raise LaunchError(
|
215
216
|
f"The Google Artifact Registry repository {self.repository} "
|
216
217
|
f"does not exist. Please create it or modify your registry configuration."
|
@@ -1,4 +1,5 @@
|
|
1
1
|
"""Monitors kubernetes resources managed by the launch agent."""
|
2
|
+
|
2
3
|
import asyncio
|
3
4
|
import logging
|
4
5
|
import sys
|
@@ -433,6 +434,8 @@ class SafeWatch:
|
|
433
434
|
del kwargs["resource_version"]
|
434
435
|
self._last_seen_resource_version = None
|
435
436
|
except Exception as E:
|
437
|
+
exc_type = type(E).__name__
|
438
|
+
stack_trace = traceback.format_exc()
|
436
439
|
wandb.termerror(
|
437
|
-
f"Unknown exception in event stream: {E}, attempting to recover"
|
440
|
+
f"Unknown exception in event stream of type {exc_type}: {E}, attempting to recover. Stack trace: {stack_trace}"
|
438
441
|
)
|
@@ -1,4 +1,5 @@
|
|
1
1
|
"""Implementation of KubernetesRunner class for wandb launch."""
|
2
|
+
|
2
3
|
import asyncio
|
3
4
|
import base64
|
4
5
|
import datetime
|
@@ -28,7 +29,6 @@ from wandb.sdk.lib.retry import ExponentialBackoff, retry_async
|
|
28
29
|
from wandb.util import get_module
|
29
30
|
|
30
31
|
from .._project_spec import EntryPoint, LaunchProject
|
31
|
-
from ..builder.build import get_env_vars_dict
|
32
32
|
from ..errors import LaunchError
|
33
33
|
from ..utils import (
|
34
34
|
LOG_PREFIX,
|
@@ -373,8 +373,7 @@ class KubernetesRunner(AbstractRunner):
|
|
373
373
|
}
|
374
374
|
|
375
375
|
entry_point = (
|
376
|
-
launch_project.override_entrypoint
|
377
|
-
or launch_project.get_single_entry_point()
|
376
|
+
launch_project.override_entrypoint or launch_project.get_job_entry_point()
|
378
377
|
)
|
379
378
|
if launch_project.docker_image:
|
380
379
|
# dont specify run id if user provided image, could have multiple runs
|
@@ -400,8 +399,8 @@ class KubernetesRunner(AbstractRunner):
|
|
400
399
|
launch_project.override_entrypoint is not None,
|
401
400
|
)
|
402
401
|
|
403
|
-
env_vars = get_env_vars_dict(
|
404
|
-
|
402
|
+
env_vars = launch_project.get_env_vars_dict(
|
403
|
+
self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
|
405
404
|
)
|
406
405
|
api_key_secret = None
|
407
406
|
for cont in containers:
|
@@ -510,8 +509,8 @@ class KubernetesRunner(AbstractRunner):
|
|
510
509
|
api_version = resource_args.get("apiVersion", "batch/v1")
|
511
510
|
|
512
511
|
if api_version not in ["batch/v1", "batch/v1beta1"]:
|
513
|
-
env_vars = get_env_vars_dict(
|
514
|
-
|
512
|
+
env_vars = launch_project.get_env_vars_dict(
|
513
|
+
self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
|
515
514
|
)
|
516
515
|
# Crawl the resource args and add our env vars to the containers.
|
517
516
|
add_wandb_env(resource_args, env_vars)
|
@@ -539,9 +538,9 @@ class KubernetesRunner(AbstractRunner):
|
|
539
538
|
WANDB_K8S_LABEL_MONITOR,
|
540
539
|
LaunchAgent.name(),
|
541
540
|
)
|
542
|
-
resource_args["metadata"]["labels"][
|
543
|
-
|
544
|
-
|
541
|
+
resource_args["metadata"]["labels"][WANDB_K8S_LABEL_AGENT] = (
|
542
|
+
LaunchAgent.name()
|
543
|
+
)
|
545
544
|
|
546
545
|
overrides = {}
|
547
546
|
if launch_project.override_args:
|
@@ -12,7 +12,6 @@ from wandb.sdk.launch.environment.abstract import AbstractEnvironment
|
|
12
12
|
from wandb.sdk.launch.registry.abstract import AbstractRegistry
|
13
13
|
|
14
14
|
from .._project_spec import LaunchProject
|
15
|
-
from ..builder.build import get_env_vars_dict
|
16
15
|
from ..errors import LaunchError
|
17
16
|
from ..utils import (
|
18
17
|
LOG_PREFIX,
|
@@ -133,8 +132,8 @@ class LocalContainerRunner(AbstractRunner):
|
|
133
132
|
docker_args = self._populate_docker_args(launch_project, image_uri)
|
134
133
|
synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
|
135
134
|
|
136
|
-
env_vars = get_env_vars_dict(
|
137
|
-
|
135
|
+
env_vars = launch_project.get_env_vars_dict(
|
136
|
+
self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
|
138
137
|
)
|
139
138
|
|
140
139
|
# When running against local port, need to swap to local docker host
|
@@ -4,16 +4,12 @@ from typing import Any, List, Optional
|
|
4
4
|
|
5
5
|
import wandb
|
6
6
|
|
7
|
-
from .._project_spec import LaunchProject
|
8
|
-
from ..builder.build import get_env_vars_dict
|
7
|
+
from .._project_spec import LaunchProject
|
9
8
|
from ..errors import LaunchError
|
10
9
|
from ..utils import (
|
11
10
|
LOG_PREFIX,
|
12
11
|
MAX_ENV_LENGTHS,
|
13
12
|
PROJECT_SYNCHRONOUS,
|
14
|
-
_is_wandb_uri,
|
15
|
-
download_wandb_python_deps,
|
16
|
-
parse_wandb_uri,
|
17
13
|
sanitize_wandb_api_key,
|
18
14
|
validate_wandb_python_deps,
|
19
15
|
)
|
@@ -47,8 +43,7 @@ class LocalProcessRunner(AbstractRunner):
|
|
47
43
|
|
48
44
|
synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
|
49
45
|
entry_point = (
|
50
|
-
launch_project.override_entrypoint
|
51
|
-
or launch_project.get_single_entry_point()
|
46
|
+
launch_project.override_entrypoint or launch_project.get_job_entry_point()
|
52
47
|
)
|
53
48
|
|
54
49
|
cmd: List[Any] = []
|
@@ -56,23 +51,7 @@ class LocalProcessRunner(AbstractRunner):
|
|
56
51
|
if launch_project.project_dir is None:
|
57
52
|
raise LaunchError("Launch LocalProcessRunner received empty project dir")
|
58
53
|
|
59
|
-
|
60
|
-
if launch_project.uri and _is_wandb_uri(launch_project.uri):
|
61
|
-
source_entity, source_project, run_name = parse_wandb_uri(
|
62
|
-
launch_project.uri
|
63
|
-
)
|
64
|
-
run_requirements_file = download_wandb_python_deps(
|
65
|
-
source_entity,
|
66
|
-
source_project,
|
67
|
-
run_name,
|
68
|
-
self._api,
|
69
|
-
launch_project.project_dir,
|
70
|
-
)
|
71
|
-
validate_wandb_python_deps(
|
72
|
-
run_requirements_file,
|
73
|
-
launch_project.project_dir,
|
74
|
-
)
|
75
|
-
elif launch_project.job:
|
54
|
+
if launch_project.job:
|
76
55
|
assert launch_project._job_artifact is not None
|
77
56
|
try:
|
78
57
|
validate_wandb_python_deps(
|
@@ -81,14 +60,14 @@ class LocalProcessRunner(AbstractRunner):
|
|
81
60
|
)
|
82
61
|
except Exception:
|
83
62
|
wandb.termwarn("Unable to validate python dependencies")
|
84
|
-
env_vars = get_env_vars_dict(
|
85
|
-
|
63
|
+
env_vars = launch_project.get_env_vars_dict(
|
64
|
+
self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
|
86
65
|
)
|
87
66
|
for env_key, env_value in env_vars.items():
|
88
67
|
cmd += [f"{shlex.quote(env_key)}={shlex.quote(env_value)}"]
|
89
|
-
|
90
|
-
|
91
|
-
cmd +=
|
68
|
+
if entry_point is not None:
|
69
|
+
cmd += entry_point.command
|
70
|
+
cmd += launch_project.override_args
|
92
71
|
|
93
72
|
command_str = " ".join(cmd).strip()
|
94
73
|
_msg = f"{LOG_PREFIX}Launching run as a local-process with command {sanitize_wandb_api_key(command_str)}"
|
@@ -1,4 +1,5 @@
|
|
1
1
|
"""Implementation of the SageMakerRunner class."""
|
2
|
+
|
2
3
|
import asyncio
|
3
4
|
import logging
|
4
5
|
from typing import Any, Dict, List, Optional, cast
|
@@ -11,8 +12,7 @@ from wandb.apis.internal import Api
|
|
11
12
|
from wandb.sdk.launch.environment.aws_environment import AwsEnvironment
|
12
13
|
from wandb.sdk.launch.errors import LaunchError
|
13
14
|
|
14
|
-
from .._project_spec import EntryPoint, LaunchProject
|
15
|
-
from ..builder.build import get_env_vars_dict
|
15
|
+
from .._project_spec import EntryPoint, LaunchProject
|
16
16
|
from ..registry.abstract import AbstractRegistry
|
17
17
|
from ..utils import (
|
18
18
|
LOG_PREFIX,
|
@@ -67,6 +67,7 @@ class SagemakerSubmittedRun(AbstractRun):
|
|
67
67
|
logGroupName="/aws/sagemaker/TrainingJobs",
|
68
68
|
logStreamName=log_name,
|
69
69
|
)
|
70
|
+
assert "events" in res
|
70
71
|
return "\n".join(
|
71
72
|
[f'{event["timestamp"]}:{event["message"]}' for event in res["events"]]
|
72
73
|
)
|
@@ -220,12 +221,12 @@ class SageMakerRunner(AbstractRunner):
|
|
220
221
|
launch_project.fill_macros(image_uri)
|
221
222
|
_logger.info("Connecting to sagemaker client")
|
222
223
|
entry_point = (
|
223
|
-
launch_project.override_entrypoint
|
224
|
-
or launch_project.get_single_entry_point()
|
225
|
-
)
|
226
|
-
command_args = get_entry_point_command(
|
227
|
-
entry_point, launch_project.override_args
|
224
|
+
launch_project.override_entrypoint or launch_project.get_job_entry_point()
|
228
225
|
)
|
226
|
+
command_args = []
|
227
|
+
if entry_point is not None:
|
228
|
+
command_args += entry_point.command
|
229
|
+
command_args += launch_project.override_args
|
229
230
|
if command_args:
|
230
231
|
command_str = " ".join(command_args)
|
231
232
|
wandb.termlog(
|
@@ -324,16 +325,16 @@ def build_sagemaker_args(
|
|
324
325
|
sagemaker_args["TrainingJobName"] = training_job_name
|
325
326
|
entry_cmd = entry_point.command if entry_point else []
|
326
327
|
|
327
|
-
sagemaker_args[
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
328
|
+
sagemaker_args["AlgorithmSpecification"] = (
|
329
|
+
merge_image_uri_with_algorithm_specification(
|
330
|
+
given_sagemaker_args.get(
|
331
|
+
"AlgorithmSpecification",
|
332
|
+
given_sagemaker_args.get("algorithm_specification"),
|
333
|
+
),
|
334
|
+
image_uri,
|
335
|
+
entry_cmd,
|
336
|
+
args,
|
337
|
+
)
|
337
338
|
)
|
338
339
|
|
339
340
|
sagemaker_args["RoleArn"] = role_arn
|
@@ -348,18 +349,18 @@ def build_sagemaker_args(
|
|
348
349
|
|
349
350
|
if sagemaker_args.get("ResourceConfig") is None:
|
350
351
|
raise LaunchError(
|
351
|
-
"Sagemaker launcher requires a ResourceConfig
|
352
|
+
"Sagemaker launcher requires a ResourceConfig resource argument"
|
352
353
|
)
|
353
354
|
|
354
355
|
if sagemaker_args.get("StoppingCondition") is None:
|
355
356
|
raise LaunchError(
|
356
|
-
"Sagemaker launcher requires a StoppingCondition
|
357
|
+
"Sagemaker launcher requires a StoppingCondition resource argument"
|
357
358
|
)
|
358
359
|
|
359
360
|
given_env = given_sagemaker_args.get(
|
360
361
|
"Environment", sagemaker_args.get("environment", {})
|
361
362
|
)
|
362
|
-
calced_env = get_env_vars_dict(
|
363
|
+
calced_env = launch_project.get_env_vars_dict(api, max_env_length)
|
363
364
|
total_env = {**calced_env, **given_env}
|
364
365
|
sagemaker_args["Environment"] = total_env
|
365
366
|
|
@@ -8,8 +8,7 @@ if False:
|
|
8
8
|
from wandb.apis.internal import Api
|
9
9
|
from wandb.util import get_module
|
10
10
|
|
11
|
-
from .._project_spec import LaunchProject
|
12
|
-
from ..builder.build import get_env_vars_dict
|
11
|
+
from .._project_spec import LaunchProject
|
13
12
|
from ..environment.gcp_environment import GcpEnvironment
|
14
13
|
from ..errors import LaunchError
|
15
14
|
from ..registry.abstract import AbstractRegistry
|
@@ -113,14 +112,16 @@ class VertexRunner(AbstractRunner):
|
|
113
112
|
synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
|
114
113
|
|
115
114
|
entry_point = (
|
116
|
-
launch_project.override_entrypoint
|
117
|
-
or launch_project.get_single_entry_point()
|
115
|
+
launch_project.override_entrypoint or launch_project.get_job_entry_point()
|
118
116
|
)
|
119
117
|
|
120
118
|
# TODO: Set entrypoint in each container
|
121
|
-
entry_cmd =
|
122
|
-
|
123
|
-
|
119
|
+
entry_cmd = []
|
120
|
+
if entry_point is not None:
|
121
|
+
entry_cmd += entry_point.command
|
122
|
+
entry_cmd += launch_project.override_args
|
123
|
+
|
124
|
+
env_vars = launch_project.get_env_vars_dict(
|
124
125
|
api=self._api,
|
125
126
|
max_env_length=MAX_ENV_LENGTHS[self.__class__.__name__],
|
126
127
|
)
|
@@ -1,4 +1,5 @@
|
|
1
1
|
"""Abstract Scheduler class."""
|
2
|
+
|
2
3
|
import asyncio
|
3
4
|
import base64
|
4
5
|
import copy
|
@@ -157,7 +158,9 @@ class Scheduler(ABC):
|
|
157
158
|
self._runs: Dict[str, SweepRun] = {}
|
158
159
|
# Threading lock to ensure thread-safe access to the runs dictionary
|
159
160
|
self._threading_lock: threading.Lock = threading.Lock()
|
160
|
-
self._polling_sleep =
|
161
|
+
self._polling_sleep = (
|
162
|
+
polling_sleep if polling_sleep is not None else DEFAULT_POLLING_SLEEP
|
163
|
+
)
|
161
164
|
self._project_queue = project_queue
|
162
165
|
# Optionally run multiple workers in (pseudo-)parallel. Workers do not
|
163
166
|
# actually run training workloads, they simply send heartbeat messages
|
@@ -405,7 +408,7 @@ class Scheduler(ABC):
|
|
405
408
|
return count
|
406
409
|
|
407
410
|
def _try_load_executable(self) -> bool:
|
408
|
-
"""Check
|
411
|
+
"""Check existence of valid executable for a run.
|
409
412
|
|
410
413
|
logs and returns False when job is unreachable
|
411
414
|
"""
|
@@ -420,7 +423,7 @@ class Scheduler(ABC):
|
|
420
423
|
return False
|
421
424
|
return True
|
422
425
|
elif self._kwargs.get("image_uri"):
|
423
|
-
# TODO(gst): check docker
|
426
|
+
# TODO(gst): check docker existence? Use registry in launch config?
|
424
427
|
return True
|
425
428
|
else:
|
426
429
|
return False
|
@@ -608,7 +611,7 @@ class Scheduler(ABC):
|
|
608
611
|
f"Failed to get runstate for run ({run_id}). Error: {traceback.format_exc()}"
|
609
612
|
)
|
610
613
|
run_state = RunState.FAILED
|
611
|
-
else: # first time we get
|
614
|
+
else: # first time we get unknown state
|
612
615
|
run_state = RunState.UNKNOWN
|
613
616
|
except (AttributeError, ValueError):
|
614
617
|
wandb.termwarn(
|
@@ -1,4 +1,5 @@
|
|
1
1
|
"""Scheduler for classic wandb Sweeps."""
|
2
|
+
|
2
3
|
import logging
|
3
4
|
from pprint import pformat as pf
|
4
5
|
from typing import Any, Dict, List, Optional
|
@@ -58,7 +59,7 @@ class SweepScheduler(Scheduler):
|
|
58
59
|
return None
|
59
60
|
|
60
61
|
def _get_sweep_commands(self, worker_id: int) -> List[Dict[str, Any]]:
|
61
|
-
"""Helper to
|
62
|
+
"""Helper to receive sweep command from backend."""
|
62
63
|
# AgentHeartbeat wants a Dict of runs which are running or queued
|
63
64
|
_run_states: Dict[str, bool] = {}
|
64
65
|
for run_id, run in self._yield_runs():
|
wandb/sdk/launch/sweeps/utils.py
CHANGED
@@ -217,7 +217,7 @@ def create_sweep_command_args(command: Dict) -> Dict[str, Any]:
|
|
217
217
|
flags: List[str] = []
|
218
218
|
# (2) flags without hyphens (e.g. foo=bar)
|
219
219
|
flags_no_hyphens: List[str] = []
|
220
|
-
# (3) flags with false booleans
|
220
|
+
# (3) flags with false booleans omitted (e.g. --foo)
|
221
221
|
flags_no_booleans: List[str] = []
|
222
222
|
# (4) flags as a dictionary (used for constructing a json)
|
223
223
|
flags_dict: Dict[str, Any] = {}
|
@@ -257,7 +257,7 @@ def make_launch_sweep_entrypoint(
|
|
257
257
|
"""Use args dict from create_sweep_command_args to construct entrypoint.
|
258
258
|
|
259
259
|
If replace is True, remove macros from entrypoint, fill them in with args
|
260
|
-
and then return the args in
|
260
|
+
and then return the args in separate return value.
|
261
261
|
"""
|
262
262
|
if not command:
|
263
263
|
return None, None
|
@@ -296,7 +296,7 @@ def check_job_exists(public_api: "PublicApi", job: Optional[str]) -> bool:
|
|
296
296
|
|
297
297
|
|
298
298
|
def get_previous_args(
|
299
|
-
run_spec: Dict[str, Any]
|
299
|
+
run_spec: Dict[str, Any],
|
300
300
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
301
301
|
"""Parse through previous scheduler run_spec.
|
302
302
|
|