wandb 0.16.4__py3-none-any.whl → 0.16.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. wandb/__init__.py +2 -2
  2. wandb/agents/pyagent.py +1 -1
  3. wandb/apis/public/api.py +6 -6
  4. wandb/apis/reports/v2/interface.py +4 -8
  5. wandb/apis/reports/v2/internal.py +12 -45
  6. wandb/cli/cli.py +24 -3
  7. wandb/integration/ultralytics/callback.py +0 -1
  8. wandb/proto/v3/wandb_internal_pb2.py +332 -312
  9. wandb/proto/v3/wandb_settings_pb2.py +13 -3
  10. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  11. wandb/proto/v4/wandb_internal_pb2.py +316 -312
  12. wandb/proto/v4/wandb_settings_pb2.py +5 -3
  13. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  14. wandb/sdk/artifacts/artifact.py +67 -17
  15. wandb/sdk/artifacts/artifact_manifest_entry.py +6 -1
  16. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
  17. wandb/sdk/artifacts/artifact_saver.py +1 -18
  18. wandb/sdk/artifacts/storage_handler.py +2 -1
  19. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +13 -5
  20. wandb/sdk/interface/interface.py +42 -9
  21. wandb/sdk/interface/interface_shared.py +13 -7
  22. wandb/sdk/internal/file_stream.py +19 -0
  23. wandb/sdk/internal/handler.py +1 -4
  24. wandb/sdk/internal/internal_api.py +2 -0
  25. wandb/sdk/internal/job_builder.py +45 -17
  26. wandb/sdk/internal/sender.py +53 -28
  27. wandb/sdk/internal/settings_static.py +9 -0
  28. wandb/sdk/internal/system/system_info.py +4 -1
  29. wandb/sdk/launch/create_job.py +1 -0
  30. wandb/sdk/launch/runner/kubernetes_runner.py +20 -2
  31. wandb/sdk/launch/utils.py +5 -5
  32. wandb/sdk/lib/__init__.py +2 -5
  33. wandb/sdk/lib/_settings_toposort_generated.py +1 -0
  34. wandb/sdk/lib/filesystem.py +11 -1
  35. wandb/sdk/lib/run_moment.py +72 -0
  36. wandb/sdk/service/streams.py +1 -6
  37. wandb/sdk/wandb_init.py +12 -1
  38. wandb/sdk/wandb_login.py +43 -26
  39. wandb/sdk/wandb_run.py +158 -89
  40. wandb/sdk/wandb_settings.py +53 -16
  41. wandb/testing/relay.py +5 -6
  42. {wandb-0.16.4.dist-info → wandb-0.16.5.dist-info}/METADATA +1 -1
  43. {wandb-0.16.4.dist-info → wandb-0.16.5.dist-info}/RECORD +47 -46
  44. {wandb-0.16.4.dist-info → wandb-0.16.5.dist-info}/WHEEL +1 -1
  45. {wandb-0.16.4.dist-info → wandb-0.16.5.dist-info}/LICENSE +0 -0
  46. {wandb-0.16.4.dist-info → wandb-0.16.5.dist-info}/entry_points.txt +0 -0
  47. {wandb-0.16.4.dist-info → wandb-0.16.5.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  import base64
2
2
  import functools
3
3
  import itertools
4
+ import json
4
5
  import logging
5
6
  import os
6
7
  import queue
@@ -58,6 +59,7 @@ class Chunk(NamedTuple):
58
59
  class DefaultFilePolicy:
59
60
  def __init__(self, start_chunk_id: int = 0) -> None:
60
61
  self._chunk_id = start_chunk_id
62
+ self.has_debug_log = False
61
63
 
62
64
  def process_chunks(
63
65
  self, chunks: List[Chunk]
@@ -66,6 +68,21 @@ class DefaultFilePolicy:
66
68
  self._chunk_id += len(chunks)
67
69
  return {"offset": chunk_id, "content": [c.data for c in chunks]}
68
70
 
71
+ # TODO: this is very inefficient, this is meant for temporary debugging and will be removed in future releases
72
+ def _debug_log(self, data: Any):
73
+ if self.has_debug_log or not os.environ.get("WANDB_DEBUG_FILESTREAM_LOG"):
74
+ return
75
+
76
+ loaded = json.loads(data)
77
+ if not isinstance(loaded, dict):
78
+ return
79
+
80
+ # get key size and convert to MB
81
+ key_sizes = [(k, len(json.dumps(v))) for k, v in loaded.items()]
82
+ key_msg = [f"{k}: {v/1048576:.5f} MB" for k, v in key_sizes]
83
+ wandb.termerror(f"Step: {loaded['_step']} | {key_msg}", repeat=False)
84
+ self.has_debug_log = True
85
+
69
86
 
70
87
  class JsonlFilePolicy(DefaultFilePolicy):
71
88
  def process_chunks(self, chunks: List[Chunk]) -> "ProcessedChunk":
@@ -81,6 +98,7 @@ class JsonlFilePolicy(DefaultFilePolicy):
81
98
  )
82
99
  wandb.termerror(msg, repeat=False)
83
100
  wandb._sentry.message(msg, repeat=False)
101
+ self._debug_log(chunk.data)
84
102
  else:
85
103
  chunk_data.append(chunk.data)
86
104
 
@@ -99,6 +117,7 @@ class SummaryFilePolicy(DefaultFilePolicy):
99
117
  )
