wandb 0.13.10__py3-none-any.whl → 0.14.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (228) hide show
  1. wandb/__init__.py +2 -3
  2. wandb/apis/__init__.py +1 -3
  3. wandb/apis/importers/__init__.py +4 -0
  4. wandb/apis/importers/base.py +312 -0
  5. wandb/apis/importers/mlflow.py +113 -0
  6. wandb/apis/internal.py +29 -2
  7. wandb/apis/normalize.py +6 -5
  8. wandb/apis/public.py +163 -180
  9. wandb/apis/reports/_templates.py +6 -12
  10. wandb/apis/reports/report.py +1 -1
  11. wandb/apis/reports/runset.py +1 -3
  12. wandb/apis/reports/util.py +12 -10
  13. wandb/beta/workflows.py +57 -34
  14. wandb/catboost/__init__.py +1 -2
  15. wandb/cli/cli.py +215 -133
  16. wandb/data_types.py +63 -56
  17. wandb/docker/__init__.py +78 -16
  18. wandb/docker/auth.py +21 -22
  19. wandb/env.py +0 -1
  20. wandb/errors/__init__.py +8 -116
  21. wandb/errors/term.py +1 -1
  22. wandb/fastai/__init__.py +1 -2
  23. wandb/filesync/dir_watcher.py +8 -5
  24. wandb/filesync/step_prepare.py +76 -75
  25. wandb/filesync/step_upload.py +1 -2
  26. wandb/integration/catboost/__init__.py +1 -3
  27. wandb/integration/catboost/catboost.py +8 -14
  28. wandb/integration/fastai/__init__.py +7 -13
  29. wandb/integration/gym/__init__.py +35 -4
  30. wandb/integration/keras/__init__.py +3 -3
  31. wandb/integration/keras/callbacks/metrics_logger.py +9 -8
  32. wandb/integration/keras/callbacks/model_checkpoint.py +9 -9
  33. wandb/integration/keras/callbacks/tables_builder.py +31 -19
  34. wandb/integration/kfp/kfp_patch.py +20 -17
  35. wandb/integration/kfp/wandb_logging.py +1 -2
  36. wandb/integration/lightgbm/__init__.py +21 -19
  37. wandb/integration/prodigy/prodigy.py +6 -7
  38. wandb/integration/sacred/__init__.py +9 -12
  39. wandb/integration/sagemaker/__init__.py +1 -3
  40. wandb/integration/sagemaker/auth.py +0 -1
  41. wandb/integration/sagemaker/config.py +1 -1
  42. wandb/integration/sagemaker/resources.py +1 -1
  43. wandb/integration/sb3/sb3.py +8 -4
  44. wandb/integration/tensorboard/__init__.py +1 -3
  45. wandb/integration/tensorboard/log.py +8 -8
  46. wandb/integration/tensorboard/monkeypatch.py +11 -9
  47. wandb/integration/tensorflow/__init__.py +1 -3
  48. wandb/integration/xgboost/__init__.py +4 -6
  49. wandb/integration/yolov8/__init__.py +7 -0
  50. wandb/integration/yolov8/yolov8.py +250 -0
  51. wandb/jupyter.py +31 -35
  52. wandb/lightgbm/__init__.py +1 -2
  53. wandb/old/settings.py +2 -2
  54. wandb/plot/bar.py +1 -2
  55. wandb/plot/confusion_matrix.py +1 -3
  56. wandb/plot/histogram.py +1 -2
  57. wandb/plot/line.py +1 -2
  58. wandb/plot/line_series.py +4 -4
  59. wandb/plot/pr_curve.py +17 -20
  60. wandb/plot/roc_curve.py +1 -3
  61. wandb/plot/scatter.py +1 -2
  62. wandb/proto/v3/wandb_server_pb2.py +85 -39
  63. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  64. wandb/proto/v4/wandb_server_pb2.py +51 -39
  65. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  66. wandb/sdk/__init__.py +1 -3
  67. wandb/sdk/backend/backend.py +1 -1
  68. wandb/sdk/data_types/_dtypes.py +38 -30
  69. wandb/sdk/data_types/base_types/json_metadata.py +1 -3
  70. wandb/sdk/data_types/base_types/media.py +17 -17
  71. wandb/sdk/data_types/base_types/wb_value.py +33 -26
  72. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +91 -125
  73. wandb/sdk/data_types/helper_types/classes.py +1 -1
  74. wandb/sdk/data_types/helper_types/image_mask.py +12 -12
  75. wandb/sdk/data_types/histogram.py +5 -4
  76. wandb/sdk/data_types/html.py +1 -2
  77. wandb/sdk/data_types/image.py +11 -11
  78. wandb/sdk/data_types/molecule.py +3 -6
  79. wandb/sdk/data_types/object_3d.py +1 -2
  80. wandb/sdk/data_types/plotly.py +1 -2
  81. wandb/sdk/data_types/saved_model.py +10 -8
  82. wandb/sdk/data_types/video.py +1 -1
  83. wandb/sdk/integration_utils/data_logging.py +5 -5
  84. wandb/sdk/interface/artifacts.py +288 -266
  85. wandb/sdk/interface/interface.py +2 -3
  86. wandb/sdk/interface/interface_grpc.py +1 -1
  87. wandb/sdk/interface/interface_queue.py +1 -1
  88. wandb/sdk/interface/interface_relay.py +1 -1
  89. wandb/sdk/interface/interface_shared.py +1 -2
  90. wandb/sdk/interface/interface_sock.py +1 -1
  91. wandb/sdk/interface/message_future.py +1 -1
  92. wandb/sdk/interface/message_future_poll.py +1 -1
  93. wandb/sdk/interface/router.py +1 -1
  94. wandb/sdk/interface/router_queue.py +1 -1
  95. wandb/sdk/interface/router_relay.py +1 -1
  96. wandb/sdk/interface/router_sock.py +1 -1
  97. wandb/sdk/interface/summary_record.py +1 -1
  98. wandb/sdk/internal/artifacts.py +1 -1
  99. wandb/sdk/internal/datastore.py +2 -3
  100. wandb/sdk/internal/file_pusher.py +5 -3
  101. wandb/sdk/internal/file_stream.py +22 -19
  102. wandb/sdk/internal/handler.py +5 -4
  103. wandb/sdk/internal/internal.py +1 -1
  104. wandb/sdk/internal/internal_api.py +115 -55
  105. wandb/sdk/internal/job_builder.py +1 -3
  106. wandb/sdk/internal/profiler.py +1 -1
  107. wandb/sdk/internal/progress.py +4 -6
  108. wandb/sdk/internal/sample.py +1 -3
  109. wandb/sdk/internal/sender.py +28 -16
  110. wandb/sdk/internal/settings_static.py +5 -5
  111. wandb/sdk/internal/system/assets/__init__.py +1 -0
  112. wandb/sdk/internal/system/assets/cpu.py +3 -9
  113. wandb/sdk/internal/system/assets/disk.py +2 -4
  114. wandb/sdk/internal/system/assets/gpu.py +6 -18
  115. wandb/sdk/internal/system/assets/gpu_apple.py +2 -4
  116. wandb/sdk/internal/system/assets/interfaces.py +50 -22
  117. wandb/sdk/internal/system/assets/ipu.py +1 -3
  118. wandb/sdk/internal/system/assets/memory.py +7 -13
  119. wandb/sdk/internal/system/assets/network.py +4 -8
  120. wandb/sdk/internal/system/assets/open_metrics.py +283 -0
  121. wandb/sdk/internal/system/assets/tpu.py +1 -4
  122. wandb/sdk/internal/system/assets/trainium.py +26 -14
  123. wandb/sdk/internal/system/system_info.py +2 -3
  124. wandb/sdk/internal/system/system_monitor.py +52 -20
  125. wandb/sdk/internal/tb_watcher.py +12 -13
  126. wandb/sdk/launch/_project_spec.py +54 -65
  127. wandb/sdk/launch/agent/agent.py +374 -90
  128. wandb/sdk/launch/builder/abstract.py +61 -7
  129. wandb/sdk/launch/builder/build.py +81 -110
  130. wandb/sdk/launch/builder/docker_builder.py +181 -0
  131. wandb/sdk/launch/builder/kaniko_builder.py +419 -0
  132. wandb/sdk/launch/builder/noop.py +31 -12
  133. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +70 -20
  134. wandb/sdk/launch/environment/abstract.py +28 -0
  135. wandb/sdk/launch/environment/aws_environment.py +276 -0
  136. wandb/sdk/launch/environment/gcp_environment.py +271 -0
  137. wandb/sdk/launch/environment/local_environment.py +65 -0
  138. wandb/sdk/launch/github_reference.py +3 -8
  139. wandb/sdk/launch/launch.py +38 -29
  140. wandb/sdk/launch/launch_add.py +6 -8
  141. wandb/sdk/launch/loader.py +230 -0
  142. wandb/sdk/launch/registry/abstract.py +54 -0
  143. wandb/sdk/launch/registry/elastic_container_registry.py +163 -0
  144. wandb/sdk/launch/registry/google_artifact_registry.py +203 -0
  145. wandb/sdk/launch/registry/local_registry.py +62 -0
  146. wandb/sdk/launch/runner/abstract.py +1 -16
  147. wandb/sdk/launch/runner/{kubernetes.py → kubernetes_runner.py} +83 -95
  148. wandb/sdk/launch/runner/local_container.py +46 -22
  149. wandb/sdk/launch/runner/local_process.py +1 -4
  150. wandb/sdk/launch/runner/{aws.py → sagemaker_runner.py} +53 -212
  151. wandb/sdk/launch/runner/{gcp_vertex.py → vertex_runner.py} +38 -55
  152. wandb/sdk/launch/sweeps/__init__.py +3 -2
  153. wandb/sdk/launch/sweeps/scheduler.py +132 -39
  154. wandb/sdk/launch/sweeps/scheduler_sweep.py +80 -89
  155. wandb/sdk/launch/utils.py +101 -30
  156. wandb/sdk/launch/wandb_reference.py +2 -7
  157. wandb/sdk/lib/_settings_toposort_generate.py +166 -0
  158. wandb/sdk/lib/_settings_toposort_generated.py +201 -0
  159. wandb/sdk/lib/apikey.py +2 -4
  160. wandb/sdk/lib/config_util.py +4 -1
  161. wandb/sdk/lib/console.py +1 -3
  162. wandb/sdk/lib/deprecate.py +3 -3
  163. wandb/sdk/lib/file_stream_utils.py +7 -5
  164. wandb/sdk/lib/filenames.py +1 -1
  165. wandb/sdk/lib/filesystem.py +61 -5
  166. wandb/sdk/lib/git.py +1 -3
  167. wandb/sdk/lib/import_hooks.py +4 -7
  168. wandb/sdk/lib/ipython.py +8 -5
  169. wandb/sdk/lib/lazyloader.py +1 -3
  170. wandb/sdk/lib/mailbox.py +14 -4
  171. wandb/sdk/lib/proto_util.py +10 -5
  172. wandb/sdk/lib/redirect.py +15 -22
  173. wandb/sdk/lib/reporting.py +1 -3
  174. wandb/sdk/lib/retry.py +4 -5
  175. wandb/sdk/lib/runid.py +1 -3
  176. wandb/sdk/lib/server.py +15 -9
  177. wandb/sdk/lib/sock_client.py +1 -1
  178. wandb/sdk/lib/sparkline.py +1 -1
  179. wandb/sdk/lib/wburls.py +1 -1
  180. wandb/sdk/service/port_file.py +1 -2
  181. wandb/sdk/service/service.py +36 -13
  182. wandb/sdk/service/service_base.py +12 -1
  183. wandb/sdk/verify/verify.py +5 -7
  184. wandb/sdk/wandb_artifacts.py +142 -177
  185. wandb/sdk/wandb_config.py +5 -8
  186. wandb/sdk/wandb_helper.py +1 -1
  187. wandb/sdk/wandb_init.py +24 -13
  188. wandb/sdk/wandb_login.py +9 -9
  189. wandb/sdk/wandb_manager.py +39 -4
  190. wandb/sdk/wandb_metric.py +2 -6
  191. wandb/sdk/wandb_require.py +4 -15
  192. wandb/sdk/wandb_require_helpers.py +1 -9
  193. wandb/sdk/wandb_run.py +95 -141
  194. wandb/sdk/wandb_save.py +1 -3
  195. wandb/sdk/wandb_settings.py +149 -54
  196. wandb/sdk/wandb_setup.py +66 -46
  197. wandb/sdk/wandb_summary.py +13 -10
  198. wandb/sdk/wandb_sweep.py +6 -7
  199. wandb/sdk/wandb_watch.py +1 -1
  200. wandb/sklearn/calculate/confusion_matrix.py +1 -1
  201. wandb/sklearn/calculate/learning_curve.py +1 -1
  202. wandb/sklearn/calculate/summary_metrics.py +1 -3
  203. wandb/sklearn/plot/__init__.py +1 -1
  204. wandb/sklearn/plot/classifier.py +27 -18
  205. wandb/sklearn/plot/clusterer.py +4 -5
  206. wandb/sklearn/plot/regressor.py +4 -4
  207. wandb/sklearn/plot/shared.py +2 -2
  208. wandb/sync/__init__.py +1 -3
  209. wandb/sync/sync.py +4 -5
  210. wandb/testing/relay.py +11 -10
  211. wandb/trigger.py +1 -1
  212. wandb/util.py +106 -81
  213. wandb/viz.py +4 -4
  214. wandb/wandb_agent.py +50 -50
  215. wandb/wandb_controller.py +2 -3
  216. wandb/wandb_run.py +1 -2
  217. wandb/wandb_torch.py +1 -1
  218. wandb/xgboost/__init__.py +1 -2
  219. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/METADATA +6 -2
  220. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/RECORD +224 -209
  221. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/WHEEL +1 -1
  222. wandb/sdk/launch/builder/docker.py +0 -80
  223. wandb/sdk/launch/builder/kaniko.py +0 -393
  224. wandb/sdk/launch/builder/loader.py +0 -32
  225. wandb/sdk/launch/runner/loader.py +0 -50
  226. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/LICENSE +0 -0
  227. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/entry_points.txt +0 -0
  228. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/top_level.txt +0 -0
