wandb 0.17.0rc2__py3-none-any.whl → 0.17.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -2
- wandb/apis/importers/internals/internal.py +0 -1
- wandb/apis/importers/wandb.py +12 -7
- wandb/apis/internal.py +0 -3
- wandb/apis/public/api.py +213 -79
- wandb/apis/public/artifacts.py +335 -100
- wandb/apis/public/files.py +9 -9
- wandb/apis/public/jobs.py +16 -4
- wandb/apis/public/projects.py +26 -28
- wandb/apis/public/query_generator.py +1 -1
- wandb/apis/public/runs.py +163 -65
- wandb/apis/public/sweeps.py +2 -2
- wandb/apis/reports/__init__.py +1 -7
- wandb/apis/reports/v1/__init__.py +5 -27
- wandb/apis/reports/v2/__init__.py +7 -19
- wandb/apis/workspaces/__init__.py +8 -0
- wandb/beta/workflows.py +8 -3
- wandb/cli/cli.py +131 -59
- wandb/docker/__init__.py +1 -1
- wandb/errors/term.py +10 -2
- wandb/filesync/step_checksum.py +1 -4
- wandb/filesync/step_prepare.py +4 -24
- wandb/filesync/step_upload.py +5 -107
- wandb/filesync/upload_job.py +0 -76
- wandb/integration/gym/__init__.py +35 -15
- wandb/integration/openai/fine_tuning.py +21 -3
- wandb/integration/prodigy/prodigy.py +1 -1
- wandb/jupyter.py +16 -17
- wandb/plot/pr_curve.py +2 -1
- wandb/plot/roc_curve.py +2 -1
- wandb/{plots → plot}/utils.py +13 -25
- wandb/proto/v3/wandb_internal_pb2.py +54 -54
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +54 -54
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v5/wandb_base_pb2.py +30 -0
- wandb/proto/v5/wandb_internal_pb2.py +355 -0
- wandb/proto/v5/wandb_server_pb2.py +63 -0
- wandb/proto/v5/wandb_settings_pb2.py +45 -0
- wandb/proto/v5/wandb_telemetry_pb2.py +41 -0
- wandb/proto/wandb_base_pb2.py +2 -0
- wandb/proto/wandb_deprecated.py +9 -1
- wandb/proto/wandb_generate_deprecated.py +34 -0
- wandb/proto/{wandb_internal_codegen.py → wandb_generate_proto.py} +1 -35
- wandb/proto/wandb_internal_pb2.py +2 -0
- wandb/proto/wandb_server_pb2.py +2 -0
- wandb/proto/wandb_settings_pb2.py +2 -0
- wandb/proto/wandb_telemetry_pb2.py +2 -0
- wandb/sdk/artifacts/artifact.py +68 -22
- wandb/sdk/artifacts/artifact_manifest.py +1 -1
- wandb/sdk/artifacts/artifact_manifest_entry.py +6 -3
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -1
- wandb/sdk/artifacts/artifact_saver.py +1 -10
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +6 -2
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +6 -4
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +2 -42
- wandb/sdk/artifacts/storage_policy.py +1 -12
- wandb/sdk/data_types/image.py +1 -1
- wandb/sdk/data_types/video.py +4 -2
- wandb/sdk/interface/interface.py +13 -0
- wandb/sdk/interface/interface_shared.py +1 -1
- wandb/sdk/internal/file_pusher.py +2 -5
- wandb/sdk/internal/file_stream.py +6 -19
- wandb/sdk/internal/internal_api.py +148 -136
- wandb/sdk/internal/job_builder.py +207 -135
- wandb/sdk/internal/progress.py +0 -28
- wandb/sdk/internal/sender.py +102 -39
- wandb/sdk/internal/settings_static.py +8 -1
- wandb/sdk/internal/system/assets/trainium.py +3 -3
- wandb/sdk/internal/system/system_info.py +4 -2
- wandb/sdk/internal/update.py +1 -1
- wandb/sdk/launch/__init__.py +9 -1
- wandb/sdk/launch/_launch.py +4 -24
- wandb/sdk/launch/_launch_add.py +1 -3
- wandb/sdk/launch/_project_spec.py +184 -224
- wandb/sdk/launch/agent/agent.py +58 -18
- wandb/sdk/launch/agent/config.py +0 -3
- wandb/sdk/launch/builder/abstract.py +67 -0
- wandb/sdk/launch/builder/build.py +165 -576
- wandb/sdk/launch/builder/context_manager.py +235 -0
- wandb/sdk/launch/builder/docker_builder.py +7 -23
- wandb/sdk/launch/builder/kaniko_builder.py +10 -23
- wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
- wandb/sdk/launch/create_job.py +51 -45
- wandb/sdk/launch/environment/aws_environment.py +26 -1
- wandb/sdk/launch/inputs/files.py +148 -0
- wandb/sdk/launch/inputs/internal.py +224 -0
- wandb/sdk/launch/inputs/manage.py +95 -0
- wandb/sdk/launch/runner/abstract.py +2 -2
- wandb/sdk/launch/runner/kubernetes_monitor.py +45 -12
- wandb/sdk/launch/runner/kubernetes_runner.py +6 -8
- wandb/sdk/launch/runner/local_container.py +2 -3
- wandb/sdk/launch/runner/local_process.py +8 -29
- wandb/sdk/launch/runner/sagemaker_runner.py +20 -14
- wandb/sdk/launch/runner/vertex_runner.py +8 -7
- wandb/sdk/launch/sweeps/scheduler.py +2 -0
- wandb/sdk/launch/sweeps/utils.py +2 -2
- wandb/sdk/launch/utils.py +16 -138
- wandb/sdk/lib/_settings_toposort_generated.py +2 -5
- wandb/sdk/lib/apikey.py +4 -2
- wandb/sdk/lib/config_util.py +3 -3
- wandb/sdk/lib/proto_util.py +22 -1
- wandb/sdk/lib/redirect.py +1 -1
- wandb/sdk/service/service.py +2 -1
- wandb/sdk/service/streams.py +5 -5
- wandb/sdk/wandb_init.py +25 -59
- wandb/sdk/wandb_login.py +28 -25
- wandb/sdk/wandb_run.py +112 -45
- wandb/sdk/wandb_settings.py +33 -64
- wandb/sdk/wandb_watch.py +1 -1
- wandb/sklearn/plot/classifier.py +4 -6
- wandb/sync/sync.py +2 -2
- wandb/testing/relay.py +32 -17
- wandb/util.py +36 -37
- wandb/wandb_agent.py +3 -3
- wandb/wandb_controller.py +3 -2
- {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/METADATA +7 -9
- {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/RECORD +124 -146
- {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/WHEEL +1 -1
- wandb/apis/reports/v1/_blocks.py +0 -1406
- wandb/apis/reports/v1/_helpers.py +0 -70
- wandb/apis/reports/v1/_panels.py +0 -1282
- wandb/apis/reports/v1/_templates.py +0 -478
- wandb/apis/reports/v1/blocks.py +0 -27
- wandb/apis/reports/v1/helpers.py +0 -2
- wandb/apis/reports/v1/mutations.py +0 -66
- wandb/apis/reports/v1/panels.py +0 -17
- wandb/apis/reports/v1/report.py +0 -268
- wandb/apis/reports/v1/runset.py +0 -144
- wandb/apis/reports/v1/templates.py +0 -7
- wandb/apis/reports/v1/util.py +0 -406
- wandb/apis/reports/v1/validators.py +0 -131
- wandb/apis/reports/v2/blocks.py +0 -25
- wandb/apis/reports/v2/expr_parsing.py +0 -257
- wandb/apis/reports/v2/gql.py +0 -68
- wandb/apis/reports/v2/interface.py +0 -1911
- wandb/apis/reports/v2/internal.py +0 -867
- wandb/apis/reports/v2/metrics.py +0 -6
- wandb/apis/reports/v2/panels.py +0 -15
- wandb/catboost/__init__.py +0 -9
- wandb/fastai/__init__.py +0 -9
- wandb/keras/__init__.py +0 -19
- wandb/lightgbm/__init__.py +0 -9
- wandb/plots/__init__.py +0 -6
- wandb/plots/explain_text.py +0 -36
- wandb/plots/heatmap.py +0 -81
- wandb/plots/named_entity.py +0 -43
- wandb/plots/part_of_speech.py +0 -50
- wandb/plots/plot_definitions.py +0 -768
- wandb/plots/precision_recall.py +0 -121
- wandb/plots/roc.py +0 -103
- wandb/sacred/__init__.py +0 -3
- wandb/xgboost/__init__.py +0 -9
- {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/entry_points.txt +0 -0
- {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/licenses/LICENSE +0 -0
@@ -12,7 +12,6 @@ from wandb.sdk.launch.environment.abstract import AbstractEnvironment
|
|
12
12
|
from wandb.sdk.launch.registry.abstract import AbstractRegistry
|
13
13
|
|
14
14
|
from .._project_spec import LaunchProject
|
15
|
-
from ..builder.build import get_env_vars_dict
|
16
15
|
from ..errors import LaunchError
|
17
16
|
from ..utils import (
|
18
17
|
LOG_PREFIX,
|
@@ -133,8 +132,8 @@ class LocalContainerRunner(AbstractRunner):
|
|
133
132
|
docker_args = self._populate_docker_args(launch_project, image_uri)
|
134
133
|
synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
|
135
134
|
|
136
|
-
env_vars = get_env_vars_dict(
|
137
|
-
|
135
|
+
env_vars = launch_project.get_env_vars_dict(
|
136
|
+
self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
|
138
137
|
)
|
139
138
|
|
140
139
|
# When running against local port, need to swap to local docker host
|
@@ -4,16 +4,12 @@ from typing import Any, List, Optional
|
|
4
4
|
|
5
5
|
import wandb
|
6
6
|
|
7
|
-
from .._project_spec import LaunchProject
|
8
|
-
from ..builder.build import get_env_vars_dict
|
7
|
+
from .._project_spec import LaunchProject
|
9
8
|
from ..errors import LaunchError
|
10
9
|
from ..utils import (
|
11
10
|
LOG_PREFIX,
|
12
11
|
MAX_ENV_LENGTHS,
|
13
12
|
PROJECT_SYNCHRONOUS,
|
14
|
-
_is_wandb_uri,
|
15
|
-
download_wandb_python_deps,
|
16
|
-
parse_wandb_uri,
|
17
13
|
sanitize_wandb_api_key,
|
18
14
|
validate_wandb_python_deps,
|
19
15
|
)
|
@@ -47,8 +43,7 @@ class LocalProcessRunner(AbstractRunner):
|
|
47
43
|
|
48
44
|
synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
|
49
45
|
entry_point = (
|
50
|
-
launch_project.override_entrypoint
|
51
|
-
or launch_project.get_single_entry_point()
|
46
|
+
launch_project.override_entrypoint or launch_project.get_job_entry_point()
|
52
47
|
)
|
53
48
|
|
54
49
|
cmd: List[Any] = []
|
@@ -56,23 +51,7 @@ class LocalProcessRunner(AbstractRunner):
|
|
56
51
|
if launch_project.project_dir is None:
|
57
52
|
raise LaunchError("Launch LocalProcessRunner received empty project dir")
|
58
53
|
|
59
|
-
|
60
|
-
if launch_project.uri and _is_wandb_uri(launch_project.uri):
|
61
|
-
source_entity, source_project, run_name = parse_wandb_uri(
|
62
|
-
launch_project.uri
|
63
|
-
)
|
64
|
-
run_requirements_file = download_wandb_python_deps(
|
65
|
-
source_entity,
|
66
|
-
source_project,
|
67
|
-
run_name,
|
68
|
-
self._api,
|
69
|
-
launch_project.project_dir,
|
70
|
-
)
|
71
|
-
validate_wandb_python_deps(
|
72
|
-
run_requirements_file,
|
73
|
-
launch_project.project_dir,
|
74
|
-
)
|
75
|
-
elif launch_project.job:
|
54
|
+
if launch_project.job:
|
76
55
|
assert launch_project._job_artifact is not None
|
77
56
|
try:
|
78
57
|
validate_wandb_python_deps(
|
@@ -81,14 +60,14 @@ class LocalProcessRunner(AbstractRunner):
|
|
81
60
|
)
|
82
61
|
except Exception:
|
83
62
|
wandb.termwarn("Unable to validate python dependencies")
|
84
|
-
env_vars = get_env_vars_dict(
|
85
|
-
|
63
|
+
env_vars = launch_project.get_env_vars_dict(
|
64
|
+
self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
|
86
65
|
)
|
87
66
|
for env_key, env_value in env_vars.items():
|
88
67
|
cmd += [f"{shlex.quote(env_key)}={shlex.quote(env_value)}"]
|
89
|
-
|
90
|
-
|
91
|
-
cmd +=
|
68
|
+
if entry_point is not None:
|
69
|
+
cmd += entry_point.command
|
70
|
+
cmd += launch_project.override_args
|
92
71
|
|
93
72
|
command_str = " ".join(cmd).strip()
|
94
73
|
_msg = f"{LOG_PREFIX}Launching run as a local-process with command {sanitize_wandb_api_key(command_str)}"
|
@@ -12,8 +12,7 @@ from wandb.apis.internal import Api
|
|
12
12
|
from wandb.sdk.launch.environment.aws_environment import AwsEnvironment
|
13
13
|
from wandb.sdk.launch.errors import LaunchError
|
14
14
|
|
15
|
-
from .._project_spec import EntryPoint, LaunchProject
|
16
|
-
from ..builder.build import get_env_vars_dict
|
15
|
+
from .._project_spec import EntryPoint, LaunchProject
|
17
16
|
from ..registry.abstract import AbstractRegistry
|
18
17
|
from ..utils import (
|
19
18
|
LOG_PREFIX,
|
@@ -68,6 +67,7 @@ class SagemakerSubmittedRun(AbstractRun):
|
|
68
67
|
logGroupName="/aws/sagemaker/TrainingJobs",
|
69
68
|
logStreamName=log_name,
|
70
69
|
)
|
70
|
+
assert "events" in res
|
71
71
|
return "\n".join(
|
72
72
|
[f'{event["timestamp"]}:{event["message"]}' for event in res["events"]]
|
73
73
|
)
|
@@ -179,7 +179,10 @@ class SageMakerRunner(AbstractRunner):
|
|
179
179
|
caller_id = client.get_caller_identity()
|
180
180
|
account_id = caller_id["Account"]
|
181
181
|
_logger.info(f"Using account ID {account_id}")
|
182
|
-
|
182
|
+
partition = await self.environment.get_partition()
|
183
|
+
role_arn = get_role_arn(
|
184
|
+
given_sagemaker_args, self.backend_config, account_id, partition
|
185
|
+
)
|
183
186
|
|
184
187
|
# Create a sagemaker client to launch the job.
|
185
188
|
sagemaker_client = session.client("sagemaker")
|
@@ -221,12 +224,12 @@ class SageMakerRunner(AbstractRunner):
|
|
221
224
|
launch_project.fill_macros(image_uri)
|
222
225
|
_logger.info("Connecting to sagemaker client")
|
223
226
|
entry_point = (
|
224
|
-
launch_project.override_entrypoint
|
225
|
-
or launch_project.get_single_entry_point()
|
226
|
-
)
|
227
|
-
command_args = get_entry_point_command(
|
228
|
-
entry_point, launch_project.override_args
|
227
|
+
launch_project.override_entrypoint or launch_project.get_job_entry_point()
|
229
228
|
)
|
229
|
+
command_args = []
|
230
|
+
if entry_point is not None:
|
231
|
+
command_args += entry_point.command
|
232
|
+
command_args += launch_project.override_args
|
230
233
|
if command_args:
|
231
234
|
command_str = " ".join(command_args)
|
232
235
|
wandb.termlog(
|
@@ -349,18 +352,18 @@ def build_sagemaker_args(
|
|
349
352
|
|
350
353
|
if sagemaker_args.get("ResourceConfig") is None:
|
351
354
|
raise LaunchError(
|
352
|
-
"Sagemaker launcher requires a ResourceConfig
|
355
|
+
"Sagemaker launcher requires a ResourceConfig resource argument"
|
353
356
|
)
|
354
357
|
|
355
358
|
if sagemaker_args.get("StoppingCondition") is None:
|
356
359
|
raise LaunchError(
|
357
|
-
"Sagemaker launcher requires a StoppingCondition
|
360
|
+
"Sagemaker launcher requires a StoppingCondition resource argument"
|
358
361
|
)
|
359
362
|
|
360
363
|
given_env = given_sagemaker_args.get(
|
361
364
|
"Environment", sagemaker_args.get("environment", {})
|
362
365
|
)
|
363
|
-
calced_env = get_env_vars_dict(
|
366
|
+
calced_env = launch_project.get_env_vars_dict(api, max_env_length)
|
364
367
|
total_env = {**calced_env, **given_env}
|
365
368
|
sagemaker_args["Environment"] = total_env
|
366
369
|
|
@@ -405,7 +408,10 @@ async def launch_sagemaker_job(
|
|
405
408
|
|
406
409
|
|
407
410
|
def get_role_arn(
|
408
|
-
sagemaker_args: Dict[str, Any],
|
411
|
+
sagemaker_args: Dict[str, Any],
|
412
|
+
backend_config: Dict[str, Any],
|
413
|
+
account_id: str,
|
414
|
+
partition: str,
|
409
415
|
) -> str:
|
410
416
|
"""Get the role arn from the sagemaker args or the backend config."""
|
411
417
|
role_arn = sagemaker_args.get("RoleArn") or sagemaker_args.get("role_arn")
|
@@ -416,7 +422,7 @@ def get_role_arn(
|
|
416
422
|
"AWS sagemaker require a string RoleArn set this by adding a `RoleArn` key to the sagemaker"
|
417
423
|
"field of resource_args"
|
418
424
|
)
|
419
|
-
if role_arn.startswith("arn:
|
425
|
+
if role_arn.startswith(f"arn:{partition}:iam::"):
|
420
426
|
return role_arn # type: ignore
|
421
427
|
|
422
|
-
return f"arn:
|
428
|
+
return f"arn:{partition}:iam::{account_id}:role/{role_arn}"
|
@@ -8,8 +8,7 @@ if False:
|
|
8
8
|
from wandb.apis.internal import Api
|
9
9
|
from wandb.util import get_module
|
10
10
|
|
11
|
-
from .._project_spec import LaunchProject
|
12
|
-
from ..builder.build import get_env_vars_dict
|
11
|
+
from .._project_spec import LaunchProject
|
13
12
|
from ..environment.gcp_environment import GcpEnvironment
|
14
13
|
from ..errors import LaunchError
|
15
14
|
from ..registry.abstract import AbstractRegistry
|
@@ -113,14 +112,16 @@ class VertexRunner(AbstractRunner):
|
|
113
112
|
synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
|
114
113
|
|
115
114
|
entry_point = (
|
116
|
-
launch_project.override_entrypoint
|
117
|
-
or launch_project.get_single_entry_point()
|
115
|
+
launch_project.override_entrypoint or launch_project.get_job_entry_point()
|
118
116
|
)
|
119
117
|
|
120
118
|
# TODO: Set entrypoint in each container
|
121
|
-
entry_cmd =
|
122
|
-
|
123
|
-
|
119
|
+
entry_cmd = []
|
120
|
+
if entry_point is not None:
|
121
|
+
entry_cmd += entry_point.command
|
122
|
+
entry_cmd += launch_project.override_args
|
123
|
+
|
124
|
+
env_vars = launch_project.get_env_vars_dict(
|
124
125
|
api=self._api,
|
125
126
|
max_env_length=MAX_ENV_LENGTHS[self.__class__.__name__],
|
126
127
|
)
|
@@ -668,6 +668,8 @@ class Scheduler(ABC):
|
|
668
668
|
launch_config = copy.deepcopy(self._wandb_run.config.get("launch", {}))
|
669
669
|
if "overrides" not in launch_config:
|
670
670
|
launch_config["overrides"] = {"run_config": {}}
|
671
|
+
if "run_config" not in launch_config["overrides"]:
|
672
|
+
launch_config["overrides"]["run_config"] = {}
|
671
673
|
launch_config["overrides"]["run_config"].update(args["args_dict"])
|
672
674
|
|
673
675
|
if macro_args: # pipe in hyperparam args as params to launch
|
wandb/sdk/launch/sweeps/utils.py
CHANGED
@@ -211,7 +211,7 @@ def create_sweep_command_args(command: Dict) -> Dict[str, Any]:
|
|
211
211
|
|
212
212
|
"""
|
213
213
|
if "args" not in command:
|
214
|
-
raise ValueError('No "args" found in command:
|
214
|
+
raise ValueError('No "args" found in command: {}'.format(command))
|
215
215
|
# four different formats of command args
|
216
216
|
# (1) standard command line flags (e.g. --foo=bar)
|
217
217
|
flags: List[str] = []
|
@@ -228,7 +228,7 @@ def create_sweep_command_args(command: Dict) -> Dict[str, Any]:
|
|
228
228
|
try:
|
229
229
|
_value: Any = config["value"]
|
230
230
|
except KeyError:
|
231
|
-
raise ValueError('No "value" found for command["args"]["
|
231
|
+
raise ValueError('No "value" found for command["args"]["{}"]'.format(param))
|
232
232
|
|
233
233
|
_flag: str = f"{param}={_value}"
|
234
234
|
flags.append("--" + _flag)
|
wandb/sdk/launch/utils.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1
|
-
# heavily inspired by https://github.com/mlflow/mlflow/blob/master/mlflow/projects/utils.py
|
2
1
|
import asyncio
|
3
2
|
import json
|
4
3
|
import logging
|
@@ -16,7 +15,6 @@ import wandb
|
|
16
15
|
import wandb.docker as docker
|
17
16
|
from wandb import util
|
18
17
|
from wandb.apis.internal import Api
|
19
|
-
from wandb.errors import CommError
|
20
18
|
from wandb.sdk.launch.errors import LaunchError
|
21
19
|
from wandb.sdk.launch.git_reference import GitReference
|
22
20
|
from wandb.sdk.launch.wandb_reference import WandbReference
|
@@ -32,12 +30,13 @@ FAILED_PACKAGES_REGEX = re.compile(
|
|
32
30
|
)
|
33
31
|
|
34
32
|
if TYPE_CHECKING: # pragma: no cover
|
35
|
-
from wandb.sdk.artifacts.artifact import Artifact
|
36
33
|
from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker
|
37
34
|
|
38
35
|
|
39
36
|
# TODO: this should be restricted to just Git repos and not S3 and stuff like that
|
40
|
-
_GIT_URI_REGEX = re.compile(
|
37
|
+
_GIT_URI_REGEX = re.compile(
|
38
|
+
r"^[^/|^~|^\.].*(git|bitbucket|dev\.azure\.com|\.visualstudio\.com)"
|
39
|
+
)
|
41
40
|
_VALID_IP_REGEX = r"^https?://[0-9]+(?:\.[0-9]+){3}(:[0-9]+)?"
|
42
41
|
_VALID_PIP_PACKAGE_REGEX = r"^[a-zA-Z0-9_.-]+$"
|
43
42
|
_VALID_WANDB_REGEX = r"^https?://(api.)?wandb"
|
@@ -75,6 +74,7 @@ AZURE_BLOB_REGEX = re.compile(
|
|
75
74
|
r"^https://([^\.]+)\.blob\.core\.windows\.net/([^/]+)/?(.*)$"
|
76
75
|
)
|
77
76
|
|
77
|
+
ARN_PARTITION_RE = re.compile(r"^arn:([^:]+):[^:]*:[^:]*:[^:]*:[^:]*$")
|
78
78
|
|
79
79
|
PROJECT_SYNCHRONOUS = "SYNCHRONOUS"
|
80
80
|
|
@@ -316,16 +316,13 @@ def construct_launch_spec(
|
|
316
316
|
|
317
317
|
|
318
318
|
def validate_launch_spec_source(launch_spec: Dict[str, Any]) -> None:
|
319
|
-
uri = launch_spec.get("uri")
|
320
319
|
job = launch_spec.get("job")
|
321
320
|
docker_image = launch_spec.get("docker", {}).get("docker_image")
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
elif sum(map(bool, [uri, job, docker_image])) > 1:
|
328
|
-
raise LaunchError("Must specify exactly one of uri, job or image")
|
321
|
+
if bool(job) == bool(docker_image):
|
322
|
+
raise LaunchError(
|
323
|
+
"Exactly one of job or docker_image must be specified in the launch "
|
324
|
+
"spec."
|
325
|
+
)
|
329
326
|
|
330
327
|
|
331
328
|
def parse_wandb_uri(uri: str) -> Tuple[str, str, str]:
|
@@ -336,77 +333,6 @@ def parse_wandb_uri(uri: str) -> Tuple[str, str, str]:
|
|
336
333
|
return (ref.entity, ref.project, ref.run_id)
|
337
334
|
|
338
335
|
|
339
|
-
def is_bare_wandb_uri(uri: str) -> bool:
|
340
|
-
"""Check that a wandb uri is valid.
|
341
|
-
|
342
|
-
URI must be in the format
|
343
|
-
`/<entity>/<project>/runs/<run_name>[other stuff]`
|
344
|
-
or
|
345
|
-
`/<entity>/<project>/artifacts/job/<job_name>[other stuff]`.
|
346
|
-
"""
|
347
|
-
_logger.info(f"Checking if uri {uri} is bare...")
|
348
|
-
return uri.startswith("/") and WandbReference.is_uri_job_or_run(uri)
|
349
|
-
|
350
|
-
|
351
|
-
def fetch_wandb_project_run_info(
|
352
|
-
entity: str, project: str, run_name: str, api: Api
|
353
|
-
) -> Any:
|
354
|
-
_logger.info("Fetching run info...")
|
355
|
-
try:
|
356
|
-
result = api.get_run_info(entity, project, run_name)
|
357
|
-
except CommError:
|
358
|
-
result = None
|
359
|
-
if result is None:
|
360
|
-
raise LaunchError(
|
361
|
-
f"Run info is invalid or doesn't exist for {api.settings('base_url')}/{entity}/{project}/runs/{run_name}"
|
362
|
-
)
|
363
|
-
if result.get("codePath") is None:
|
364
|
-
# TODO: we don't currently expose codePath in the runInfo endpoint, this downloads
|
365
|
-
# it from wandb-metadata.json if we can.
|
366
|
-
metadata = api.download_url(
|
367
|
-
project, "wandb-metadata.json", run=run_name, entity=entity
|
368
|
-
)
|
369
|
-
if metadata is not None:
|
370
|
-
_, response = api.download_file(metadata["url"])
|
371
|
-
data = response.json()
|
372
|
-
result["codePath"] = data.get("codePath")
|
373
|
-
result["cudaVersion"] = data.get("cuda", None)
|
374
|
-
|
375
|
-
return result
|
376
|
-
|
377
|
-
|
378
|
-
def download_entry_point(
|
379
|
-
entity: str, project: str, run_name: str, api: Api, entry_point: str, dir: str
|
380
|
-
) -> bool:
|
381
|
-
metadata = api.download_url(
|
382
|
-
project, f"code/{entry_point}", run=run_name, entity=entity
|
383
|
-
)
|
384
|
-
if metadata is not None:
|
385
|
-
_, response = api.download_file(metadata["url"])
|
386
|
-
with util.fsync_open(os.path.join(dir, entry_point), "wb") as file:
|
387
|
-
for data in response.iter_content(chunk_size=1024):
|
388
|
-
file.write(data)
|
389
|
-
return True
|
390
|
-
return False
|
391
|
-
|
392
|
-
|
393
|
-
def download_wandb_python_deps(
|
394
|
-
entity: str, project: str, run_name: str, api: Api, dir: str
|
395
|
-
) -> Optional[str]:
|
396
|
-
reqs = api.download_url(project, "requirements.txt", run=run_name, entity=entity)
|
397
|
-
if reqs is not None:
|
398
|
-
_logger.info("Downloading python dependencies")
|
399
|
-
_, response = api.download_file(reqs["url"])
|
400
|
-
|
401
|
-
with util.fsync_open(
|
402
|
-
os.path.join(dir, "requirements.frozen.txt"), "wb"
|
403
|
-
) as file:
|
404
|
-
for data in response.iter_content(chunk_size=1024):
|
405
|
-
file.write(data)
|
406
|
-
return "requirements.frozen.txt"
|
407
|
-
return None
|
408
|
-
|
409
|
-
|
410
336
|
def get_local_python_deps(
|
411
337
|
dir: str, filename: str = "requirements.local.txt"
|
412
338
|
) -> Optional[str]:
|
@@ -498,19 +424,6 @@ def validate_wandb_python_deps(
|
|
498
424
|
_logger.warning("Unable to validate local python dependencies")
|
499
425
|
|
500
426
|
|
501
|
-
def fetch_project_diff(
|
502
|
-
entity: str, project: str, run_name: str, api: Api
|
503
|
-
) -> Optional[str]:
|
504
|
-
"""Fetches project diff from wandb servers."""
|
505
|
-
_logger.info("Searching for diff.patch")
|
506
|
-
patch = None
|
507
|
-
try:
|
508
|
-
(_, _, patch, _) = api.run_config(project, run_name, entity)
|
509
|
-
except CommError:
|
510
|
-
pass
|
511
|
-
return patch
|
512
|
-
|
513
|
-
|
514
427
|
def apply_patch(patch_string: str, dst_dir: str) -> None:
|
515
428
|
"""Applies a patch file to a directory."""
|
516
429
|
_logger.info("Applying diff.patch")
|
@@ -531,17 +444,6 @@ def apply_patch(patch_string: str, dst_dir: str) -> None:
|
|
531
444
|
raise wandb.Error("Failed to apply diff.patch associated with run.")
|
532
445
|
|
533
446
|
|
534
|
-
def _make_refspec_from_version(version: Optional[str]) -> List[str]:
|
535
|
-
"""Create a refspec that checks for the existence of origin/main and the version."""
|
536
|
-
if version:
|
537
|
-
return [f"+{version}"]
|
538
|
-
|
539
|
-
return [
|
540
|
-
"+refs/heads/main*:refs/remotes/origin/main*",
|
541
|
-
"+refs/heads/master*:refs/remotes/origin/master*",
|
542
|
-
]
|
543
|
-
|
544
|
-
|
545
447
|
def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> Optional[str]:
|
546
448
|
"""Clones the git repo at ``uri`` into ``dst_dir``.
|
547
449
|
|
@@ -561,13 +463,6 @@ def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> Optional[
|
|
561
463
|
return version
|
562
464
|
|
563
465
|
|
564
|
-
def merge_parameters(
|
565
|
-
higher_priority_params: Dict[str, Any], lower_priority_params: Dict[str, Any]
|
566
|
-
) -> Dict[str, Any]:
|
567
|
-
"""Merge the contents of two dicts, keeping values from higher_priority_params if there are conflicts."""
|
568
|
-
return {**lower_priority_params, **higher_priority_params}
|
569
|
-
|
570
|
-
|
571
466
|
def convert_jupyter_notebook_to_script(fname: str, project_dir: str) -> str:
|
572
467
|
nbconvert = wandb.util.get_module(
|
573
468
|
"nbconvert", "nbformat and nbconvert are required to use launch with notebooks"
|
@@ -597,25 +492,6 @@ def convert_jupyter_notebook_to_script(fname: str, project_dir: str) -> str:
|
|
597
492
|
return new_name
|
598
493
|
|
599
494
|
|
600
|
-
def check_and_download_code_artifacts(
|
601
|
-
entity: str, project: str, run_name: str, internal_api: Api, project_dir: str
|
602
|
-
) -> Optional["Artifact"]:
|
603
|
-
_logger.info("Checking for code artifacts")
|
604
|
-
public_api = wandb.PublicApi(
|
605
|
-
overrides={"base_url": internal_api.settings("base_url")}
|
606
|
-
)
|
607
|
-
|
608
|
-
run = public_api.run(f"{entity}/{project}/{run_name}")
|
609
|
-
run_artifacts = run.logged_artifacts()
|
610
|
-
|
611
|
-
for artifact in run_artifacts:
|
612
|
-
if hasattr(artifact, "type") and artifact.type == "code":
|
613
|
-
artifact.download(project_dir)
|
614
|
-
return artifact # type: ignore
|
615
|
-
|
616
|
-
return None
|
617
|
-
|
618
|
-
|
619
495
|
def to_camel_case(maybe_snake_str: str) -> str:
|
620
496
|
if "_" not in maybe_snake_str:
|
621
497
|
return maybe_snake_str
|
@@ -623,11 +499,6 @@ def to_camel_case(maybe_snake_str: str) -> str:
|
|
623
499
|
return "".join(x.title() if x else "_" for x in components)
|
624
500
|
|
625
501
|
|
626
|
-
def run_shell(args: List[str]) -> Tuple[str, str]:
|
627
|
-
out = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
628
|
-
return out.stdout.decode("utf-8").strip(), out.stderr.decode("utf-8").strip()
|
629
|
-
|
630
|
-
|
631
502
|
def validate_build_and_registry_configs(
|
632
503
|
build_config: Dict[str, Any], registry_config: Dict[str, Any]
|
633
504
|
) -> None:
|
@@ -864,3 +735,10 @@ def get_entrypoint_file(entrypoint: List[str]) -> Optional[str]:
|
|
864
735
|
if len(entrypoint) < 2:
|
865
736
|
return None
|
866
737
|
return entrypoint[1]
|
738
|
+
|
739
|
+
|
740
|
+
def get_current_python_version() -> Tuple[str, str]:
|
741
|
+
full_version = sys.version.split()[0].split(".")
|
742
|
+
major = full_version[0]
|
743
|
+
version = ".".join(full_version[:2]) if len(full_version) >= 2 else major + ".0"
|
744
|
+
return version, major
|
@@ -13,7 +13,6 @@ else:
|
|
13
13
|
_Setting = Literal[
|
14
14
|
"_args",
|
15
15
|
"_aws_lambda",
|
16
|
-
"_async_upload_concurrency_limit",
|
17
16
|
"_cli_only_mode",
|
18
17
|
"_code_path_local",
|
19
18
|
"_colab",
|
@@ -25,7 +24,6 @@ _Setting = Literal[
|
|
25
24
|
"_disable_update_check",
|
26
25
|
"_disable_viewer",
|
27
26
|
"_disable_machine_info",
|
28
|
-
"_except_exit",
|
29
27
|
"_executable",
|
30
28
|
"_extra_http_headers",
|
31
29
|
"_file_stream_retry_max",
|
@@ -92,6 +90,7 @@ _Setting = Literal[
|
|
92
90
|
"colab_url",
|
93
91
|
"config_paths",
|
94
92
|
"console",
|
93
|
+
"console_multipart",
|
95
94
|
"deployment",
|
96
95
|
"disable_code",
|
97
96
|
"disable_git",
|
@@ -104,6 +103,7 @@ _Setting = Literal[
|
|
104
103
|
"files_dir",
|
105
104
|
"force",
|
106
105
|
"fork_from",
|
106
|
+
"resume_from",
|
107
107
|
"git_commit",
|
108
108
|
"git_remote",
|
109
109
|
"git_remote_url",
|
@@ -126,7 +126,6 @@ _Setting = Literal[
|
|
126
126
|
"login_timeout",
|
127
127
|
"mode",
|
128
128
|
"notebook_name",
|
129
|
-
"problem",
|
130
129
|
"program",
|
131
130
|
"program_abspath",
|
132
131
|
"program_relpath",
|
@@ -179,7 +178,6 @@ _Setting = Literal[
|
|
179
178
|
]
|
180
179
|
|
181
180
|
SETTINGS_TOPOLOGICALLY_SORTED: Final[Tuple[_Setting, ...]] = (
|
182
|
-
"_async_upload_concurrency_limit",
|
183
181
|
"_service_wait",
|
184
182
|
"_stats_sample_rate_seconds",
|
185
183
|
"_stats_samples_to_average",
|
@@ -189,7 +187,6 @@ SETTINGS_TOPOLOGICALLY_SORTED: Final[Tuple[_Setting, ...]] = (
|
|
189
187
|
"console",
|
190
188
|
"job_source",
|
191
189
|
"mode",
|
192
|
-
"problem",
|
193
190
|
"project",
|
194
191
|
"run_id",
|
195
192
|
"start_method",
|
wandb/sdk/lib/apikey.py
CHANGED
@@ -202,7 +202,7 @@ def write_netrc(host: str, entity: str, key: str) -> Optional[bool]:
|
|
202
202
|
elif skip:
|
203
203
|
skip -= 1
|
204
204
|
else:
|
205
|
-
f.write("
|
205
|
+
f.write("{}\n".format(line))
|
206
206
|
f.write(
|
207
207
|
textwrap.dedent(
|
208
208
|
"""\
|
@@ -236,7 +236,9 @@ def write_key(
|
|
236
236
|
_, suffix = key.split("-", 1) if "-" in key else ("", key)
|
237
237
|
|
238
238
|
if len(suffix) != 40:
|
239
|
-
raise ValueError(
|
239
|
+
raise ValueError(
|
240
|
+
"API key must be 40 characters long, yours was {}".format(len(key))
|
241
|
+
)
|
240
242
|
|
241
243
|
if anonymous:
|
242
244
|
api.set_setting("anonymous", "true", globally=True, persist=True)
|
wandb/sdk/lib/config_util.py
CHANGED
@@ -66,13 +66,13 @@ def dict_from_config_file(
|
|
66
66
|
) -> Optional[Dict[str, Any]]:
|
67
67
|
if not os.path.exists(filename):
|
68
68
|
if must_exist:
|
69
|
-
raise ConfigError("config file
|
70
|
-
logger.debug("no default config file found in
|
69
|
+
raise ConfigError("config file {} doesn't exist".format(filename))
|
70
|
+
logger.debug("no default config file found in {}".format(filename))
|
71
71
|
return None
|
72
72
|
try:
|
73
73
|
conf_file = open(filename)
|
74
74
|
except OSError:
|
75
|
-
raise ConfigError("Couldn't read config file:
|
75
|
+
raise ConfigError("Couldn't read config file: {}".format(filename))
|
76
76
|
try:
|
77
77
|
loaded = load_yaml(conf_file)
|
78
78
|
except yaml.parser.ParserError:
|
wandb/sdk/lib/proto_util.py
CHANGED
@@ -12,7 +12,28 @@ if TYPE_CHECKING: # pragma: no cover
|
|
12
12
|
|
13
13
|
|
14
14
|
def dict_from_proto_list(obj_list: "RepeatedCompositeFieldContainer") -> Dict[str, Any]:
|
15
|
-
|
15
|
+
result: Dict[str, Any] = {}
|
16
|
+
|
17
|
+
for item in obj_list:
|
18
|
+
# Start from the root of the result dict
|
19
|
+
current_level = result
|
20
|
+
|
21
|
+
if len(item.nested_key) > 0:
|
22
|
+
keys = list(item.nested_key)
|
23
|
+
else:
|
24
|
+
keys = [item.key]
|
25
|
+
|
26
|
+
for key in keys[:-1]:
|
27
|
+
if key not in current_level:
|
28
|
+
current_level[key] = {}
|
29
|
+
# Move the reference deeper into the nested dictionary
|
30
|
+
current_level = current_level[key]
|
31
|
+
|
32
|
+
# Set the value at the final key location, parsing JSON from the value_json field
|
33
|
+
final_key = keys[-1]
|
34
|
+
current_level[final_key] = json.loads(item.value_json)
|
35
|
+
|
36
|
+
return result
|
16
37
|
|
17
38
|
|
18
39
|
def _result_from_record(record: "pb.Record") -> "pb.Result":
|
wandb/sdk/lib/redirect.py
CHANGED
wandb/sdk/service/service.py
CHANGED
@@ -110,7 +110,8 @@ class _Service:
|
|
110
110
|
f"The wandb service process exited with {proc.returncode}. "
|
111
111
|
"Ensure that `sys.executable` is a valid python interpreter. "
|
112
112
|
"You can override it with the `_executable` setting "
|
113
|
-
"or with the `WANDB__EXECUTABLE` environment variable."
|
113
|
+
"or with the `WANDB__EXECUTABLE` environment variable."
|
114
|
+
f"\n{context}",
|
114
115
|
context=context,
|
115
116
|
)
|
116
117
|
if not os.path.isfile(fname):
|
wandb/sdk/service/streams.py
CHANGED
@@ -319,8 +319,8 @@ class StreamMux:
|
|
319
319
|
# These could be done in parallel in the future
|
320
320
|
for _sid, stream in started_streams.items():
|
321
321
|
# dispatch all our final requests
|
322
|
-
poll_exit_handle = stream.interface.deliver_poll_exit()
|
323
322
|
server_info_handle = stream.interface.deliver_request_server_info()
|
323
|
+
poll_exit_handle = stream.interface.deliver_poll_exit()
|
324
324
|
final_summary_handle = stream.interface.deliver_get_summary()
|
325
325
|
sampled_history_handle = stream.interface.deliver_request_sampled_history()
|
326
326
|
internal_messages_handle = stream.interface.deliver_internal_messages()
|
@@ -330,14 +330,14 @@ class StreamMux:
|
|
330
330
|
internal_messages_response = result.response.internal_messages_response
|
331
331
|
|
332
332
|
# wait for them, it's ok to do this serially but this can be improved
|
333
|
-
result = poll_exit_handle.wait(timeout=-1)
|
334
|
-
assert result
|
335
|
-
poll_exit_response = result.response.poll_exit_response
|
336
|
-
|
337
333
|
result = server_info_handle.wait(timeout=-1)
|
338
334
|
assert result
|
339
335
|
server_info_response = result.response.server_info_response
|
340
336
|
|
337
|
+
result = poll_exit_handle.wait(timeout=-1)
|
338
|
+
assert result
|
339
|
+
poll_exit_response = result.response.poll_exit_response
|
340
|
+
|
341
341
|
result = sampled_history_handle.wait(timeout=-1)
|
342
342
|
assert result
|
343
343
|
sampled_history = result.response.sampled_history_response
|