wandb 0.16.6__py3-none-any.whl → 0.17.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (193) hide show
  1. package_readme.md +95 -0
  2. wandb/__init__.py +2 -3
  3. wandb/agents/pyagent.py +0 -1
  4. wandb/analytics/sentry.py +2 -1
  5. wandb/apis/importers/internals/internal.py +0 -1
  6. wandb/apis/importers/internals/protocols.py +30 -56
  7. wandb/apis/importers/mlflow.py +13 -26
  8. wandb/apis/importers/wandb.py +8 -14
  9. wandb/apis/internal.py +0 -3
  10. wandb/apis/public/api.py +55 -3
  11. wandb/apis/public/artifacts.py +1 -0
  12. wandb/apis/public/files.py +1 -0
  13. wandb/apis/public/history.py +1 -0
  14. wandb/apis/public/jobs.py +17 -4
  15. wandb/apis/public/projects.py +1 -0
  16. wandb/apis/public/reports.py +1 -0
  17. wandb/apis/public/runs.py +15 -17
  18. wandb/apis/public/sweeps.py +1 -0
  19. wandb/apis/public/teams.py +1 -0
  20. wandb/apis/public/users.py +1 -0
  21. wandb/apis/reports/v1/_blocks.py +3 -7
  22. wandb/apis/reports/v2/gql.py +1 -0
  23. wandb/apis/reports/v2/interface.py +3 -4
  24. wandb/apis/reports/v2/internal.py +5 -8
  25. wandb/cli/cli.py +92 -22
  26. wandb/data_types.py +9 -6
  27. wandb/docker/__init__.py +1 -1
  28. wandb/env.py +38 -8
  29. wandb/errors/__init__.py +5 -0
  30. wandb/errors/term.py +10 -2
  31. wandb/filesync/step_checksum.py +1 -4
  32. wandb/filesync/step_prepare.py +4 -24
  33. wandb/filesync/step_upload.py +4 -106
  34. wandb/filesync/upload_job.py +0 -76
  35. wandb/integration/catboost/catboost.py +1 -1
  36. wandb/integration/fastai/__init__.py +1 -0
  37. wandb/integration/huggingface/resolver.py +2 -2
  38. wandb/integration/keras/__init__.py +1 -0
  39. wandb/integration/keras/callbacks/metrics_logger.py +1 -1
  40. wandb/integration/keras/keras.py +7 -7
  41. wandb/integration/langchain/wandb_tracer.py +1 -0
  42. wandb/integration/lightning/fabric/logger.py +1 -3
  43. wandb/integration/metaflow/metaflow.py +41 -6
  44. wandb/integration/openai/fine_tuning.py +3 -3
  45. wandb/integration/prodigy/prodigy.py +1 -1
  46. wandb/old/summary.py +1 -1
  47. wandb/plot/confusion_matrix.py +1 -1
  48. wandb/plot/pr_curve.py +2 -1
  49. wandb/plot/roc_curve.py +2 -1
  50. wandb/{plots → plot}/utils.py +13 -25
  51. wandb/proto/v3/wandb_internal_pb2.py +364 -332
  52. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  53. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  54. wandb/proto/v4/wandb_internal_pb2.py +322 -316
  55. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  56. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  57. wandb/proto/wandb_deprecated.py +7 -1
  58. wandb/proto/wandb_internal_codegen.py +3 -29
  59. wandb/sdk/artifacts/artifact.py +26 -11
  60. wandb/sdk/artifacts/artifact_download_logger.py +1 -0
  61. wandb/sdk/artifacts/artifact_file_cache.py +18 -4
  62. wandb/sdk/artifacts/artifact_instance_cache.py +1 -0
  63. wandb/sdk/artifacts/artifact_manifest.py +1 -0
  64. wandb/sdk/artifacts/artifact_manifest_entry.py +7 -3
  65. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
  66. wandb/sdk/artifacts/artifact_saver.py +2 -8
  67. wandb/sdk/artifacts/artifact_state.py +1 -0
  68. wandb/sdk/artifacts/artifact_ttl.py +1 -0
  69. wandb/sdk/artifacts/exceptions.py +1 -0
  70. wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
  71. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +13 -18
  72. wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -0
  73. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -0
  74. wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -0
  75. wandb/sdk/artifacts/storage_handlers/s3_handler.py +5 -3
  76. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +1 -0
  77. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -0
  78. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -0
  79. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +3 -42
  80. wandb/sdk/artifacts/storage_policy.py +2 -12
  81. wandb/sdk/data_types/_dtypes.py +8 -8
  82. wandb/sdk/data_types/base_types/media.py +3 -6
  83. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +3 -1
  84. wandb/sdk/data_types/image.py +1 -1
  85. wandb/sdk/data_types/video.py +1 -1
  86. wandb/sdk/integration_utils/auto_logging.py +5 -6
  87. wandb/sdk/integration_utils/data_logging.py +10 -6
  88. wandb/sdk/interface/interface.py +68 -32
  89. wandb/sdk/interface/interface_shared.py +7 -13
  90. wandb/sdk/internal/datastore.py +1 -1
  91. wandb/sdk/internal/file_pusher.py +2 -5
  92. wandb/sdk/internal/file_stream.py +5 -18
  93. wandb/sdk/internal/handler.py +18 -2
  94. wandb/sdk/internal/internal.py +0 -1
  95. wandb/sdk/internal/internal_api.py +1 -129
  96. wandb/sdk/internal/internal_util.py +0 -1
  97. wandb/sdk/internal/job_builder.py +159 -45
  98. wandb/sdk/internal/profiler.py +1 -0
  99. wandb/sdk/internal/progress.py +0 -28
  100. wandb/sdk/internal/run.py +1 -0
  101. wandb/sdk/internal/sender.py +1 -2
  102. wandb/sdk/internal/system/assets/gpu_amd.py +44 -44
  103. wandb/sdk/internal/system/assets/gpu_apple.py +56 -11
  104. wandb/sdk/internal/system/assets/interfaces.py +6 -8
  105. wandb/sdk/internal/system/assets/open_metrics.py +2 -2
  106. wandb/sdk/internal/system/assets/trainium.py +1 -3
  107. wandb/sdk/launch/__init__.py +9 -1
  108. wandb/sdk/launch/_launch.py +4 -24
  109. wandb/sdk/launch/_launch_add.py +1 -3
  110. wandb/sdk/launch/_project_spec.py +186 -224
  111. wandb/sdk/launch/agent/agent.py +37 -13
  112. wandb/sdk/launch/agent/config.py +72 -14
  113. wandb/sdk/launch/builder/abstract.py +69 -1
  114. wandb/sdk/launch/builder/build.py +156 -555
  115. wandb/sdk/launch/builder/context_manager.py +235 -0
  116. wandb/sdk/launch/builder/docker_builder.py +8 -23
  117. wandb/sdk/launch/builder/kaniko_builder.py +12 -25
  118. wandb/sdk/launch/builder/noop.py +1 -0
  119. wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
  120. wandb/sdk/launch/create_job.py +47 -37
  121. wandb/sdk/launch/environment/abstract.py +1 -0
  122. wandb/sdk/launch/environment/gcp_environment.py +1 -0
  123. wandb/sdk/launch/environment/local_environment.py +1 -0
  124. wandb/sdk/launch/inputs/files.py +148 -0
  125. wandb/sdk/launch/inputs/internal.py +217 -0
  126. wandb/sdk/launch/inputs/manage.py +95 -0
  127. wandb/sdk/launch/loader.py +1 -0
  128. wandb/sdk/launch/registry/abstract.py +1 -0
  129. wandb/sdk/launch/registry/azure_container_registry.py +1 -0
  130. wandb/sdk/launch/registry/elastic_container_registry.py +1 -0
  131. wandb/sdk/launch/registry/google_artifact_registry.py +2 -1
  132. wandb/sdk/launch/registry/local_registry.py +1 -0
  133. wandb/sdk/launch/runner/abstract.py +1 -0
  134. wandb/sdk/launch/runner/kubernetes_monitor.py +1 -0
  135. wandb/sdk/launch/runner/kubernetes_runner.py +9 -10
  136. wandb/sdk/launch/runner/local_container.py +2 -3
  137. wandb/sdk/launch/runner/local_process.py +8 -29
  138. wandb/sdk/launch/runner/sagemaker_runner.py +21 -20
  139. wandb/sdk/launch/runner/vertex_runner.py +8 -7
  140. wandb/sdk/launch/sweeps/scheduler.py +4 -3
  141. wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
  142. wandb/sdk/launch/sweeps/utils.py +3 -3
  143. wandb/sdk/launch/utils.py +15 -140
  144. wandb/sdk/lib/_settings_toposort_generated.py +0 -5
  145. wandb/sdk/lib/fsm.py +8 -12
  146. wandb/sdk/lib/gitlib.py +4 -4
  147. wandb/sdk/lib/import_hooks.py +1 -1
  148. wandb/sdk/lib/lazyloader.py +0 -1
  149. wandb/sdk/lib/proto_util.py +23 -2
  150. wandb/sdk/lib/redirect.py +19 -14
  151. wandb/sdk/lib/retry.py +3 -2
  152. wandb/sdk/lib/tracelog.py +1 -1
  153. wandb/sdk/service/service.py +19 -16
  154. wandb/sdk/verify/verify.py +2 -1
  155. wandb/sdk/wandb_init.py +14 -55
  156. wandb/sdk/wandb_manager.py +2 -2
  157. wandb/sdk/wandb_require.py +5 -0
  158. wandb/sdk/wandb_run.py +114 -56
  159. wandb/sdk/wandb_settings.py +0 -48
  160. wandb/sdk/wandb_setup.py +1 -1
  161. wandb/sklearn/__init__.py +1 -0
  162. wandb/sklearn/plot/__init__.py +1 -0
  163. wandb/sklearn/plot/classifier.py +11 -12
  164. wandb/sklearn/plot/clusterer.py +2 -1
  165. wandb/sklearn/plot/regressor.py +1 -0
  166. wandb/sklearn/plot/shared.py +1 -0
  167. wandb/sklearn/utils.py +1 -0
  168. wandb/testing/relay.py +4 -4
  169. wandb/trigger.py +1 -0
  170. wandb/util.py +67 -54
  171. wandb/wandb_controller.py +2 -3
  172. wandb/wandb_torch.py +1 -2
  173. {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/METADATA +67 -70
  174. {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/RECORD +177 -187
  175. {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/WHEEL +1 -2
  176. wandb/bin/apple_gpu_stats +0 -0
  177. wandb/catboost/__init__.py +0 -9
  178. wandb/fastai/__init__.py +0 -9
  179. wandb/keras/__init__.py +0 -18
  180. wandb/lightgbm/__init__.py +0 -9
  181. wandb/plots/__init__.py +0 -6
  182. wandb/plots/explain_text.py +0 -36
  183. wandb/plots/heatmap.py +0 -81
  184. wandb/plots/named_entity.py +0 -43
  185. wandb/plots/part_of_speech.py +0 -50
  186. wandb/plots/plot_definitions.py +0 -768
  187. wandb/plots/precision_recall.py +0 -121
  188. wandb/plots/roc.py +0 -103
  189. wandb/sacred/__init__.py +0 -3
  190. wandb/xgboost/__init__.py +0 -9
  191. wandb-0.16.6.dist-info/top_level.txt +0 -1
  192. {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/entry_points.txt +0 -0
  193. {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info/licenses}/LICENSE +0 -0
@@ -2,18 +2,20 @@
2
2
 
3
3
  Arguments can come from a launch spec or call to wandb launch.
4
4
  """
5
+
5
6
  import enum
7
+ import json
6
8
  import logging
7
9
  import os
8
10
  import tempfile
9
11
  from copy import deepcopy
10
12
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
11
13
 
14
+ from six.moves import shlex_quote
15
+
12
16
  import wandb
13
- import wandb.docker as docker
14
17
  from wandb.apis.internal import Api
15
18
  from wandb.errors import CommError
16
- from wandb.sdk.launch import utils
17
19
  from wandb.sdk.launch.utils import get_entrypoint_file
18
20
  from wandb.sdk.lib.runid import generate_id
19
21
 
@@ -33,15 +35,18 @@ IMAGE_TAG_MAX_LENGTH = 32
33
35
 
34
36
 
35
37
  class LaunchSource(enum.IntEnum):
36
- WANDB: int = 1
37
- GIT: int = 2
38
- LOCAL: int = 3
39
- DOCKER: int = 4
40
- JOB: int = 5
38
+ """Enumeration of possible sources for a launch project.
41
39
 
40
+ Attributes:
41
+ DOCKER: Source is a Docker image. This can happen if a user runs
42
+ `wandb launch -d <docker-image>`.
43
+ JOB: Source is a job. This is standard case.
44
+ SCHEDULER: Source is a wandb sweep scheduler command.
45
+ """
42
46
 
43
- class EntrypointDefaults(List[str]):
44
- PYTHON = ["python", "main.py"]
47
+ DOCKER: int = 1
48
+ JOB: int = 2
49
+ SCHEDULER: int = 3
45
50
 
46
51
 
47
52
  class LaunchProject:
@@ -60,8 +65,16 @@ class LaunchProject:
60
65
 
61
66
  This class is stateful and certain methods can only be called after
62
67
  `LaunchProject.fetch_and_validate_project()` has been called.
68
+
69
+ Notes on the entrypoint:
70
+ - The entrypoint is the command that will be run inside the container.
71
+ - The LaunchProject stores two entrypoints
72
+ - The job entrypoint is the entrypoint specified in the job's config.
73
+ - The override entrypoint is the entrypoint specified in the launch spec.
74
+ - The override entrypoint takes precedence over the job entrypoint.
63
75
  """
64
76
 
77
+ # This init is way to long, and there are too many attributes on this sucker.
65
78
  def __init__(
66
79
  self,
67
80
  uri: Optional[str],
@@ -79,9 +92,6 @@ class LaunchProject:
79
92
  run_id: Optional[str],
80
93
  sweep_id: Optional[str] = None,
81
94
  ):
82
- if uri is not None and utils.is_bare_wandb_uri(uri):
83
- uri = api.settings("base_url") + uri
84
- _logger.info(f"{LOG_PREFIX}Updating uri with base uri: {uri}")
85
95
  self.uri = uri
86
96
  self.job = job
87
97
  if job is not None:
@@ -105,74 +115,57 @@ class LaunchProject:
105
115
  self.accelerator_base_image: Optional[str] = resource_args_build.get(
106
116
  "accelerator", {}
107
117
  ).get("base_image") or resource_args_build.get("cuda", {}).get("base_image")
108
- self._base_image: Optional[str] = launch_spec.get("base_image")
109
118
  self.docker_image: Optional[str] = docker_config.get(
110
119
  "docker_image"
111
120
  ) or launch_spec.get("image_uri")
112
- uid = RESOURCE_UID_MAP.get(resource, 1000)
113
- if self._base_image:
114
- uid = docker.get_image_uid(self._base_image)
115
- _logger.info(f"{LOG_PREFIX}Retrieved base image uid {uid}")
116
- self.docker_user_id: int = docker_config.get("user_id", uid)
117
- self.git_version: Optional[str] = git_info.get("version")
118
- self.git_repo: Optional[str] = git_info.get("repo")
119
- self.overrides = overrides
120
- self.override_args: List[str] = overrides.get("args", [])
121
- self.override_config: Dict[str, Any] = overrides.get("run_config", {})
122
- self.override_artifacts: Dict[str, Any] = overrides.get("artifacts", {})
123
- self.override_entrypoint: Optional[EntryPoint] = None
124
- self.override_dockerfile: Optional[str] = overrides.get("dockerfile")
121
+ self.docker_user_id = docker_config.get("user_id", 1000)
122
+ self._entry_point: Optional[EntryPoint] = (
123
+ None # todo: keep multiple entrypoint support?
124
+ )
125
+ self.init_overrides(overrides)
126
+ self.init_source()
127
+ self.init_git(git_info)
125
128
  self.deps_type: Optional[str] = None
126
129
  self._runtime: Optional[str] = None
127
130
  self.run_id = run_id or generate_id()
128
131
  self._queue_name: Optional[str] = None
129
132
  self._queue_entity: Optional[str] = None
130
133
  self._run_queue_item_id: Optional[str] = None
131
- self._entry_point: Optional[
132
- EntryPoint
133
- ] = None # todo: keep multiple entrypoint support?
134
-
135
- override_entrypoint = overrides.get("entry_point")
136
- if override_entrypoint:
137
- _logger.info("Adding override entry point")
138
- self.override_entrypoint = EntryPoint(
139
- name=get_entrypoint_file(override_entrypoint),
140
- command=override_entrypoint,
141
- )
134
+ self._job_dockerfile: Optional[str] = None
135
+ self._job_build_context: Optional[str] = None
142
136
 
143
- if overrides.get("sweep_id") is not None:
144
- _logger.info("Adding override sweep id")
145
- self.sweep_id = overrides["sweep_id"]
137
+ def init_source(self) -> None:
146
138
  if self.docker_image is not None:
147
139
  self.source = LaunchSource.DOCKER
148
140
  self.project_dir = None
149
141
  elif self.job is not None:
150
142
  self.source = LaunchSource.JOB
151
143
  self.project_dir = tempfile.mkdtemp()
152
- elif self.uri is not None and utils._is_wandb_uri(self.uri):
153
- _logger.info(f"URI {self.uri} indicates a wandb uri")
154
- self.source = LaunchSource.WANDB
155
- self.project_dir = tempfile.mkdtemp()
156
- elif self.uri is not None and utils._is_git_uri(self.uri):
157
- _logger.info(f"URI {self.uri} indicates a git uri")
158
- self.source = LaunchSource.GIT
159
- self.project_dir = tempfile.mkdtemp()
160
- elif self.uri is not None and "placeholder-" in self.uri:
161
- wandb.termlog(
162
- f"{LOG_PREFIX}Launch received placeholder URI, replacing with local path."
144
+ if self.uri and self.uri.startswith("placeholder"):
145
+ self.source = LaunchSource.SCHEDULER
146
+ self.project_dir = os.getcwd()
147
+ self._entry_point = self.override_entrypoint
148
+
149
+ def init_git(self, git_info: Dict[str, str]) -> None:
150
+ self.git_version = git_info.get("version")
151
+ self.git_repo = git_info.get("repo")
152
+
153
+ def init_overrides(self, overrides: Dict[str, Any]) -> None:
154
+ """Initialize override attributes for a launch project."""
155
+ self.overrides = overrides
156
+ self.override_args: List[str] = overrides.get("args", [])
157
+ self.override_config: Dict[str, Any] = overrides.get("run_config", {})
158
+ self.override_artifacts: Dict[str, Any] = overrides.get("artifacts", {})
159
+ self.override_files: Dict[str, Any] = overrides.get("files", {})
160
+ self.override_entrypoint: Optional[EntryPoint] = None
161
+ self.override_dockerfile: Optional[str] = overrides.get("dockerfile")
162
+ override_entrypoint = overrides.get("entry_point")
163
+ if override_entrypoint:
164
+ _logger.info("Adding override entry point")
165
+ self.override_entrypoint = EntryPoint(
166
+ name=get_entrypoint_file(override_entrypoint),
167
+ command=override_entrypoint,
163
168
  )
164
- self.uri = os.getcwd()
165
- self.source = LaunchSource.LOCAL
166
- self.project_dir = self.uri
167
- else:
168
- _logger.info(f"URI {self.uri} indicates a local uri")
169
- # assume local
170
- if self.uri is not None and not os.path.exists(self.uri):
171
- raise LaunchError(
172
- "Assumed URI supplied is a local path but path is not valid"
173
- )
174
- self.source = LaunchSource.LOCAL
175
- self.project_dir = self.uri
176
169
 
177
170
  def __repr__(self) -> str:
178
171
  """String representation of LaunchProject."""
@@ -211,6 +204,20 @@ class LaunchProject:
211
204
  launch_spec.get("sweep_id", {}),
212
205
  )
213
206
 
207
+ @property
208
+ def job_dockerfile(self) -> Optional[str]:
209
+ return self._job_dockerfile
210
+
211
+ @property
212
+ def job_build_context(self) -> Optional[str]:
213
+ return self._job_build_context
214
+
215
+ def set_job_dockerfile(self, dockerfile: str) -> None:
216
+ self._job_dockerfile = dockerfile
217
+
218
+ def set_job_build_context(self, build_context: str) -> None:
219
+ self._job_build_context = build_context
220
+
214
221
  @property
215
222
  def image_name(self) -> str:
216
223
  if self.docker_image is not None:
@@ -274,7 +281,7 @@ class LaunchProject:
274
281
  image (str): The image name to fill in for ${wandb-image}.
275
282
 
276
283
  Returns:
277
- None
284
+ Dict[str, Any]: The resource args with all macros filled in.
278
285
  """
279
286
  update_dict = {
280
287
  "project_name": self.target_project,
@@ -324,8 +331,8 @@ class LaunchProject:
324
331
  self._docker_image = value
325
332
  self._ensure_not_docker_image_and_local_process()
326
333
 
327
- def get_single_entry_point(self) -> Optional["EntryPoint"]:
328
- """Returns the first entrypoint for the project, or None if no entry point was provided because a docker image was provided."""
334
+ def get_job_entry_point(self) -> Optional["EntryPoint"]:
335
+ """Returns the job entrypoint for the project."""
329
336
  # assuming project only has 1 entry point, pull that out
330
337
  # tmp fn until we figure out if we want to support multiple entry points or not
331
338
  if not self._entry_point:
@@ -336,8 +343,8 @@ class LaunchProject:
336
343
  return None
337
344
  return self._entry_point
338
345
 
339
- def set_entry_point(self, command: List[str]) -> "EntryPoint":
340
- """Add an entry point to the project."""
346
+ def set_job_entry_point(self, command: List[str]) -> "EntryPoint":
347
+ """Set job entrypoint for the project."""
341
348
  assert (
342
349
  self._entry_point is None
343
350
  ), "Cannot set entry point twice. Use LaunchProject.override_entrypoint"
@@ -358,51 +365,23 @@ class LaunchProject:
358
365
  """
359
366
  if self.source == LaunchSource.DOCKER:
360
367
  return
361
- if self.source == LaunchSource.LOCAL:
362
- if not self._entry_point:
363
- wandb.termlog(
364
- f"{LOG_PREFIX}Entry point for repo not specified, defaulting to `python main.py`"
365
- )
366
- self.set_entry_point(EntrypointDefaults.PYTHON)
367
368
  elif self.source == LaunchSource.JOB:
368
369
  self._fetch_job()
369
- else:
370
- self._fetch_project_local(internal_api=self.api)
371
-
372
370
  assert self.project_dir is not None
373
- # this prioritizes pip, and we don't support any cases where both are present conda projects when uploaded to
374
- # wandb become pip projects via requirements.frozen.txt, wandb doesn't preserve conda envs
375
- if os.path.exists(
376
- os.path.join(self.project_dir, "requirements.txt")
377
- ) or os.path.exists(os.path.join(self.project_dir, "requirements.frozen.txt")):
378
- self.deps_type = "pip"
379
- elif os.path.exists(os.path.join(self.project_dir, "environment.yml")):
380
- self.deps_type = "conda"
381
371
 
372
+ # Let's make sure we document this very clearly.
382
373
  def get_image_source_string(self) -> str:
383
374
  """Returns a unique string identifying the source of an image."""
384
- if self.source == LaunchSource.LOCAL:
385
- # TODO: more correct to get a hash of local uri contents
386
- assert isinstance(self.uri, str)
387
- return self.uri
388
- elif self.source == LaunchSource.JOB:
375
+ if self.source == LaunchSource.JOB:
389
376
  assert self._job_artifact is not None
390
377
  return f"{self._job_artifact.name}:v{self._job_artifact.version}"
391
- elif self.source == LaunchSource.GIT:
392
- assert isinstance(self.uri, str)
393
- ret = self.uri
394
- if self.git_version:
395
- ret += self.git_version
396
- return ret
397
- elif self.source == LaunchSource.WANDB:
398
- assert isinstance(self.uri, str)
399
- return self.uri
400
378
  elif self.source == LaunchSource.DOCKER:
401
379
  assert isinstance(self.docker_image, str)
402
- _logger.debug("")
403
380
  return self.docker_image
404
381
  else:
405
- raise LaunchError("Unknown source type when determing image source string")
382
+ raise LaunchError(
383
+ "Unknown source type when determining image source string"
384
+ )
406
385
 
407
386
  def _ensure_not_docker_image_and_local_process(self) -> None:
408
387
  """Ensure that docker image is not specified with local-process resource runner.
@@ -430,111 +409,84 @@ class LaunchProject:
430
409
  raise LaunchError(
431
410
  f"Error accessing job {self.job}: {msg} on {public_api.settings.get('base_url')}"
432
411
  )
433
- job.configure_launch_project(self)
412
+ job.configure_launch_project(self) # Why is this a method of the job?
434
413
  self._job_artifact = job._job_artifact
435
414
 
436
- def _fetch_project_local(self, internal_api: Api) -> None:
437
- """Fetch a project (either wandb run or git repo) into a local directory, returning the path to the local project directory."""
438
- # these asserts are all guaranteed to pass, but are required by mypy
439
- assert self.source != LaunchSource.LOCAL and self.source != LaunchSource.JOB
440
- assert isinstance(self.uri, str)
441
- assert self.project_dir is not None
442
- _logger.info("Fetching project locally...")
443
- if utils._is_wandb_uri(self.uri):
444
- source_entity, source_project, source_run_name = utils.parse_wandb_uri(
445
- self.uri
446
- )
447
- run_info = utils.fetch_wandb_project_run_info(
448
- source_entity, source_project, source_run_name, internal_api
449
- )
450
- program_name = run_info.get("codePath") or run_info["program"]
451
-
452
- self.python_version = run_info.get("python", "3")
453
- downloaded_code_artifact = utils.check_and_download_code_artifacts(
454
- source_entity,
455
- source_project,
456
- source_run_name,
457
- internal_api,
458
- self.project_dir,
459
- )
460
- if not downloaded_code_artifact:
461
- if not run_info["git"]:
462
- raise LaunchError(
463
- "Reproducing a run requires either an associated git repo or a code artifact logged with `run.log_code()`"
464
- )
465
- branch_name = utils._fetch_git_repo(
466
- self.project_dir,
467
- run_info["git"]["remote"],
468
- run_info["git"]["commit"],
469
- )
470
- if self.git_version is None:
471
- self.git_version = branch_name
472
- patch = utils.fetch_project_diff(
473
- source_entity, source_project, source_run_name, internal_api
474
- )
475
- if patch:
476
- utils.apply_patch(patch, self.project_dir)
477
-
478
- # For cases where the entry point wasn't checked into git
479
- if not os.path.exists(os.path.join(self.project_dir, program_name)):
480
- downloaded_entrypoint = utils.download_entry_point(
481
- source_entity,
482
- source_project,
483
- source_run_name,
484
- internal_api,
485
- program_name,
486
- self.project_dir,
487
- )
488
-
489
- if not downloaded_entrypoint:
490
- raise LaunchError(
491
- f"Entrypoint file: {program_name} does not exist, "
492
- "and could not be downloaded. Please specify the entrypoint for this run."
493
- )
494
-
495
- if (
496
- "_session_history.ipynb" in os.listdir(self.project_dir)
497
- or ".ipynb" in program_name
498
- ):
499
- program_name = utils.convert_jupyter_notebook_to_script(
500
- program_name, self.project_dir
501
- )
415
+ def get_env_vars_dict(self, api: Api, max_env_length: int) -> Dict[str, str]:
416
+ """Generate environment variables for the project.
502
417
 
503
- # Download any frozen requirements
504
- utils.download_wandb_python_deps(
505
- source_entity,
506
- source_project,
507
- source_run_name,
508
- internal_api,
509
- self.project_dir,
510
- )
418
+ Arguments:
419
+ launch_project: LaunchProject to generate environment variables for.
511
420
 
512
- if not self._entry_point:
513
- _, ext = os.path.splitext(program_name)
514
- if ext == ".py":
515
- entry_point = ["python", program_name]
516
- elif ext == ".sh":
517
- command = os.environ.get("SHELL", "bash")
518
- entry_point = [command, program_name]
519
- else:
520
- raise LaunchError(f"Unsupported entrypoint: {program_name}")
521
- self.set_entry_point(entry_point)
522
- if not self.override_args:
523
- self.override_args = run_info["args"]
524
- else:
525
- assert utils._GIT_URI_REGEX.match(self.uri), (
526
- "Non-wandb URI %s should be a Git URI" % self.uri
527
- )
528
- if not self._entry_point:
529
- wandb.termlog(
530
- f"{LOG_PREFIX}Entry point for repo not specified, defaulting to python main.py"
531
- )
532
- self.set_entry_point(EntrypointDefaults.PYTHON)
533
- branch_name = utils._fetch_git_repo(
534
- self.project_dir, self.uri, self.git_version
535
- )
536
- if self.git_version is None:
537
- self.git_version = branch_name
421
+ Returns:
422
+ Dictionary of environment variables.
423
+ """
424
+ env_vars = {}
425
+ env_vars["WANDB_BASE_URL"] = api.settings("base_url")
426
+ override_api_key = self.launch_spec.get("_wandb_api_key")
427
+ env_vars["WANDB_API_KEY"] = override_api_key or api.api_key
428
+ if self.target_project:
429
+ env_vars["WANDB_PROJECT"] = self.target_project
430
+ env_vars["WANDB_ENTITY"] = self.target_entity
431
+ env_vars["WANDB_LAUNCH"] = "True"
432
+ env_vars["WANDB_RUN_ID"] = self.run_id
433
+ if self.docker_image:
434
+ env_vars["WANDB_DOCKER"] = self.docker_image
435
+ if self.name is not None:
436
+ env_vars["WANDB_NAME"] = self.name
437
+ if "author" in self.launch_spec and not override_api_key:
438
+ env_vars["WANDB_USERNAME"] = self.launch_spec["author"]
439
+ if self.sweep_id:
440
+ env_vars["WANDB_SWEEP_ID"] = self.sweep_id
441
+ if self.launch_spec.get("_resume_count", 0) > 0:
442
+ env_vars["WANDB_RESUME"] = "allow"
443
+ if self.queue_name:
444
+ env_vars[wandb.env.LAUNCH_QUEUE_NAME] = self.queue_name
445
+ if self.queue_entity:
446
+ env_vars[wandb.env.LAUNCH_QUEUE_ENTITY] = self.queue_entity
447
+ if self.run_queue_item_id:
448
+ env_vars[wandb.env.LAUNCH_TRACE_ID] = self.run_queue_item_id
449
+
450
+ _inject_wandb_config_env_vars(self.override_config, env_vars, max_env_length)
451
+ _inject_file_overrides_env_vars(self.override_files, env_vars, max_env_length)
452
+
453
+ artifacts = {}
454
+ # if we're spinning up a launch process from a job
455
+ # we should tell the run to use that artifact
456
+ if self.job:
457
+ artifacts = {wandb.util.LAUNCH_JOB_ARTIFACT_SLOT_NAME: self.job}
458
+ env_vars["WANDB_ARTIFACTS"] = json.dumps(
459
+ {**artifacts, **self.override_artifacts}
460
+ )
461
+ return env_vars
462
+
463
+ def parse_existing_requirements(self) -> str:
464
+ import pkg_resources
465
+
466
+ requirements_line = ""
467
+ assert self.project_dir is not None
468
+ base_requirements = os.path.join(self.project_dir, "requirements.txt")
469
+ if os.path.exists(base_requirements):
470
+ include_only = set()
471
+ with open(base_requirements) as f:
472
+ iter = pkg_resources.parse_requirements(f)
473
+ while True:
474
+ try:
475
+ pkg = next(iter)
476
+ if hasattr(pkg, "name"):
477
+ name = pkg.name.lower()
478
+ else:
479
+ name = str(pkg)
480
+ include_only.add(shlex_quote(name))
481
+ except StopIteration:
482
+ break
483
+ # Different versions of pkg_resources throw different errors
484
+ # just catch them all and ignore packages we can't parse
485
+ except Exception as e:
486
+ _logger.warn(f"Unable to parse requirements.txt: {e}")
487
+ continue
488
+ requirements_line += "WANDB_ONLY_INCLUDE={} ".format(",".join(include_only))
489
+ return requirements_line
538
490
 
539
491
 
540
492
  class EntryPoint:
@@ -544,13 +496,6 @@ class EntryPoint:
544
496
  self.name = name
545
497
  self.command = command
546
498
 
547
- def compute_command(self, user_parameters: Optional[List[str]]) -> List[str]:
548
- """Converts user parameter dictionary to a string."""
549
- ret = self.command
550
- if user_parameters:
551
- return ret + user_parameters
552
- return ret
553
-
554
499
  def update_entrypoint_path(self, new_path: str) -> None:
555
500
  """Updates the entrypoint path to a new path."""
556
501
  if len(self.command) == 2 and (
@@ -559,18 +504,35 @@ class EntryPoint:
559
504
  self.command[1] = new_path
560
505
 
561
506
 
562
- def get_entry_point_command(
563
- entry_point: Optional["EntryPoint"], parameters: List[str]
564
- ) -> List[str]:
565
- """Returns the shell command to execute in order to run the specified entry point.
566
-
567
- Arguments:
568
- entry_point: Entry point to run
569
- parameters: Parameters (dictionary) for the entry point command
570
-
571
- Returns:
572
- List of strings representing the shell command to be executed
573
- """
574
- if entry_point is None:
575
- return []
576
- return entry_point.compute_command(parameters)
507
+ def _inject_wandb_config_env_vars(
508
+ config: Dict[str, Any], env_dict: Dict[str, Any], maximum_env_length: int
509
+ ) -> None:
510
+ str_config = json.dumps(config)
511
+ if len(str_config) <= maximum_env_length:
512
+ env_dict["WANDB_CONFIG"] = str_config
513
+ return
514
+
515
+ chunks = [
516
+ str_config[i : i + maximum_env_length]
517
+ for i in range(0, len(str_config), maximum_env_length)
518
+ ]
519
+ config_chunks_dict = {f"WANDB_CONFIG_{i}": chunk for i, chunk in enumerate(chunks)}
520
+ env_dict.update(config_chunks_dict)
521
+
522
+
523
+ def _inject_file_overrides_env_vars(
524
+ overrides: Dict[str, Any], env_dict: Dict[str, Any], maximum_env_length: int
525
+ ) -> None:
526
+ str_overrides = json.dumps(overrides)
527
+ if len(str_overrides) <= maximum_env_length:
528
+ env_dict["WANDB_LAUNCH_FILE_OVERRIDES"] = str_overrides
529
+ return
530
+
531
+ chunks = [
532
+ str_overrides[i : i + maximum_env_length]
533
+ for i in range(0, len(str_overrides), maximum_env_length)
534
+ ]
535
+ overrides_chunks_dict = {
536
+ f"WANDB_LAUNCH_FILE_OVERRIDES_{i}": chunk for i, chunk in enumerate(chunks)
537
+ }
538
+ env_dict.update(overrides_chunks_dict)
@@ -1,4 +1,5 @@
1
1
  """Implementation of launch agent."""
2
+
2
3
  import asyncio
3
4
  import logging
4
5
  import os
@@ -8,7 +9,9 @@ import time
8
9
  import traceback
9
10
  from dataclasses import dataclass
10
11
  from multiprocessing import Event
11
- from typing import Any, Dict, List, Optional, Union
12
+ from typing import Any, Dict, List, Optional, Tuple, Union
13
+
14
+ import yaml
12
15
 
13
16
  import wandb
14
17
  from wandb.apis.internal import Api
@@ -17,11 +20,11 @@ from wandb.sdk.launch._launch_add import launch_add
17
20
  from wandb.sdk.launch.runner.local_container import LocalSubmittedRun
18
21
  from wandb.sdk.launch.runner.local_process import LocalProcessRunner
19
22
  from wandb.sdk.launch.sweeps.scheduler import Scheduler
23
+ from wandb.sdk.launch.utils import LAUNCH_CONFIG_FILE, resolve_build_and_registry_config
20
24
  from wandb.sdk.lib import runid
21
25
 
22
26
  from .. import loader
23
27
  from .._project_spec import LaunchProject
24
- from ..builder.build import construct_agent_configs
25
28
  from ..errors import LaunchDockerError, LaunchError
26
29
  from ..utils import (
27
30
  LAUNCH_DEFAULT_PROJECT,
@@ -133,6 +136,31 @@ class InternalAgentLogger:
133
136
  _logger.debug(f"{LOG_PREFIX}{message}")
134
137
 
135
138
 
139
+ def construct_agent_configs(
140
+ launch_config: Optional[Dict] = None,
141
+ build_config: Optional[Dict] = None,
142
+ ) -> Tuple[Optional[Dict[str, Any]], Dict[str, Any], Dict[str, Any]]:
143
+ registry_config = None
144
+ environment_config = None
145
+ if launch_config is not None:
146
+ build_config = launch_config.get("builder")
147
+ registry_config = launch_config.get("registry")
148
+
149
+ default_launch_config = None
150
+ if os.path.exists(os.path.expanduser(LAUNCH_CONFIG_FILE)):
151
+ with open(os.path.expanduser(LAUNCH_CONFIG_FILE)) as f:
152
+ default_launch_config = (
153
+ yaml.safe_load(f) or {}
154
+ ) # In case the config is empty, we want it to be {} instead of None.
155
+ environment_config = default_launch_config.get("environment")
156
+
157
+ build_config, registry_config = resolve_build_and_registry_config(
158
+ default_launch_config, build_config, registry_config
159
+ )
160
+
161
+ return environment_config, build_config, registry_config
162
+
163
+
136
164
  class LaunchAgent:
137
165
  """Launch agent class which polls run given run queues and launches runs for wandb launch."""
138
166
 
@@ -172,7 +200,7 @@ class LaunchAgent:
172
200
  config: Config dictionary for the agent.
173
201
  """
174
202
  self._entity = config["entity"]
175
- self._project = config.get("project", LAUNCH_DEFAULT_PROJECT)
203
+ self._project = LAUNCH_DEFAULT_PROJECT
176
204
  self._api = api
177
205
  self._base_url = self._api.settings().get("base_url")
178
206
  self._ticks = 0
@@ -240,7 +268,7 @@ class LaunchAgent:
240
268
  """Determine whether a job/runSpec is a sweep scheduler."""
241
269
  if not run_spec:
242
270
  self._internal_logger.debug(
243
- "Recieved runSpec in _is_scheduler_job that was empty"
271
+ "Received runSpec in _is_scheduler_job that was empty"
244
272
  )
245
273
 
246
274
  if run_spec.get("uri") != Scheduler.PLACEHOLDER_URI:
@@ -276,6 +304,8 @@ class LaunchAgent:
276
304
 
277
305
  def _init_agent_run(self) -> None:
278
306
  # TODO: has it been long enough that all backends support agents?
307
+ self._wandb_run = None
308
+
279
309
  if self.gorilla_supports_agents:
280
310
  settings = wandb.Settings(silent=True, disable_git=True)
281
311
  self._wandb_run = wandb.init(
@@ -285,8 +315,6 @@ class LaunchAgent:
285
315
  id=self._name,
286
316
  job_type=HIDDEN_AGENT_RUN_TYPE,
287
317
  )
288
- else:
289
- self._wandb_run = None
290
318
 
291
319
  @property
292
320
  def thread_ids(self) -> List[int]:
@@ -338,10 +366,7 @@ class LaunchAgent:
338
366
  if self._name:
339
367
  output_str += f"{self._name} "
340
368
  if self.num_running_jobs < self._max_jobs:
341
- output_str += "polling on "
342
- if self._project != LAUNCH_DEFAULT_PROJECT:
343
- output_str += f"project {self._project}, "
344
- output_str += f"queues {','.join(self._queues)}, "
369
+ output_str += f"polling on queues {','.join(self._queues)}, "
345
370
  output_str += (
346
371
  f"running {self.num_running_jobs} out of a maximum of {self._max_jobs} jobs"
347
372
  )
@@ -433,7 +458,6 @@ class LaunchAgent:
433
458
  # We retry for 60 seconds with an exponential backoff in case
434
459
  # upsert run is taking a while.
435
460
  logs = None
436
- start_time = time.time()
437
461
  interval = 1
438
462
  while True:
439
463
  called_init = self._check_run_exists_and_inited(
@@ -442,7 +466,7 @@ class LaunchAgent:
442
466
  job_and_run_status.run_id,
443
467
  job_and_run_status.run_queue_item_id,
444
468
  )
445
- if called_init or time.time() - start_time > RUN_INFO_GRACE_PERIOD:
469
+ if called_init or interval > RUN_INFO_GRACE_PERIOD:
446
470
  break
447
471
  if not called_init:
448
472
  # Fetch the logs now if we don't get run info on the
@@ -691,7 +715,7 @@ class LaunchAgent:
691
715
  default_config, override_build_config
692
716
  )
693
717
  image_uri = project.docker_image
694
- entrypoint = project.get_single_entry_point()
718
+ entrypoint = project.get_job_entry_point()
695
719
  environment = loader.environment_from_config(
696
720
  default_config.get("environment", {})
697
721
  )