@@ -1,36 +1,28 @@
1
1
  """Scheduler for classic wandb Sweeps."""
2
2
  import logging
3
- import pprint
4
3
  import queue
5
4
  import socket
6
5
  import time
7
- from dataclasses import dataclass
6
+ from pprint import pformat as pf
8
7
  from typing import Any, Dict, List
9
8
 
10
9
  import wandb
11
10
  from wandb.sdk.launch.sweeps import SchedulerError
12
11
  from wandb.sdk.launch.sweeps.scheduler import (
13
12
  LOG_PREFIX,
13
+ RunState,
14
14
  Scheduler,
15
15
  SchedulerState,
16
- SimpleRunState,
17
16
  SweepRun,
17
+ _Worker,
18
18
  )
19
- from wandb.wandb_agent import Agent as LegacySweepAgent
19
+ from wandb.wandb_agent import _create_sweep_command_args
20
20
 
21
- logger = logging.getLogger(__name__)
22
-
23
-
24
- @dataclass
25
- class _Worker:
26
- agent_config: Dict[str, Any]
27
- agent_id: str
21
+ _logger = logging.getLogger(__name__)
28
22
 
29
23
 
30
24
  class SweepScheduler(Scheduler):
31
- """A SweepScheduler is a controller/agent that will populate a Launch RunQueue with
32
- launch jobs it creates from run suggestions it pulls from an internal sweeps RunQueue.
33
- """
25
+ """A controller/agent that populates a Launch RunQueue from a sweeps RunQueue."""
34
26
 
