wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (156) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/analytics/sentry.py +1 -0
  3. wandb/apis/importers/base.py +20 -5
  4. wandb/apis/importers/mlflow.py +7 -1
  5. wandb/apis/internal.py +12 -0
  6. wandb/apis/public.py +247 -1387
  7. wandb/apis/reports/_panels.py +58 -35
  8. wandb/beta/workflows.py +6 -7
  9. wandb/cli/cli.py +130 -60
  10. wandb/data_types.py +3 -1
  11. wandb/filesync/dir_watcher.py +21 -27
  12. wandb/filesync/step_checksum.py +8 -8
  13. wandb/filesync/step_prepare.py +23 -10
  14. wandb/filesync/step_upload.py +13 -13
  15. wandb/filesync/upload_job.py +4 -8
  16. wandb/integration/cohere/__init__.py +3 -0
  17. wandb/integration/cohere/cohere.py +21 -0
  18. wandb/integration/cohere/resolver.py +347 -0
  19. wandb/integration/gym/__init__.py +4 -6
  20. wandb/integration/huggingface/__init__.py +3 -0
  21. wandb/integration/huggingface/huggingface.py +18 -0
  22. wandb/integration/huggingface/resolver.py +213 -0
  23. wandb/integration/langchain/wandb_tracer.py +16 -179
  24. wandb/integration/openai/__init__.py +1 -3
  25. wandb/integration/openai/openai.py +11 -143
  26. wandb/integration/openai/resolver.py +111 -38
  27. wandb/integration/sagemaker/config.py +2 -2
  28. wandb/integration/tensorboard/log.py +4 -4
  29. wandb/old/settings.py +24 -7
  30. wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
  31. wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
  32. wandb/proto/wandb_deprecated.py +3 -1
  33. wandb/sdk/__init__.py +1 -1
  34. wandb/sdk/artifacts/__init__.py +0 -0
  35. wandb/sdk/artifacts/artifact.py +2101 -0
  36. wandb/sdk/artifacts/artifact_download_logger.py +42 -0
  37. wandb/sdk/artifacts/artifact_manifest.py +67 -0
  38. wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
  39. wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
  40. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
  41. wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
  42. wandb/sdk/artifacts/artifact_state.py +10 -0
  43. wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
  44. wandb/sdk/artifacts/exceptions.py +55 -0
  45. wandb/sdk/artifacts/storage_handler.py +59 -0
  46. wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
  47. wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
  48. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
  49. wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
  50. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
  51. wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
  52. wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
  53. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
  54. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
  55. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
  56. wandb/sdk/artifacts/storage_layout.py +6 -0
  57. wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
  58. wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
  59. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
  60. wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
  61. wandb/sdk/data_types/_dtypes.py +7 -12
  62. wandb/sdk/data_types/base_types/json_metadata.py +3 -2
  63. wandb/sdk/data_types/base_types/media.py +8 -8
  64. wandb/sdk/data_types/base_types/wb_value.py +12 -13
  65. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
  66. wandb/sdk/data_types/helper_types/classes.py +6 -8
  67. wandb/sdk/data_types/helper_types/image_mask.py +5 -6
  68. wandb/sdk/data_types/histogram.py +4 -3
  69. wandb/sdk/data_types/html.py +3 -4
  70. wandb/sdk/data_types/image.py +11 -9
  71. wandb/sdk/data_types/molecule.py +5 -3
  72. wandb/sdk/data_types/object_3d.py +7 -5
  73. wandb/sdk/data_types/plotly.py +3 -2
  74. wandb/sdk/data_types/saved_model.py +11 -11
  75. wandb/sdk/data_types/trace_tree.py +5 -4
  76. wandb/sdk/data_types/utils.py +3 -5
  77. wandb/sdk/data_types/video.py +5 -4
  78. wandb/sdk/integration_utils/auto_logging.py +215 -0
  79. wandb/sdk/interface/interface.py +15 -15
  80. wandb/sdk/internal/file_pusher.py +8 -16
  81. wandb/sdk/internal/file_stream.py +5 -11
  82. wandb/sdk/internal/handler.py +13 -1
  83. wandb/sdk/internal/internal_api.py +287 -13
  84. wandb/sdk/internal/job_builder.py +119 -30
  85. wandb/sdk/internal/sender.py +6 -26
  86. wandb/sdk/internal/settings_static.py +2 -0
  87. wandb/sdk/internal/system/assets/__init__.py +2 -0
  88. wandb/sdk/internal/system/assets/gpu.py +42 -0
  89. wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
  90. wandb/sdk/internal/system/env_probe_helpers.py +13 -0
  91. wandb/sdk/internal/system/system_info.py +3 -3
  92. wandb/sdk/internal/tb_watcher.py +32 -22
  93. wandb/sdk/internal/thread_local_settings.py +18 -0
  94. wandb/sdk/launch/_project_spec.py +57 -11
  95. wandb/sdk/launch/agent/agent.py +147 -65
  96. wandb/sdk/launch/agent/job_status_tracker.py +34 -0
  97. wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
  98. wandb/sdk/launch/builder/abstract.py +5 -1
  99. wandb/sdk/launch/builder/build.py +21 -18
  100. wandb/sdk/launch/builder/docker_builder.py +10 -4
  101. wandb/sdk/launch/builder/kaniko_builder.py +113 -23
  102. wandb/sdk/launch/builder/noop.py +6 -3
  103. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
  104. wandb/sdk/launch/environment/aws_environment.py +3 -2
  105. wandb/sdk/launch/environment/azure_environment.py +124 -0
  106. wandb/sdk/launch/environment/gcp_environment.py +2 -4
  107. wandb/sdk/launch/environment/local_environment.py +1 -1
  108. wandb/sdk/launch/errors.py +19 -0
  109. wandb/sdk/launch/github_reference.py +32 -19
  110. wandb/sdk/launch/launch.py +3 -8
  111. wandb/sdk/launch/launch_add.py +6 -2
  112. wandb/sdk/launch/loader.py +21 -2
  113. wandb/sdk/launch/registry/azure_container_registry.py +132 -0
  114. wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
  115. wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
  116. wandb/sdk/launch/registry/local_registry.py +2 -1
  117. wandb/sdk/launch/runner/abstract.py +24 -3
  118. wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
  119. wandb/sdk/launch/runner/local_container.py +103 -51
  120. wandb/sdk/launch/runner/local_process.py +1 -1
  121. wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
  122. wandb/sdk/launch/runner/vertex_runner.py +10 -5
  123. wandb/sdk/launch/sweeps/__init__.py +7 -9
  124. wandb/sdk/launch/sweeps/scheduler.py +307 -77
  125. wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
  126. wandb/sdk/launch/sweeps/utils.py +82 -35
  127. wandb/sdk/launch/utils.py +89 -75
  128. wandb/sdk/lib/_settings_toposort_generated.py +7 -0
  129. wandb/sdk/lib/capped_dict.py +26 -0
  130. wandb/sdk/lib/{git.py → gitlib.py} +76 -59
  131. wandb/sdk/lib/hashutil.py +12 -4
  132. wandb/sdk/lib/paths.py +96 -8
  133. wandb/sdk/lib/sock_client.py +2 -2
  134. wandb/sdk/lib/timer.py +1 -0
  135. wandb/sdk/service/server.py +22 -9
  136. wandb/sdk/service/server_sock.py +1 -1
  137. wandb/sdk/service/service.py +27 -8
  138. wandb/sdk/verify/verify.py +4 -7
  139. wandb/sdk/wandb_config.py +2 -6
  140. wandb/sdk/wandb_init.py +57 -53
  141. wandb/sdk/wandb_require.py +7 -0
  142. wandb/sdk/wandb_run.py +61 -223
  143. wandb/sdk/wandb_settings.py +28 -4
  144. wandb/testing/relay.py +15 -2
  145. wandb/util.py +74 -36
  146. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
  147. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
  148. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
  149. wandb/integration/langchain/util.py +0 -191
  150. wandb/sdk/interface/artifacts/__init__.py +0 -33
  151. wandb/sdk/interface/artifacts/artifact.py +0 -615
  152. wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
  153. wandb/sdk/wandb_artifacts.py +0 -2226
  154. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
  155. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
  156. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -58,12 +58,16 @@ def _link_and_save_file(
58
58
  interface.publish_files(dict(files=[(GlobStr(glob.escape(file_name)), "live")]))
59
59
 
60
60
 
61
- def is_tfevents_file_created_by(path: str, hostname: str, start_time: float) -> bool:
62
- """Check if a path is a tfevents file created by hostname.
61
+ def is_tfevents_file_created_by(
62
+ path: str, hostname: Optional[str], start_time: Optional[float]
63
+ ) -> bool:
64
+ """Check if a path is a tfevents file.
65
+
66
+ Optionally checks that it was created by [hostname] after [start_time].
63
67
 
64
68
  tensorboard tfevents filename format:
65
69
  https://github.com/tensorflow/tensorboard/blob/f3f26b46981da5bd46a5bb93fcf02d9eb7608bc1/tensorboard/summary/writer/event_file_writer.py#L81
66
- tensorflow tfevents fielname format:
70
+ tensorflow tfevents filename format:
67
71
  https://github.com/tensorflow/tensorflow/blob/8f597046dc30c14b5413813d02c0e0aed399c177/tensorflow/core/util/events_writer.cc#L68
68
72
  """
69
73
  if not path:
@@ -77,23 +81,27 @@ def is_tfevents_file_created_by(path: str, hostname: str, start_time: float) ->
77
81
  except ValueError:
78
82
  return False
79
83
  # check the hostname, which may have dots
80
- for i, part in enumerate(hostname.split(".")):
84
+ if hostname is not None:
85
+ for i, part in enumerate(hostname.split(".")):
86
+ try:
87
+ fname_component_part = fname_components[tfevents_idx + 2 + i]
88
+ except IndexError:
89
+ return False
90
+ if part != fname_component_part:
91
+ return False
92
+ if start_time is not None:
81
93
  try:
82
- fname_component_part = fname_components[tfevents_idx + 2 + i]
83
- except IndexError:
94
+ created_time = int(fname_components[tfevents_idx + 1])
95
+ except (ValueError, IndexError):
84
96
  return False
85
- if part != fname_component_part:
97
+ # Ensure that the file is newer then our start time, and that it was
98
+ # created from the same hostname.
99
+ # TODO: we should also check the PID (also contained in the tfevents
100
+ # filename). Can we assume that our parent pid is the user process
101
+ # that wrote these files?
102
+ if created_time < int(start_time):
86
103
  return False
87
- try:
88
- created_time = int(fname_components[tfevents_idx + 1])
89
- except (ValueError, IndexError):
90
- return False
91
- # Ensure that the file is newer then our start time, and that it was
92
- # created from the same hostname.
93
- # TODO: we should also check the PID (also contained in the tfevents
94
- # filename). Can we assume that our parent pid is the user process
95
- # that wrote these files?
96
- return created_time >= int(start_time)
104
+ return True
97
105
 
98
106
 
99
107
  class TBWatcher:
@@ -136,6 +144,7 @@ class TBWatcher:
136
144
  # Note that we strip '/' instead of os.sep, because elsewhere we've
137
145
  # converted paths to forward slash.
138
146
  namespace = logdir.replace(filename, "").replace(rootdir, "").strip("/")
147
+
139
148
  # TODO: revisit this heuristic, it exists because we don't know the
140
149
  # root log directory until more than one tfevents file is written to
141
150
  if len(dirs) == 1 and namespace not in ["train", "validation"]:
@@ -217,12 +226,13 @@ class TBDirWatcher:
217
226
  """Check if a path has been modified since launch and contains tfevents."""
218
227
  if not path:
219
228
  raise ValueError("Path must be a nonempty string")
220
- if self._force:
221
- return True
222
229
  path = self.tf_compat.tf.compat.as_str_any(path)
223
- return is_tfevents_file_created_by(
224
- path, self._hostname, self._tbwatcher._settings._start_time
225
- )
230
+ if self._force:
231
+ return is_tfevents_file_created_by(path, None, None)
232
+ else:
233
+ return is_tfevents_file_created_by(
234
+ path, self._hostname, self._tbwatcher._settings._start_time
235
+ )
226
236
 
227
237
  def _loader(
228
238
  self, save: bool = True, namespace: Optional[str] = None
@@ -0,0 +1,18 @@
1
+ import threading
2
+ from typing import Dict, Optional
3
+
4
+
5
+ # Context variable for setting API settings (api keys, etc.) for internal and public apis thread-locally
6
+ # TODO: move this into actual settings
7
+ class _ThreadLocalApiSettings(threading.local):
8
+ api_key: Optional[str]
9
+ cookies: Optional[Dict]
10
+ headers: Optional[Dict]
11
+
12
+ def __init__(self) -> None:
13
+ self.api_key = None
14
+ self.cookies = None
15
+ self.headers = None
16
+
17
+
18
+ _thread_local_api_settings: _ThreadLocalApiSettings = _ThreadLocalApiSettings()
@@ -7,17 +7,20 @@ import json
7
7
  import logging
8
8
  import os
9
9
  import tempfile
10
- from typing import Any, Dict, List, Optional
10
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
11
11
 
12
12
  import wandb
13
13
  import wandb.docker as docker
14
14
  from wandb.apis.internal import Api
15
- from wandb.apis.public import Artifact as PublicArtifact
16
15
  from wandb.errors import CommError
16
+ from wandb.sdk.launch import utils
17
17
  from wandb.sdk.lib.runid import generate_id
18
18
 
19
- from . import utils
20
- from .utils import LOG_PREFIX, LaunchError
19
+ from .errors import LaunchError
20
+ from .utils import LOG_PREFIX, recursive_macro_sub
21
+
22
+ if TYPE_CHECKING:
23
+ from wandb.sdk.artifacts.artifact import Artifact
21
24
 
22
25
  _logger = logging.getLogger(__name__)
23
26
 
@@ -59,6 +62,7 @@ class LaunchProject:
59
62
  resource: str,
60
63
  resource_args: Dict[str, Any],
61
64
  run_id: Optional[str],
65
+ sweep_id: Optional[str] = None,
62
66
  ):
63
67
  if uri is not None and utils.is_bare_wandb_uri(uri):
64
68
  uri = api.settings("base_url") + uri
@@ -67,7 +71,7 @@ class LaunchProject:
67
71
  self.job = job
68
72
  if job is not None:
69
73
  wandb.termlog(f"{LOG_PREFIX}Launching job: {job}")
70
- self._job_artifact: Optional[PublicArtifact] = None
74
+ self._job_artifact: Optional["Artifact"] = None
71
75
  self.api = api
72
76
  self.launch_spec = launch_spec
73
77
  self.target_entity = target_entity
@@ -78,11 +82,12 @@ class LaunchProject:
78
82
  # runner, so we need to pop the builder key out
79
83
  resource_args_build = resource_args.get(resource, {}).pop("builder", {})
80
84
  self.resource = resource
81
- self.resource_args = resource_args
85
+ self.resource_args = resource_args.copy()
86
+ self.sweep_id = sweep_id
82
87
  self.python_version: Optional[str] = launch_spec.get("python_version")
83
- self.cuda_base_image: Optional[str] = resource_args_build.get("cuda", {}).get(
84
- "base_image"
85
- )
88
+ self.accelerator_base_image: Optional[str] = resource_args_build.get(
89
+ "accelerator", {}
90
+ ).get("base_image") or resource_args_build.get("cuda", {}).get("base_image")
86
91
  self._base_image: Optional[str] = launch_spec.get("base_image")
87
92
  self.docker_image: Optional[str] = docker_config.get(
88
93
  "docker_image"
@@ -110,6 +115,9 @@ class LaunchProject:
110
115
  self.override_entrypoint = self.add_entry_point(
111
116
  overrides.get("entry_point") # type: ignore
112
117
  )
118
+ if overrides.get("sweep_id") is not None:
119
+ _logger.info("Adding override sweep id")
120
+ self.sweep_id = overrides["sweep_id"]
113
121
  if self.docker_image is not None:
114
122
  self.source = LaunchSource.DOCKER
115
123
  self.project_dir = None
@@ -172,6 +180,43 @@ class LaunchProject:
172
180
  assert self.job is not None
173
181
  return wandb.util.make_docker_image_name_safe(self.job.split(":")[0])
174
182
 
183
+ def fill_macros(self, image: str) -> None:
184
+ """Substitute values for macros in resource arguments.
185
+
186
+ Certain macros can be used in resource args. These macros allow the
187
+ user to set resource args dynamically in the context of the
188
+ run being launched. The macros are given in the ${macro} format. The
189
+ following macros are currently supported:
190
+
191
+ ${project_name} - the name of the project the run is being launched to.
192
+ ${entity_name} - the owner of the project the run being launched to.
193
+ ${run_id} - the id of the run being launched.
194
+ ${run_name} - the name of the run that is launching.
195
+ ${image_uri} - the URI of the container image for this run.
196
+
197
+ Additionally, you may use ${<ENV-VAR-NAME>} to refer to the value of any
198
+ environment variables that you plan to set in the environment of any
199
+ agents that will receive these resource args.
200
+
201
+ Calling this method will overwrite the contents of self.resource_args
202
+ with the substituted values.
203
+
204
+ Args:
205
+ image (str): The image name to fill in for ${wandb-image}.
206
+
207
+ Returns:
208
+ None
209
+ """
210
+ update_dict = {
211
+ "project_name": self.target_project,
212
+ "entity_name": self.target_entity,
213
+ "run_id": self.run_id,
214
+ "run_name": self.name,
215
+ "image_uri": image,
216
+ }
217
+ update_dict.update(os.environ)
218
+ self.resource_args = recursive_macro_sub(self.resource_args, update_dict)
219
+
175
220
  def build_required(self) -> bool:
176
221
  """Checks the source to see if a build is required."""
177
222
  # since the image tag for images built from jobs
@@ -416,6 +461,7 @@ def create_project_from_spec(launch_spec: Dict[str, Any], api: Api) -> LaunchPro
416
461
  launch_spec.get("resource", None),
417
462
  launch_spec.get("resource_args", {}),
418
463
  launch_spec.get("run_id", None),
464
+ launch_spec.get("sweep_id", {}),
419
465
  )
420
466
 
421
467
 
@@ -446,8 +492,8 @@ def fetch_and_validate_project(
446
492
  launch_project._fetch_project_local(internal_api=api)
447
493
 
448
494
  assert launch_project.project_dir is not None
449
- # this prioritizes pip, and we don't support any cases where both are present
450
- # conda projects when uploaded to wandb become pip projects via requirements.frozen.txt, wandb doesn't preserve conda envs
495
+ # this prioritizes pip, and we don't support any cases where both are present conda projects when uploaded to
496
+ # wandb become pip projects via requirements.frozen.txt, wandb doesn't preserve conda envs
451
497
  if os.path.exists(
452
498
  os.path.join(launch_project.project_dir, "requirements.txt")
453
499
  ) or os.path.exists(
@@ -5,7 +5,6 @@ import pprint
5
5
  import threading
6
6
  import time
7
7
  import traceback
8
- from dataclasses import dataclass
9
8
  from multiprocessing import Event
10
9
  from multiprocessing.pool import ThreadPool
11
10
  from typing import Any, Dict, List, Optional, Union
@@ -13,22 +12,18 @@ from typing import Any, Dict, List, Optional, Union
13
12
  import wandb
14
13
  from wandb.apis.internal import Api
15
14
  from wandb.errors import CommError
16
- from wandb.sdk.launch._project_spec import LaunchProject
15
+ from wandb.sdk.launch.launch_add import launch_add
17
16
  from wandb.sdk.launch.runner.local_container import LocalSubmittedRun
18
- from wandb.sdk.launch.sweeps import SCHEDULER_URI
17
+ from wandb.sdk.launch.sweeps.scheduler import Scheduler
19
18
  from wandb.sdk.lib import runid
20
19
 
21
20
  from .. import loader
22
21
  from .._project_spec import create_project_from_spec, fetch_and_validate_project
23
22
  from ..builder.build import construct_builder_args
24
- from ..runner.abstract import AbstractRun
25
- from ..utils import (
26
- LAUNCH_DEFAULT_PROJECT,
27
- LOG_PREFIX,
28
- PROJECT_SYNCHRONOUS,
29
- LaunchDockerError,
30
- LaunchError,
31
- )
23
+ from ..errors import LaunchDockerError, LaunchError
24
+ from ..utils import LAUNCH_DEFAULT_PROJECT, LOG_PREFIX, PROJECT_SYNCHRONOUS
25
+ from .job_status_tracker import JobAndRunStatusTracker
26
+ from .run_queue_item_file_saver import RunQueueItemFileSaver
32
27
 
33
28
  AGENT_POLLING_INTERVAL = 10
34
29
  ACTIVE_SWEEP_POLLING_INTERVAL = 1 # more frequent when we know we have jobs
@@ -37,30 +32,13 @@ AGENT_POLLING = "POLLING"
37
32
  AGENT_RUNNING = "RUNNING"
38
33
  AGENT_KILLED = "KILLED"
39
34
 
40
- MAX_THREADS = 64
41
-
42
- _logger = logging.getLogger(__name__)
35
+ HIDDEN_AGENT_RUN_TYPE = "sweep-controller"
43
36
 
37
+ MAX_THREADS = 64
44
38
 
45
- @dataclass
46
- class JobAndRunStatus:
47
- run_queue_item_id: str
48
- run_id: Optional[str] = None
49
- project: Optional[str] = None
50
- entity: Optional[str] = None
51
- run: Optional[AbstractRun] = None
52
- failed_to_start: bool = False
53
- completed_status: Optional[str] = None
54
- is_scheduler: bool = False
55
-
56
- @property
57
- def job_completed(self) -> bool:
58
- return self.failed_to_start or self.completed_status is not None
39
+ MAX_RESUME_COUNT = 5
59
40
 
60
- def update_run_info(self, launch_project: LaunchProject) -> None:
61
- self.run_id = launch_project.run_id
62
- self.project = launch_project.target_project
63
- self.entity = launch_project.target_entity
41
+ _logger = logging.getLogger(__name__)
64
42
 
65
43
 
66
44
  def _convert_access(access: str) -> str:
@@ -101,16 +79,21 @@ def _max_from_config(
101
79
  return max_from_config
102
80
 
103
81
 
104
- def _job_is_scheduler(run_spec: Dict[str, Any]) -> bool:
82
+ def _is_scheduler_job(run_spec: Dict[str, Any]) -> bool:
105
83
  """Determine whether a job/runSpec is a sweep scheduler."""
106
84
  if not run_spec:
107
- _logger.debug("Recieved runSpec in _job_is_scheduler that was empty")
85
+ _logger.debug("Recieved runSpec in _is_scheduler_job that was empty")
108
86
 
109
- if run_spec.get("uri") != SCHEDULER_URI:
87
+ if run_spec.get("uri") != Scheduler.PLACEHOLDER_URI:
110
88
  return False
111
89
 
112
90
  if run_spec.get("resource") == "local-process":
113
- # If a scheduler is a local-process (100%), also
91
+ # Any job pushed to a run queue that has a scheduler uri is
92
+ # allowed to use local-process
93
+ if run_spec.get("job"):
94
+ return True
95
+
96
+ # If a scheduler is local-process and run through CLI, also
114
97
  # confirm command is in format: [wandb scheduler <sweep>]
115
98
  cmd = run_spec.get("overrides", {}).get("entry_point", [])
116
99
  if len(cmd) < 3:
@@ -137,7 +120,7 @@ class LaunchAgent:
137
120
  self._api = api
138
121
  self._base_url = self._api.settings().get("base_url")
139
122
  self._ticks = 0
140
- self._jobs: Dict[int, JobAndRunStatus] = {}
123
+ self._jobs: Dict[int, JobAndRunStatusTracker] = {}
141
124
  self._jobs_lock = threading.Lock()
142
125
  self._jobs_event = Event()
143
126
  self._jobs_event.set()
@@ -169,15 +152,40 @@ class LaunchAgent:
169
152
  self.gorilla_supports_agents,
170
153
  )
171
154
  self._id = create_response["launchAgentId"]
172
- self._name = "" # hacky: want to display this to the user but we don't get it back from gql until polling starts. fix later
173
155
  if self._api.entity_is_team(self._entity):
174
156
  wandb.termwarn(
175
157
  f"{LOG_PREFIX}Agent is running on team entity ({self._entity}). Members of this team will be able to run code on this device."
176
158
  )
177
159
 
178
- def fail_run_queue_item(self, run_queue_item_id: str) -> None:
160
+ agent_response = self._api.get_launch_agent(
161
+ self._id, self.gorilla_supports_agents
162
+ )
163
+ self._name = agent_response["name"]
164
+ self._init_agent_run()
165
+
166
+ def fail_run_queue_item(
167
+ self,
168
+ run_queue_item_id: str,
169
+ message: str,
170
+ phase: str,
171
+ files: Optional[List[str]] = None,
172
+ ) -> None:
179
173
  if self._gorilla_supports_fail_run_queue_items:
180
- self._api.fail_run_queue_item(run_queue_item_id)
174
+ self._api.fail_run_queue_item(run_queue_item_id, message, phase, files)
175
+
176
+ def _init_agent_run(self) -> None:
177
+ # TODO: has it been long enough that all backends support agents?
178
+ if self.gorilla_supports_agents:
179
+ settings = wandb.Settings(silent=True, disable_git=True)
180
+ self._wandb_run = wandb.init(
181
+ project=self._project,
182
+ entity=self._entity,
183
+ settings=settings,
184
+ id=self._name,
185
+ job_type=HIDDEN_AGENT_RUN_TYPE,
186
+ )
187
+ else:
188
+ self._wandb_run = None
181
189
 
182
190
  @property
183
191
  def thread_ids(self) -> List[int]:
@@ -253,24 +261,43 @@ class LaunchAgent:
253
261
  if not update_ret["success"]:
254
262
  wandb.termerror(f"{LOG_PREFIX}Failed to update agent status to {status}")
255
263
 
256
- def finish_thread_id(self, thread_id: int) -> None:
264
+ def finish_thread_id(
265
+ self,
266
+ thread_id: int,
267
+ exception: Optional[Union[Exception, LaunchDockerError]] = None,
268
+ ) -> None:
257
269
  """Removes the job from our list for now."""
258
270
  job_and_run_status = self._jobs[thread_id]
259
- if not job_and_run_status.run_id or not job_and_run_status.project:
260
- self.fail_run_queue_item(job_and_run_status.run_queue_item_id)
261
- elif job_and_run_status.entity != self._entity:
271
+ if (
272
+ job_and_run_status.entity is not None
273
+ and job_and_run_status.entity != self._entity
274
+ ):
262
275
  _logger.info(
263
276
  "Skipping check for completed run status because run is on a different entity than agent"
264
277
  )
278
+ elif exception is not None:
279
+ tb_str = traceback.format_exception(
280
+ type(exception), value=exception, tb=exception.__traceback__
281
+ )
282
+ fnames = job_and_run_status.saver.save_contents(
283
+ "".join(tb_str), "error.log", "error"
284
+ )
285
+ self.fail_run_queue_item(
286
+ job_and_run_status.run_queue_item_id,
287
+ str(exception),
288
+ job_and_run_status.err_stage,
289
+ fnames,
290
+ )
265
291
  elif job_and_run_status.completed_status not in ["stopped", "failed"]:
266
292
  _logger.info(
267
293
  "Skipping check for completed run status because run was successful"
268
294
  )
269
- else:
295
+ elif job_and_run_status.run is not None:
270
296
  run_info = None
271
297
  # sweep runs exist but have no info before they are started
272
298
  # so run_info returned will be None
273
299
  # normal runs just throw a comm error
300
+ # TODO: make more clear
274
301
  try:
275
302
  run_info = self._api.get_run_info(
276
303
  self._entity, job_and_run_status.project, job_and_run_status.run_id
@@ -279,7 +306,22 @@ class LaunchAgent:
279
306
  except CommError:
280
307
  pass
281
308
  if run_info is None:
282
- self.fail_run_queue_item(job_and_run_status.run_queue_item_id)
309
+ _msg = "The submitted run was not successfully started"
310
+ fnames = None
311
+
312
+ logs = job_and_run_status.run.get_logs()
313
+ if logs:
314
+ fnames = job_and_run_status.saver.save_contents(
315
+ logs, "error.log", "error"
316
+ )
317
+ self.fail_run_queue_item(
318
+ job_and_run_status.run_queue_item_id, _msg, "run", fnames
319
+ )
320
+ else:
321
+ _logger.info("Finish thread id had no exception, ror run")
322
+ wandb._sentry.exception(
323
+ "launch agent called finish thread id on thread without run or exception"
324
+ )
283
325
 
284
326
  # TODO: keep logs or something for the finished jobs
285
327
  with self._jobs_lock:
@@ -296,7 +338,9 @@ class LaunchAgent:
296
338
  if job.job_completed:
297
339
  self.finish_thread_id(thread_id)
298
340
 
299
- def run_job(self, job: Dict[str, Any]) -> None:
341
+ def run_job(
342
+ self, job: Dict[str, Any], queue: str, file_saver: RunQueueItemFileSaver
343
+ ) -> None:
300
344
  """Set up project and run the job.
301
345
 
302
346
  Arguments:
@@ -322,6 +366,8 @@ class LaunchAgent:
322
366
  job,
323
367
  self.default_config,
324
368
  self._api,
369
+ queue,
370
+ file_saver,
325
371
  ),
