wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (156) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/analytics/sentry.py +1 -0
  3. wandb/apis/importers/base.py +20 -5
  4. wandb/apis/importers/mlflow.py +7 -1
  5. wandb/apis/internal.py +12 -0
  6. wandb/apis/public.py +247 -1387
  7. wandb/apis/reports/_panels.py +58 -35
  8. wandb/beta/workflows.py +6 -7
  9. wandb/cli/cli.py +130 -60
  10. wandb/data_types.py +3 -1
  11. wandb/filesync/dir_watcher.py +21 -27
  12. wandb/filesync/step_checksum.py +8 -8
  13. wandb/filesync/step_prepare.py +23 -10
  14. wandb/filesync/step_upload.py +13 -13
  15. wandb/filesync/upload_job.py +4 -8
  16. wandb/integration/cohere/__init__.py +3 -0
  17. wandb/integration/cohere/cohere.py +21 -0
  18. wandb/integration/cohere/resolver.py +347 -0
  19. wandb/integration/gym/__init__.py +4 -6
  20. wandb/integration/huggingface/__init__.py +3 -0
  21. wandb/integration/huggingface/huggingface.py +18 -0
  22. wandb/integration/huggingface/resolver.py +213 -0
  23. wandb/integration/langchain/wandb_tracer.py +16 -179
  24. wandb/integration/openai/__init__.py +1 -3
  25. wandb/integration/openai/openai.py +11 -143
  26. wandb/integration/openai/resolver.py +111 -38
  27. wandb/integration/sagemaker/config.py +2 -2
  28. wandb/integration/tensorboard/log.py +4 -4
  29. wandb/old/settings.py +24 -7
  30. wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
  31. wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
  32. wandb/proto/wandb_deprecated.py +3 -1
  33. wandb/sdk/__init__.py +1 -1
  34. wandb/sdk/artifacts/__init__.py +0 -0
  35. wandb/sdk/artifacts/artifact.py +2101 -0
  36. wandb/sdk/artifacts/artifact_download_logger.py +42 -0
  37. wandb/sdk/artifacts/artifact_manifest.py +67 -0
  38. wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
  39. wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
  40. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
  41. wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
  42. wandb/sdk/artifacts/artifact_state.py +10 -0
  43. wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
  44. wandb/sdk/artifacts/exceptions.py +55 -0
  45. wandb/sdk/artifacts/storage_handler.py +59 -0
  46. wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
  47. wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
  48. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
  49. wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
  50. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
  51. wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
  52. wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
  53. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
  54. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
  55. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
  56. wandb/sdk/artifacts/storage_layout.py +6 -0
  57. wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
  58. wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
  59. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
  60. wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
  61. wandb/sdk/data_types/_dtypes.py +7 -12
  62. wandb/sdk/data_types/base_types/json_metadata.py +3 -2
  63. wandb/sdk/data_types/base_types/media.py +8 -8
  64. wandb/sdk/data_types/base_types/wb_value.py +12 -13
  65. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
  66. wandb/sdk/data_types/helper_types/classes.py +6 -8
  67. wandb/sdk/data_types/helper_types/image_mask.py +5 -6
  68. wandb/sdk/data_types/histogram.py +4 -3
  69. wandb/sdk/data_types/html.py +3 -4
  70. wandb/sdk/data_types/image.py +11 -9
  71. wandb/sdk/data_types/molecule.py +5 -3
  72. wandb/sdk/data_types/object_3d.py +7 -5
  73. wandb/sdk/data_types/plotly.py +3 -2
  74. wandb/sdk/data_types/saved_model.py +11 -11
  75. wandb/sdk/data_types/trace_tree.py +5 -4
  76. wandb/sdk/data_types/utils.py +3 -5
  77. wandb/sdk/data_types/video.py +5 -4
  78. wandb/sdk/integration_utils/auto_logging.py +215 -0
  79. wandb/sdk/interface/interface.py +15 -15
  80. wandb/sdk/internal/file_pusher.py +8 -16
  81. wandb/sdk/internal/file_stream.py +5 -11
  82. wandb/sdk/internal/handler.py +13 -1
  83. wandb/sdk/internal/internal_api.py +287 -13
  84. wandb/sdk/internal/job_builder.py +119 -30
  85. wandb/sdk/internal/sender.py +6 -26
  86. wandb/sdk/internal/settings_static.py +2 -0
  87. wandb/sdk/internal/system/assets/__init__.py +2 -0
  88. wandb/sdk/internal/system/assets/gpu.py +42 -0
  89. wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
  90. wandb/sdk/internal/system/env_probe_helpers.py +13 -0
  91. wandb/sdk/internal/system/system_info.py +3 -3
  92. wandb/sdk/internal/tb_watcher.py +32 -22
  93. wandb/sdk/internal/thread_local_settings.py +18 -0
  94. wandb/sdk/launch/_project_spec.py +57 -11
  95. wandb/sdk/launch/agent/agent.py +147 -65
  96. wandb/sdk/launch/agent/job_status_tracker.py +34 -0
  97. wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
  98. wandb/sdk/launch/builder/abstract.py +5 -1
  99. wandb/sdk/launch/builder/build.py +21 -18
  100. wandb/sdk/launch/builder/docker_builder.py +10 -4
  101. wandb/sdk/launch/builder/kaniko_builder.py +113 -23
  102. wandb/sdk/launch/builder/noop.py +6 -3
  103. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
  104. wandb/sdk/launch/environment/aws_environment.py +3 -2
  105. wandb/sdk/launch/environment/azure_environment.py +124 -0
  106. wandb/sdk/launch/environment/gcp_environment.py +2 -4
  107. wandb/sdk/launch/environment/local_environment.py +1 -1
  108. wandb/sdk/launch/errors.py +19 -0
  109. wandb/sdk/launch/github_reference.py +32 -19
  110. wandb/sdk/launch/launch.py +3 -8
  111. wandb/sdk/launch/launch_add.py +6 -2
  112. wandb/sdk/launch/loader.py +21 -2
  113. wandb/sdk/launch/registry/azure_container_registry.py +132 -0
  114. wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
  115. wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
  116. wandb/sdk/launch/registry/local_registry.py +2 -1
  117. wandb/sdk/launch/runner/abstract.py +24 -3
  118. wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
  119. wandb/sdk/launch/runner/local_container.py +103 -51
  120. wandb/sdk/launch/runner/local_process.py +1 -1
  121. wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
  122. wandb/sdk/launch/runner/vertex_runner.py +10 -5
  123. wandb/sdk/launch/sweeps/__init__.py +7 -9
  124. wandb/sdk/launch/sweeps/scheduler.py +307 -77
  125. wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
  126. wandb/sdk/launch/sweeps/utils.py +82 -35
  127. wandb/sdk/launch/utils.py +89 -75
  128. wandb/sdk/lib/_settings_toposort_generated.py +7 -0
  129. wandb/sdk/lib/capped_dict.py +26 -0
  130. wandb/sdk/lib/{git.py → gitlib.py} +76 -59
  131. wandb/sdk/lib/hashutil.py +12 -4
  132. wandb/sdk/lib/paths.py +96 -8
  133. wandb/sdk/lib/sock_client.py +2 -2
  134. wandb/sdk/lib/timer.py +1 -0
  135. wandb/sdk/service/server.py +22 -9
  136. wandb/sdk/service/server_sock.py +1 -1
  137. wandb/sdk/service/service.py +27 -8
  138. wandb/sdk/verify/verify.py +4 -7
  139. wandb/sdk/wandb_config.py +2 -6
  140. wandb/sdk/wandb_init.py +57 -53
  141. wandb/sdk/wandb_require.py +7 -0
  142. wandb/sdk/wandb_run.py +61 -223
  143. wandb/sdk/wandb_settings.py +28 -4
  144. wandb/testing/relay.py +15 -2
  145. wandb/util.py +74 -36
  146. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
  147. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
  148. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
  149. wandb/integration/langchain/util.py +0 -191
  150. wandb/sdk/interface/artifacts/__init__.py +0 -33
  151. wandb/sdk/interface/artifacts/artifact.py +0 -615
  152. wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
  153. wandb/sdk/wandb_artifacts.py +0 -2226
  154. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
  155. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
  156. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,7 @@ import traceback
