wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -1
- wandb/analytics/sentry.py +1 -0
- wandb/apis/importers/base.py +20 -5
- wandb/apis/importers/mlflow.py +7 -1
- wandb/apis/internal.py +12 -0
- wandb/apis/public.py +247 -1387
- wandb/apis/reports/_panels.py +58 -35
- wandb/beta/workflows.py +6 -7
- wandb/cli/cli.py +130 -60
- wandb/data_types.py +3 -1
- wandb/filesync/dir_watcher.py +21 -27
- wandb/filesync/step_checksum.py +8 -8
- wandb/filesync/step_prepare.py +23 -10
- wandb/filesync/step_upload.py +13 -13
- wandb/filesync/upload_job.py +4 -8
- wandb/integration/cohere/__init__.py +3 -0
- wandb/integration/cohere/cohere.py +21 -0
- wandb/integration/cohere/resolver.py +347 -0
- wandb/integration/gym/__init__.py +4 -6
- wandb/integration/huggingface/__init__.py +3 -0
- wandb/integration/huggingface/huggingface.py +18 -0
- wandb/integration/huggingface/resolver.py +213 -0
- wandb/integration/langchain/wandb_tracer.py +16 -179
- wandb/integration/openai/__init__.py +1 -3
- wandb/integration/openai/openai.py +11 -143
- wandb/integration/openai/resolver.py +111 -38
- wandb/integration/sagemaker/config.py +2 -2
- wandb/integration/tensorboard/log.py +4 -4
- wandb/old/settings.py +24 -7
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/wandb_deprecated.py +3 -1
- wandb/sdk/__init__.py +1 -1
- wandb/sdk/artifacts/__init__.py +0 -0
- wandb/sdk/artifacts/artifact.py +2101 -0
- wandb/sdk/artifacts/artifact_download_logger.py +42 -0
- wandb/sdk/artifacts/artifact_manifest.py +67 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
- wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
- wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
- wandb/sdk/artifacts/artifact_state.py +10 -0
- wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
- wandb/sdk/artifacts/exceptions.py +55 -0
- wandb/sdk/artifacts/storage_handler.py +59 -0
- wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
- wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
- wandb/sdk/artifacts/storage_layout.py +6 -0
- wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
- wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
- wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
- wandb/sdk/data_types/_dtypes.py +7 -12
- wandb/sdk/data_types/base_types/json_metadata.py +3 -2
- wandb/sdk/data_types/base_types/media.py +8 -8
- wandb/sdk/data_types/base_types/wb_value.py +12 -13
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
- wandb/sdk/data_types/helper_types/classes.py +6 -8
- wandb/sdk/data_types/helper_types/image_mask.py +5 -6
- wandb/sdk/data_types/histogram.py +4 -3
- wandb/sdk/data_types/html.py +3 -4
- wandb/sdk/data_types/image.py +11 -9
- wandb/sdk/data_types/molecule.py +5 -3
- wandb/sdk/data_types/object_3d.py +7 -5
- wandb/sdk/data_types/plotly.py +3 -2
- wandb/sdk/data_types/saved_model.py +11 -11
- wandb/sdk/data_types/trace_tree.py +5 -4
- wandb/sdk/data_types/utils.py +3 -5
- wandb/sdk/data_types/video.py +5 -4
- wandb/sdk/integration_utils/auto_logging.py +215 -0
- wandb/sdk/interface/interface.py +15 -15
- wandb/sdk/internal/file_pusher.py +8 -16
- wandb/sdk/internal/file_stream.py +5 -11
- wandb/sdk/internal/handler.py +13 -1
- wandb/sdk/internal/internal_api.py +287 -13
- wandb/sdk/internal/job_builder.py +119 -30
- wandb/sdk/internal/sender.py +6 -26
- wandb/sdk/internal/settings_static.py +2 -0
- wandb/sdk/internal/system/assets/__init__.py +2 -0
- wandb/sdk/internal/system/assets/gpu.py +42 -0
- wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
- wandb/sdk/internal/system/env_probe_helpers.py +13 -0
- wandb/sdk/internal/system/system_info.py +3 -3
- wandb/sdk/internal/tb_watcher.py +32 -22
- wandb/sdk/internal/thread_local_settings.py +18 -0
- wandb/sdk/launch/_project_spec.py +57 -11
- wandb/sdk/launch/agent/agent.py +147 -65
- wandb/sdk/launch/agent/job_status_tracker.py +34 -0
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
- wandb/sdk/launch/builder/abstract.py +5 -1
- wandb/sdk/launch/builder/build.py +21 -18
- wandb/sdk/launch/builder/docker_builder.py +10 -4
- wandb/sdk/launch/builder/kaniko_builder.py +113 -23
- wandb/sdk/launch/builder/noop.py +6 -3
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
- wandb/sdk/launch/environment/aws_environment.py +3 -2
- wandb/sdk/launch/environment/azure_environment.py +124 -0
- wandb/sdk/launch/environment/gcp_environment.py +2 -4
- wandb/sdk/launch/environment/local_environment.py +1 -1
- wandb/sdk/launch/errors.py +19 -0
- wandb/sdk/launch/github_reference.py +32 -19
- wandb/sdk/launch/launch.py +3 -8
- wandb/sdk/launch/launch_add.py +6 -2
- wandb/sdk/launch/loader.py +21 -2
- wandb/sdk/launch/registry/azure_container_registry.py +132 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
- wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
- wandb/sdk/launch/registry/local_registry.py +2 -1
- wandb/sdk/launch/runner/abstract.py +24 -3
- wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
- wandb/sdk/launch/runner/local_container.py +103 -51
- wandb/sdk/launch/runner/local_process.py +1 -1
- wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
- wandb/sdk/launch/runner/vertex_runner.py +10 -5
- wandb/sdk/launch/sweeps/__init__.py +7 -9
- wandb/sdk/launch/sweeps/scheduler.py +307 -77
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +82 -35
- wandb/sdk/launch/utils.py +89 -75
- wandb/sdk/lib/_settings_toposort_generated.py +7 -0
- wandb/sdk/lib/capped_dict.py +26 -0
- wandb/sdk/lib/{git.py → gitlib.py} +76 -59
- wandb/sdk/lib/hashutil.py +12 -4
- wandb/sdk/lib/paths.py +96 -8
- wandb/sdk/lib/sock_client.py +2 -2
- wandb/sdk/lib/timer.py +1 -0
- wandb/sdk/service/server.py +22 -9
- wandb/sdk/service/server_sock.py +1 -1
- wandb/sdk/service/service.py +27 -8
- wandb/sdk/verify/verify.py +4 -7
- wandb/sdk/wandb_config.py +2 -6
- wandb/sdk/wandb_init.py +57 -53
- wandb/sdk/wandb_require.py +7 -0
- wandb/sdk/wandb_run.py +61 -223
- wandb/sdk/wandb_settings.py +28 -4
- wandb/testing/relay.py +15 -2
- wandb/util.py +74 -36
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
- wandb/integration/langchain/util.py +0 -191
- wandb/sdk/interface/artifacts/__init__.py +0 -33
- wandb/sdk/interface/artifacts/artifact.py +0 -615
- wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
- wandb/sdk/wandb_artifacts.py +0 -2226
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
wandb/sdk/internal/tb_watcher.py
CHANGED
@@ -58,12 +58,16 @@ def _link_and_save_file(
|
|
58
58
|
interface.publish_files(dict(files=[(GlobStr(glob.escape(file_name)), "live")]))
|
59
59
|
|
60
60
|
|
61
|
-
def is_tfevents_file_created_by(
|
62
|
-
|
61
|
+
def is_tfevents_file_created_by(
|
62
|
+
path: str, hostname: Optional[str], start_time: Optional[float]
|
63
|
+
) -> bool:
|
64
|
+
"""Check if a path is a tfevents file.
|
65
|
+
|
66
|
+
Optionally checks that it was created by [hostname] after [start_time].
|
63
67
|
|
64
68
|
tensorboard tfevents filename format:
|
65
69
|
https://github.com/tensorflow/tensorboard/blob/f3f26b46981da5bd46a5bb93fcf02d9eb7608bc1/tensorboard/summary/writer/event_file_writer.py#L81
|
66
|
-
tensorflow tfevents
|
70
|
+
tensorflow tfevents filename format:
|
67
71
|
https://github.com/tensorflow/tensorflow/blob/8f597046dc30c14b5413813d02c0e0aed399c177/tensorflow/core/util/events_writer.cc#L68
|
68
72
|
"""
|
69
73
|
if not path:
|
@@ -77,23 +81,27 @@ def is_tfevents_file_created_by(path: str, hostname: str, start_time: float) ->
|
|
77
81
|
except ValueError:
|
78
82
|
return False
|
79
83
|
# check the hostname, which may have dots
|
80
|
-
|
84
|
+
if hostname is not None:
|
85
|
+
for i, part in enumerate(hostname.split(".")):
|
86
|
+
try:
|
87
|
+
fname_component_part = fname_components[tfevents_idx + 2 + i]
|
88
|
+
except IndexError:
|
89
|
+
return False
|
90
|
+
if part != fname_component_part:
|
91
|
+
return False
|
92
|
+
if start_time is not None:
|
81
93
|
try:
|
82
|
-
|
83
|
-
except IndexError:
|
94
|
+
created_time = int(fname_components[tfevents_idx + 1])
|
95
|
+
except (ValueError, IndexError):
|
84
96
|
return False
|
85
|
-
|
97
|
+
# Ensure that the file is newer then our start time, and that it was
|
98
|
+
# created from the same hostname.
|
99
|
+
# TODO: we should also check the PID (also contained in the tfevents
|
100
|
+
# filename). Can we assume that our parent pid is the user process
|
101
|
+
# that wrote these files?
|
102
|
+
if created_time < int(start_time):
|
86
103
|
return False
|
87
|
-
|
88
|
-
created_time = int(fname_components[tfevents_idx + 1])
|
89
|
-
except (ValueError, IndexError):
|
90
|
-
return False
|
91
|
-
# Ensure that the file is newer then our start time, and that it was
|
92
|
-
# created from the same hostname.
|
93
|
-
# TODO: we should also check the PID (also contained in the tfevents
|
94
|
-
# filename). Can we assume that our parent pid is the user process
|
95
|
-
# that wrote these files?
|
96
|
-
return created_time >= int(start_time)
|
104
|
+
return True
|
97
105
|
|
98
106
|
|
99
107
|
class TBWatcher:
|
@@ -136,6 +144,7 @@ class TBWatcher:
|
|
136
144
|
# Note that we strip '/' instead of os.sep, because elsewhere we've
|
137
145
|
# converted paths to forward slash.
|
138
146
|
namespace = logdir.replace(filename, "").replace(rootdir, "").strip("/")
|
147
|
+
|
139
148
|
# TODO: revisit this heuristic, it exists because we don't know the
|
140
149
|
# root log directory until more than one tfevents file is written to
|
141
150
|
if len(dirs) == 1 and namespace not in ["train", "validation"]:
|
@@ -217,12 +226,13 @@ class TBDirWatcher:
|
|
217
226
|
"""Check if a path has been modified since launch and contains tfevents."""
|
218
227
|
if not path:
|
219
228
|
raise ValueError("Path must be a nonempty string")
|
220
|
-
if self._force:
|
221
|
-
return True
|
222
229
|
path = self.tf_compat.tf.compat.as_str_any(path)
|
223
|
-
|
224
|
-
path,
|
225
|
-
|
230
|
+
if self._force:
|
231
|
+
return is_tfevents_file_created_by(path, None, None)
|
232
|
+
else:
|
233
|
+
return is_tfevents_file_created_by(
|
234
|
+
path, self._hostname, self._tbwatcher._settings._start_time
|
235
|
+
)
|
226
236
|
|
227
237
|
def _loader(
|
228
238
|
self, save: bool = True, namespace: Optional[str] = None
|
@@ -0,0 +1,18 @@
|
|
1
|
+
import threading
|
2
|
+
from typing import Dict, Optional
|
3
|
+
|
4
|
+
|
5
|
+
# Context variable for setting API settings (api keys, etc.) for internal and public apis thread-locally
|
6
|
+
# TODO: move this into actual settings
|
7
|
+
class _ThreadLocalApiSettings(threading.local):
|
8
|
+
api_key: Optional[str]
|
9
|
+
cookies: Optional[Dict]
|
10
|
+
headers: Optional[Dict]
|
11
|
+
|
12
|
+
def __init__(self) -> None:
|
13
|
+
self.api_key = None
|
14
|
+
self.cookies = None
|
15
|
+
self.headers = None
|
16
|
+
|
17
|
+
|
18
|
+
_thread_local_api_settings: _ThreadLocalApiSettings = _ThreadLocalApiSettings()
|
@@ -7,17 +7,20 @@ import json
|
|
7
7
|
import logging
|
8
8
|
import os
|
9
9
|
import tempfile
|
10
|
-
from typing import Any, Dict, List, Optional
|
10
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
11
11
|
|
12
12
|
import wandb
|
13
13
|
import wandb.docker as docker
|
14
14
|
from wandb.apis.internal import Api
|
15
|
-
from wandb.apis.public import Artifact as PublicArtifact
|
16
15
|
from wandb.errors import CommError
|
16
|
+
from wandb.sdk.launch import utils
|
17
17
|
from wandb.sdk.lib.runid import generate_id
|
18
18
|
|
19
|
-
from . import
|
20
|
-
from .utils import LOG_PREFIX,
|
19
|
+
from .errors import LaunchError
|
20
|
+
from .utils import LOG_PREFIX, recursive_macro_sub
|
21
|
+
|
22
|
+
if TYPE_CHECKING:
|
23
|
+
from wandb.sdk.artifacts.artifact import Artifact
|
21
24
|
|
22
25
|
_logger = logging.getLogger(__name__)
|
23
26
|
|
@@ -59,6 +62,7 @@ class LaunchProject:
|
|
59
62
|
resource: str,
|
60
63
|
resource_args: Dict[str, Any],
|
61
64
|
run_id: Optional[str],
|
65
|
+
sweep_id: Optional[str] = None,
|
62
66
|
):
|
63
67
|
if uri is not None and utils.is_bare_wandb_uri(uri):
|
64
68
|
uri = api.settings("base_url") + uri
|
@@ -67,7 +71,7 @@ class LaunchProject:
|
|
67
71
|
self.job = job
|
68
72
|
if job is not None:
|
69
73
|
wandb.termlog(f"{LOG_PREFIX}Launching job: {job}")
|
70
|
-
self._job_artifact: Optional[
|
74
|
+
self._job_artifact: Optional["Artifact"] = None
|
71
75
|
self.api = api
|
72
76
|
self.launch_spec = launch_spec
|
73
77
|
self.target_entity = target_entity
|
@@ -78,11 +82,12 @@ class LaunchProject:
|
|
78
82
|
# runner, so we need to pop the builder key out
|
79
83
|
resource_args_build = resource_args.get(resource, {}).pop("builder", {})
|
80
84
|
self.resource = resource
|
81
|
-
self.resource_args = resource_args
|
85
|
+
self.resource_args = resource_args.copy()
|
86
|
+
self.sweep_id = sweep_id
|
82
87
|
self.python_version: Optional[str] = launch_spec.get("python_version")
|
83
|
-
self.
|
84
|
-
"
|
85
|
-
)
|
88
|
+
self.accelerator_base_image: Optional[str] = resource_args_build.get(
|
89
|
+
"accelerator", {}
|
90
|
+
).get("base_image") or resource_args_build.get("cuda", {}).get("base_image")
|
86
91
|
self._base_image: Optional[str] = launch_spec.get("base_image")
|
87
92
|
self.docker_image: Optional[str] = docker_config.get(
|
88
93
|
"docker_image"
|
@@ -110,6 +115,9 @@ class LaunchProject:
|
|
110
115
|
self.override_entrypoint = self.add_entry_point(
|
111
116
|
overrides.get("entry_point") # type: ignore
|
112
117
|
)
|
118
|
+
if overrides.get("sweep_id") is not None:
|
119
|
+
_logger.info("Adding override sweep id")
|
120
|
+
self.sweep_id = overrides["sweep_id"]
|
113
121
|
if self.docker_image is not None:
|
114
122
|
self.source = LaunchSource.DOCKER
|
115
123
|
self.project_dir = None
|
@@ -172,6 +180,43 @@ class LaunchProject:
|
|
172
180
|
assert self.job is not None
|
173
181
|
return wandb.util.make_docker_image_name_safe(self.job.split(":")[0])
|
174
182
|
|
183
|
+
def fill_macros(self, image: str) -> None:
|
184
|
+
"""Substitute values for macros in resource arguments.
|
185
|
+
|
186
|
+
Certain macros can be used in resource args. These macros allow the
|
187
|
+
user to set resource args dynamically in the context of the
|
188
|
+
run being launched. The macros are given in the ${macro} format. The
|
189
|
+
following macros are currently supported:
|
190
|
+
|
191
|
+
${project_name} - the name of the project the run is being launched to.
|
192
|
+
${entity_name} - the owner of the project the run being launched to.
|
193
|
+
${run_id} - the id of the run being launched.
|
194
|
+
${run_name} - the name of the run that is launching.
|
195
|
+
${image_uri} - the URI of the container image for this run.
|
196
|
+
|
197
|
+
Additionally, you may use ${<ENV-VAR-NAME>} to refer to the value of any
|
198
|
+
environment variables that you plan to set in the environment of any
|
199
|
+
agents that will receive these resource args.
|
200
|
+
|
201
|
+
Calling this method will overwrite the contents of self.resource_args
|
202
|
+
with the substituted values.
|
203
|
+
|
204
|
+
Args:
|
205
|
+
image (str): The image name to fill in for ${wandb-image}.
|
206
|
+
|
207
|
+
Returns:
|
208
|
+
None
|
209
|
+
"""
|
210
|
+
update_dict = {
|
211
|
+
"project_name": self.target_project,
|
212
|
+
"entity_name": self.target_entity,
|
213
|
+
"run_id": self.run_id,
|
214
|
+
"run_name": self.name,
|
215
|
+
"image_uri": image,
|
216
|
+
}
|
217
|
+
update_dict.update(os.environ)
|
218
|
+
self.resource_args = recursive_macro_sub(self.resource_args, update_dict)
|
219
|
+
|
175
220
|
def build_required(self) -> bool:
|
176
221
|
"""Checks the source to see if a build is required."""
|
177
222
|
# since the image tag for images built from jobs
|
@@ -416,6 +461,7 @@ def create_project_from_spec(launch_spec: Dict[str, Any], api: Api) -> LaunchPro
|
|
416
461
|
launch_spec.get("resource", None),
|
417
462
|
launch_spec.get("resource_args", {}),
|
418
463
|
launch_spec.get("run_id", None),
|
464
|
+
launch_spec.get("sweep_id", {}),
|
419
465
|
)
|
420
466
|
|
421
467
|
|
@@ -446,8 +492,8 @@ def fetch_and_validate_project(
|
|
446
492
|
launch_project._fetch_project_local(internal_api=api)
|
447
493
|
|
448
494
|
assert launch_project.project_dir is not None
|
449
|
-
# this prioritizes pip, and we don't support any cases where both are present
|
450
|
-
#
|
495
|
+
# this prioritizes pip, and we don't support any cases where both are present conda projects when uploaded to
|
496
|
+
# wandb become pip projects via requirements.frozen.txt, wandb doesn't preserve conda envs
|
451
497
|
if os.path.exists(
|
452
498
|
os.path.join(launch_project.project_dir, "requirements.txt")
|
453
499
|
) or os.path.exists(
|
wandb/sdk/launch/agent/agent.py
CHANGED
@@ -5,7 +5,6 @@ import pprint
|
|
5
5
|
import threading
|
6
6
|
import time
|
7
7
|
import traceback
|
8
|
-
from dataclasses import dataclass
|
9
8
|
from multiprocessing import Event
|
10
9
|
from multiprocessing.pool import ThreadPool
|
11
10
|
from typing import Any, Dict, List, Optional, Union
|
@@ -13,22 +12,18 @@ from typing import Any, Dict, List, Optional, Union
|
|
13
12
|
import wandb
|
14
13
|
from wandb.apis.internal import Api
|
15
14
|
from wandb.errors import CommError
|
16
|
-
from wandb.sdk.launch.
|
15
|
+
from wandb.sdk.launch.launch_add import launch_add
|
17
16
|
from wandb.sdk.launch.runner.local_container import LocalSubmittedRun
|
18
|
-
from wandb.sdk.launch.sweeps import
|
17
|
+
from wandb.sdk.launch.sweeps.scheduler import Scheduler
|
19
18
|
from wandb.sdk.lib import runid
|
20
19
|
|
21
20
|
from .. import loader
|
22
21
|
from .._project_spec import create_project_from_spec, fetch_and_validate_project
|
23
22
|
from ..builder.build import construct_builder_args
|
24
|
-
from ..
|
25
|
-
from ..utils import
|
26
|
-
|
27
|
-
|
28
|
-
PROJECT_SYNCHRONOUS,
|
29
|
-
LaunchDockerError,
|
30
|
-
LaunchError,
|
31
|
-
)
|
23
|
+
from ..errors import LaunchDockerError, LaunchError
|
24
|
+
from ..utils import LAUNCH_DEFAULT_PROJECT, LOG_PREFIX, PROJECT_SYNCHRONOUS
|
25
|
+
from .job_status_tracker import JobAndRunStatusTracker
|
26
|
+
from .run_queue_item_file_saver import RunQueueItemFileSaver
|
32
27
|
|
33
28
|
AGENT_POLLING_INTERVAL = 10
|
34
29
|
ACTIVE_SWEEP_POLLING_INTERVAL = 1 # more frequent when we know we have jobs
|
@@ -37,30 +32,13 @@ AGENT_POLLING = "POLLING"
|
|
37
32
|
AGENT_RUNNING = "RUNNING"
|
38
33
|
AGENT_KILLED = "KILLED"
|
39
34
|
|
40
|
-
|
41
|
-
|
42
|
-
_logger = logging.getLogger(__name__)
|
35
|
+
HIDDEN_AGENT_RUN_TYPE = "sweep-controller"
|
43
36
|
|
37
|
+
MAX_THREADS = 64
|
44
38
|
|
45
|
-
|
46
|
-
class JobAndRunStatus:
|
47
|
-
run_queue_item_id: str
|
48
|
-
run_id: Optional[str] = None
|
49
|
-
project: Optional[str] = None
|
50
|
-
entity: Optional[str] = None
|
51
|
-
run: Optional[AbstractRun] = None
|
52
|
-
failed_to_start: bool = False
|
53
|
-
completed_status: Optional[str] = None
|
54
|
-
is_scheduler: bool = False
|
55
|
-
|
56
|
-
@property
|
57
|
-
def job_completed(self) -> bool:
|
58
|
-
return self.failed_to_start or self.completed_status is not None
|
39
|
+
MAX_RESUME_COUNT = 5
|
59
40
|
|
60
|
-
|
61
|
-
self.run_id = launch_project.run_id
|
62
|
-
self.project = launch_project.target_project
|
63
|
-
self.entity = launch_project.target_entity
|
41
|
+
_logger = logging.getLogger(__name__)
|
64
42
|
|
65
43
|
|
66
44
|
def _convert_access(access: str) -> str:
|
@@ -101,16 +79,21 @@ def _max_from_config(
|
|
101
79
|
return max_from_config
|
102
80
|
|
103
81
|
|
104
|
-
def
|
82
|
+
def _is_scheduler_job(run_spec: Dict[str, Any]) -> bool:
|
105
83
|
"""Determine whether a job/runSpec is a sweep scheduler."""
|
106
84
|
if not run_spec:
|
107
|
-
_logger.debug("Recieved runSpec in
|
85
|
+
_logger.debug("Recieved runSpec in _is_scheduler_job that was empty")
|
108
86
|
|
109
|
-
if run_spec.get("uri") !=
|
87
|
+
if run_spec.get("uri") != Scheduler.PLACEHOLDER_URI:
|
110
88
|
return False
|
111
89
|
|
112
90
|
if run_spec.get("resource") == "local-process":
|
113
|
-
#
|
91
|
+
# Any job pushed to a run queue that has a scheduler uri is
|
92
|
+
# allowed to use local-process
|
93
|
+
if run_spec.get("job"):
|
94
|
+
return True
|
95
|
+
|
96
|
+
# If a scheduler is local-process and run through CLI, also
|
114
97
|
# confirm command is in format: [wandb scheduler <sweep>]
|
115
98
|
cmd = run_spec.get("overrides", {}).get("entry_point", [])
|
116
99
|
if len(cmd) < 3:
|
@@ -137,7 +120,7 @@ class LaunchAgent:
|
|
137
120
|
self._api = api
|
138
121
|
self._base_url = self._api.settings().get("base_url")
|
139
122
|
self._ticks = 0
|
140
|
-
self._jobs: Dict[int,
|
123
|
+
self._jobs: Dict[int, JobAndRunStatusTracker] = {}
|
141
124
|
self._jobs_lock = threading.Lock()
|
142
125
|
self._jobs_event = Event()
|
143
126
|
self._jobs_event.set()
|
@@ -169,15 +152,40 @@ class LaunchAgent:
|
|
169
152
|
self.gorilla_supports_agents,
|
170
153
|
)
|
171
154
|
self._id = create_response["launchAgentId"]
|
172
|
-
self._name = "" # hacky: want to display this to the user but we don't get it back from gql until polling starts. fix later
|
173
155
|
if self._api.entity_is_team(self._entity):
|
174
156
|
wandb.termwarn(
|
175
157
|
f"{LOG_PREFIX}Agent is running on team entity ({self._entity}). Members of this team will be able to run code on this device."
|
176
158
|
)
|
177
159
|
|
178
|
-
|
160
|
+
agent_response = self._api.get_launch_agent(
|
161
|
+
self._id, self.gorilla_supports_agents
|
162
|
+
)
|
163
|
+
self._name = agent_response["name"]
|
164
|
+
self._init_agent_run()
|
165
|
+
|
166
|
+
def fail_run_queue_item(
|
167
|
+
self,
|
168
|
+
run_queue_item_id: str,
|
169
|
+
message: str,
|
170
|
+
phase: str,
|
171
|
+
files: Optional[List[str]] = None,
|
172
|
+
) -> None:
|
179
173
|
if self._gorilla_supports_fail_run_queue_items:
|
180
|
-
self._api.fail_run_queue_item(run_queue_item_id)
|
174
|
+
self._api.fail_run_queue_item(run_queue_item_id, message, phase, files)
|
175
|
+
|
176
|
+
def _init_agent_run(self) -> None:
|
177
|
+
# TODO: has it been long enough that all backends support agents?
|
178
|
+
if self.gorilla_supports_agents:
|
179
|
+
settings = wandb.Settings(silent=True, disable_git=True)
|
180
|
+
self._wandb_run = wandb.init(
|
181
|
+
project=self._project,
|
182
|
+
entity=self._entity,
|
183
|
+
settings=settings,
|
184
|
+
id=self._name,
|
185
|
+
job_type=HIDDEN_AGENT_RUN_TYPE,
|
186
|
+
)
|
187
|
+
else:
|
188
|
+
self._wandb_run = None
|
181
189
|
|
182
190
|
@property
|
183
191
|
def thread_ids(self) -> List[int]:
|
@@ -253,24 +261,43 @@ class LaunchAgent:
|
|
253
261
|
if not update_ret["success"]:
|
254
262
|
wandb.termerror(f"{LOG_PREFIX}Failed to update agent status to {status}")
|
255
263
|
|
256
|
-
def finish_thread_id(
|
264
|
+
def finish_thread_id(
|
265
|
+
self,
|
266
|
+
thread_id: int,
|
267
|
+
exception: Optional[Union[Exception, LaunchDockerError]] = None,
|
268
|
+
) -> None:
|
257
269
|
"""Removes the job from our list for now."""
|
258
270
|
job_and_run_status = self._jobs[thread_id]
|
259
|
-
if
|
260
|
-
|
261
|
-
|
271
|
+
if (
|
272
|
+
job_and_run_status.entity is not None
|
273
|
+
and job_and_run_status.entity != self._entity
|
274
|
+
):
|
262
275
|
_logger.info(
|
263
276
|
"Skipping check for completed run status because run is on a different entity than agent"
|
264
277
|
)
|
278
|
+
elif exception is not None:
|
279
|
+
tb_str = traceback.format_exception(
|
280
|
+
type(exception), value=exception, tb=exception.__traceback__
|
281
|
+
)
|
282
|
+
fnames = job_and_run_status.saver.save_contents(
|
283
|
+
"".join(tb_str), "error.log", "error"
|
284
|
+
)
|
285
|
+
self.fail_run_queue_item(
|
286
|
+
job_and_run_status.run_queue_item_id,
|
287
|
+
str(exception),
|
288
|
+
job_and_run_status.err_stage,
|
289
|
+
fnames,
|
290
|
+
)
|
265
291
|
elif job_and_run_status.completed_status not in ["stopped", "failed"]:
|
266
292
|
_logger.info(
|
267
293
|
"Skipping check for completed run status because run was successful"
|
268
294
|
)
|
269
|
-
|
295
|
+
elif job_and_run_status.run is not None:
|
270
296
|
run_info = None
|
271
297
|
# sweep runs exist but have no info before they are started
|
272
298
|
# so run_info returned will be None
|
273
299
|
# normal runs just throw a comm error
|
300
|
+
# TODO: make more clear
|
274
301
|
try:
|
275
302
|
run_info = self._api.get_run_info(
|
276
303
|
self._entity, job_and_run_status.project, job_and_run_status.run_id
|
@@ -279,7 +306,22 @@ class LaunchAgent:
|
|
279
306
|
except CommError:
|
280
307
|
pass
|
281
308
|
if run_info is None:
|
282
|
-
|
309
|
+
_msg = "The submitted run was not successfully started"
|
310
|
+
fnames = None
|
311
|
+
|
312
|
+
logs = job_and_run_status.run.get_logs()
|
313
|
+
if logs:
|
314
|
+
fnames = job_and_run_status.saver.save_contents(
|
315
|
+
logs, "error.log", "error"
|
316
|
+
)
|
317
|
+
self.fail_run_queue_item(
|
318
|
+
job_and_run_status.run_queue_item_id, _msg, "run", fnames
|
319
|
+
)
|
320
|
+
else:
|
321
|
+
_logger.info("Finish thread id had no exception, ror run")
|
322
|
+
wandb._sentry.exception(
|
323
|
+
"launch agent called finish thread id on thread without run or exception"
|
324
|
+
)
|
283
325
|
|
284
326
|
# TODO: keep logs or something for the finished jobs
|
285
327
|
with self._jobs_lock:
|
@@ -296,7 +338,9 @@ class LaunchAgent:
|
|
296
338
|
if job.job_completed:
|
297
339
|
self.finish_thread_id(thread_id)
|
298
340
|
|
299
|
-
def run_job(
|
341
|
+
def run_job(
|
342
|
+
self, job: Dict[str, Any], queue: str, file_saver: RunQueueItemFileSaver
|
343
|
+
) -> None:
|
300
344
|
"""Set up project and run the job.
|
301
345
|
|
302
346
|
Arguments:
|
@@ -322,6 +366,8 @@ class LaunchAgent:
|
|
322
366
|
job,
|
323
367
|
self.default_config,
|
324
368
|
self._api,
|
369
|
+
queue,
|
370
|
+
file_saver,
|
325
371
|
),
|
326
372
|
)
|
327
373
|
|
@@ -367,7 +413,6 @@ class LaunchAgent:
|
|
367
413
|
agent_response = self._api.get_launch_agent(
|
368
414
|
self._id, self.gorilla_supports_agents
|
369
415
|
)
|
370
|
-
self._name = agent_response["name"] # hack: first time we get name
|
371
416
|
if agent_response["stopPolling"]:
|
372
417
|
# shutdown process and all jobs if requested from ui
|
373
418
|
raise KeyboardInterrupt
|
@@ -376,7 +421,10 @@ class LaunchAgent:
|
|
376
421
|
for queue in self._queues:
|
377
422
|
job = self.pop_from_queue(queue)
|
378
423
|
if job:
|
379
|
-
|
424
|
+
file_saver = RunQueueItemFileSaver(
|
425
|
+
self._wandb_run, job["runQueueItemId"]
|
426
|
+
)
|
427
|
+
if _is_scheduler_job(job.get("runSpec")):
|
380
428
|
# If job is a scheduler, and we are already at the cap, ignore,
|
381
429
|
# don't ack, and it will be pushed back onto the queue in 1 min
|
382
430
|
if self.num_running_schedulers >= self._max_schedulers:
|
@@ -388,13 +436,25 @@ class LaunchAgent:
|
|
388
436
|
continue
|
389
437
|
|
390
438
|
try:
|
391
|
-
self.run_job(job)
|
439
|
+
self.run_job(job, queue, file_saver)
|
392
440
|
except Exception as e:
|
393
441
|
wandb.termerror(
|
394
442
|
f"{LOG_PREFIX}Error running job: {traceback.format_exc()}"
|
395
443
|
)
|
396
444
|
wandb._sentry.exception(e)
|
397
|
-
|
445
|
+
|
446
|
+
# always the first phase, because we only enter phase 2 within the thread
|
447
|
+
files = file_saver.save_contents(
|
448
|
+
contents=traceback.format_exc(),
|
449
|
+
fname="error.log",
|
450
|
+
file_sub_type="error",
|
451
|
+
)
|
452
|
+
self.fail_run_queue_item(
|
453
|
+
run_queue_item_id=job["runQueueItemId"],
|
454
|
+
message=str(e),
|
455
|
+
phase="agent",
|
456
|
+
files=files,
|
457
|
+
)
|
398
458
|
|
399
459
|
for thread_id in self.thread_ids:
|
400
460
|
self._update_finished(thread_id)
|
@@ -429,20 +489,27 @@ class LaunchAgent:
|
|
429
489
|
job: Dict[str, Any],
|
430
490
|
default_config: Dict[str, Any],
|
431
491
|
api: Api,
|
492
|
+
queue: str,
|
493
|
+
file_saver: RunQueueItemFileSaver,
|
432
494
|
) -> None:
|
433
495
|
thread_id = threading.current_thread().ident
|
434
496
|
assert thread_id is not None
|
497
|
+
job_tracker = JobAndRunStatusTracker(job["runQueueItemId"], queue, file_saver)
|
498
|
+
with self._jobs_lock:
|
499
|
+
self._jobs[thread_id] = job_tracker
|
435
500
|
try:
|
436
|
-
self._thread_run_job(
|
501
|
+
self._thread_run_job(
|
502
|
+
launch_spec, job, default_config, api, queue, thread_id, job_tracker
|
503
|
+
)
|
437
504
|
except LaunchDockerError as e:
|
438
505
|
wandb.termerror(
|
439
506
|
f"{LOG_PREFIX}agent {self._name} encountered an issue while starting Docker, see above output for details."
|
440
507
|
)
|
441
|
-
self.finish_thread_id(thread_id)
|
508
|
+
self.finish_thread_id(thread_id, e)
|
442
509
|
wandb._sentry.exception(e)
|
443
510
|
except Exception as e:
|
444
511
|
wandb.termerror(f"{LOG_PREFIX}Error running job: {traceback.format_exc()}")
|
445
|
-
self.finish_thread_id(thread_id)
|
512
|
+
self.finish_thread_id(thread_id, e)
|
446
513
|
wandb._sentry.exception(e)
|
447
514
|
|
448
515
|
def _thread_run_job(
|
@@ -451,11 +518,10 @@ class LaunchAgent:
|
|
451
518
|
job: Dict[str, Any],
|
452
519
|
default_config: Dict[str, Any],
|
453
520
|
api: Api,
|
521
|
+
queue: str,
|
454
522
|
thread_id: int,
|
523
|
+
job_tracker: JobAndRunStatusTracker,
|
455
524
|
) -> None:
|
456
|
-
job_tracker = JobAndRunStatus(job["runQueueItemId"])
|
457
|
-
with self._jobs_lock:
|
458
|
-
self._jobs[thread_id] = job_tracker
|
459
525
|
project = create_project_from_spec(launch_spec, api)
|
460
526
|
job_tracker.update_run_info(project)
|
461
527
|
_logger.info("Fetching and validating project...")
|
@@ -480,9 +546,8 @@ class LaunchAgent:
|
|
480
546
|
backend = loader.runner_from_config(resource, api, backend_config, environment)
|
481
547
|
_logger.info("Backend loaded...")
|
482
548
|
api.ack_run_queue_item(job["runQueueItemId"], project.run_id)
|
483
|
-
run = backend.run(project, builder)
|
484
|
-
|
485
|
-
if _job_is_scheduler(launch_spec):
|
549
|
+
run = backend.run(project, builder, job_tracker)
|
550
|
+
if _is_scheduler_job(launch_spec):
|
486
551
|
with self._jobs_lock:
|
487
552
|
self._jobs[thread_id].is_scheduler = True
|
488
553
|
wandb.termlog(
|
@@ -497,15 +562,17 @@ class LaunchAgent:
|
|
497
562
|
with self._jobs_lock:
|
498
563
|
job_tracker.run = run
|
499
564
|
while self._jobs_event.is_set():
|
500
|
-
if self._check_run_finished(job_tracker):
|
565
|
+
if self._check_run_finished(job_tracker, launch_spec):
|
501
566
|
return
|
502
567
|
time.sleep(AGENT_POLLING_INTERVAL)
|
503
568
|
# temp: for local, kill all jobs. we don't yet have good handling for different
|
504
569
|
# types of runners in general
|
505
|
-
if isinstance(run, LocalSubmittedRun):
|
506
|
-
run.
|
570
|
+
if isinstance(run, LocalSubmittedRun) and run._command_proc is not None:
|
571
|
+
run._command_proc.kill()
|
507
572
|
|
508
|
-
def _check_run_finished(
|
573
|
+
def _check_run_finished(
|
574
|
+
self, job_tracker: JobAndRunStatusTracker, launch_spec: Dict[str, Any]
|
575
|
+
) -> bool:
|
509
576
|
if job_tracker.completed_status:
|
510
577
|
return True
|
511
578
|
|
@@ -522,13 +589,28 @@ class LaunchAgent:
|
|
522
589
|
try:
|
523
590
|
run = job_tracker.run
|
524
591
|
status = run.get_status().state
|
525
|
-
if status in ["stopped", "failed", "finished"]:
|
592
|
+
if status in ["stopped", "failed", "finished", "preempted"]:
|
526
593
|
if job_tracker.is_scheduler:
|
527
594
|
wandb.termlog(f"{LOG_PREFIX}Scheduler finished with ID: {run.id}")
|
528
595
|
else:
|
529
596
|
wandb.termlog(f"{LOG_PREFIX}Job finished with ID: {run.id}")
|
530
597
|
with self._jobs_lock:
|
531
598
|
job_tracker.completed_status = status
|
599
|
+
if status == "preempted":
|
600
|
+
config = launch_spec.copy()
|
601
|
+
config["run_id"] = job_tracker.run_id
|
602
|
+
config["_resume_count"] = config.get("_resume_count", 0) + 1
|
603
|
+
if config["_resume_count"] > MAX_RESUME_COUNT:
|
604
|
+
wandb.termlog(
|
605
|
+
f"{LOG_PREFIX}Run {job_tracker.run_id} has already resumed {MAX_RESUME_COUNT} times."
|
606
|
+
)
|
607
|
+
return True
|
608
|
+
wandb.termlog(f"{LOG_PREFIX}Requeueing run {job_tracker.run_id}.")
|
609
|
+
launch_add(
|
610
|
+
config=config,
|
611
|
+
project_queue=self._project,
|
612
|
+
queue_name=job_tracker.queue,
|
613
|
+
)
|
532
614
|
return True
|
533
615
|
return False
|
534
616
|
except LaunchError as e:
|
@@ -0,0 +1,34 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
from wandb.sdk.launch._project_spec import LaunchProject
|
5
|
+
|
6
|
+
from ..runner.abstract import AbstractRun
|
7
|
+
from .run_queue_item_file_saver import RunQueueItemFileSaver
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class JobAndRunStatusTracker:
|
12
|
+
run_queue_item_id: str
|
13
|
+
queue: str
|
14
|
+
saver: RunQueueItemFileSaver
|
15
|
+
run_id: Optional[str] = None
|
16
|
+
project: Optional[str] = None
|
17
|
+
entity: Optional[str] = None
|
18
|
+
run: Optional[AbstractRun] = None
|
19
|
+
failed_to_start: bool = False
|
20
|
+
completed_status: Optional[str] = None
|
21
|
+
is_scheduler: bool = False
|
22
|
+
err_stage: str = "agent"
|
23
|
+
|
24
|
+
@property
|
25
|
+
def job_completed(self) -> bool:
|
26
|
+
return self.failed_to_start or self.completed_status is not None
|
27
|
+
|
28
|
+
def update_run_info(self, launch_project: LaunchProject) -> None:
|
29
|
+
self.run_id = launch_project.run_id
|
30
|
+
self.project = launch_project.target_project
|
31
|
+
self.entity = launch_project.target_entity
|
32
|
+
|
33
|
+
def set_err_stage(self, stage: str) -> None:
|
34
|
+
self.err_stage = stage
|