wandb 0.13.11__py3-none-any.whl → 0.14.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -1
- wandb/apis/importers/__init__.py +4 -0
- wandb/apis/importers/base.py +312 -0
- wandb/apis/importers/mlflow.py +113 -0
- wandb/apis/internal.py +9 -0
- wandb/apis/public.py +0 -2
- wandb/cli/cli.py +100 -72
- wandb/docker/__init__.py +33 -5
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/internal/internal_api.py +85 -9
- wandb/sdk/launch/_project_spec.py +45 -55
- wandb/sdk/launch/agent/agent.py +80 -18
- wandb/sdk/launch/builder/build.py +16 -74
- wandb/sdk/launch/builder/docker_builder.py +36 -8
- wandb/sdk/launch/builder/kaniko_builder.py +78 -37
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +68 -18
- wandb/sdk/launch/environment/aws_environment.py +4 -0
- wandb/sdk/launch/launch.py +1 -6
- wandb/sdk/launch/launch_add.py +0 -5
- wandb/sdk/launch/registry/abstract.py +12 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +31 -1
- wandb/sdk/launch/registry/google_artifact_registry.py +32 -0
- wandb/sdk/launch/registry/local_registry.py +15 -1
- wandb/sdk/launch/runner/abstract.py +0 -14
- wandb/sdk/launch/runner/kubernetes_runner.py +25 -19
- wandb/sdk/launch/runner/local_container.py +7 -8
- wandb/sdk/launch/runner/local_process.py +0 -3
- wandb/sdk/launch/runner/sagemaker_runner.py +0 -3
- wandb/sdk/launch/runner/vertex_runner.py +0 -2
- wandb/sdk/launch/sweeps/scheduler.py +39 -10
- wandb/sdk/launch/utils.py +52 -4
- wandb/sdk/wandb_run.py +3 -10
- wandb/sync/sync.py +1 -0
- wandb/util.py +1 -0
- {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/METADATA +1 -1
- {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/RECORD +41 -38
- {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/WHEEL +1 -1
- {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/LICENSE +0 -0
- {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,6 @@
|
|
2
2
|
|
3
3
|
Arguments can come from a launch spec or call to wandb launch.
|
4
4
|
"""
|
5
|
-
import binascii
|
6
5
|
import enum
|
7
6
|
import json
|
8
7
|
import logging
|
@@ -60,7 +59,6 @@ class LaunchProject:
|
|
60
59
|
overrides: Dict[str, Any],
|
61
60
|
resource: str,
|
62
61
|
resource_args: Dict[str, Any],
|
63
|
-
cuda: Optional[bool],
|
64
62
|
run_id: Optional[str],
|
65
63
|
):
|
66
64
|
if uri is not None and utils.is_bare_wandb_uri(uri):
|
@@ -76,10 +74,16 @@ class LaunchProject:
|
|
76
74
|
self.target_entity = target_entity
|
77
75
|
self.target_project = target_project.lower()
|
78
76
|
self.name = name # TODO: replace with run_id
|
77
|
+
# the builder key can be passed in through the resource args
|
78
|
+
# but these resource_args are then passed to the appropriate
|
79
|
+
# runner, so we need to pop the builder key out
|
80
|
+
resource_args_build = resource_args.get(resource, {}).pop("builder", {})
|
79
81
|
self.resource = resource
|
80
82
|
self.resource_args = resource_args
|
81
83
|
self.python_version: Optional[str] = launch_spec.get("python_version")
|
82
|
-
self.
|
84
|
+
self.cuda_base_image: Optional[str] = resource_args_build.get("cuda", {}).get(
|
85
|
+
"base_image"
|
86
|
+
)
|
83
87
|
self._base_image: Optional[str] = launch_spec.get("base_image")
|
84
88
|
self.docker_image: Optional[str] = docker_config.get(
|
85
89
|
"docker_image"
|
@@ -96,11 +100,8 @@ class LaunchProject:
|
|
96
100
|
self.override_artifacts: Dict[str, Any] = overrides.get("artifacts", {})
|
97
101
|
self.override_entrypoint: Optional[EntryPoint] = None
|
98
102
|
self.deps_type: Optional[str] = None
|
99
|
-
self.cuda = cuda
|
100
103
|
self._runtime: Optional[str] = None
|
101
104
|
self.run_id = run_id or generate_id()
|
102
|
-
self._image_tag: str = self._initialize_image_job_tag() or self.run_id
|
103
|
-
wandb.termlog(f"{LOG_PREFIX}Launch project using image tag {self._image_tag}")
|
104
105
|
self._entry_points: Dict[
|
105
106
|
str, EntryPoint
|
106
107
|
] = {} # todo: keep multiple entrypoint support?
|
@@ -140,8 +141,6 @@ class LaunchProject:
|
|
140
141
|
)
|
141
142
|
self.source = LaunchSource.LOCAL
|
142
143
|
self.project_dir = self.uri
|
143
|
-
if launch_spec.get("resource_args"):
|
144
|
-
self.resource_args = launch_spec["resource_args"]
|
145
144
|
|
146
145
|
self.aux_dir = tempfile.mkdtemp()
|
147
146
|
self.clear_parameter_run_config_collisions()
|
@@ -175,24 +174,15 @@ class LaunchProject:
|
|
175
174
|
assert self.job is not None
|
176
175
|
return wandb.util.make_docker_image_name_safe(self.job.split(":")[0])
|
177
176
|
|
178
|
-
def
|
179
|
-
if
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
@property
|
188
|
-
def image_uri(self) -> str:
|
189
|
-
if self.docker_image:
|
190
|
-
return self.docker_image
|
191
|
-
return f"{self.image_name}:{self.image_tag}"
|
192
|
-
|
193
|
-
@property
|
194
|
-
def image_tag(self) -> str:
|
195
|
-
return self._image_tag[:IMAGE_TAG_MAX_LENGTH]
|
177
|
+
def build_required(self) -> bool:
|
178
|
+
"""Checks the source to see if a build is required."""
|
179
|
+
# since the image tag for images built from jobs
|
180
|
+
# is based on the job version index, which is immutable
|
181
|
+
# we don't need to build the image for a job if that tag
|
182
|
+
# already exists
|
183
|
+
if self.source != LaunchSource.JOB:
|
184
|
+
return True
|
185
|
+
return False
|
196
186
|
|
197
187
|
@property
|
198
188
|
def docker_image(self) -> Optional[str]:
|
@@ -243,10 +233,37 @@ class LaunchProject:
|
|
243
233
|
try:
|
244
234
|
job = public_api.job(self.job, path=job_dir)
|
245
235
|
except CommError:
|
246
|
-
raise LaunchError(
|
236
|
+
raise LaunchError(
|
237
|
+
f"Job {self.job} not found. Jobs have the format: <entity>/<project>/<name>:<alias>"
|
238
|
+
)
|
247
239
|
job.configure_launch_project(self)
|
248
240
|
self._job_artifact = job._job_artifact
|
249
241
|
|
242
|
+
def get_image_source_string(self) -> str:
|
243
|
+
"""Returns a unique string identifying the source of an image."""
|
244
|
+
if self.source == LaunchSource.LOCAL:
|
245
|
+
# TODO: more correct to get a hash of local uri contents
|
246
|
+
assert isinstance(self.uri, str)
|
247
|
+
return self.uri
|
248
|
+
elif self.source == LaunchSource.JOB:
|
249
|
+
assert self._job_artifact is not None
|
250
|
+
return f"{self._job_artifact.name}:v{self._job_artifact.version}"
|
251
|
+
elif self.source == LaunchSource.GIT:
|
252
|
+
assert isinstance(self.uri, str)
|
253
|
+
ret = self.uri
|
254
|
+
if self.git_version:
|
255
|
+
ret += self.git_version
|
256
|
+
return ret
|
257
|
+
elif self.source == LaunchSource.WANDB:
|
258
|
+
assert isinstance(self.uri, str)
|
259
|
+
return self.uri
|
260
|
+
elif self.source == LaunchSource.DOCKER:
|
261
|
+
assert isinstance(self.docker_image, str)
|
262
|
+
_logger.debug("")
|
263
|
+
return self.docker_image
|
264
|
+
else:
|
265
|
+
raise LaunchError("Unknown source type when determing image source string")
|
266
|
+
|
250
267
|
def _fetch_project_local(self, internal_api: Api) -> None:
|
251
268
|
"""Fetch a project (either wandb run or git repo) into a local directory, returning the path to the local project directory."""
|
252
269
|
# these asserts are all guaranteed to pass, but are required by mypy
|
@@ -263,24 +280,6 @@ class LaunchProject:
|
|
263
280
|
)
|
264
281
|
program_name = run_info.get("codePath") or run_info["program"]
|
265
282
|
|
266
|
-
if run_info.get("cudaVersion"):
|
267
|
-
original_cuda_version = ".".join(run_info["cudaVersion"].split(".")[:2])
|
268
|
-
|
269
|
-
if self.cuda is None:
|
270
|
-
# only set cuda on by default if cuda is None (unspecified), not False (user specifically requested cpu image)
|
271
|
-
wandb.termlog(
|
272
|
-
f"{LOG_PREFIX}Original wandb run {source_run_name} was run with cuda version {original_cuda_version}. Enabling cuda builds by default; to build on a CPU-only image, run again with --cuda=False"
|
273
|
-
)
|
274
|
-
self.cuda_version = original_cuda_version
|
275
|
-
self.cuda = True
|
276
|
-
if (
|
277
|
-
self.cuda
|
278
|
-
and self.cuda_version
|
279
|
-
and self.cuda_version != original_cuda_version
|
280
|
-
):
|
281
|
-
wandb.termlog(
|
282
|
-
f"{LOG_PREFIX}Specified cuda version {self.cuda_version} differs from original cuda version {original_cuda_version}. Running with specified version {self.cuda_version}"
|
283
|
-
)
|
284
283
|
self.python_version = run_info.get("python", "3")
|
285
284
|
downloaded_code_artifact = utils.check_and_download_code_artifacts(
|
286
285
|
source_entity,
|
@@ -289,11 +288,7 @@ class LaunchProject:
|
|
289
288
|
internal_api,
|
290
289
|
self.project_dir,
|
291
290
|
)
|
292
|
-
if downloaded_code_artifact:
|
293
|
-
self._image_tag = binascii.hexlify(
|
294
|
-
downloaded_code_artifact.digest.encode()
|
295
|
-
).decode()
|
296
|
-
else:
|
291
|
+
if not downloaded_code_artifact:
|
297
292
|
if not run_info["git"]:
|
298
293
|
raise LaunchError(
|
299
294
|
"Reproducing a run requires either an associated git repo or a code artifact logged with `run.log_code()`"
|
@@ -308,12 +303,8 @@ class LaunchProject:
|
|
308
303
|
patch = utils.fetch_project_diff(
|
309
304
|
source_entity, source_project, source_run_name, internal_api
|
310
305
|
)
|
311
|
-
tag_string = run_info["git"]["remote"] + run_info["git"]["commit"]
|
312
306
|
if patch:
|
313
307
|
utils.apply_patch(patch, self.project_dir)
|
314
|
-
tag_string += patch
|
315
|
-
|
316
|
-
self._image_tag = binascii.hexlify(tag_string.encode()).decode()
|
317
308
|
|
318
309
|
# For cases where the entry point wasn't checked into git
|
319
310
|
if not os.path.exists(os.path.join(self.project_dir, program_name)):
|
@@ -450,7 +441,6 @@ def create_project_from_spec(launch_spec: Dict[str, Any], api: Api) -> LaunchPro
|
|
450
441
|
launch_spec.get("overrides", {}),
|
451
442
|
launch_spec.get("resource", None),
|
452
443
|
launch_spec.get("resource_args", {}),
|
453
|
-
launch_spec.get("cuda", None),
|
454
444
|
launch_spec.get("run_id", None),
|
455
445
|
)
|
456
446
|
|
wandb/sdk/launch/agent/agent.py
CHANGED
@@ -13,6 +13,8 @@ from typing import Any, Dict, List, Optional, Union
|
|
13
13
|
import wandb
|
14
14
|
import wandb.util as util
|
15
15
|
from wandb.apis.internal import Api
|
16
|
+
from wandb.errors import CommError
|
17
|
+
from wandb.sdk.launch._project_spec import LaunchProject
|
16
18
|
from wandb.sdk.launch.runner.local_container import LocalSubmittedRun
|
17
19
|
from wandb.sdk.launch.sweeps import SCHEDULER_URI
|
18
20
|
from wandb.sdk.lib import runid
|
@@ -21,7 +23,13 @@ from .. import loader
|
|
21
23
|
from .._project_spec import create_project_from_spec, fetch_and_validate_project
|
22
24
|
from ..builder.build import construct_builder_args
|
23
25
|
from ..runner.abstract import AbstractRun
|
24
|
-
from ..utils import
|
26
|
+
from ..utils import (
|
27
|
+
LAUNCH_DEFAULT_PROJECT,
|
28
|
+
LOG_PREFIX,
|
29
|
+
PROJECT_SYNCHRONOUS,
|
30
|
+
LaunchDockerError,
|
31
|
+
LaunchError,
|
32
|
+
)
|
25
33
|
|
26
34
|
AGENT_POLLING_INTERVAL = 10
|
27
35
|
ACTIVE_SWEEP_POLLING_INTERVAL = 1 # more frequent when we know we have jobs
|
@@ -37,6 +45,9 @@ _logger = logging.getLogger(__name__)
|
|
37
45
|
|
38
46
|
@dataclass
|
39
47
|
class JobAndRunStatus:
|
48
|
+
run_queue_item_id: str
|
49
|
+
run_id: Optional[str] = None
|
50
|
+
project: Optional[str] = None
|
40
51
|
run: Optional[AbstractRun] = None
|
41
52
|
failed_to_start: bool = False
|
42
53
|
completed: bool = False
|
@@ -46,6 +57,10 @@ class JobAndRunStatus:
|
|
46
57
|
def job_completed(self) -> bool:
|
47
58
|
return self.completed or self.failed_to_start
|
48
59
|
|
60
|
+
def update_run_info(self, launch_project: LaunchProject) -> None:
|
61
|
+
self.run_id = launch_project.run_id
|
62
|
+
self.project = launch_project.target_project
|
63
|
+
|
49
64
|
|
50
65
|
def _convert_access(access: str) -> str:
|
51
66
|
"""Convert access string to a value accepted by wandb."""
|
@@ -90,7 +105,20 @@ def _job_is_scheduler(run_spec: Dict[str, Any]) -> bool:
|
|
90
105
|
if not run_spec:
|
91
106
|
_logger.debug("Recieved runSpec in _job_is_scheduler that was empty")
|
92
107
|
|
93
|
-
|
108
|
+
if run_spec.get("uri") != SCHEDULER_URI:
|
109
|
+
return False
|
110
|
+
|
111
|
+
if run_spec.get("resource") == "local-process":
|
112
|
+
# If a scheduler is a local-process (100%), also
|
113
|
+
# confirm command is in format: [wandb scheduler <sweep>]
|
114
|
+
cmd = run_spec.get("overrides", {}).get("entry_point", [])
|
115
|
+
if len(cmd) < 3:
|
116
|
+
return False
|
117
|
+
|
118
|
+
if cmd[:2] != ["wandb", "scheduler"]:
|
119
|
+
return False
|
120
|
+
|
121
|
+
return True
|
94
122
|
|
95
123
|
|
96
124
|
class LaunchAgent:
|
@@ -119,7 +147,6 @@ class LaunchAgent:
|
|
119
147
|
self._max_schedulers = _max_from_config(config, "max_schedulers")
|
120
148
|
self._pool = ThreadPool(
|
121
149
|
processes=int(min(MAX_THREADS, self._max_jobs + self._max_schedulers)),
|
122
|
-
# initializer=init_pool_processes,
|
123
150
|
initargs=(self._jobs, self._jobs_lock),
|
124
151
|
)
|
125
152
|
self.default_config: Dict[str, Any] = config
|
@@ -128,6 +155,10 @@ class LaunchAgent:
|
|
128
155
|
self.gorilla_supports_agents = (
|
129
156
|
self._api.launch_agent_introspection() is not None
|
130
157
|
)
|
158
|
+
self._gorilla_supports_fail_run_queue_items = (
|
159
|
+
self._api.fail_run_queue_item_introspection()
|
160
|
+
)
|
161
|
+
|
131
162
|
self._queues = config.get("queues", ["default"])
|
132
163
|
create_response = self._api.create_launch_agent(
|
133
164
|
self._entity,
|
@@ -137,6 +168,14 @@ class LaunchAgent:
|
|
137
168
|
)
|
138
169
|
self._id = create_response["launchAgentId"]
|
139
170
|
self._name = "" # hacky: want to display this to the user but we don't get it back from gql until polling starts. fix later
|
171
|
+
if self._api.entity_is_team(self._entity):
|
172
|
+
wandb.termwarn(
|
173
|
+
f"{LOG_PREFIX}Agent is running on team entity ({self._entity}). Members of this team will be able to run code on this device."
|
174
|
+
)
|
175
|
+
|
176
|
+
def fail_run_queue_item(self, run_queue_item_id: str) -> None:
|
177
|
+
if self._gorilla_supports_fail_run_queue_items:
|
178
|
+
self._api.fail_run_queue_item(run_queue_item_id)
|
140
179
|
|
141
180
|
@property
|
142
181
|
def thread_ids(self) -> List[int]:
|
@@ -214,9 +253,28 @@ class LaunchAgent:
|
|
214
253
|
|
215
254
|
def finish_thread_id(self, thread_id: int) -> None:
|
216
255
|
"""Removes the job from our list for now."""
|
256
|
+
job_and_run_status = self._jobs[thread_id]
|
257
|
+
if not job_and_run_status.run_id or not job_and_run_status.project:
|
258
|
+
self.fail_run_queue_item(job_and_run_status.run_queue_item_id)
|
259
|
+
else:
|
260
|
+
run_info = None
|
261
|
+
# sweep runs exist but have no info before they are started
|
262
|
+
# so run_info returned will be None
|
263
|
+
# normal runs just throw a comm error
|
264
|
+
try:
|
265
|
+
run_info = self._api.get_run_info(
|
266
|
+
self._entity, job_and_run_status.project, job_and_run_status.run_id
|
267
|
+
)
|
268
|
+
|
269
|
+
except CommError:
|
270
|
+
pass
|
271
|
+
if run_info is None:
|
272
|
+
self.fail_run_queue_item(job_and_run_status.run_queue_item_id)
|
273
|
+
|
217
274
|
# TODO: keep logs or something for the finished jobs
|
218
275
|
with self._jobs_lock:
|
219
276
|
del self._jobs[thread_id]
|
277
|
+
|
220
278
|
# update status back to polling if no jobs are running
|
221
279
|
if len(self.thread_ids) == 0:
|
222
280
|
self.update_status(AGENT_POLLING)
|
@@ -295,16 +353,12 @@ class LaunchAgent:
|
|
295
353
|
|
296
354
|
try:
|
297
355
|
self.run_job(job)
|
298
|
-
except Exception:
|
356
|
+
except Exception as e:
|
299
357
|
wandb.termerror(
|
300
358
|
f"{LOG_PREFIX}Error running job: {traceback.format_exc()}"
|
301
359
|
)
|
302
|
-
|
303
|
-
|
304
|
-
except Exception:
|
305
|
-
_logger.error(
|
306
|
-
f"{LOG_PREFIX}Error acking job when job errored: {traceback.format_exc()}"
|
307
|
-
)
|
360
|
+
util.sentry_exc(e)
|
361
|
+
self.fail_run_queue_item(job["runQueueItemId"])
|
308
362
|
|
309
363
|
for thread_id in self.thread_ids:
|
310
364
|
self._update_finished(thread_id)
|
@@ -340,11 +394,20 @@ class LaunchAgent:
|
|
340
394
|
default_config: Dict[str, Any],
|
341
395
|
api: Api,
|
342
396
|
) -> None:
|
397
|
+
thread_id = threading.current_thread().ident
|
398
|
+
assert thread_id is not None
|
343
399
|
try:
|
344
|
-
self._thread_run_job(launch_spec, job, default_config, api)
|
345
|
-
except
|
400
|
+
self._thread_run_job(launch_spec, job, default_config, api, thread_id)
|
401
|
+
except LaunchDockerError as e:
|
402
|
+
wandb.termerror(
|
403
|
+
f"{LOG_PREFIX}agent {self._name} encountered an issue while starting Docker, see above output for details."
|
404
|
+
)
|
405
|
+
self.finish_thread_id(thread_id)
|
406
|
+
util.sentry_exc(e)
|
407
|
+
except Exception as e:
|
346
408
|
wandb.termerror(f"{LOG_PREFIX}Error running job: {traceback.format_exc()}")
|
347
|
-
|
409
|
+
self.finish_thread_id(thread_id)
|
410
|
+
util.sentry_exc(e)
|
348
411
|
|
349
412
|
def _thread_run_job(
|
350
413
|
self,
|
@@ -352,13 +415,13 @@ class LaunchAgent:
|
|
352
415
|
job: Dict[str, Any],
|
353
416
|
default_config: Dict[str, Any],
|
354
417
|
api: Api,
|
418
|
+
thread_id: int,
|
355
419
|
) -> None:
|
356
|
-
|
357
|
-
assert thread_id is not None
|
358
|
-
job_tracker = JobAndRunStatus()
|
420
|
+
job_tracker = JobAndRunStatus(job["runQueueItemId"])
|
359
421
|
with self._jobs_lock:
|
360
422
|
self._jobs[thread_id] = job_tracker
|
361
423
|
project = create_project_from_spec(launch_spec, api)
|
424
|
+
job_tracker.update_run_info(project)
|
362
425
|
_logger.info("Fetching and validating project...")
|
363
426
|
project = fetch_and_validate_project(project, api)
|
364
427
|
_logger.info("Fetching resource...")
|
@@ -366,8 +429,6 @@ class LaunchAgent:
|
|
366
429
|
backend_config: Dict[str, Any] = {
|
367
430
|
PROJECT_SYNCHRONOUS: False, # agent always runs async
|
368
431
|
}
|
369
|
-
|
370
|
-
backend_config["runQueueItemId"] = job["runQueueItemId"]
|
371
432
|
_logger.info("Loading backend")
|
372
433
|
override_build_config = launch_spec.get("builder")
|
373
434
|
|
@@ -382,6 +443,7 @@ class LaunchAgent:
|
|
382
443
|
builder = loader.builder_from_config(build_config, environment, registry)
|
383
444
|
backend = loader.runner_from_config(resource, api, backend_config, environment)
|
384
445
|
_logger.info("Backend loaded...")
|
446
|
+
api.ack_run_queue_item(job["runQueueItemId"], project.run_id)
|
385
447
|
run = backend.run(project, builder)
|
386
448
|
|
387
449
|
if _job_is_scheduler(launch_spec):
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import hashlib
|
1
2
|
import json
|
2
3
|
import logging
|
3
4
|
import os
|
@@ -21,7 +22,6 @@ from wandb.sdk.launch.loader import (
|
|
21
22
|
registry_from_config,
|
22
23
|
)
|
23
24
|
|
24
|
-
from ...lib.git import GitRepo
|
25
25
|
from .._project_spec import (
|
26
26
|
EntryPoint,
|
27
27
|
EntrypointDefaults,
|
@@ -132,6 +132,7 @@ PIP_TEMPLATE = """
|
|
132
132
|
RUN python -m venv /env
|
133
133
|
# make sure we install into the env
|
134
134
|
ENV PATH="/env/bin:$PATH"
|
135
|
+
|
135
136
|
COPY {requirements_files} ./
|
136
137
|
{buildx_optional_prefix} {pip_install}
|
137
138
|
"""
|
@@ -192,8 +193,8 @@ def get_base_setup(
|
|
192
193
|
CPU version is built on python, GPU version is built on nvidia:cuda.
|
193
194
|
"""
|
194
195
|
python_base_image = f"python:{py_version}-buster"
|
195
|
-
if launch_project.
|
196
|
-
|
196
|
+
if launch_project.cuda_base_image:
|
197
|
+
_logger.info(f"Using cuda base image: {launch_project.cuda_base_image}")
|
197
198
|
# cuda image doesn't come with python tooling
|
198
199
|
if py_major == "2":
|
199
200
|
python_packages = [
|
@@ -210,7 +211,7 @@ def get_base_setup(
|
|
210
211
|
"python3-setuptools",
|
211
212
|
]
|
212
213
|
base_setup = CUDA_SETUP_TEMPLATE.format(
|
213
|
-
cuda_base_image=
|
214
|
+
cuda_base_image=launch_project.cuda_base_image,
|
214
215
|
python_packages=" \\\n".join(python_packages),
|
215
216
|
py_version=py_version,
|
216
217
|
)
|
@@ -390,57 +391,6 @@ def generate_dockerfile(
|
|
390
391
|
return dockerfile_contents
|
391
392
|
|
392
393
|
|
393
|
-
_inspected_images = {}
|
394
|
-
|
395
|
-
|
396
|
-
def docker_image_exists(docker_image: str, should_raise: bool = False) -> bool:
|
397
|
-
"""Check if a specific image is already available.
|
398
|
-
|
399
|
-
Optionally raises an exception if the image is not found.
|
400
|
-
"""
|
401
|
-
_logger.info("Checking if base image exists...")
|
402
|
-
try:
|
403
|
-
data = docker.run(["docker", "image", "inspect", docker_image])
|
404
|
-
# always true, since return stderr defaults to false
|
405
|
-
assert isinstance(data, str)
|
406
|
-
parsed = json.loads(data)[0]
|
407
|
-
_inspected_images[docker_image] = parsed
|
408
|
-
return True
|
409
|
-
except (docker.DockerError, ValueError) as e:
|
410
|
-
if should_raise:
|
411
|
-
raise e
|
412
|
-
_logger.info("Base image not found. Generating new base image")
|
413
|
-
return False
|
414
|
-
|
415
|
-
|
416
|
-
def docker_image_inspect(docker_image: str) -> Dict[str, Any]:
|
417
|
-
"""Get the parsed json result of docker inspect image_name."""
|
418
|
-
if _inspected_images.get(docker_image) is None:
|
419
|
-
docker_image_exists(docker_image, True)
|
420
|
-
return _inspected_images.get(docker_image, {})
|
421
|
-
|
422
|
-
|
423
|
-
def pull_docker_image(docker_image: str) -> None:
|
424
|
-
"""Pull the requested docker image."""
|
425
|
-
if docker_image_exists(docker_image):
|
426
|
-
# don't pull images if they exist already, eg if they are local images
|
427
|
-
return
|
428
|
-
try:
|
429
|
-
docker.run(["docker", "pull", docker_image])
|
430
|
-
except docker.DockerError as e:
|
431
|
-
raise LaunchError(f"Docker server returned error: {e}")
|
432
|
-
|
433
|
-
|
434
|
-
def construct_gcp_image_uri(
|
435
|
-
launch_project: LaunchProject,
|
436
|
-
gcp_repo: str,
|
437
|
-
gcp_project: str,
|
438
|
-
gcp_registry: str,
|
439
|
-
) -> str:
|
440
|
-
base_uri = launch_project.image_uri
|
441
|
-
return "/".join([gcp_registry, gcp_project, gcp_repo, base_uri])
|
442
|
-
|
443
|
-
|
444
394
|
def construct_gcp_registry_uri(
|
445
395
|
gcp_repo: str, gcp_project: str, gcp_registry: str
|
446
396
|
) -> str:
|
@@ -474,24 +424,6 @@ def _parse_existing_requirements(launch_project: LaunchProject) -> str:
|
|
474
424
|
return requirements_line
|
475
425
|
|
476
426
|
|
477
|
-
def _get_docker_image_uri(name: Optional[str], work_dir: str, image_id: str) -> str:
|
478
|
-
"""Create a Docker image URI for a project.
|
479
|
-
|
480
|
-
The resulting URI is based on the git hash of the specified working directory.
|
481
|
-
:param name: The URI of the Docker repository with which to tag the image. The
|
482
|
-
repository URI is used as the prefix of the image URI.
|
483
|
-
:param work_dir: Path to the working directory in which to search for a git commit hash.
|
484
|
-
"""
|
485
|
-
name = name.replace(" ", "-") if name else "wandb-launch"
|
486
|
-
# Optionally include first 7 digits of git SHA in tag name, if available.
|
487
|
-
|
488
|
-
git_commit = GitRepo(work_dir).last_commit
|
489
|
-
version_string = (
|
490
|
-
":" + str(git_commit[:7]) + image_id if git_commit else ":" + image_id
|
491
|
-
)
|
492
|
-
return name + version_string
|
493
|
-
|
494
|
-
|
495
427
|
def _create_docker_build_ctx(
|
496
428
|
launch_project: LaunchProject,
|
497
429
|
dockerfile_contents: str,
|
@@ -537,7 +469,7 @@ def construct_builder_args(
|
|
537
469
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
538
470
|
registry_config = None
|
539
471
|
if launch_config is not None:
|
540
|
-
build_config = launch_config.get("
|
472
|
+
build_config = launch_config.get("builder")
|
541
473
|
registry_config = launch_config.get("registry")
|
542
474
|
|
543
475
|
default_launch_config = None
|
@@ -625,3 +557,13 @@ def build_image_from_project(
|
|
625
557
|
raise LaunchError("Error building image uri")
|
626
558
|
else:
|
627
559
|
return image_uri
|
560
|
+
|
561
|
+
|
562
|
+
def image_tag_from_dockerfile_and_source(
|
563
|
+
launch_project: LaunchProject, dockerfile_contents: str
|
564
|
+
) -> str:
|
565
|
+
"""Hashes the source and dockerfile contents into a unique tag."""
|
566
|
+
image_source_string = launch_project.get_image_source_string()
|
567
|
+
unique_id_string = image_source_string + dockerfile_contents
|
568
|
+
image_tag = hashlib.sha256(unique_id_string.encode("utf-8")).hexdigest()[:8]
|
569
|
+
return image_tag
|
@@ -16,10 +16,17 @@ from .._project_spec import (
|
|
16
16
|
get_entry_point_command,
|
17
17
|
)
|
18
18
|
from ..registry.local_registry import LocalRegistry
|
19
|
-
from ..utils import
|
19
|
+
from ..utils import (
|
20
|
+
LOG_PREFIX,
|
21
|
+
LaunchDockerError,
|
22
|
+
LaunchError,
|
23
|
+
sanitize_wandb_api_key,
|
24
|
+
warn_failed_packages_from_build_logs,
|
25
|
+
)
|
20
26
|
from .build import (
|
21
27
|
_create_docker_build_ctx,
|
22
28
|
generate_dockerfile,
|
29
|
+
image_tag_from_dockerfile_and_source,
|
23
30
|
validate_docker_installation,
|
24
31
|
)
|
25
32
|
|
@@ -110,15 +117,32 @@ class DockerBuilder(AbstractBuilder):
|
|
110
117
|
launch_project (LaunchProject): The project to build.
|
111
118
|
entrypoint (EntryPoint): The entrypoint to use.
|
112
119
|
"""
|
120
|
+
dockerfile_str = generate_dockerfile(
|
121
|
+
launch_project, entrypoint, launch_project.resource, "docker"
|
122
|
+
)
|
123
|
+
|
124
|
+
image_tag = image_tag_from_dockerfile_and_source(launch_project, dockerfile_str)
|
125
|
+
|
113
126
|
repository = None if not self.registry else self.registry.get_repo_uri()
|
127
|
+
# if repo is set, use the repo name as the image name
|
114
128
|
if repository:
|
115
|
-
image_uri = f"{repository}:{
|
129
|
+
image_uri = f"{repository}:{image_tag}"
|
130
|
+
# otherwise, base the image name off of the source
|
131
|
+
# which the launch_project checks in image_name
|
116
132
|
else:
|
117
|
-
image_uri = launch_project.
|
118
|
-
|
119
|
-
|
120
|
-
|
133
|
+
image_uri = f"{launch_project.image_name}:{image_tag}"
|
134
|
+
|
135
|
+
if not launch_project.build_required() and self.registry.check_image_exists(
|
136
|
+
image_uri
|
137
|
+
):
|
138
|
+
return image_uri
|
139
|
+
|
140
|
+
_logger.info(
|
141
|
+
f"image {image_uri} does not already exist in repository, building."
|
121
142
|
)
|
143
|
+
|
144
|
+
entry_cmd = get_entry_point_command(entrypoint, launch_project.override_args)
|
145
|
+
|
122
146
|
create_metadata_file(
|
123
147
|
launch_project,
|
124
148
|
image_uri,
|
@@ -128,9 +152,13 @@ class DockerBuilder(AbstractBuilder):
|
|
128
152
|
build_ctx_path = _create_docker_build_ctx(launch_project, dockerfile_str)
|
129
153
|
dockerfile = os.path.join(build_ctx_path, _GENERATED_DOCKERFILE_NAME)
|
130
154
|
try:
|
131
|
-
docker.build(
|
155
|
+
output = docker.build(
|
156
|
+
tags=[image_uri], file=dockerfile, context_path=build_ctx_path
|
157
|
+
)
|
158
|
+
warn_failed_packages_from_build_logs(output, image_uri)
|
159
|
+
|
132
160
|
except docker.DockerError as e:
|
133
|
-
raise
|
161
|
+
raise LaunchDockerError(f"Error communicating with docker client: {e}")
|
134
162
|
|
135
163
|
try:
|
136
164
|
os.remove(build_ctx_path)
|