wandb 0.17.0rc2__py3-none-any.whl → 0.17.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (158) hide show
  1. wandb/__init__.py +1 -2
  2. wandb/apis/importers/internals/internal.py +0 -1
  3. wandb/apis/importers/wandb.py +12 -7
  4. wandb/apis/internal.py +0 -3
  5. wandb/apis/public/api.py +213 -79
  6. wandb/apis/public/artifacts.py +335 -100
  7. wandb/apis/public/files.py +9 -9
  8. wandb/apis/public/jobs.py +16 -4
  9. wandb/apis/public/projects.py +26 -28
  10. wandb/apis/public/query_generator.py +1 -1
  11. wandb/apis/public/runs.py +163 -65
  12. wandb/apis/public/sweeps.py +2 -2
  13. wandb/apis/reports/__init__.py +1 -7
  14. wandb/apis/reports/v1/__init__.py +5 -27
  15. wandb/apis/reports/v2/__init__.py +7 -19
  16. wandb/apis/workspaces/__init__.py +8 -0
  17. wandb/beta/workflows.py +8 -3
  18. wandb/cli/cli.py +131 -59
  19. wandb/docker/__init__.py +1 -1
  20. wandb/errors/term.py +10 -2
  21. wandb/filesync/step_checksum.py +1 -4
  22. wandb/filesync/step_prepare.py +4 -24
  23. wandb/filesync/step_upload.py +5 -107
  24. wandb/filesync/upload_job.py +0 -76
  25. wandb/integration/gym/__init__.py +35 -15
  26. wandb/integration/openai/fine_tuning.py +21 -3
  27. wandb/integration/prodigy/prodigy.py +1 -1
  28. wandb/jupyter.py +16 -17
  29. wandb/plot/pr_curve.py +2 -1
  30. wandb/plot/roc_curve.py +2 -1
  31. wandb/{plots → plot}/utils.py +13 -25
  32. wandb/proto/v3/wandb_internal_pb2.py +54 -54
  33. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  34. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  35. wandb/proto/v4/wandb_internal_pb2.py +54 -54
  36. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  37. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  38. wandb/proto/v5/wandb_base_pb2.py +30 -0
  39. wandb/proto/v5/wandb_internal_pb2.py +355 -0
  40. wandb/proto/v5/wandb_server_pb2.py +63 -0
  41. wandb/proto/v5/wandb_settings_pb2.py +45 -0
  42. wandb/proto/v5/wandb_telemetry_pb2.py +41 -0
  43. wandb/proto/wandb_base_pb2.py +2 -0
  44. wandb/proto/wandb_deprecated.py +9 -1
  45. wandb/proto/wandb_generate_deprecated.py +34 -0
  46. wandb/proto/{wandb_internal_codegen.py → wandb_generate_proto.py} +1 -35
  47. wandb/proto/wandb_internal_pb2.py +2 -0
  48. wandb/proto/wandb_server_pb2.py +2 -0
  49. wandb/proto/wandb_settings_pb2.py +2 -0
  50. wandb/proto/wandb_telemetry_pb2.py +2 -0
  51. wandb/sdk/artifacts/artifact.py +68 -22
  52. wandb/sdk/artifacts/artifact_manifest.py +1 -1
  53. wandb/sdk/artifacts/artifact_manifest_entry.py +6 -3
  54. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -1
  55. wandb/sdk/artifacts/artifact_saver.py +1 -10
  56. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +6 -2
  57. wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -1
  58. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +6 -4
  59. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +2 -42
  60. wandb/sdk/artifacts/storage_policy.py +1 -12
  61. wandb/sdk/data_types/image.py +1 -1
  62. wandb/sdk/data_types/video.py +4 -2
  63. wandb/sdk/interface/interface.py +13 -0
  64. wandb/sdk/interface/interface_shared.py +1 -1
  65. wandb/sdk/internal/file_pusher.py +2 -5
  66. wandb/sdk/internal/file_stream.py +6 -19
  67. wandb/sdk/internal/internal_api.py +148 -136
  68. wandb/sdk/internal/job_builder.py +207 -135
  69. wandb/sdk/internal/progress.py +0 -28
  70. wandb/sdk/internal/sender.py +102 -39
  71. wandb/sdk/internal/settings_static.py +8 -1
  72. wandb/sdk/internal/system/assets/trainium.py +3 -3
  73. wandb/sdk/internal/system/system_info.py +4 -2
  74. wandb/sdk/internal/update.py +1 -1
  75. wandb/sdk/launch/__init__.py +9 -1
  76. wandb/sdk/launch/_launch.py +4 -24
  77. wandb/sdk/launch/_launch_add.py +1 -3
  78. wandb/sdk/launch/_project_spec.py +184 -224
  79. wandb/sdk/launch/agent/agent.py +58 -18
  80. wandb/sdk/launch/agent/config.py +0 -3
  81. wandb/sdk/launch/builder/abstract.py +67 -0
  82. wandb/sdk/launch/builder/build.py +165 -576
  83. wandb/sdk/launch/builder/context_manager.py +235 -0
  84. wandb/sdk/launch/builder/docker_builder.py +7 -23
  85. wandb/sdk/launch/builder/kaniko_builder.py +10 -23
  86. wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
  87. wandb/sdk/launch/create_job.py +51 -45
  88. wandb/sdk/launch/environment/aws_environment.py +26 -1
  89. wandb/sdk/launch/inputs/files.py +148 -0
  90. wandb/sdk/launch/inputs/internal.py +224 -0
  91. wandb/sdk/launch/inputs/manage.py +95 -0
  92. wandb/sdk/launch/runner/abstract.py +2 -2
  93. wandb/sdk/launch/runner/kubernetes_monitor.py +45 -12
  94. wandb/sdk/launch/runner/kubernetes_runner.py +6 -8
  95. wandb/sdk/launch/runner/local_container.py +2 -3
  96. wandb/sdk/launch/runner/local_process.py +8 -29
  97. wandb/sdk/launch/runner/sagemaker_runner.py +20 -14
  98. wandb/sdk/launch/runner/vertex_runner.py +8 -7
  99. wandb/sdk/launch/sweeps/scheduler.py +2 -0
  100. wandb/sdk/launch/sweeps/utils.py +2 -2
  101. wandb/sdk/launch/utils.py +16 -138
  102. wandb/sdk/lib/_settings_toposort_generated.py +2 -5
  103. wandb/sdk/lib/apikey.py +4 -2
  104. wandb/sdk/lib/config_util.py +3 -3
  105. wandb/sdk/lib/proto_util.py +22 -1
  106. wandb/sdk/lib/redirect.py +1 -1
  107. wandb/sdk/service/service.py +2 -1
  108. wandb/sdk/service/streams.py +5 -5
  109. wandb/sdk/wandb_init.py +25 -59
  110. wandb/sdk/wandb_login.py +28 -25
  111. wandb/sdk/wandb_run.py +112 -45
  112. wandb/sdk/wandb_settings.py +33 -64
  113. wandb/sdk/wandb_watch.py +1 -1
  114. wandb/sklearn/plot/classifier.py +4 -6
  115. wandb/sync/sync.py +2 -2
  116. wandb/testing/relay.py +32 -17
  117. wandb/util.py +36 -37
  118. wandb/wandb_agent.py +3 -3
  119. wandb/wandb_controller.py +3 -2
  120. {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/METADATA +7 -9
  121. {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/RECORD +124 -146
  122. {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/WHEEL +1 -1
  123. wandb/apis/reports/v1/_blocks.py +0 -1406
  124. wandb/apis/reports/v1/_helpers.py +0 -70
  125. wandb/apis/reports/v1/_panels.py +0 -1282
  126. wandb/apis/reports/v1/_templates.py +0 -478
  127. wandb/apis/reports/v1/blocks.py +0 -27
  128. wandb/apis/reports/v1/helpers.py +0 -2
  129. wandb/apis/reports/v1/mutations.py +0 -66
  130. wandb/apis/reports/v1/panels.py +0 -17
  131. wandb/apis/reports/v1/report.py +0 -268
  132. wandb/apis/reports/v1/runset.py +0 -144
  133. wandb/apis/reports/v1/templates.py +0 -7
  134. wandb/apis/reports/v1/util.py +0 -406
  135. wandb/apis/reports/v1/validators.py +0 -131
  136. wandb/apis/reports/v2/blocks.py +0 -25
  137. wandb/apis/reports/v2/expr_parsing.py +0 -257
  138. wandb/apis/reports/v2/gql.py +0 -68
  139. wandb/apis/reports/v2/interface.py +0 -1911
  140. wandb/apis/reports/v2/internal.py +0 -867
  141. wandb/apis/reports/v2/metrics.py +0 -6
  142. wandb/apis/reports/v2/panels.py +0 -15
  143. wandb/catboost/__init__.py +0 -9
  144. wandb/fastai/__init__.py +0 -9
  145. wandb/keras/__init__.py +0 -19
  146. wandb/lightgbm/__init__.py +0 -9
  147. wandb/plots/__init__.py +0 -6
  148. wandb/plots/explain_text.py +0 -36
  149. wandb/plots/heatmap.py +0 -81
  150. wandb/plots/named_entity.py +0 -43
  151. wandb/plots/part_of_speech.py +0 -50
  152. wandb/plots/plot_definitions.py +0 -768
  153. wandb/plots/precision_recall.py +0 -121
  154. wandb/plots/roc.py +0 -103
  155. wandb/sacred/__init__.py +0 -3
  156. wandb/xgboost/__init__.py +0 -9
  157. {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/entry_points.txt +0 -0
  158. {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/licenses/LICENSE +0 -0
@@ -4,17 +4,18 @@ Arguments can come from a launch spec or call to wandb launch.
4
4
  """
5
5
 
6
6
  import enum
7
+ import json
7
8
  import logging
8
9
  import os
9
10
  import tempfile
10
11
  from copy import deepcopy
11
12
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
12
13
 
14
+ from six.moves import shlex_quote
15
+
13
16
  import wandb
14
- import wandb.docker as docker
15
17
  from wandb.apis.internal import Api
16
18
  from wandb.errors import CommError
17
- from wandb.sdk.launch import utils
18
19
  from wandb.sdk.launch.utils import get_entrypoint_file
19
20
  from wandb.sdk.lib.runid import generate_id
20
21
 
@@ -34,15 +35,18 @@ IMAGE_TAG_MAX_LENGTH = 32
34
35
 
35
36
 
36
37
  class LaunchSource(enum.IntEnum):
37
- WANDB: int = 1
38
- GIT: int = 2
39
- LOCAL: int = 3
40
- DOCKER: int = 4
41
- JOB: int = 5
38
+ """Enumeration of possible sources for a launch project.
42
39
 
40
+ Attributes:
41
+ DOCKER: Source is a Docker image. This can happen if a user runs
42
+ `wandb launch -d <docker-image>`.
43
+ JOB: Source is a job. This is standard case.
44
+ SCHEDULER: Source is a wandb sweep scheduler command.
45
+ """
43
46
 
44
- class EntrypointDefaults(List[str]):
45
- PYTHON = ["python", "main.py"]
47
+ DOCKER: int = 1
48
+ JOB: int = 2
49
+ SCHEDULER: int = 3
46
50
 
47
51
 
48
52
  class LaunchProject:
@@ -61,8 +65,16 @@ class LaunchProject:
61
65
 
62
66
  This class is stateful and certain methods can only be called after
63
67
  `LaunchProject.fetch_and_validate_project()` has been called.
68
+
69
+ Notes on the entrypoint:
70
+ - The entrypoint is the command that will be run inside the container.
71
+ - The LaunchProject stores two entrypoints
72
+ - The job entrypoint is the entrypoint specified in the job's config.
73
+ - The override entrypoint is the entrypoint specified in the launch spec.
74
+ - The override entrypoint takes precedence over the job entrypoint.
64
75
  """
65
76
 
77
+ # This init is way to long, and there are too many attributes on this sucker.
66
78
  def __init__(
67
79
  self,
68
80
  uri: Optional[str],
@@ -80,9 +92,6 @@ class LaunchProject:
80
92
  run_id: Optional[str],
81
93
  sweep_id: Optional[str] = None,
82
94
  ):
83
- if uri is not None and utils.is_bare_wandb_uri(uri):
84
- uri = api.settings("base_url") + uri
85
- _logger.info(f"{LOG_PREFIX}Updating uri with base uri: {uri}")
86
95
  self.uri = uri
87
96
  self.job = job
88
97
  if job is not None:
@@ -106,75 +115,57 @@ class LaunchProject:
106
115
  self.accelerator_base_image: Optional[str] = resource_args_build.get(
107
116
  "accelerator", {}
108
117
  ).get("base_image") or resource_args_build.get("cuda", {}).get("base_image")
109
- self._base_image: Optional[str] = launch_spec.get("base_image")
110
118
  self.docker_image: Optional[str] = docker_config.get(
111
119
  "docker_image"
112
120
  ) or launch_spec.get("image_uri")
113
- uid = RESOURCE_UID_MAP.get(resource, 1000)
114
- if self._base_image:
115
- uid = docker.get_image_uid(self._base_image)
116
- _logger.info(f"{LOG_PREFIX}Retrieved base image uid {uid}")
117
- self.docker_user_id: int = docker_config.get("user_id", uid)
118
- self.git_version: Optional[str] = git_info.get("version")
119
- self.git_repo: Optional[str] = git_info.get("repo")
120
- self.overrides = overrides
121
- self.override_args: List[str] = overrides.get("args", [])
122
- self.override_config: Dict[str, Any] = overrides.get("run_config", {})
123
- self.override_artifacts: Dict[str, Any] = overrides.get("artifacts", {})
124
- self.override_files: Dict[str, Any] = overrides.get("files", {})
125
- self.override_entrypoint: Optional[EntryPoint] = None
126
- self.override_dockerfile: Optional[str] = overrides.get("dockerfile")
121
+ self.docker_user_id = docker_config.get("user_id", 1000)
122
+ self._entry_point: Optional[EntryPoint] = (
123
+ None # todo: keep multiple entrypoint support?
124
+ )
125
+ self.init_overrides(overrides)
126
+ self.init_source()
127
+ self.init_git(git_info)
127
128
  self.deps_type: Optional[str] = None
128
129
  self._runtime: Optional[str] = None
129
130
  self.run_id = run_id or generate_id()
130
131
  self._queue_name: Optional[str] = None
131
132
  self._queue_entity: Optional[str] = None
132
133
  self._run_queue_item_id: Optional[str] = None
133
- self._entry_point: Optional[EntryPoint] = (
134
- None # todo: keep multiple entrypoint support?
135
- )
136
-
137
- override_entrypoint = overrides.get("entry_point")
138
- if override_entrypoint:
139
- _logger.info("Adding override entry point")
140
- self.override_entrypoint = EntryPoint(
141
- name=get_entrypoint_file(override_entrypoint),
142
- command=override_entrypoint,
143
- )
134
+ self._job_dockerfile: Optional[str] = None
135
+ self._job_build_context: Optional[str] = None
144
136
 
145
- if overrides.get("sweep_id") is not None:
146
- _logger.info("Adding override sweep id")
147
- self.sweep_id = overrides["sweep_id"]
137
+ def init_source(self) -> None:
148
138
  if self.docker_image is not None:
149
139
  self.source = LaunchSource.DOCKER
150
140
  self.project_dir = None
151
141
  elif self.job is not None:
152
142
  self.source = LaunchSource.JOB
153
143
  self.project_dir = tempfile.mkdtemp()
154
- elif self.uri is not None and utils._is_wandb_uri(self.uri):
155
- _logger.info(f"URI {self.uri} indicates a wandb uri")
156
- self.source = LaunchSource.WANDB
157
- self.project_dir = tempfile.mkdtemp()
158
- elif self.uri is not None and utils._is_git_uri(self.uri):
159
- _logger.info(f"URI {self.uri} indicates a git uri")
160
- self.source = LaunchSource.GIT
161
- self.project_dir = tempfile.mkdtemp()
162
- elif self.uri is not None and "placeholder-" in self.uri:
163
- wandb.termlog(
164
- f"{LOG_PREFIX}Launch received placeholder URI, replacing with local path."
144
+ if self.uri and self.uri.startswith("placeholder"):
145
+ self.source = LaunchSource.SCHEDULER
146
+ self.project_dir = os.getcwd()
147
+ self._entry_point = self.override_entrypoint
148
+
149
+ def init_git(self, git_info: Dict[str, str]) -> None:
150
+ self.git_version = git_info.get("version")
151
+ self.git_repo = git_info.get("repo")
152
+
153
+ def init_overrides(self, overrides: Dict[str, Any]) -> None:
154
+ """Initialize override attributes for a launch project."""
155
+ self.overrides = overrides
156
+ self.override_args: List[str] = overrides.get("args", [])
157
+ self.override_config: Dict[str, Any] = overrides.get("run_config", {})
158
+ self.override_artifacts: Dict[str, Any] = overrides.get("artifacts", {})
159
+ self.override_files: Dict[str, Any] = overrides.get("files", {})
160
+ self.override_entrypoint: Optional[EntryPoint] = None
161
+ self.override_dockerfile: Optional[str] = overrides.get("dockerfile")
162
+ override_entrypoint = overrides.get("entry_point")
163
+ if override_entrypoint:
164
+ _logger.info("Adding override entry point")
165
+ self.override_entrypoint = EntryPoint(
166
+ name=get_entrypoint_file(override_entrypoint),
167
+ command=override_entrypoint,
165
168
  )
166
- self.uri = os.getcwd()
167
- self.source = LaunchSource.LOCAL
168
- self.project_dir = self.uri
169
- else:
170
- _logger.info(f"URI {self.uri} indicates a local uri")
171
- # assume local
172
- if self.uri is not None and not os.path.exists(self.uri):
173
- raise LaunchError(
174
- "Assumed URI supplied is a local path but path is not valid"
175
- )
176
- self.source = LaunchSource.LOCAL
177
- self.project_dir = self.uri
178
169
 
179
170
  def __repr__(self) -> str:
180
171
  """String representation of LaunchProject."""
@@ -213,6 +204,20 @@ class LaunchProject:
213
204
  launch_spec.get("sweep_id", {}),
214
205
  )
215
206
 
207
+ @property
208
+ def job_dockerfile(self) -> Optional[str]:
209
+ return self._job_dockerfile
210
+
211
+ @property
212
+ def job_build_context(self) -> Optional[str]:
213
+ return self._job_build_context
214
+
215
+ def set_job_dockerfile(self, dockerfile: str) -> None:
216
+ self._job_dockerfile = dockerfile
217
+
218
+ def set_job_build_context(self, build_context: str) -> None:
219
+ self._job_build_context = build_context
220
+
216
221
  @property
217
222
  def image_name(self) -> str:
218
223
  if self.docker_image is not None:
@@ -276,7 +281,7 @@ class LaunchProject:
276
281
  image (str): The image name to fill in for ${wandb-image}.
277
282
 
278
283
  Returns:
279
- None
284
+ Dict[str, Any]: The resource args with all macros filled in.
280
285
  """
281
286
  update_dict = {
282
287
  "project_name": self.target_project,
@@ -326,8 +331,8 @@ class LaunchProject:
326
331
  self._docker_image = value
327
332
  self._ensure_not_docker_image_and_local_process()
328
333
 
329
- def get_single_entry_point(self) -> Optional["EntryPoint"]:
330
- """Returns the first entrypoint for the project, or None if no entry point was provided because a docker image was provided."""
334
+ def get_job_entry_point(self) -> Optional["EntryPoint"]:
335
+ """Returns the job entrypoint for the project."""
331
336
  # assuming project only has 1 entry point, pull that out
332
337
  # tmp fn until we figure out if we want to support multiple entry points or not
333
338
  if not self._entry_point:
@@ -338,8 +343,8 @@ class LaunchProject:
338
343
  return None
339
344
  return self._entry_point
340
345
 
341
- def set_entry_point(self, command: List[str]) -> "EntryPoint":
342
- """Add an entry point to the project."""
346
+ def set_job_entry_point(self, command: List[str]) -> "EntryPoint":
347
+ """Set job entrypoint for the project."""
343
348
  assert (
344
349
  self._entry_point is None
345
350
  ), "Cannot set entry point twice. Use LaunchProject.override_entrypoint"
@@ -360,48 +365,18 @@ class LaunchProject:
360
365
  """
361
366
  if self.source == LaunchSource.DOCKER:
362
367
  return
363
- if self.source == LaunchSource.LOCAL:
364
- if not self._entry_point:
365
- wandb.termlog(
366
- f"{LOG_PREFIX}Entry point for repo not specified, defaulting to `python main.py`"
367
- )
368
- self.set_entry_point(EntrypointDefaults.PYTHON)
369
368
  elif self.source == LaunchSource.JOB:
370
369
  self._fetch_job()
371
- else:
372
- self._fetch_project_local(internal_api=self.api)
373
-
374
370
  assert self.project_dir is not None
375
- # this prioritizes pip, and we don't support any cases where both are present conda projects when uploaded to
376
- # wandb become pip projects via requirements.frozen.txt, wandb doesn't preserve conda envs
377
- if os.path.exists(
378
- os.path.join(self.project_dir, "requirements.txt")
379
- ) or os.path.exists(os.path.join(self.project_dir, "requirements.frozen.txt")):
380
- self.deps_type = "pip"
381
- elif os.path.exists(os.path.join(self.project_dir, "environment.yml")):
382
- self.deps_type = "conda"
383
371
 
372
+ # Let's make sure we document this very clearly.
384
373
  def get_image_source_string(self) -> str:
385
374
  """Returns a unique string identifying the source of an image."""
386
- if self.source == LaunchSource.LOCAL:
387
- # TODO: more correct to get a hash of local uri contents
388
- assert isinstance(self.uri, str)
389
- return self.uri
390
- elif self.source == LaunchSource.JOB:
375
+ if self.source == LaunchSource.JOB:
391
376
  assert self._job_artifact is not None
392
377
  return f"{self._job_artifact.name}:v{self._job_artifact.version}"
393
- elif self.source == LaunchSource.GIT:
394
- assert isinstance(self.uri, str)
395
- ret = self.uri
396
- if self.git_version:
397
- ret += self.git_version
398
- return ret
399
- elif self.source == LaunchSource.WANDB:
400
- assert isinstance(self.uri, str)
401
- return self.uri
402
378
  elif self.source == LaunchSource.DOCKER:
403
379
  assert isinstance(self.docker_image, str)
404
- _logger.debug("")
405
380
  return self.docker_image
406
381
  else:
407
382
  raise LaunchError(
@@ -434,111 +409,86 @@ class LaunchProject:
434
409
  raise LaunchError(
435
410
  f"Error accessing job {self.job}: {msg} on {public_api.settings.get('base_url')}"
436
411
  )
437
- job.configure_launch_project(self)
412
+ job.configure_launch_project(self) # Why is this a method of the job?
438
413
  self._job_artifact = job._job_artifact
439
414
 
440
- def _fetch_project_local(self, internal_api: Api) -> None:
441
- """Fetch a project (either wandb run or git repo) into a local directory, returning the path to the local project directory."""
442
- # these asserts are all guaranteed to pass, but are required by mypy
443
- assert self.source != LaunchSource.LOCAL and self.source != LaunchSource.JOB
444
- assert isinstance(self.uri, str)
445
- assert self.project_dir is not None
446
- _logger.info("Fetching project locally...")
447
- if utils._is_wandb_uri(self.uri):
448
- source_entity, source_project, source_run_name = utils.parse_wandb_uri(
449
- self.uri
450
- )
451
- run_info = utils.fetch_wandb_project_run_info(
452
- source_entity, source_project, source_run_name, internal_api
453
- )
454
- program_name = run_info.get("codePath") or run_info["program"]
455
-
456
- self.python_version = run_info.get("python", "3")
457
- downloaded_code_artifact = utils.check_and_download_code_artifacts(
458
- source_entity,
459
- source_project,
460
- source_run_name,
461
- internal_api,
462
- self.project_dir,
463
- )
464
- if not downloaded_code_artifact:
465
- if not run_info["git"]:
466
- raise LaunchError(
467
- "Reproducing a run requires either an associated git repo or a code artifact logged with `run.log_code()`"
468
- )
469
- branch_name = utils._fetch_git_repo(
470
- self.project_dir,
471
- run_info["git"]["remote"],
472
- run_info["git"]["commit"],
473
- )
474
- if self.git_version is None:
475
- self.git_version = branch_name
476
- patch = utils.fetch_project_diff(
477
- source_entity, source_project, source_run_name, internal_api
478
- )
479
- if patch:
480
- utils.apply_patch(patch, self.project_dir)
481
-
482
- # For cases where the entry point wasn't checked into git
483
- if not os.path.exists(os.path.join(self.project_dir, program_name)):
484
- downloaded_entrypoint = utils.download_entry_point(
485
- source_entity,
486
- source_project,
487
- source_run_name,
488
- internal_api,
489
- program_name,
490
- self.project_dir,
491
- )
492
-
493
- if not downloaded_entrypoint:
494
- raise LaunchError(
495
- f"Entrypoint file: {program_name} does not exist, "
496
- "and could not be downloaded. Please specify the entrypoint for this run."
497
- )
498
-
499
- if (
500
- "_session_history.ipynb" in os.listdir(self.project_dir)
501
- or ".ipynb" in program_name
502
- ):
503
- program_name = utils.convert_jupyter_notebook_to_script(
504
- program_name, self.project_dir
505
- )
415
+ def get_env_vars_dict(self, api: Api, max_env_length: int) -> Dict[str, str]:
416
+ """Generate environment variables for the project.
506
417
 
507
- # Download any frozen requirements
508
- utils.download_wandb_python_deps(
509
- source_entity,
510
- source_project,
511
- source_run_name,
512
- internal_api,
513
- self.project_dir,
514
- )
418
+ Arguments:
419
+ launch_project: LaunchProject to generate environment variables for.
515
420
 
516
- if not self._entry_point:
517
- _, ext = os.path.splitext(program_name)
518
- if ext == ".py":
519
- entry_point = ["python", program_name]
520
- elif ext == ".sh":
521
- command = os.environ.get("SHELL", "bash")
522
- entry_point = [command, program_name]
523
- else:
524
- raise LaunchError(f"Unsupported entrypoint: {program_name}")
525
- self.set_entry_point(entry_point)
526
- if not self.override_args:
527
- self.override_args = run_info["args"]
528
- else:
529
- assert utils._GIT_URI_REGEX.match(self.uri), (
530
- "Non-wandb URI %s should be a Git URI" % self.uri
531
- )
532
- if not self._entry_point:
533
- wandb.termlog(
534
- f"{LOG_PREFIX}Entry point for repo not specified, defaulting to python main.py"
535
- )
536
- self.set_entry_point(EntrypointDefaults.PYTHON)
537
- branch_name = utils._fetch_git_repo(
538
- self.project_dir, self.uri, self.git_version
539
- )
540
- if self.git_version is None:
541
- self.git_version = branch_name
421
+ Returns:
422
+ Dictionary of environment variables.
423
+ """
424
+ env_vars = {}
425
+ env_vars["WANDB_BASE_URL"] = api.settings("base_url")
426
+ override_api_key = self.launch_spec.get("_wandb_api_key")
427
+ env_vars["WANDB_API_KEY"] = override_api_key or api.api_key
428
+ if self.target_project:
429
+ env_vars["WANDB_PROJECT"] = self.target_project
430
+ env_vars["WANDB_ENTITY"] = self.target_entity
431
+ env_vars["WANDB_LAUNCH"] = "True"
432
+ env_vars["WANDB_RUN_ID"] = self.run_id
433
+ if self.docker_image:
434
+ env_vars["WANDB_DOCKER"] = self.docker_image
435
+ if self.name is not None:
436
+ env_vars["WANDB_NAME"] = self.name
437
+ if "author" in self.launch_spec and not override_api_key:
438
+ env_vars["WANDB_USERNAME"] = self.launch_spec["author"]
439
+ if self.sweep_id:
440
+ env_vars["WANDB_SWEEP_ID"] = self.sweep_id
441
+ if self.launch_spec.get("_resume_count", 0) > 0:
442
+ env_vars["WANDB_RESUME"] = "allow"
443
+ if self.queue_name:
444
+ env_vars[wandb.env.LAUNCH_QUEUE_NAME] = self.queue_name
445
+ if self.queue_entity:
446
+ env_vars[wandb.env.LAUNCH_QUEUE_ENTITY] = self.queue_entity
447
+ if self.run_queue_item_id:
448
+ env_vars[wandb.env.LAUNCH_TRACE_ID] = self.run_queue_item_id
449
+
450
+ _inject_wandb_config_env_vars(self.override_config, env_vars, max_env_length)
451
+ _inject_file_overrides_env_vars(self.override_files, env_vars, max_env_length)
452
+
453
+ artifacts = {}
454
+ # if we're spinning up a launch process from a job
455
+ # we should tell the run to use that artifact
456
+ if self.job:
457
+ artifacts = {wandb.util.LAUNCH_JOB_ARTIFACT_SLOT_NAME: self.job}
458
+ env_vars["WANDB_ARTIFACTS"] = json.dumps(
459
+ {**artifacts, **self.override_artifacts}
460
+ )
461
+ return env_vars
462
+
463
+ def parse_existing_requirements(self) -> str:
464
+ import pkg_resources
465
+
466
+ requirements_line = ""
467
+ assert self.project_dir is not None
468
+ base_requirements = os.path.join(self.project_dir, "requirements.txt")
469
+ if os.path.exists(base_requirements):
470
+ include_only = set()
471
+ with open(base_requirements) as f:
472
+ iter = pkg_resources.parse_requirements(f)
473
+ while True:
474
+ try:
475
+ pkg = next(iter)
476
+ if hasattr(pkg, "name"):
477
+ name = pkg.name.lower()
478
+ else:
479
+ name = str(pkg)
480
+ include_only.add(shlex_quote(name))
481
+ except StopIteration:
482
+ break
483
+ # Different versions of pkg_resources throw different errors
484
+ # just catch them all and ignore packages we can't parse
485
+ except Exception as e:
486
+ _logger.warn(f"Unable to parse requirements.txt: {e}")
487
+ continue
488
+ requirements_line += "WANDB_ONLY_INCLUDE={} ".format(",".join(include_only))
489
+ if "wandb" not in requirements_line:
490
+ wandb.termwarn(f"{LOG_PREFIX}wandb is not present in requirements.txt.")
491
+ return requirements_line
542
492
 
543
493
 
544
494
  class EntryPoint:
@@ -548,13 +498,6 @@ class EntryPoint:
548
498
  self.name = name
549
499
  self.command = command
550
500
 
551
- def compute_command(self, user_parameters: Optional[List[str]]) -> List[str]:
552
- """Converts user parameter dictionary to a string."""
553
- ret = self.command
554
- if user_parameters:
555
- return ret + user_parameters
556
- return ret
557
-
558
501
  def update_entrypoint_path(self, new_path: str) -> None:
559
502
  """Updates the entrypoint path to a new path."""
560
503
  if len(self.command) == 2 and (
@@ -563,18 +506,35 @@ class EntryPoint:
563
506
  self.command[1] = new_path
564
507
 
565
508
 
566
- def get_entry_point_command(
567
- entry_point: Optional["EntryPoint"], parameters: List[str]
568
- ) -> List[str]:
569
- """Returns the shell command to execute in order to run the specified entry point.
570
-
571
- Arguments:
572
- entry_point: Entry point to run
573
- parameters: Parameters (dictionary) for the entry point command
574
-
575
- Returns:
576
- List of strings representing the shell command to be executed
577
- """
578
- if entry_point is None:
579
- return []
580
- return entry_point.compute_command(parameters)
509
+ def _inject_wandb_config_env_vars(
510
+ config: Dict[str, Any], env_dict: Dict[str, Any], maximum_env_length: int
511
+ ) -> None:
512
+ str_config = json.dumps(config)
513
+ if len(str_config) <= maximum_env_length:
514
+ env_dict["WANDB_CONFIG"] = str_config
515
+ return
516
+
517
+ chunks = [
518
+ str_config[i : i + maximum_env_length]
519
+ for i in range(0, len(str_config), maximum_env_length)
520
+ ]
521
+ config_chunks_dict = {f"WANDB_CONFIG_{i}": chunk for i, chunk in enumerate(chunks)}
522
+ env_dict.update(config_chunks_dict)
523
+
524
+
525
+ def _inject_file_overrides_env_vars(
526
+ overrides: Dict[str, Any], env_dict: Dict[str, Any], maximum_env_length: int
527
+ ) -> None:
528
+ str_overrides = json.dumps(overrides)
529
+ if len(str_overrides) <= maximum_env_length:
530
+ env_dict["WANDB_LAUNCH_FILE_OVERRIDES"] = str_overrides
531
+ return
532
+
533
+ chunks = [
534
+ str_overrides[i : i + maximum_env_length]
535
+ for i in range(0, len(str_overrides), maximum_env_length)
536
+ ]
537
+ overrides_chunks_dict = {
538
+ f"WANDB_LAUNCH_FILE_OVERRIDES_{i}": chunk for i, chunk in enumerate(chunks)
539
+ }
540
+ env_dict.update(overrides_chunks_dict)