9
9
  from abc import ABC, abstractmethod
10
10
  from dataclasses import dataclass
11
11
  from enum import Enum
12
- from typing import Any, Dict, Iterator, List, Optional, Tuple
12
+ from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
13
13
 
14
14
  import click
15
15
  import yaml
@@ -17,7 +17,10 @@ import yaml
17
17
  import wandb
18
18
  import wandb.apis.public as public
19
19
  from wandb.apis.internal import Api
20
+ from wandb.apis.public import Api as PublicApi
21
+ from wandb.apis.public import QueuedRun, Run
20
22
  from wandb.errors import CommError
23
+ from wandb.sdk.launch.errors import LaunchError
21
24
  from wandb.sdk.launch.launch_add import launch_add
22
25
  from wandb.sdk.launch.sweeps import SchedulerError
23
26
  from wandb.sdk.launch.sweeps.utils import (
@@ -25,10 +28,13 @@ from wandb.sdk.launch.sweeps.utils import (
25
28
  make_launch_sweep_entrypoint,
26
29
  )
27
30
  from wandb.sdk.lib.runid import generate_id
31
+ from wandb.sdk.wandb_run import Run as SdkRun
28
32
 
29
33
  _logger = logging.getLogger(__name__)
30
34
  LOG_PREFIX = f"{click.style('sched:', fg='cyan')} "
31
35
 
36
+ DEFAULT_POLLING_SLEEP = 5.0
37
+
32
38
 
33
39
  class SchedulerState(Enum):
34
40
  PENDING = 0
@@ -42,9 +48,29 @@ class SchedulerState(Enum):
42
48
 
43
49
 
44
50
  class RunState(Enum):
45
- ALIVE = 0
46
- DEAD = 1
47
- UNKNOWN = 2
51
+ RUNNING = "running", "alive"
52
+ PENDING = "pending", "alive"
53
+ PREEMPTING = "preempting", "alive"
54
+ CRASHED = "crashed", "dead"
55
+ FAILED = "failed", "dead"
56
+ KILLED = "killed", "dead"
57
+ FINISHED = "finished", "dead"
58
+ PREEMPTED = "preempted", "dead"
59
+ # unknown when api.get_run_state fails or returns unexpected state
60
+ # assumed alive, unless we get unknown 2x then move to failed (dead)
61
+ UNKNOWN = "unknown", "alive"
62
+
63
+ def __new__(cls: Any, *args: List, **kwds: Any) -> "RunState":
64
+ obj: "RunState" = object.__new__(cls)
65
+ obj._value_ = args[0]
66
+ return obj
67
+
68
+ def __init__(self, _: str, life: str = "unknown") -> None:
69
+ self._life = life
70
+
71
+ @property
72
+ def is_alive(self) -> bool:
73
+ return self._life == "alive"
48
74
 
49
75
 
50
76
  @dataclass
@@ -57,7 +83,7 @@ class _Worker:
57
83
  class SweepRun:
58
84
  id: str
59
85
  worker_id: int
60
- state: RunState = RunState.ALIVE
86
+ state: RunState = RunState.RUNNING
61
87
  queued_run: Optional[public.QueuedRun] = None
62
88
  args: Optional[Dict[str, Any]] = None
63
89
  logs: Optional[List[str]] = None
@@ -66,20 +92,24 @@ class SweepRun:
66
92
  class Scheduler(ABC):
67
93
  """A controller/agent that populates a Launch RunQueue from a hyperparameter sweep."""
68
94
 
95
+ PLACEHOLDER_URI = "placeholder-uri-scheduler"
96
+ SWEEP_JOB_TYPE = "sweep-controller"
97
+ ENTRYPOINT = ["wandb", "scheduler", "WANDB_SWEEP_ID"]
98
+
69
99
  def __init__(
70
100
  self,
71
101
  api: Api,
72
102
  *args: Optional[Any],
73
- num_workers: int = 8,
74
- polling_sleep: float = 5.0,
103
+ polling_sleep: Optional[float] = None,
75
104
  sweep_id: Optional[str] = None,
76
105
  entity: Optional[str] = None,
77
106
  project: Optional[str] = None,
78
107
  project_queue: Optional[str] = None,
108
+ num_workers: Optional[Union[int, str]] = None,
79
109
  **kwargs: Optional[Any],
80
110
  ):
81
111
  self._api = api
82
- self._public_api = public.Api()
112
+ self._public_api = PublicApi()
83
113
  self._entity = (
84
114
  entity
85
115
  or os.environ.get("WANDB_ENTITY")
@@ -100,27 +130,42 @@ class Scheduler(ABC):
100
130
  if resp.get("state") == SchedulerState.CANCELLED.name:
101
131
  self._state = SchedulerState.CANCELLED
102
132
  self._sweep_config = yaml.safe_load(resp["config"])
133
+ self._num_runs_launched: int = self._get_num_runs_launched(resp["runs"])
134
+ if self._num_runs_launched > 0:
135
+ wandb.termlog(
136
+ f"{LOG_PREFIX}Found {self._num_runs_launched} previous valid runs for sweep {self._sweep_id}"
137
+ )
103
138
  except Exception as e:
104
139
  raise SchedulerError(
105
140
  f"{LOG_PREFIX}Exception when finding sweep ({sweep_id}) {e}"
106
141
  )
107
142
 
143
+ # Scheduler may receive additional kwargs which will be piped into the launch command
144
+ self._kwargs: Dict[str, Any] = kwargs
145
+
108
146
  # Dictionary of the runs being managed by the scheduler
109
147
  self._runs: Dict[str, SweepRun] = {}
110
148
  # Threading lock to ensure thread-safe access to the runs dictionary
111
149
  self._threading_lock: threading.Lock = threading.Lock()
112
- self._polling_sleep = polling_sleep
150
+ self._polling_sleep = polling_sleep or DEFAULT_POLLING_SLEEP
113
151
  self._project_queue = project_queue
114
152
  # Optionally run multiple workers in (pseudo-)parallel. Workers do not
115
153
  # actually run training workloads, they simply send heartbeat messages
116
154
  # (emulating a real agent) and add new runs to the launch queue. The
117
155
  # launch agent is the one that actually runs the training workloads.
118
156
  self._workers: Dict[int, _Worker] = {}
119
- self._num_workers = num_workers
120
- self._num_runs_launched = 0
121
157
 
122
- # Scheduler may receive additional kwargs which will be piped into the launch command
123
- self._kwargs: Dict[str, Any] = kwargs
158
+ # Init wandb scheduler run
159
+ self._wandb_run = self._init_wandb_run()
160
+
161
+ # Grab params from scheduler wandb run config
162
+ num_workers = num_workers or self._wandb_run.config.get("scheduler", {}).get(
163
+ "num_workers"
164
+ )
165
+ self._num_workers = int(num_workers) if str(num_workers).isdigit() else 8
166
+ self._settings_config: Dict[str, Any] = self._wandb_run.config.get(
167
+ "settings", {}
168
+ )
124
169
 
125
170
  @abstractmethod
126
171
  def _get_next_sweep_run(self, worker_id: int) -> Optional[SweepRun]:
@@ -168,7 +213,6 @@ class Scheduler(ABC):
168
213
  @property
169
214
  def at_runcap(self) -> bool:
170
215
  """False if under user-specified cap on # of runs."""
171
- # TODO(gst): Count previous runs for resumed sweeps
172
216
  run_cap = self._sweep_config.get("run_cap")
173
217
  if not run_cap:
174
218
  return False
@@ -200,6 +244,18 @@ class Scheduler(ABC):
200
244
  _id: w for _id, w in self._workers.items() if _id not in self.busy_workers
201
245
  }
202
246
 
247
+ def _init_wandb_run(self) -> SdkRun:
248
+ """Controls resume or init logic for a scheduler wandb run."""
249
+ _type = self._kwargs.get("sweep_type", "sweep")
250
+ run: SdkRun = wandb.init(
251
+ name=f"{_type}-scheduler-{self._sweep_id}",
252
+ job_type=self.SWEEP_JOB_TYPE,
253
+ # WANDB_RUN_ID = sweep_id for scheduler
254
+ resume="allow",
255
+ config=self._kwargs, # when run as a job, this sets config
256
+ )
257
+ return run
258
+
203
259
  def stop_sweep(self) -> None:
204
260
  """Stop the sweep."""
205
261
  self._state = SchedulerState.STOPPED
@@ -228,6 +284,7 @@ class Scheduler(ABC):
228
284
  self.exit()
229
285
  return
230
286
 
287
+ # For resuming sweeps
231
288
  self._load_state()
232
289
  self._register_agents()
233
290
  self.run()
@@ -238,10 +295,12 @@ class Scheduler(ABC):
238
295
  self.state = SchedulerState.RUNNING
239
296
  try:
240
297
  while True:
241
- wandb.termlog(f"{LOG_PREFIX}Polling for new runs to launch")
298
+ self._update_scheduler_run_state()
242
299
  if not self.is_alive:
243
300
  break
244
301
 
302
+ wandb.termlog(f"{LOG_PREFIX}Polling for new runs to launch")
303
+
245
304
  self._update_run_states()
246
305
  self._poll()
247
306
  if self.state == SchedulerState.FLUSH_RUNS:
@@ -259,8 +318,17 @@ class Scheduler(ABC):
259
318
  self.state = SchedulerState.FLUSH_RUNS
260
319
  break
261
320
 
262
- run: Optional[SweepRun] = self._get_next_sweep_run(worker_id)
263
- if not run:
321
+ try:
322
+ run: Optional[SweepRun] = self._get_next_sweep_run(worker_id)
323
+ if not run:
324
+ break
325
+ except SchedulerError as e:
326
+ raise SchedulerError(e)
327
+ except Exception as e:
328
+ wandb.termerror(
329
+ f"{LOG_PREFIX}Failed to get next sweep run: {e}"
330
+ )
331
+ self.state = SchedulerState.FAILED
264
332
  break
265
333
 
266
334
  if self._add_to_launch_queue(run):
@@ -278,18 +346,49 @@ class Scheduler(ABC):
278
346
  self.exit()
279
347
  raise e
280
348
  else:
281
- wandb.termlog(f"{LOG_PREFIX}Scheduler completed")
349
+ wandb.termlog(f"{LOG_PREFIX}Scheduler completed successfully")
350
+ # don't overwrite special states (e.g. STOPPED, FAILED)
351
+ if self.state in [SchedulerState.RUNNING, SchedulerState.FLUSH_RUNS]:
352
+ self.state = SchedulerState.COMPLETED
282
353
  self.exit()
283
354
 
284
355
  def exit(self) -> None:
285
356
  self._exit()
286
- self._save_state()
357
+ # _save_state isn't controlled, possibly fails
358
+ try:
359
+ self._save_state()
360
+ except Exception:
361
+ wandb.termerror(
362
+ f"{LOG_PREFIX}Failed to save state: {traceback.format_exc()}"
363
+ )
364
+
287
365
  if self.state not in [
288
366
  SchedulerState.COMPLETED,
289
367
  SchedulerState.STOPPED,
290
368
  ]:
291
369
  self.state = SchedulerState.FAILED
370
+ self._set_sweep_state("CRASHED")
371
+ else:
372
+ self._set_sweep_state("FINISHED")
373
+
292
374
  self._stop_runs()
375
+ self._wandb_run.finish()
376
+
377
+ def _get_num_runs_launched(self, runs: List[Dict[str, Any]]) -> int:
378
+ """Returns the number of valid runs in the sweep."""
379
+ count = 0
380
+ for run in runs:
381
+ # if bad run, shouldn't be counted against run cap
382
+ if run.get("state", "") in ["killed", "crashed"] and not run.get(
383
+ "summaryMetrics"
384
+ ):
385
+ _logger.debug(
386
+ f"excluding run: {run['name']} with state: {run['state']} from run cap \n{run}"
387
+ )
388
+ continue
389
+ count += 1
390
+
391
+ return count
293
392
 
294
393
  def _try_load_executable(self) -> bool:
295
394
  """Check existance of valid executable for a run.
@@ -297,9 +396,8 @@ class Scheduler(ABC):
297
396
  logs and returns False when job is unreachable
298
397
  """
299
398
  if self._kwargs.get("job"):
300
- _public_api = public.Api()
301
399
  try:
302
- _job_artifact = _public_api.artifact(self._kwargs["job"], type="job")
400
+ _job_artifact = self._public_api.job(self._kwargs["job"])
303
401
  wandb.termlog(
304
402
  f"{LOG_PREFIX}Successfully loaded job ({_job_artifact.name}) in scheduler"
305
403
  )
@@ -316,12 +414,17 @@ class Scheduler(ABC):
316
414
  def _register_agents(self) -> None:
317
415
  for worker_id in range(self._num_workers):
318
416
  _logger.debug(f"{LOG_PREFIX}Starting AgentHeartbeat worker ({worker_id})")
319
- agent_config = self._api.register_agent(
320
- f"{socket.gethostname()}-{worker_id}", # host
321
- sweep_id=self._sweep_id,
322
- project_name=self._project,
323
- entity=self._entity,
324
- )
417
+ try:
418
+ agent_config = self._api.register_agent(
419
+ f"{socket.gethostname()}-{worker_id}", # host
420
+ sweep_id=self._sweep_id,
421
+ project_name=self._project,
422
+ entity=self._entity,
423
+ )
424
+ except Exception as e:
425
+ _logger.debug(f"failed to register agent: {e}")
426
+ self.fail_sweep(f"failed to register agent: {e}")
427
+
325
428
  self._workers[worker_id] = _Worker(
326
429
  agent_config=agent_config,
327
430
  agent_id=agent_config["id"],
@@ -332,6 +435,17 @@ class Scheduler(ABC):
332
435
  with self._threading_lock:
333
436
  yield from self._runs.items()
334
437
 
438
+ def _cleanup_runs(self, runs_to_remove: List[str]) -> None:
439
+ """Helper for removing runs from memory.
440
+
441
+ Can be overloaded to prevent deletion of runs, which is useful
442
+ for debugging or when polling on completed runs.
443
+ """
444
+ with self._threading_lock:
445
+ for run_id in runs_to_remove:
446
+ wandb.termlog(f"{LOG_PREFIX}Cleaning up finished run ({run_id})")
447
+ del self._runs[run_id]
448
+
335
449
  def _stop_runs(self) -> None:
336
450
  to_delete = []
337
451
  for run_id, _ in self._yield_runs():
@@ -357,7 +471,7 @@ class Scheduler(ABC):
357
471
  )
358
472
  return False
359
473
 
360
- if run.state == RunState.DEAD:
474
+ if not run.state.is_alive:
361
475
  # run already dead, just delete reference
362
476
  return True
363
477
 
@@ -366,82 +480,195 @@ class Scheduler(ABC):
366
480
  f"Run:v1:{run_id}:{self._project}:{self._entity}".encode()
367
481
  ).decode("utf-8")
