wandb 0.16.4__py3-none-any.whl → 0.16.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. wandb/__init__.py +2 -2
  2. wandb/agents/pyagent.py +1 -1
  3. wandb/apis/public/api.py +6 -6
  4. wandb/apis/reports/v2/interface.py +4 -8
  5. wandb/apis/reports/v2/internal.py +12 -45
  6. wandb/cli/cli.py +29 -5
  7. wandb/integration/openai/fine_tuning.py +74 -37
  8. wandb/integration/ultralytics/callback.py +0 -1
  9. wandb/proto/v3/wandb_internal_pb2.py +332 -312
  10. wandb/proto/v3/wandb_settings_pb2.py +13 -3
  11. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  12. wandb/proto/v4/wandb_internal_pb2.py +316 -312
  13. wandb/proto/v4/wandb_settings_pb2.py +5 -3
  14. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  15. wandb/sdk/artifacts/artifact.py +92 -26
  16. wandb/sdk/artifacts/artifact_manifest_entry.py +6 -1
  17. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
  18. wandb/sdk/artifacts/artifact_saver.py +16 -36
  19. wandb/sdk/artifacts/storage_handler.py +2 -1
  20. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +13 -5
  21. wandb/sdk/interface/interface.py +60 -15
  22. wandb/sdk/interface/interface_shared.py +13 -7
  23. wandb/sdk/internal/file_stream.py +19 -0
  24. wandb/sdk/internal/handler.py +1 -4
  25. wandb/sdk/internal/internal_api.py +2 -0
  26. wandb/sdk/internal/job_builder.py +45 -17
  27. wandb/sdk/internal/sender.py +53 -28
  28. wandb/sdk/internal/settings_static.py +9 -0
  29. wandb/sdk/internal/system/system_info.py +4 -1
  30. wandb/sdk/launch/_launch.py +5 -0
  31. wandb/sdk/launch/_project_spec.py +5 -20
  32. wandb/sdk/launch/agent/agent.py +80 -37
  33. wandb/sdk/launch/agent/config.py +8 -0
  34. wandb/sdk/launch/builder/kaniko_builder.py +149 -134
  35. wandb/sdk/launch/create_job.py +44 -48
  36. wandb/sdk/launch/runner/kubernetes_monitor.py +3 -1
  37. wandb/sdk/launch/runner/kubernetes_runner.py +20 -2
  38. wandb/sdk/launch/sweeps/scheduler.py +3 -1
  39. wandb/sdk/launch/utils.py +23 -5
  40. wandb/sdk/lib/__init__.py +2 -5
  41. wandb/sdk/lib/_settings_toposort_generated.py +2 -0
  42. wandb/sdk/lib/filesystem.py +11 -1
  43. wandb/sdk/lib/run_moment.py +78 -0
  44. wandb/sdk/service/streams.py +1 -6
  45. wandb/sdk/wandb_init.py +12 -7
  46. wandb/sdk/wandb_login.py +43 -26
  47. wandb/sdk/wandb_run.py +179 -94
  48. wandb/sdk/wandb_settings.py +55 -16
  49. wandb/testing/relay.py +5 -6
  50. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/METADATA +1 -1
  51. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/RECORD +55 -54
  52. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/WHEEL +1 -1
  53. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/LICENSE +0 -0
  54. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/entry_points.txt +0 -0
  55. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,7 @@ from wandb.sdk.artifacts.artifact import Artifact
11
11
  from wandb.sdk.internal.job_builder import JobBuilder
12
12
  from wandb.sdk.launch.builder.build import get_current_python_version
13
13
  from wandb.sdk.launch.git_reference import GitReference
14
- from wandb.sdk.launch.utils import _is_git_uri
14
+ from wandb.sdk.launch.utils import _is_git_uri, get_entrypoint_file
15
15
  from wandb.sdk.lib import filesystem
16
16
  from wandb.util import make_artifact_name_safe
17
17
 
