wandb 0.17.0rc2__py3-none-win_amd64.whl → 0.17.1__py3-none-win_amd64.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -2
- wandb/apis/importers/internals/internal.py +0 -1
- wandb/apis/importers/wandb.py +12 -7
- wandb/apis/internal.py +0 -3
- wandb/apis/public/api.py +213 -79
- wandb/apis/public/artifacts.py +335 -100
- wandb/apis/public/files.py +9 -9
- wandb/apis/public/jobs.py +16 -4
- wandb/apis/public/projects.py +26 -28
- wandb/apis/public/query_generator.py +1 -1
- wandb/apis/public/runs.py +163 -65
- wandb/apis/public/sweeps.py +2 -2
- wandb/apis/reports/__init__.py +1 -7
- wandb/apis/reports/v1/__init__.py +5 -27
- wandb/apis/reports/v2/__init__.py +7 -19
- wandb/apis/workspaces/__init__.py +8 -0
- wandb/beta/workflows.py +8 -3
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +131 -59
- wandb/docker/__init__.py +1 -1
- wandb/errors/term.py +10 -2
- wandb/filesync/step_checksum.py +1 -4
- wandb/filesync/step_prepare.py +4 -24
- wandb/filesync/step_upload.py +5 -107
- wandb/filesync/upload_job.py +0 -76
- wandb/integration/gym/__init__.py +35 -15
- wandb/integration/openai/fine_tuning.py +21 -3
- wandb/integration/prodigy/prodigy.py +1 -1
- wandb/jupyter.py +16 -17
- wandb/plot/pr_curve.py +2 -1
- wandb/plot/roc_curve.py +2 -1
- wandb/{plots → plot}/utils.py +13 -25
- wandb/proto/v3/wandb_internal_pb2.py +54 -54
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +54 -54
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v5/wandb_base_pb2.py +30 -0
- wandb/proto/v5/wandb_internal_pb2.py +355 -0
- wandb/proto/v5/wandb_server_pb2.py +63 -0
- wandb/proto/v5/wandb_settings_pb2.py +45 -0
- wandb/proto/v5/wandb_telemetry_pb2.py +41 -0
- wandb/proto/wandb_base_pb2.py +2 -0
- wandb/proto/wandb_deprecated.py +9 -1
- wandb/proto/wandb_generate_deprecated.py +34 -0
- wandb/proto/{wandb_internal_codegen.py → wandb_generate_proto.py} +1 -35
- wandb/proto/wandb_internal_pb2.py +2 -0
- wandb/proto/wandb_server_pb2.py +2 -0
- wandb/proto/wandb_settings_pb2.py +2 -0
- wandb/proto/wandb_telemetry_pb2.py +2 -0
- wandb/sdk/artifacts/artifact.py +68 -22
- wandb/sdk/artifacts/artifact_manifest.py +1 -1
- wandb/sdk/artifacts/artifact_manifest_entry.py +6 -3
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -1
- wandb/sdk/artifacts/artifact_saver.py +1 -10
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +6 -2
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +6 -4
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +2 -42
- wandb/sdk/artifacts/storage_policy.py +1 -12
- wandb/sdk/data_types/image.py +1 -1
- wandb/sdk/data_types/video.py +4 -2
- wandb/sdk/interface/interface.py +13 -0
- wandb/sdk/interface/interface_shared.py +1 -1
- wandb/sdk/internal/file_pusher.py +2 -5
- wandb/sdk/internal/file_stream.py +6 -19
- wandb/sdk/internal/internal_api.py +148 -136
- wandb/sdk/internal/job_builder.py +207 -135
- wandb/sdk/internal/progress.py +0 -28
- wandb/sdk/internal/sender.py +102 -39
- wandb/sdk/internal/settings_static.py +8 -1
- wandb/sdk/internal/system/assets/trainium.py +3 -3
- wandb/sdk/internal/system/system_info.py +4 -2
- wandb/sdk/internal/update.py +1 -1
- wandb/sdk/launch/__init__.py +9 -1
- wandb/sdk/launch/_launch.py +4 -24
- wandb/sdk/launch/_launch_add.py +1 -3
- wandb/sdk/launch/_project_spec.py +184 -224
- wandb/sdk/launch/agent/agent.py +58 -18
- wandb/sdk/launch/agent/config.py +0 -3
- wandb/sdk/launch/builder/abstract.py +67 -0
- wandb/sdk/launch/builder/build.py +165 -576
- wandb/sdk/launch/builder/context_manager.py +235 -0
- wandb/sdk/launch/builder/docker_builder.py +7 -23
- wandb/sdk/launch/builder/kaniko_builder.py +10 -23
- wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
- wandb/sdk/launch/create_job.py +51 -45
- wandb/sdk/launch/environment/aws_environment.py +26 -1
- wandb/sdk/launch/inputs/files.py +148 -0
- wandb/sdk/launch/inputs/internal.py +224 -0
- wandb/sdk/launch/inputs/manage.py +95 -0
- wandb/sdk/launch/runner/abstract.py +2 -2
- wandb/sdk/launch/runner/kubernetes_monitor.py +45 -12
- wandb/sdk/launch/runner/kubernetes_runner.py +6 -8
- wandb/sdk/launch/runner/local_container.py +2 -3
- wandb/sdk/launch/runner/local_process.py +8 -29
- wandb/sdk/launch/runner/sagemaker_runner.py +20 -14
- wandb/sdk/launch/runner/vertex_runner.py +8 -7
- wandb/sdk/launch/sweeps/scheduler.py +2 -0
- wandb/sdk/launch/sweeps/utils.py +2 -2
- wandb/sdk/launch/utils.py +16 -138
- wandb/sdk/lib/_settings_toposort_generated.py +2 -5
- wandb/sdk/lib/apikey.py +4 -2
- wandb/sdk/lib/config_util.py +3 -3
- wandb/sdk/lib/proto_util.py +22 -1
- wandb/sdk/lib/redirect.py +1 -1
- wandb/sdk/service/service.py +2 -1
- wandb/sdk/service/streams.py +5 -5
- wandb/sdk/wandb_init.py +25 -59
- wandb/sdk/wandb_login.py +28 -25
- wandb/sdk/wandb_run.py +112 -45
- wandb/sdk/wandb_settings.py +33 -64
- wandb/sdk/wandb_watch.py +1 -1
- wandb/sklearn/plot/classifier.py +4 -6
- wandb/sync/sync.py +2 -2
- wandb/testing/relay.py +32 -17
- wandb/util.py +36 -37
- wandb/wandb_agent.py +3 -3
- wandb/wandb_controller.py +3 -2
- {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/METADATA +7 -9
- {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/RECORD +125 -147
- wandb/apis/reports/v1/_blocks.py +0 -1406
- wandb/apis/reports/v1/_helpers.py +0 -70
- wandb/apis/reports/v1/_panels.py +0 -1282
- wandb/apis/reports/v1/_templates.py +0 -478
- wandb/apis/reports/v1/blocks.py +0 -27
- wandb/apis/reports/v1/helpers.py +0 -2
- wandb/apis/reports/v1/mutations.py +0 -66
- wandb/apis/reports/v1/panels.py +0 -17
- wandb/apis/reports/v1/report.py +0 -268
- wandb/apis/reports/v1/runset.py +0 -144
- wandb/apis/reports/v1/templates.py +0 -7
- wandb/apis/reports/v1/util.py +0 -406
- wandb/apis/reports/v1/validators.py +0 -131
- wandb/apis/reports/v2/blocks.py +0 -25
- wandb/apis/reports/v2/expr_parsing.py +0 -257
- wandb/apis/reports/v2/gql.py +0 -68
- wandb/apis/reports/v2/interface.py +0 -1911
- wandb/apis/reports/v2/internal.py +0 -867
- wandb/apis/reports/v2/metrics.py +0 -6
- wandb/apis/reports/v2/panels.py +0 -15
- wandb/catboost/__init__.py +0 -9
- wandb/fastai/__init__.py +0 -9
- wandb/keras/__init__.py +0 -19
- wandb/lightgbm/__init__.py +0 -9
- wandb/plots/__init__.py +0 -6
- wandb/plots/explain_text.py +0 -36
- wandb/plots/heatmap.py +0 -81
- wandb/plots/named_entity.py +0 -43
- wandb/plots/part_of_speech.py +0 -50
- wandb/plots/plot_definitions.py +0 -768
- wandb/plots/precision_recall.py +0 -121
- wandb/plots/roc.py +0 -103
- wandb/sacred/__init__.py +0 -3
- wandb/xgboost/__init__.py +0 -9
- {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/WHEEL +0 -0
- {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/entry_points.txt +0 -0
- {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/licenses/LICENSE +0 -0
@@ -7,7 +7,7 @@ from typing import Dict, Optional
|
|
7
7
|
from wandb.sdk.launch.errors import LaunchError
|
8
8
|
from wandb.util import get_module
|
9
9
|
|
10
|
-
from ..utils import S3_URI_RE, event_loop_thread_exec
|
10
|
+
from ..utils import ARN_PARTITION_RE, S3_URI_RE, event_loop_thread_exec
|
11
11
|
from .abstract import AbstractEnvironment
|
12
12
|
|
13
13
|
boto3 = get_module(
|
@@ -49,6 +49,7 @@ class AwsEnvironment(AbstractEnvironment):
|
|
49
49
|
self._secret_key = secret_key
|
50
50
|
self._session_token = session_token
|
51
51
|
self._account = None
|
52
|
+
self._partition = None
|
52
53
|
|
53
54
|
@classmethod
|
54
55
|
def from_default(cls, region: Optional[str] = None) -> "AwsEnvironment":
|
@@ -122,6 +123,30 @@ class AwsEnvironment(AbstractEnvironment):
|
|
122
123
|
def region(self, region: str) -> None:
|
123
124
|
self._region = region
|
124
125
|
|
126
|
+
async def get_partition(self) -> str:
|
127
|
+
"""Set the partition for the AWS environment."""
|
128
|
+
try:
|
129
|
+
session = await self.get_session()
|
130
|
+
client = await event_loop_thread_exec(session.client)("sts")
|
131
|
+
get_caller_identity = event_loop_thread_exec(client.get_caller_identity)
|
132
|
+
identity = await get_caller_identity()
|
133
|
+
arn = identity.get("Arn")
|
134
|
+
if not arn:
|
135
|
+
raise LaunchError(
|
136
|
+
"Could not set partition for AWS environment. ARN not found."
|
137
|
+
)
|
138
|
+
matched_partition = ARN_PARTITION_RE.match(arn)
|
139
|
+
if not matched_partition:
|
140
|
+
raise LaunchError(
|
141
|
+
f"Could not set partition for AWS environment. ARN {arn} is not valid."
|
142
|
+
)
|
143
|
+
partition = matched_partition.group(1)
|
144
|
+
return partition
|
145
|
+
except botocore.exceptions.ClientError as e:
|
146
|
+
raise LaunchError(
|
147
|
+
f"Could not set partition for AWS environment. {e}"
|
148
|
+
) from e
|
149
|
+
|
125
150
|
async def verify(self) -> None:
|
126
151
|
"""Verify that the AWS environment is configured correctly.
|
127
152
|
|
@@ -0,0 +1,148 @@
|
|
1
|
+
import json
|
2
|
+
import os
|
3
|
+
from typing import Any, Dict
|
4
|
+
|
5
|
+
import yaml
|
6
|
+
|
7
|
+
from ..errors import LaunchError
|
8
|
+
|
9
|
+
FILE_OVERRIDE_ENV_VAR = "WANDB_LAUNCH_FILE_OVERRIDES"
|
10
|
+
|
11
|
+
|
12
|
+
class FileOverrides:
|
13
|
+
"""Singleton that read file overrides json from environment variables."""
|
14
|
+
|
15
|
+
_instance = None
|
16
|
+
|
17
|
+
def __new__(cls):
|
18
|
+
if cls._instance is None:
|
19
|
+
cls._instance = object.__new__(cls)
|
20
|
+
cls._instance.overrides = {}
|
21
|
+
cls._instance.load()
|
22
|
+
return cls._instance
|
23
|
+
|
24
|
+
def load(self) -> None:
|
25
|
+
"""Load overrides from an environment variable."""
|
26
|
+
overrides = os.environ.get(FILE_OVERRIDE_ENV_VAR)
|
27
|
+
if overrides is None:
|
28
|
+
if f"{FILE_OVERRIDE_ENV_VAR}_0" in os.environ:
|
29
|
+
overrides = ""
|
30
|
+
idx = 0
|
31
|
+
while f"{FILE_OVERRIDE_ENV_VAR}_{idx}" in os.environ:
|
32
|
+
overrides += os.environ[f"{FILE_OVERRIDE_ENV_VAR}_{idx}"]
|
33
|
+
idx += 1
|
34
|
+
if overrides:
|
35
|
+
try:
|
36
|
+
contents = json.loads(overrides)
|
37
|
+
if not isinstance(contents, dict):
|
38
|
+
raise LaunchError(f"Invalid JSON in {FILE_OVERRIDE_ENV_VAR}")
|
39
|
+
self.overrides = contents
|
40
|
+
except json.JSONDecodeError:
|
41
|
+
raise LaunchError(f"Invalid JSON in {FILE_OVERRIDE_ENV_VAR}")
|
42
|
+
|
43
|
+
|
44
|
+
def config_path_is_valid(path: str) -> None:
|
45
|
+
"""Validate a config file path.
|
46
|
+
|
47
|
+
This function checks if a given config file path is valid. A valid path
|
48
|
+
should meet the following criteria:
|
49
|
+
|
50
|
+
- The path must be expressed as a relative path without any upwards path
|
51
|
+
traversal, e.g. `../config.json`.
|
52
|
+
- The file specified by the path must exist.
|
53
|
+
- The file must have a supported extension (`.json`, `.yaml`, or `.yml`).
|
54
|
+
|
55
|
+
Args:
|
56
|
+
path (str): The path to validate.
|
57
|
+
|
58
|
+
Raises:
|
59
|
+
LaunchError: If the path is not valid.
|
60
|
+
"""
|
61
|
+
if os.path.isabs(path):
|
62
|
+
raise LaunchError(
|
63
|
+
f"Invalid config path: {path}. Please provide a relative path."
|
64
|
+
)
|
65
|
+
if ".." in path:
|
66
|
+
raise LaunchError(
|
67
|
+
f"Invalid config path: {path}. Please provide a relative path "
|
68
|
+
"without any upward path traversal, e.g. `../config.json`."
|
69
|
+
)
|
70
|
+
path = os.path.normpath(path)
|
71
|
+
if not os.path.exists(path):
|
72
|
+
raise LaunchError(f"Invalid config path: {path}. File does not exist.")
|
73
|
+
if not any(path.endswith(ext) for ext in [".json", ".yaml", ".yml"]):
|
74
|
+
raise LaunchError(
|
75
|
+
f"Invalid config path: {path}. Only JSON and YAML files are supported."
|
76
|
+
)
|
77
|
+
|
78
|
+
|
79
|
+
def override_file(path: str) -> None:
|
80
|
+
"""Check for file overrides in the environment and apply them if found."""
|
81
|
+
file_overrides = FileOverrides()
|
82
|
+
if path in file_overrides.overrides:
|
83
|
+
overrides = file_overrides.overrides.get(path)
|
84
|
+
if overrides is not None:
|
85
|
+
config = _read_config_file(path)
|
86
|
+
_update_dict(config, overrides)
|
87
|
+
_write_config_file(path, config)
|
88
|
+
|
89
|
+
|
90
|
+
def _write_config_file(path: str, config: Any) -> None:
|
91
|
+
"""Write a config file to disk.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
path (str): The path to the config file.
|
95
|
+
config (Any): The contents of the config file as a Python object.
|
96
|
+
|
97
|
+
Raises:
|
98
|
+
LaunchError: If the file extension is not supported.
|
99
|
+
"""
|
100
|
+
_, ext = os.path.splitext(path)
|
101
|
+
if ext == ".json":
|
102
|
+
with open(path, "w") as f:
|
103
|
+
json.dump(config, f, indent=2)
|
104
|
+
elif ext in [".yaml", ".yml"]:
|
105
|
+
with open(path, "w") as f:
|
106
|
+
yaml.safe_dump(config, f)
|
107
|
+
else:
|
108
|
+
raise LaunchError(f"Unsupported file extension: {ext}")
|
109
|
+
|
110
|
+
|
111
|
+
def _read_config_file(path: str) -> Any:
|
112
|
+
"""Read a config file from disk.
|
113
|
+
|
114
|
+
Args:
|
115
|
+
path (str): The path to the config file.
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
Any: The contents of the config file as a Python object.
|
119
|
+
"""
|
120
|
+
_, ext = os.path.splitext(path)
|
121
|
+
if ext == ".json":
|
122
|
+
with open(
|
123
|
+
path,
|
124
|
+
) as f:
|
125
|
+
return json.load(f)
|
126
|
+
elif ext in [".yaml", ".yml"]:
|
127
|
+
with open(
|
128
|
+
path,
|
129
|
+
) as f:
|
130
|
+
return yaml.safe_load(f)
|
131
|
+
else:
|
132
|
+
raise LaunchError(f"Unsupported file extension: {ext}")
|
133
|
+
|
134
|
+
|
135
|
+
def _update_dict(target: Dict, source: Dict) -> None:
|
136
|
+
"""Update a dictionary with the contents of another dictionary.
|
137
|
+
|
138
|
+
Args:
|
139
|
+
target (Dict): The dictionary to update.
|
140
|
+
source (Dict): The dictionary to update from.
|
141
|
+
"""
|
142
|
+
for key, value in source.items():
|
143
|
+
if isinstance(value, dict):
|
144
|
+
if key not in target:
|
145
|
+
target[key] = {}
|
146
|
+
_update_dict(target[key], value)
|
147
|
+
else:
|
148
|
+
target[key] = value
|
@@ -0,0 +1,224 @@
|
|
1
|
+
"""The layer between launch sdk user code and the wandb internal process.
|
2
|
+
|
3
|
+
If there is an active run this communication is done through the wandb run's
|
4
|
+
backend interface.
|
5
|
+
|
6
|
+
If there is no active run, the messages are staged on the StagedLaunchInputs
|
7
|
+
singleton and sent when a run is created.
|
8
|
+
"""
|
9
|
+
|
10
|
+
import os
|
11
|
+
import pathlib
|
12
|
+
import shutil
|
13
|
+
import tempfile
|
14
|
+
from typing import List, Optional
|
15
|
+
|
16
|
+
import wandb
|
17
|
+
import wandb.data_types
|
18
|
+
from wandb.sdk.launch.errors import LaunchError
|
19
|
+
from wandb.sdk.wandb_run import Run
|
20
|
+
|
21
|
+
from .files import config_path_is_valid, override_file
|
22
|
+
|
23
|
+
PERIOD = "."
|
24
|
+
BACKSLASH = "\\"
|
25
|
+
LAUNCH_MANAGED_CONFIGS_DIR = "_wandb_configs"
|
26
|
+
|
27
|
+
|
28
|
+
class ConfigTmpDir:
|
29
|
+
"""Singleton for managing temporary directories for configuration files.
|
30
|
+
|
31
|
+
Any configuration files designated as inputs to a launch job are copied to
|
32
|
+
a temporary directory. This singleton manages the temporary directory and
|
33
|
+
provides paths to the configuration files.
|
34
|
+
"""
|
35
|
+
|
36
|
+
_instance = None
|
37
|
+
|
38
|
+
def __new__(cls):
|
39
|
+
if cls._instance is None:
|
40
|
+
cls._instance = object.__new__(cls)
|
41
|
+
return cls._instance
|
42
|
+
|
43
|
+
def __init__(self):
|
44
|
+
if not hasattr(self, "_tmp_dir"):
|
45
|
+
self._tmp_dir = tempfile.mkdtemp()
|
46
|
+
self._configs_dir = os.path.join(self._tmp_dir, LAUNCH_MANAGED_CONFIGS_DIR)
|
47
|
+
os.mkdir(self._configs_dir)
|
48
|
+
|
49
|
+
@property
|
50
|
+
def tmp_dir(self):
|
51
|
+
return pathlib.Path(self._tmp_dir)
|
52
|
+
|
53
|
+
@property
|
54
|
+
def configs_dir(self):
|
55
|
+
return pathlib.Path(self._configs_dir)
|
56
|
+
|
57
|
+
|
58
|
+
class JobInputArguments:
|
59
|
+
"""Arguments for the publish_job_input of Interface."""
|
60
|
+
|
61
|
+
def __init__(
|
62
|
+
self,
|
63
|
+
include: Optional[List[str]] = None,
|
64
|
+
exclude: Optional[List[str]] = None,
|
65
|
+
file_path: Optional[str] = None,
|
66
|
+
run_config: Optional[bool] = None,
|
67
|
+
):
|
68
|
+
self.include = include
|
69
|
+
self.exclude = exclude
|
70
|
+
self.file_path = file_path
|
71
|
+
self.run_config = run_config
|
72
|
+
|
73
|
+
|
74
|
+
class StagedLaunchInputs:
|
75
|
+
_instance = None
|
76
|
+
|
77
|
+
def __new__(cls):
|
78
|
+
if cls._instance is None:
|
79
|
+
cls._instance = object.__new__(cls)
|
80
|
+
return cls._instance
|
81
|
+
|
82
|
+
def __init__(self) -> None:
|
83
|
+
if not hasattr(self, "_staged_inputs"):
|
84
|
+
self._staged_inputs: List[JobInputArguments] = []
|
85
|
+
|
86
|
+
def add_staged_input(
|
87
|
+
self,
|
88
|
+
input_arguments: JobInputArguments,
|
89
|
+
):
|
90
|
+
self._staged_inputs.append(input_arguments)
|
91
|
+
|
92
|
+
def apply(self, run: Run):
|
93
|
+
"""Apply the staged inputs to the given run."""
|
94
|
+
for input in self._staged_inputs:
|
95
|
+
_publish_job_input(input, run)
|
96
|
+
|
97
|
+
|
98
|
+
def _publish_job_input(
|
99
|
+
input: JobInputArguments,
|
100
|
+
run: Run,
|
101
|
+
) -> None:
|
102
|
+
"""Publish a job input to the backend interface of the given run.
|
103
|
+
|
104
|
+
Arguments:
|
105
|
+
input (JobInputArguments): The arguments for the job input.
|
106
|
+
run (Run): The run to publish the job input to.
|
107
|
+
"""
|
108
|
+
assert run._backend is not None
|
109
|
+
assert run._backend.interface is not None
|
110
|
+
assert input.run_config is not None
|
111
|
+
|
112
|
+
interface = run._backend.interface
|
113
|
+
if input.file_path:
|
114
|
+
config_dir = ConfigTmpDir()
|
115
|
+
dest = os.path.join(config_dir.configs_dir, input.file_path)
|
116
|
+
run.save(dest, base_path=config_dir.tmp_dir)
|
117
|
+
interface.publish_job_input(
|
118
|
+
include_paths=[_split_on_unesc_dot(path) for path in input.include]
|
119
|
+
if input.include
|
120
|
+
else [],
|
121
|
+
exclude_paths=[_split_on_unesc_dot(path) for path in input.exclude]
|
122
|
+
if input.exclude
|
123
|
+
else [],
|
124
|
+
run_config=input.run_config,
|
125
|
+
file_path=input.file_path or "",
|
126
|
+
)
|
127
|
+
|
128
|
+
|
129
|
+
def handle_config_file_input(
|
130
|
+
path: str,
|
131
|
+
include: Optional[List[str]] = None,
|
132
|
+
exclude: Optional[List[str]] = None,
|
133
|
+
):
|
134
|
+
"""Declare an overridable configuration file for a launch job.
|
135
|
+
|
136
|
+
The configuration file is copied to a temporary directory and the path to
|
137
|
+
the copy is sent to the backend interface of the active run and used to
|
138
|
+
configure the job builder.
|
139
|
+
|
140
|
+
If there is no active run, the configuration file is staged and sent when a
|
141
|
+
run is created.
|
142
|
+
"""
|
143
|
+
config_path_is_valid(path)
|
144
|
+
override_file(path)
|
145
|
+
tmp_dir = ConfigTmpDir()
|
146
|
+
dest = os.path.join(tmp_dir.configs_dir, path)
|
147
|
+
dest_dir = os.path.dirname(dest)
|
148
|
+
if not os.path.exists(dest_dir):
|
149
|
+
os.makedirs(dest_dir)
|
150
|
+
shutil.copy(
|
151
|
+
path,
|
152
|
+
dest,
|
153
|
+
)
|
154
|
+
arguments = JobInputArguments(
|
155
|
+
include=include,
|
156
|
+
exclude=exclude,
|
157
|
+
file_path=path,
|
158
|
+
run_config=False,
|
159
|
+
)
|
160
|
+
if wandb.run is not None:
|
161
|
+
_publish_job_input(arguments, wandb.run)
|
162
|
+
else:
|
163
|
+
staged_inputs = StagedLaunchInputs()
|
164
|
+
staged_inputs.add_staged_input(arguments)
|
165
|
+
|
166
|
+
|
167
|
+
def handle_run_config_input(
|
168
|
+
include: Optional[List[str]] = None, exclude: Optional[List[str]] = None
|
169
|
+
):
|
170
|
+
"""Declare wandb.config as an overridable configuration for a launch job.
|
171
|
+
|
172
|
+
The include and exclude paths are sent to the backend interface of the
|
173
|
+
active run and used to configure the job builder.
|
174
|
+
|
175
|
+
If there is no active run, the include and exclude paths are staged and sent
|
176
|
+
when a run is created.
|
177
|
+
"""
|
178
|
+
arguments = JobInputArguments(
|
179
|
+
include=include,
|
180
|
+
exclude=exclude,
|
181
|
+
run_config=True,
|
182
|
+
file_path=None,
|
183
|
+
)
|
184
|
+
if wandb.run is not None:
|
185
|
+
_publish_job_input(arguments, wandb.run)
|
186
|
+
else:
|
187
|
+
stage_inputs = StagedLaunchInputs()
|
188
|
+
stage_inputs.add_staged_input(arguments)
|
189
|
+
|
190
|
+
|
191
|
+
def _split_on_unesc_dot(path: str) -> List[str]:
|
192
|
+
r"""Split a string on unescaped dots.
|
193
|
+
|
194
|
+
Arguments:
|
195
|
+
path (str): The string to split.
|
196
|
+
|
197
|
+
Raises:
|
198
|
+
ValueError: If the path has a trailing escape character.
|
199
|
+
|
200
|
+
Returns:
|
201
|
+
List[str]: The split string.
|
202
|
+
"""
|
203
|
+
parts = []
|
204
|
+
part = ""
|
205
|
+
i = 0
|
206
|
+
while i < len(path):
|
207
|
+
if path[i] == BACKSLASH:
|
208
|
+
if i == len(path) - 1:
|
209
|
+
raise LaunchError(
|
210
|
+
f"Invalid config path {path}: trailing {BACKSLASH}.",
|
211
|
+
)
|
212
|
+
if path[i + 1] == PERIOD:
|
213
|
+
part += PERIOD
|
214
|
+
i += 2
|
215
|
+
elif path[i] == PERIOD:
|
216
|
+
parts.append(part)
|
217
|
+
part = ""
|
218
|
+
i += 1
|
219
|
+
else:
|
220
|
+
part += path[i]
|
221
|
+
i += 1
|
222
|
+
if part:
|
223
|
+
parts.append(part)
|
224
|
+
return parts
|
@@ -0,0 +1,95 @@
|
|
1
|
+
"""Functions for declaring overridable configuration for launch jobs."""
|
2
|
+
|
3
|
+
from typing import List, Optional
|
4
|
+
|
5
|
+
|
6
|
+
def manage_config_file(
|
7
|
+
path: str,
|
8
|
+
include: Optional[List[str]] = None,
|
9
|
+
exclude: Optional[List[str]] = None,
|
10
|
+
):
|
11
|
+
r"""Declare an overridable configuration file for a launch job.
|
12
|
+
|
13
|
+
If a new job version is created from the active run, the configuration file
|
14
|
+
will be added to the job's inputs. If the job is launched and overrides
|
15
|
+
have been provided for the configuration file, this function will detect
|
16
|
+
the overrides from the environment and update the configuration file on disk.
|
17
|
+
Note that these overrides will only be applied in ephemeral containers.
|
18
|
+
`include` and `exclude` are lists of dot separated paths with the config.
|
19
|
+
The paths are used to filter subtrees of the configuration file out of the
|
20
|
+
job's inputs.
|
21
|
+
|
22
|
+
For example, given the following configuration file:
|
23
|
+
```yaml
|
24
|
+
model:
|
25
|
+
name: resnet
|
26
|
+
layers: 18
|
27
|
+
training:
|
28
|
+
epochs: 10
|
29
|
+
batch_size: 32
|
30
|
+
```
|
31
|
+
|
32
|
+
Passing `include=['model']` will only include the `model` subtree in the
|
33
|
+
job's inputs. Passing `exclude=['model.layers']` will exclude the `layers`
|
34
|
+
key from the `model` subtree. Note that `exclude` takes precedence over
|
35
|
+
`include`.
|
36
|
+
|
37
|
+
`.` is used as a separator for nested keys. If a key contains a `.`, it
|
38
|
+
should be escaped with a backslash, e.g. `include=[r'model\.layers']`. Note
|
39
|
+
the use of `r` to denote a raw string when using escape chars.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
path (str): The path to the configuration file. This path must be
|
43
|
+
relative and must not contain backwards traversal, i.e. `..`.
|
44
|
+
include (List[str]): A list of keys to include in the configuration file.
|
45
|
+
exclude (List[str]): A list of keys to exclude from the configuration file.
|
46
|
+
|
47
|
+
Raises:
|
48
|
+
LaunchError: If the path is not valid, or if there is no active run.
|
49
|
+
"""
|
50
|
+
from .internal import handle_config_file_input
|
51
|
+
|
52
|
+
return handle_config_file_input(path, include, exclude)
|
53
|
+
|
54
|
+
|
55
|
+
def manage_wandb_config(
|
56
|
+
include: Optional[List[str]] = None,
|
57
|
+
exclude: Optional[List[str]] = None,
|
58
|
+
):
|
59
|
+
r"""Declare wandb.config as an overridable configuration for a launch job.
|
60
|
+
|
61
|
+
If a new job version is created from the active run, the run config
|
62
|
+
(wandb.config) will become an overridable input of the job. If the job is
|
63
|
+
launched and overrides have been provided for the run config, the overrides
|
64
|
+
will be applied to the run config when `wandb.init` is called.
|
65
|
+
`include` and `exclude` are lists of dot separated paths with the config.
|
66
|
+
The paths are used to filter subtrees of the configuration file out of the
|
67
|
+
job's inputs.
|
68
|
+
|
69
|
+
For example, given the following run config contents:
|
70
|
+
```yaml
|
71
|
+
model:
|
72
|
+
name: resnet
|
73
|
+
layers: 18
|
74
|
+
training:
|
75
|
+
epochs: 10
|
76
|
+
batch_size: 32
|
77
|
+
```
|
78
|
+
Passing `include=['model']` will only include the `model` subtree in the
|
79
|
+
job's inputs. Passing `exclude=['model.layers']` will exclude the `layers`
|
80
|
+
key from the `model` subtree. Note that `exclude` takes precedence over
|
81
|
+
`include`.
|
82
|
+
`.` is used as a separator for nested keys. If a key contains a `.`, it
|
83
|
+
should be escaped with a backslash, e.g. `include=[r'model\.layers']`. Note
|
84
|
+
the use of `r` to denote a raw string when using escape chars.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
include (List[str]): A list of subtrees to include in the configuration.
|
88
|
+
exclude (List[str]): A list of subtrees to exclude from the configuration.
|
89
|
+
|
90
|
+
Raises:
|
91
|
+
LaunchError: If there is no active run.
|
92
|
+
"""
|
93
|
+
from .internal import handle_run_config_input
|
94
|
+
|
95
|
+
handle_run_config_input(include, exclude)
|
@@ -40,9 +40,9 @@ State = Literal[
|
|
40
40
|
|
41
41
|
|
42
42
|
class Status:
|
43
|
-
def __init__(self, state: "State" = "unknown",
|
43
|
+
def __init__(self, state: "State" = "unknown", messages: List[str] = None): # type: ignore
|
44
44
|
self.state = state
|
45
|
-
self.
|
45
|
+
self.messages = messages or []
|
46
46
|
|
47
47
|
def __repr__(self) -> "State":
|
48
48
|
return self.state
|
@@ -14,6 +14,7 @@ from kubernetes_asyncio.client import ( # type: ignore # noqa: F401
|
|
14
14
|
BatchV1Api,
|
15
15
|
CoreV1Api,
|
16
16
|
CustomObjectsApi,
|
17
|
+
V1Pod,
|
17
18
|
V1PodStatus,
|
18
19
|
)
|
19
20
|
|
@@ -118,6 +119,27 @@ def _is_container_creating(status: "V1PodStatus") -> bool:
|
|
118
119
|
return False
|
119
120
|
|
120
121
|
|
122
|
+
def _is_pod_unschedulable(status: "V1PodStatus") -> Tuple[bool, str]:
|
123
|
+
"""Return whether the pod is unschedulable along with the reason message."""
|
124
|
+
if not status.conditions:
|
125
|
+
return False, ""
|
126
|
+
for condition in status.conditions:
|
127
|
+
if (
|
128
|
+
condition.type == "PodScheduled"
|
129
|
+
and condition.status == "False"
|
130
|
+
and condition.reason == "Unschedulable"
|
131
|
+
):
|
132
|
+
return True, condition.message
|
133
|
+
return False, ""
|
134
|
+
|
135
|
+
|
136
|
+
def _get_crd_job_name(object: "V1Pod") -> Optional[str]:
|
137
|
+
refs = object.metadata.owner_references
|
138
|
+
if refs:
|
139
|
+
return refs[0].name
|
140
|
+
return None
|
141
|
+
|
142
|
+
|
121
143
|
def _state_from_conditions(conditions: List[Dict[str, Any]]) -> Optional[State]:
|
122
144
|
"""Get the status from the pod conditions."""
|
123
145
|
true_conditions = [
|
@@ -298,10 +320,18 @@ class LaunchKubernetesMonitor:
|
|
298
320
|
counts[state] += 1
|
299
321
|
return counts
|
300
322
|
|
301
|
-
def
|
323
|
+
def _set_status_state(self, job_name: str, state: State) -> None:
|
302
324
|
"""Set the status of the run."""
|
303
|
-
if self._job_states
|
304
|
-
self._job_states[job_name] =
|
325
|
+
if job_name not in self._job_states:
|
326
|
+
self._job_states[job_name] = Status(state)
|
327
|
+
elif self._job_states[job_name].state != state:
|
328
|
+
self._job_states[job_name].state = state
|
329
|
+
|
330
|
+
def _add_status_message(self, job_name: str, message: str) -> None:
|
331
|
+
if job_name not in self._job_states:
|
332
|
+
self._job_states[job_name] = Status("unknown")
|
333
|
+
wandb.termwarn(f"Warning from Kubernetes for job {job_name}: {message}")
|
334
|
+
self._job_states[job_name].messages.append(message)
|
305
335
|
|
306
336
|
async def _monitor_pods(self, namespace: str) -> None:
|
307
337
|
"""Monitor a namespace for changes."""
|
@@ -312,15 +342,19 @@ class LaunchKubernetesMonitor:
|
|
312
342
|
label_selector=self._label_selector,
|
313
343
|
):
|
314
344
|
obj = event.get("object")
|
315
|
-
job_name = obj.metadata.labels.get("job-name")
|
345
|
+
job_name = obj.metadata.labels.get("job-name") or _get_crd_job_name(obj)
|
316
346
|
if job_name is None or not hasattr(obj, "status"):
|
317
347
|
continue
|
318
348
|
if self.__get_status(job_name) in ["finished", "failed"]:
|
319
349
|
continue
|
350
|
+
|
351
|
+
is_unschedulable, reason = _is_pod_unschedulable(obj.status)
|
352
|
+
if is_unschedulable:
|
353
|
+
self._add_status_message(job_name, reason)
|
320
354
|
if obj.status.phase == "Running" or _is_container_creating(obj.status):
|
321
|
-
self.
|
355
|
+
self._set_status_state(job_name, "running")
|
322
356
|
elif _is_preempted(obj.status):
|
323
|
-
self.
|
357
|
+
self._set_status_state(job_name, "preempted")
|
324
358
|
|
325
359
|
async def _monitor_jobs(self, namespace: str) -> None:
|
326
360
|
"""Monitor a namespace for changes."""
|
@@ -334,15 +368,15 @@ class LaunchKubernetesMonitor:
|
|
334
368
|
job_name = obj.metadata.name
|
335
369
|
|
336
370
|
if obj.status.succeeded == 1:
|
337
|
-
self.
|
371
|
+
self._set_status_state(job_name, "finished")
|
338
372
|
elif obj.status.failed is not None and obj.status.failed >= 1:
|
339
|
-
self.
|
373
|
+
self._set_status_state(job_name, "failed")
|
340
374
|
|
341
375
|
# If the job is deleted and we haven't seen a terminal state
|
342
376
|
# then we will consider the job failed.
|
343
377
|
if event.get("type") == "DELETED":
|
344
378
|
if self._job_states.get(job_name) != Status("finished"):
|
345
|
-
self.
|
379
|
+
self._set_status_state(job_name, "failed")
|
346
380
|
|
347
381
|
async def _monitor_crd(
|
348
382
|
self, namespace: str, custom_resource: CustomResource
|
@@ -355,7 +389,7 @@ class LaunchKubernetesMonitor:
|
|
355
389
|
plural=custom_resource.plural,
|
356
390
|
group=custom_resource.group,
|
357
391
|
version=custom_resource.version,
|
358
|
-
label_selector=self._label_selector,
|
392
|
+
label_selector=self._label_selector,
|
359
393
|
):
|
360
394
|
object = event.get("object")
|
361
395
|
name = object.get("metadata", dict()).get("name")
|
@@ -383,8 +417,7 @@ class LaunchKubernetesMonitor:
|
|
383
417
|
)
|
384
418
|
if state is None:
|
385
419
|
continue
|
386
|
-
|
387
|
-
self._set_status(name, status)
|
420
|
+
self._set_status_state(name, state)
|
388
421
|
|
389
422
|
|
390
423
|
class SafeWatch:
|
@@ -29,7 +29,6 @@ from wandb.sdk.lib.retry import ExponentialBackoff, retry_async
|
|
29
29
|
from wandb.util import get_module
|
30
30
|
|
31
31
|
from .._project_spec import EntryPoint, LaunchProject
|
32
|
-
from ..builder.build import get_env_vars_dict
|
33
32
|
from ..errors import LaunchError
|
34
33
|
from ..utils import (
|
35
34
|
LOG_PREFIX,
|
@@ -374,8 +373,7 @@ class KubernetesRunner(AbstractRunner):
|
|
374
373
|
}
|
375
374
|
|
376
375
|
entry_point = (
|
377
|
-
launch_project.override_entrypoint
|
378
|
-
or launch_project.get_single_entry_point()
|
376
|
+
launch_project.override_entrypoint or launch_project.get_job_entry_point()
|
379
377
|
)
|
380
378
|
if launch_project.docker_image:
|
381
379
|
# dont specify run id if user provided image, could have multiple runs
|
@@ -401,8 +399,8 @@ class KubernetesRunner(AbstractRunner):
|
|
401
399
|
launch_project.override_entrypoint is not None,
|
402
400
|
)
|
403
401
|
|
404
|
-
env_vars = get_env_vars_dict(
|
405
|
-
|
402
|
+
env_vars = launch_project.get_env_vars_dict(
|
403
|
+
self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
|
406
404
|
)
|
407
405
|
api_key_secret = None
|
408
406
|
for cont in containers:
|
@@ -511,8 +509,8 @@ class KubernetesRunner(AbstractRunner):
|
|
511
509
|
api_version = resource_args.get("apiVersion", "batch/v1")
|
512
510
|
|
513
511
|
if api_version not in ["batch/v1", "batch/v1beta1"]:
|
514
|
-
env_vars = get_env_vars_dict(
|
515
|
-
|
512
|
+
env_vars = launch_project.get_env_vars_dict(
|
513
|
+
self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
|
516
514
|
)
|
517
515
|
# Crawl the resource args and add our env vars to the containers.
|
518
516
|
add_wandb_env(resource_args, env_vars)
|
@@ -537,7 +535,7 @@ class KubernetesRunner(AbstractRunner):
|
|
537
535
|
if LaunchAgent.initialized():
|
538
536
|
add_label_to_pods(
|
539
537
|
resource_args,
|
540
|
-
|
538
|
+
WANDB_K8S_LABEL_AGENT,
|
541
539
|
LaunchAgent.name(),
|
542
540
|
)
|
543
541
|
resource_args["metadata"]["labels"][WANDB_K8S_LABEL_AGENT] = (
|