wandb 0.13.10__py3-none-any.whl → 0.14.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (228) hide show
  1. wandb/__init__.py +2 -3
  2. wandb/apis/__init__.py +1 -3
  3. wandb/apis/importers/__init__.py +4 -0
  4. wandb/apis/importers/base.py +312 -0
  5. wandb/apis/importers/mlflow.py +113 -0
  6. wandb/apis/internal.py +29 -2
  7. wandb/apis/normalize.py +6 -5
  8. wandb/apis/public.py +163 -180
  9. wandb/apis/reports/_templates.py +6 -12
  10. wandb/apis/reports/report.py +1 -1
  11. wandb/apis/reports/runset.py +1 -3
  12. wandb/apis/reports/util.py +12 -10
  13. wandb/beta/workflows.py +57 -34
  14. wandb/catboost/__init__.py +1 -2
  15. wandb/cli/cli.py +215 -133
  16. wandb/data_types.py +63 -56
  17. wandb/docker/__init__.py +78 -16
  18. wandb/docker/auth.py +21 -22
  19. wandb/env.py +0 -1
  20. wandb/errors/__init__.py +8 -116
  21. wandb/errors/term.py +1 -1
  22. wandb/fastai/__init__.py +1 -2
  23. wandb/filesync/dir_watcher.py +8 -5
  24. wandb/filesync/step_prepare.py +76 -75
  25. wandb/filesync/step_upload.py +1 -2
  26. wandb/integration/catboost/__init__.py +1 -3
  27. wandb/integration/catboost/catboost.py +8 -14
  28. wandb/integration/fastai/__init__.py +7 -13
  29. wandb/integration/gym/__init__.py +35 -4
  30. wandb/integration/keras/__init__.py +3 -3
  31. wandb/integration/keras/callbacks/metrics_logger.py +9 -8
  32. wandb/integration/keras/callbacks/model_checkpoint.py +9 -9
  33. wandb/integration/keras/callbacks/tables_builder.py +31 -19
  34. wandb/integration/kfp/kfp_patch.py +20 -17
  35. wandb/integration/kfp/wandb_logging.py +1 -2
  36. wandb/integration/lightgbm/__init__.py +21 -19
  37. wandb/integration/prodigy/prodigy.py +6 -7
  38. wandb/integration/sacred/__init__.py +9 -12
  39. wandb/integration/sagemaker/__init__.py +1 -3
  40. wandb/integration/sagemaker/auth.py +0 -1
  41. wandb/integration/sagemaker/config.py +1 -1
  42. wandb/integration/sagemaker/resources.py +1 -1
  43. wandb/integration/sb3/sb3.py +8 -4
  44. wandb/integration/tensorboard/__init__.py +1 -3
  45. wandb/integration/tensorboard/log.py +8 -8
  46. wandb/integration/tensorboard/monkeypatch.py +11 -9
  47. wandb/integration/tensorflow/__init__.py +1 -3
  48. wandb/integration/xgboost/__init__.py +4 -6
  49. wandb/integration/yolov8/__init__.py +7 -0
  50. wandb/integration/yolov8/yolov8.py +250 -0
  51. wandb/jupyter.py +31 -35
  52. wandb/lightgbm/__init__.py +1 -2
  53. wandb/old/settings.py +2 -2
  54. wandb/plot/bar.py +1 -2
  55. wandb/plot/confusion_matrix.py +1 -3
  56. wandb/plot/histogram.py +1 -2
  57. wandb/plot/line.py +1 -2
  58. wandb/plot/line_series.py +4 -4
  59. wandb/plot/pr_curve.py +17 -20
  60. wandb/plot/roc_curve.py +1 -3
  61. wandb/plot/scatter.py +1 -2
  62. wandb/proto/v3/wandb_server_pb2.py +85 -39
  63. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  64. wandb/proto/v4/wandb_server_pb2.py +51 -39
  65. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  66. wandb/sdk/__init__.py +1 -3
  67. wandb/sdk/backend/backend.py +1 -1
  68. wandb/sdk/data_types/_dtypes.py +38 -30
  69. wandb/sdk/data_types/base_types/json_metadata.py +1 -3
  70. wandb/sdk/data_types/base_types/media.py +17 -17
  71. wandb/sdk/data_types/base_types/wb_value.py +33 -26
  72. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +91 -125
  73. wandb/sdk/data_types/helper_types/classes.py +1 -1
  74. wandb/sdk/data_types/helper_types/image_mask.py +12 -12
  75. wandb/sdk/data_types/histogram.py +5 -4
  76. wandb/sdk/data_types/html.py +1 -2
  77. wandb/sdk/data_types/image.py +11 -11
  78. wandb/sdk/data_types/molecule.py +3 -6
  79. wandb/sdk/data_types/object_3d.py +1 -2
  80. wandb/sdk/data_types/plotly.py +1 -2
  81. wandb/sdk/data_types/saved_model.py +10 -8
  82. wandb/sdk/data_types/video.py +1 -1
  83. wandb/sdk/integration_utils/data_logging.py +5 -5
  84. wandb/sdk/interface/artifacts.py +288 -266
  85. wandb/sdk/interface/interface.py +2 -3
  86. wandb/sdk/interface/interface_grpc.py +1 -1
  87. wandb/sdk/interface/interface_queue.py +1 -1
  88. wandb/sdk/interface/interface_relay.py +1 -1
  89. wandb/sdk/interface/interface_shared.py +1 -2
  90. wandb/sdk/interface/interface_sock.py +1 -1
  91. wandb/sdk/interface/message_future.py +1 -1
  92. wandb/sdk/interface/message_future_poll.py +1 -1
  93. wandb/sdk/interface/router.py +1 -1
  94. wandb/sdk/interface/router_queue.py +1 -1
  95. wandb/sdk/interface/router_relay.py +1 -1
  96. wandb/sdk/interface/router_sock.py +1 -1
  97. wandb/sdk/interface/summary_record.py +1 -1
  98. wandb/sdk/internal/artifacts.py +1 -1
  99. wandb/sdk/internal/datastore.py +2 -3
  100. wandb/sdk/internal/file_pusher.py +5 -3
  101. wandb/sdk/internal/file_stream.py +22 -19
  102. wandb/sdk/internal/handler.py +5 -4
  103. wandb/sdk/internal/internal.py +1 -1
  104. wandb/sdk/internal/internal_api.py +115 -55
  105. wandb/sdk/internal/job_builder.py +1 -3
  106. wandb/sdk/internal/profiler.py +1 -1
  107. wandb/sdk/internal/progress.py +4 -6
  108. wandb/sdk/internal/sample.py +1 -3
  109. wandb/sdk/internal/sender.py +28 -16
  110. wandb/sdk/internal/settings_static.py +5 -5
  111. wandb/sdk/internal/system/assets/__init__.py +1 -0
  112. wandb/sdk/internal/system/assets/cpu.py +3 -9
  113. wandb/sdk/internal/system/assets/disk.py +2 -4
  114. wandb/sdk/internal/system/assets/gpu.py +6 -18
  115. wandb/sdk/internal/system/assets/gpu_apple.py +2 -4
  116. wandb/sdk/internal/system/assets/interfaces.py +50 -22
  117. wandb/sdk/internal/system/assets/ipu.py +1 -3
  118. wandb/sdk/internal/system/assets/memory.py +7 -13
  119. wandb/sdk/internal/system/assets/network.py +4 -8
  120. wandb/sdk/internal/system/assets/open_metrics.py +283 -0
  121. wandb/sdk/internal/system/assets/tpu.py +1 -4
  122. wandb/sdk/internal/system/assets/trainium.py +26 -14
  123. wandb/sdk/internal/system/system_info.py +2 -3
  124. wandb/sdk/internal/system/system_monitor.py +52 -20
  125. wandb/sdk/internal/tb_watcher.py +12 -13
  126. wandb/sdk/launch/_project_spec.py +54 -65
  127. wandb/sdk/launch/agent/agent.py +374 -90
  128. wandb/sdk/launch/builder/abstract.py +61 -7
  129. wandb/sdk/launch/builder/build.py +81 -110
  130. wandb/sdk/launch/builder/docker_builder.py +181 -0
  131. wandb/sdk/launch/builder/kaniko_builder.py +419 -0
  132. wandb/sdk/launch/builder/noop.py +31 -12
  133. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +70 -20
  134. wandb/sdk/launch/environment/abstract.py +28 -0
  135. wandb/sdk/launch/environment/aws_environment.py +276 -0
  136. wandb/sdk/launch/environment/gcp_environment.py +271 -0
  137. wandb/sdk/launch/environment/local_environment.py +65 -0
  138. wandb/sdk/launch/github_reference.py +3 -8
  139. wandb/sdk/launch/launch.py +38 -29
  140. wandb/sdk/launch/launch_add.py +6 -8
  141. wandb/sdk/launch/loader.py +230 -0
  142. wandb/sdk/launch/registry/abstract.py +54 -0
  143. wandb/sdk/launch/registry/elastic_container_registry.py +163 -0
  144. wandb/sdk/launch/registry/google_artifact_registry.py +203 -0
  145. wandb/sdk/launch/registry/local_registry.py +62 -0
  146. wandb/sdk/launch/runner/abstract.py +1 -16
  147. wandb/sdk/launch/runner/{kubernetes.py → kubernetes_runner.py} +83 -95
  148. wandb/sdk/launch/runner/local_container.py +46 -22
  149. wandb/sdk/launch/runner/local_process.py +1 -4
  150. wandb/sdk/launch/runner/{aws.py → sagemaker_runner.py} +53 -212
  151. wandb/sdk/launch/runner/{gcp_vertex.py → vertex_runner.py} +38 -55
  152. wandb/sdk/launch/sweeps/__init__.py +3 -2
  153. wandb/sdk/launch/sweeps/scheduler.py +132 -39
  154. wandb/sdk/launch/sweeps/scheduler_sweep.py +80 -89
  155. wandb/sdk/launch/utils.py +101 -30
  156. wandb/sdk/launch/wandb_reference.py +2 -7
  157. wandb/sdk/lib/_settings_toposort_generate.py +166 -0
  158. wandb/sdk/lib/_settings_toposort_generated.py +201 -0
  159. wandb/sdk/lib/apikey.py +2 -4
  160. wandb/sdk/lib/config_util.py +4 -1
  161. wandb/sdk/lib/console.py +1 -3
  162. wandb/sdk/lib/deprecate.py +3 -3
  163. wandb/sdk/lib/file_stream_utils.py +7 -5
  164. wandb/sdk/lib/filenames.py +1 -1
  165. wandb/sdk/lib/filesystem.py +61 -5
  166. wandb/sdk/lib/git.py +1 -3
  167. wandb/sdk/lib/import_hooks.py +4 -7
  168. wandb/sdk/lib/ipython.py +8 -5
  169. wandb/sdk/lib/lazyloader.py +1 -3
  170. wandb/sdk/lib/mailbox.py +14 -4
  171. wandb/sdk/lib/proto_util.py +10 -5
  172. wandb/sdk/lib/redirect.py +15 -22
  173. wandb/sdk/lib/reporting.py +1 -3
  174. wandb/sdk/lib/retry.py +4 -5
  175. wandb/sdk/lib/runid.py +1 -3
  176. wandb/sdk/lib/server.py +15 -9
  177. wandb/sdk/lib/sock_client.py +1 -1
  178. wandb/sdk/lib/sparkline.py +1 -1
  179. wandb/sdk/lib/wburls.py +1 -1
  180. wandb/sdk/service/port_file.py +1 -2
  181. wandb/sdk/service/service.py +36 -13
  182. wandb/sdk/service/service_base.py +12 -1
  183. wandb/sdk/verify/verify.py +5 -7
  184. wandb/sdk/wandb_artifacts.py +142 -177
  185. wandb/sdk/wandb_config.py +5 -8
  186. wandb/sdk/wandb_helper.py +1 -1
  187. wandb/sdk/wandb_init.py +24 -13
  188. wandb/sdk/wandb_login.py +9 -9
  189. wandb/sdk/wandb_manager.py +39 -4
  190. wandb/sdk/wandb_metric.py +2 -6
  191. wandb/sdk/wandb_require.py +4 -15
  192. wandb/sdk/wandb_require_helpers.py +1 -9
  193. wandb/sdk/wandb_run.py +95 -141
  194. wandb/sdk/wandb_save.py +1 -3
  195. wandb/sdk/wandb_settings.py +149 -54
  196. wandb/sdk/wandb_setup.py +66 -46
  197. wandb/sdk/wandb_summary.py +13 -10
  198. wandb/sdk/wandb_sweep.py +6 -7
  199. wandb/sdk/wandb_watch.py +1 -1
  200. wandb/sklearn/calculate/confusion_matrix.py +1 -1
  201. wandb/sklearn/calculate/learning_curve.py +1 -1
  202. wandb/sklearn/calculate/summary_metrics.py +1 -3
  203. wandb/sklearn/plot/__init__.py +1 -1
  204. wandb/sklearn/plot/classifier.py +27 -18
  205. wandb/sklearn/plot/clusterer.py +4 -5
  206. wandb/sklearn/plot/regressor.py +4 -4
  207. wandb/sklearn/plot/shared.py +2 -2
  208. wandb/sync/__init__.py +1 -3
  209. wandb/sync/sync.py +4 -5
  210. wandb/testing/relay.py +11 -10
  211. wandb/trigger.py +1 -1
  212. wandb/util.py +106 -81
  213. wandb/viz.py +4 -4
  214. wandb/wandb_agent.py +50 -50
  215. wandb/wandb_controller.py +2 -3
  216. wandb/wandb_run.py +1 -2
  217. wandb/wandb_torch.py +1 -1
  218. wandb/xgboost/__init__.py +1 -2
  219. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/METADATA +6 -2
  220. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/RECORD +224 -209
  221. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/WHEEL +1 -1
  222. wandb/sdk/launch/builder/docker.py +0 -80
  223. wandb/sdk/launch/builder/kaniko.py +0 -393
  224. wandb/sdk/launch/builder/loader.py +0 -32
  225. wandb/sdk/launch/runner/loader.py +0 -50
  226. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/LICENSE +0 -0
  227. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/entry_points.txt +0 -0
  228. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/top_level.txt +0 -0
