wandb 0.18.1__py3-none-win_amd64.whl → 0.18.3__py3-none-win_amd64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (79) hide show
  1. wandb/__init__.py +3 -3
  2. wandb/__init__.pyi +67 -12
  3. wandb/apis/internal.py +3 -0
  4. wandb/apis/public/api.py +128 -2
  5. wandb/apis/public/artifacts.py +11 -7
  6. wandb/apis/public/jobs.py +8 -0
  7. wandb/apis/public/runs.py +16 -5
  8. wandb/bin/nvidia_gpu_stats.exe +0 -0
  9. wandb/bin/wandb-core +0 -0
  10. wandb/cli/cli.py +0 -3
  11. wandb/errors/__init__.py +11 -40
  12. wandb/errors/errors.py +37 -0
  13. wandb/errors/warnings.py +2 -0
  14. wandb/integration/tensorboard/log.py +1 -1
  15. wandb/old/core.py +2 -80
  16. wandb/plot/bar.py +7 -4
  17. wandb/plot/confusion_matrix.py +5 -4
  18. wandb/plot/histogram.py +7 -4
  19. wandb/plot/line.py +7 -4
  20. wandb/proto/v3/wandb_internal_pb2.py +31 -21
  21. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  22. wandb/proto/v4/wandb_internal_pb2.py +23 -21
  23. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  24. wandb/proto/v5/wandb_internal_pb2.py +23 -21
  25. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  26. wandb/sdk/artifacts/_validators.py +48 -3
  27. wandb/sdk/artifacts/artifact.py +160 -186
  28. wandb/sdk/artifacts/artifact_file_cache.py +13 -11
  29. wandb/sdk/artifacts/artifact_instance_cache.py +4 -2
  30. wandb/sdk/artifacts/artifact_manifest.py +13 -11
  31. wandb/sdk/artifacts/artifact_manifest_entry.py +24 -22
  32. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +9 -7
  33. wandb/sdk/artifacts/artifact_saver.py +27 -25
  34. wandb/sdk/artifacts/exceptions.py +26 -25
  35. wandb/sdk/artifacts/storage_handler.py +11 -9
  36. wandb/sdk/artifacts/storage_handlers/azure_handler.py +16 -14
  37. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +15 -13
  38. wandb/sdk/artifacts/storage_handlers/http_handler.py +15 -14
  39. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +10 -8
  40. wandb/sdk/artifacts/storage_handlers/multi_handler.py +14 -12
  41. wandb/sdk/artifacts/storage_handlers/s3_handler.py +19 -19
  42. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +10 -8
  43. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +12 -10
  44. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +9 -7
  45. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +31 -29
  46. wandb/sdk/artifacts/storage_policy.py +20 -20
  47. wandb/sdk/backend/backend.py +8 -26
  48. wandb/sdk/data_types/base_types/wb_value.py +1 -3
  49. wandb/sdk/data_types/video.py +2 -2
  50. wandb/sdk/interface/interface.py +0 -24
  51. wandb/sdk/interface/interface_shared.py +0 -12
  52. wandb/sdk/internal/handler.py +0 -10
  53. wandb/sdk/internal/internal_api.py +71 -0
  54. wandb/sdk/internal/sender.py +0 -43
  55. wandb/sdk/internal/tb_watcher.py +1 -1
  56. wandb/sdk/lib/_settings_toposort_generated.py +1 -0
  57. wandb/sdk/lib/hashutil.py +34 -12
  58. wandb/sdk/lib/service_connection.py +216 -0
  59. wandb/sdk/lib/service_token.py +94 -0
  60. wandb/sdk/lib/sock_client.py +7 -3
  61. wandb/sdk/service/server.py +2 -5
  62. wandb/sdk/service/service.py +0 -22
  63. wandb/sdk/wandb_init.py +33 -22
  64. wandb/sdk/wandb_run.py +45 -33
  65. wandb/sdk/wandb_settings.py +2 -0
  66. wandb/sdk/wandb_setup.py +25 -16
  67. wandb/sdk/wandb_sync.py +9 -3
  68. wandb/sdk/wandb_watch.py +31 -15
  69. wandb/util.py +8 -1
  70. {wandb-0.18.1.dist-info → wandb-0.18.3.dist-info}/METADATA +3 -2
  71. {wandb-0.18.1.dist-info → wandb-0.18.3.dist-info}/RECORD +75 -74
  72. wandb/sdk/internal/update.py +0 -113
  73. wandb/sdk/service/service_base.py +0 -50
  74. wandb/sdk/service/service_sock.py +0 -70
  75. wandb/sdk/wandb_manager.py +0 -232
  76. /wandb/{sdk/lib → plot}/viz.py +0 -0
  77. {wandb-0.18.1.dist-info → wandb-0.18.3.dist-info}/WHEEL +0 -0
  78. {wandb-0.18.1.dist-info → wandb-0.18.3.dist-info}/entry_points.txt +0 -0
  79. {wandb-0.18.1.dist-info → wandb-0.18.3.dist-info}/licenses/LICENSE +0 -0