326
372
  )
327
373
 
@@ -367,7 +413,6 @@ class LaunchAgent:
367
413
  agent_response = self._api.get_launch_agent(
368
414
  self._id, self.gorilla_supports_agents
369
415
  )
370
- self._name = agent_response["name"] # hack: first time we get name
371
416
  if agent_response["stopPolling"]:
372
417
  # shutdown process and all jobs if requested from ui
373
418
  raise KeyboardInterrupt
@@ -376,7 +421,10 @@ class LaunchAgent:
376
421
  for queue in self._queues:
377
422
  job = self.pop_from_queue(queue)
378
423
  if job:
379
- if _job_is_scheduler(job.get("runSpec")):
424
+ file_saver = RunQueueItemFileSaver(
425
+ self._wandb_run, job["runQueueItemId"]
426
+ )
427
+ if _is_scheduler_job(job.get("runSpec")):
380
428
  # If job is a scheduler, and we are already at the cap, ignore,
381
429
  # don't ack, and it will be pushed back onto the queue in 1 min
382
430
  if self.num_running_schedulers >= self._max_schedulers:
@@ -388,13 +436,25 @@ class LaunchAgent:
388
436
  continue
389
437
 
390
438
  try:
391
- self.run_job(job)
439
+ self.run_job(job, queue, file_saver)
392
440
  except Exception as e:
