wandb 0.13.11__py3-none-any.whl → 0.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/apis/importers/__init__.py +4 -0
- wandb/apis/importers/base.py +312 -0
- wandb/apis/importers/mlflow.py +113 -0
- wandb/apis/internal.py +9 -0
- wandb/apis/public.py +0 -2
- wandb/cli/cli.py +100 -72
- wandb/docker/__init__.py +33 -5
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/internal/internal_api.py +85 -9
- wandb/sdk/launch/_project_spec.py +45 -55
- wandb/sdk/launch/agent/agent.py +80 -18
- wandb/sdk/launch/builder/build.py +16 -74
- wandb/sdk/launch/builder/docker_builder.py +36 -8
- wandb/sdk/launch/builder/kaniko_builder.py +78 -37
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +68 -18
- wandb/sdk/launch/environment/aws_environment.py +4 -0
- wandb/sdk/launch/launch.py +1 -6
- wandb/sdk/launch/launch_add.py +0 -5
- wandb/sdk/launch/registry/abstract.py +12 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +31 -1
- wandb/sdk/launch/registry/google_artifact_registry.py +32 -0
- wandb/sdk/launch/registry/local_registry.py +15 -1
- wandb/sdk/launch/runner/abstract.py +0 -14
- wandb/sdk/launch/runner/kubernetes_runner.py +25 -19
- wandb/sdk/launch/runner/local_container.py +7 -8
- wandb/sdk/launch/runner/local_process.py +0 -3
- wandb/sdk/launch/runner/sagemaker_runner.py +0 -3
- wandb/sdk/launch/runner/vertex_runner.py +0 -2
- wandb/sdk/launch/sweeps/scheduler.py +39 -10
- wandb/sdk/launch/utils.py +52 -4
- wandb/sdk/wandb_run.py +3 -10
- wandb/sync/sync.py +1 -0
- wandb/util.py +1 -0
- {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/METADATA +1 -1
- {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/RECORD +41 -38
- {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/WHEEL +1 -1
- {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/LICENSE +0 -0
- {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,6 @@
|
|
2
2
|
|
3
3
|
Arguments can come from a launch spec or call to wandb launch.
|
4
4
|
"""
|
5
|
-
import binascii
|
6
5
|
import enum
|
7
6
|
import json
|
8
7
|
import logging
|
@@ -60,7 +59,6 @@ class LaunchProject:
|
|
60
59
|
overrides: Dict[str, Any],
|
61
60
|
resource: str,
|
62
61
|
resource_args: Dict[str, Any],
|
63
|
-
cuda: Optional[bool],
|
64
62
|
run_id: Optional[str],
|
65
63
|
):
|
66
64
|
if uri is not None and utils.is_bare_wandb_uri(uri):
|
@@ -76,10 +74,16 @@ class LaunchProject:
|
|
76
74
|
self.target_entity = target_entity
|
77
75
|
self.target_project = target_project.lower()
|
78
76
|
self.name = name # TODO: replace with run_id
|
77
|
+
# the builder key can be passed in through the resource args
|
78
|
+
# but these resource_args are then passed to the appropriate
|
79
|
+
# runner, so we need to pop the builder key out
|
80
|
+
resource_args_build = resource_args.get(resource, {}).pop("builder", {})
|
79
81
|
self.resource = resource
|
80
82
|
self.resource_args = resource_args
|
81
83
|
self.python_version: Optional[str] = launch_spec.get("python_version")
|
82
|
-
self.
|
84
|
+
self.cuda_base_image: Optional[str] = resource_args_build.get("cuda", {}).get(
|
85
|
+
"base_image"
|
86
|
+
)
|
83
87
|
self._base_image: Optional[str] = launch_spec.get("base_image")
|
84
88
|
self.docker_image: Optional[str] = docker_config.get(
|
85
89
|
"docker_image"
|
@@ -96,11 +100,8 @@ class LaunchProject:
|
|
96
100
|
self.override_artifacts: Dict[str, Any] = overrides.get("artifacts", {})
|
97
101
|
self.override_entrypoint: Optional[EntryPoint] = None
|
98
102
|
self.deps_type: Optional[str] = None
|
99
|
-
self.cuda = cuda
|
100
103
|
self._runtime: Optional[str] = None
|
101
104
|
self.run_id = run_id or generate_id()
|
102
|
-
self._image_tag: str = self._initialize_image_job_tag() or self.run_id
|
103
|
-
wandb.termlog(f"{LOG_PREFIX}Launch project using image tag {self._image_tag}")
|
104
105
|
self._entry_points: Dict[
|
105
106
|
str, EntryPoint
|
106
107
|
] = {} # todo: keep multiple entrypoint support?
|
@@ -140,8 +141,6 @@ class LaunchProject:
|
|
140
141
|
)
|
141
142
|
self.source = LaunchSource.LOCAL
|
142
143
|
self.project_dir = self.uri
|
143
|
-
if launch_spec.get("resource_args"):
|
144
|
-
self.resource_args = launch_spec["resource_args"]
|
145
144
|
|
146
145
|
self.aux_dir = tempfile.mkdtemp()
|
147
146
|
self.clear_parameter_run_config_collisions()
|
@@ -175,24 +174,15 @@ class LaunchProject:
|
|
175
174
|
assert self.job is not None
|
176
175
|
return wandb.util.make_docker_image_name_safe(self.job.split(":")[0])
|
177
176
|
|
178
|
-
def
|
179
|
-
if
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
@property
|
188
|
-
def image_uri(self) -> str:
|
189
|
-
if self.docker_image:
|
190
|
-
return self.docker_image
|
191
|
-
return f"{self.image_name}:{self.image_tag}"
|
192
|
-
|
193
|
-
@property
|
194
|
-
def image_tag(self) -> str:
|
195
|
-
return self._image_tag[:IMAGE_TAG_MAX_LENGTH]
|
177
|
+
def build_required(self) -> bool:
|
178
|
+
"""Checks the source to see if a build is required."""
|
179
|
+
# since the image tag for images built from jobs
|
180
|
+
# is based on the job version index, which is immutable
|
181
|
+
# we don't need to build the image for a job if that tag
|
182
|
+
# already exists
|
183
|
+
if self.source != LaunchSource.JOB:
|
184
|
+
return True
|
185
|
+
return False
|
196
186
|
|
197
187
|
@property
|
198
188
|
def docker_image(self) -> Optional[str]:
|
@@ -243,10 +233,37 @@ class LaunchProject:
|
|
243
233
|
try:
|
244
234
|
job = public_api.job(self.job, path=job_dir)
|
245
235
|
except CommError:
|
246
|
-
raise LaunchError(
|
236
|
+
raise LaunchError(
|
237
|
+
f"Job {self.job} not found. Jobs have the format: <entity>/<project>/<name>:<alias>"
|
238
|
+
)
|
247
239
|
job.configure_launch_project(self)
|
248
240
|
self._job_artifact = job._job_artifact
|
249
241
|
|
242
|
+
def get_image_source_string(self) -> str:
|
243
|
+
"""Returns a unique string identifying the source of an image."""
|
244
|
+
if self.source == LaunchSource.LOCAL:
|
245
|
+
# TODO: more correct to get a hash of local uri contents
|
246
|
+
assert isinstance(self.uri, str)
|
247
|
+
return self.uri
|
248
|
+
elif self.source == LaunchSource.JOB:
|
249
|
+
assert self._job_artifact is not None
|
250
|
+
return f"{self._job_artifact.name}:v{self._job_artifact.version}"
|
251
|
+
elif self.source == LaunchSource.GIT:
|
252
|
+
assert isinstance(self.uri, str)
|
253
|
+
ret = self.uri
|
254
|
+
if self.git_version:
|
255
|
+
ret += self.git_version
|
256
|
+
return ret
|
257
|
+
elif self.source == LaunchSource.WANDB:
|
258
|
+
assert isinstance(self.uri, str)
|
259
|
+
return self.uri
|
260
|
+
elif self.source == LaunchSource.DOCKER:
|
261
|
+
assert isinstance(self.docker_image, str)
|
262
|
+
_logger.debug("")
|
263
|
+
return self.docker_image
|
264
|
+
else:
|
265
|
+
raise LaunchError("Unknown source type when determing image source string")
|
266
|
+
|
250
267
|
def _fetch_project_local(self, internal_api: Api) -> None:
|
251
268
|
"""Fetch a project (either wandb run or git repo) into a local directory, returning the path to the local project directory."""
|
252
269
|
# these asserts are all guaranteed to pass, but are required by mypy
|
@@ -263,24 +280,6 @@ class LaunchProject:
|
|
263
280
|
)
|
264
281
|
program_name = run_info.get("codePath") or run_info["program"]
|
265
282
|
|
266
|
-
if run_info.get("cudaVersion"):
|
267
|
-
original_cuda_version = ".".join(run_info["cudaVersion"].split(".")[:2])
|
268
|
-
|
269
|
-
if self.cuda is None:
|
270
|
-
# only set cuda on by default if cuda is None (unspecified), not False (user specifically requested cpu image)
|
271
|
-
wandb.termlog(
|
272
|
-
f"{LOG_PREFIX}Original wandb run {source_run_name} was run with cuda version {original_cuda_version}. Enabling cuda builds by default; to build on a CPU-only image, run again with --cuda=False"
|
273
|
-
)
|
274
|
-
self.cuda_version = original_cuda_version
|
275
|
-
self.cuda = True
|
276
|
-
if (
|
277
|
-
self.cuda
|
278
|
-
and self.cuda_version
|
279
|
-
and self.cuda_version != original_cuda_version
|
280
|
-
):
|
281
|
-
wandb.termlog(
|
282
|
-
f"{LOG_PREFIX}Specified cuda version {self.cuda_version} differs from original cuda version {original_cuda_version}. Running with specified version {self.cuda_version}"
|
283
|
-
)
|
284
283
|
self.python_version = run_info.get("python", "3")
|
285
284
|
downloaded_code_artifact = utils.check_and_download_code_artifacts(
|
286
285
|
source_entity,
|
@@ -289,11 +288,7 @@ class LaunchProject:
|
|
289
288
|
internal_api,
|
290
289
|
self.project_dir,
|
291
290
|
)
|
292
|
-
if downloaded_code_artifact:
|
293
|
-
self._image_tag = binascii.hexlify(
|
294
|
-
downloaded_code_artifact.digest.encode()
|
295
|
-
).decode()
|
296
|
-
else:
|
291
|
+
if not downloaded_code_artifact:
|
297
292
|
if not run_info["git"]:
|
298
293
|
raise LaunchError(
|
299
294
|
"Reproducing a run requires either an associated git repo or a code artifact logged with `run.log_code()`"
|
@@ -308,12 +303,8 @@ class LaunchProject:
|
|
308
303
|
patch = utils.fetch_project_diff(
|
309
304
|
source_entity, source_project, source_run_name, internal_api
|
310
305
|
)
|
311
|
-
tag_string = run_info["git"]["remote"] + run_info["git"]["commit"]
|
312
306
|
if patch:
|
313
307
|
utils.apply_patch(patch, self.project_dir)
|
314
|
-
tag_string += patch
|
315
|
-
|
316
|
-
self._image_tag = binascii.hexlify(tag_string.encode()).decode()
|
317
308
|
|
318
309
|
# For cases where the entry point wasn't checked into git
|
319
310
|
if not os.path.exists(os.path.join(self.project_dir, program_name)):
|
@@ -450,7 +441,6 @@ def create_project_from_spec(launch_spec: Dict[str, Any], api: Api) -> LaunchPro
|
|
450
441
|
launch_spec.get("overrides", {}),
|
451
442
|
launch_spec.get("resource", None),
|
452
443
|
launch_spec.get("resource_args", {}),
|
453
|
-
launch_spec.get("cuda", None),
|
454
444
|
launch_spec.get("run_id", None),
|
455
445
|
)
|
456
446
|
|
wandb/sdk/launch/agent/agent.py
CHANGED
@@ -13,6 +13,8 @@ from typing import Any, Dict, List, Optional, Union
|
|
13
13
|
import wandb
|
14
14
|
import wandb.util as util
|
15
15
|
from wandb.apis.internal import Api
|
16
|
+
from wandb.errors import CommError
|
17
|
+
from wandb.sdk.launch._project_spec import LaunchProject
|
16
18
|
from wandb.sdk.launch.runner.local_container import LocalSubmittedRun
|
17
19
|
from wandb.sdk.launch.sweeps import SCHEDULER_URI
|
18
20
|
from wandb.sdk.lib import runid
|
@@ -21,7 +23,13 @@ from .. import loader
|
|
21
23
|
from .._project_spec import create_project_from_spec, fetch_and_validate_project
|
22
24
|
from ..builder.build import construct_builder_args
|
23
25
|
from ..runner.abstract import AbstractRun
|
24
|
-
from ..utils import
|
26
|
+
from ..utils import (
|
27
|
+
LAUNCH_DEFAULT_PROJECT,
|
28
|
+
LOG_PREFIX,
|
29
|
+
PROJECT_SYNCHRONOUS,
|
30
|
+
LaunchDockerError,
|
31
|
+
LaunchError,
|
32
|
+
)
|
25
33
|
|
26
34
|
AGENT_POLLING_INTERVAL = 10
|
27
35
|
ACTIVE_SWEEP_POLLING_INTERVAL = 1 # more frequent when we know we have jobs
|
@@ -37,6 +45,9 @@ _logger = logging.getLogger(__name__)
|
|
37
45
|
|
38
46
|
@dataclass
|
39
47
|
class JobAndRunStatus:
|
48
|
+
run_queue_item_id: str
|
49
|
+
run_id: Optional[str] = None
|
50
|
+
project: Optional[str] = None
|
40
51
|
run: Optional[AbstractRun] = None
|
41
52
|
failed_to_start: bool = False
|
42
53
|
completed: bool = False
|
@@ -46,6 +57,10 @@ class JobAndRunStatus:
|
|
46
57
|
def job_completed(self) -> bool:
|
47
58
|
return self.completed or self.failed_to_start
|
48
59
|
|
60
|
+
def update_run_info(self, launch_project: LaunchProject) -> None:
|
61
|
+
self.run_id = launch_project.run_id
|
62
|
+
self.project = launch_project.target_project
|
63
|
+
|
49
64
|
|
50
65
|
def _convert_access(access: str) -> str:
|
51
66
|
"""Convert access string to a value accepted by wandb."""
|
@@ -90,7 +105,20 @@ def _job_is_scheduler(run_spec: Dict[str, Any]) -> bool:
|
|
90
105
|
if not run_spec:
|
91
106
|
_logger.debug("Recieved runSpec in _job_is_scheduler that was empty")
|
92
107
|
|
93
|
-
|
108
|
+
if run_spec.get("uri") != SCHEDULER_URI:
|
109
|
+
return False
|
110
|
+
|
111
|
+
if run_spec.get("resource") == "local-process":
|
112
|
+
# If a scheduler is a local-process (100%), also
|
113
|
+
# confirm command is in format: [wandb scheduler <sweep>]
|
114
|
+
cmd = run_spec.get("overrides", {}).get("entry_point", [])
|
115
|
+
if len(cmd) < 3:
|
116
|
+
return False
|
117
|
+
|
118
|
+
if cmd[:2] != ["wandb", "scheduler"]:
|
119
|
+
return False
|
120
|
+
|
121
|
+
return True
|
94
122
|
|
95
123
|
|
96
124
|
class LaunchAgent:
|
@@ -119,7 +147,6 @@ class LaunchAgent:
|
|
119
147
|
self._max_schedulers = _max_from_config(config, "max_schedulers")
|
120
148
|
self._pool = ThreadPool(
|
121
149
|
processes=int(min(MAX_THREADS, self._max_jobs + self._max_schedulers)),
|
122
|
-
# initializer=init_pool_processes,
|
123
150
|
initargs=(self._jobs, self._jobs_lock),
|
124
151
|
)
|
125
152
|
self.default_config: Dict[str, Any] = config
|
@@ -128,6 +155,10 @@ class LaunchAgent:
|
|
128
155
|
self.gorilla_supports_agents = (
|
129
156
|
self._api.launch_agent_introspection() is not None
|
130
157
|
)
|
158
|
+
self._gorilla_supports_fail_run_queue_items = (
|
159
|
+
self._api.fail_run_queue_item_introspection()
|
160
|
+
)
|
161
|
+
|
131
162
|
self._queues = config.get("queues", ["default"])
|
132
163
|
create_response = self._api.create_launch_agent(
|
133
164
|
self._entity,
|
@@ -137,6 +168,14 @@ class LaunchAgent:
|
|
137
168
|
)
|
138
169
|
self._id = create_response["launchAgentId"]
|
139
170
|
self._name = "" # hacky: want to display this to the user but we don't get it back from gql until polling starts. fix later
|
171
|
+
if self._api.entity_is_team(self._entity):
|
172
|
+
wandb.termwarn(
|
173
|
+
f"{LOG_PREFIX}Agent is running on team entity ({self._entity}). Members of this team will be able to run code on this device."
|
174
|
+
)
|
175
|
+
|
176
|
+
def fail_run_queue_item(self, run_queue_item_id: str) -> None:
|
177
|
+
if self._gorilla_supports_fail_run_queue_items:
|
178
|
+
self._api.fail_run_queue_item(run_queue_item_id)
|
140
179
|
|
141
180
|
@property
|
142
181
|
def thread_ids(self) -> List[int]:
|
@@ -214,9 +253,28 @@ class LaunchAgent:
|
|
214
253
|
|
215
254
|
def finish_thread_id(self, thread_id: int) -> None:
|
216
255
|
"""Removes the job from our list for now."""
|
256
|
+
job_and_run_status = self._jobs[thread_id]
|
257
|
+
if not job_and_run_status.run_id or not job_and_run_status.project:
|
258
|
+
self.fail_run_queue_item(job_and_run_status.run_queue_item_id)
|
259
|
+
else:
|
260
|
+
run_info = None
|
261
|
+
# sweep runs exist but have no info before they are started
|
262
|
+
# so run_info returned will be None
|
263
|
+
# normal runs just throw a comm error
|
264
|
+
try:
|
265
|
+
run_info = self._api.get_run_info(
|
266
|
+
self._entity, job_and_run_status.project, job_and_run_status.run_id
|
267
|
+
)
|
268
|
+
|
269
|
+
except CommError:
|
270
|
+
pass
|
271
|
+
if run_info is None:
|
272
|
+
self.fail_run_queue_item(job_and_run_status.run_queue_item_id)
|
273
|
+
|
217
274
|
# TODO: keep logs or something for the finished jobs
|
218
275
|
with self._jobs_lock:
|
219
276
|
del self._jobs[thread_id]
|
277
|
+
|
220
278
|
# update status back to polling if no jobs are running
|
221
279
|
if len(self.thread_ids) == 0:
|
222
280
|
self.update_status(AGENT_POLLING)
|
@@ -295,16 +353,12 @@ class LaunchAgent:
|
|
295
353
|
|
296
354
|
try:
|
297
355
|
self.run_job(job)
|
298
|
-
except Exception:
|
356
|
+
except Exception as e:
|
299
357
|
wandb.termerror(
|
300
358
|
f"{LOG_PREFIX}Error running job: {traceback.format_exc()}"
|
301
359
|
)
|
302
|
-
|
303
|
-
|
304
|
-
except Exception:
|
305
|
-
_logger.error(
|
306
|
-
f"{LOG_PREFIX}Error acking job when job errored: {traceback.format_exc()}"
|
307
|
-
)
|
360
|
+
util.sentry_exc(e)
|
361
|
+
self.fail_run_queue_item(job["runQueueItemId"])
|
308
362
|
|
309
363
|
for thread_id in self.thread_ids:
|
310
364
|
self._update_finished(thread_id)
|
@@ -340,11 +394,20 @@ class LaunchAgent:
|
|
340
394
|
default_config: Dict[str, Any],
|
341
395
|
api: Api,
|
342
396
|
) -> None:
|
397
|
+
thread_id = threading.current_thread().ident
|
398
|
+
assert thread_id is not None
|
343
399
|
try:
|
344
|
-
self._thread_run_job(launch_spec, job, default_config, api)
|
345
|
-
except
|
400
|
+
self._thread_run_job(launch_spec, job, default_config, api, thread_id)
|
401
|
+
except LaunchDockerError as e:
|
402
|
+
wandb.termerror(
|
403
|
+
f"{LOG_PREFIX}agent {self._name} encountered an issue while starting Docker, see above output for details."
|
404
|
+
)
|
405
|
+
self.finish_thread_id(thread_id)
|
406
|
+
util.sentry_exc(e)
|
407
|
+
except Exception as e:
|
346
408
|
wandb.termerror(f"{LOG_PREFIX}Error running job: {traceback.format_exc()}")
|
347
|
-
|
409
|
+
self.finish_thread_id(thread_id)
|
410
|
+
util.sentry_exc(e)
|
348
411
|
|
349
412
|
def _thread_run_job(
|
350
413
|
self,
|
@@ -352,13 +415,13 @@ class LaunchAgent:
|
|
352
415
|
job: Dict[str, Any],
|
353
416
|
default_config: Dict[str, Any],
|
354
417
|
api: Api,
|
418
|
+
thread_id: int,
|
355
419
|
) -> None:
|
356
|
-
|
357
|
-
assert thread_id is not None
|
358
|
-
job_tracker = JobAndRunStatus()
|
420
|
+
job_tracker = JobAndRunStatus(job["runQueueItemId"])
|
359
421
|
with self._jobs_lock:
|
360
422
|
self._jobs[thread_id] = job_tracker
|
361
423
|
project = create_project_from_spec(launch_spec, api)
|
424
|
+
job_tracker.update_run_info(project)
|
362
425
|
_logger.info("Fetching and validating project...")
|
363
426
|
project = fetch_and_validate_project(project, api)
|
364
427
|
_logger.info("Fetching resource...")
|
@@ -366,8 +429,6 @@ class LaunchAgent:
|
|
366
429
|
backend_config: Dict[str, Any] = {
|
367
430
|
PROJECT_SYNCHRONOUS: False, # agent always runs async
|
368
431
|
}
|
369
|
-
|
370
|
-
backend_config["runQueueItemId"] = job["runQueueItemId"]
|
371
432
|
_logger.info("Loading backend")
|
372
433
|
override_build_config = launch_spec.get("builder")
|
373
434
|
|
@@ -382,6 +443,7 @@ class LaunchAgent:
|
|
382
443
|
builder = loader.builder_from_config(build_config, environment, registry)
|
383
444
|
backend = loader.runner_from_config(resource, api, backend_config, environment)
|
384
445
|
_logger.info("Backend loaded...")
|
446
|
+
api.ack_run_queue_item(job["runQueueItemId"], project.run_id)
|
385
447
|
run = backend.run(project, builder)
|
386
448
|
|
387
449
|
if _job_is_scheduler(launch_spec):
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import hashlib
|
1
2
|
import json
|
2
3
|
import logging
|
3
4
|
import os
|
@@ -21,7 +22,6 @@ from wandb.sdk.launch.loader import (
|
|
21
22
|
registry_from_config,
|
22
23
|
)
|
23
24
|
|
24
|
-
from ...lib.git import GitRepo
|
25
25
|
from .._project_spec import (
|
26
26
|
EntryPoint,
|
27
27
|
EntrypointDefaults,
|
@@ -132,6 +132,7 @@ PIP_TEMPLATE = """
|
|
132
132
|
RUN python -m venv /env
|
133
133
|
# make sure we install into the env
|
134
134
|
ENV PATH="/env/bin:$PATH"
|
135
|
+
|
135
136
|
COPY {requirements_files} ./
|
136
137
|
{buildx_optional_prefix} {pip_install}
|
137
138
|
"""
|
@@ -192,8 +193,8 @@ def get_base_setup(
|
|
192
193
|
CPU version is built on python, GPU version is built on nvidia:cuda.
|
193
194
|
"""
|
194
195
|
python_base_image = f"python:{py_version}-buster"
|
195
|
-
if launch_project.
|
196
|
-
|
196
|
+
if launch_project.cuda_base_image:
|
197
|
+
_logger.info(f"Using cuda base image: {launch_project.cuda_base_image}")
|
197
198
|
# cuda image doesn't come with python tooling
|
198
199
|
if py_major == "2":
|
199
200
|
python_packages = [
|
@@ -210,7 +211,7 @@ def get_base_setup(
|
|
210
211
|
"python3-setuptools",
|
211
212
|
]
|
212
213
|
base_setup = CUDA_SETUP_TEMPLATE.format(
|
213
|
-
cuda_base_image=
|
214
|
+
cuda_base_image=launch_project.cuda_base_image,
|
214
215
|
python_packages=" \\\n".join(python_packages),
|
215
216
|
py_version=py_version,
|
216
217
|
)
|
@@ -390,57 +391,6 @@ def generate_dockerfile(
|
|
390
391
|
return dockerfile_contents
|
391
392
|
|
392
393
|
|
393
|
-
_inspected_images = {}
|
394
|
-
|
395
|
-
|
396
|
-
def docker_image_exists(docker_image: str, should_raise: bool = False) -> bool:
|
397
|
-
"""Check if a specific image is already available.
|
398
|
-
|
399
|
-
Optionally raises an exception if the image is not found.
|
400
|
-
"""
|
401
|
-
_logger.info("Checking if base image exists...")
|
402
|
-
try:
|
403
|
-
data = docker.run(["docker", "image", "inspect", docker_image])
|
404
|
-
# always true, since return stderr defaults to false
|
405
|
-
assert isinstance(data, str)
|
406
|
-
parsed = json.loads(data)[0]
|
407
|
-
_inspected_images[docker_image] = parsed
|
408
|
-
return True
|
409
|
-
except (docker.DockerError, ValueError) as e:
|
410
|
-
if should_raise:
|
411
|
-
raise e
|
412
|
-
_logger.info("Base image not found. Generating new base image")
|
413
|
-
return False
|
414
|
-
|
415
|
-
|
416
|
-
def docker_image_inspect(docker_image: str) -> Dict[str, Any]:
|
417
|
-
"""Get the parsed json result of docker inspect image_name."""
|
418
|
-
if _inspected_images.get(docker_image) is None:
|
419
|
-
docker_image_exists(docker_image, True)
|
420
|
-
return _inspected_images.get(docker_image, {})
|
421
|
-
|
422
|
-
|
423
|
-
def pull_docker_image(docker_image: str) -> None:
|
424
|
-
"""Pull the requested docker image."""
|
425
|
-
if docker_image_exists(docker_image):
|
426
|
-
# don't pull images if they exist already, eg if they are local images
|
427
|
-
return
|
428
|
-
try:
|
429
|
-
docker.run(["docker", "pull", docker_image])
|
430
|
-
except docker.DockerError as e:
|
431
|
-
raise LaunchError(f"Docker server returned error: {e}")
|
432
|
-
|
433
|
-
|
434
|
-
def construct_gcp_image_uri(
|
435
|
-
launch_project: LaunchProject,
|
436
|
-
gcp_repo: str,
|
437
|
-
gcp_project: str,
|
438
|
-
gcp_registry: str,
|
439
|
-
) -> str:
|
440
|
-
base_uri = launch_project.image_uri
|
441
|
-
return "/".join([gcp_registry, gcp_project, gcp_repo, base_uri])
|
442
|
-
|
443
|
-
|
444
394
|
def construct_gcp_registry_uri(
|
445
395
|
gcp_repo: str, gcp_project: str, gcp_registry: str
|
446
396
|
) -> str:
|
@@ -474,24 +424,6 @@ def _parse_existing_requirements(launch_project: LaunchProject) -> str:
|
|
474
424
|
return requirements_line
|
475
425
|
|
476
426
|
|
477
|
-
def _get_docker_image_uri(name: Optional[str], work_dir: str, image_id: str) -> str:
|
478
|
-
"""Create a Docker image URI for a project.
|
479
|
-
|
480
|
-
The resulting URI is based on the git hash of the specified working directory.
|
481
|
-
:param name: The URI of the Docker repository with which to tag the image. The
|
482
|
-
repository URI is used as the prefix of the image URI.
|
483
|
-
:param work_dir: Path to the working directory in which to search for a git commit hash.
|
484
|
-
"""
|
485
|
-
name = name.replace(" ", "-") if name else "wandb-launch"
|
486
|
-
# Optionally include first 7 digits of git SHA in tag name, if available.
|
487
|
-
|
488
|
-
git_commit = GitRepo(work_dir).last_commit
|
489
|
-
version_string = (
|
490
|
-
":" + str(git_commit[:7]) + image_id if git_commit else ":" + image_id
|
491
|
-
)
|
492
|
-
return name + version_string
|
493
|
-
|
494
|
-
|
495
427
|
def _create_docker_build_ctx(
|
496
428
|
launch_project: LaunchProject,
|
497
429
|
dockerfile_contents: str,
|
@@ -537,7 +469,7 @@ def construct_builder_args(
|
|
537
469
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
538
470
|
registry_config = None
|
539
471
|
if launch_config is not None:
|
540
|
-
build_config = launch_config.get("
|
472
|
+
build_config = launch_config.get("builder")
|
541
473
|
registry_config = launch_config.get("registry")
|
542
474
|
|
543
475
|
default_launch_config = None
|
@@ -625,3 +557,13 @@ def build_image_from_project(
|
|
625
557
|
raise LaunchError("Error building image uri")
|
626
558
|
else:
|
627
559
|
return image_uri
|
560
|
+
|
561
|
+
|
562
|
+
def image_tag_from_dockerfile_and_source(
|
563
|
+
launch_project: LaunchProject, dockerfile_contents: str
|
564
|
+
) -> str:
|
565
|
+
"""Hashes the source and dockerfile contents into a unique tag."""
|
566
|
+
image_source_string = launch_project.get_image_source_string()
|
567
|
+
unique_id_string = image_source_string + dockerfile_contents
|
568
|
+
image_tag = hashlib.sha256(unique_id_string.encode("utf-8")).hexdigest()[:8]
|
569
|
+
return image_tag
|
@@ -16,10 +16,17 @@ from .._project_spec import (
|
|
16
16
|
get_entry_point_command,
|
17
17
|
)
|
18
18
|
from ..registry.local_registry import LocalRegistry
|
19
|
-
from ..utils import
|
19
|
+
from ..utils import (
|
20
|
+
LOG_PREFIX,
|
21
|
+
LaunchDockerError,
|
22
|
+
LaunchError,
|
23
|
+
sanitize_wandb_api_key,
|
24
|
+
warn_failed_packages_from_build_logs,
|
25
|
+
)
|
20
26
|
from .build import (
|
21
27
|
_create_docker_build_ctx,
|
22
28
|
generate_dockerfile,
|
29
|
+
image_tag_from_dockerfile_and_source,
|
23
30
|
validate_docker_installation,
|
24
31
|
)
|
25
32
|
|
@@ -110,15 +117,32 @@ class DockerBuilder(AbstractBuilder):
|
|
110
117
|
launch_project (LaunchProject): The project to build.
|
111
118
|
entrypoint (EntryPoint): The entrypoint to use.
|
112
119
|
"""
|
120
|
+
dockerfile_str = generate_dockerfile(
|
121
|
+
launch_project, entrypoint, launch_project.resource, "docker"
|
122
|
+
)
|
123
|
+
|
124
|
+
image_tag = image_tag_from_dockerfile_and_source(launch_project, dockerfile_str)
|
125
|
+
|
113
126
|
repository = None if not self.registry else self.registry.get_repo_uri()
|
127
|
+
# if repo is set, use the repo name as the image name
|
114
128
|
if repository:
|
115
|
-
image_uri = f"{repository}:{
|
129
|
+
image_uri = f"{repository}:{image_tag}"
|
130
|
+
# otherwise, base the image name off of the source
|
131
|
+
# which the launch_project checks in image_name
|
116
132
|
else:
|
117
|
-
image_uri = launch_project.
|
118
|
-
|
119
|
-
|
120
|
-
|
133
|
+
image_uri = f"{launch_project.image_name}:{image_tag}"
|
134
|
+
|
135
|
+
if not launch_project.build_required() and self.registry.check_image_exists(
|
136
|
+
image_uri
|
137
|
+
):
|
138
|
+
return image_uri
|
139
|
+
|
140
|
+
_logger.info(
|
141
|
+
f"image {image_uri} does not already exist in repository, building."
|
121
142
|
)
|
143
|
+
|
144
|
+
entry_cmd = get_entry_point_command(entrypoint, launch_project.override_args)
|
145
|
+
|
122
146
|
create_metadata_file(
|
123
147
|
launch_project,
|
124
148
|
image_uri,
|
@@ -128,9 +152,13 @@ class DockerBuilder(AbstractBuilder):
|
|
128
152
|
build_ctx_path = _create_docker_build_ctx(launch_project, dockerfile_str)
|
129
153
|
dockerfile = os.path.join(build_ctx_path, _GENERATED_DOCKERFILE_NAME)
|
130
154
|
try:
|
131
|
-
docker.build(
|
155
|
+
output = docker.build(
|
156
|
+
tags=[image_uri], file=dockerfile, context_path=build_ctx_path
|
157
|
+
)
|
158
|
+
warn_failed_packages_from_build_logs(output, image_uri)
|
159
|
+
|
132
160
|
except docker.DockerError as e:
|
133
|
-
raise
|
161
|
+
raise LaunchDockerError(f"Error communicating with docker client: {e}")
|
134
162
|
|
135
163
|
try:
|
136
164
|
os.remove(build_ctx_path)
|