wandb/__init__.py CHANGED
@@ -8,7 +8,7 @@ For scripts and interactive notebooks, see https://github.com/wandb/examples.
8
8
 
9
9
  For reference documentation, see https://docs.wandb.com/ref/python.
10
10
  """
11
- __version__ = "0.18.1"
11
+ __version__ = "0.18.3"
12
12
 
13
13
  from typing import Optional
14
14
 
@@ -76,8 +76,7 @@ from wandb.data_types import JoinedTable
76
76
 
77
77
  from wandb.wandb_agent import agent
78
78
 
79
- from wandb.sdk.lib.viz import visualize
80
- from wandb import plot
79
+ from wandb.plot.viz import visualize
81
80
  from wandb.integration.sagemaker import sagemaker_auth
82
81
  from wandb.sdk.internal import profiler
83
82
 
@@ -242,4 +241,5 @@ __all__ = (
242
241
  "use_model",
243
242
  "link_model",
244
243
  "define_metric",
244
+ "watch",
245
245
  )
wandb/__init__.pyi CHANGED
@@ -9,6 +9,8 @@ For scripts and interactive notebooks, see https://github.com/wandb/examples.
9
9
  For reference documentation, see https://docs.wandb.com/ref/python.
10
10
  """
11
11
 
12
+ from __future__ import annotations
13
+
12
14
  __all__ = (
13
15
  "__version__",
14
16
  "init",
@@ -49,12 +51,23 @@ __all__ = (
49
51
  "Artifact",
50
52
  "Settings",
51
53
  "teardown",
54
+ "watch",
52
55
  )
53
56
 
54
57
  import os
55
- from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union
58
+ from typing import (
59
+ TYPE_CHECKING,
60
+ Any,
61
+ Callable,
62
+ Dict,
63
+ List,
64
+ Literal,
65
+ Optional,
66
+ Sequence,
67
+ Union,
68
+ )
56
69
 
57
- from wandb.analytics import Sentry as _Sentry
70
+ from wandb.analytics import Sentry
58
71
  from wandb.apis import InternalApi, PublicApi
59
72
  from wandb.data_types import (
60
73
  Audio,
@@ -79,17 +92,20 @@ from wandb.sdk.wandb_run import Run
79
92
  from wandb.sdk.wandb_setup import _WandbSetup
80
93
  from wandb.wandb_controller import _WandbController
81
94
 
82
- __version__: str = "0.18.1"
95
+ if TYPE_CHECKING:
96
+ import torch # type: ignore [import-not-found]
83
97
 
84
- run: Optional[Run] = None
85
- config = wandb_config.Config
86
- summary = wandb_summary.Summary
87
- Api = PublicApi
88
- api = InternalApi()
89
- _sentry = _Sentry()
98
+ __version__: str = "0.18.3"
90
99
 
91
- # record of patched libraries
92
- patched = {"tensorboard": [], "keras": [], "gym": []} # type: ignore
100
+ run: Run | None
101
+ config: wandb_config.Config
102
+ summary: wandb_summary.Summary
103
+ Api: PublicApi
104
+
105
+ # private attributes
106
+ _sentry: Sentry
107
+ api: InternalApi
108
+ patched: Dict[str, List[Callable]]
93
109
 
94
110
  def setup(
95
111
  settings: Optional[Settings] = None,
@@ -183,7 +199,7 @@ def init(
183
199
  allow_val_change: Optional[bool] = None,
184
200
  resume: Optional[Union[bool, str]] = None,
185
201
  force: Optional[bool] = None,
186
- tensorboard: Optional[bool] = None, # alias for sync_tensorboard
202
+ tensorboard: Optional[bool] = None,
187
203
  sync_tensorboard: Optional[bool] = None,
188
204
  monitor_gym: Optional[bool] = None,
189
205
  save_code: Optional[bool] = None,
@@ -1082,3 +1098,42 @@ def link_model(
1082
1098
  None
1083
1099
  """
1084
1100
  ...
1101
+
1102
+ def watch(
1103
+ models: torch.nn.Module | Sequence[torch.nn.Module],
1104
+ criterion: torch.F | None = None,
1105
+ log: Literal["gradients", "parameters", "all"] | None = "gradients",
1106
+ log_freq: int = 1000,
1107
+ idx: int | None = None,
1108
+ log_graph: bool = False,
1109
+ ) -> Graph:
1110
+ """Hooks into the given PyTorch model(s) to monitor gradients and the model's computational graph.
1111
+
1112
+ This function can track parameters, gradients, or both during training. It should be
1113
+ extended to support arbitrary machine learning models in the future.
1114
+
1115
+ Args:
1116
+ models (Union[torch.nn.Module, Sequence[torch.nn.Module]]):
1117
+ A single model or a sequence of models to be monitored.
1118
+ criterion (Optional[torch.F]):
1119
+ The loss function being optimized (optional).
1120
+ log (Optional[Literal["gradients", "parameters", "all"]]):
1121
+ Specifies whether to log "gradients", "parameters", or "all".
1122
+ Set to None to disable logging. (default="gradients")
1123
+ log_freq (int):
1124
+ Frequency (in batches) to log gradients and parameters. (default=1000)
1125
+ idx (Optional[int]):
1126
+ Index used when tracking multiple models with `wandb.watch`. (default=None)
1127
+ log_graph (bool):
1128
+ Whether to log the model's computational graph. (default=False)
1129
+
1130
+ Returns:
1131
+ wandb.Graph:
1132
+ The graph object, which will be populated after the first backward pass.
1133
+
1134
+ Raises:
1135
+ ValueError:
1136
+ If `wandb.init` has not been called or if any of the models are not instances
1137
+ of `torch.nn.Module`.
1138
+ """
1139
+ ...
wandb/apis/internal.py CHANGED
@@ -204,6 +204,9 @@ class Api:
204
204
  def create_run_queue(self, *args, **kwargs):
205
205
  return self.api.create_run_queue(*args, **kwargs)
206
206
 
207
+ def upsert_run_queue(self, *args, **kwargs):
208
+ return self.api.upsert_run_queue(*args, **kwargs)
209
+
207
210
  def update_launch_agent_status(self, *args, **kwargs):
208
211
  return self.api.update_launch_agent_status(*args, **kwargs)
209
212
 
wandb/apis/public/api.py CHANGED
@@ -421,6 +421,121 @@ class Api:
421
421
  _default_resource_config=config,
422
422
  )
423
423
 
424
+ def upsert_run_queue(
425
+ self,
426
+ name: str,
427
+ resource_config: dict,
428
+ resource_type: "public.RunQueueResourceType",
429
+ entity: Optional[str] = None,
430
+ template_variables: Optional[dict] = None,
431
+ external_links: Optional[dict] = None,
432
+ prioritization_mode: Optional["public.RunQueuePrioritizationMode"] = None,
433
+ ):
434
+ """Upsert a run queue (launch).
435
+
436
+ Arguments:
437
+ name: (str) Name of the queue to create
438
+ entity: (str) Optional name of the entity to create the queue. If None, will use the configured or default entity.
439
+ resource_config: (dict) Optional default resource configuration to be used for the queue. Use handlebars (eg. "{{var}}") to specify template variables.
440
+ resource_type: (str) Type of resource to be used for the queue. One of "local-container", "local-process", "kubernetes", "sagemaker", or "gcp-vertex".
441
+ template_variables: (dict) A dictionary of template variable schemas to be used with the config. Expected format of:
442
+ {
443
+ "var-name": {
444
+ "schema": {
445
+ "type": ("string", "number", or "integer"),
446
+ "default": (optional value),
447
+ "minimum": (optional minimum),
448
+ "maximum": (optional maximum),
449
+ "enum": [..."(options)"]
450
+ }
451
+ }
452
+ }
453
+ external_links: (dict) Optional dictionary of external links to be used with the queue. Expected format of:
454
+ {
455
+ "name": "url"
456
+ }
457
+ prioritization_mode: (str) Optional version of prioritization to use. Either "V0" or None
458
+
459
+ Returns:
460
+ The upserted `RunQueue`.
461
+
462
+ Raises:
463
+ ValueError if any of the parameters are invalid
464
+ wandb.Error on wandb API errors
465
+ """
466
+ if entity is None:
467
+ entity = self.settings["entity"] or self.default_entity
468
+ if entity is None:
469
+ raise ValueError(
470
+ "entity must be passed as a parameter, or set in settings"
471
+ )
472
+
473
+ if len(name) == 0:
474
+ raise ValueError("name must be non-empty")
475
+ if len(name) > 64:
476
+ raise ValueError("name must be less than 64 characters")
477
+
478
+ prioritization_mode = prioritization_mode or "DISABLED"
479
+ prioritization_mode = prioritization_mode.upper()
480
+ if prioritization_mode not in ["V0", "DISABLED"]:
481
+ raise ValueError(
482
+ "prioritization_mode must be 'V0' or 'DISABLED' if present"
483
+ )
484
+
485
+ if resource_type not in [
486
+ "local-container",
487
+ "local-process",
488
+ "kubernetes",
489
+ "sagemaker",
490
+ "gcp-vertex",
491
+ ]:
492
+ raise ValueError(
493
+ "resource_type must be one of 'local-container', 'local-process', 'kubernetes', 'sagemaker', or 'gcp-vertex'"
494
+ )
495
+
496
+ self.create_project(LAUNCH_DEFAULT_PROJECT, entity)
497
+ api = InternalApi(
498
+ default_settings={
499
+ "entity": entity,
500
+ "project": self.project(LAUNCH_DEFAULT_PROJECT),
501
+ },
502
+ retry_timedelta=RETRY_TIMEDELTA,
503
+ )
504
+ # User provides external_links as a dict with name: url format
505
+ # but backend stores it as a list of dicts with url and label keys.
506
+ external_links = external_links or {}
507
+ external_links = {
508
+ "links": [
509
+ {
510
+ "label": key,
511
+ "url": value,
512
+ }
513
+ for key, value in external_links.items()
514
+ ]
515
+ }
516
+ upsert_run_queue_result = api.upsert_run_queue(
517
+ name,
518
+ entity,
519
+ resource_type,
520
+ {"resource_args": {resource_type: resource_config}},
521
+ template_variables=template_variables,
522
+ external_links=external_links,
523
+ prioritization_mode=prioritization_mode,
524
+ )
525
+ if not upsert_run_queue_result["success"]:
526
+ raise wandb.Error("failed to create run queue")
527
+ schema_errors = (
528
+ upsert_run_queue_result.get("configSchemaValidationErrors") or []
529
+ )
530
+ for error in schema_errors:
531
+ wandb.termwarn(f"resource config validation: {error}")
532
+
533
+ return public.RunQueue(
534
+ client=self.client,
535
+ name=name,
536
+ entity=entity,
537
+ )
538
+
424
539
  def create_user(self, email, admin=False):
425
540
  """Create a new user.
426
541
 
@@ -996,7 +1111,11 @@ class Api:
996
1111
 
997
1112
  @normalize_exceptions
998
1113
  def artifacts(
999
- self, type_name: str, name: str, per_page: Optional[int] = 50
1114
+ self,
1115
+ type_name: str,
1116
+ name: str,
1117
+ per_page: Optional[int] = 50,
1118
+ tags: Optional[List[str]] = None,
1000
1119
  ) -> "public.Artifacts":
1001
1120
  """Return an `Artifacts` collection from the given parameters.
1002
1121
 
@@ -1005,13 +1124,20 @@ class Api:
1005
1124
  name: (str) An artifact collection name. May be prefixed with entity/project.
1006
1125
  per_page: (int, optional) Sets the page size for query pagination. None will use the default size.
1007
1126
  Usually there is no reason to change this.
1127
+ tags: (list[str], optional) Only return artifacts with all of these tags.
1008
1128
 
1009
1129
  Returns:
1010
1130
  An iterable `Artifacts` object.
1011
1131
  """
1012
1132
  entity, project, collection_name = self._parse_artifact_path(name)
1013
1133
  return public.Artifacts(
1014
- self.client, entity, project, collection_name, type_name, per_page=per_page
1134
+ self.client,
1135
+ entity,
1136
+ project,
1137
+ collection_name,
1138
+ type_name,
1139
+ per_page=per_page,
1140
+ tags=tags,
1015
1141
  )
1016
1142
 
1017
1143
  @normalize_exceptions
@@ -3,7 +3,7 @@
3
3
  import json
4
4
  import re
5
5
  from copy import copy
6
- from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sequence
6
+ from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sequence, Union
7
7
 
8
8
  from wandb_gql import Client, gql
9
9
 
@@ -757,12 +757,14 @@ class Artifacts(Paginator):
757
757
  filters: Optional[Mapping[str, Any]] = None,
758
758
  order: Optional[str] = None,
759
759
  per_page: int = 50,
760
+ tags: Optional[Union[str, List[str]]] = None,
760
761
  ):
761
762
  self.entity = entity
762
763
  self.collection_name = collection_name
763
764
  self.type = type
764
765
  self.project = project
765
766
  self.filters = {"state": "COMMITTED"} if filters is None else filters
767
+ self.tags = [tags] if isinstance(tags, str) else tags
766
768
  self.order = order
767
769
  variables = {
768
770
  "project": self.project,
@@ -835,9 +837,9 @@ class Artifacts(Paginator):
835
837
  return None
836
838
 
837
839
  def convert_objects(self):
838
- if self.last_response["project"]["artifactType"]["artifactCollection"] is None:
839
- return []
840
- return [
840
+ collection = self.last_response["project"]["artifactType"]["artifactCollection"]
841
+ artifact_edges = collection.get("artifacts", {}).get("edges", [])
842
+ artifacts = (
841
843
  wandb.Artifact._from_attrs(
842
844
  self.entity,
843
845
  self.project,
@@ -845,9 +847,11 @@ class Artifacts(Paginator):
845
847
  a["node"],
846
848
  self.client,
847
849
  )
848
- for a in self.last_response["project"]["artifactType"][
849
- "artifactCollection"
850
- ]["artifacts"]["edges"]
850
+ for a in artifact_edges
851
+ )
852
+ required_tags = set(self.tags or [])
853
+ return [
854
+ artifact for artifact in artifacts if required_tags.issubset(artifact.tags)
851
855
  ]
852
856
 
853
857
 
wandb/apis/public/jobs.py CHANGED
@@ -473,6 +473,12 @@ class RunQueue:
473
473
  self._get_metadata()
474
474
  return self._access
475
475
 
476
+ @property
477
+ def external_links(self) -> Dict[str, str]:
478
+ if self._external_links is None:
479
+ self._get_metadata()
480
+ return self._external_links
481
+
476
482
  @property
477
483
  def type(self) -> RunQueueResourceType:
478
484
  if self._type is None:
@@ -549,6 +555,7 @@ class RunQueue:
549
555
  access
550
556
  defaultResourceConfigID
551
557
  prioritizationMode
558
+ externalLinks
552
559
  }
553
560
  }
554
561
  }
@@ -565,6 +572,7 @@ class RunQueue:
565
572
  self._default_resource_config_id = res["project"]["runQueue"][
566
573
  "defaultResourceConfigID"
567
574
  ]
575
+ self._external_links = res["project"]["runQueue"]["externalLinks"]
568
576
  if self._default_resource_config_id is None:
569
577
  self._default_resource_config = {}
570
578
  self._prioritization_mode = res["project"]["runQueue"]["prioritizationMode"]
wandb/apis/public/runs.py CHANGED
@@ -6,7 +6,7 @@ import sys
6
6
  import tempfile
7
7
  import time
8
8
  import urllib
9
- from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional
9
+ from typing import TYPE_CHECKING, Any, Collection, Dict, List, Mapping, Optional
10
10
 
11
11
  if sys.version_info >= (3, 8):
12
12
  from typing import Literal
@@ -794,13 +794,20 @@ class Run(Attrs):
794
794
  raise ValueError("You must pass a wandb.Api().artifact() to use_artifact")
795
795
 
796
796
  @normalize_exceptions
797
- def log_artifact(self, artifact, aliases=None):
797
+ def log_artifact(
798
+ self,
799
+ artifact: "wandb.Artifact",
800
+ aliases: Optional[Collection[str]] = None,
801
+ tags: Optional[Collection[str]] = None,
802
+ ):
798
803
  """Declare an artifact as output of a run.
799
804
 
800
805
  Arguments:
801
806
  artifact (`Artifact`): An artifact returned from
802
- `wandb.Api().artifact(name)`
803
- aliases (list, optional): Aliases to apply to this artifact
807
+ `wandb.Api().artifact(name)`.
808
+ aliases (list, optional): Aliases to apply to this artifact.
809
+ tags: (list, optional) Tags to apply to this artifact, if any.
810
+
804
811
  Returns:
805
812
  A `Artifact` object.
806
813
  """
@@ -825,7 +832,11 @@ class Run(Attrs):
825
832
 
826
833
  artifact_collection_name = artifact.source_name.split(":")[0]
827
834
  api.create_artifact(
828
- artifact.type, artifact_collection_name, artifact.digest, aliases=aliases
835
+ artifact.type,
836
+ artifact_collection_name,
837
+ artifact.digest,
838
+ aliases=aliases,
839
+ tags=tags,
829
840
  )
830
841
  return artifact
831
842
 
Binary file
wandb/bin/wandb-core CHANGED
Binary file
wandb/cli/cli.py CHANGED
@@ -275,7 +275,6 @@ def login(key, host, cloud, relogin, anonymously, verify, no_offline=False):
275
275
  @click.option("--address", default=None, help="The address to bind service.")
276
276
  @click.option("--pid", default=None, type=int, help="The parent process id to monitor.")
277
277
  @click.option("--debug", is_flag=True, help="log debug info")
278
- @click.option("--serve-sock", is_flag=True, help="use socket mode")
279
278
  @display_error
280
279
  def service(
281
280
  sock_port=None,
@@ -283,7 +282,6 @@ def service(
283
282
  address=None,
284
283
  pid=None,
285
284
  debug=False,
286
- serve_sock=False,
287
285
  ):
288
286
  from wandb.sdk.service.server import WandbServer
289
287
 
@@ -293,7 +291,6 @@ def service(
293
291
  address=address,
294
292
  pid=pid,
295
293
  debug=debug,
296
- serve_sock=serve_sock,
297
294
  )
298
295
  server.serve()
299
296
 
wandb/errors/__init__.py CHANGED
@@ -1,46 +1,17 @@
1
- __all__ = [
1
+ __all__ = (
2
2
  "Error",
3
3
  "CommError",
4
4
  "AuthenticationError",
5
5
  "UsageError",
6
6
  "UnsupportedError",
7
7
  "WandbCoreNotAvailableError",
8
- ]
9
-
10
- from typing import Optional
11
-
12
-
13
- class Error(Exception):
14
- """Base W&B Error."""
15
-
16
- def __init__(self, message, context: Optional[dict] = None) -> None:
17
- super().__init__(message)
18
- self.message = message
19
- # sentry context capture
20
- if context:
21
- self.context = context
22
-
23
-
24
- class CommError(Error):
25
- """Error communicating with W&B servers."""
26
-
27
- def __init__(self, msg, exc=None) -> None:
28
- self.exc = exc
29
- self.message = msg
30
- super().__init__(self.message)
31
-
32
-
33
- class AuthenticationError(CommError):
34
- """Raised when authentication fails."""
35
-
36
-
37
- class UsageError(Error):
38
- """Raised when an invalid usage of the SDK API is detected."""
39
-
40
-
41
- class UnsupportedError(UsageError):
42
- """Raised when trying to use a feature that is not supported."""
43
-
44
-
45
- class WandbCoreNotAvailableError(Error):
46
- """Raised when wandb core is not available."""
8
+ )
9
+
10
+ from .errors import (
11
+ AuthenticationError,
12
+ CommError,
13
+ Error,
14
+ UnsupportedError,
15
+ UsageError,
16
+ WandbCoreNotAvailableError,
17
+ )
wandb/errors/errors.py ADDED
@@ -0,0 +1,37 @@
1
+ from typing import Optional
2
+
3
+
4
+ class Error(Exception):
5
+ """Base W&B Error."""
6
+
7
+ def __init__(self, message, context: Optional[dict] = None) -> None:
8
+ super().__init__(message)
9
+ self.message = message
10
+ # sentry context capture
11
+ if context:
12
+ self.context = context
13
+
14
+
15
+ class CommError(Error):
16
+ """Error communicating with W&B servers."""
17
+
18
+ def __init__(self, msg, exc=None) -> None:
19
+ self.exc = exc
20
+ self.message = msg
21
+ super().__init__(self.message)
22
+
23
+
24
+ class AuthenticationError(CommError):
25
+ """Raised when authentication fails."""
26
+
27
+
28
+ class UsageError(Error):
29
+ """Raised when an invalid usage of the SDK API is detected."""
30
+
31
+
32
+ class UnsupportedError(UsageError):
33
+ """Raised when trying to use a feature that is not supported."""
34
+
35
+
36
+ class WandbCoreNotAvailableError(Error):
37
+ """Raised when wandb core is not available."""
@@ -0,0 +1,2 @@
1
+ class WandbWarning(Warning):
2
+ """Base W&B Warning."""
@@ -5,8 +5,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
5
5
 
6
6
  import wandb
7
7
  import wandb.util
8
+ from wandb.plot.viz import custom_chart
8
9
  from wandb.sdk.lib import telemetry
9
- from wandb.sdk.lib.viz import custom_chart
10
10
 
11
11
  if TYPE_CHECKING:
12
12
  import numpy as np