wandb 0.16.6__py3-none-any.whl → 0.17.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- package_readme.md +95 -0
- wandb/__init__.py +2 -3
- wandb/agents/pyagent.py +0 -1
- wandb/analytics/sentry.py +2 -1
- wandb/apis/importers/internals/internal.py +0 -1
- wandb/apis/importers/internals/protocols.py +30 -56
- wandb/apis/importers/mlflow.py +13 -26
- wandb/apis/importers/wandb.py +8 -14
- wandb/apis/internal.py +0 -3
- wandb/apis/public/api.py +55 -3
- wandb/apis/public/artifacts.py +1 -0
- wandb/apis/public/files.py +1 -0
- wandb/apis/public/history.py +1 -0
- wandb/apis/public/jobs.py +17 -4
- wandb/apis/public/projects.py +1 -0
- wandb/apis/public/reports.py +1 -0
- wandb/apis/public/runs.py +15 -17
- wandb/apis/public/sweeps.py +1 -0
- wandb/apis/public/teams.py +1 -0
- wandb/apis/public/users.py +1 -0
- wandb/apis/reports/v1/_blocks.py +3 -7
- wandb/apis/reports/v2/gql.py +1 -0
- wandb/apis/reports/v2/interface.py +3 -4
- wandb/apis/reports/v2/internal.py +5 -8
- wandb/cli/cli.py +92 -22
- wandb/data_types.py +9 -6
- wandb/docker/__init__.py +1 -1
- wandb/env.py +38 -8
- wandb/errors/__init__.py +5 -0
- wandb/errors/term.py +10 -2
- wandb/filesync/step_checksum.py +1 -4
- wandb/filesync/step_prepare.py +4 -24
- wandb/filesync/step_upload.py +4 -106
- wandb/filesync/upload_job.py +0 -76
- wandb/integration/catboost/catboost.py +1 -1
- wandb/integration/fastai/__init__.py +1 -0
- wandb/integration/huggingface/resolver.py +2 -2
- wandb/integration/keras/__init__.py +1 -0
- wandb/integration/keras/callbacks/metrics_logger.py +1 -1
- wandb/integration/keras/keras.py +7 -7
- wandb/integration/langchain/wandb_tracer.py +1 -0
- wandb/integration/lightning/fabric/logger.py +1 -3
- wandb/integration/metaflow/metaflow.py +41 -6
- wandb/integration/openai/fine_tuning.py +3 -3
- wandb/integration/prodigy/prodigy.py +1 -1
- wandb/old/summary.py +1 -1
- wandb/plot/confusion_matrix.py +1 -1
- wandb/plot/pr_curve.py +2 -1
- wandb/plot/roc_curve.py +2 -1
- wandb/{plots → plot}/utils.py +13 -25
- wandb/proto/v3/wandb_internal_pb2.py +364 -332
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +322 -316
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/wandb_deprecated.py +7 -1
- wandb/proto/wandb_internal_codegen.py +3 -29
- wandb/sdk/artifacts/artifact.py +26 -11
- wandb/sdk/artifacts/artifact_download_logger.py +1 -0
- wandb/sdk/artifacts/artifact_file_cache.py +18 -4
- wandb/sdk/artifacts/artifact_instance_cache.py +1 -0
- wandb/sdk/artifacts/artifact_manifest.py +1 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +7 -3
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
- wandb/sdk/artifacts/artifact_saver.py +2 -8
- wandb/sdk/artifacts/artifact_state.py +1 -0
- wandb/sdk/artifacts/artifact_ttl.py +1 -0
- wandb/sdk/artifacts/exceptions.py +1 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +13 -18
- wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +5 -3
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +3 -42
- wandb/sdk/artifacts/storage_policy.py +2 -12
- wandb/sdk/data_types/_dtypes.py +8 -8
- wandb/sdk/data_types/base_types/media.py +3 -6
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +3 -1
- wandb/sdk/data_types/image.py +1 -1
- wandb/sdk/data_types/video.py +1 -1
- wandb/sdk/integration_utils/auto_logging.py +5 -6
- wandb/sdk/integration_utils/data_logging.py +10 -6
- wandb/sdk/interface/interface.py +68 -32
- wandb/sdk/interface/interface_shared.py +7 -13
- wandb/sdk/internal/datastore.py +1 -1
- wandb/sdk/internal/file_pusher.py +2 -5
- wandb/sdk/internal/file_stream.py +5 -18
- wandb/sdk/internal/handler.py +18 -2
- wandb/sdk/internal/internal.py +0 -1
- wandb/sdk/internal/internal_api.py +1 -129
- wandb/sdk/internal/internal_util.py +0 -1
- wandb/sdk/internal/job_builder.py +159 -45
- wandb/sdk/internal/profiler.py +1 -0
- wandb/sdk/internal/progress.py +0 -28
- wandb/sdk/internal/run.py +1 -0
- wandb/sdk/internal/sender.py +1 -2
- wandb/sdk/internal/system/assets/gpu_amd.py +44 -44
- wandb/sdk/internal/system/assets/gpu_apple.py +56 -11
- wandb/sdk/internal/system/assets/interfaces.py +6 -8
- wandb/sdk/internal/system/assets/open_metrics.py +2 -2
- wandb/sdk/internal/system/assets/trainium.py +1 -3
- wandb/sdk/launch/__init__.py +9 -1
- wandb/sdk/launch/_launch.py +4 -24
- wandb/sdk/launch/_launch_add.py +1 -3
- wandb/sdk/launch/_project_spec.py +186 -224
- wandb/sdk/launch/agent/agent.py +37 -13
- wandb/sdk/launch/agent/config.py +72 -14
- wandb/sdk/launch/builder/abstract.py +69 -1
- wandb/sdk/launch/builder/build.py +156 -555
- wandb/sdk/launch/builder/context_manager.py +235 -0
- wandb/sdk/launch/builder/docker_builder.py +8 -23
- wandb/sdk/launch/builder/kaniko_builder.py +12 -25
- wandb/sdk/launch/builder/noop.py +1 -0
- wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
- wandb/sdk/launch/create_job.py +47 -37
- wandb/sdk/launch/environment/abstract.py +1 -0
- wandb/sdk/launch/environment/gcp_environment.py +1 -0
- wandb/sdk/launch/environment/local_environment.py +1 -0
- wandb/sdk/launch/inputs/files.py +148 -0
- wandb/sdk/launch/inputs/internal.py +217 -0
- wandb/sdk/launch/inputs/manage.py +95 -0
- wandb/sdk/launch/loader.py +1 -0
- wandb/sdk/launch/registry/abstract.py +1 -0
- wandb/sdk/launch/registry/azure_container_registry.py +1 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +1 -0
- wandb/sdk/launch/registry/google_artifact_registry.py +2 -1
- wandb/sdk/launch/registry/local_registry.py +1 -0
- wandb/sdk/launch/runner/abstract.py +1 -0
- wandb/sdk/launch/runner/kubernetes_monitor.py +1 -0
- wandb/sdk/launch/runner/kubernetes_runner.py +9 -10
- wandb/sdk/launch/runner/local_container.py +2 -3
- wandb/sdk/launch/runner/local_process.py +8 -29
- wandb/sdk/launch/runner/sagemaker_runner.py +21 -20
- wandb/sdk/launch/runner/vertex_runner.py +8 -7
- wandb/sdk/launch/sweeps/scheduler.py +4 -3
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +3 -3
- wandb/sdk/launch/utils.py +15 -140
- wandb/sdk/lib/_settings_toposort_generated.py +0 -5
- wandb/sdk/lib/fsm.py +8 -12
- wandb/sdk/lib/gitlib.py +4 -4
- wandb/sdk/lib/import_hooks.py +1 -1
- wandb/sdk/lib/lazyloader.py +0 -1
- wandb/sdk/lib/proto_util.py +23 -2
- wandb/sdk/lib/redirect.py +19 -14
- wandb/sdk/lib/retry.py +3 -2
- wandb/sdk/lib/tracelog.py +1 -1
- wandb/sdk/service/service.py +19 -16
- wandb/sdk/verify/verify.py +2 -1
- wandb/sdk/wandb_init.py +14 -55
- wandb/sdk/wandb_manager.py +2 -2
- wandb/sdk/wandb_require.py +5 -0
- wandb/sdk/wandb_run.py +114 -56
- wandb/sdk/wandb_settings.py +0 -48
- wandb/sdk/wandb_setup.py +1 -1
- wandb/sklearn/__init__.py +1 -0
- wandb/sklearn/plot/__init__.py +1 -0
- wandb/sklearn/plot/classifier.py +11 -12
- wandb/sklearn/plot/clusterer.py +2 -1
- wandb/sklearn/plot/regressor.py +1 -0
- wandb/sklearn/plot/shared.py +1 -0
- wandb/sklearn/utils.py +1 -0
- wandb/testing/relay.py +4 -4
- wandb/trigger.py +1 -0
- wandb/util.py +67 -54
- wandb/wandb_controller.py +2 -3
- wandb/wandb_torch.py +1 -2
- {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/METADATA +67 -70
- {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/RECORD +177 -187
- {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/WHEEL +1 -2
- wandb/bin/apple_gpu_stats +0 -0
- wandb/catboost/__init__.py +0 -9
- wandb/fastai/__init__.py +0 -9
- wandb/keras/__init__.py +0 -18
- wandb/lightgbm/__init__.py +0 -9
- wandb/plots/__init__.py +0 -6
- wandb/plots/explain_text.py +0 -36
- wandb/plots/heatmap.py +0 -81
- wandb/plots/named_entity.py +0 -43
- wandb/plots/part_of_speech.py +0 -50
- wandb/plots/plot_definitions.py +0 -768
- wandb/plots/precision_recall.py +0 -121
- wandb/plots/roc.py +0 -103
- wandb/sacred/__init__.py +0 -3
- wandb/xgboost/__init__.py +0 -9
- wandb-0.16.6.dist-info/top_level.txt +0 -1
- {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info/licenses}/LICENSE +0 -0
@@ -2,18 +2,20 @@
|
|
2
2
|
|
3
3
|
Arguments can come from a launch spec or call to wandb launch.
|
4
4
|
"""
|
5
|
+
|
5
6
|
import enum
|
7
|
+
import json
|
6
8
|
import logging
|
7
9
|
import os
|
8
10
|
import tempfile
|
9
11
|
from copy import deepcopy
|
10
12
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
|
11
13
|
|
14
|
+
from six.moves import shlex_quote
|
15
|
+
|
12
16
|
import wandb
|
13
|
-
import wandb.docker as docker
|
14
17
|
from wandb.apis.internal import Api
|
15
18
|
from wandb.errors import CommError
|
16
|
-
from wandb.sdk.launch import utils
|
17
19
|
from wandb.sdk.launch.utils import get_entrypoint_file
|
18
20
|
from wandb.sdk.lib.runid import generate_id
|
19
21
|
|
@@ -33,15 +35,18 @@ IMAGE_TAG_MAX_LENGTH = 32
|
|
33
35
|
|
34
36
|
|
35
37
|
class LaunchSource(enum.IntEnum):
|
36
|
-
|
37
|
-
GIT: int = 2
|
38
|
-
LOCAL: int = 3
|
39
|
-
DOCKER: int = 4
|
40
|
-
JOB: int = 5
|
38
|
+
"""Enumeration of possible sources for a launch project.
|
41
39
|
|
40
|
+
Attributes:
|
41
|
+
DOCKER: Source is a Docker image. This can happen if a user runs
|
42
|
+
`wandb launch -d <docker-image>`.
|
43
|
+
JOB: Source is a job. This is standard case.
|
44
|
+
SCHEDULER: Source is a wandb sweep scheduler command.
|
45
|
+
"""
|
42
46
|
|
43
|
-
|
44
|
-
|
47
|
+
DOCKER: int = 1
|
48
|
+
JOB: int = 2
|
49
|
+
SCHEDULER: int = 3
|
45
50
|
|
46
51
|
|
47
52
|
class LaunchProject:
|
@@ -60,8 +65,16 @@ class LaunchProject:
|
|
60
65
|
|
61
66
|
This class is stateful and certain methods can only be called after
|
62
67
|
`LaunchProject.fetch_and_validate_project()` has been called.
|
68
|
+
|
69
|
+
Notes on the entrypoint:
|
70
|
+
- The entrypoint is the command that will be run inside the container.
|
71
|
+
- The LaunchProject stores two entrypoints
|
72
|
+
- The job entrypoint is the entrypoint specified in the job's config.
|
73
|
+
- The override entrypoint is the entrypoint specified in the launch spec.
|
74
|
+
- The override entrypoint takes precedence over the job entrypoint.
|
63
75
|
"""
|
64
76
|
|
77
|
+
# This init is way to long, and there are too many attributes on this sucker.
|
65
78
|
def __init__(
|
66
79
|
self,
|
67
80
|
uri: Optional[str],
|
@@ -79,9 +92,6 @@ class LaunchProject:
|
|
79
92
|
run_id: Optional[str],
|
80
93
|
sweep_id: Optional[str] = None,
|
81
94
|
):
|
82
|
-
if uri is not None and utils.is_bare_wandb_uri(uri):
|
83
|
-
uri = api.settings("base_url") + uri
|
84
|
-
_logger.info(f"{LOG_PREFIX}Updating uri with base uri: {uri}")
|
85
95
|
self.uri = uri
|
86
96
|
self.job = job
|
87
97
|
if job is not None:
|
@@ -105,74 +115,57 @@ class LaunchProject:
|
|
105
115
|
self.accelerator_base_image: Optional[str] = resource_args_build.get(
|
106
116
|
"accelerator", {}
|
107
117
|
).get("base_image") or resource_args_build.get("cuda", {}).get("base_image")
|
108
|
-
self._base_image: Optional[str] = launch_spec.get("base_image")
|
109
118
|
self.docker_image: Optional[str] = docker_config.get(
|
110
119
|
"docker_image"
|
111
120
|
) or launch_spec.get("image_uri")
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
self.
|
117
|
-
self.
|
118
|
-
self.
|
119
|
-
self.overrides = overrides
|
120
|
-
self.override_args: List[str] = overrides.get("args", [])
|
121
|
-
self.override_config: Dict[str, Any] = overrides.get("run_config", {})
|
122
|
-
self.override_artifacts: Dict[str, Any] = overrides.get("artifacts", {})
|
123
|
-
self.override_entrypoint: Optional[EntryPoint] = None
|
124
|
-
self.override_dockerfile: Optional[str] = overrides.get("dockerfile")
|
121
|
+
self.docker_user_id = docker_config.get("user_id", 1000)
|
122
|
+
self._entry_point: Optional[EntryPoint] = (
|
123
|
+
None # todo: keep multiple entrypoint support?
|
124
|
+
)
|
125
|
+
self.init_overrides(overrides)
|
126
|
+
self.init_source()
|
127
|
+
self.init_git(git_info)
|
125
128
|
self.deps_type: Optional[str] = None
|
126
129
|
self._runtime: Optional[str] = None
|
127
130
|
self.run_id = run_id or generate_id()
|
128
131
|
self._queue_name: Optional[str] = None
|
129
132
|
self._queue_entity: Optional[str] = None
|
130
133
|
self._run_queue_item_id: Optional[str] = None
|
131
|
-
self.
|
132
|
-
|
133
|
-
] = None # todo: keep multiple entrypoint support?
|
134
|
-
|
135
|
-
override_entrypoint = overrides.get("entry_point")
|
136
|
-
if override_entrypoint:
|
137
|
-
_logger.info("Adding override entry point")
|
138
|
-
self.override_entrypoint = EntryPoint(
|
139
|
-
name=get_entrypoint_file(override_entrypoint),
|
140
|
-
command=override_entrypoint,
|
141
|
-
)
|
134
|
+
self._job_dockerfile: Optional[str] = None
|
135
|
+
self._job_build_context: Optional[str] = None
|
142
136
|
|
143
|
-
|
144
|
-
_logger.info("Adding override sweep id")
|
145
|
-
self.sweep_id = overrides["sweep_id"]
|
137
|
+
def init_source(self) -> None:
|
146
138
|
if self.docker_image is not None:
|
147
139
|
self.source = LaunchSource.DOCKER
|
148
140
|
self.project_dir = None
|
149
141
|
elif self.job is not None:
|
150
142
|
self.source = LaunchSource.JOB
|
151
143
|
self.project_dir = tempfile.mkdtemp()
|
152
|
-
|
153
|
-
|
154
|
-
self.
|
155
|
-
self.
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
144
|
+
if self.uri and self.uri.startswith("placeholder"):
|
145
|
+
self.source = LaunchSource.SCHEDULER
|
146
|
+
self.project_dir = os.getcwd()
|
147
|
+
self._entry_point = self.override_entrypoint
|
148
|
+
|
149
|
+
def init_git(self, git_info: Dict[str, str]) -> None:
|
150
|
+
self.git_version = git_info.get("version")
|
151
|
+
self.git_repo = git_info.get("repo")
|
152
|
+
|
153
|
+
def init_overrides(self, overrides: Dict[str, Any]) -> None:
|
154
|
+
"""Initialize override attributes for a launch project."""
|
155
|
+
self.overrides = overrides
|
156
|
+
self.override_args: List[str] = overrides.get("args", [])
|
157
|
+
self.override_config: Dict[str, Any] = overrides.get("run_config", {})
|
158
|
+
self.override_artifacts: Dict[str, Any] = overrides.get("artifacts", {})
|
159
|
+
self.override_files: Dict[str, Any] = overrides.get("files", {})
|
160
|
+
self.override_entrypoint: Optional[EntryPoint] = None
|
161
|
+
self.override_dockerfile: Optional[str] = overrides.get("dockerfile")
|
162
|
+
override_entrypoint = overrides.get("entry_point")
|
163
|
+
if override_entrypoint:
|
164
|
+
_logger.info("Adding override entry point")
|
165
|
+
self.override_entrypoint = EntryPoint(
|
166
|
+
name=get_entrypoint_file(override_entrypoint),
|
167
|
+
command=override_entrypoint,
|
163
168
|
)
|
164
|
-
self.uri = os.getcwd()
|
165
|
-
self.source = LaunchSource.LOCAL
|
166
|
-
self.project_dir = self.uri
|
167
|
-
else:
|
168
|
-
_logger.info(f"URI {self.uri} indicates a local uri")
|
169
|
-
# assume local
|
170
|
-
if self.uri is not None and not os.path.exists(self.uri):
|
171
|
-
raise LaunchError(
|
172
|
-
"Assumed URI supplied is a local path but path is not valid"
|
173
|
-
)
|
174
|
-
self.source = LaunchSource.LOCAL
|
175
|
-
self.project_dir = self.uri
|
176
169
|
|
177
170
|
def __repr__(self) -> str:
|
178
171
|
"""String representation of LaunchProject."""
|
@@ -211,6 +204,20 @@ class LaunchProject:
|
|
211
204
|
launch_spec.get("sweep_id", {}),
|
212
205
|
)
|
213
206
|
|
207
|
+
@property
|
208
|
+
def job_dockerfile(self) -> Optional[str]:
|
209
|
+
return self._job_dockerfile
|
210
|
+
|
211
|
+
@property
|
212
|
+
def job_build_context(self) -> Optional[str]:
|
213
|
+
return self._job_build_context
|
214
|
+
|
215
|
+
def set_job_dockerfile(self, dockerfile: str) -> None:
|
216
|
+
self._job_dockerfile = dockerfile
|
217
|
+
|
218
|
+
def set_job_build_context(self, build_context: str) -> None:
|
219
|
+
self._job_build_context = build_context
|
220
|
+
|
214
221
|
@property
|
215
222
|
def image_name(self) -> str:
|
216
223
|
if self.docker_image is not None:
|
@@ -274,7 +281,7 @@ class LaunchProject:
|
|
274
281
|
image (str): The image name to fill in for ${wandb-image}.
|
275
282
|
|
276
283
|
Returns:
|
277
|
-
|
284
|
+
Dict[str, Any]: The resource args with all macros filled in.
|
278
285
|
"""
|
279
286
|
update_dict = {
|
280
287
|
"project_name": self.target_project,
|
@@ -324,8 +331,8 @@ class LaunchProject:
|
|
324
331
|
self._docker_image = value
|
325
332
|
self._ensure_not_docker_image_and_local_process()
|
326
333
|
|
327
|
-
def
|
328
|
-
"""Returns the
|
334
|
+
def get_job_entry_point(self) -> Optional["EntryPoint"]:
|
335
|
+
"""Returns the job entrypoint for the project."""
|
329
336
|
# assuming project only has 1 entry point, pull that out
|
330
337
|
# tmp fn until we figure out if we want to support multiple entry points or not
|
331
338
|
if not self._entry_point:
|
@@ -336,8 +343,8 @@ class LaunchProject:
|
|
336
343
|
return None
|
337
344
|
return self._entry_point
|
338
345
|
|
339
|
-
def
|
340
|
-
"""
|
346
|
+
def set_job_entry_point(self, command: List[str]) -> "EntryPoint":
|
347
|
+
"""Set job entrypoint for the project."""
|
341
348
|
assert (
|
342
349
|
self._entry_point is None
|
343
350
|
), "Cannot set entry point twice. Use LaunchProject.override_entrypoint"
|
@@ -358,51 +365,23 @@ class LaunchProject:
|
|
358
365
|
"""
|
359
366
|
if self.source == LaunchSource.DOCKER:
|
360
367
|
return
|
361
|
-
if self.source == LaunchSource.LOCAL:
|
362
|
-
if not self._entry_point:
|
363
|
-
wandb.termlog(
|
364
|
-
f"{LOG_PREFIX}Entry point for repo not specified, defaulting to `python main.py`"
|
365
|
-
)
|
366
|
-
self.set_entry_point(EntrypointDefaults.PYTHON)
|
367
368
|
elif self.source == LaunchSource.JOB:
|
368
369
|
self._fetch_job()
|
369
|
-
else:
|
370
|
-
self._fetch_project_local(internal_api=self.api)
|
371
|
-
|
372
370
|
assert self.project_dir is not None
|
373
|
-
# this prioritizes pip, and we don't support any cases where both are present conda projects when uploaded to
|
374
|
-
# wandb become pip projects via requirements.frozen.txt, wandb doesn't preserve conda envs
|
375
|
-
if os.path.exists(
|
376
|
-
os.path.join(self.project_dir, "requirements.txt")
|
377
|
-
) or os.path.exists(os.path.join(self.project_dir, "requirements.frozen.txt")):
|
378
|
-
self.deps_type = "pip"
|
379
|
-
elif os.path.exists(os.path.join(self.project_dir, "environment.yml")):
|
380
|
-
self.deps_type = "conda"
|
381
371
|
|
372
|
+
# Let's make sure we document this very clearly.
|
382
373
|
def get_image_source_string(self) -> str:
|
383
374
|
"""Returns a unique string identifying the source of an image."""
|
384
|
-
if self.source == LaunchSource.
|
385
|
-
# TODO: more correct to get a hash of local uri contents
|
386
|
-
assert isinstance(self.uri, str)
|
387
|
-
return self.uri
|
388
|
-
elif self.source == LaunchSource.JOB:
|
375
|
+
if self.source == LaunchSource.JOB:
|
389
376
|
assert self._job_artifact is not None
|
390
377
|
return f"{self._job_artifact.name}:v{self._job_artifact.version}"
|
391
|
-
elif self.source == LaunchSource.GIT:
|
392
|
-
assert isinstance(self.uri, str)
|
393
|
-
ret = self.uri
|
394
|
-
if self.git_version:
|
395
|
-
ret += self.git_version
|
396
|
-
return ret
|
397
|
-
elif self.source == LaunchSource.WANDB:
|
398
|
-
assert isinstance(self.uri, str)
|
399
|
-
return self.uri
|
400
378
|
elif self.source == LaunchSource.DOCKER:
|
401
379
|
assert isinstance(self.docker_image, str)
|
402
|
-
_logger.debug("")
|
403
380
|
return self.docker_image
|
404
381
|
else:
|
405
|
-
raise LaunchError(
|
382
|
+
raise LaunchError(
|
383
|
+
"Unknown source type when determining image source string"
|
384
|
+
)
|
406
385
|
|
407
386
|
def _ensure_not_docker_image_and_local_process(self) -> None:
|
408
387
|
"""Ensure that docker image is not specified with local-process resource runner.
|
@@ -430,111 +409,84 @@ class LaunchProject:
|
|
430
409
|
raise LaunchError(
|
431
410
|
f"Error accessing job {self.job}: {msg} on {public_api.settings.get('base_url')}"
|
432
411
|
)
|
433
|
-
job.configure_launch_project(self)
|
412
|
+
job.configure_launch_project(self) # Why is this a method of the job?
|
434
413
|
self._job_artifact = job._job_artifact
|
435
414
|
|
436
|
-
def
|
437
|
-
"""
|
438
|
-
# these asserts are all guaranteed to pass, but are required by mypy
|
439
|
-
assert self.source != LaunchSource.LOCAL and self.source != LaunchSource.JOB
|
440
|
-
assert isinstance(self.uri, str)
|
441
|
-
assert self.project_dir is not None
|
442
|
-
_logger.info("Fetching project locally...")
|
443
|
-
if utils._is_wandb_uri(self.uri):
|
444
|
-
source_entity, source_project, source_run_name = utils.parse_wandb_uri(
|
445
|
-
self.uri
|
446
|
-
)
|
447
|
-
run_info = utils.fetch_wandb_project_run_info(
|
448
|
-
source_entity, source_project, source_run_name, internal_api
|
449
|
-
)
|
450
|
-
program_name = run_info.get("codePath") or run_info["program"]
|
451
|
-
|
452
|
-
self.python_version = run_info.get("python", "3")
|
453
|
-
downloaded_code_artifact = utils.check_and_download_code_artifacts(
|
454
|
-
source_entity,
|
455
|
-
source_project,
|
456
|
-
source_run_name,
|
457
|
-
internal_api,
|
458
|
-
self.project_dir,
|
459
|
-
)
|
460
|
-
if not downloaded_code_artifact:
|
461
|
-
if not run_info["git"]:
|
462
|
-
raise LaunchError(
|
463
|
-
"Reproducing a run requires either an associated git repo or a code artifact logged with `run.log_code()`"
|
464
|
-
)
|
465
|
-
branch_name = utils._fetch_git_repo(
|
466
|
-
self.project_dir,
|
467
|
-
run_info["git"]["remote"],
|
468
|
-
run_info["git"]["commit"],
|
469
|
-
)
|
470
|
-
if self.git_version is None:
|
471
|
-
self.git_version = branch_name
|
472
|
-
patch = utils.fetch_project_diff(
|
473
|
-
source_entity, source_project, source_run_name, internal_api
|
474
|
-
)
|
475
|
-
if patch:
|
476
|
-
utils.apply_patch(patch, self.project_dir)
|
477
|
-
|
478
|
-
# For cases where the entry point wasn't checked into git
|
479
|
-
if not os.path.exists(os.path.join(self.project_dir, program_name)):
|
480
|
-
downloaded_entrypoint = utils.download_entry_point(
|
481
|
-
source_entity,
|
482
|
-
source_project,
|
483
|
-
source_run_name,
|
484
|
-
internal_api,
|
485
|
-
program_name,
|
486
|
-
self.project_dir,
|
487
|
-
)
|
488
|
-
|
489
|
-
if not downloaded_entrypoint:
|
490
|
-
raise LaunchError(
|
491
|
-
f"Entrypoint file: {program_name} does not exist, "
|
492
|
-
"and could not be downloaded. Please specify the entrypoint for this run."
|
493
|
-
)
|
494
|
-
|
495
|
-
if (
|
496
|
-
"_session_history.ipynb" in os.listdir(self.project_dir)
|
497
|
-
or ".ipynb" in program_name
|
498
|
-
):
|
499
|
-
program_name = utils.convert_jupyter_notebook_to_script(
|
500
|
-
program_name, self.project_dir
|
501
|
-
)
|
415
|
+
def get_env_vars_dict(self, api: Api, max_env_length: int) -> Dict[str, str]:
|
416
|
+
"""Generate environment variables for the project.
|
502
417
|
|
503
|
-
|
504
|
-
|
505
|
-
source_entity,
|
506
|
-
source_project,
|
507
|
-
source_run_name,
|
508
|
-
internal_api,
|
509
|
-
self.project_dir,
|
510
|
-
)
|
418
|
+
Arguments:
|
419
|
+
launch_project: LaunchProject to generate environment variables for.
|
511
420
|
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
421
|
+
Returns:
|
422
|
+
Dictionary of environment variables.
|
423
|
+
"""
|
424
|
+
env_vars = {}
|
425
|
+
env_vars["WANDB_BASE_URL"] = api.settings("base_url")
|
426
|
+
override_api_key = self.launch_spec.get("_wandb_api_key")
|
427
|
+
env_vars["WANDB_API_KEY"] = override_api_key or api.api_key
|
428
|
+
if self.target_project:
|
429
|
+
env_vars["WANDB_PROJECT"] = self.target_project
|
430
|
+
env_vars["WANDB_ENTITY"] = self.target_entity
|
431
|
+
env_vars["WANDB_LAUNCH"] = "True"
|
432
|
+
env_vars["WANDB_RUN_ID"] = self.run_id
|
433
|
+
if self.docker_image:
|
434
|
+
env_vars["WANDB_DOCKER"] = self.docker_image
|
435
|
+
if self.name is not None:
|
436
|
+
env_vars["WANDB_NAME"] = self.name
|
437
|
+
if "author" in self.launch_spec and not override_api_key:
|
438
|
+
env_vars["WANDB_USERNAME"] = self.launch_spec["author"]
|
439
|
+
if self.sweep_id:
|
440
|
+
env_vars["WANDB_SWEEP_ID"] = self.sweep_id
|
441
|
+
if self.launch_spec.get("_resume_count", 0) > 0:
|
442
|
+
env_vars["WANDB_RESUME"] = "allow"
|
443
|
+
if self.queue_name:
|
444
|
+
env_vars[wandb.env.LAUNCH_QUEUE_NAME] = self.queue_name
|
445
|
+
if self.queue_entity:
|
446
|
+
env_vars[wandb.env.LAUNCH_QUEUE_ENTITY] = self.queue_entity
|
447
|
+
if self.run_queue_item_id:
|
448
|
+
env_vars[wandb.env.LAUNCH_TRACE_ID] = self.run_queue_item_id
|
449
|
+
|
450
|
+
_inject_wandb_config_env_vars(self.override_config, env_vars, max_env_length)
|
451
|
+
_inject_file_overrides_env_vars(self.override_files, env_vars, max_env_length)
|
452
|
+
|
453
|
+
artifacts = {}
|
454
|
+
# if we're spinning up a launch process from a job
|
455
|
+
# we should tell the run to use that artifact
|
456
|
+
if self.job:
|
457
|
+
artifacts = {wandb.util.LAUNCH_JOB_ARTIFACT_SLOT_NAME: self.job}
|
458
|
+
env_vars["WANDB_ARTIFACTS"] = json.dumps(
|
459
|
+
{**artifacts, **self.override_artifacts}
|
460
|
+
)
|
461
|
+
return env_vars
|
462
|
+
|
463
|
+
def parse_existing_requirements(self) -> str:
|
464
|
+
import pkg_resources
|
465
|
+
|
466
|
+
requirements_line = ""
|
467
|
+
assert self.project_dir is not None
|
468
|
+
base_requirements = os.path.join(self.project_dir, "requirements.txt")
|
469
|
+
if os.path.exists(base_requirements):
|
470
|
+
include_only = set()
|
471
|
+
with open(base_requirements) as f:
|
472
|
+
iter = pkg_resources.parse_requirements(f)
|
473
|
+
while True:
|
474
|
+
try:
|
475
|
+
pkg = next(iter)
|
476
|
+
if hasattr(pkg, "name"):
|
477
|
+
name = pkg.name.lower()
|
478
|
+
else:
|
479
|
+
name = str(pkg)
|
480
|
+
include_only.add(shlex_quote(name))
|
481
|
+
except StopIteration:
|
482
|
+
break
|
483
|
+
# Different versions of pkg_resources throw different errors
|
484
|
+
# just catch them all and ignore packages we can't parse
|
485
|
+
except Exception as e:
|
486
|
+
_logger.warn(f"Unable to parse requirements.txt: {e}")
|
487
|
+
continue
|
488
|
+
requirements_line += "WANDB_ONLY_INCLUDE={} ".format(",".join(include_only))
|
489
|
+
return requirements_line
|
538
490
|
|
539
491
|
|
540
492
|
class EntryPoint:
|
@@ -544,13 +496,6 @@ class EntryPoint:
|
|
544
496
|
self.name = name
|
545
497
|
self.command = command
|
546
498
|
|
547
|
-
def compute_command(self, user_parameters: Optional[List[str]]) -> List[str]:
|
548
|
-
"""Converts user parameter dictionary to a string."""
|
549
|
-
ret = self.command
|
550
|
-
if user_parameters:
|
551
|
-
return ret + user_parameters
|
552
|
-
return ret
|
553
|
-
|
554
499
|
def update_entrypoint_path(self, new_path: str) -> None:
|
555
500
|
"""Updates the entrypoint path to a new path."""
|
556
501
|
if len(self.command) == 2 and (
|
@@ -559,18 +504,35 @@ class EntryPoint:
|
|
559
504
|
self.command[1] = new_path
|
560
505
|
|
561
506
|
|
562
|
-
def
|
563
|
-
|
564
|
-
) ->
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
507
|
+
def _inject_wandb_config_env_vars(
|
508
|
+
config: Dict[str, Any], env_dict: Dict[str, Any], maximum_env_length: int
|
509
|
+
) -> None:
|
510
|
+
str_config = json.dumps(config)
|
511
|
+
if len(str_config) <= maximum_env_length:
|
512
|
+
env_dict["WANDB_CONFIG"] = str_config
|
513
|
+
return
|
514
|
+
|
515
|
+
chunks = [
|
516
|
+
str_config[i : i + maximum_env_length]
|
517
|
+
for i in range(0, len(str_config), maximum_env_length)
|
518
|
+
]
|
519
|
+
config_chunks_dict = {f"WANDB_CONFIG_{i}": chunk for i, chunk in enumerate(chunks)}
|
520
|
+
env_dict.update(config_chunks_dict)
|
521
|
+
|
522
|
+
|
523
|
+
def _inject_file_overrides_env_vars(
|
524
|
+
overrides: Dict[str, Any], env_dict: Dict[str, Any], maximum_env_length: int
|
525
|
+
) -> None:
|
526
|
+
str_overrides = json.dumps(overrides)
|
527
|
+
if len(str_overrides) <= maximum_env_length:
|
528
|
+
env_dict["WANDB_LAUNCH_FILE_OVERRIDES"] = str_overrides
|
529
|
+
return
|
530
|
+
|
531
|
+
chunks = [
|
532
|
+
str_overrides[i : i + maximum_env_length]
|
533
|
+
for i in range(0, len(str_overrides), maximum_env_length)
|
534
|
+
]
|
535
|
+
overrides_chunks_dict = {
|
536
|
+
f"WANDB_LAUNCH_FILE_OVERRIDES_{i}": chunk for i, chunk in enumerate(chunks)
|
537
|
+
}
|
538
|
+
env_dict.update(overrides_chunks_dict)
|
wandb/sdk/launch/agent/agent.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
"""Implementation of launch agent."""
|
2
|
+
|
2
3
|
import asyncio
|
3
4
|
import logging
|
4
5
|
import os
|
@@ -8,7 +9,9 @@ import time
|
|
8
9
|
import traceback
|
9
10
|
from dataclasses import dataclass
|
10
11
|
from multiprocessing import Event
|
11
|
-
from typing import Any, Dict, List, Optional, Union
|
12
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
13
|
+
|
14
|
+
import yaml
|
12
15
|
|
13
16
|
import wandb
|
14
17
|
from wandb.apis.internal import Api
|
@@ -17,11 +20,11 @@ from wandb.sdk.launch._launch_add import launch_add
|
|
17
20
|
from wandb.sdk.launch.runner.local_container import LocalSubmittedRun
|
18
21
|
from wandb.sdk.launch.runner.local_process import LocalProcessRunner
|
19
22
|
from wandb.sdk.launch.sweeps.scheduler import Scheduler
|
23
|
+
from wandb.sdk.launch.utils import LAUNCH_CONFIG_FILE, resolve_build_and_registry_config
|
20
24
|
from wandb.sdk.lib import runid
|
21
25
|
|
22
26
|
from .. import loader
|
23
27
|
from .._project_spec import LaunchProject
|
24
|
-
from ..builder.build import construct_agent_configs
|
25
28
|
from ..errors import LaunchDockerError, LaunchError
|
26
29
|
from ..utils import (
|
27
30
|
LAUNCH_DEFAULT_PROJECT,
|
@@ -133,6 +136,31 @@ class InternalAgentLogger:
|
|
133
136
|
_logger.debug(f"{LOG_PREFIX}{message}")
|
134
137
|
|
135
138
|
|
139
|
+
def construct_agent_configs(
|
140
|
+
launch_config: Optional[Dict] = None,
|
141
|
+
build_config: Optional[Dict] = None,
|
142
|
+
) -> Tuple[Optional[Dict[str, Any]], Dict[str, Any], Dict[str, Any]]:
|
143
|
+
registry_config = None
|
144
|
+
environment_config = None
|
145
|
+
if launch_config is not None:
|
146
|
+
build_config = launch_config.get("builder")
|
147
|
+
registry_config = launch_config.get("registry")
|
148
|
+
|
149
|
+
default_launch_config = None
|
150
|
+
if os.path.exists(os.path.expanduser(LAUNCH_CONFIG_FILE)):
|
151
|
+
with open(os.path.expanduser(LAUNCH_CONFIG_FILE)) as f:
|
152
|
+
default_launch_config = (
|
153
|
+
yaml.safe_load(f) or {}
|
154
|
+
) # In case the config is empty, we want it to be {} instead of None.
|
155
|
+
environment_config = default_launch_config.get("environment")
|
156
|
+
|
157
|
+
build_config, registry_config = resolve_build_and_registry_config(
|
158
|
+
default_launch_config, build_config, registry_config
|
159
|
+
)
|
160
|
+
|
161
|
+
return environment_config, build_config, registry_config
|
162
|
+
|
163
|
+
|
136
164
|
class LaunchAgent:
|
137
165
|
"""Launch agent class which polls run given run queues and launches runs for wandb launch."""
|
138
166
|
|
@@ -172,7 +200,7 @@ class LaunchAgent:
|
|
172
200
|
config: Config dictionary for the agent.
|
173
201
|
"""
|
174
202
|
self._entity = config["entity"]
|
175
|
-
self._project =
|
203
|
+
self._project = LAUNCH_DEFAULT_PROJECT
|
176
204
|
self._api = api
|
177
205
|
self._base_url = self._api.settings().get("base_url")
|
178
206
|
self._ticks = 0
|
@@ -240,7 +268,7 @@ class LaunchAgent:
|
|
240
268
|
"""Determine whether a job/runSpec is a sweep scheduler."""
|
241
269
|
if not run_spec:
|
242
270
|
self._internal_logger.debug(
|
243
|
-
"
|
271
|
+
"Received runSpec in _is_scheduler_job that was empty"
|
244
272
|
)
|
245
273
|
|
246
274
|
if run_spec.get("uri") != Scheduler.PLACEHOLDER_URI:
|
@@ -276,6 +304,8 @@ class LaunchAgent:
|
|
276
304
|
|
277
305
|
def _init_agent_run(self) -> None:
|
278
306
|
# TODO: has it been long enough that all backends support agents?
|
307
|
+
self._wandb_run = None
|
308
|
+
|
279
309
|
if self.gorilla_supports_agents:
|
280
310
|
settings = wandb.Settings(silent=True, disable_git=True)
|
281
311
|
self._wandb_run = wandb.init(
|
@@ -285,8 +315,6 @@ class LaunchAgent:
|
|
285
315
|
id=self._name,
|
286
316
|
job_type=HIDDEN_AGENT_RUN_TYPE,
|
287
317
|
)
|
288
|
-
else:
|
289
|
-
self._wandb_run = None
|
290
318
|
|
291
319
|
@property
|
292
320
|
def thread_ids(self) -> List[int]:
|
@@ -338,10 +366,7 @@ class LaunchAgent:
|
|
338
366
|
if self._name:
|
339
367
|
output_str += f"{self._name} "
|
340
368
|
if self.num_running_jobs < self._max_jobs:
|
341
|
-
output_str += "polling on "
|
342
|
-
if self._project != LAUNCH_DEFAULT_PROJECT:
|
343
|
-
output_str += f"project {self._project}, "
|
344
|
-
output_str += f"queues {','.join(self._queues)}, "
|
369
|
+
output_str += f"polling on queues {','.join(self._queues)}, "
|
345
370
|
output_str += (
|
346
371
|
f"running {self.num_running_jobs} out of a maximum of {self._max_jobs} jobs"
|
347
372
|
)
|
@@ -433,7 +458,6 @@ class LaunchAgent:
|
|
433
458
|
# We retry for 60 seconds with an exponential backoff in case
|
434
459
|
# upsert run is taking a while.
|
435
460
|
logs = None
|
436
|
-
start_time = time.time()
|
437
461
|
interval = 1
|
438
462
|
while True:
|
439
463
|
called_init = self._check_run_exists_and_inited(
|
@@ -442,7 +466,7 @@ class LaunchAgent:
|
|
442
466
|
job_and_run_status.run_id,
|
443
467
|
job_and_run_status.run_queue_item_id,
|
444
468
|
)
|
445
|
-
if called_init or
|
469
|
+
if called_init or interval > RUN_INFO_GRACE_PERIOD:
|
446
470
|
break
|
447
471
|
if not called_init:
|
448
472
|
# Fetch the logs now if we don't get run info on the
|
@@ -691,7 +715,7 @@ class LaunchAgent:
|
|
691
715
|
default_config, override_build_config
|
692
716
|
)
|
693
717
|
image_uri = project.docker_image
|
694
|
-
entrypoint = project.
|
718
|
+
entrypoint = project.get_job_entry_point()
|
695
719
|
environment = loader.environment_from_config(
|
696
720
|
default_config.get("environment", {})
|
697
721
|
)
|