35
27
  def __init__(
36
28
  self,
@@ -41,11 +33,6 @@ class SweepScheduler(Scheduler):
41
33
  **kwargs: Any,
42
34
  ):
43
35
  super().__init__(*args, **kwargs)
44
- # Optionally run multiple workers in (pseudo-)parallel. Workers do not
45
- # actually run training workloads, they simply send heartbeat messages
46
- # (emulating a real agent) and add new runs to the launch queue. The
47
- # launch agent is the one that actually runs the training workloads.
48
- self._workers: Dict[int, _Worker] = {}
49
36
  self._num_workers: int = num_workers
50
37
  # Thread will pop items off the Sweeps RunQueue using AgentHeartbeat
51
38
  # and put them in this internal queue, which will be used to populate
@@ -56,7 +43,7 @@ class SweepScheduler(Scheduler):
56
43
 
57
44
  def _start(self) -> None:
58
45
  for worker_id in range(self._num_workers):
59
- logger.debug(f"{LOG_PREFIX}Starting AgentHeartbeat worker {worker_id}\n")
46
+ _logger.debug(f"{LOG_PREFIX}Starting AgentHeartbeat worker {worker_id}\n")
60
47
  agent_config = self._api.register_agent(
61
48
  f"{socket.gethostname()}-{worker_id}", # host
62
49
  sweep_id=self._sweep_id,
@@ -68,92 +55,96 @@ class SweepScheduler(Scheduler):
68
55
  agent_id=agent_config["id"],
69
56
  )