393
441
  wandb.termerror(
394
442
  f"{LOG_PREFIX}Error running job: {traceback.format_exc()}"
395
443
  )
396
444
  wandb._sentry.exception(e)
397
- self.fail_run_queue_item(job["runQueueItemId"])
445
+
446
+ # always the first phase, because we only enter phase 2 within the thread
447
+ files = file_saver.save_contents(
448
+ contents=traceback.format_exc(),
449
+ fname="error.log",
450
+ file_sub_type="error",
451
+ )
452
+ self.fail_run_queue_item(
453
+ run_queue_item_id=job["runQueueItemId"],
454
+ message=str(e),
455
+ phase="agent",
456
+ files=files,
457
+ )
398
458
 
399
459
  for thread_id in self.thread_ids:
400
460
  self._update_finished(thread_id)
@@ -429,20 +489,27 @@ class LaunchAgent:
429
489
  job: Dict[str, Any],
430
490
  default_config: Dict[str, Any],
431
491
  api: Api,
492
+ queue: str,
493
+ file_saver: RunQueueItemFileSaver,
432
494
  ) -> None:
433
495
  thread_id = threading.current_thread().ident
434
496
  assert thread_id is not None
497
+ job_tracker = JobAndRunStatusTracker(job["runQueueItemId"], queue, file_saver)
498
+ with self._jobs_lock:
499
+ self._jobs[thread_id] = job_tracker
435
500
  try:
436
- self._thread_run_job(launch_spec, job, default_config, api, thread_id)
501
+ self._thread_run_job(
502
+ launch_spec, job, default_config, api, queue, thread_id, job_tracker
503
+ )
437
504
  except LaunchDockerError as e:
438
505
  wandb.termerror(
439
506
  f"{LOG_PREFIX}agent {self._name} encountered an issue while starting Docker, see above output for details."
440
507
  )
441
- self.finish_thread_id(thread_id)
508
+ self.finish_thread_id(thread_id, e)
442
509
  wandb._sentry.exception(e)
443
510
  except Exception as e:
444
511
  wandb.termerror(f"{LOG_PREFIX}Error running job: {traceback.format_exc()}")
445
- self.finish_thread_id(thread_id)
512
+ self.finish_thread_id(thread_id, e)
446
513
  wandb._sentry.exception(e)
447
514
 
448
515
  def _thread_run_job(
@@ -451,11 +518,10 @@ class LaunchAgent:
451
518
  job: Dict[str, Any],
452
519
  default_config: Dict[str, Any],
453
520
  api: Api,
521
+ queue: str,
454
522
  thread_id: int,
523
+ job_tracker: JobAndRunStatusTracker,
455
524
  ) -> None:
456
- job_tracker = JobAndRunStatus(job["runQueueItemId"])
457
- with self._jobs_lock:
458
- self._jobs[thread_id] = job_tracker
459
525
  project = create_project_from_spec(launch_spec, api)