100
118
  wandb.termerror(msg, repeat=False)
101
119
  wandb._sentry.message(msg, repeat=False)
120
+ self._debug_log(data)
102
121
  return False
103
122
  return {"offset": 0, "content": [data]}
104
123
 
@@ -689,7 +689,7 @@ class HandleManager:
689
689
  self._settings, interface=self._interface, run_proto=run_start.run
690
690
  )
691
691
 
692
- if run_start.run.resumed:
692
+ if run_start.run.resumed or run_start.run.forked:
693
693
  self._step = run_start.run.starting_step
694
694
  result = proto_util._result_from_record(record)
695
695
  self._respond_result(result)
@@ -862,9 +862,6 @@ class HandleManager:
862
862
  self._respond_result(result)
863
863
  self._stopped.set()
864
864
 
865
- def handle_request_job_info(self, record: Record) -> None:
866
- self._dispatch_record(record, always_send=True)
867
-
868
865
  def finish(self) -> None:
869
866
  logger.info("shutting down handler")
870
867
  if self._system_monitor is not None:
@@ -2150,6 +2150,7 @@ class Api:
2150
2150
  name
2151
2151
  }
2152
2152
  }
2153
+ historyLineCount
2153
2154
  }
2154
2155
  inserted
2155
2156
  _Server_Settings_
@@ -2237,6 +2238,7 @@ class Api:
2237
2238
  .get("serverSettings", {})
2238
2239
  .get("serverMessages", [])
2239
2240
  )
2241
+
2240
2242
  return (
2241
2243
  response["upsertBucket"]["bucket"],
2242
2244
  response["upsertBucket"]["inserted"],
@@ -4,7 +4,7 @@ import logging
4
4
  import os
5
5
  import re
6
6
  import sys
7
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
7
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
8
8
 
9
9
  import wandb
10
10
  from wandb.sdk.artifacts.artifact import Artifact
@@ -28,6 +28,8 @@ FROZEN_REQUIREMENTS_FNAME = "requirements.frozen.txt"
28
28
  JOB_FNAME = "wandb-job.json"
29
29
  JOB_ARTIFACT_TYPE = "job"
30
30
 
31
+ LOG_LEVEL = Literal["log", "warn", "error"]
32
+
31
33
 
32
34
  class GitInfo(TypedDict):
33
35
  remote: str
@@ -89,8 +91,9 @@ class JobBuilder:
89
91
  _job_seq_id: Optional[str]
90
92
  _job_version_alias: Optional[str]
91
93
  _is_notebook_run: bool
94
+ _verbose: bool
92
95
 
93
- def __init__(self, settings: SettingsStatic):
96
+ def __init__(self, settings: SettingsStatic, verbose: bool = False):
94
97
  self._settings = settings
95
98
  self._metadatafile_path = None
96
99
  self._requirements_path = None
@@ -106,6 +109,7 @@ class JobBuilder:
106
109
  Literal["repo", "artifact", "image"]
107
110
  ] = settings.job_source # type: ignore[assignment]