70
57
 
71
- def _heartbeat(self, worker_id: int) -> None:
72
- # Make sure Scheduler is alive
73
- if not self.is_alive():
74
- return
58
+ def _get_sweep_commands(self, worker_id: int) -> List[Dict[str, Any]]:
75
59
  # AgentHeartbeat wants a Dict of runs which are running or queued
76
60
  _run_states: Dict[str, bool] = {}
77
61
  for run_id, run in self._yield_runs():
78
62
  # Filter out runs that are from a different worker thread
79
- if run.worker_id == worker_id and run.state == SimpleRunState.ALIVE:
63
+ if run.worker_id == worker_id and run.state == RunState.ALIVE:
80
64
  _run_states[run_id] = True
81
- logger.debug(
82
- f"{LOG_PREFIX}AgentHeartbeat sending: \n{pprint.pformat(_run_states)}\n"
83
- )
65
+
66
+ _logger.debug(f"{LOG_PREFIX}Sending states: \n{pf(_run_states)}\n")
84
67
  commands: List[Dict[str, Any]] = self._api.agent_heartbeat(
85
68
  self._workers[worker_id].agent_id, # agent_id: str
86
69
  {}, # metrics: dict
87
70
  _run_states, # run_states: dict
88
71
  )
89
- logger.debug(
90
- f"{LOG_PREFIX}AgentHeartbeat received {len(commands)} commands: \n{pprint.pformat(commands)}\n"
91
- )
92
- if commands:
93
- for command in commands:
94
- # The command "type" can be one of "run", "resume", "stop", "exit"
95
- _type = command.get("type", None)
96
- if _type in ["exit", "stop"]:
72
+ _logger.debug(f"{LOG_PREFIX}AgentHeartbeat commands: \n{pf(commands)}\n")
73
+
74
+ return commands
75
+
76
+ def _heartbeat(self, worker_id: int) -> bool:
77
+ # Make sure Scheduler is alive
78
+ if not self.is_alive():
79
+ return False
80
+ elif self.state == SchedulerState.FLUSH_RUNS:
81
+ # already hit run_cap, just noop
82
+ return False
83
+
84
+ commands: List[Dict[str, Any]] = self._get_sweep_commands(worker_id)
85
+ for command in commands:
86
+ # The command "type" can be one of "run", "resume", "stop", "exit"
87
+ _type = command.get("type")
88
+ if _type in ["exit", "stop"]:
89
+ run_cap = command.get("run_cap")
90
+ if run_cap is not None:
91
+ # If Sweep hit run_cap, go into flushing state
92
+ wandb.termlog(f"{LOG_PREFIX}Sweep hit run_cap: {run_cap}")
93
+ self.state = SchedulerState.FLUSH_RUNS
94
+ else:
97
95
  # Tell (virtual) agent to stop running
98
96
  self.state = SchedulerState.STOPPED
99
- self.exit()
100
- return
101
- elif _type in ["run", "resume"]:
102
- _run_id = command.get("run_id", None)
103
- if _run_id is None:
104
- self.state = SchedulerState.FAILED
105
- raise SchedulerError(
106
- f"AgentHeartbeat command {command} missing run_id"
107
- )
108
- if _run_id in self._runs:
109
- wandb.termlog(f"{LOG_PREFIX} Skipping duplicate run {run_id}")
110
- else:
111
- run = SweepRun(
112
- id=_run_id,
113
- args=command.get("args", {}),
114
- logs=command.get("logs", []),
115
- program=command.get("program", None),
116
- worker_id=worker_id,
117
- )
118
- self._runs[run.id] = run
119
- self._heartbeat_queue.put(run)
120
- else:
97
+ return False
98
+
99
+ if _type in ["run", "resume"]:
100
+ _run_id = command.get("run_id")
101
+ if not _run_id:
121
102
  self.state = SchedulerState.FAILED
122
- raise SchedulerError(f"AgentHeartbeat unknown command type {_type}")
103
+ raise SchedulerError(f"No runId in agent heartbeat: {command}")
104
+ if _run_id in self._runs:
105
+ wandb.termlog(f"{LOG_PREFIX}Skipping duplicate run: {_run_id}")
106
+ continue
107
+
108
+ run = SweepRun(
109
+ id=_run_id,
110
+ args=command.get("args", {}),
111
+ logs=command.get("logs", []),
112
+ worker_id=worker_id,
113
+ )
114
+ self._runs[run.id] = run
115
+ self._heartbeat_queue.put(run)
116
+ else:
117
+ self.state = SchedulerState.FAILED
118
+ raise SchedulerError(f"AgentHeartbeat unknown command: {_type}")
119
+ return True
123
120
 
