wandb 0.16.5__py3-none-any.whl → 0.17.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (194) hide show
  1. package_readme.md +95 -0
  2. wandb/__init__.py +2 -3
  3. wandb/agents/pyagent.py +0 -1
  4. wandb/analytics/sentry.py +2 -1
  5. wandb/apis/importers/internals/internal.py +0 -1
  6. wandb/apis/importers/internals/protocols.py +30 -56
  7. wandb/apis/importers/mlflow.py +13 -26
  8. wandb/apis/importers/wandb.py +8 -14
  9. wandb/apis/internal.py +0 -3
  10. wandb/apis/public/api.py +55 -3
  11. wandb/apis/public/artifacts.py +1 -0
  12. wandb/apis/public/files.py +1 -0
  13. wandb/apis/public/history.py +1 -0
  14. wandb/apis/public/jobs.py +17 -4
  15. wandb/apis/public/projects.py +1 -0
  16. wandb/apis/public/reports.py +1 -0
  17. wandb/apis/public/runs.py +15 -17
  18. wandb/apis/public/sweeps.py +1 -0
  19. wandb/apis/public/teams.py +1 -0
  20. wandb/apis/public/users.py +1 -0
  21. wandb/apis/reports/v1/_blocks.py +3 -7
  22. wandb/apis/reports/v2/gql.py +1 -0
  23. wandb/apis/reports/v2/interface.py +3 -4
  24. wandb/apis/reports/v2/internal.py +5 -8
  25. wandb/cli/cli.py +95 -22
  26. wandb/data_types.py +9 -6
  27. wandb/docker/__init__.py +1 -1
  28. wandb/env.py +38 -8
  29. wandb/errors/__init__.py +5 -0
  30. wandb/errors/term.py +10 -2
  31. wandb/filesync/step_checksum.py +1 -4
  32. wandb/filesync/step_prepare.py +4 -24
  33. wandb/filesync/step_upload.py +4 -106
  34. wandb/filesync/upload_job.py +0 -76
  35. wandb/integration/catboost/catboost.py +1 -1
  36. wandb/integration/fastai/__init__.py +1 -0
  37. wandb/integration/huggingface/resolver.py +2 -2
  38. wandb/integration/keras/__init__.py +1 -0
  39. wandb/integration/keras/callbacks/metrics_logger.py +1 -1
  40. wandb/integration/keras/keras.py +7 -7
  41. wandb/integration/langchain/wandb_tracer.py +1 -0
  42. wandb/integration/lightning/fabric/logger.py +1 -3
  43. wandb/integration/metaflow/metaflow.py +41 -6
  44. wandb/integration/openai/fine_tuning.py +77 -40
  45. wandb/integration/prodigy/prodigy.py +1 -1
  46. wandb/old/summary.py +1 -1
  47. wandb/plot/confusion_matrix.py +1 -1
  48. wandb/plot/pr_curve.py +2 -1
  49. wandb/plot/roc_curve.py +2 -1
  50. wandb/{plots → plot}/utils.py +13 -25
  51. wandb/proto/v3/wandb_internal_pb2.py +364 -332
  52. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  53. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  54. wandb/proto/v4/wandb_internal_pb2.py +322 -316
  55. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  56. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  57. wandb/proto/wandb_deprecated.py +7 -1
  58. wandb/proto/wandb_internal_codegen.py +3 -29
  59. wandb/sdk/artifacts/artifact.py +51 -20
  60. wandb/sdk/artifacts/artifact_download_logger.py +1 -0
  61. wandb/sdk/artifacts/artifact_file_cache.py +18 -4
  62. wandb/sdk/artifacts/artifact_instance_cache.py +1 -0
  63. wandb/sdk/artifacts/artifact_manifest.py +1 -0
  64. wandb/sdk/artifacts/artifact_manifest_entry.py +7 -3
  65. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
  66. wandb/sdk/artifacts/artifact_saver.py +18 -27
  67. wandb/sdk/artifacts/artifact_state.py +1 -0
  68. wandb/sdk/artifacts/artifact_ttl.py +1 -0
  69. wandb/sdk/artifacts/exceptions.py +1 -0
  70. wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
  71. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +13 -18
  72. wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -0
  73. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -0
  74. wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -0
  75. wandb/sdk/artifacts/storage_handlers/s3_handler.py +5 -3
  76. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +1 -0
  77. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -0
  78. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -0
  79. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +3 -42
  80. wandb/sdk/artifacts/storage_policy.py +2 -12
  81. wandb/sdk/data_types/_dtypes.py +8 -8
  82. wandb/sdk/data_types/base_types/media.py +3 -6
  83. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +3 -1
  84. wandb/sdk/data_types/image.py +1 -1
  85. wandb/sdk/data_types/video.py +1 -1
  86. wandb/sdk/integration_utils/auto_logging.py +5 -6
  87. wandb/sdk/integration_utils/data_logging.py +10 -6
  88. wandb/sdk/interface/interface.py +86 -38
  89. wandb/sdk/interface/interface_shared.py +7 -13
  90. wandb/sdk/internal/datastore.py +1 -1
  91. wandb/sdk/internal/file_pusher.py +2 -5
  92. wandb/sdk/internal/file_stream.py +5 -18
  93. wandb/sdk/internal/handler.py +18 -2
  94. wandb/sdk/internal/internal.py +0 -1
  95. wandb/sdk/internal/internal_api.py +1 -129
  96. wandb/sdk/internal/internal_util.py +0 -1
  97. wandb/sdk/internal/job_builder.py +159 -45
  98. wandb/sdk/internal/profiler.py +1 -0
  99. wandb/sdk/internal/progress.py +0 -28
  100. wandb/sdk/internal/run.py +1 -0
  101. wandb/sdk/internal/sender.py +1 -2
  102. wandb/sdk/internal/system/assets/gpu_amd.py +44 -44
  103. wandb/sdk/internal/system/assets/gpu_apple.py +56 -11
  104. wandb/sdk/internal/system/assets/interfaces.py +6 -8
  105. wandb/sdk/internal/system/assets/open_metrics.py +2 -2
  106. wandb/sdk/internal/system/assets/trainium.py +1 -3
  107. wandb/sdk/launch/__init__.py +9 -1
  108. wandb/sdk/launch/_launch.py +9 -24
  109. wandb/sdk/launch/_launch_add.py +1 -3
  110. wandb/sdk/launch/_project_spec.py +188 -241
  111. wandb/sdk/launch/agent/agent.py +115 -48
  112. wandb/sdk/launch/agent/config.py +80 -14
  113. wandb/sdk/launch/builder/abstract.py +69 -1
  114. wandb/sdk/launch/builder/build.py +156 -555
  115. wandb/sdk/launch/builder/context_manager.py +235 -0
  116. wandb/sdk/launch/builder/docker_builder.py +8 -23
  117. wandb/sdk/launch/builder/kaniko_builder.py +161 -159
  118. wandb/sdk/launch/builder/noop.py +1 -0
  119. wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
  120. wandb/sdk/launch/create_job.py +68 -63
  121. wandb/sdk/launch/environment/abstract.py +1 -0
  122. wandb/sdk/launch/environment/gcp_environment.py +1 -0
  123. wandb/sdk/launch/environment/local_environment.py +1 -0
  124. wandb/sdk/launch/inputs/files.py +148 -0
  125. wandb/sdk/launch/inputs/internal.py +217 -0
  126. wandb/sdk/launch/inputs/manage.py +95 -0
  127. wandb/sdk/launch/loader.py +1 -0
  128. wandb/sdk/launch/registry/abstract.py +1 -0
  129. wandb/sdk/launch/registry/azure_container_registry.py +1 -0
  130. wandb/sdk/launch/registry/elastic_container_registry.py +1 -0
  131. wandb/sdk/launch/registry/google_artifact_registry.py +2 -1
  132. wandb/sdk/launch/registry/local_registry.py +1 -0
  133. wandb/sdk/launch/runner/abstract.py +1 -0
  134. wandb/sdk/launch/runner/kubernetes_monitor.py +4 -1
  135. wandb/sdk/launch/runner/kubernetes_runner.py +9 -10
  136. wandb/sdk/launch/runner/local_container.py +2 -3
  137. wandb/sdk/launch/runner/local_process.py +8 -29
  138. wandb/sdk/launch/runner/sagemaker_runner.py +21 -20
  139. wandb/sdk/launch/runner/vertex_runner.py +8 -7
  140. wandb/sdk/launch/sweeps/scheduler.py +7 -4
  141. wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
  142. wandb/sdk/launch/sweeps/utils.py +3 -3
  143. wandb/sdk/launch/utils.py +33 -140
  144. wandb/sdk/lib/_settings_toposort_generated.py +1 -5
  145. wandb/sdk/lib/fsm.py +8 -12
  146. wandb/sdk/lib/gitlib.py +4 -4
  147. wandb/sdk/lib/import_hooks.py +1 -1
  148. wandb/sdk/lib/lazyloader.py +0 -1
  149. wandb/sdk/lib/proto_util.py +23 -2
  150. wandb/sdk/lib/redirect.py +19 -14
  151. wandb/sdk/lib/retry.py +3 -2
  152. wandb/sdk/lib/run_moment.py +7 -1
  153. wandb/sdk/lib/tracelog.py +1 -1
  154. wandb/sdk/service/service.py +19 -16
  155. wandb/sdk/verify/verify.py +2 -1
  156. wandb/sdk/wandb_init.py +16 -63
  157. wandb/sdk/wandb_manager.py +2 -2
  158. wandb/sdk/wandb_require.py +5 -0
  159. wandb/sdk/wandb_run.py +164 -90
  160. wandb/sdk/wandb_settings.py +2 -48
  161. wandb/sdk/wandb_setup.py +1 -1
  162. wandb/sklearn/__init__.py +1 -0
  163. wandb/sklearn/plot/__init__.py +1 -0
  164. wandb/sklearn/plot/classifier.py +11 -12
  165. wandb/sklearn/plot/clusterer.py +2 -1
  166. wandb/sklearn/plot/regressor.py +1 -0
  167. wandb/sklearn/plot/shared.py +1 -0
  168. wandb/sklearn/utils.py +1 -0
  169. wandb/testing/relay.py +4 -4
  170. wandb/trigger.py +1 -0
  171. wandb/util.py +67 -54
  172. wandb/wandb_controller.py +2 -3
  173. wandb/wandb_torch.py +1 -2
  174. {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/METADATA +67 -70
  175. {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/RECORD +178 -188
  176. {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/WHEEL +1 -2
  177. wandb/bin/apple_gpu_stats +0 -0
  178. wandb/catboost/__init__.py +0 -9
  179. wandb/fastai/__init__.py +0 -9
  180. wandb/keras/__init__.py +0 -18
  181. wandb/lightgbm/__init__.py +0 -9
  182. wandb/plots/__init__.py +0 -6
  183. wandb/plots/explain_text.py +0 -36
  184. wandb/plots/heatmap.py +0 -81
  185. wandb/plots/named_entity.py +0 -43
  186. wandb/plots/part_of_speech.py +0 -50
  187. wandb/plots/plot_definitions.py +0 -768
  188. wandb/plots/precision_recall.py +0 -121
  189. wandb/plots/roc.py +0 -103
  190. wandb/sacred/__init__.py +0 -3
  191. wandb/xgboost/__init__.py +0 -9
  192. wandb-0.16.5.dist-info/top_level.txt +0 -1
  193. {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/entry_points.txt +0 -0
  194. {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info/licenses}/LICENSE +0 -0
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  import logging
3
3
  import os
4
+ import re
4
5
  import sys
5
6
  import tempfile
6
7
  from typing import Any, Dict, List, Optional, Tuple
@@ -9,9 +10,12 @@ import wandb
9
10
  from wandb.apis.internal import Api
10
11
  from wandb.sdk.artifacts.artifact import Artifact
11
12
  from wandb.sdk.internal.job_builder import JobBuilder
12
- from wandb.sdk.launch.builder.build import get_current_python_version
13
13
  from wandb.sdk.launch.git_reference import GitReference
14
- from wandb.sdk.launch.utils import _is_git_uri
14
+ from wandb.sdk.launch.utils import (
15
+ _is_git_uri,
16
+ get_current_python_version,
17
+ get_entrypoint_file,
18
+ )
15
19
  from wandb.sdk.lib import filesystem
16
20
  from wandb.util import make_artifact_name_safe
17
21
 
@@ -19,6 +23,9 @@ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
19
23
  _logger = logging.getLogger("wandb")
20
24
 
21
25
 
26
+ CODE_ARTIFACT_EXCLUDE_PATHS = ["wandb", ".git"]
27
+
28
+
22
29
  def create_job(
23
30
  path: str,
24
31
  job_type: str,
@@ -30,6 +37,8 @@ def create_job(
30
37
  runtime: Optional[str] = None,
31
38
  entrypoint: Optional[str] = None,
32
39
  git_hash: Optional[str] = None,
40
+ build_context: Optional[str] = None,
41
+ dockerfile: Optional[str] = None,
33
42
  ) -> Optional[Artifact]:
34
43
  """Create a job from a path, not as the output of a run.
35
44
 
@@ -42,9 +51,12 @@ def create_job(
42
51
  description (Optional[str]): Description of the job.
43
52
  aliases (Optional[List[str]]): Aliases for the job.
44
53
  runtime (Optional[str]): Python runtime of the job, like 3.9.
45
- entrypoint (Optional[str]): Entrypoint of the job.
54
+ entrypoint (Optional[str]): Entrypoint of the job. If build_context is
55
+ provided, path is relative to build_context.
46
56
  git_hash (Optional[str]): Git hash of a specific commit, when using git type jobs.
47
-
57
+ build_context (Optional[str]): Path to the build context, when using image type jobs.
58
+ dockerfile (Optional[str]): Path to the Dockerfile, when using image type jobs.
59
+ If build_context is provided, path is relative to build_context.
48
60
 
49
61
  Returns:
50
62
  Optional[Artifact]: The artifact created by the job, the action (for printing), and job aliases.
@@ -81,6 +93,8 @@ def create_job(
81
93
  runtime,
82
94
  entrypoint,
83
95
  git_hash,
96
+ build_context,
97
+ dockerfile,
84
98
  )
85
99
 
86
100
  return artifact_job
@@ -98,6 +112,8 @@ def _create_job(
98
112
  runtime: Optional[str] = None,
99
113
  entrypoint: Optional[str] = None,
100
114
  git_hash: Optional[str] = None,
115
+ build_context: Optional[str] = None,
116
+ dockerfile: Optional[str] = None,
101
117
  ) -> Tuple[Optional[Artifact], str, List[str]]:
102
118
  wandb.termlog(f"Creating launch job of type: {job_type}...")
103
119
 
@@ -107,6 +123,13 @@ def _create_job(
107
123
  )
108
124
  return None, "", []
109
125
 
126
+ if runtime is not None:
127
+ if not re.match(r"^3\.\d+$", runtime):
128
+ wandb.termerror(
129
+ f"Runtime (-r, --runtime) must be a minor version of Python 3, "
130
+ f"e.g. 3.9 or 3.10, received {runtime}"
131
+ )
132
+ return None, "", []
110
133
  aliases = aliases or []
111
134
  tempdir = tempfile.TemporaryDirectory()
112
135
  try:
@@ -145,6 +168,7 @@ def _create_job(
145
168
 
146
169
  job_builder = _configure_job_builder_for_partial(tempdir.name, job_source=job_type)
147
170
  if job_type == "code":
171
+ assert entrypoint is not None
148
172
  job_name = _make_code_artifact(
149
173
  api=api,
150
174
  job_builder=job_builder,
@@ -160,7 +184,10 @@ def _create_job(
160
184
  name = job_name
161
185
 
162
186
  # build job artifact, loads wandb-metadata and creates wandb-job.json here
163
- artifact = job_builder.build()
187
+ artifact = job_builder.build(
188
+ dockerfile=dockerfile,
189
+ build_context=build_context,
190
+ )
164
191
  if not artifact:
165
192
  wandb.termerror("JobBuilder failed to build a job")
166
193
  _logger.debug("Failed to build job, check job source and metadata")
@@ -219,6 +246,7 @@ def _make_metadata_for_partial_job(
219
246
  """Create metadata for partial jobs, return metadata and requirements."""
220
247
  metadata = {"_partial": "v0"}
221
248
  if job_type == "git":
249
+ assert entrypoint is not None
222
250
  repo_metadata = _create_repo_metadata(
223
251
  path=path,
224
252
  tempdir=tempdir.name,
@@ -233,13 +261,7 @@ def _make_metadata_for_partial_job(
233
261
  return metadata, None
234
262
 
235
263
  if job_type == "code":
236
- path, entrypoint = _handle_artifact_entrypoint(path, entrypoint)
237
- if not entrypoint:
238
- wandb.termerror(
239
- "Artifact jobs must have an entrypoint, either included in the path or specified with -E"
240
- )
241
- return None, None
242
-
264
+ assert entrypoint is not None
243
265
  artifact_metadata, requirements = _create_artifact_metadata(
244
266
  path=path, entrypoint=entrypoint, runtime=runtime
245
267
  )
@@ -268,7 +290,7 @@ def _make_metadata_for_partial_job(
268
290
  def _create_repo_metadata(
269
291
  path: str,
270
292
  tempdir: str,
271
- entrypoint: Optional[str] = None,
293
+ entrypoint: str,
272
294
  git_hash: Optional[str] = None,
273
295
  runtime: Optional[str] = None,
274
296
  ) -> Optional[Dict[str, Any]]:
@@ -304,25 +326,16 @@ def _create_repo_metadata(
304
326
  with open(os.path.join(local_dir, ".python-version")) as f:
305
327
  python_version = f.read().strip().splitlines()[0]
306
328
  else:
307
- major, minor = get_current_python_version()
308
- python_version = f"{major}.{minor}"
329
+ _, python_version = get_current_python_version()
309
330
 
310
331
  python_version = _clean_python_version(python_version)
311
332
 
312
- # check if entrypoint is valid
313
- assert entrypoint is not None
314
- if not os.path.exists(os.path.join(local_dir, entrypoint)):
315
- wandb.termerror(f"Entrypoint {entrypoint} not found in git repo")
316
- return None
317
-
318
333
  metadata = {
319
334
  "git": {
320
335
  "commit": commit,
321
336
  "remote": ref.url,
322
337
  },
323
- "codePathLocal": entrypoint, # not in git context, optionally also set local
324
- "codePath": entrypoint,
325
- "entrypoint": [f"python{python_version}", entrypoint],
338
+ "entrypoint": entrypoint.split(" "),
326
339
  "python": python_version, # used to build container
327
340
  "notebook": False, # partial jobs from notebooks not supported
328
341
  }
@@ -332,10 +345,17 @@ def _create_repo_metadata(
332
345
 
333
346
  def _create_artifact_metadata(
334
347
  path: str, entrypoint: str, runtime: Optional[str] = None
335
- ) -> Tuple[Dict[str, Any], List[str]]:
348
+ ) -> Tuple[Optional[Dict[str, Any]], Optional[List[str]]]:
336
349
  if not os.path.isdir(path):
337
350
  wandb.termerror("Path must be a valid file or directory")
338
351
  return {}, []
352
+ entrypoint_list = entrypoint.split(" ")
353
+ entrypoint_file = get_entrypoint_file(entrypoint_list)
354
+ if not entrypoint_file:
355
+ wandb.termerror(
356
+ f"Entrypoint {entrypoint} is invalid. An entrypoint should include both an executable and a file, for example 'python train.py'"
357
+ )
358
+ return None, None
339
359
 
340
360
  # read local requirements.txt and dump to temp dir for builder
341
361
  requirements = []
@@ -347,41 +367,17 @@ def _create_artifact_metadata(
347
367
  if runtime:
348
368
  python_version = _clean_python_version(runtime)
349
369
  else:
350
- python_version = ".".join(get_current_python_version())
370
+ python_version, _ = get_current_python_version()
371
+ python_version = _clean_python_version(python_version)
351
372
 
352
- metadata = {"python": python_version, "codePath": entrypoint}
373
+ metadata = {
374
+ "python": python_version,
375
+ "codePath": entrypoint_file,
376
+ "entrypoint": entrypoint_list,
377
+ }
353
378
  return metadata, requirements
354
379
 
355
380
 
356
- def _handle_artifact_entrypoint(
357
- path: str, entrypoint: Optional[str] = None
358
- ) -> Tuple[str, Optional[str]]:
359
- if os.path.isfile(path):
360
- if entrypoint and path.endswith(entrypoint):
361
- path = path.replace(entrypoint, "")
362
- wandb.termwarn(
363
- f"Both entrypoint provided and path contains file. Using provided entrypoint: {entrypoint}, path is now: {path}"
364
- )
365
- elif entrypoint:
366
- wandb.termwarn(
367
- f"Ignoring passed in entrypoint as it does not match file path found in 'path'. Path entrypoint: {path.split('/')[-1]}"
368
- )
369
- entrypoint = path.split("/")[-1]
370
- path = "/".join(path.split("/")[:-1])
371
- elif not entrypoint:
372
- wandb.termerror("Entrypoint not valid")
373
- return "", None
374
- path = path or "." # when path is just an entrypoint, use cdw
375
-
376
- if not os.path.exists(os.path.join(path, entrypoint)):
377
- wandb.termerror(
378
- f"Could not find execution point: {os.path.join(path, entrypoint)}"
379
- )
380
- return "", None
381
-
382
- return path, entrypoint
383
-
384
-
385
381
  def _configure_job_builder_for_partial(tmpdir: str, job_source: str) -> JobBuilder:
386
382
  """Configure job builder with temp dir and job source."""
387
383
  # adjust git source to repo
@@ -411,7 +407,7 @@ def _make_code_artifact(
411
407
  job_builder: JobBuilder,
412
408
  run: "wandb.sdk.wandb_run.Run",
413
409
  path: str,
414
- entrypoint: Optional[str],
410
+ entrypoint: str,
415
411
  entity: Optional[str],
416
412
  project: Optional[str],
417
413
  name: Optional[str],
@@ -420,17 +416,19 @@ def _make_code_artifact(
420
416
 
421
417
  Returns the name of the eventual job.
422
418
  """
423
- artifact_name = _make_code_artifact_name(os.path.join(path, entrypoint or ""), name)
419
+ entrypoint_list = entrypoint.split(" ")
420
+ # We no longer require the entrypoint to end in an existing file. But we
421
+ # need something to use as the default job artifact name. In the future we
422
+ # may require the user to provide a job name explicitly when calling
423
+ # wandb job create.
424
+ entrypoint_file = entrypoint_list[-1]
425
+ artifact_name = _make_code_artifact_name(os.path.join(path, entrypoint_file), name)
424
426
  code_artifact = wandb.Artifact(
425
427
  name=artifact_name,
426
428
  type="code",
427
429
  description="Code artifact for job",
428
430
  )
429
431
 
430
- # Update path and entrypoint vars to match metadata
431
- # TODO(gst): consolidate into one place
432
- path, entrypoint = _handle_artifact_entrypoint(path, entrypoint)
433
-
434
432
  try:
435
433
  code_artifact.add_dir(path)
436
434
  except Exception as e:
@@ -441,6 +439,13 @@ def _make_code_artifact(
441
439
  wandb.termerror(f"Error adding to code artifact: {e}")
442
440
  return None
443
441
 
442
+ # Remove paths we don't want to include, if present
443
+ for item in CODE_ARTIFACT_EXCLUDE_PATHS:
444
+ try:
445
+ code_artifact.remove(item)
446
+ except FileNotFoundError:
447
+ pass
448
+
444
449
  res, _ = api.create_artifact(
445
450
  artifact_type_name="code",
446
451
  artifact_collection_name=artifact_name,
@@ -451,7 +456,7 @@ def _make_code_artifact(
451
456
  project_name=project,
452
457
  run_name=run.id, # run will be deleted after creation
453
458
  description="Code artifact for job",
454
- metadata={"codePath": path, "entrypoint": entrypoint},
459
+ metadata={"codePath": path, "entrypoint": entrypoint_file},
455
460
  is_user_created=True,
456
461
  aliases=[
457
462
  {"artifactCollectionName": artifact_name, "alias": a} for a in ["latest"]
@@ -1,4 +1,5 @@
1
1
  """Abstract base class for environments."""
2
+
2
3
  from abc import ABC, abstractmethod
3
4
 
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Implementation of the GCP environment for wandb launch."""
2
+
2
3
  import logging
3
4
  import os
4
5
  import subprocess
@@ -1,4 +1,5 @@
1
1
  """Dummy local environment implementation. This is the default environment."""
2
+
2
3
  from typing import Any, Dict, Union
3
4
 
4
5
  from wandb.sdk.launch.errors import LaunchError
@@ -0,0 +1,148 @@
1
+ import json
2
+ import os
3
+ from typing import Any, Dict
4
+
5
+ import yaml
6
+
7
+ from ..errors import LaunchError
8
+
9
+ FILE_OVERRIDE_ENV_VAR = "WANDB_LAUNCH_FILE_OVERRIDES"
10
+
11
+
12
+ class FileOverrides:
13
+ """Singleton that read file overrides json from environment variables."""
14
+
15
+ _instance = None
16
+
17
+ def __new__(cls):
18
+ if cls._instance is None:
19
+ cls._instance = object.__new__(cls)
20
+ cls._instance.overrides = {}
21
+ cls._instance.load()
22
+ return cls._instance
23
+
24
+ def load(self) -> None:
25
+ """Load overrides from an environment variable."""
26
+ overrides = os.environ.get(FILE_OVERRIDE_ENV_VAR)
27
+ if overrides is None:
28
+ if f"{FILE_OVERRIDE_ENV_VAR}_0" in os.environ:
29
+ overrides = ""
30
+ idx = 0
31
+ while f"{FILE_OVERRIDE_ENV_VAR}_{idx}" in os.environ:
32
+ overrides += os.environ[f"{FILE_OVERRIDE_ENV_VAR}_{idx}"]
33
+ idx += 1
34
+ if overrides:
35
+ try:
36
+ contents = json.loads(overrides)
37
+ if not isinstance(contents, dict):
38
+ raise LaunchError(f"Invalid JSON in {FILE_OVERRIDE_ENV_VAR}")
39
+ self.overrides = contents
40
+ except json.JSONDecodeError:
41
+ raise LaunchError(f"Invalid JSON in {FILE_OVERRIDE_ENV_VAR}")
42
+
43
+
44
+ def config_path_is_valid(path: str) -> None:
45
+ """Validate a config file path.
46
+
47
+ This function checks if a given config file path is valid. A valid path
48
+ should meet the following criteria:
49
+
50
+ - The path must be expressed as a relative path without any upwards path
51
+ traversal, e.g. `../config.json`.
52
+ - The file specified by the path must exist.
53
+ - The file must have a supported extension (`.json`, `.yaml`, or `.yml`).
54
+
55
+ Args:
56
+ path (str): The path to validate.
57
+
58
+ Raises:
59
+ LaunchError: If the path is not valid.
60
+ """
61
+ if os.path.isabs(path):
62
+ raise LaunchError(
63
+ f"Invalid config path: {path}. Please provide a relative path."
64
+ )
65
+ if ".." in path:
66
+ raise LaunchError(
67
+ f"Invalid config path: {path}. Please provide a relative path "
68
+ "without any upward path traversal, e.g. `../config.json`."
69
+ )
70
+ path = os.path.normpath(path)
71
+ if not os.path.exists(path):
72
+ raise LaunchError(f"Invalid config path: {path}. File does not exist.")
73
+ if not any(path.endswith(ext) for ext in [".json", ".yaml", ".yml"]):
74
+ raise LaunchError(
75
+ f"Invalid config path: {path}. Only JSON and YAML files are supported."
76
+ )
77
+
78
+
79
+ def override_file(path: str) -> None:
80
+ """Check for file overrides in the environment and apply them if found."""
81
+ file_overrides = FileOverrides()
82
+ if path in file_overrides.overrides:
83
+ overrides = file_overrides.overrides.get(path)
84
+ if overrides is not None:
85
+ config = _read_config_file(path)
86
+ _update_dict(config, overrides)
87
+ _write_config_file(path, config)
88
+
89
+
90
+ def _write_config_file(path: str, config: Any) -> None:
91
+ """Write a config file to disk.
92
+
93
+ Args:
94
+ path (str): The path to the config file.
95
+ config (Any): The contents of the config file as a Python object.
96
+
97
+ Raises:
98
+ LaunchError: If the file extension is not supported.
99
+ """
100
+ _, ext = os.path.splitext(path)
101
+ if ext == ".json":
102
+ with open(path, "w") as f:
103
+ json.dump(config, f, indent=2)
104
+ elif ext in [".yaml", ".yml"]:
105
+ with open(path, "w") as f:
106
+ yaml.safe_dump(config, f)
107
+ else:
108
+ raise LaunchError(f"Unsupported file extension: {ext}")
109
+
110
+
111
+ def _read_config_file(path: str) -> Any:
112
+ """Read a config file from disk.
113
+
114
+ Args:
115
+ path (str): The path to the config file.
116
+
117
+ Returns:
118
+ Any: The contents of the config file as a Python object.
119
+ """
120
+ _, ext = os.path.splitext(path)
121
+ if ext == ".json":
122
+ with open(
123
+ path,
124
+ ) as f:
125
+ return json.load(f)
126
+ elif ext in [".yaml", ".yml"]:
127
+ with open(
128
+ path,
129
+ ) as f:
130
+ return yaml.safe_load(f)
131
+ else:
132
+ raise LaunchError(f"Unsupported file extension: {ext}")
133
+
134
+
135
+ def _update_dict(target: Dict, source: Dict) -> None:
136
+ """Update a dictionary with the contents of another dictionary.
137
+
138
+ Args:
139
+ target (Dict): The dictionary to update.
140
+ source (Dict): The dictionary to update from.
141
+ """
142
+ for key, value in source.items():
143
+ if isinstance(value, dict):
144
+ if key not in target:
145
+ target[key] = {}
146
+ _update_dict(target[key], value)
147
+ else:
148
+ target[key] = value
@@ -0,0 +1,217 @@
1
+ """The layer between launch sdk user code and the wandb internal process.
2
+
3
+ If there is an active run this communication is done through the wandb run's
4
+ backend interface.
5
+
6
+ If there is no active run, the messages are staged on the StagedLaunchInputs
7
+ singleton and sent when a run is created.
8
+ """
9
+
10
+ import os
11
+ import pathlib
12
+ import shutil
13
+ import tempfile
14
+ from typing import List, Optional
15
+
16
+ import wandb
17
+ import wandb.data_types
18
+ from wandb.sdk.launch.errors import LaunchError
19
+ from wandb.sdk.wandb_run import Run
20
+
21
+ from .files import config_path_is_valid, override_file
22
+
23
+ PERIOD = "."
24
+ BACKSLASH = "\\"
25
+
26
+
27
+ class ConfigTmpDir:
28
+ """Singleton for managing temporary directories for configuration files.
29
+
30
+ Any configuration files designated as inputs to a launch job are copied to
31
+ a temporary directory. This singleton manages the temporary directory and
32
+ provides paths to the configuration files.
33
+ """
34
+
35
+ _instance = None
36
+
37
+ def __new__(cls):
38
+ if cls._instance is None:
39
+ cls._instance = object.__new__(cls)
40
+ return cls._instance
41
+
42
+ def __init__(self):
43
+ if not hasattr(self, "_tmp_dir"):
44
+ self._tmp_dir = tempfile.mkdtemp()
45
+ self._configs_dir = os.path.join(self._tmp_dir, "configs")
46
+ os.mkdir(self._configs_dir)
47
+
48
+ @property
49
+ def tmp_dir(self):
50
+ return pathlib.Path(self._tmp_dir)
51
+
52
+ @property
53
+ def configs_dir(self):
54
+ return pathlib.Path(self._configs_dir)
55
+
56
+
57
+ class JobInputArguments:
58
+ """Arguments for the publish_job_input of Interface."""
59
+
60
+ def __init__(
61
+ self,
62
+ include: Optional[List[str]] = None,
63
+ exclude: Optional[List[str]] = None,
64
+ file_path: Optional[str] = None,
65
+ run_config: Optional[bool] = None,
66
+ ):
67
+ self.include = include
68
+ self.exclude = exclude
69
+ self.file_path = file_path
70
+ self.run_config = run_config
71
+
72
+
73
+ class StagedLaunchInputs:
74
+ _instance = None
75
+
76
+ def __new__(cls):
77
+ if cls._instance is None:
78
+ cls._instance = object.__new__(cls)
79
+ return cls._instance
80
+
81
+ def __init__(self) -> None:
82
+ if not hasattr(self, "_staged_inputs"):
83
+ self._staged_inputs: List[JobInputArguments] = []
84
+
85
+ def add_staged_input(
86
+ self,
87
+ input_arguments: JobInputArguments,
88
+ ):
89
+ self._staged_inputs.append(input_arguments)
90
+
91
+ def apply(self, run: Run):
92
+ """Apply the staged inputs to the given run."""
93
+ for input in self._staged_inputs:
94
+ _publish_job_input(input, run)
95
+
96
+
97
+ def _publish_job_input(
98
+ input: JobInputArguments,
99
+ run: Run,
100
+ ) -> None:
101
+ """Publish a job input to the backend interface of the given run.
102
+
103
+ Arguments:
104
+ input (JobInputArguments): The arguments for the job input.
105
+ run (Run): The run to publish the job input to.
106
+ """
107
+ assert run._backend is not None
108
+ assert run._backend.interface is not None
109
+ assert input.run_config is not None
110
+
111
+ interface = run._backend.interface
112
+ if input.file_path:
113
+ config_dir = ConfigTmpDir()
114
+ dest = os.path.join(config_dir.configs_dir, input.file_path)
115
+ run.save(dest, base_path=config_dir.tmp_dir)
116
+ interface.publish_job_input(
117
+ include_paths=[_split_on_unesc_dot(path) for path in input.include]
118
+ if input.include
119
+ else [],
120
+ exclude_paths=[_split_on_unesc_dot(path) for path in input.exclude]
121
+ if input.exclude
122
+ else [],
123
+ run_config=input.run_config,
124
+ file_path=input.file_path or "",
125
+ )
126
+
127
+
128
+ def handle_config_file_input(
129
+ path: str,
130
+ include: Optional[List[str]] = None,
131
+ exclude: Optional[List[str]] = None,
132
+ ):
133
+ """Declare an overridable configuration file for a launch job.
134
+
135
+ The configuration file is copied to a temporary directory and the path to
136
+ the copy is sent to the backend interface of the active run and used to
137
+ configure the job builder.
138
+
139
+ If there is no active run, the configuration file is staged and sent when a
140
+ run is created.
141
+ """
142
+ config_path_is_valid(path)
143
+ override_file(path)
144
+ tmp_dir = ConfigTmpDir()
145
+ dest = os.path.join(tmp_dir.configs_dir, path)
146
+ shutil.copy(path, dest)
147
+ arguments = JobInputArguments(
148
+ include=include,
149
+ exclude=exclude,
150
+ file_path=path,
151
+ run_config=False,
152
+ )
153
+ if wandb.run is not None:
154
+ _publish_job_input(arguments, wandb.run)
155
+ else:
156
+ staged_inputs = StagedLaunchInputs()
157
+ staged_inputs.add_staged_input(arguments)
158
+
159
+
160
+ def handle_run_config_input(
161
+ include: Optional[List[str]] = None, exclude: Optional[List[str]] = None
162
+ ):
163
+ """Declare wandb.config as an overridable configuration for a launch job.
164
+
165
+ The include and exclude paths are sent to the backend interface of the
166
+ active run and used to configure the job builder.
167
+
168
+ If there is no active run, the include and exclude paths are staged and sent
169
+ when a run is created.
170
+ """
171
+ arguments = JobInputArguments(
172
+ include=include,
173
+ exclude=exclude,
174
+ run_config=True,
175
+ file_path=None,
176
+ )
177
+ if wandb.run is not None:
178
+ _publish_job_input(arguments, wandb.run)
179
+ else:
180
+ stage_inputs = StagedLaunchInputs()
181
+ stage_inputs.add_staged_input(arguments)
182
+
183
+
184
+ def _split_on_unesc_dot(path: str) -> List[str]:
185
+ r"""Split a string on unescaped dots.
186
+
187
+ Arguments:
188
+ path (str): The string to split.
189
+
190
+ Raises:
191
+ ValueError: If the path has a trailing escape character.
192
+
193
+ Returns:
194
+ List[str]: The split string.
195
+ """
196
+ parts = []
197
+ part = ""
198
+ i = 0
199
+ while i < len(path):
200
+ if path[i] == BACKSLASH:
201
+ if i == len(path) - 1:
202
+ raise LaunchError(
203
+ f"Invalid config path {path}: trailing {BACKSLASH}.",
204
+ )
205
+ if path[i + 1] == PERIOD:
206
+ part += PERIOD
207
+ i += 2
208
+ elif path[i] == PERIOD:
209
+ parts.append(part)
210
+ part = ""
211
+ i += 1
212
+ else:
213
+ part += path[i]
214
+ i += 1
215
+ if part:
216
+ parts.append(part)
217
+ return parts