@@ -1,42 +1,69 @@
1
- """
2
- Implementation of launch agent.
3
- """
4
-
1
+ """Implementation of launch agent."""
5
2
  import logging
6
3
  import os
7
4
  import pprint
5
+ import threading
8
6
  import time
9
7
  import traceback
10
- from typing import Any, Dict, List, Union
8
+ from dataclasses import dataclass
9
+ from multiprocessing import Event
10
+ from multiprocessing.pool import ThreadPool
11
+ from typing import Any, Dict, List, Optional, Union
11
12
 
12
13
  import wandb
13
14
  import wandb.util as util
14
15
  from wandb.apis.internal import Api
16
+ from wandb.errors import CommError
17
+ from wandb.sdk.launch._project_spec import LaunchProject
15
18
  from wandb.sdk.launch.runner.local_container import LocalSubmittedRun
19
+ from wandb.sdk.launch.sweeps import SCHEDULER_URI
16
20
  from wandb.sdk.lib import runid
17
21
 
22
+ from .. import loader
18
23
  from .._project_spec import create_project_from_spec, fetch_and_validate_project
19
- from ..builder.loader import load_builder
24
+ from ..builder.build import construct_builder_args
20
25
  from ..runner.abstract import AbstractRun
21
- from ..runner.loader import load_backend
22
26
  from ..utils import (
23
27
  LAUNCH_DEFAULT_PROJECT,
24
28
  LOG_PREFIX,
25
29
  PROJECT_SYNCHRONOUS,
26
- resolve_build_and_registry_config,
30
+ LaunchDockerError,
31
+ LaunchError,
27
32
  )