124
121
  def _run(self) -> None:
125
122
  # Go through all workers and heartbeat
126
- for worker_id in self._workers.keys():
123
+ for worker_id in self._workers:
127
124
  self._heartbeat(worker_id)
128
- try:
129
- run: SweepRun = self._heartbeat_queue.get(
130
- timeout=self._heartbeat_queue_timeout
131
- )
132
- except queue.Empty:
133
- wandb.termlog(f"{LOG_PREFIX}No jobs in Sweeps RunQueue, waiting...")
134
- time.sleep(self._heartbeat_queue_sleep)
135
- return
136
- # If run is already stopped just ignore the request
137
- if run.state in [
138
- SimpleRunState.DEAD,
139
- SimpleRunState.UNKNOWN,
140
- ]:
141
- return
142
- wandb.termlog(
143
- f"{LOG_PREFIX}Converting Sweep Run (RunID:{run.id}) to Launch Job"
144
- )
145
- _ = self._add_to_launch_queue(
146
- run_id=run.id,
147
- entry_point=["python3", run.program] if run.program else None,
148
- # Use legacy sweep utilities to extract args dict from agent heartbeat run.args
149
- config={
150
- "overrides": {
151
- "run_config": LegacySweepAgent._create_command_args(
152
- {"args": run.args}
153
- )["args_dict"]
154
- }
155
- },
156
- )
125
+
126
+ for _worker_id in self._workers:
127
+ try:
128
+ run: SweepRun = self._heartbeat_queue.get(
129
+ timeout=self._heartbeat_queue_timeout
130
+ )
131
+
132
+ # If run is already stopped just ignore the request
133
+ if run.state in [RunState.DEAD, RunState.UNKNOWN]:
134
+ wandb.termwarn(f"{LOG_PREFIX}Ignoring dead run {run.id}")
135
+ _logger.debug(f"dead run {run.id} state: {run.state}")
136
+ continue
137
+
138
+ sweep_args = _create_sweep_command_args({"args": run.args})["args_dict"]
139
+ launch_config = {"overrides": {"run_config": sweep_args}}
140
+ self._add_to_launch_queue(run_id=run.id, config=launch_config)
141
+ except queue.Empty:
142
+ if self.state == SchedulerState.FLUSH_RUNS:
143
+ wandb.termlog(f"{LOG_PREFIX}Sweep stopped, waiting on runs...")
144
+ else:
145
+ wandb.termlog(f"{LOG_PREFIX}No new runs to launch, waiting...")
146
+ time.sleep(self._heartbeat_queue_sleep)
147
+ return
157
148
 
158
149
  def _exit(self) -> None:
159
150
  pass
wandb/sdk/launch/utils.py CHANGED
@@ -10,15 +10,49 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
10
10
  import click
11
11
 
12
12
  import wandb
13
+ import wandb.docker as docker
13
14
  from wandb import util
14
15
  from wandb.apis.internal import Api
15
- from wandb.errors import CommError, LaunchError
16
+ from wandb.errors import CommError, Error
16
17
  from wandb.sdk.launch.wandb_reference import WandbReference
17
18
 
19
+ from .builder.templates._wandb_bootstrap import (
20
+ FAILED_PACKAGES_POSTFIX,
21
+ FAILED_PACKAGES_PREFIX,
22
+ )
23
+
24
+ FAILED_PACKAGES_REGEX = re.compile(
25
+ f"{re.escape(FAILED_PACKAGES_PREFIX)}(.*){re.escape(FAILED_PACKAGES_POSTFIX)}"
26
+ )
27
+
18
28
  if TYPE_CHECKING: # pragma: no cover
19
29
  from wandb.apis.public import Artifact as PublicArtifact
20
30
 
21
31
 
32
+ class LaunchError(Error):
33
+ """Raised when a known error occurs in wandb launch."""
34
+
35
+ pass
36
+
37
+
38
+ class LaunchDockerError(Error):
39
+ """Raised when Docker daemon is not running."""
40
+
41
+ pass
42
+
43
+
44
+ class ExecutionError(Error):
45
+ """Generic execution exception."""
46
+
47
+ pass
48
+
49
+
50
+ class SweepError(Error):
51
+ """Raised when a known error occurs with wandb sweeps."""
52
+
53
+ pass
54
+
55
+
22
56
  # TODO: this should be restricted to just Git repos and not S3 and stuff like that
23
57
  _GIT_URI_REGEX = re.compile(r"^[^/|^~|^\.].*(git|bitbucket)")
24
58
  _VALID_IP_REGEX = r"^https?://[0-9]+(?:\.[0-9]+){3}(:[0-9]+)?"
@@ -128,11 +162,10 @@ def construct_launch_spec(
128
162
  parameters: Optional[Dict[str, Any]],
129
163
  resource_args: Optional[Dict[str, Any]],
130
164
  launch_config: Optional[Dict[str, Any]],
131
- cuda: Optional[bool],
132
165
  run_id: Optional[str],
133
166
  repository: Optional[str],
134
167
  ) -> Dict[str, Any]:
135
- """Constructs the launch specification from CLI arguments."""
168
+ """Construct the launch specification from CLI arguments."""
136
169
  # override base config (if supplied) with supplied args
137
170
  launch_spec = launch_config if launch_config is not None else {}
138
171
  if uri is not None:
@@ -184,8 +217,6 @@ def construct_launch_spec(
184
217
 
185
218
  if entry_point:
186
219
  launch_spec["overrides"]["entry_point"] = entry_point
187
- if cuda is not None:
188
- launch_spec["cuda"] = cuda
189
220
 
190
221
  if run_id is not None:
191
222
  launch_spec["run_id"] = run_id
@@ -214,7 +245,7 @@ def validate_launch_spec_source(launch_spec: Dict[str, Any]) -> None:
214
245
 
215
246
 
216
247
  def parse_wandb_uri(uri: str) -> Tuple[str, str, str]:
217
- """Parses wandb uri to retrieve entity, project and run name."""
248
+ """Parse wandb uri to retrieve entity, project and run name."""
218
249
  ref = WandbReference.parse(uri)
219
250
  if not ref or not ref.entity or not ref.project or not ref.run_id:
220
251
  raise LaunchError(f"Trouble parsing wandb uri {uri}")
@@ -222,10 +253,12 @@ def parse_wandb_uri(uri: str) -> Tuple[str, str, str]:
222
253
 
223
254
 
224
255
  def is_bare_wandb_uri(uri: str) -> bool:
225
- """Checks if the uri is of the format
226
- /<entity>/<project>/runs/<run_name>[other stuff]
256
+ """Check that a wandb uri is valid.
257
+
258
+ URI must be in the format
259
+ `/<entity>/<project>/runs/<run_name>[other stuff]`
227
260
  or
228
- /<entity>/<project>/artifacts/job/<job_name>[other stuff]
261
+ `/<entity>/<project>/artifacts/job/<job_name>[other stuff]`.
229
262
  """
230
263
  _logger.info(f"Checking if uri {uri} is bare...")
231
264
  return uri.startswith("/") and WandbReference.is_uri_job_or_run(uri)
@@ -306,7 +339,7 @@ def get_local_python_deps(
306
339
 
307
340
 
308
341
  def diff_pip_requirements(req_1: List[str], req_2: List[str]) -> Dict[str, str]:
309
- """Returns a list of pip requirements that are not in req_1 but are in req_2."""
342
+ """Return a list of pip requirements that are not in req_1 but are in req_2."""
310
343
 
311
344
  def _parse_req(req: List[str]) -> Dict[str, str]:
312
345
  # TODO: This can be made more exhaustive, but for 99% of cases this is fine
@@ -366,7 +399,7 @@ def validate_wandb_python_deps(
366
399
  requirements_file: Optional[str],
367
400
  dir: str,
368
401
  ) -> None:
369
- """Warns if local python dependencies differ from wandb requirements.txt"""
402
+ """Warn if local python dependencies differ from wandb requirements.txt."""
370
403
  if requirements_file is not None:
371
404
  requirements_path = os.path.join(dir, requirements_file)
372
405
  with open(requirements_path) as f:
@@ -417,10 +450,7 @@ def apply_patch(patch_string: str, dst_dir: str) -> None:
417
450
 
418
451
 
419
452
  def _make_refspec_from_version(version: Optional[str]) -> List[str]:
420
- """
421
- Helper to create a refspec that checks for the existence of origin/main
422
- and the version, if provided.
423
- """
453
+ """Create a refspec that checks for the existence of origin/main and the version."""
424
454
  if version:
425
455
  return [f"+{version}"]
426
456
 
@@ -452,10 +482,10 @@ def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> str:
452
482
  repo.git.checkout(version)
453
483
  except git.exc.GitCommandError as e:
454
484
  raise LaunchError(
455
- "Unable to checkout version '%s' of git repo %s"
485
+ f"Unable to checkout version '{version}' of git repo {uri}"
456
486
  "- please ensure that the version exists in the repo. "
457
- "Error: %s" % (version, uri, e)
458
- )
487
+ f"Error: {e}"
488
+ ) from e
459
489
  else:
460
490
  if getattr(repo, "references", None) is not None:
461
491
  branches = [ref.name for ref in repo.references]
@@ -475,10 +505,10 @@ def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> str:
475
505
  )
476
506
  except (AttributeError, IndexError) as e:
477
507
  raise LaunchError(
478
- "Unable to checkout default version '%s' of git repo %s "
508
+ f"Unable to checkout default version '{version}' of git repo {uri} "
479
509
  "- to specify a git version use: --git-version \n"
480
- "Error: %s" % (version, uri, e)
481
- )
510
+ f"Error: {e}"
511
+ ) from e
482
512
 
483
513
  repo.submodule_update(init=True, recursive=True)
484
514
  return version
@@ -557,10 +587,9 @@ def validate_build_and_registry_configs(
557
587
 
558
588
 
559
589
  def get_kube_context_and_api_client(
560
- kubernetes: Any, # noqa: F811
561
- resource_args: Dict[str, Any], # noqa: F811
590
+ kubernetes: Any,
591
+ resource_args: Dict[str, Any],
562
592
  ) -> Tuple[Any, Any]:
563
-
564
593
  config_file = resource_args.get("config_file", None)
565
594
  context = None
566
595
  if config_file is not None or os.path.exists(os.path.expanduser("~/.kube/config")):
@@ -579,7 +608,14 @@ def get_kube_context_and_api_client(
579
608
  raise LaunchError(f"Specified context {context_name} was not found.")
580
609
  else:
581
610
  context = active_context
582
-
611
+ # TODO: We should not really be performing this check if the user is not
612
+ # using EKS but I don't see an obvious way to make an eks specific code path
613
+ # right here.
614
+ util.get_module(
615
+ "awscli",
616
+ "awscli is required to load a kubernetes context "
617
+ "from eks. Please run `pip install wandb[launch]` to install it.",
618
+ )
583
619
  kubernetes.config.load_kube_config(config_file, context["name"])
584
620
  api_client = kubernetes.config.new_client_from_config(
585
621
  config_file, context=context["name"]
@@ -598,7 +634,7 @@ def resolve_build_and_registry_config(
598
634
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
599
635
  resolved_build_config: Dict[str, Any] = {}
600
636
  if build_config is None and default_launch_config is not None:
601
- resolved_build_config = default_launch_config.get("build", {})
637
+ resolved_build_config = default_launch_config.get("builder", {})
602
638
  elif build_config is not None:
603
639
  resolved_build_config = build_config
604
640
  resolved_registry_config: Dict[str, Any] = {}
@@ -611,10 +647,10 @@ def resolve_build_and_registry_config(
611
647
 
612
648
 
613
649
  def check_logged_in(api: Api) -> bool:
614
- """
615
- Uses an internal api reference to check if a user is logged in
616
- raises an error if the viewer doesn't load, likely broken API key
617
- expected time cost is 0.1-0.2 seconds
650
+ """Check if a user is logged in.
651
+
652
+ Raises an error if the viewer doesn't load (likely a broken API key). Expected time
653
+ cost is 0.1-0.2 seconds.
618
654
  """
619
655
  res = api.api.viewer()
620
656
  if not res:
@@ -633,3 +669,38 @@ def make_name_dns_safe(name: str) -> str:
633
669
  # Actual length limit is 253, but we want to leave room for the generated suffix
634
670
  resp = resp[:200]
635
671
  return resp
672
+
673
+
674
+ def warn_failed_packages_from_build_logs(log: str, image_uri: str) -> None:
675
+ match = FAILED_PACKAGES_REGEX.search(log)
676
+ if match:
677
+ wandb.termwarn(
678
+ f"Failed to install the following packages: {match.group(1)} for image: {image_uri}. Will attempt to launch image without them."
679
+ )
680
+
681
+
682
+ def docker_image_exists(docker_image: str, should_raise: bool = False) -> bool:
683
+ """Check if a specific image is already available.
684
+
685
+ Optionally raises an exception if the image is not found.
686
+ """
687
+ _logger.info("Checking if base image exists...")
688
+ try:
689
+ docker.run(["docker", "image", "inspect", docker_image])
690
+ return True
691
+ except (docker.DockerError, ValueError) as e:
692
+ if should_raise:
693
+ raise e
694
+ _logger.info("Base image not found. Generating new base image")
695
+ return False
696
+
697
+
698
+ def pull_docker_image(docker_image: str) -> None:
699
+ """Pull the requested docker image."""
700
+ if docker_image_exists(docker_image):
701
+ # don't pull images if they exist already, eg if they are local images
702
+ return
703
+ try:
704
+ docker.run(["docker", "pull", docker_image])
705
+ except docker.DockerError as e:
706
+ raise LaunchError(f"Docker server returned error: {e}")
@@ -1,6 +1,4 @@
1
- """
2
- Support for parsing W&B URLs (which might be user provided) into constituent parts.
3
- """
1
+ """Support for parsing W&B URLs (which might be user provided) into constituent parts."""
4
2
 
5
3
  from dataclasses import dataclass
6
4
  from enum import IntEnum
@@ -35,7 +33,6 @@ RESERVED_JOB_PATHS = ("_view",)
35
33
 
36
34
  @dataclass
37
35
  class WandbReference:
38
-
39
36
  # TODO: This will include port, should we separate that out?
40
37
  host: Optional[str] = None
41
38
 
@@ -88,9 +85,7 @@ class WandbReference:
88
85
 
89
86
  @staticmethod
90
87
  def parse(uri: str) -> Optional["WandbReference"]:
91
- """
92
- Attempt to parse a string as a W&B URL.
93
- """
88
+ """Attempt to parse a string as a W&B URL."""
94
89
  # TODO: Error if HTTP and host is not localhost?
95
90
  if (
96
91
  not uri.startswith("/")
@@ -0,0 +1,166 @@
1
+ import inspect
2
+ import sys
3
+ from typing import Any, Dict, List, Optional, Set, Tuple
4
+
5
+ from wandb.errors import UsageError
6
+ from wandb.sdk.wandb_settings import Settings
7
+
8
+ if sys.version_info >= (3, 8):
9
+ from typing import get_args, get_origin, get_type_hints
10
+ elif sys.version_info >= (3, 7):
11
+ from typing_extensions import get_args, get_origin, get_type_hints
12
+ else:
13
+
14
+ def get_args(obj: Any) -> Optional[Any]:
15
+ return obj.__args__ if hasattr(obj, "__args__") else None
16
+
17
+ def get_origin(obj: Any) -> Optional[Any]:
18
+ return obj.__origin__ if hasattr(obj, "__origin__") else None
19
+
20
+ def get_type_hints(obj: Any) -> Dict[str, Any]:
21
+ return dict(obj.__annotations__) if hasattr(obj, "__annotations__") else dict()
22
+
23
+
24
+ template = """
25
+ __all__ = ("SETTINGS_TOPOLOGICALLY_SORTED", "_Setting")
26
+
27
+ import sys
28
+ from typing import Tuple
29
+
30
+ if sys.version_info >= (3, 8):
31
+ from typing import Final, Literal
32
+ else:
33
+ from typing_extensions import Final, Literal
34
+
35
+
36
+ _Setting = Literal[
37
+ $settings_literal_list
38
+ ]
39
+
40
+ SETTINGS_TOPOLOGICALLY_SORTED: Final[Tuple[_Setting, ...]] = (
41
+ $settings_topologically_sorted
42
+ )
43
+ """
44
+
45
+
46
+ class Graph:
47
+ # A simple class representing an unweighted directed graph
48
+ # that uses an adjacency list representation.
49
+ # We use to ensure that we don't have cyclic dependencies in the settings
50
+ # and that modifications to the settings are applied in the correct order.
51
+ def __init__(self) -> None:
52
+ self.adj_list: Dict[str, Set[str]] = {}
53
+
54
+ def add_node(self, node: str) -> None:
55
+ if node not in self.adj_list:
56
+ self.adj_list[node] = set()
57
+
58
+ def add_edge(self, node1: str, node2: str) -> None:
59
+ self.adj_list[node1].add(node2)
60
+
61
+ def get_neighbors(self, node: str) -> Set[str]:
62
+ return self.adj_list[node]
63
+
64
+ # return a list of nodes sorted in topological order
65
+ def topological_sort_dfs(self) -> List[str]:
66
+ sorted_copy = {k: sorted(v) for k, v in self.adj_list.items()}
67
+
68
+ sorted_nodes: List[str] = []
69
+ visited_nodes: Set[str] = set()
70
+ current_nodes: Set[str] = set()
71
+
72
+ def visit(n: str) -> None:
73
+ if n in visited_nodes:
74
+ return None
75
+ if n in current_nodes:
76
+ raise UsageError("Cyclic dependency detected in wandb.Settings")
77
+
78
+ current_nodes.add(n)
79
+ for neighbor in sorted_copy[n]:
80
+ visit(neighbor)
81
+
82
+ current_nodes.remove(n)
83
+ visited_nodes.add(n)
84
+ sorted_nodes.append(n)
85
+
86
+ return None
87
+
88
+ for node in self.adj_list:
89
+ if node not in visited_nodes:
90
+ visit(node)
91
+
92
+ return sorted_nodes
93
+
94
+
95
+ def _get_modification_order(
96
+ settings: Settings,
97
+ ) -> Tuple[Tuple[str, ...], Tuple[str, ...]]:
98
+ """Return the order in which settings should be modified, based on dependencies."""
99
+ dependency_graph = Graph()
100
+
101
+ props = tuple(get_type_hints(Settings).keys())
102
+
103
+ # discover prop dependencies from validator methods and runtime hooks
104
+
105
+ prefix = "_validate_"
106
+ symbols = set(dir(settings))
107
+ validator_methods = tuple(sorted(m for m in symbols if m.startswith(prefix)))
108
+
109
+ # extract dependencies from validator methods
110
+ for m in validator_methods:
111
+ setting = m.split(prefix)[1]
112
+ dependency_graph.add_node(setting)
113
+ # if the method is not static, inspect its code to find the attributes it depends on
114
+ if (
115
+ not isinstance(Settings.__dict__[m], staticmethod)
116
+ and not isinstance(Settings.__dict__[m], classmethod)
117
+ and Settings.__dict__[m].__code__.co_argcount > 0
118
+ ):
119
+ unbound_closure_vars = inspect.getclosurevars(Settings.__dict__[m]).unbound
120
+ dependencies = (v for v in unbound_closure_vars if v in props)
121
+ for d in dependencies:
122
+ dependency_graph.add_node(d)
123
+ dependency_graph.add_edge(setting, d)
124
+
125
+ # extract dependencies from props' runtime hooks
126
+ default_props = settings._default_props()
127
+ for prop, spec in default_props.items():
128
+ if "hook" not in spec:
129
+ continue
130
+
131
+ dependency_graph.add_node(prop)
132
+
133
+ hook = spec["hook"]
134
+ if callable(hook):
135
+ hook = [hook]
136
+
137
+ for h in hook:
138
+ unbound_closure_vars = inspect.getclosurevars(h).unbound
139
+ dependencies = (v for v in unbound_closure_vars if v in props)
140
+ for d in dependencies:
141
+ dependency_graph.add_node(d)
142
+ dependency_graph.add_edge(prop, d)
143
+
144
+ modification_order = dependency_graph.topological_sort_dfs()
145
+ return props, tuple(modification_order)
146
+
147
+
148
+ def generate(settings: Settings) -> None:
149
+ _settings_literal_list, _settings_topologically_sorted = _get_modification_order(
150
+ settings
151
+ )
152
+ settings_literal_list = ", ".join(f'"{s}"' for s in _settings_literal_list)
153
+ settings_topologically_sorted = ", ".join(
154
+ f'"{s}"' for s in _settings_topologically_sorted
155
+ )
156
+
157
+ print(
158
+ template.replace("$settings_literal_list", settings_literal_list,).replace(
159
+ "$settings_topologically_sorted",
160
+ settings_topologically_sorted,
161
+ )
162
+ )
163
+
164
+
165
+ if __name__ == "__main__":
166
+ generate(Settings())