460
526
  job_tracker.update_run_info(project)
461
527
  _logger.info("Fetching and validating project...")
@@ -480,9 +546,8 @@ class LaunchAgent:
480
546
  backend = loader.runner_from_config(resource, api, backend_config, environment)
481
547
  _logger.info("Backend loaded...")
482
548
  api.ack_run_queue_item(job["runQueueItemId"], project.run_id)
483
- run = backend.run(project, builder)
484
-
485
- if _job_is_scheduler(launch_spec):
549
+ run = backend.run(project, builder, job_tracker)
550
+ if _is_scheduler_job(launch_spec):
486
551
  with self._jobs_lock:
487
552
  self._jobs[thread_id].is_scheduler = True
488
553
  wandb.termlog(
@@ -497,15 +562,17 @@ class LaunchAgent:
497
562
  with self._jobs_lock:
498
563
  job_tracker.run = run
499
564
  while self._jobs_event.is_set():
500
- if self._check_run_finished(job_tracker):
565
+ if self._check_run_finished(job_tracker, launch_spec):
501
566
  return
502
567
  time.sleep(AGENT_POLLING_INTERVAL)
503
568
  # temp: for local, kill all jobs. we don't yet have good handling for different
504
569
  # types of runners in general
505
- if isinstance(run, LocalSubmittedRun):
506
- run.command_proc.kill()
570
+ if isinstance(run, LocalSubmittedRun) and run._command_proc is not None:
571
+ run._command_proc.kill()
507
572
 
508
- def _check_run_finished(self, job_tracker: JobAndRunStatus) -> bool:
573
+ def _check_run_finished(
574
+ self, job_tracker: JobAndRunStatusTracker, launch_spec: Dict[str, Any]
575
+ ) -> bool:
509
576
  if job_tracker.completed_status:
510
577
  return True
511
578
 
@@ -522,13 +589,28 @@ class LaunchAgent:
522
589
  try:
523
590
  run = job_tracker.run
524
591
  status = run.get_status().state
525
- if status in ["stopped", "failed", "finished"]:
592
+ if status in ["stopped", "failed", "finished", "preempted"]:
526
593
  if job_tracker.is_scheduler:
527
594
  wandb.termlog(f"{LOG_PREFIX}Scheduler finished with ID: {run.id}")
528
595
  else:
529
596
  wandb.termlog(f"{LOG_PREFIX}Job finished with ID: {run.id}")
530
597
  with self._jobs_lock:
531
598
  job_tracker.completed_status = status
599
+ if status == "preempted":
600
+ config = launch_spec.copy()
601
+ config["run_id"] = job_tracker.run_id
602
+ config["_resume_count"] = config.get("_resume_count", 0) + 1
603
+ if config["_resume_count"] > MAX_RESUME_COUNT:
604
+ wandb.termlog(
605
+ f"{LOG_PREFIX}Run {job_tracker.run_id} has already resumed {MAX_RESUME_COUNT} times."
606
+ )
607
+ return True
608
+ wandb.termlog(f"{LOG_PREFIX}Requeueing run {job_tracker.run_id}.")
609
+ launch_add(
610
+ config=config,
611
+ project_queue=self._project,
612
+ queue_name=job_tracker.queue,
613
+ )
532
614
  return True
533
615
  return False
534
616
  except LaunchError as e:
@@ -0,0 +1,34 @@
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ from wandb.sdk.launch._project_spec import LaunchProject
5
+
6
+ from ..runner.abstract import AbstractRun
7
+ from .run_queue_item_file_saver import RunQueueItemFileSaver
8
+
9
+
10
+ @dataclass
11
+ class JobAndRunStatusTracker:
12
+ run_queue_item_id: str
13
+ queue: str
14
+ saver: RunQueueItemFileSaver
15
+ run_id: Optional[str] = None
16
+ project: Optional[str] = None
17
+ entity: Optional[str] = None
18
+ run: Optional[AbstractRun] = None
19
+ failed_to_start: bool = False
20
+ completed_status: Optional[str] = None
21
+ is_scheduler: bool = False
22
+ err_stage: str = "agent"
23
+
24
+ @property
25
+ def job_completed(self) -> bool:
26
+ return self.failed_to_start or self.completed_status is not None
27
+
28
+ def update_run_info(self, launch_project: LaunchProject) -> None:
29
+ self.run_id = launch_project.run_id
30
+ self.project = launch_project.target_project
31
+ self.entity = launch_project.target_entity
32
+
33
+ def set_err_stage(self, stage: str) -> None:
34
+ self.err_stage = stage