28
33
 
29
34
  AGENT_POLLING_INTERVAL = 10
35
+ ACTIVE_SWEEP_POLLING_INTERVAL = 1 # more frequent when we know we have jobs
30
36
 
31
37
  AGENT_POLLING = "POLLING"
32
38
  AGENT_RUNNING = "RUNNING"
33
39
  AGENT_KILLED = "KILLED"
34
40
 
41
+ MAX_THREADS = 64
42
+
35
43
  _logger = logging.getLogger(__name__)
36
44
 
37
45
 
46
+ @dataclass
47
+ class JobAndRunStatus:
48
+ run_queue_item_id: str
49
+ run_id: Optional[str] = None
50
+ project: Optional[str] = None
51
+ run: Optional[AbstractRun] = None
52
+ failed_to_start: bool = False
53
+ completed: bool = False
54
+ is_scheduler: bool = False
55
+
56
+ @property
57
+ def job_completed(self) -> bool:
58
+ return self.completed or self.failed_to_start
59
+
60
+ def update_run_info(self, launch_project: LaunchProject) -> None:
61
+ self.run_id = launch_project.run_id
62
+ self.project = launch_project.target_project
63
+
64
+
38
65
  def _convert_access(access: str) -> str:
39
- """Converts access string to a value accepted by wandb."""
66
+ """Convert access string to a value accepted by wandb."""
40
67
  access = access.upper()