368
482
 
369
- success: bool = self._api.stop_run(run_id=encoded_run_id)
370
- if success:
371
- wandb.termlog(f"{LOG_PREFIX}Stopped run {run_id}.")
483
+ try:
484
+ success: bool = self._api.stop_run(run_id=encoded_run_id)
485
+ if success:
486
+ wandb.termlog(f"{LOG_PREFIX}Stopped run {run_id}.")
487
+ return True
488
+ except Exception as e:
489
+ _logger.debug(f"error stopping run ({run_id}): {e}")
490
+
491
+ return False
492
+
493
+ def _update_scheduler_run_state(self) -> None:
494
+ """Update the scheduler state from state of scheduler run and sweep state."""
495
+ state: RunState = self._get_run_state(self._wandb_run.id)
372
496
 
373
- return success
497
+ if state == RunState.KILLED:
498
+ self.state = SchedulerState.STOPPED
499
+ elif state in [RunState.FAILED, RunState.CRASHED]:
500
+ self.state = SchedulerState.FAILED
501
+ elif state == RunState.FINISHED:
502
+ self.state = SchedulerState.COMPLETED
503
+
504
+ try:
505
+ sweep_state = self._api.get_sweep_state(
506
+ self._sweep_id, self._entity, self._project
507
+ )
508
+ except Exception as e:
509
+ _logger.debug(f"sweep state error: {sweep_state} e: {e}")
510
+ return
511
+
512
+ if sweep_state in ["FINISHED", "CANCELLED"]:
513
+ self.state = SchedulerState.COMPLETED
514
+ elif sweep_state in ["PAUSED", "STOPPED"]:
515
+ self.state = SchedulerState.FLUSH_RUNS
374
516
 
