wandb 0.16.5__py3-none-any.whl → 0.17.0rc1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- package_readme.md +95 -0
- wandb/__init__.py +2 -2
- wandb/agents/pyagent.py +0 -1
- wandb/analytics/sentry.py +2 -1
- wandb/apis/importers/internals/protocols.py +30 -56
- wandb/apis/importers/mlflow.py +13 -26
- wandb/apis/importers/wandb.py +8 -14
- wandb/apis/public/api.py +1 -0
- wandb/apis/public/artifacts.py +1 -0
- wandb/apis/public/files.py +1 -0
- wandb/apis/public/history.py +1 -0
- wandb/apis/public/jobs.py +1 -0
- wandb/apis/public/projects.py +1 -0
- wandb/apis/public/reports.py +1 -0
- wandb/apis/public/runs.py +1 -0
- wandb/apis/public/sweeps.py +1 -0
- wandb/apis/public/teams.py +1 -0
- wandb/apis/public/users.py +1 -0
- wandb/apis/reports/v1/_blocks.py +2 -6
- wandb/apis/reports/v2/gql.py +1 -0
- wandb/apis/reports/v2/interface.py +3 -4
- wandb/apis/reports/v2/internal.py +5 -8
- wandb/cli/cli.py +7 -4
- wandb/data_types.py +3 -3
- wandb/env.py +35 -5
- wandb/errors/__init__.py +5 -0
- wandb/integration/catboost/catboost.py +1 -1
- wandb/integration/fastai/__init__.py +1 -0
- wandb/integration/keras/__init__.py +1 -0
- wandb/integration/keras/keras.py +6 -6
- wandb/integration/langchain/wandb_tracer.py +1 -0
- wandb/integration/lightning/fabric/logger.py +1 -3
- wandb/integration/metaflow/metaflow.py +41 -6
- wandb/integration/openai/fine_tuning.py +77 -40
- wandb/keras/__init__.py +1 -0
- wandb/proto/v3/wandb_internal_pb2.py +364 -332
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_internal_pb2.py +322 -316
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/wandb_internal_codegen.py +0 -25
- wandb/sdk/artifacts/artifact.py +41 -13
- wandb/sdk/artifacts/artifact_download_logger.py +1 -0
- wandb/sdk/artifacts/artifact_file_cache.py +18 -4
- wandb/sdk/artifacts/artifact_instance_cache.py +1 -0
- wandb/sdk/artifacts/artifact_manifest.py +1 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +1 -0
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
- wandb/sdk/artifacts/artifact_saver.py +21 -21
- wandb/sdk/artifacts/artifact_state.py +1 -0
- wandb/sdk/artifacts/artifact_ttl.py +1 -0
- wandb/sdk/artifacts/exceptions.py +1 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +13 -18
- wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +5 -3
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +1 -0
- wandb/sdk/artifacts/storage_policy.py +1 -0
- wandb/sdk/data_types/base_types/media.py +3 -6
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +3 -1
- wandb/sdk/integration_utils/auto_logging.py +5 -6
- wandb/sdk/integration_utils/data_logging.py +5 -1
- wandb/sdk/interface/interface.py +72 -37
- wandb/sdk/interface/interface_shared.py +7 -13
- wandb/sdk/internal/datastore.py +1 -1
- wandb/sdk/internal/handler.py +18 -2
- wandb/sdk/internal/internal.py +0 -1
- wandb/sdk/internal/internal_util.py +0 -1
- wandb/sdk/internal/job_builder.py +4 -3
- wandb/sdk/internal/profiler.py +1 -0
- wandb/sdk/internal/run.py +1 -0
- wandb/sdk/internal/sender.py +1 -1
- wandb/sdk/internal/system/assets/gpu_amd.py +44 -44
- wandb/sdk/internal/system/assets/gpu_apple.py +56 -11
- wandb/sdk/internal/system/assets/interfaces.py +6 -8
- wandb/sdk/internal/system/assets/open_metrics.py +2 -2
- wandb/sdk/internal/system/assets/trainium.py +1 -3
- wandb/sdk/launch/_launch.py +5 -0
- wandb/sdk/launch/_project_spec.py +10 -23
- wandb/sdk/launch/agent/agent.py +81 -37
- wandb/sdk/launch/agent/config.py +80 -11
- wandb/sdk/launch/builder/abstract.py +1 -0
- wandb/sdk/launch/builder/build.py +28 -1
- wandb/sdk/launch/builder/docker_builder.py +1 -0
- wandb/sdk/launch/builder/kaniko_builder.py +149 -134
- wandb/sdk/launch/builder/noop.py +1 -0
- wandb/sdk/launch/create_job.py +61 -48
- wandb/sdk/launch/environment/abstract.py +1 -0
- wandb/sdk/launch/environment/gcp_environment.py +1 -0
- wandb/sdk/launch/environment/local_environment.py +1 -0
- wandb/sdk/launch/loader.py +1 -0
- wandb/sdk/launch/registry/abstract.py +1 -0
- wandb/sdk/launch/registry/azure_container_registry.py +1 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +1 -0
- wandb/sdk/launch/registry/google_artifact_registry.py +1 -0
- wandb/sdk/launch/registry/local_registry.py +1 -0
- wandb/sdk/launch/runner/abstract.py +1 -0
- wandb/sdk/launch/runner/kubernetes_monitor.py +4 -1
- wandb/sdk/launch/runner/kubernetes_runner.py +4 -3
- wandb/sdk/launch/runner/sagemaker_runner.py +11 -10
- wandb/sdk/launch/sweeps/scheduler.py +4 -1
- wandb/sdk/launch/sweeps/scheduler_sweep.py +1 -0
- wandb/sdk/launch/sweeps/utils.py +1 -1
- wandb/sdk/launch/utils.py +21 -3
- wandb/sdk/lib/_settings_toposort_generated.py +1 -0
- wandb/sdk/lib/fsm.py +8 -12
- wandb/sdk/lib/gitlib.py +4 -4
- wandb/sdk/lib/lazyloader.py +0 -1
- wandb/sdk/lib/proto_util.py +1 -1
- wandb/sdk/lib/retry.py +3 -2
- wandb/sdk/lib/run_moment.py +7 -1
- wandb/sdk/service/service.py +17 -15
- wandb/sdk/verify/verify.py +2 -1
- wandb/sdk/wandb_init.py +2 -8
- wandb/sdk/wandb_manager.py +2 -2
- wandb/sdk/wandb_require.py +5 -0
- wandb/sdk/wandb_run.py +64 -46
- wandb/sdk/wandb_settings.py +2 -1
- wandb/sklearn/__init__.py +1 -0
- wandb/sklearn/plot/__init__.py +1 -0
- wandb/sklearn/plot/classifier.py +1 -0
- wandb/sklearn/plot/clusterer.py +1 -0
- wandb/sklearn/plot/regressor.py +1 -0
- wandb/sklearn/plot/shared.py +1 -0
- wandb/sklearn/utils.py +1 -0
- wandb/testing/relay.py +4 -4
- wandb/trigger.py +1 -0
- wandb/util.py +40 -17
- wandb/wandb_controller.py +0 -1
- wandb/wandb_torch.py +1 -2
- {wandb-0.16.5.dist-info → wandb-0.17.0rc1.dist-info}/METADATA +68 -69
- {wandb-0.16.5.dist-info → wandb-0.17.0rc1.dist-info}/RECORD +139 -140
- {wandb-0.16.5.dist-info → wandb-0.17.0rc1.dist-info}/WHEEL +1 -2
- wandb/bin/apple_gpu_stats +0 -0
- wandb-0.16.5.dist-info/top_level.txt +0 -1
- {wandb-0.16.5.dist-info → wandb-0.17.0rc1.dist-info}/entry_points.txt +0 -0
- {wandb-0.16.5.dist-info → wandb-0.17.0rc1.dist-info/licenses}/LICENSE +0 -0
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
Arguments can come from a launch spec or call to wandb launch.
|
4
4
|
"""
|
5
|
+
|
5
6
|
import enum
|
6
7
|
import logging
|
7
8
|
import os
|
@@ -14,6 +15,7 @@ import wandb.docker as docker
|
|
14
15
|
from wandb.apis.internal import Api
|
15
16
|
from wandb.errors import CommError
|
16
17
|
from wandb.sdk.launch import utils
|
18
|
+
from wandb.sdk.launch.utils import get_entrypoint_file
|
17
19
|
from wandb.sdk.lib.runid import generate_id
|
18
20
|
|
19
21
|
from .errors import LaunchError
|
@@ -119,6 +121,7 @@ class LaunchProject:
|
|
119
121
|
self.override_args: List[str] = overrides.get("args", [])
|
120
122
|
self.override_config: Dict[str, Any] = overrides.get("run_config", {})
|
121
123
|
self.override_artifacts: Dict[str, Any] = overrides.get("artifacts", {})
|
124
|
+
self.override_files: Dict[str, Any] = overrides.get("files", {})
|
122
125
|
self.override_entrypoint: Optional[EntryPoint] = None
|
123
126
|
self.override_dockerfile: Optional[str] = overrides.get("dockerfile")
|
124
127
|
self.deps_type: Optional[str] = None
|
@@ -127,15 +130,15 @@ class LaunchProject:
|
|
127
130
|
self._queue_name: Optional[str] = None
|
128
131
|
self._queue_entity: Optional[str] = None
|
129
132
|
self._run_queue_item_id: Optional[str] = None
|
130
|
-
self._entry_point: Optional[
|
131
|
-
|
132
|
-
|
133
|
+
self._entry_point: Optional[EntryPoint] = (
|
134
|
+
None # todo: keep multiple entrypoint support?
|
135
|
+
)
|
133
136
|
|
134
137
|
override_entrypoint = overrides.get("entry_point")
|
135
138
|
if override_entrypoint:
|
136
139
|
_logger.info("Adding override entry point")
|
137
140
|
self.override_entrypoint = EntryPoint(
|
138
|
-
name=
|
141
|
+
name=get_entrypoint_file(override_entrypoint),
|
139
142
|
command=override_entrypoint,
|
140
143
|
)
|
141
144
|
|
@@ -536,24 +539,6 @@ class LaunchProject:
|
|
536
539
|
self.git_version = branch_name
|
537
540
|
|
538
541
|
|
539
|
-
def _get_entrypoint_file(entrypoint: List[str]) -> Optional[str]:
|
540
|
-
"""Get the entrypoint file from the given command.
|
541
|
-
|
542
|
-
Args:
|
543
|
-
entrypoint (List[str]): List of command and arguments.
|
544
|
-
|
545
|
-
Returns:
|
546
|
-
Optional[str]: The entrypoint file if found, otherwise None.
|
547
|
-
"""
|
548
|
-
if not entrypoint:
|
549
|
-
return None
|
550
|
-
if entrypoint[0].endswith(".py") or entrypoint[0].endswith(".sh"):
|
551
|
-
return entrypoint[0]
|
552
|
-
if len(entrypoint) < 2:
|
553
|
-
return None
|
554
|
-
return entrypoint[1]
|
555
|
-
|
556
|
-
|
557
542
|
class EntryPoint:
|
558
543
|
"""An entry point into a wandb launch specification."""
|
559
544
|
|
@@ -570,7 +555,9 @@ class EntryPoint:
|
|
570
555
|
|
571
556
|
def update_entrypoint_path(self, new_path: str) -> None:
|
572
557
|
"""Updates the entrypoint path to a new path."""
|
573
|
-
if len(self.command) == 2 and
|
558
|
+
if len(self.command) == 2 and (
|
559
|
+
self.command[0].startswith("python") or self.command[0] == "bash"
|
560
|
+
):
|
574
561
|
self.command[1] = new_path
|
575
562
|
|
576
563
|
|
wandb/sdk/launch/agent/agent.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
"""Implementation of launch agent."""
|
2
|
+
|
2
3
|
import asyncio
|
3
4
|
import logging
|
4
5
|
import os
|
@@ -45,7 +46,10 @@ MAX_RESUME_COUNT = 5
|
|
45
46
|
|
46
47
|
RUN_INFO_GRACE_PERIOD = 60
|
47
48
|
|
48
|
-
|
49
|
+
DEFAULT_STOPPED_RUN_TIMEOUT = 60
|
50
|
+
|
51
|
+
DEFAULT_PRINT_INTERVAL = 5 * 60
|
52
|
+
VERBOSE_PRINT_INTERVAL = 20
|
49
53
|
|
50
54
|
_env_timeout = os.environ.get("WANDB_LAUNCH_START_TIMEOUT")
|
51
55
|
if _env_timeout:
|
@@ -105,30 +109,29 @@ def _max_from_config(
|
|
105
109
|
return max_from_config
|
106
110
|
|
107
111
|
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
_logger.debug("Recieved runSpec in _is_scheduler_job that was empty")
|
112
|
+
class InternalAgentLogger:
|
113
|
+
def __init__(self, verbosity=0):
|
114
|
+
self._print_to_terminal = verbosity >= 2
|
112
115
|
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
# Any job pushed to a run queue that has a scheduler uri is
|
118
|
-
# allowed to use local-process
|
119
|
-
if run_spec.get("job"):
|
120
|
-
return True
|
116
|
+
def error(self, message: str):
|
117
|
+
if self._print_to_terminal:
|
118
|
+
wandb.termerror(f"{LOG_PREFIX}{message}")
|
119
|
+
_logger.error(f"{LOG_PREFIX}{message}")
|
121
120
|
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
return False
|
121
|
+
def warn(self, message: str):
|
122
|
+
if self._print_to_terminal:
|
123
|
+
wandb.termwarn(f"{LOG_PREFIX}{message}")
|
124
|
+
_logger.warn(f"{LOG_PREFIX}{message}")
|
127
125
|
|
128
|
-
|
129
|
-
|
126
|
+
def info(self, message: str):
|
127
|
+
if self._print_to_terminal:
|
128
|
+
wandb.termlog(f"{LOG_PREFIX}{message}")
|
129
|
+
_logger.info(f"{LOG_PREFIX}{message}")
|
130
130
|
|
131
|
-
|
131
|
+
def debug(self, message: str):
|
132
|
+
if self._print_to_terminal:
|
133
|
+
wandb.termlog(f"{LOG_PREFIX}{message}")
|
134
|
+
_logger.debug(f"{LOG_PREFIX}{message}")
|
132
135
|
|
133
136
|
|
134
137
|
class LaunchAgent:
|
@@ -184,7 +187,13 @@ class LaunchAgent:
|
|
184
187
|
self._max_jobs = _max_from_config(config, "max_jobs")
|
185
188
|
self._max_schedulers = _max_from_config(config, "max_schedulers")
|
186
189
|
self._secure_mode = config.get("secure_mode", False)
|
190
|
+
self._verbosity = config.get("verbosity", 0)
|
191
|
+
self._internal_logger = InternalAgentLogger(verbosity=self._verbosity)
|
192
|
+
self._last_status_print_time = 0.0
|
187
193
|
self.default_config: Dict[str, Any] = config
|
194
|
+
self._stopped_run_timeout = config.get(
|
195
|
+
"stopped_run_timeout", DEFAULT_STOPPED_RUN_TIMEOUT
|
196
|
+
)
|
188
197
|
|
189
198
|
# Get agent version from env var if present, otherwise wandb version
|
190
199
|
self.version: str = "wandb@" + wandb.__version__
|
@@ -228,6 +237,33 @@ class LaunchAgent:
|
|
228
237
|
self._name = agent_response["name"]
|
229
238
|
self._init_agent_run()
|
230
239
|
|
240
|
+
def _is_scheduler_job(self, run_spec: Dict[str, Any]) -> bool:
|
241
|
+
"""Determine whether a job/runSpec is a sweep scheduler."""
|
242
|
+
if not run_spec:
|
243
|
+
self._internal_logger.debug(
|
244
|
+
"Recieved runSpec in _is_scheduler_job that was empty"
|
245
|
+
)
|
246
|
+
|
247
|
+
if run_spec.get("uri") != Scheduler.PLACEHOLDER_URI:
|
248
|
+
return False
|
249
|
+
|
250
|
+
if run_spec.get("resource") == "local-process":
|
251
|
+
# Any job pushed to a run queue that has a scheduler uri is
|
252
|
+
# allowed to use local-process
|
253
|
+
if run_spec.get("job"):
|
254
|
+
return True
|
255
|
+
|
256
|
+
# If a scheduler is local-process and run through CLI, also
|
257
|
+
# confirm command is in format: [wandb scheduler <sweep>]
|
258
|
+
cmd = run_spec.get("overrides", {}).get("entry_point", [])
|
259
|
+
if len(cmd) < 3:
|
260
|
+
return False
|
261
|
+
|
262
|
+
if cmd[:2] != ["wandb", "scheduler"]:
|
263
|
+
return False
|
264
|
+
|
265
|
+
return True
|
266
|
+
|
231
267
|
async def fail_run_queue_item(
|
232
268
|
self,
|
233
269
|
run_queue_item_id: str,
|
@@ -298,6 +334,7 @@ class LaunchAgent:
|
|
298
334
|
|
299
335
|
def print_status(self) -> None:
|
300
336
|
"""Prints the current status of the agent."""
|
337
|
+
self._last_status_print_time = time.time()
|
301
338
|
output_str = "agent "
|
302
339
|
if self._name:
|
303
340
|
output_str += f"{self._name} "
|
@@ -344,8 +381,8 @@ class LaunchAgent:
|
|
344
381
|
if run_state.lower() != "pending":
|
345
382
|
return True
|
346
383
|
except CommError:
|
347
|
-
|
348
|
-
f"Run {entity}/{project}/{run_id} with rqi id: {rqi_id} did not have associated run"
|
384
|
+
self._internal_logger.info(
|
385
|
+
f"Run {entity}/{project}/{run_id} with rqi id: {rqi_id} did not have associated run",
|
349
386
|
)
|
350
387
|
return False
|
351
388
|
|
@@ -361,8 +398,8 @@ class LaunchAgent:
|
|
361
398
|
job_and_run_status.entity is not None
|
362
399
|
and job_and_run_status.entity != self._entity
|
363
400
|
):
|
364
|
-
|
365
|
-
"Skipping check for completed run status because run is on a different entity than agent"
|
401
|
+
self._internal_logger.info(
|
402
|
+
"Skipping check for completed run status because run is on a different entity than agent",
|
366
403
|
)
|
367
404
|
elif exception is not None:
|
368
405
|
tb_str = traceback.format_exception(
|
@@ -378,8 +415,8 @@ class LaunchAgent:
|
|
378
415
|
fnames,
|
379
416
|
)
|
380
417
|
elif job_and_run_status.project is None or job_and_run_status.run_id is None:
|
381
|
-
|
382
|
-
f"called finish_thread_id on thread whose tracker has no project or run id. RunQueueItemID: {job_and_run_status.run_queue_item_id}"
|
418
|
+
self._internal_logger.info(
|
419
|
+
f"called finish_thread_id on thread whose tracker has no project or run id. RunQueueItemID: {job_and_run_status.run_queue_item_id}",
|
383
420
|
)
|
384
421
|
wandb.termerror(
|
385
422
|
"Missing project or run id on thread called finish thread id"
|
@@ -430,7 +467,9 @@ class LaunchAgent:
|
|
430
467
|
job_and_run_status.run_queue_item_id, _msg, "run", fnames
|
431
468
|
)
|
432
469
|
else:
|
433
|
-
|
470
|
+
self._internal_logger.info(
|
471
|
+
f"Finish thread id {thread_id} had no exception and no run"
|
472
|
+
)
|
434
473
|
wandb._sentry.exception(
|
435
474
|
"launch agent called finish thread id on thread without run or exception"
|
436
475
|
)
|
@@ -458,7 +497,7 @@ class LaunchAgent:
|
|
458
497
|
await self.update_status(AGENT_RUNNING)
|
459
498
|
|
460
499
|
# parse job
|
461
|
-
|
500
|
+
self._internal_logger.info("Parsing launch spec")
|
462
501
|
launch_spec = job["runSpec"]
|
463
502
|
|
464
503
|
# Abort if this job attempts to override secure mode
|
@@ -511,6 +550,10 @@ class LaunchAgent:
|
|
511
550
|
KeyboardInterrupt: if the agent is requested to stop.
|
512
551
|
"""
|
513
552
|
self.print_status()
|
553
|
+
if self._verbosity == 0:
|
554
|
+
print_interval = DEFAULT_PRINT_INTERVAL
|
555
|
+
else:
|
556
|
+
print_interval = VERBOSE_PRINT_INTERVAL
|
514
557
|
try:
|
515
558
|
while True:
|
516
559
|
job = None
|
@@ -532,7 +575,7 @@ class LaunchAgent:
|
|
532
575
|
file_saver = RunQueueItemFileSaver(
|
533
576
|
self._wandb_run, job["runQueueItemId"]
|
534
577
|
)
|
535
|
-
if _is_scheduler_job(job.get("runSpec", {})):
|
578
|
+
if self._is_scheduler_job(job.get("runSpec", {})):
|
536
579
|
# If job is a scheduler, and we are already at the cap, ignore,
|
537
580
|
# don't ack, and it will be pushed back onto the queue in 1 min
|
538
581
|
if self.num_running_schedulers >= self._max_schedulers:
|
@@ -567,6 +610,7 @@ class LaunchAgent:
|
|
567
610
|
await self.update_status(AGENT_POLLING)
|
568
611
|
else:
|
569
612
|
await self.update_status(AGENT_RUNNING)
|
613
|
+
if time.time() - self._last_status_print_time > print_interval:
|
570
614
|
self.print_status()
|
571
615
|
|
572
616
|
if self.num_running_jobs == self._max_jobs or job is None:
|
@@ -634,14 +678,14 @@ class LaunchAgent:
|
|
634
678
|
await self.check_sweep_state(launch_spec, api)
|
635
679
|
|
636
680
|
job_tracker.update_run_info(project)
|
637
|
-
|
681
|
+
self._internal_logger.info("Fetching and validating project...")
|
638
682
|
project.fetch_and_validate_project()
|
639
|
-
|
683
|
+
self._internal_logger.info("Fetching resource...")
|
640
684
|
resource = launch_spec.get("resource") or "local-container"
|
641
685
|
backend_config: Dict[str, Any] = {
|
642
686
|
PROJECT_SYNCHRONOUS: False, # agent always runs async
|
643
687
|
}
|
644
|
-
|
688
|
+
self._internal_logger.info("Loading backend")
|
645
689
|
override_build_config = launch_spec.get("builder")
|
646
690
|
|
647
691
|
_, build_config, registry_config = construct_agent_configs(
|
@@ -661,13 +705,13 @@ class LaunchAgent:
|
|
661
705
|
assert entrypoint is not None
|
662
706
|
image_uri = await builder.build_image(project, entrypoint, job_tracker)
|
663
707
|
|
664
|
-
|
708
|
+
self._internal_logger.info("Backend loaded...")
|
665
709
|
if isinstance(backend, LocalProcessRunner):
|
666
710
|
run = await backend.run(project, image_uri)
|
667
711
|
else:
|
668
712
|
assert image_uri
|
669
713
|
run = await backend.run(project, image_uri)
|
670
|
-
if _is_scheduler_job(launch_spec):
|
714
|
+
if self._is_scheduler_job(launch_spec):
|
671
715
|
with self._jobs_lock:
|
672
716
|
self._jobs[thread_id].is_scheduler = True
|
673
717
|
wandb.termlog(
|
@@ -700,7 +744,7 @@ class LaunchAgent:
|
|
700
744
|
if stopped_time is None:
|
701
745
|
stopped_time = time.time()
|
702
746
|
else:
|
703
|
-
if time.time() - stopped_time >
|
747
|
+
if time.time() - stopped_time > self._stopped_run_timeout:
|
704
748
|
await run.cancel()
|
705
749
|
await asyncio.sleep(AGENT_POLLING_INTERVAL)
|
706
750
|
|
@@ -720,7 +764,7 @@ class LaunchAgent:
|
|
720
764
|
project=launch_spec["project"],
|
721
765
|
)
|
722
766
|
except Exception as e:
|
723
|
-
|
767
|
+
self._internal_logger.debug(f"Fetch sweep state error: {e}")
|
724
768
|
state = None
|
725
769
|
|
726
770
|
if state != "RUNNING" and state != "PAUSED":
|
wandb/sdk/launch/agent/config.py
CHANGED
@@ -80,17 +80,7 @@ class RegistryConfig(BaseModel):
|
|
80
80
|
@validator("uri") # type: ignore
|
81
81
|
@classmethod
|
82
82
|
def validate_uri(cls, uri: str) -> str:
|
83
|
-
|
84
|
-
GCP_ARTIFACT_REGISTRY_URI_REGEX,
|
85
|
-
AZURE_CONTAINER_REGISTRY_URI_REGEX,
|
86
|
-
ELASTIC_CONTAINER_REGISTRY_URI_REGEX,
|
87
|
-
]:
|
88
|
-
if regex.match(uri):
|
89
|
-
return uri
|
90
|
-
raise ValueError(
|
91
|
-
"Invalid uri. URI must be a repository URI for an "
|
92
|
-
"ECR, ACR, or GCP Artifact Registry."
|
93
|
-
)
|
83
|
+
return validate_registry_uri(uri)
|
94
84
|
|
95
85
|
|
96
86
|
class EnvironmentConfig(BaseModel):
|
@@ -186,6 +176,14 @@ class BuilderConfig(BaseModel):
|
|
186
176
|
"""Right now there are no required fields for docker builds."""
|
187
177
|
return values
|
188
178
|
|
179
|
+
@validator("destination") # type: ignore
|
180
|
+
@classmethod
|
181
|
+
def validate_destination(cls, destination: Optional[str]) -> Optional[str]:
|
182
|
+
"""Validate that the destination is a valid container registry URI."""
|
183
|
+
if destination is None:
|
184
|
+
return None
|
185
|
+
return validate_registry_uri(destination)
|
186
|
+
|
189
187
|
|
190
188
|
class AgentConfig(BaseModel):
|
191
189
|
"""Configuration for the Launch agent."""
|
@@ -225,6 +223,77 @@ class AgentConfig(BaseModel):
|
|
225
223
|
None,
|
226
224
|
description="The builder to use.",
|
227
225
|
)
|
226
|
+
verbosity: Optional[int] = Field(
|
227
|
+
0,
|
228
|
+
description="How verbose to print, 0 = default, 1 = verbose, 2 = very verbose",
|
229
|
+
)
|
230
|
+
stopped_run_timeout: Optional[int] = Field(
|
231
|
+
60,
|
232
|
+
description="How many seconds to wait after receiving the stop command before forcibly cancelling a run.",
|
233
|
+
)
|
228
234
|
|
229
235
|
class Config:
|
230
236
|
extra = "forbid"
|
237
|
+
|
238
|
+
|
239
|
+
def validate_registry_uri(uri: str) -> str:
|
240
|
+
"""Validate that the registry URI is a valid container registry URI.
|
241
|
+
|
242
|
+
The URI should resolve to an image name in a container registry. The recognized
|
243
|
+
formats are for ECR, ACR, and GCP Artifact Registry. If the URI does not match
|
244
|
+
any of these formats, a warning is printed indicating the registry type is not
|
245
|
+
recognized and the agent can't guarantee that images can be pushed.
|
246
|
+
|
247
|
+
If the format is recognized but does not resolve to an image name, an
|
248
|
+
error is raised. For example, if the URI is an ECR URI but does not include
|
249
|
+
an image name or includes a tag as well as an image name, an error is raised.
|
250
|
+
"""
|
251
|
+
tag_msg = (
|
252
|
+
"Destination for built images may not include a tag, but the URI provided "
|
253
|
+
"includes the suffix '{tag}'. Please remove the tag and try again. The agent "
|
254
|
+
"will automatically tag each image with a unique hash of the source code."
|
255
|
+
)
|
256
|
+
if uri.startswith("https://"):
|
257
|
+
uri = uri[8:]
|
258
|
+
|
259
|
+
match = GCP_ARTIFACT_REGISTRY_URI_REGEX.match(uri)
|
260
|
+
if match:
|
261
|
+
if match.group("tag"):
|
262
|
+
raise ValueError(tag_msg.format(tag=match.group("tag")))
|
263
|
+
if not match.group("image_name"):
|
264
|
+
raise ValueError(
|
265
|
+
"An image name must be specified in the URI for a GCP Artifact Registry. "
|
266
|
+
"Please provide a uri with the format "
|
267
|
+
"'https://<region>-docker.pkg.dev/<project>/<repository>/<image>'."
|
268
|
+
)
|
269
|
+
return uri
|
270
|
+
|
271
|
+
match = AZURE_CONTAINER_REGISTRY_URI_REGEX.match(uri)
|
272
|
+
if match:
|
273
|
+
if match.group("tag"):
|
274
|
+
raise ValueError(tag_msg.format(tag=match.group("tag")))
|
275
|
+
if not match.group("repository"):
|
276
|
+
raise ValueError(
|
277
|
+
"A repository name must be specified in the URI for an "
|
278
|
+
"Azure Container Registry. Please provide a uri with the format "
|
279
|
+
"'https://<registry-name>.azurecr.io/<repository>'."
|
280
|
+
)
|
281
|
+
return uri
|
282
|
+
|
283
|
+
match = ELASTIC_CONTAINER_REGISTRY_URI_REGEX.match(uri)
|
284
|
+
if match:
|
285
|
+
if match.group("tag"):
|
286
|
+
raise ValueError(tag_msg.format(tag=match.group("tag")))
|
287
|
+
if not match.group("repository"):
|
288
|
+
raise ValueError(
|
289
|
+
"A repository name must be specified in the URI for an "
|
290
|
+
"Elastic Container Registry. Please provide a uri with the format "
|
291
|
+
"'https://<account-id>.dkr.ecr.<region>.amazonaws.com/<repository>'."
|
292
|
+
)
|
293
|
+
return uri
|
294
|
+
|
295
|
+
wandb.termwarn(
|
296
|
+
f"Unable to recognize registry type in URI {uri}. You are responsible "
|
297
|
+
"for ensuring the agent can push images to this registry."
|
298
|
+
)
|
299
|
+
return uri
|
@@ -237,7 +237,11 @@ def get_base_setup(
|
|
237
237
|
|
238
238
|
CPU version is built on python, Accelerator version is built on user provided.
|
239
239
|
"""
|
240
|
-
|
240
|
+
minor = int(py_version.split(".")[1])
|
241
|
+
if minor < 12:
|
242
|
+
python_base_image = f"python:{py_version}-buster"
|
243
|
+
else:
|
244
|
+
python_base_image = f"python:{py_version}-bookworm"
|
241
245
|
if launch_project.accelerator_base_image:
|
242
246
|
_logger.info(
|
243
247
|
f"Using accelerator base image: {launch_project.accelerator_base_image}"
|
@@ -311,6 +315,11 @@ def get_env_vars_dict(
|
|
311
315
|
_inject_wandb_config_env_vars(
|
312
316
|
launch_project.override_config, env_vars, max_env_length
|
313
317
|
)
|
318
|
+
|
319
|
+
_inject_file_overrides_env_vars(
|
320
|
+
launch_project.override_files, env_vars, max_env_length
|
321
|
+
)
|
322
|
+
|
314
323
|
artifacts = {}
|
315
324
|
# if we're spinning up a launch process from a job
|
316
325
|
# we should tell the run to use that artifact
|
@@ -677,3 +686,21 @@ def _inject_wandb_config_env_vars(
|
|
677
686
|
]
|
678
687
|
config_chunks_dict = {f"WANDB_CONFIG_{i}": chunk for i, chunk in enumerate(chunks)}
|
679
688
|
env_dict.update(config_chunks_dict)
|
689
|
+
|
690
|
+
|
691
|
+
def _inject_file_overrides_env_vars(
|
692
|
+
overrides: Dict[str, Any], env_dict: Dict[str, Any], maximum_env_length: int
|
693
|
+
) -> None:
|
694
|
+
str_overrides = json.dumps(overrides)
|
695
|
+
if len(str_overrides) <= maximum_env_length:
|
696
|
+
env_dict["WANDB_LAUNCH_FILE_OVERRIDES"] = str_overrides
|
697
|
+
return
|
698
|
+
|
699
|
+
chunks = [
|
700
|
+
str_overrides[i : i + maximum_env_length]
|
701
|
+
for i in range(0, len(str_overrides), maximum_env_length)
|
702
|
+
]
|
703
|
+
overrides_chunks_dict = {
|
704
|
+
f"WANDB_LAUNCH_FILE_OVERRIDES_{i}": chunk for i, chunk in enumerate(chunks)
|
705
|
+
}
|
706
|
+
env_dict.update(overrides_chunks_dict)
|