@@ -145,6 +145,7 @@ def _create_job(
145
145
 
146
146
  job_builder = _configure_job_builder_for_partial(tempdir.name, job_source=job_type)
147
147
  if job_type == "code":
148
+ assert entrypoint is not None
148
149
  job_name = _make_code_artifact(
149
150
  api=api,
150
151
  job_builder=job_builder,
@@ -233,7 +234,6 @@ def _make_metadata_for_partial_job(
233
234
  return metadata, None
234
235
 
235
236
  if job_type == "code":
236
- path, entrypoint = _handle_artifact_entrypoint(path, entrypoint)
237
237
  if not entrypoint:
238
238
  wandb.termerror(
239
239
  "Artifact jobs must have an entrypoint, either included in the path or specified with -E"
@@ -304,15 +304,22 @@ def _create_repo_metadata(
304
304
  with open(os.path.join(local_dir, ".python-version")) as f:
305
305
  python_version = f.read().strip().splitlines()[0]
306
306
  else:
307
- major, minor = get_current_python_version()
308
- python_version = f"{major}.{minor}"
307
+ _, python_version = get_current_python_version()
309
308
 
310
309
  python_version = _clean_python_version(python_version)
311
310
 
312
311
  # check if entrypoint is valid
313
312
  assert entrypoint is not None
314
- if not os.path.exists(os.path.join(local_dir, entrypoint)):
315
- wandb.termerror(f"Entrypoint {entrypoint} not found in git repo")
313
+ entrypoint_list = entrypoint.split(" ")
314
+ entrypoint_file = get_entrypoint_file(entrypoint_list)
315
+ if not entrypoint_file:
316
+ wandb.termerror(
317
+ f"Entrypoint {entrypoint} is invalid. An entrypoint should include both an executable and a file, for example 'python train.py'"
318
+ )
319
+ return None
320
+
321
+ if not os.path.exists(os.path.join(local_dir, entrypoint_file)):
322
+ wandb.termerror(f"Entrypoint file {entrypoint_file} not found in git repo")
316
323
  return None
317
324
 
318
325
  metadata = {
@@ -320,9 +327,9 @@ def _create_repo_metadata(
320
327
  "commit": commit,
321
328
  "remote": ref.url,
322
329
  },
323
- "codePathLocal": entrypoint, # not in git context, optionally also set local
324
- "codePath": entrypoint,
325
- "entrypoint": [f"python{python_version}", entrypoint],
330
+ "codePathLocal": entrypoint_file, # not in git context, optionally also set local
331
+ "codePath": entrypoint_file,
332
+ "entrypoint": entrypoint_list,
326
333
  "python": python_version, # used to build container
327
334
  "notebook": False, # partial jobs from notebooks not supported
328
335
  }
@@ -332,10 +339,17 @@ def _create_repo_metadata(
332
339
 
333
340
  def _create_artifact_metadata(
334
341
  path: str, entrypoint: str, runtime: Optional[str] = None
335
- ) -> Tuple[Dict[str, Any], List[str]]:
342
+ ) -> Tuple[Optional[Dict[str, Any]], Optional[List[str]]]:
336
343
  if not os.path.isdir(path):
337
344
  wandb.termerror("Path must be a valid file or directory")
338
345
  return {}, []
346
+ entrypoint_list = entrypoint.split(" ")
347
+ entrypoint_file = get_entrypoint_file(entrypoint_list)
348
+ if not entrypoint_file:
349
+ wandb.termerror(
350
+ f"Entrypoint {entrypoint} is invalid. An entrypoint should include both an executable and a file, for example 'python train.py'"
351
+ )
352
+ return None, None
339
353
 
340
354
  # read local requirements.txt and dump to temp dir for builder
341
355
  requirements = []
@@ -347,41 +361,17 @@ def _create_artifact_metadata(
347
361
  if runtime:
348
362
  python_version = _clean_python_version(runtime)
349
363
  else:
350
- python_version = ".".join(get_current_python_version())
364
+ python_version, _ = get_current_python_version()
365
+ python_version = _clean_python_version(python_version)
351
366
 
352
- metadata = {"python": python_version, "codePath": entrypoint}
367
+ metadata = {
368
+ "python": python_version,
369
+ "codePath": entrypoint_file,
370
+ "entrypoint": entrypoint_list,
371
+ }
353
372
  return metadata, requirements
354
373
 
355
374
 
356
- def _handle_artifact_entrypoint(
357
- path: str, entrypoint: Optional[str] = None
358
- ) -> Tuple[str, Optional[str]]:
359
- if os.path.isfile(path):
360
- if entrypoint and path.endswith(entrypoint):
361
- path = path.replace(entrypoint, "")
362
- wandb.termwarn(
363
- f"Both entrypoint provided and path contains file. Using provided entrypoint: {entrypoint}, path is now: {path}"
364
- )
365
- elif entrypoint:
366
- wandb.termwarn(
367
- f"Ignoring passed in entrypoint as it does not match file path found in 'path'. Path entrypoint: {path.split('/')[-1]}"
368
- )
369
- entrypoint = path.split("/")[-1]
370
- path = "/".join(path.split("/")[:-1])
371
- elif not entrypoint:
372
- wandb.termerror("Entrypoint not valid")
373
- return "", None
374
- path = path or "." # when path is just an entrypoint, use cdw
375
-
376
- if not os.path.exists(os.path.join(path, entrypoint)):
377
- wandb.termerror(
378
- f"Could not find execution point: {os.path.join(path, entrypoint)}"
379
- )
380
- return "", None
381
-
382
- return path, entrypoint
383
-
384
-
385
375
  def _configure_job_builder_for_partial(tmpdir: str, job_source: str) -> JobBuilder:
386
376
  """Configure job builder with temp dir and job source."""
387
377
  # adjust git source to repo
@@ -396,6 +386,7 @@ def _configure_job_builder_for_partial(tmpdir: str, job_source: str) -> JobBuild
396
386
  settings.update({"files_dir": tmpdir, "job_source": job_source})
397
387
  job_builder = JobBuilder(
398
388
  settings=settings, # type: ignore
389
+ verbose=True,
399
390
  )
400
391
  # never allow notebook runs
401
392
  job_builder._is_notebook_run = False
@@ -410,7 +401,7 @@ def _make_code_artifact(
410
401
  job_builder: JobBuilder,
411
402
  run: "wandb.sdk.wandb_run.Run",
412
403
  path: str,
413
- entrypoint: Optional[str],
404
+ entrypoint: str,
414
405
  entity: Optional[str],
415
406
  project: Optional[str],
416
407
  name: Optional[str],
@@ -419,17 +410,22 @@ def _make_code_artifact(
419
410
 
420
411
  Returns the name of the eventual job.
421
412
  """
422
- artifact_name = _make_code_artifact_name(os.path.join(path, entrypoint or ""), name)
413
+ assert entrypoint is not None
414
+ entrypoint_list = entrypoint.split(" ")
415
+ entrypoint_file = get_entrypoint_file(entrypoint_list)
416
+ if not entrypoint_file:
417
+ wandb.termerror(
418
+ f"Entrypoint {entrypoint} is invalid. An entrypoint should include both an executable and a file, for example 'python train.py'"
419
+ )
420
+ return None
421
+
422
+ artifact_name = _make_code_artifact_name(os.path.join(path, entrypoint_file), name)
423
423
  code_artifact = wandb.Artifact(
424
424
  name=artifact_name,
425
425
  type="code",
426
426
  description="Code artifact for job",
427
427
  )
428
428
 
429
- # Update path and entrypoint vars to match metadata
430
- # TODO(gst): consolidate into one place
431
- path, entrypoint = _handle_artifact_entrypoint(path, entrypoint)
432
-
433
429
  try:
434
430
  code_artifact.add_dir(path)
435
431
  except Exception as e:
@@ -450,7 +446,7 @@ def _make_code_artifact(
450
446
  project_name=project,
451
447
  run_name=run.id, # run will be deleted after creation
452
448
  description="Code artifact for job",
453
- metadata={"codePath": path, "entrypoint": entrypoint},
449
+ metadata={"codePath": path, "entrypoint": entrypoint_file},
454
450
  is_user_created=True,
455
451
  aliases=[
456
452
  {"artifactCollectionName": artifact_name, "alias": a} for a in ["latest"]
@@ -433,6 +433,8 @@ class SafeWatch:
433
433
  del kwargs["resource_version"]
434
434
  self._last_seen_resource_version = None
435
435
  except Exception as E:
436
+ exc_type = type(E).__name__
437
+ stack_trace = traceback.format_exc()
436
438
  wandb.termerror(
437
- f"Unknown exception in event stream: {E}, attempting to recover"
439
+ f"Unknown exception in event stream of type {exc_type}: {E}, attempting to recover. Stack trace: {stack_trace}"
438
440
  )
@@ -1,6 +1,7 @@
1
1
  """Implementation of KubernetesRunner class for wandb launch."""
2
2
  import asyncio
3
3
  import base64
4
+ import datetime
4
5
  import json
5
6
  import logging
6
7
  import os
@@ -23,6 +24,7 @@ from wandb.sdk.launch.runner.kubernetes_monitor import (
23
24
  CustomResource,
24
25
  LaunchKubernetesMonitor,
25
26
  )
27
+ from wandb.sdk.lib.retry import ExponentialBackoff, retry_async
26
28
  from wandb.util import get_module
27
29
 
28
30
  from .._project_spec import EntryPoint, LaunchProject
@@ -59,6 +61,7 @@ from kubernetes_asyncio.client.models.v1_secret import ( # type: ignore # noqa:
59
61
  from kubernetes_asyncio.client.rest import ApiException # type: ignore # noqa: E402
60
62
 
61
63
  TIMEOUT = 5
64
+ API_KEY_SECRET_MAX_RETRIES = 5
62
65
 
63
66
  _logger = logging.getLogger(__name__)
64
67
 
@@ -421,8 +424,23 @@ class KubernetesRunner(AbstractRunner):
421
424
  else:
422
425
  secret_name += f"-{launch_project.run_id}"
423
426
 
424
- api_key_secret = await ensure_api_key_secret(
425
- core_api, secret_name, namespace, value
427
+ def handle_exception(e):
428
+ wandb.termwarn(
429
+ f"Exception when ensuring Kubernetes API key secret: {e}. Retrying..."
430
+ )
431
+
432
+ api_key_secret = await retry_async(
433
+ backoff=ExponentialBackoff(
434
+ initial_sleep=datetime.timedelta(seconds=1),
435
+ max_sleep=datetime.timedelta(minutes=1),
436
+ max_retries=API_KEY_SECRET_MAX_RETRIES,
437
+ ),
438
+ fn=ensure_api_key_secret,
439
+ on_exc=handle_exception,
440
+ core_api=core_api,
441
+ secret_name=secret_name,
442
+ namespace=namespace,
443
+ api_key=value,
426
444
  )
427
445
  env.append(
428
446
  {
@@ -157,7 +157,9 @@ class Scheduler(ABC):
157
157
  self._runs: Dict[str, SweepRun] = {}
158
158
  # Threading lock to ensure thread-safe access to the runs dictionary
159
159
  self._threading_lock: threading.Lock = threading.Lock()
160
- self._polling_sleep = polling_sleep or DEFAULT_POLLING_SLEEP
160
+ self._polling_sleep = (
161
+ polling_sleep if polling_sleep is not None else DEFAULT_POLLING_SLEEP
162
+ )
161
163
  self._project_queue = project_queue
162
164
  # Optionally run multiple workers in (pseudo-)parallel. Workers do not
163
165
  # actually run training workloads, they simply send heartbeat messages
wandb/sdk/launch/utils.py CHANGED
@@ -222,14 +222,14 @@ def get_default_entity(api: Api, launch_config: Optional[Dict[str, Any]]):
222
222
 
223
223
 
224
224
  def strip_resource_args_and_template_vars(launch_spec: Dict[str, Any]) -> None:
225
- wandb.termwarn(
226
- "Launch spec contains both resource_args and template_variables, "
227
- "only one can be set. Using template_variables."
228
- )
229
225
  if launch_spec.get("resource_args", None) and launch_spec.get(
230
226
  "template_variables", None
231
227
  ):
232
- launch_spec["resource_args"] = None
228
+ wandb.termwarn(
229
+ "Launch spec contains both resource_args and template_variables, "
230
+ "only one can be set. Using template_variables."
231
+ )
232
+ launch_spec.pop("resource_args")
233
233
 
234
234
 
235
235
  def construct_launch_spec(
@@ -846,3 +846,21 @@ def fetch_and_validate_template_variables(
846
846
  raise LaunchError(f"Value for {key} must be of type {field_type}.")
847
847
  template_variables[key] = val
848
848
  return template_variables
849
+
850
+
851
+ def get_entrypoint_file(entrypoint: List[str]) -> Optional[str]:
852
+ """Get the entrypoint file from the given command.
853
+
854
+ Args:
855
+ entrypoint (List[str]): List of command and arguments.
856
+
857
+ Returns:
858
+ Optional[str]: The entrypoint file if found, otherwise None.
859
+ """
860
+ if not entrypoint:
861
+ return None
862
+ if entrypoint[0].endswith(".py") or entrypoint[0].endswith(".sh"):
863
+ return entrypoint[0]
864
+ if len(entrypoint) < 2:
865
+ return None
866
+ return entrypoint[1]
wandb/sdk/lib/__init__.py CHANGED
@@ -1,8 +1,5 @@
1
1
  from . import lazyloader
2
2
  from .disabled import RunDisabled, SummaryDisabled
3
+ from .run_moment import RunMoment
3
4
 
4
- __all__ = (
5
- "lazyloader",
6
- "RunDisabled",
7
- "SummaryDisabled",
8
- )
5
+ __all__ = ("lazyloader", "RunDisabled", "SummaryDisabled", "RunMoment")
@@ -22,6 +22,7 @@ _Setting = Literal[
22
22
  "_disable_service",
23
23
  "_disable_setproctitle",
24
24
  "_disable_stats",
25
+ "_disable_update_check",
25
26
  "_disable_viewer",
26
27
  "_disable_machine_info",
27
28
  "_except_exit",
@@ -102,6 +103,7 @@ _Setting = Literal[
102
103
  "entity",
103
104
  "files_dir",
104
105
  "force",
106
+ "fork_from",
105
107
  "git_commit",
106
108
  "git_remote",
107
109
  "git_remote_url",
@@ -227,7 +227,17 @@ def safe_open(
227
227
 
228
228
 
229
229
  def safe_copy(source_path: StrPath, target_path: StrPath) -> StrPath:
230
- """Copy a file, ensuring any changes only apply atomically once finished."""
230
+ """Copy a file atomically.
231
+
232
+ Copying is not usually atomic, and on operating systems that allow multiple
233
+ writers to the same file, the result can get corrupted. If two writers copy
234
+ to the same file, the contents can become interleaved.
235
+
236
+ We mitigate the issue somewhat by copying to a temporary file first and
237
+ then renaming. Renaming is atomic: if process 1 renames file A to X and
238
+ process 2 renames file B to X, then X will either contain the contents
239
+ of A or the contents of B, not some mixture of both.
240
+ """
231
241
  # TODO (hugh): check that there is enough free space.
232
242
  output_path = Path(target_path).resolve()
233
243
  output_path.parent.mkdir(parents=True, exist_ok=True)
@@ -0,0 +1,78 @@
1
+ import sys
2
+ from dataclasses import dataclass
3
+ from typing import Union, cast
4
+ from urllib import parse
5
+
6
+ if sys.version_info >= (3, 8):
7
+ from typing import Literal
8
+ else:
9
+ from typing_extensions import Literal
10
+
11
+ _STEP = Literal["_step"]
12
+
13
+
14
+ @dataclass
15
+ class RunMoment:
16
+ """A moment in a run."""
17
+
18
+ run: str # run name
19
+
20
+ # currently, the _step value to fork from. in future, this will be optional
21
+ value: Union[int, float]
22
+
23
+ # only step for now, in future this will be relaxed to be any metric
24
+ metric: _STEP = "_step"
25
+
26
+ def __post_init__(self):
27
+ if self.metric != "_step":
28
+ raise ValueError(
29
+ f"Only the metric '_step' is supported, got '{self.metric}'."
30
+ )
31
+ if not isinstance(self.value, (int, float)):
32
+ raise ValueError(
33
+ f"Only int or float values are supported, got '{self.value}'."
34
+ )
35
+ if not isinstance(self.run, str):
36
+ raise ValueError(f"Only string run names are supported, got '{self.run}'.")
37
+
38
+ @classmethod
39
+ def from_uri(cls, uri: str) -> "RunMoment":
40
+ parsable = "runmoment://" + uri
41
+ parse_err = ValueError(
42
+ f"Could not parse passed run moment string '{uri}', "
43
+ f"expected format '<run>?<metric>=<numeric_value>'. "
44
+ f"Currently, only the metric '_step' is supported. "
45
+ f"Example: 'ans3bsax?_step=123'."
46
+ )
47
+
48
+ try:
49
+ parsed = parse.urlparse(parsable)
50
+ except ValueError as e:
51
+ raise parse_err from e
52
+
53
+ if parsed.scheme != "runmoment":
54
+ raise parse_err
55
+
56
+ # extract run, metric, value from parsed
57
+ if not parsed.netloc:
58
+ raise parse_err
59
+
60
+ run = parsed.netloc
61
+
62
+ if parsed.path or parsed.params or parsed.fragment:
63
+ raise parse_err
64
+
65
+ query = parse.parse_qs(parsed.query)
66
+ if len(query) != 1:
67
+ raise parse_err
68
+
69
+ metric = list(query.keys())[0]
70
+ if metric != "_step":
71
+ raise parse_err
72
+ value: str = query[metric][0]
73
+ try:
74
+ num_value = int(value) if value.isdigit() else float(value)
75
+ except ValueError as e:
76
+ raise parse_err from e
77
+
78
+ return cls(run=run, metric=cast(_STEP, metric), value=num_value)
@@ -5,6 +5,7 @@ StreamRecord: All the external state for the internal thread (queues, etc)
5
5
  StreamAction: Lightweight record for stream ops for thread safety
6
6
  StreamMux: Container for dictionary of stream threads per runid
7
7
  """
8
+
8
9
  import functools
9
10
  import multiprocessing
10
11
  import queue
@@ -327,7 +328,6 @@ class StreamMux:
327
328
  result = internal_messages_handle.wait(timeout=-1)
328
329
  assert result
329
330
  internal_messages_response = result.response.internal_messages_response
330
- job_info_handle = stream.interface.deliver_request_job_info()
331
331
 
332
332
  # wait for them, it's ok to do this serially but this can be improved
333
333
  result = poll_exit_handle.wait(timeout=-1)
@@ -346,17 +346,12 @@ class StreamMux:
346
346
  assert result
347
347
  final_summary = result.response.get_summary_response
348
348
 
349
- result = job_info_handle.wait(timeout=-1)
350
- assert result
351
- job_info = result.response.job_info_response
352
-
353
349
  Run._footer(
354
350
  sampled_history=sampled_history,
355
351
  final_summary=final_summary,
356
352
  poll_exit_response=poll_exit_response,
357
353
  server_info_response=server_info_response,
358
354
  internal_messages_response=internal_messages_response,
359
- job_info=job_info,
360
355
  settings=stream._settings, # type: ignore
361
356
  printer=printer,
362
357
  )
wandb/sdk/wandb_init.py CHANGED
@@ -7,6 +7,7 @@ your evaluation script, and each step would be tracked as a run in W&B.
7
7
  For more on using `wandb.init()`, including code snippets, check out our
8
8
  [guide and FAQs](https://docs.wandb.ai/guides/track/launch).
9
9
  """
10
+
10
11
  import copy
11
12
  import json
12
13
  import logging
@@ -194,12 +195,6 @@ class _WandbInit:
194
195
  # Start with settings from wandb library singleton
195
196
  settings: Settings = self._wl.settings.copy()
196
197
 
197
- # when using launch, we don't want to reuse the same run id from the singleton
198
- # since users might launch multiple runs in the same process
199
- # TODO(kdg): allow users to control this via launch settings
200
- if settings.launch and singleton is not None:
201
- settings.update({"run_id": None}, source=Source.INIT)
202
-
203
198
  settings_param = kwargs.pop("settings", None)
204
199
  if settings_param is not None and isinstance(settings_param, (Settings, dict)):
205
200
  settings.update(settings_param, source=Source.INIT)
@@ -828,7 +823,7 @@ class _WandbInit:
828
823
  and self.settings.launch_config_path
829
824
  and os.path.exists(self.settings.launch_config_path)
830
825
  ):
831
- run._save(self.settings.launch_config_path)
826
+ run.save(self.settings.launch_config_path)
832
827
  # put artifacts in run config here
833
828
  # since doing so earlier will cause an error
834
829
  # as the run is not upserted
@@ -961,6 +956,7 @@ def init(
961
956
  monitor_gym: Optional[bool] = None,
962
957
  save_code: Optional[bool] = None,
963
958
  id: Optional[str] = None,
959
+ fork_from: Optional[str] = None,
964
960
  settings: Union[Settings, Dict[str, Any], None] = None,
965
961
  ) -> Union[Run, RunDisabled, None]:
966
962
  r"""Start a new run to track and log to W&B.
@@ -1122,6 +1118,10 @@ def init(
1122
1118
  for saving hyperparameters to compare across runs. The ID cannot
1123
1119
  contain the following special characters: `/\#?%:`.
1124
1120
  See [our guide to resuming runs](https://docs.wandb.com/guides/runs/resuming).
1121
+ fork_from: (str, optional) A string with the format {run_id}?_step={step} describing
1122
+ a moment in a previous run to fork a new run from. Creates a new run that picks up
1123
+ logging history from the specified run at the specified moment. The target run must
1124
+ be in the current project. Example: `fork_from="my-run-id?_step=1234"`.
1125
1125
 
1126
1126
  Examples:
1127
1127
  ### Set where the run is logged
@@ -1167,6 +1167,11 @@ def init(
1167
1167
  error_seen = None
1168
1168
  except_exit = None
1169
1169
  run: Optional[Union[Run, RunDisabled]] = None
1170
+
1171
+ # convert fork_from into a version that can be passed to settings
1172
+ if fork_from is not None and resume is not None:
1173
+ raise ValueError("Cannot specify both `fork_from` and `resume`")
1174
+
1170
1175
  try:
1171
1176
  wi = _WandbInit()
1172
1177
  wi.setup(kwargs)
wandb/sdk/wandb_login.py CHANGED
@@ -22,7 +22,7 @@ from wandb.old.settings import Settings as OldSettings
22
22
  from ..apis import InternalApi
23
23
  from .internal.internal_api import Api
24
24
  from .lib import apikey
25
- from .wandb_settings import Settings, Source
25
+ from .wandb_settings import Settings
26
26
 
27
27
 
28
28
  def _handle_host_wandb_setting(host: Optional[str], cloud: bool = False) -> None:
@@ -80,11 +80,17 @@ def login(
80
80
  _handle_host_wandb_setting(host)
81
81
  if wandb.setup()._settings._noop:
82
82
  return True
83
- kwargs = dict(locals())
84
- _verify = kwargs.pop("verify", False)
85
- configured = _login(**kwargs)
86
83
 
87
- if _verify:
84
+ configured = _login(
85
+ anonymous=anonymous,
86
+ key=key,
87
+ relogin=relogin,
88
+ host=host,
89
+ force=force,
90
+ timeout=timeout,
91
+ )
92
+
93
+ if verify:
88
94
  from . import wandb_setup
89
95
 
90
96
  singleton = wandb_setup._WandbSetup._instance
@@ -115,22 +121,32 @@ class _WandbLogin:
115
121
  self._key = None
116
122
  self._relogin = None
117
123
 
118
- def setup(self, kwargs):
119
- self.kwargs = kwargs
124
+ def setup(
125
+ self,
126
+ *,
127
+ anonymous: Optional[Literal["must", "allow", "never"]] = None,
128
+ key: Optional[str] = None,
129
+ relogin: Optional[bool] = None,
130
+ host: Optional[str] = None,
131
+ force: Optional[bool] = None,
132
+ timeout: Optional[int] = None,
133
+ ):
134
+ self._relogin = relogin
120
135
 
121
136
  # built up login settings
122
137
  login_settings: Settings = wandb.Settings()
123
- settings_param = kwargs.pop("_settings", None)
124
- # note that this case does not come up anywhere except for the tests
125
- if settings_param is not None:
126
- if isinstance(settings_param, Settings):
127
- login_settings._apply_settings(settings_param)
128
- elif isinstance(settings_param, dict):
129
- login_settings.update(settings_param, source=Source.LOGIN)
130
- _logger = wandb.setup()._get_logger()
131
- # Do not save relogin into settings as we just want to relogin once
132
- self._relogin = kwargs.pop("relogin", None)
133
- login_settings._apply_login(kwargs, _logger=_logger)
138
+ logger = wandb.setup()._get_logger()
139
+
140
+ login_settings._apply_login(
141
+ {
142
+ "anonymous": anonymous,
143
+ "key": key,
144
+ "host": host,
145
+ "force": force,
146
+ "timeout": timeout,
147
+ },
148
+ _logger=logger,
149
+ )
134
150
 
135
151
  # make sure they are applied globally
136
152
  self._wl = wandb.setup(settings=login_settings)
@@ -259,6 +275,7 @@ class _WandbLogin:
259
275
 
260
276
 
261
277
  def _login(
278
+ *,
262
279
  anonymous: Optional[Literal["must", "allow", "never"]] = None,
263
280
  key: Optional[str] = None,
264
281
  relogin: Optional[bool] = None,
@@ -270,9 +287,6 @@ def _login(
270
287
  _disable_warning: Optional[bool] = None,
271
288
  _entity: Optional[str] = None,
272
289
  ):
273
- kwargs = dict(locals())
274
- _disable_warning = kwargs.pop("_disable_warning", None)
275
-
276
290
  if wandb.run is not None:
277
291
  if not _disable_warning:
278
292
  wandb.termwarn("Calling wandb.login() after wandb.init() has no effect.")
@@ -280,20 +294,24 @@ def _login(
280
294
 
281
295
  wlogin = _WandbLogin()
282
296
 
283
- _backend = kwargs.pop("_backend", None)
284
297
  if _backend:
285
298
  wlogin.set_backend(_backend)
286
299
 
287
- _silent = kwargs.pop("_silent", None)
288
300
  if _silent:
289
301
  wlogin.set_silent(_silent)
290
302
 
291
- _entity = kwargs.pop("_entity", None)
292
303
  if _entity:
293
304
  wlogin.set_entity(_entity)
294
305
 
295
306
  # configure login object
296
- wlogin.setup(kwargs)
307
+ wlogin.setup(
308
+ anonymous=anonymous,
309
+ key=key,
310
+ relogin=relogin,
311
+ host=host,
312
+ force=force,
313
+ timeout=timeout,
314
+ )
297
315
 
298
316
  if wlogin._settings._offline:
299
317
  return False
@@ -306,7 +324,6 @@ def _login(
306
324
  # perform a login
307
325
  logged_in = wlogin.login()
308
326
 
309
- key = kwargs.get("key")
310
327
  if key:
311
328
  wlogin.configure_api_key(key)
312
329