375
517
  def _update_run_states(self) -> None:
376
518
  """Iterate through runs.
377
519
 
378
520
  Get state from backend and deletes runs if not in running state. Threadsafe.
379
521
  """
380
- # TODO(gst): move to better constants place
381
- end_states = [
382
- "crashed",
383
- "failed",
384
- "killed",
385
- "finished",
386
- "preempted",
387
- ]
388
- run_states = ["running", "pending", "preempting"]
389
-
390
- _runs_to_remove: List[str] = []
522
+ runs_to_remove: List[str] = []
391
523
  for run_id, run in self._yield_runs():
524
+ run.state = self._get_run_state(run_id, run.state)
525
+
392
526
  try:
393
- _state = self._api.get_run_state(self._entity, self._project, run_id)
394
- _rqi_state = run.queued_run.state if run.queued_run else None
395
- if not _state or _state in end_states or _rqi_state == "failed":
396
- _logger.debug(
397
- f"({run_id}) run-state:{_state}, rqi-state:{_rqi_state}"
398
- )
399
- run.state = RunState.DEAD
400
- _runs_to_remove.append(run_id)
401
- elif _state in run_states:
402
- run.state = RunState.ALIVE
403
- except CommError as e:
404
- _logger.debug(
405
- f"Issue when getting state for run ({run_id}) with error: {e}"
527
+ rqi_state = run.queued_run.state if run.queued_run else None
528
+ except (CommError, LaunchError) as e:
529
+ _logger.debug(f"Failed to get queued_run.state: {e}")
530
+ rqi_state = None
531
+
532
+ if not run.state.is_alive or rqi_state == "failed":
533
+ _logger.debug(f"({run_id}) states: ({run.state}, {rqi_state})")
534
+ runs_to_remove.append(run_id)
535
+ self._cleanup_runs(runs_to_remove)
536
+
537
+ def _get_metrics_from_run(self, run_id: str) -> List[Any]:
538
+ """Use the public api to get metrics from a run.
539
+
540
+ Uses the metric name found in the sweep config, any
541
+ misspellings will result in an empty list.
542
+ """
543
+ try:
544
+ queued_run: Optional[QueuedRun] = self._runs[run_id].queued_run
545
+ if not queued_run:
546
+ return []
547
+
548
+ api_run: Run = self._public_api.run(
549
+ f"{queued_run.entity}/{queued_run.project}/{run_id}"
550
+ )
551
+ metric_name = self._sweep_config["metric"]["name"]
552
+ history = api_run.scan_history(keys=["_step", metric_name])
553
+ metrics = [x[metric_name] for x in history]
554
+
555
+ return metrics
556
+ except Exception as e:
557
+ _logger.debug(f"[_get_metrics_from_run] {e}")
558
+ return []
559
+
560
+ def _get_run_info(self, run_id: str) -> Dict[str, Any]:
561
+ """Use the public api to get info about a run."""
562
+ try:
563
+ info: Dict[str, Any] = self._api.get_run_info(
564
+ self._entity, self._project, run_id
565
+ )
566
+ if info:
567
+ return info
568
+ except Exception as e:
569
+ _logger.debug(f"[_get_run_info] {e}")
570
+ return {}
571
+
572
+ def _get_run_state(
573
+ self, run_id: str, prev_run_state: RunState = RunState.UNKNOWN
574
+ ) -> RunState:
575
+ """Use the public api to get state of a run."""
576
+ run_state = None
577
+ try:
578
+ state = self._api.get_run_state(self._entity, self._project, run_id)
579
+ run_state = RunState(state)
580
+ except CommError as e:
581
+ _logger.debug(f"error getting state for run ({run_id}): {e}")
582
+ if prev_run_state == RunState.UNKNOWN:
583
+ # triggers when we get an unknown state for the second time
584
+ wandb.termwarn(
585
+ f"Failed to get runstate for run ({run_id}). Error: {traceback.format_exc()}"
406
586
  )
407
- run.state = RunState.UNKNOWN
408
- continue
409
- # Remove any runs that are dead
410
- with self._threading_lock:
411
- for run_id in _runs_to_remove:
412
- wandb.termlog(f"{LOG_PREFIX}Cleaning up finished run ({run_id})")
413
- del self._runs[run_id]
587
+ run_state = RunState.FAILED
588
+ else: # first time we get unknwon state
589
+ run_state = RunState.UNKNOWN
590
+ except (AttributeError, ValueError):
591
+ wandb.termwarn(
592
+ f"Bad state ({run_state}) for run ({run_id}). Error: {traceback.format_exc()}"
593
+ )
594
+ run_state = RunState.UNKNOWN
595
+ return run_state
414
596
 
415
- def _add_to_launch_queue(self, run: SweepRun) -> bool:
416
- """Convert a sweeprun into a launch job then push to runqueue."""
417
- # job and image first from CLI args, then from sweep config
418
- _job = self._kwargs.get("job") or self._sweep_config.get("job")
419
- _sweep_config_uri = self._sweep_config.get("image_uri")
420
- _image_uri = self._kwargs.get("image_uri") or _sweep_config_uri
421
- if _job is None and _image_uri is None:
422
- raise SchedulerError(f"{LOG_PREFIX}No 'job' nor 'image_uri' ({run.id})")
423
- elif _job is not None and _image_uri is not None:
424
- raise SchedulerError(f"{LOG_PREFIX}Sweep has both 'job' and 'image_uri'")
597
+ def _create_run(self) -> Dict[str, Any]:
598
+ """Use the public api to create a blank run."""
599
+ try:
600
+ run: List[Dict[str, Any]] = self._api.upsert_run(
601
+ project=self._project,
602
+ entity=self._entity,
603
+ sweep_name=self._sweep_id,
604
+ )
605
+ if run:
606
+ return run[0]
607
+ except Exception as e:
608
+ _logger.debug(f"[_create_run] {e}")
609
+ raise SchedulerError(
610
+ "Error creating run from scheduler, check API connection and CLI version."
611
+ )
612
+ return {}
613
+
614
+ def _set_sweep_state(self, state: str) -> None:
615
+ wandb.termlog(f"{LOG_PREFIX}Updating sweep state to: {state.lower()}")
616
+ try:
617
+ self._api.set_sweep_state(sweep=self._sweep_id, state=state)
618
+ except Exception as e:
619
+ _logger.debug(f"[set_sweep_state] {e}")
620
+
621
+ def _encode(self, _id: str) -> str:
622
+ return (
623
+ base64.b64decode(bytes(_id.encode("utf-8"))).decode("utf-8").split(":")[2]
624
+ )
425
625
 
626
+ def _make_entry_and_launch_config(
627
+ self, run: SweepRun
628
+ ) -> Tuple[Optional[List[str]], Dict[str, Dict[str, Any]]]:
426
629
  args = create_sweep_command_args({"args": run.args})
427
630
  entry_point, macro_args = make_launch_sweep_entrypoint(
428
631
  args, self._sweep_config.get("command")
429
632
  )
633
+ # handle program macro
634
+ if entry_point and "${program}" in entry_point:
635
+ if not self._sweep_config.get("program"):
636
+ raise SchedulerError(
637
+ f"{LOG_PREFIX}Program macro in command has no corresponding 'program' in sweep config."
638
+ )
639
+ pidx = entry_point.index("${program}")
640
+ entry_point[pidx] = self._sweep_config["program"]
641
+
430
642
  launch_config = {"overrides": {"run_config": args["args_dict"]}}
431
643
  if macro_args: # pipe in hyperparam args as params to launch
432
644
  launch_config["overrides"]["args"] = macro_args
433
645
 
434
646
  if entry_point:
435
- wandb.termwarn(
436
- f"{LOG_PREFIX}Sweep command {entry_point} will override"
437
- f' {"job" if _job else "image_uri"} entrypoint'
438
- )
439
647
  unresolved = [x for x in entry_point if str(x).startswith("${")]
440
648
  if unresolved:
441
649
  wandb.termwarn(
442
650
  f"{LOG_PREFIX}Sweep command contains unresolved macros: "
443
651
  f"{unresolved}, see launch docs for supported macros."
444
652
  )
653
+ return entry_point, launch_config
654
+
655
+ def _add_to_launch_queue(self, run: SweepRun) -> bool:
656
+ """Convert a sweeprun into a launch job then push to runqueue."""
657
+ # job and image first from CLI args, then from sweep config
658
+ _job = self._kwargs.get("job") or self._sweep_config.get("job")
659
+ _sweep_config_uri = self._sweep_config.get("image_uri")
660
+ _image_uri = self._kwargs.get("image_uri") or _sweep_config_uri
661
+ if _job is None and _image_uri is None:
662
+ raise SchedulerError(f"{LOG_PREFIX}No 'job' nor 'image_uri' ({run.id})")
663
+ elif _job is not None and _image_uri is not None:
664
+ raise SchedulerError(f"{LOG_PREFIX}Sweep has both 'job' and 'image_uri'")
665
+
666
+ entry_point, launch_config = self._make_entry_and_launch_config(run)
667
+ if entry_point:
668
+ wandb.termwarn(
669
+ f"{LOG_PREFIX}Sweep command {entry_point} will override"
670
+ f' {"job" if _job else "image_uri"} entrypoint'
671
+ )
445
672
 
446
673
  run_id = run.id or generate_id()
447
674
  queued_run = launch_add(
@@ -457,8 +684,11 @@ class Scheduler(ABC):
457
684
  resource=self._kwargs.get("resource", None),
458
685
  resource_args=self._kwargs.get("resource_args", None),
459
686
  author=self._kwargs.get("author"),
687
+ sweep_id=self._sweep_id,
460
688
  )
461
689
  run.queued_run = queued_run
690
+ # TODO(gst): unify run and queued_run state
691
+ run.state = RunState.RUNNING # assume it will get picked up
462
692
  self._runs[run_id] = run
463
693
 
464
694
  wandb.termlog(
@@ -50,6 +50,7 @@ class SweepScheduler(Scheduler):
50
50
 
51
51
  return SweepRun(
52
52
  id=_run_id,
53
+ state=RunState.PENDING,
53
54
  args=command.get("args", {}),
54
55
  logs=command.get("logs", []),
55
56
  worker_id=worker_id,
@@ -62,7 +63,7 @@ class SweepScheduler(Scheduler):
62
63
  _run_states: Dict[str, bool] = {}
63
64
  for run_id, run in self._yield_runs():
64
65
  # Filter out runs that are from a different worker thread
65
- if run.worker_id == worker_id and run.state == RunState.ALIVE:
66
+ if run.worker_id == worker_id and run.state.is_alive:
66
67
  _run_states[run_id] = True
67
68
 
68
69
  _logger.debug(f"Sending states: \n{pf(_run_states)}\n")