41
68
  assert (
42
69
  access == "PROJECT" or access == "USER"
@@ -44,31 +71,94 @@ def _convert_access(access: str) -> str:
44
71
  return access
45
72
 
46
73
 
74
+ def _max_from_config(
75
+ config: Dict[str, Any], key: str, default: int = 1
76
+ ) -> Union[int, float]:
77
+ """Get an integer from the config, or float.inf if -1.
78
+
79
+ Utility for parsing integers from the agent config with a default, infinity
80
+ handling, and integer parsing. Raises more informative error if parse error.
81
+ """
82
+ try:
83
+ val = config.get(key)
84
+ if val is None:
85
+ val = default
86
+ max_from_config = int(val)
87
+ except ValueError as e:
88
+ raise LaunchError(
89
+ f"Error when parsing LaunchAgent config key: ['{key}': "
90
+ f"{config.get(key)}]. Error: {str(e)}"
91
+ )
92
+ if max_from_config == -1:
93
+ return float("inf")
94
+
95
+ if max_from_config < 0:
96
+ raise LaunchError(
97
+ f"Error when parsing LaunchAgent config key: ['{key}': "
98
+ f"{config.get(key)}]. Error: negative value."
99
+ )
100
+ return max_from_config
101
+
102
+
103
+ def _job_is_scheduler(run_spec: Dict[str, Any]) -> bool:
104
+ """Determine whether a job/runSpec is a sweep scheduler."""
105
+ if not run_spec:
106
+ _logger.debug("Recieved runSpec in _job_is_scheduler that was empty")
107
+
108
+ if run_spec.get("uri") != SCHEDULER_URI:
109
+ return False
110
+
111
+ if run_spec.get("resource") == "local-process":
112
+ # If a scheduler is a local-process (100%), also
113
+ # confirm command is in format: [wandb scheduler <sweep>]
114
+ cmd = run_spec.get("overrides", {}).get("entry_point", [])
115
+ if len(cmd) < 3:
116
+ return False
117
+
118
+ if cmd[:2] != ["wandb", "scheduler"]:
119
+ return False
120
+
121
+ return True
122
+
123
+
47
124
  class LaunchAgent:
48
125
  """Launch agent class which polls run given run queues and launches runs for wandb launch."""
49
126
 
50
127
  def __init__(self, api: Api, config: Dict[str, Any]):
51
- self._entity = config.get("entity")
52
- self._project = config.get("project")
128
+ """Initialize a launch agent.
129
+
130
+ Arguments:
131
+ api: Api object to use for making requests to the backend.
132
+ config: Config dictionary for the agent.
133
+ """
134
+ self._entity = config["entity"]
135
+ self._project = config["project"]
53
136
  self._api = api
54
137
  self._base_url = self._api.settings().get("base_url")
55
- self._jobs: Dict[Union[int, str], AbstractRun] = {}
56
138
  self._ticks = 0
57
- self._running = 0
139
+ self._jobs: Dict[int, JobAndRunStatus] = {}
140
+ self._jobs_lock = threading.Lock()
141
+ self._jobs_event = Event()
142
+ self._jobs_event.set()
58
143
  self._cwd = os.getcwd()
59
144
  self._namespace = runid.generate_id()
60
145
  self._access = _convert_access("project")
61
- max_jobs_from_config = int(config.get("max_jobs", 1))
62
- if max_jobs_from_config == -1:
63
- self._max_jobs = float("inf")
64
- else:
65
- self._max_jobs = max_jobs_from_config
146
+ self._max_jobs = _max_from_config(config, "max_jobs")
147
+ self._max_schedulers = _max_from_config(config, "max_schedulers")
148
+ self._pool = ThreadPool(
149
+ processes=int(min(MAX_THREADS, self._max_jobs + self._max_schedulers)),
150
+ initargs=(self._jobs, self._jobs_lock),
151
+ )
66
152
  self.default_config: Dict[str, Any] = config
67
153
 
68
154
  # serverside creation
69
155
  self.gorilla_supports_agents = (
70
156
  self._api.launch_agent_introspection() is not None
71
157
  )
158
+ self._gorilla_supports_fail_run_queue_items = (
159
+ self._api.fail_run_queue_item_introspection()
160
+ )
161
+
72
162
  self._queues = config.get("queues", ["default"])
73
163
  create_response = self._api.create_launch_agent(
74
164
  self._entity,
@@ -78,14 +168,45 @@ class LaunchAgent:
78
168
  )
79
169
  self._id = create_response["launchAgentId"]
80
170
  self._name = "" # hacky: want to display this to the user but we don't get it back from gql until polling starts. fix later
171
+ if self._api.entity_is_team(self._entity):
172
+ wandb.termwarn(
173
+ f"{LOG_PREFIX}Agent is running on team entity ({self._entity}). Members of this team will be able to run code on this device."
174
+ )
175
+
176
+ def fail_run_queue_item(self, run_queue_item_id: str) -> None:
177
+ if self._gorilla_supports_fail_run_queue_items:
178
+ self._api.fail_run_queue_item(run_queue_item_id)
179
+
180
+ @property
181
+ def thread_ids(self) -> List[int]:
182
+ """Returns a list of keys running thread ids for the agent."""
183
+ with self._jobs_lock:
184
+ return list(self._jobs.keys())
185
+
186
+ @property
187
+ def num_running_schedulers(self) -> int:
188
+ """Return just the number of schedulers."""
189
+ with self._jobs_lock:
190
+ return len([x for x in self._jobs if self._jobs[x].is_scheduler])
81
191
 
82
192
  @property
83
- def job_ids(self) -> List[Union[int, str]]:
84
- """Returns a list of keys running job ids for the agent."""
85
- return list(self._jobs.keys())
193
+ def num_running_jobs(self) -> int:
194
+ """Return the number of jobs not including schedulers."""
195
+ with self._jobs_lock:
196
+ return len([x for x in self._jobs if not self._jobs[x].is_scheduler])
86
197
 
87
198
  def pop_from_queue(self, queue: str) -> Any:
88
- """Pops an item off the runqueue to run as a job."""
199
+ """Pops an item off the runqueue to run as a job.
200
+
201
+ Arguments:
202
+ queue: Queue to pop from.
203
+
204
+ Returns:
205
+ Item popped off the queue.
206
+
207
+ Raises:
208
+ Exception: if there is an error popping from the queue.
209
+ """
89
210
  try:
90
211
  ups = self._api.pop_from_run_queue(
91
212
  queue,
@@ -100,41 +221,77 @@ class LaunchAgent:
100
221
 
101
222
  def print_status(self) -> None:
102
223
  """Prints the current status of the agent."""
103
- if self._project == LAUNCH_DEFAULT_PROJECT:
104
- wandb.termlog(
105
- f"{LOG_PREFIX}agent {self._name} polling on queues {','.join(self._queues)} for jobs"
106
- )
107
- else:
108
- wandb.termlog(
109
- f"{LOG_PREFIX}agent {self._name} polling on project {self._project}, queues {','.join(self._queues)} for jobs"
110
- )
224
+ output_str = "agent "
225
+ if self._name:
226
+ output_str += f"{self._name} "
227
+ if self.num_running_jobs < self._max_jobs:
228
+ output_str += "polling on "
229
+ if self._project != LAUNCH_DEFAULT_PROJECT:
230
+ output_str += f"project {self._project}, "
231
+ output_str += f"queues {','.join(self._queues)}, "
232
+ output_str += (
233
+ f"running {self.num_running_jobs} out of a maximum of {self._max_jobs} jobs"
234
+ )
235
+
236
+ wandb.termlog(f"{LOG_PREFIX}{output_str}")
237
+ if self.num_running_jobs > 0:
238
+ output_str += f": {','.join(str(job_id) for job_id in self.thread_ids)}"
239
+
240
+ _logger.info(output_str)
111
241
 
112
242
  def update_status(self, status: str) -> None:
243
+ """Update the status of the agent.
244
+
245
+ Arguments:
246
+ status: Status to update the agent to.
247
+ """
113
248
  update_ret = self._api.update_launch_agent_status(
114
249
  self._id, status, self.gorilla_supports_agents
115
250
  )
116
251
  if not update_ret["success"]:
117
- wandb.termerror(f"Failed to update agent status to {status}")
252
+ wandb.termerror(f"{LOG_PREFIX}Failed to update agent status to {status}")
118
253
 
119
- def finish_job_id(self, job_id: Union[str, int]) -> None:
254
+ def finish_thread_id(self, thread_id: int) -> None:
120
255
  """Removes the job from our list for now."""
256
+ job_and_run_status = self._jobs[thread_id]
257
+ if not job_and_run_status.run_id or not job_and_run_status.project:
258
+ self.fail_run_queue_item(job_and_run_status.run_queue_item_id)
259
+ else:
260
+ run_info = None
261
+ # sweep runs exist but have no info before they are started
262
+ # so run_info returned will be None
263
+ # normal runs just throw a comm error
264
+ try:
265
+ run_info = self._api.get_run_info(
266
+ self._entity, job_and_run_status.project, job_and_run_status.run_id
267
+ )
268
+
269
+ except CommError:
270
+ pass
271
+ if run_info is None:
272
+ self.fail_run_queue_item(job_and_run_status.run_queue_item_id)
273
+
121
274
  # TODO: keep logs or something for the finished jobs
122
- del self._jobs[job_id]
123
- self._running -= 1
275
+ with self._jobs_lock:
276
+ del self._jobs[thread_id]
277
+
124
278
  # update status back to polling if no jobs are running
125
- if self._running == 0:
279
+ if len(self.thread_ids) == 0:
126
280
  self.update_status(AGENT_POLLING)
127
281
 
128
- def _update_finished(self, job_id: Union[int, str]) -> None:
282
+ def _update_finished(self, thread_id: int) -> None:
129
283
  """Check our status enum."""
130
- try:
131
- if self._jobs[job_id].get_status().state in ["failed", "finished"]:
132
- self.finish_job_id(job_id)
133
- except Exception:
134
- self.finish_job_id(job_id)
284
+ with self._jobs_lock:
285
+ job = self._jobs[thread_id]
286
+ if job.job_completed:
287
+ self.finish_thread_id(thread_id)
135
288
 
136
289
  def run_job(self, job: Dict[str, Any]) -> None:
137
- """Sets up project and runs the job."""
290
+ """Set up project and run the job.
291
+
292
+ Arguments:
293
+ job: Job to run.
294
+ """
138
295
  _msg = f"{LOG_PREFIX}Launch agent received job:\n{pprint.pformat(job)}\n"
139
296
  wandb.termlog(_msg)
140
297
  _logger.info(_msg)
@@ -151,81 +308,208 @@ class LaunchAgent:
151
308
  launch_spec["overrides"].get("args", [])
152
309
  )
153
310
 
154
- project = create_project_from_spec(launch_spec, self._api)
155
- _logger.info("Fetching and validating project...")
156
- project = fetch_and_validate_project(project, self._api)
157
- _logger.info("Fetching resource...")
158
- resource = launch_spec.get("resource") or "local-container"
159
- backend_config: Dict[str, Any] = {
160
- PROJECT_SYNCHRONOUS: False, # agent always runs async
161
- }
162
-
163
- backend_config["runQueueItemId"] = job["runQueueItemId"]
164
- _logger.info("Loading backend")
165
- override_build_config = launch_spec.get("build")
166
- override_registry_config = launch_spec.get("registry")
167
-
168
- build_config, registry_config = resolve_build_and_registry_config(
169
- self.default_config, override_build_config, override_registry_config
311
+ self._pool.apply_async(
312
+ self.thread_run_job,
313
+ (
314
+ launch_spec,
315
+ job,
316
+ self.default_config,
317
+ self._api,
318
+ ),
170
319
  )
171
- builder = load_builder(build_config)
172
-
173
- default_runner = self.default_config.get("runner", {}).get("type")
174
- if default_runner == resource:
175
- backend_config["runner"] = self.default_config.get("runner")
176
- backend = load_backend(resource, self._api, backend_config)
177
- backend.verify()
178
- _logger.info("Backend loaded...")
179
- run = backend.run(project, builder, registry_config)
180
- if run:
181
- self._jobs[run.id] = run
182
- self._running += 1
183
320
 
184
321
  def loop(self) -> None:
185
- """Main loop function for agent."""
322
+ """Loop infinitely to poll for jobs and run them.
323
+
324
+ Raises:
325
+ KeyboardInterrupt: if the agent is requested to stop.
326
+ """
186
327
  self.print_status()
187
328
  try:
188
329
  while True:
189
330
  self._ticks += 1
190
- job = None
191
-
192
331
  agent_response = self._api.get_launch_agent(
193
332
  self._id, self.gorilla_supports_agents
194
333
  )
195
- self._name = agent_response[
196
- "name"
197
- ] # hacky, but we don't return the name on create so this is first time
334
+ self._name = agent_response["name"] # hack: first time we get name
198
335
  if agent_response["stopPolling"]:
199
336
  # shutdown process and all jobs if requested from ui
200
337
  raise KeyboardInterrupt
201
- if self._running < self._max_jobs:
338
+ if self.num_running_jobs < self._max_jobs:
202
339
  # only check for new jobs if we're not at max
203
340
  for queue in self._queues:
204
341
  job = self.pop_from_queue(queue)
205
342
  if job:
343
+ if _job_is_scheduler(job.get("runSpec")):
344
+ # If job is a scheduler, and we are already at the cap, ignore,
345
+ # don't ack, and it will be pushed back onto the queue in 1 min
346
+ if self.num_running_schedulers >= self._max_schedulers:
347
+ wandb.termwarn(
348
+ f"{LOG_PREFIX}Agent already running the maximum number "
349
+ f"of sweep schedulers: {self._max_schedulers}. To set "
350
+ "this value use `max_schedulers` key in the agent config"
351
+ )
352
+ continue
353
+
206
354
  try:
207
355
  self.run_job(job)
208
- except Exception:
356
+ except Exception as e:
209
357
  wandb.termerror(
210
- f"Error running job: {traceback.format_exc()}"
358
+ f"{LOG_PREFIX}Error running job: {traceback.format_exc()}"
211
359
  )
212
- self._api.ack_run_queue_item(job["runQueueItemId"])
213
- for job_id in self.job_ids:
214
- self._update_finished(job_id)
360
+ util.sentry_exc(e)
361
+ self.fail_run_queue_item(job["runQueueItemId"])
362
+
363
+ for thread_id in self.thread_ids:
364
+ self._update_finished(thread_id)
215
365
  if self._ticks % 2 == 0:
216
- if self._running == 0:
366
+ if len(self.thread_ids) == 0:
217
367
  self.update_status(AGENT_POLLING)
218
- self.print_status()
219
368
  else:
220
369
  self.update_status(AGENT_RUNNING)
221
- time.sleep(AGENT_POLLING_INTERVAL)
370
+ self.print_status()
371
+
372
+ if (
373
+ self.num_running_jobs == self._max_jobs
374
+ or self.num_running_schedulers == 0
375
+ ):
376
+ # all threads busy or no schedulers running
377
+ time.sleep(AGENT_POLLING_INTERVAL)
378
+ else:
379
+ time.sleep(ACTIVE_SWEEP_POLLING_INTERVAL)
222
380
 
223
381
  except KeyboardInterrupt:
224
- # temp: for local, kill all jobs. we don't yet have good handling for different
225
- # types of runners in general
226
- for _, run in self._jobs.items():
227
- if isinstance(run, LocalSubmittedRun):
228
- run.command_proc.kill()
382
+ self._jobs_event.clear()
229
383
  self.update_status(AGENT_KILLED)
230
384
  wandb.termlog(f"{LOG_PREFIX}Shutting down, active jobs:")
231
385
  self.print_status()
386
+ self._pool.close()
387
+ self._pool.join()
388
+
389
+ # Threaded functions
390
+ def thread_run_job(
391
+ self,
392
+ launch_spec: Dict[str, Any],
393
+ job: Dict[str, Any],
394
+ default_config: Dict[str, Any],
395
+ api: Api,
396
+ ) -> None:
397
+ thread_id = threading.current_thread().ident
398
+ assert thread_id is not None
399
+ try:
400
+ self._thread_run_job(launch_spec, job, default_config, api, thread_id)
401
+ except LaunchDockerError as e:
402
+ wandb.termerror(
403
+ f"{LOG_PREFIX}agent {self._name} encountered an issue while starting Docker, see above output for details."
404
+ )
405
+ self.finish_thread_id(thread_id)
406
+ util.sentry_exc(e)
407
+ except Exception as e:
408
+ wandb.termerror(f"{LOG_PREFIX}Error running job: {traceback.format_exc()}")
409
+ self.finish_thread_id(thread_id)
410
+ util.sentry_exc(e)
411
+
412
+ def _thread_run_job(
413
+ self,
414
+ launch_spec: Dict[str, Any],
415
+ job: Dict[str, Any],
416
+ default_config: Dict[str, Any],
417
+ api: Api,
418
+ thread_id: int,
419
+ ) -> None:
420
+ job_tracker = JobAndRunStatus(job["runQueueItemId"])
421
+ with self._jobs_lock:
422
+ self._jobs[thread_id] = job_tracker
423
+ project = create_project_from_spec(launch_spec, api)
424
+ job_tracker.update_run_info(project)
425
+ _logger.info("Fetching and validating project...")
426
+ project = fetch_and_validate_project(project, api)
427
+ _logger.info("Fetching resource...")
428
+ resource = launch_spec.get("resource") or "local-container"
429
+ backend_config: Dict[str, Any] = {
430
+ PROJECT_SYNCHRONOUS: False, # agent always runs async
431
+ }
432
+ _logger.info("Loading backend")
433
+ override_build_config = launch_spec.get("builder")
434
+
435
+ build_config, registry_config = construct_builder_args(
436
+ default_config, override_build_config
437
+ )
438
+
439
+ environment = loader.environment_from_config(
440
+ default_config.get("environment", {})
441
+ )
442
+ registry = loader.registry_from_config(registry_config, environment)
443
+ builder = loader.builder_from_config(build_config, environment, registry)
444
+ backend = loader.runner_from_config(resource, api, backend_config, environment)
445
+ _logger.info("Backend loaded...")
446
+ api.ack_run_queue_item(job["runQueueItemId"], project.run_id)
447
+ run = backend.run(project, builder)
448
+
449
+ if _job_is_scheduler(launch_spec):
450
+ with self._jobs_lock:
451
+ self._jobs[thread_id].is_scheduler = True
452
+ wandb.termlog(
453
+ f"{LOG_PREFIX}Preparing to run sweep scheduler "
454
+ f"({self.num_running_schedulers}/{self._max_schedulers})"
455
+ )
456
+
457
+ if not run:
458
+ with self._jobs_lock:
459
+ job_tracker.failed_to_start = True
460
+ return
461
+ with self._jobs_lock:
462
+ job_tracker.run = run
463
+ while self._jobs_event.is_set():
464
+ if self._check_run_finished(job_tracker):
465
+ return
466
+ time.sleep(AGENT_POLLING_INTERVAL)
467
+ # temp: for local, kill all jobs. we don't yet have good handling for different
468
+ # types of runners in general
469
+ if isinstance(run, LocalSubmittedRun):
470
+ run.command_proc.kill()
471
+
472
+ def _check_run_finished(self, job_tracker: JobAndRunStatus) -> bool:
473
+ if job_tracker.completed:
474
+ return True
475
+
476
+ # the run can be done before the run has started
477
+ # but can also be none if the run failed to start
478
+ # so if there is no run, either the run hasn't started yet
479
+ # or it has failed
480
+ if job_tracker.run is None:
481
+ if job_tracker.failed_to_start:
482
+ return True
483
+ return False
484
+
485
+ known_error = False
486
+ try:
487
+ run = job_tracker.run
488
+ status = run.get_status().state
489
+ if status in ["stopped", "failed", "finished"]:
490
+ if job_tracker.is_scheduler:
491
+ wandb.termlog(f"{LOG_PREFIX}Scheduler finished with ID: {run.id}")
492
+ else:
493
+ wandb.termlog(f"{LOG_PREFIX}Job finished with ID: {run.id}")
494
+ with self._jobs_lock:
495
+ job_tracker.completed = True
496
+ return True
497
+ return False
498
+ except LaunchError as e:
499
+ wandb.termerror(
500
+ f"{LOG_PREFIX}Terminating job {run.id} because it failed to start: {str(e)}"
501
+ )
502
+ known_error = True
503
+ with self._jobs_lock:
504
+ job_tracker.failed_to_start = True
505
+ # TODO: make get_status robust to errors for each runner, and handle them
506
+ # TODO: add sentry to track this case and solve issues
507
+ except Exception:
508
+ wandb.termerror(f"{LOG_PREFIX}Error getting status for job {run.id}")
509
+ wandb.termerror(traceback.format_exc())
510
+ _logger.info("---")
511
+ _logger.info("Caught exception while getting status.")
512
+ _logger.info(f"Job ID: {run.id}")
513
+ _logger.info(traceback.format_exc())
514
+ _logger.info("---")
515
+ return known_error
@@ -1,5 +1,9 @@
1
+ """Abstract plugin class defining the interface needed to build container images for W&B Launch."""
1
2
  from abc import ABC, abstractmethod
2
- from typing import Any, Dict, Optional
3
+ from typing import Any, Dict
4
+
5
+ from wandb.sdk.launch.environment.abstract import AbstractEnvironment
6
+ from wandb.sdk.launch.registry.abstract import AbstractRegistry
3
7
 
4
8
  from .._project_spec import EntryPoint, LaunchProject
5
9
 
@@ -7,25 +11,75 @@ from .._project_spec import EntryPoint, LaunchProject
7
11
  class AbstractBuilder(ABC):
8
12
  """Abstract plugin class defining the interface needed to build container images for W&B Launch."""
9
13
 
10
- type: str
14
+ builder_type: str
15
+ environment: AbstractEnvironment
16
+ registry: AbstractRegistry
17
+ builder_config: Dict[str, Any]
18
+
19
+ @abstractmethod
20
+ def __init__(
21
+ self,
22
+ environment: AbstractEnvironment,
23
+ registry: AbstractRegistry,
24
+ verify: bool = True,
25
+ ) -> None:
26
+ """Initialize a builder.
27
+
28
+ Arguments:
29
+ builder_config: The builder config.
30
+ registry: The registry to use.
31
+ verify: Whether to verify the functionality of the builder.
32
+
33
+ Raises:
34
+ LaunchError: If the builder cannot be intialized or verified.
35
+ """
36
+ raise NotImplementedError
37
+
38
+ @classmethod
39
+ @abstractmethod
40
+ def from_config(
41
+ cls,
42
+ config: dict,
43
+ environment: AbstractEnvironment,
44
+ registry: AbstractRegistry,
45
+ verify: bool = True,
46
+ ) -> "AbstractBuilder":
47
+ """Create a builder from a config dictionary.
48
+
49
+ Arguments:
50
+ config: The config dictionary.
51
+ environment: The environment to use.
52
+ registry: The registry to use.
53
+ verify: Whether to verify the functionality of the builder.
54
+ login: Whether to login to the registry immediately.
11
55
 
12
- def __init__(self, builder_config: Dict[str, Any]) -> None:
13
- self.builder_config = builder_config
56
+ Returns:
57
+ The builder.
58
+ """
59
+ raise NotImplementedError
14
60
 
15
61
  @abstractmethod
16
62
  def build_image(
17
63
  self,
18
64
  launch_project: LaunchProject,
19
- registry: Optional[str],
20
65
  entrypoint: EntryPoint,
21
66
  ) -> str:
22
67
  """Build the image for the given project.
23
68
 
24
- Args:
69
+ Arguments:
25
70
  launch_project: The project to build.
26
71
  build_ctx_path: The path to the build context.
27
72
 
28
73
  Returns:
29
74
  The image name.
30
75
  """
31
- pass
76
+ raise NotImplementedError
77
+
78
+ @abstractmethod
79
+ def verify(self) -> None:
80
+ """Verify that the builder can be used to build images.
81
+
82
+ Raises:
83
+ LaunchError: If the builder cannot be used to build images.
84
+ """
85
+ raise NotImplementedError