108
111
  self._is_notebook_run = self._get_is_notebook_run()
112
+ self._verbose = verbose
109
113
 
110
114
  def set_config(self, config: Dict[str, Any]) -> None:
111
115
  self._config = config
@@ -197,6 +201,21 @@ class JobBuilder:
197
201
 
198
202
  return source, name
199
203
 
204
+ def _log_if_verbose(self, message: str, level: LOG_LEVEL) -> None:
205
+ log_func: Optional[Union[Callable[[Any], None], Callable[[Any], None]]] = None
206
+ if level == "log":
207
+ _logger.info(message)
208
+ log_func = wandb.termlog
209
+ elif level == "warn":
210
+ _logger.warning(message)
211
+ log_func = wandb.termwarn
212
+ elif level == "error":
213
+ _logger.error(message)
214
+ log_func = wandb.termerror
215
+
216
+ if self._verbose and log_func is not None:
217
+ log_func(message)
218
+
200
219
  def _build_artifact_job_source(
201
220
  self,
202
221
  program_relpath: str,
@@ -212,8 +231,9 @@ class JobBuilder:
212
231
  # at the directory the notebook is in instead of the jupyter core
213
232
  if not os.path.exists(os.path.basename(program_relpath)):
214
233
  _logger.info("target path does not exist, exiting")
215
- wandb.termwarn(
216
- "No program path found when generating artifact job source for a non-colab notebook run. See https://docs.wandb.ai/guides/launch/create-job"
234
+ self._log_if_verbose(
235
+ "No program path found when generating artifact job source for a non-colab notebook run. See https://docs.wandb.ai/guides/launch/create-job",
236
+ "warn",
217
237
  )
218
238
  return None, None
219
239
  full_program_relpath = os.path.basename(program_relpath)
@@ -299,22 +319,25 @@ class JobBuilder:
299
319
  if not os.path.exists(
300
320
  os.path.join(self._settings.files_dir, REQUIREMENTS_FNAME)
301
321
  ):
302
- wandb.termwarn(
303
- "No requirements.txt found, not creating job artifact. See https://docs.wandb.ai/guides/launch/create-job"
322
+ self._log_if_verbose(
323
+ "No requirements.txt found, not creating job artifact. See https://docs.wandb.ai/guides/launch/create-job",
324
+ "warn",
304
325
  )
305
326
  return None
306
327
  metadata = self._handle_metadata_file()
307
328
  if metadata is None:
308
- wandb.termwarn(
309
- f"Ensure read and write access to run files dir: {self._settings.files_dir}, control this via the WANDB_DIR env var. See https://docs.wandb.ai/guides/track/environment-variables"
329
+ self._log_if_verbose(
330
+ f"Ensure read and write access to run files dir: {self._settings.files_dir}, control this via the WANDB_DIR env var. See https://docs.wandb.ai/guides/track/environment-variables",
331
+ "warn",
310
332
  )
311
333
  return None
312
334
 
313
335
  runtime: Optional[str] = metadata.get("python")
314
336
  # can't build a job without a python version
315
337
  if runtime is None:
316
- wandb.termwarn(
317
- "No python version found in metadata, not creating job artifact. See https://docs.wandb.ai/guides/launch/create-job"
338
+ self._log_if_verbose(
339
+ "No python version found in metadata, not creating job artifact. See https://docs.wandb.ai/guides/launch/create-job",
340
+ "warn",
318
341
  )
319
342
  return None
320
343
 
@@ -345,13 +368,16 @@ class JobBuilder:
345
368
  or self._settings.job_source
346
369
  or self._source_type
347
370
  ):
348
- wandb.termwarn("No source type found, not creating job artifact")
371
+ self._log_if_verbose(
372
+ "No source type found, not creating job artifact", "warn"
373
+ )
349
374
  return None
350
375
 
351
376
  program_relpath = self._get_program_relpath(source_type, metadata)
352
377
  if source_type != "image" and not program_relpath:
353
- wandb.termwarn(
354
- "No program path found, not creating job artifact. See https://docs.wandb.ai/guides/launch/create-job"
378
+ self._log_if_verbose(
379
+ "No program path found, not creating job artifact. See https://docs.wandb.ai/guides/launch/create-job",
380
+ "warn",
355
381
  )
356
382
  return None
357
383
 
@@ -377,10 +403,11 @@ class JobBuilder:
377
403
 
378
404
  if source is None:
379
405
  if source_type:
380
- wandb.termwarn(
406
+ self._log_if_verbose(
381
407
  f"Source type is set to '{source_type}' but some required information is missing "
382
408
  "from the environment. A job will not be created from this run. See "
383
- "https://docs.wandb.ai/guides/launch/create-job"
409
+ "https://docs.wandb.ai/guides/launch/create-job",
410
+ "warn",
384
411
  )
385
412
  return None
386
413
 
@@ -447,8 +474,9 @@ class JobBuilder:
447
474
  program = metadata.get("program")
448
475
 
449
476
  if not program:
450
- wandb.termwarn(
451
- "Notebook 'program' path not found in metadata. See https://docs.wandb.ai/guides/launch/create-job"
477
+ self._log_if_verbose(
478
+ "Notebook 'program' path not found in metadata. See https://docs.wandb.ai/guides/launch/create-job",
479
+ "warn",
452
480
  )
453
481
 
454
482
  return program
@@ -115,6 +115,7 @@ def _manifest_json_from_proto(manifest: "ArtifactManifest") -> Dict:
115
115
  "ref": content.ref if content.ref else None,
116
116
  "size": content.size if content.size is not None else None,
117
117
  "local_path": content.local_path if content.local_path else None,
118
+ "skip_cache": content.skip_cache,
118
119
  "extra": {
119
120
  extra.key: json.loads(extra.value_json) for extra in content.extra
120
121
  },
@@ -733,18 +734,7 @@ class SendManager:
733
734
  )
734
735
  self._respond_result(result)
735
736
 
736
- def send_request_job_info(self, record: "Record") -> None:
737
- """Respond to a request for a job link."""
738
- result = proto_util._result_from_record(record)
739
- result.response.job_info_response.sequenceId = (
740
- self._job_builder._job_seq_id or ""
741
- )
742
- result.response.job_info_response.version = (
743
- self._job_builder._job_version_alias or ""
744
- )
745
- self._respond_result(result)
746
-
747
- def _maybe_setup_resume(
737
+ def _setup_resume(
748
738
  self, run: "RunRecord"
749
739
  ) -> Optional["wandb_internal_pb2.ErrorInfo"]:
750
740
  """Queries the backend for a run; fail if the settings are incompatible."""
@@ -890,6 +880,30 @@ class SendManager:
890
880
  pass
891
881
  # TODO: do something if sync spell is not successful?
892
882
 
883
+ def _setup_fork(self, server_run: dict):
884
+ assert self._settings.fork_from
885
+ assert self._settings.fork_from.metric == "_step"
886
+ assert self._run
887
+ first_step = int(self._settings.fork_from.value) + 1
888
+ self._resume_state.step = first_step
889
+ self._resume_state.history = server_run.get("historyLineCount", 0)
890
+ self._run.forked = True
891
+ self._run.starting_step = first_step
892
+
893
+ def _handle_error(
894
+ self,
895
+ record: "Record",
896
+ error: "wandb_internal_pb2.ErrorInfo",
897
+ run: "RunRecord",
898
+ ) -> None:
899
+ if record.control.req_resp or record.control.mailbox_slot:
900
+ result = proto_util._result_from_record(record)
901
+ result.run_result.run.CopyFrom(run)
902
+ result.run_result.error.CopyFrom(error)
903
+ self._respond_result(result)
904
+ else:
905
+ logger.error("Got error in async mode: %s", error.message)
906
+
893
907
  def send_run(self, record: "Record", file_dir: Optional[str] = None) -> None:
894
908
  run = record.run
895
909
  error = None
@@ -911,21 +925,28 @@ class SendManager:
911
925
  config_value_dict = self._config_backend_dict()
912
926
  self._config_save(config_value_dict)
913
927
 
928
+ do_fork = self._settings.fork_from is not None and is_wandb_init
929
+ do_resume = bool(self._settings.resume)
930
+
931
+ if do_fork and do_resume:
932
+ error = wandb_internal_pb2.ErrorInfo()
933
+ error.code = wandb_internal_pb2.ErrorInfo.ErrorCode.USAGE
934
+ error.message = (
935
+ "You cannot use `resume` and `fork_from` together. Please choose one."
936
+ )
937
+ self._handle_error(record, error, run)
938
+
914
939
  if is_wandb_init:
915
940
  # Ensure we have a project to query for status
916
941
  if run.project == "":
917
942
  run.project = util.auto_project_name(self._settings.program)
918
943
  # Only check resume status on `wandb.init`
919
- error = self._maybe_setup_resume(run)
944
+
945
+ if do_resume:
946
+ error = self._setup_resume(run)
920
947
 
921
948
  if error is not None:
922
- if record.control.req_resp or record.control.mailbox_slot:
923
- result = proto_util._result_from_record(record)
924
- result.run_result.run.CopyFrom(run)
925
- result.run_result.error.CopyFrom(error)
926
- self._respond_result(result)
927
- else:
928
- logger.error("Got error in async mode: %s", error.message)
949
+ self._handle_error(record, error, run)
929
950
  return
930
951
 
931
952
  # Save the resumed config
@@ -945,19 +966,22 @@ class SendManager:
945
966
  self._config_save(config_value_dict)
946
967
 
947
968
  try:
948
- self._init_run(run, config_value_dict)
969
+ server_run = self._init_run(run, config_value_dict)
949
970
  except (CommError, UsageError) as e:
950
971
  logger.error(e, exc_info=True)
951
- if record.control.req_resp or record.control.mailbox_slot:
952
- result = proto_util._result_from_record(record)
953
- result.run_result.run.CopyFrom(run)
954
- error = ProtobufErrorHandler.from_exception(e)
955
- result.run_result.error.CopyFrom(error)
956
- self._respond_result(result)
972
+ error = ProtobufErrorHandler.from_exception(e)
973
+ self._handle_error(record, error, run)
957
974
  return
958
975
 
959
976
  assert self._run # self._run is configured in _init_run()
960
977
 
978
+ if do_fork:
979
+ error = self._setup_fork(server_run)
980
+
981
+ if error is not None:
982
+ self._handle_error(record, error, run)
983
+ return
984
+
961
985
  if record.control.req_resp or record.control.mailbox_slot:
962
986
  result = proto_util._result_from_record(record)
963
987
  # TODO: we could do self._interface.publish_defer(resp) to notify
@@ -976,7 +1000,7 @@ class SendManager:
976
1000
  self,
977
1001
  run: "RunRecord",
978
1002
  config_dict: Optional[sender_config.BackendConfigDict],
979
- ) -> None:
1003
+ ) -> dict:
980
1004
  # We subtract the previous runs runtime when resuming
981
1005
  start_time = (
982
1006
  run.start_time.ToMicroseconds() / 1e6
@@ -1061,6 +1085,7 @@ class SendManager:
1061
1085
  self._run.sweep_id = sweep_id
1062
1086
  if os.getenv("SPELL_RUN_URL"):
1063
1087
  self._sync_spell()
1088
+ return server_run
1064
1089
 
1065
1090
  def _start_run_threads(self, file_dir: Optional[str] = None) -> None:
1066
1091
  assert self._run # self._run is configured by caller
@@ -2,6 +2,7 @@ from dataclasses import fields
2
2
  from typing import Any, Iterable, Sequence, Tuple
3
3
 
4
4
  from wandb.proto import wandb_settings_pb2
5
+ from wandb.sdk.lib import RunMoment
5
6
  from wandb.sdk.wandb_settings import SettingsData
6
7
 
7
8
 
@@ -38,6 +39,14 @@ class SettingsStatic(SettingsData):
38
39
  unpacked_inner[inner_key] = inner_value
39
40
  unpacked_mapping[outer_key] = unpacked_inner
40
41
  value = unpacked_mapping
42
+ elif key == "fork_from":
43
+ value = getattr(proto, key)
44
+ if value.run:
45
+ value = RunMoment(
46
+ run=value.run, value=value.value, metric=value.metric
47
+ )
48
+ else:
49
+ value = None
41
50
  else:
42
51
  if proto.HasField(key): # type: ignore [arg-type]
43
52
  value = getattr(proto, key).value
@@ -212,7 +212,10 @@ class SystemInfo:
212
212
  os.path.join(self.settings.files_dir, CONDA_ENVIRONMENTS_FNAME), "w"
213
213
  ) as f:
214
214
  subprocess.call(
215
- ["conda", "env", "export"], stdout=f, stderr=subprocess.DEVNULL
215
+ ["conda", "env", "export"],
216
+ stdout=f,
217
+ stderr=subprocess.DEVNULL,
218
+ timeout=15, # add timeout since conda env export could take a really long time
216
219
  )
217
220
  except Exception as e:
218
221
  logger.exception(f"Error saving conda packages: {e}")
@@ -396,6 +396,7 @@ def _configure_job_builder_for_partial(tmpdir: str, job_source: str) -> JobBuild
396
396
  settings.update({"files_dir": tmpdir, "job_source": job_source})
397
397
  job_builder = JobBuilder(
398
398
  settings=settings, # type: ignore
399
+ verbose=True,
399
400
  )
400
401
  # never allow notebook runs
401
402
  job_builder._is_notebook_run = False
@@ -1,6 +1,7 @@
1
1
  """Implementation of KubernetesRunner class for wandb launch."""
2
2
  import asyncio
3
3
  import base64
4
+ import datetime
4
5
  import json
5
6
  import logging
6
7
  import os
@@ -23,6 +24,7 @@ from wandb.sdk.launch.runner.kubernetes_monitor import (
23
24
  CustomResource,
24
25
  LaunchKubernetesMonitor,
25
26
  )
27
+ from wandb.sdk.lib.retry import ExponentialBackoff, retry_async
26
28
  from wandb.util import get_module
27
29
 
28
30
  from .._project_spec import EntryPoint, LaunchProject
@@ -59,6 +61,7 @@ from kubernetes_asyncio.client.models.v1_secret import ( # type: ignore # noqa:
59
61
  from kubernetes_asyncio.client.rest import ApiException # type: ignore # noqa: E402
60
62
 
61
63
  TIMEOUT = 5
64
+ API_KEY_SECRET_MAX_RETRIES = 5
62
65
 
63
66
  _logger = logging.getLogger(__name__)
64
67
 
@@ -421,8 +424,23 @@ class KubernetesRunner(AbstractRunner):
421
424
  else:
422
425
  secret_name += f"-{launch_project.run_id}"
423
426
 
424
- api_key_secret = await ensure_api_key_secret(
425
- core_api, secret_name, namespace, value
427
+ def handle_exception(e):
428
+ wandb.termwarn(
429
+ f"Exception when ensuring Kubernetes API key secret: {e}. Retrying..."
430
+ )
431
+
432
+ api_key_secret = await retry_async(
433
+ backoff=ExponentialBackoff(
434
+ initial_sleep=datetime.timedelta(seconds=1),
435
+ max_sleep=datetime.timedelta(minutes=1),
436
+ max_retries=API_KEY_SECRET_MAX_RETRIES,
437
+ ),
438
+ fn=ensure_api_key_secret,
439
+ on_exc=handle_exception,
440
+ core_api=core_api,
441
+ secret_name=secret_name,
442
+ namespace=namespace,
443
+ api_key=value,
426
444
  )
427
445
  env.append(
428
446
  {
wandb/sdk/launch/utils.py CHANGED
@@ -222,14 +222,14 @@ def get_default_entity(api: Api, launch_config: Optional[Dict[str, Any]]):
222
222
 
223
223
 
224
224
  def strip_resource_args_and_template_vars(launch_spec: Dict[str, Any]) -> None:
225
- wandb.termwarn(
226
- "Launch spec contains both resource_args and template_variables, "
227
- "only one can be set. Using template_variables."
228
- )
229
225
  if launch_spec.get("resource_args", None) and launch_spec.get(
230
226
  "template_variables", None
231
227
  ):
232
- launch_spec["resource_args"] = None
228
+ wandb.termwarn(
229
+ "Launch spec contains both resource_args and template_variables, "
230
+ "only one can be set. Using template_variables."
231
+ )
232
+ launch_spec.pop("resource_args")
233
233
 
234
234
 
235
235
  def construct_launch_spec(
wandb/sdk/lib/__init__.py CHANGED
@@ -1,8 +1,5 @@
1
1
  from . import lazyloader
2
2
  from .disabled import RunDisabled, SummaryDisabled
3
+ from .run_moment import RunMoment
3
4
 
4
- __all__ = (
5
- "lazyloader",
6
- "RunDisabled",
7
- "SummaryDisabled",
8
- )
5
+ __all__ = ("lazyloader", "RunDisabled", "SummaryDisabled", "RunMoment")
@@ -102,6 +102,7 @@ _Setting = Literal[
102
102
  "entity",
103
103
  "files_dir",
104
104
  "force",
105
+ "fork_from",
105
106
  "git_commit",
106
107
  "git_remote",
107
108
  "git_remote_url",
@@ -227,7 +227,17 @@ def safe_open(
227
227
 
228
228
 
229
229
  def safe_copy(source_path: StrPath, target_path: StrPath) -> StrPath:
230
- """Copy a file, ensuring any changes only apply atomically once finished."""
230
+ """Copy a file atomically.
231
+
232
+ Copying is not usually atomic, and on operating systems that allow multiple
233
+ writers to the same file, the result can get corrupted. If two writers copy
234
+ to the same file, the contents can become interleaved.
235
+
236
+ We mitigate the issue somewhat by copying to a temporary file first and
237
+ then renaming. Renaming is atomic: if process 1 renames file A to X and
238
+ process 2 renames file B to X, then X will either contain the contents
239
+ of A or the contents of B, not some mixture of both.
240
+ """
231
241
  # TODO (hugh): check that there is enough free space.
232
242
  output_path = Path(target_path).resolve()
233
243
  output_path.parent.mkdir(parents=True, exist_ok=True)
@@ -0,0 +1,72 @@
1
+ from dataclasses import dataclass
2
+ from typing import Literal, Union, cast
3
+ from urllib import parse
4
+
5
+ _STEP = Literal["_step"]
6
+
7
+
8
+ @dataclass
9
+ class RunMoment:
10
+ """A moment in a run."""
11
+
12
+ run: str # run name
13
+
14
+ # currently, the _step value to fork from. in future, this will be optional
15
+ value: Union[int, float]
16
+
17
+ # only step for now, in future this will be relaxed to be any metric
18
+ metric: _STEP = "_step"
19
+
20
+ def __post_init__(self):
21
+ if self.metric != "_step":
22
+ raise ValueError(
23
+ f"Only the metric '_step' is supported, got '{self.metric}'."
24
+ )
25
+ if not isinstance(self.value, (int, float)):
26
+ raise ValueError(
27
+ f"Only int or float values are supported, got '{self.value}'."
28
+ )
29
+ if not isinstance(self.run, str):
30
+ raise ValueError(f"Only string run names are supported, got '{self.run}'.")
31
+
32
+ @classmethod
33
+ def from_uri(cls, uri: str) -> "RunMoment":
34
+ parsable = "runmoment://" + uri
35
+ parse_err = ValueError(
36
+ f"Could not parse passed run moment string '{uri}', "
37
+ f"expected format '<run>?<metric>=<numeric_value>'. "
38
+ f"Currently, only the metric '_step' is supported. "
39
+ f"Example: 'ans3bsax?_step=123'."
40
+ )
41
+
42
+ try:
43
+ parsed = parse.urlparse(parsable)
44
+ except ValueError as e:
45
+ raise parse_err from e
46
+
47
+ if parsed.scheme != "runmoment":
48
+ raise parse_err
49
+
50
+ # extract run, metric, value from parsed
51
+ if not parsed.netloc:
52
+ raise parse_err
53
+
54
+ run = parsed.netloc
55
+
56
+ if parsed.path or parsed.params or parsed.fragment:
57
+ raise parse_err
58
+
59
+ query = parse.parse_qs(parsed.query)
60
+ if len(query) != 1:
61
+ raise parse_err
62
+
63
+ metric = list(query.keys())[0]
64
+ if metric != "_step":
65
+ raise parse_err
66
+ value: str = query[metric][0]
67
+ try:
68
+ num_value = int(value) if value.isdigit() else float(value)
69
+ except ValueError as e:
70
+ raise parse_err from e
71
+
72
+ return cls(run=run, metric=cast(_STEP, metric), value=num_value)
@@ -5,6 +5,7 @@ StreamRecord: All the external state for the internal thread (queues, etc)
5
5
  StreamAction: Lightweight record for stream ops for thread safety
6
6
  StreamMux: Container for dictionary of stream threads per runid
7
7
  """
8
+
8
9
  import functools
9
10
  import multiprocessing
10
11
  import queue
@@ -327,7 +328,6 @@ class StreamMux:
327
328
  result = internal_messages_handle.wait(timeout=-1)
328
329
  assert result
329
330
  internal_messages_response = result.response.internal_messages_response
330
- job_info_handle = stream.interface.deliver_request_job_info()
331
331
 
332
332
  # wait for them, it's ok to do this serially but this can be improved
333
333
  result = poll_exit_handle.wait(timeout=-1)
@@ -346,17 +346,12 @@ class StreamMux:
346
346
  assert result
347
347
  final_summary = result.response.get_summary_response
348
348
 
349
- result = job_info_handle.wait(timeout=-1)
350
- assert result
351
- job_info = result.response.job_info_response
352
-
353
349
  Run._footer(
354
350
  sampled_history=sampled_history,
355
351
  final_summary=final_summary,
356
352
  poll_exit_response=poll_exit_response,
357
353
  server_info_response=server_info_response,
358
354
  internal_messages_response=internal_messages_response,
359
- job_info=job_info,
360
355
  settings=stream._settings, # type: ignore
361
356
  printer=printer,
362
357
  )