wandb 0.16.6__py3-none-any.whl → 0.17.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- package_readme.md +95 -0
- wandb/__init__.py +2 -3
- wandb/agents/pyagent.py +0 -1
- wandb/analytics/sentry.py +2 -1
- wandb/apis/importers/internals/internal.py +0 -1
- wandb/apis/importers/internals/protocols.py +30 -56
- wandb/apis/importers/mlflow.py +13 -26
- wandb/apis/importers/wandb.py +8 -14
- wandb/apis/internal.py +0 -3
- wandb/apis/public/api.py +55 -3
- wandb/apis/public/artifacts.py +1 -0
- wandb/apis/public/files.py +1 -0
- wandb/apis/public/history.py +1 -0
- wandb/apis/public/jobs.py +17 -4
- wandb/apis/public/projects.py +1 -0
- wandb/apis/public/reports.py +1 -0
- wandb/apis/public/runs.py +15 -17
- wandb/apis/public/sweeps.py +1 -0
- wandb/apis/public/teams.py +1 -0
- wandb/apis/public/users.py +1 -0
- wandb/apis/reports/v1/_blocks.py +3 -7
- wandb/apis/reports/v2/gql.py +1 -0
- wandb/apis/reports/v2/interface.py +3 -4
- wandb/apis/reports/v2/internal.py +5 -8
- wandb/cli/cli.py +92 -22
- wandb/data_types.py +9 -6
- wandb/docker/__init__.py +1 -1
- wandb/env.py +38 -8
- wandb/errors/__init__.py +5 -0
- wandb/errors/term.py +10 -2
- wandb/filesync/step_checksum.py +1 -4
- wandb/filesync/step_prepare.py +4 -24
- wandb/filesync/step_upload.py +4 -106
- wandb/filesync/upload_job.py +0 -76
- wandb/integration/catboost/catboost.py +1 -1
- wandb/integration/fastai/__init__.py +1 -0
- wandb/integration/huggingface/resolver.py +2 -2
- wandb/integration/keras/__init__.py +1 -0
- wandb/integration/keras/callbacks/metrics_logger.py +1 -1
- wandb/integration/keras/keras.py +7 -7
- wandb/integration/langchain/wandb_tracer.py +1 -0
- wandb/integration/lightning/fabric/logger.py +1 -3
- wandb/integration/metaflow/metaflow.py +41 -6
- wandb/integration/openai/fine_tuning.py +3 -3
- wandb/integration/prodigy/prodigy.py +1 -1
- wandb/old/summary.py +1 -1
- wandb/plot/confusion_matrix.py +1 -1
- wandb/plot/pr_curve.py +2 -1
- wandb/plot/roc_curve.py +2 -1
- wandb/{plots → plot}/utils.py +13 -25
- wandb/proto/v3/wandb_internal_pb2.py +364 -332
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +322 -316
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/wandb_deprecated.py +7 -1
- wandb/proto/wandb_internal_codegen.py +3 -29
- wandb/sdk/artifacts/artifact.py +26 -11
- wandb/sdk/artifacts/artifact_download_logger.py +1 -0
- wandb/sdk/artifacts/artifact_file_cache.py +18 -4
- wandb/sdk/artifacts/artifact_instance_cache.py +1 -0
- wandb/sdk/artifacts/artifact_manifest.py +1 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +7 -3
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
- wandb/sdk/artifacts/artifact_saver.py +2 -8
- wandb/sdk/artifacts/artifact_state.py +1 -0
- wandb/sdk/artifacts/artifact_ttl.py +1 -0
- wandb/sdk/artifacts/exceptions.py +1 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +13 -18
- wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +5 -3
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +3 -42
- wandb/sdk/artifacts/storage_policy.py +2 -12
- wandb/sdk/data_types/_dtypes.py +8 -8
- wandb/sdk/data_types/base_types/media.py +3 -6
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +3 -1
- wandb/sdk/data_types/image.py +1 -1
- wandb/sdk/data_types/video.py +1 -1
- wandb/sdk/integration_utils/auto_logging.py +5 -6
- wandb/sdk/integration_utils/data_logging.py +10 -6
- wandb/sdk/interface/interface.py +68 -32
- wandb/sdk/interface/interface_shared.py +7 -13
- wandb/sdk/internal/datastore.py +1 -1
- wandb/sdk/internal/file_pusher.py +2 -5
- wandb/sdk/internal/file_stream.py +5 -18
- wandb/sdk/internal/handler.py +18 -2
- wandb/sdk/internal/internal.py +0 -1
- wandb/sdk/internal/internal_api.py +1 -129
- wandb/sdk/internal/internal_util.py +0 -1
- wandb/sdk/internal/job_builder.py +159 -45
- wandb/sdk/internal/profiler.py +1 -0
- wandb/sdk/internal/progress.py +0 -28
- wandb/sdk/internal/run.py +1 -0
- wandb/sdk/internal/sender.py +1 -2
- wandb/sdk/internal/system/assets/gpu_amd.py +44 -44
- wandb/sdk/internal/system/assets/gpu_apple.py +56 -11
- wandb/sdk/internal/system/assets/interfaces.py +6 -8
- wandb/sdk/internal/system/assets/open_metrics.py +2 -2
- wandb/sdk/internal/system/assets/trainium.py +1 -3
- wandb/sdk/launch/__init__.py +9 -1
- wandb/sdk/launch/_launch.py +4 -24
- wandb/sdk/launch/_launch_add.py +1 -3
- wandb/sdk/launch/_project_spec.py +186 -224
- wandb/sdk/launch/agent/agent.py +37 -13
- wandb/sdk/launch/agent/config.py +72 -14
- wandb/sdk/launch/builder/abstract.py +69 -1
- wandb/sdk/launch/builder/build.py +156 -555
- wandb/sdk/launch/builder/context_manager.py +235 -0
- wandb/sdk/launch/builder/docker_builder.py +8 -23
- wandb/sdk/launch/builder/kaniko_builder.py +12 -25
- wandb/sdk/launch/builder/noop.py +1 -0
- wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
- wandb/sdk/launch/create_job.py +47 -37
- wandb/sdk/launch/environment/abstract.py +1 -0
- wandb/sdk/launch/environment/gcp_environment.py +1 -0
- wandb/sdk/launch/environment/local_environment.py +1 -0
- wandb/sdk/launch/inputs/files.py +148 -0
- wandb/sdk/launch/inputs/internal.py +217 -0
- wandb/sdk/launch/inputs/manage.py +95 -0
- wandb/sdk/launch/loader.py +1 -0
- wandb/sdk/launch/registry/abstract.py +1 -0
- wandb/sdk/launch/registry/azure_container_registry.py +1 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +1 -0
- wandb/sdk/launch/registry/google_artifact_registry.py +2 -1
- wandb/sdk/launch/registry/local_registry.py +1 -0
- wandb/sdk/launch/runner/abstract.py +1 -0
- wandb/sdk/launch/runner/kubernetes_monitor.py +1 -0
- wandb/sdk/launch/runner/kubernetes_runner.py +9 -10
- wandb/sdk/launch/runner/local_container.py +2 -3
- wandb/sdk/launch/runner/local_process.py +8 -29
- wandb/sdk/launch/runner/sagemaker_runner.py +21 -20
- wandb/sdk/launch/runner/vertex_runner.py +8 -7
- wandb/sdk/launch/sweeps/scheduler.py +4 -3
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +3 -3
- wandb/sdk/launch/utils.py +15 -140
- wandb/sdk/lib/_settings_toposort_generated.py +0 -5
- wandb/sdk/lib/fsm.py +8 -12
- wandb/sdk/lib/gitlib.py +4 -4
- wandb/sdk/lib/import_hooks.py +1 -1
- wandb/sdk/lib/lazyloader.py +0 -1
- wandb/sdk/lib/proto_util.py +23 -2
- wandb/sdk/lib/redirect.py +19 -14
- wandb/sdk/lib/retry.py +3 -2
- wandb/sdk/lib/tracelog.py +1 -1
- wandb/sdk/service/service.py +19 -16
- wandb/sdk/verify/verify.py +2 -1
- wandb/sdk/wandb_init.py +14 -55
- wandb/sdk/wandb_manager.py +2 -2
- wandb/sdk/wandb_require.py +5 -0
- wandb/sdk/wandb_run.py +114 -56
- wandb/sdk/wandb_settings.py +0 -48
- wandb/sdk/wandb_setup.py +1 -1
- wandb/sklearn/__init__.py +1 -0
- wandb/sklearn/plot/__init__.py +1 -0
- wandb/sklearn/plot/classifier.py +11 -12
- wandb/sklearn/plot/clusterer.py +2 -1
- wandb/sklearn/plot/regressor.py +1 -0
- wandb/sklearn/plot/shared.py +1 -0
- wandb/sklearn/utils.py +1 -0
- wandb/testing/relay.py +4 -4
- wandb/trigger.py +1 -0
- wandb/util.py +67 -54
- wandb/wandb_controller.py +2 -3
- wandb/wandb_torch.py +1 -2
- {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/METADATA +67 -70
- {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/RECORD +177 -187
- {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/WHEEL +1 -2
- wandb/bin/apple_gpu_stats +0 -0
- wandb/catboost/__init__.py +0 -9
- wandb/fastai/__init__.py +0 -9
- wandb/keras/__init__.py +0 -18
- wandb/lightgbm/__init__.py +0 -9
- wandb/plots/__init__.py +0 -6
- wandb/plots/explain_text.py +0 -36
- wandb/plots/heatmap.py +0 -81
- wandb/plots/named_entity.py +0 -43
- wandb/plots/part_of_speech.py +0 -50
- wandb/plots/plot_definitions.py +0 -768
- wandb/plots/precision_recall.py +0 -121
- wandb/plots/roc.py +0 -103
- wandb/sacred/__init__.py +0 -3
- wandb/xgboost/__init__.py +0 -9
- wandb-0.16.6.dist-info/top_level.txt +0 -1
- {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info/licenses}/LICENSE +0 -0
@@ -1,4 +1,5 @@
|
|
1
1
|
"""Implementation of the SageMakerRunner class."""
|
2
|
+
|
2
3
|
import asyncio
|
3
4
|
import logging
|
4
5
|
from typing import Any, Dict, List, Optional, cast
|
@@ -11,8 +12,7 @@ from wandb.apis.internal import Api
|
|
11
12
|
from wandb.sdk.launch.environment.aws_environment import AwsEnvironment
|
12
13
|
from wandb.sdk.launch.errors import LaunchError
|
13
14
|
|
14
|
-
from .._project_spec import EntryPoint, LaunchProject
|
15
|
-
from ..builder.build import get_env_vars_dict
|
15
|
+
from .._project_spec import EntryPoint, LaunchProject
|
16
16
|
from ..registry.abstract import AbstractRegistry
|
17
17
|
from ..utils import (
|
18
18
|
LOG_PREFIX,
|
@@ -67,6 +67,7 @@ class SagemakerSubmittedRun(AbstractRun):
|
|
67
67
|
logGroupName="/aws/sagemaker/TrainingJobs",
|
68
68
|
logStreamName=log_name,
|
69
69
|
)
|
70
|
+
assert "events" in res
|
70
71
|
return "\n".join(
|
71
72
|
[f'{event["timestamp"]}:{event["message"]}' for event in res["events"]]
|
72
73
|
)
|
@@ -220,12 +221,12 @@ class SageMakerRunner(AbstractRunner):
|
|
220
221
|
launch_project.fill_macros(image_uri)
|
221
222
|
_logger.info("Connecting to sagemaker client")
|
222
223
|
entry_point = (
|
223
|
-
launch_project.override_entrypoint
|
224
|
-
or launch_project.get_single_entry_point()
|
225
|
-
)
|
226
|
-
command_args = get_entry_point_command(
|
227
|
-
entry_point, launch_project.override_args
|
224
|
+
launch_project.override_entrypoint or launch_project.get_job_entry_point()
|
228
225
|
)
|
226
|
+
command_args = []
|
227
|
+
if entry_point is not None:
|
228
|
+
command_args += entry_point.command
|
229
|
+
command_args += launch_project.override_args
|
229
230
|
if command_args:
|
230
231
|
command_str = " ".join(command_args)
|
231
232
|
wandb.termlog(
|
@@ -324,16 +325,16 @@ def build_sagemaker_args(
|
|
324
325
|
sagemaker_args["TrainingJobName"] = training_job_name
|
325
326
|
entry_cmd = entry_point.command if entry_point else []
|
326
327
|
|
327
|
-
sagemaker_args[
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
328
|
+
sagemaker_args["AlgorithmSpecification"] = (
|
329
|
+
merge_image_uri_with_algorithm_specification(
|
330
|
+
given_sagemaker_args.get(
|
331
|
+
"AlgorithmSpecification",
|
332
|
+
given_sagemaker_args.get("algorithm_specification"),
|
333
|
+
),
|
334
|
+
image_uri,
|
335
|
+
entry_cmd,
|
336
|
+
args,
|
337
|
+
)
|
337
338
|
)
|
338
339
|
|
339
340
|
sagemaker_args["RoleArn"] = role_arn
|
@@ -348,18 +349,18 @@ def build_sagemaker_args(
|
|
348
349
|
|
349
350
|
if sagemaker_args.get("ResourceConfig") is None:
|
350
351
|
raise LaunchError(
|
351
|
-
"Sagemaker launcher requires a ResourceConfig
|
352
|
+
"Sagemaker launcher requires a ResourceConfig resource argument"
|
352
353
|
)
|
353
354
|
|
354
355
|
if sagemaker_args.get("StoppingCondition") is None:
|
355
356
|
raise LaunchError(
|
356
|
-
"Sagemaker launcher requires a StoppingCondition
|
357
|
+
"Sagemaker launcher requires a StoppingCondition resource argument"
|
357
358
|
)
|
358
359
|
|
359
360
|
given_env = given_sagemaker_args.get(
|
360
361
|
"Environment", sagemaker_args.get("environment", {})
|
361
362
|
)
|
362
|
-
calced_env = get_env_vars_dict(
|
363
|
+
calced_env = launch_project.get_env_vars_dict(api, max_env_length)
|
363
364
|
total_env = {**calced_env, **given_env}
|
364
365
|
sagemaker_args["Environment"] = total_env
|
365
366
|
|
@@ -8,8 +8,7 @@ if False:
|
|
8
8
|
from wandb.apis.internal import Api
|
9
9
|
from wandb.util import get_module
|
10
10
|
|
11
|
-
from .._project_spec import LaunchProject
|
12
|
-
from ..builder.build import get_env_vars_dict
|
11
|
+
from .._project_spec import LaunchProject
|
13
12
|
from ..environment.gcp_environment import GcpEnvironment
|
14
13
|
from ..errors import LaunchError
|
15
14
|
from ..registry.abstract import AbstractRegistry
|
@@ -113,14 +112,16 @@ class VertexRunner(AbstractRunner):
|
|
113
112
|
synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
|
114
113
|
|
115
114
|
entry_point = (
|
116
|
-
launch_project.override_entrypoint
|
117
|
-
or launch_project.get_single_entry_point()
|
115
|
+
launch_project.override_entrypoint or launch_project.get_job_entry_point()
|
118
116
|
)
|
119
117
|
|
120
118
|
# TODO: Set entrypoint in each container
|
121
|
-
entry_cmd =
|
122
|
-
|
123
|
-
|
119
|
+
entry_cmd = []
|
120
|
+
if entry_point is not None:
|
121
|
+
entry_cmd += entry_point.command
|
122
|
+
entry_cmd += launch_project.override_args
|
123
|
+
|
124
|
+
env_vars = launch_project.get_env_vars_dict(
|
124
125
|
api=self._api,
|
125
126
|
max_env_length=MAX_ENV_LENGTHS[self.__class__.__name__],
|
126
127
|
)
|
@@ -1,4 +1,5 @@
|
|
1
1
|
"""Abstract Scheduler class."""
|
2
|
+
|
2
3
|
import asyncio
|
3
4
|
import base64
|
4
5
|
import copy
|
@@ -407,7 +408,7 @@ class Scheduler(ABC):
|
|
407
408
|
return count
|
408
409
|
|
409
410
|
def _try_load_executable(self) -> bool:
|
410
|
-
"""Check
|
411
|
+
"""Check existence of valid executable for a run.
|
411
412
|
|
412
413
|
logs and returns False when job is unreachable
|
413
414
|
"""
|
@@ -422,7 +423,7 @@ class Scheduler(ABC):
|
|
422
423
|
return False
|
423
424
|
return True
|
424
425
|
elif self._kwargs.get("image_uri"):
|
425
|
-
# TODO(gst): check docker
|
426
|
+
# TODO(gst): check docker existence? Use registry in launch config?
|
426
427
|
return True
|
427
428
|
else:
|
428
429
|
return False
|
@@ -610,7 +611,7 @@ class Scheduler(ABC):
|
|
610
611
|
f"Failed to get runstate for run ({run_id}). Error: {traceback.format_exc()}"
|
611
612
|
)
|
612
613
|
run_state = RunState.FAILED
|
613
|
-
else: # first time we get
|
614
|
+
else: # first time we get unknown state
|
614
615
|
run_state = RunState.UNKNOWN
|
615
616
|
except (AttributeError, ValueError):
|
616
617
|
wandb.termwarn(
|
@@ -1,4 +1,5 @@
|
|
1
1
|
"""Scheduler for classic wandb Sweeps."""
|
2
|
+
|
2
3
|
import logging
|
3
4
|
from pprint import pformat as pf
|
4
5
|
from typing import Any, Dict, List, Optional
|
@@ -58,7 +59,7 @@ class SweepScheduler(Scheduler):
|
|
58
59
|
return None
|
59
60
|
|
60
61
|
def _get_sweep_commands(self, worker_id: int) -> List[Dict[str, Any]]:
|
61
|
-
"""Helper to
|
62
|
+
"""Helper to receive sweep command from backend."""
|
62
63
|
# AgentHeartbeat wants a Dict of runs which are running or queued
|
63
64
|
_run_states: Dict[str, bool] = {}
|
64
65
|
for run_id, run in self._yield_runs():
|
wandb/sdk/launch/sweeps/utils.py
CHANGED
@@ -217,7 +217,7 @@ def create_sweep_command_args(command: Dict) -> Dict[str, Any]:
|
|
217
217
|
flags: List[str] = []
|
218
218
|
# (2) flags without hyphens (e.g. foo=bar)
|
219
219
|
flags_no_hyphens: List[str] = []
|
220
|
-
# (3) flags with false booleans
|
220
|
+
# (3) flags with false booleans omitted (e.g. --foo)
|
221
221
|
flags_no_booleans: List[str] = []
|
222
222
|
# (4) flags as a dictionary (used for constructing a json)
|
223
223
|
flags_dict: Dict[str, Any] = {}
|
@@ -257,7 +257,7 @@ def make_launch_sweep_entrypoint(
|
|
257
257
|
"""Use args dict from create_sweep_command_args to construct entrypoint.
|
258
258
|
|
259
259
|
If replace is True, remove macros from entrypoint, fill them in with args
|
260
|
-
and then return the args in
|
260
|
+
and then return the args in separate return value.
|
261
261
|
"""
|
262
262
|
if not command:
|
263
263
|
return None, None
|
@@ -296,7 +296,7 @@ def check_job_exists(public_api: "PublicApi", job: Optional[str]) -> bool:
|
|
296
296
|
|
297
297
|
|
298
298
|
def get_previous_args(
|
299
|
-
run_spec: Dict[str, Any]
|
299
|
+
run_spec: Dict[str, Any],
|
300
300
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
301
301
|
"""Parse through previous scheduler run_spec.
|
302
302
|
|
wandb/sdk/launch/utils.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1
|
-
# heavily inspired by https://github.com/mlflow/mlflow/blob/master/mlflow/projects/utils.py
|
2
1
|
import asyncio
|
3
2
|
import json
|
4
3
|
import logging
|
@@ -16,7 +15,6 @@ import wandb
|
|
16
15
|
import wandb.docker as docker
|
17
16
|
from wandb import util
|
18
17
|
from wandb.apis.internal import Api
|
19
|
-
from wandb.errors import CommError
|
20
18
|
from wandb.sdk.launch.errors import LaunchError
|
21
19
|
from wandb.sdk.launch.git_reference import GitReference
|
22
20
|
from wandb.sdk.launch.wandb_reference import WandbReference
|
@@ -32,7 +30,6 @@ FAILED_PACKAGES_REGEX = re.compile(
|
|
32
30
|
)
|
33
31
|
|
34
32
|
if TYPE_CHECKING: # pragma: no cover
|
35
|
-
from wandb.sdk.artifacts.artifact import Artifact
|
36
33
|
from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker
|
37
34
|
|
38
35
|
|
@@ -57,15 +54,15 @@ API_KEY_REGEX = r"WANDB_API_KEY=\w+(-\w+)?"
|
|
57
54
|
MACRO_REGEX = re.compile(r"\$\{(\w+)\}")
|
58
55
|
|
59
56
|
AZURE_CONTAINER_REGISTRY_URI_REGEX = re.compile(
|
60
|
-
r"(?:https://)?([\w]+)\.azurecr\.io/([\w\-]+):?(
|
57
|
+
r"^(?:https://)?([\w]+)\.azurecr\.io/(?P<repository>[\w\-]+):?(?P<tag>.*)"
|
61
58
|
)
|
62
59
|
|
63
60
|
ELASTIC_CONTAINER_REGISTRY_URI_REGEX = re.compile(
|
64
|
-
r"^(?P<account
|
61
|
+
r"^(?:https://)?(?P<account>[\w-]+)\.dkr\.ecr\.(?P<region>[\w-]+)\.amazonaws\.com/(?P<repository>[\w-]+):?(?P<tag>.*)$"
|
65
62
|
)
|
66
63
|
|
67
64
|
GCP_ARTIFACT_REGISTRY_URI_REGEX = re.compile(
|
68
|
-
r"^(?P<region>[\w-]+)-docker\.pkg\.dev/(?P<project>[\w-]+)/(?P<repository>[\w-]+)
|
65
|
+
r"^(?:https://)?(?P<region>[\w-]+)-docker\.pkg\.dev/(?P<project>[\w-]+)/(?P<repository>[\w-]+)/?(?P<image_name>[\w-]+)?(?P<tag>:.*)?$",
|
69
66
|
re.IGNORECASE,
|
70
67
|
)
|
71
68
|
|
@@ -316,16 +313,13 @@ def construct_launch_spec(
|
|
316
313
|
|
317
314
|
|
318
315
|
def validate_launch_spec_source(launch_spec: Dict[str, Any]) -> None:
|
319
|
-
uri = launch_spec.get("uri")
|
320
316
|
job = launch_spec.get("job")
|
321
317
|
docker_image = launch_spec.get("docker", {}).get("docker_image")
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
elif sum(map(bool, [uri, job, docker_image])) > 1:
|
328
|
-
raise LaunchError("Must specify exactly one of uri, job or image")
|
318
|
+
if bool(job) == bool(docker_image):
|
319
|
+
raise LaunchError(
|
320
|
+
"Exactly one of job or docker_image must be specified in the launch "
|
321
|
+
"spec."
|
322
|
+
)
|
329
323
|
|
330
324
|
|
331
325
|
def parse_wandb_uri(uri: str) -> Tuple[str, str, str]:
|
@@ -336,77 +330,6 @@ def parse_wandb_uri(uri: str) -> Tuple[str, str, str]:
|
|
336
330
|
return (ref.entity, ref.project, ref.run_id)
|
337
331
|
|
338
332
|
|
339
|
-
def is_bare_wandb_uri(uri: str) -> bool:
|
340
|
-
"""Check that a wandb uri is valid.
|
341
|
-
|
342
|
-
URI must be in the format
|
343
|
-
`/<entity>/<project>/runs/<run_name>[other stuff]`
|
344
|
-
or
|
345
|
-
`/<entity>/<project>/artifacts/job/<job_name>[other stuff]`.
|
346
|
-
"""
|
347
|
-
_logger.info(f"Checking if uri {uri} is bare...")
|
348
|
-
return uri.startswith("/") and WandbReference.is_uri_job_or_run(uri)
|
349
|
-
|
350
|
-
|
351
|
-
def fetch_wandb_project_run_info(
|
352
|
-
entity: str, project: str, run_name: str, api: Api
|
353
|
-
) -> Any:
|
354
|
-
_logger.info("Fetching run info...")
|
355
|
-
try:
|
356
|
-
result = api.get_run_info(entity, project, run_name)
|
357
|
-
except CommError:
|
358
|
-
result = None
|
359
|
-
if result is None:
|
360
|
-
raise LaunchError(
|
361
|
-
f"Run info is invalid or doesn't exist for {api.settings('base_url')}/{entity}/{project}/runs/{run_name}"
|
362
|
-
)
|
363
|
-
if result.get("codePath") is None:
|
364
|
-
# TODO: we don't currently expose codePath in the runInfo endpoint, this downloads
|
365
|
-
# it from wandb-metadata.json if we can.
|
366
|
-
metadata = api.download_url(
|
367
|
-
project, "wandb-metadata.json", run=run_name, entity=entity
|
368
|
-
)
|
369
|
-
if metadata is not None:
|
370
|
-
_, response = api.download_file(metadata["url"])
|
371
|
-
data = response.json()
|
372
|
-
result["codePath"] = data.get("codePath")
|
373
|
-
result["cudaVersion"] = data.get("cuda", None)
|
374
|
-
|
375
|
-
return result
|
376
|
-
|
377
|
-
|
378
|
-
def download_entry_point(
|
379
|
-
entity: str, project: str, run_name: str, api: Api, entry_point: str, dir: str
|
380
|
-
) -> bool:
|
381
|
-
metadata = api.download_url(
|
382
|
-
project, f"code/{entry_point}", run=run_name, entity=entity
|
383
|
-
)
|
384
|
-
if metadata is not None:
|
385
|
-
_, response = api.download_file(metadata["url"])
|
386
|
-
with util.fsync_open(os.path.join(dir, entry_point), "wb") as file:
|
387
|
-
for data in response.iter_content(chunk_size=1024):
|
388
|
-
file.write(data)
|
389
|
-
return True
|
390
|
-
return False
|
391
|
-
|
392
|
-
|
393
|
-
def download_wandb_python_deps(
|
394
|
-
entity: str, project: str, run_name: str, api: Api, dir: str
|
395
|
-
) -> Optional[str]:
|
396
|
-
reqs = api.download_url(project, "requirements.txt", run=run_name, entity=entity)
|
397
|
-
if reqs is not None:
|
398
|
-
_logger.info("Downloading python dependencies")
|
399
|
-
_, response = api.download_file(reqs["url"])
|
400
|
-
|
401
|
-
with util.fsync_open(
|
402
|
-
os.path.join(dir, "requirements.frozen.txt"), "wb"
|
403
|
-
) as file:
|
404
|
-
for data in response.iter_content(chunk_size=1024):
|
405
|
-
file.write(data)
|
406
|
-
return "requirements.frozen.txt"
|
407
|
-
return None
|
408
|
-
|
409
|
-
|
410
333
|
def get_local_python_deps(
|
411
334
|
dir: str, filename: str = "requirements.local.txt"
|
412
335
|
) -> Optional[str]:
|
@@ -498,19 +421,6 @@ def validate_wandb_python_deps(
|
|
498
421
|
_logger.warning("Unable to validate local python dependencies")
|
499
422
|
|
500
423
|
|
501
|
-
def fetch_project_diff(
|
502
|
-
entity: str, project: str, run_name: str, api: Api
|
503
|
-
) -> Optional[str]:
|
504
|
-
"""Fetches project diff from wandb servers."""
|
505
|
-
_logger.info("Searching for diff.patch")
|
506
|
-
patch = None
|
507
|
-
try:
|
508
|
-
(_, _, patch, _) = api.run_config(project, run_name, entity)
|
509
|
-
except CommError:
|
510
|
-
pass
|
511
|
-
return patch
|
512
|
-
|
513
|
-
|
514
424
|
def apply_patch(patch_string: str, dst_dir: str) -> None:
|
515
425
|
"""Applies a patch file to a directory."""
|
516
426
|
_logger.info("Applying diff.patch")
|
@@ -531,17 +441,6 @@ def apply_patch(patch_string: str, dst_dir: str) -> None:
|
|
531
441
|
raise wandb.Error("Failed to apply diff.patch associated with run.")
|
532
442
|
|
533
443
|
|
534
|
-
def _make_refspec_from_version(version: Optional[str]) -> List[str]:
|
535
|
-
"""Create a refspec that checks for the existence of origin/main and the version."""
|
536
|
-
if version:
|
537
|
-
return [f"+{version}"]
|
538
|
-
|
539
|
-
return [
|
540
|
-
"+refs/heads/main*:refs/remotes/origin/main*",
|
541
|
-
"+refs/heads/master*:refs/remotes/origin/master*",
|
542
|
-
]
|
543
|
-
|
544
|
-
|
545
444
|
def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> Optional[str]:
|
546
445
|
"""Clones the git repo at ``uri`` into ``dst_dir``.
|
547
446
|
|
@@ -561,13 +460,6 @@ def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> Optional[
|
|
561
460
|
return version
|
562
461
|
|
563
462
|
|
564
|
-
def merge_parameters(
|
565
|
-
higher_priority_params: Dict[str, Any], lower_priority_params: Dict[str, Any]
|
566
|
-
) -> Dict[str, Any]:
|
567
|
-
"""Merge the contents of two dicts, keeping values from higher_priority_params if there are conflicts."""
|
568
|
-
return {**lower_priority_params, **higher_priority_params}
|
569
|
-
|
570
|
-
|
571
463
|
def convert_jupyter_notebook_to_script(fname: str, project_dir: str) -> str:
|
572
464
|
nbconvert = wandb.util.get_module(
|
573
465
|
"nbconvert", "nbformat and nbconvert are required to use launch with notebooks"
|
@@ -597,25 +489,6 @@ def convert_jupyter_notebook_to_script(fname: str, project_dir: str) -> str:
|
|
597
489
|
return new_name
|
598
490
|
|
599
491
|
|
600
|
-
def check_and_download_code_artifacts(
|
601
|
-
entity: str, project: str, run_name: str, internal_api: Api, project_dir: str
|
602
|
-
) -> Optional["Artifact"]:
|
603
|
-
_logger.info("Checking for code artifacts")
|
604
|
-
public_api = wandb.PublicApi(
|
605
|
-
overrides={"base_url": internal_api.settings("base_url")}
|
606
|
-
)
|
607
|
-
|
608
|
-
run = public_api.run(f"{entity}/{project}/{run_name}")
|
609
|
-
run_artifacts = run.logged_artifacts()
|
610
|
-
|
611
|
-
for artifact in run_artifacts:
|
612
|
-
if hasattr(artifact, "type") and artifact.type == "code":
|
613
|
-
artifact.download(project_dir)
|
614
|
-
return artifact # type: ignore
|
615
|
-
|
616
|
-
return None
|
617
|
-
|
618
|
-
|
619
492
|
def to_camel_case(maybe_snake_str: str) -> str:
|
620
493
|
if "_" not in maybe_snake_str:
|
621
494
|
return maybe_snake_str
|
@@ -623,11 +496,6 @@ def to_camel_case(maybe_snake_str: str) -> str:
|
|
623
496
|
return "".join(x.title() if x else "_" for x in components)
|
624
497
|
|
625
498
|
|
626
|
-
def run_shell(args: List[str]) -> Tuple[str, str]:
|
627
|
-
out = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
628
|
-
return out.stdout.decode("utf-8").strip(), out.stderr.decode("utf-8").strip()
|
629
|
-
|
630
|
-
|
631
499
|
def validate_build_and_registry_configs(
|
632
500
|
build_config: Dict[str, Any], registry_config: Dict[str, Any]
|
633
501
|
) -> None:
|
@@ -864,3 +732,10 @@ def get_entrypoint_file(entrypoint: List[str]) -> Optional[str]:
|
|
864
732
|
if len(entrypoint) < 2:
|
865
733
|
return None
|
866
734
|
return entrypoint[1]
|
735
|
+
|
736
|
+
|
737
|
+
def get_current_python_version() -> Tuple[str, str]:
|
738
|
+
full_version = sys.version.split()[0].split(".")
|
739
|
+
major = full_version[0]
|
740
|
+
version = ".".join(full_version[:2]) if len(full_version) >= 2 else major + ".0"
|
741
|
+
return version, major
|
@@ -13,7 +13,6 @@ else:
|
|
13
13
|
_Setting = Literal[
|
14
14
|
"_args",
|
15
15
|
"_aws_lambda",
|
16
|
-
"_async_upload_concurrency_limit",
|
17
16
|
"_cli_only_mode",
|
18
17
|
"_code_path_local",
|
19
18
|
"_colab",
|
@@ -25,7 +24,6 @@ _Setting = Literal[
|
|
25
24
|
"_disable_update_check",
|
26
25
|
"_disable_viewer",
|
27
26
|
"_disable_machine_info",
|
28
|
-
"_except_exit",
|
29
27
|
"_executable",
|
30
28
|
"_extra_http_headers",
|
31
29
|
"_file_stream_retry_max",
|
@@ -126,7 +124,6 @@ _Setting = Literal[
|
|
126
124
|
"login_timeout",
|
127
125
|
"mode",
|
128
126
|
"notebook_name",
|
129
|
-
"problem",
|
130
127
|
"program",
|
131
128
|
"program_abspath",
|
132
129
|
"program_relpath",
|
@@ -179,7 +176,6 @@ _Setting = Literal[
|
|
179
176
|
]
|
180
177
|
|
181
178
|
SETTINGS_TOPOLOGICALLY_SORTED: Final[Tuple[_Setting, ...]] = (
|
182
|
-
"_async_upload_concurrency_limit",
|
183
179
|
"_service_wait",
|
184
180
|
"_stats_sample_rate_seconds",
|
185
181
|
"_stats_samples_to_average",
|
@@ -189,7 +185,6 @@ SETTINGS_TOPOLOGICALLY_SORTED: Final[Tuple[_Setting, ...]] = (
|
|
189
185
|
"console",
|
190
186
|
"job_source",
|
191
187
|
"mode",
|
192
|
-
"problem",
|
193
188
|
"project",
|
194
189
|
"run_id",
|
195
190
|
"start_method",
|
wandb/sdk/lib/fsm.py
CHANGED
@@ -52,43 +52,39 @@ T_FsmContext_contra = TypeVar("T_FsmContext_contra", contravariant=True)
|
|
52
52
|
@runtime_checkable
|
53
53
|
class FsmStateCheck(Protocol[T_FsmInputs]):
|
54
54
|
@abstractmethod
|
55
|
-
def on_check(self, inputs: T_FsmInputs) -> None:
|
56
|
-
... # pragma: no cover
|
55
|
+
def on_check(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
|
57
56
|
|
58
57
|
|
59
58
|
@runtime_checkable
|
60
59
|
class FsmStateOutput(Protocol[T_FsmInputs]):
|
61
60
|
@abstractmethod
|
62
|
-
def on_state(self, inputs: T_FsmInputs) -> None:
|
63
|
-
... # pragma: no cover
|
61
|
+
def on_state(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
|
64
62
|
|
65
63
|
|
66
64
|
@runtime_checkable
|
67
65
|
class FsmStateEnter(Protocol[T_FsmInputs]):
|
68
66
|
@abstractmethod
|
69
|
-
def on_enter(self, inputs: T_FsmInputs) -> None:
|
70
|
-
... # pragma: no cover
|
67
|
+
def on_enter(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
|
71
68
|
|
72
69
|
|
73
70
|
@runtime_checkable
|
74
71
|
class FsmStateEnterWithContext(Protocol[T_FsmInputs, T_FsmContext_contra]):
|
75
72
|
@abstractmethod
|
76
|
-
def on_enter(
|
77
|
-
|
73
|
+
def on_enter(
|
74
|
+
self, inputs: T_FsmInputs, context: T_FsmContext_contra
|
75
|
+
) -> None: ... # pragma: no cover
|
78
76
|
|
79
77
|
|
80
78
|
@runtime_checkable
|
81
79
|
class FsmStateStay(Protocol[T_FsmInputs]):
|
82
80
|
@abstractmethod
|
83
|
-
def on_stay(self, inputs: T_FsmInputs) -> None:
|
84
|
-
... # pragma: no cover
|
81
|
+
def on_stay(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
|
85
82
|
|
86
83
|
|
87
84
|
@runtime_checkable
|
88
85
|
class FsmStateExit(Protocol[T_FsmInputs, T_FsmContext_cov]):
|
89
86
|
@abstractmethod
|
90
|
-
def on_exit(self, inputs: T_FsmInputs) -> T_FsmContext_cov:
|
91
|
-
... # pragma: no cover
|
87
|
+
def on_exit(self, inputs: T_FsmInputs) -> T_FsmContext_cov: ... # pragma: no cover
|
92
88
|
|
93
89
|
|
94
90
|
# It would be nice if python provided optional protocol members, but it doesnt as described here:
|
wandb/sdk/lib/gitlib.py
CHANGED
@@ -14,7 +14,7 @@ try:
|
|
14
14
|
Repo,
|
15
15
|
)
|
16
16
|
except ImportError:
|
17
|
-
Repo = None
|
17
|
+
Repo = None # type: ignore
|
18
18
|
|
19
19
|
if TYPE_CHECKING:
|
20
20
|
from git import Repo
|
@@ -121,7 +121,7 @@ class GitRepo:
|
|
121
121
|
# TODO: Saw a user getting a Unicode decode error when parsing refs,
|
122
122
|
# more details on implementing a real fix in [WB-4064]
|
123
123
|
try:
|
124
|
-
if len(self.repo.refs) > 0:
|
124
|
+
if len(self.repo.refs) > 0: # type: ignore[arg-type]
|
125
125
|
return self.repo.head.commit.hexsha
|
126
126
|
else:
|
127
127
|
return self.repo.git.show_ref("--head").split(" ")[0]
|
@@ -140,7 +140,7 @@ class GitRepo:
|
|
140
140
|
if not self.repo:
|
141
141
|
return None
|
142
142
|
try:
|
143
|
-
return self.repo.remotes[self.remote_name]
|
143
|
+
return self.repo.remotes[self.remote_name] # type: ignore[index]
|
144
144
|
except IndexError:
|
145
145
|
return None
|
146
146
|
|
@@ -200,7 +200,7 @@ class GitRepo:
|
|
200
200
|
possible_relatives.append(tracking_branch.commit)
|
201
201
|
|
202
202
|
if not possible_relatives:
|
203
|
-
for branch in self.repo.branches:
|
203
|
+
for branch in self.repo.branches: # type: ignore[attr-defined]
|
204
204
|
tracking_branch = branch.tracking_branch()
|
205
205
|
if tracking_branch is not None:
|
206
206
|
possible_relatives.append(tracking_branch.commit)
|
wandb/sdk/lib/import_hooks.py
CHANGED
@@ -143,7 +143,7 @@ class _ImportHookChainedLoader:
|
|
143
143
|
# None, so handle None as well. The module may not support attribute
|
144
144
|
# assignment, in which case we simply skip it. Note that we also deal
|
145
145
|
# with __loader__ not existing at all. This is to future proof things
|
146
|
-
# due to proposal to remove the
|
146
|
+
# due to proposal to remove the attribute as described in the GitHub
|
147
147
|
# issue at https://github.com/python/cpython/issues/77458. Also prior
|
148
148
|
# to Python 3.3, the __loader__ attribute was only set if a custom
|
149
149
|
# module loader was used. It isn't clear whether the attribute still
|
wandb/sdk/lib/lazyloader.py
CHANGED
wandb/sdk/lib/proto_util.py
CHANGED
@@ -12,7 +12,28 @@ if TYPE_CHECKING: # pragma: no cover
|
|
12
12
|
|
13
13
|
|
14
14
|
def dict_from_proto_list(obj_list: "RepeatedCompositeFieldContainer") -> Dict[str, Any]:
|
15
|
-
|
15
|
+
result: Dict[str, Any] = {}
|
16
|
+
|
17
|
+
for item in obj_list:
|
18
|
+
# Start from the root of the result dict
|
19
|
+
current_level = result
|
20
|
+
|
21
|
+
if len(item.nested_key) > 0:
|
22
|
+
keys = list(item.nested_key)
|
23
|
+
else:
|
24
|
+
keys = [item.key]
|
25
|
+
|
26
|
+
for key in keys[:-1]:
|
27
|
+
if key not in current_level:
|
28
|
+
current_level[key] = {}
|
29
|
+
# Move the reference deeper into the nested dictionary
|
30
|
+
current_level = current_level[key]
|
31
|
+
|
32
|
+
# Set the value at the final key location, parsing JSON from the value_json field
|
33
|
+
final_key = keys[-1]
|
34
|
+
current_level[final_key] = json.loads(item.value_json)
|
35
|
+
|
36
|
+
return result
|
16
37
|
|
17
38
|
|
18
39
|
def _result_from_record(record: "pb.Record") -> "pb.Result":
|
@@ -29,7 +50,7 @@ def _assign_end_offset(record: "pb.Record", end_offset: int) -> None:
|
|
29
50
|
|
30
51
|
|
31
52
|
def proto_encode_to_dict(
|
32
|
-
pb_obj: Union["tpb.TelemetryRecord", "pb.MetricRecord"]
|
53
|
+
pb_obj: Union["tpb.TelemetryRecord", "pb.MetricRecord"],
|
33
54
|
) -> Dict[int, Any]:
|
34
55
|
data: Dict[int, Any] = dict()
|
35
56
|
fields = pb_obj.ListFields()
|