wandb 0.18.3__py3-none-win32.whl → 0.18.5__py3-none-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. wandb/__init__.py +16 -7
  2. wandb/__init__.pyi +96 -63
  3. wandb/analytics/sentry.py +91 -88
  4. wandb/apis/public/api.py +18 -4
  5. wandb/apis/public/runs.py +53 -2
  6. wandb/bin/gpu_stats.exe +0 -0
  7. wandb/bin/wandb-core +0 -0
  8. wandb/cli/beta.py +178 -0
  9. wandb/cli/cli.py +5 -171
  10. wandb/data_types.py +3 -0
  11. wandb/env.py +74 -73
  12. wandb/errors/term.py +300 -43
  13. wandb/proto/v3/wandb_internal_pb2.py +263 -223
  14. wandb/proto/v3/wandb_server_pb2.py +57 -37
  15. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  16. wandb/proto/v4/wandb_internal_pb2.py +226 -218
  17. wandb/proto/v4/wandb_server_pb2.py +41 -37
  18. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  19. wandb/proto/v5/wandb_internal_pb2.py +226 -218
  20. wandb/proto/v5/wandb_server_pb2.py +41 -37
  21. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  22. wandb/sdk/__init__.py +3 -3
  23. wandb/sdk/artifacts/_validators.py +41 -8
  24. wandb/sdk/artifacts/artifact.py +32 -1
  25. wandb/sdk/artifacts/artifact_file_cache.py +1 -2
  26. wandb/sdk/data_types/_dtypes.py +7 -3
  27. wandb/sdk/data_types/video.py +15 -6
  28. wandb/sdk/interface/interface.py +2 -0
  29. wandb/sdk/internal/internal_api.py +126 -5
  30. wandb/sdk/internal/sender.py +16 -3
  31. wandb/sdk/launch/inputs/internal.py +1 -1
  32. wandb/sdk/lib/module.py +12 -0
  33. wandb/sdk/lib/printer.py +291 -105
  34. wandb/sdk/lib/progress.py +274 -0
  35. wandb/sdk/service/streams.py +21 -11
  36. wandb/sdk/wandb_init.py +58 -54
  37. wandb/sdk/wandb_run.py +380 -454
  38. wandb/sdk/wandb_settings.py +2 -0
  39. wandb/sdk/wandb_watch.py +17 -11
  40. wandb/util.py +6 -2
  41. {wandb-0.18.3.dist-info → wandb-0.18.5.dist-info}/METADATA +4 -3
  42. {wandb-0.18.3.dist-info → wandb-0.18.5.dist-info}/RECORD +45 -43
  43. wandb/bin/nvidia_gpu_stats.exe +0 -0
  44. {wandb-0.18.3.dist-info → wandb-0.18.5.dist-info}/WHEEL +0 -0
  45. {wandb-0.18.3.dist-info → wandb-0.18.5.dist-info}/entry_points.txt +0 -0
  46. {wandb-0.18.3.dist-info → wandb-0.18.5.dist-info}/licenses/LICENSE +0 -0
@@ -29,6 +29,11 @@ from typing import (
29
29
  Union,
30
30
  )
31
31
 
32
+ if sys.version_info >= (3, 8):
33
+ from typing import Literal
34
+ else:
35
+ from typing_extensions import Literal
36
+
32
37
  import click
33
38
  import requests
34
39
  import yaml
@@ -41,6 +46,7 @@ from wandb.apis.normalize import normalize_exceptions, parse_backend_error_messa
41
46
  from wandb.errors import AuthenticationError, CommError, UnsupportedError, UsageError
42
47
  from wandb.integration.sagemaker import parse_sm_secrets
43
48
  from wandb.old.settings import Settings
49
+ from wandb.sdk.artifacts._validators import is_artifact_registry_project
44
50
  from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
45
51
  from wandb.sdk.lib.gql_request import GraphQLSession
46
52
  from wandb.sdk.lib.hashutil import B64MD5, md5_file_b64
@@ -306,6 +312,7 @@ class Api:
306
312
  self.server_use_artifact_input_info: Optional[List[str]] = None
307
313
  self.server_create_artifact_input_info: Optional[List[str]] = None
308
314
  self.server_artifact_fields_info: Optional[List[str]] = None
315
+ self.server_organization_type_fields_info: Optional[List[str]] = None
309
316
  self._max_cli_version: Optional[str] = None
310
317
  self._server_settings_type: Optional[List[str]] = None
311
318
  self.fail_run_queue_item_input_info: Optional[List[str]] = None
@@ -3366,7 +3373,7 @@ class Api:
3366
3373
  project: Optional[str] = None,
3367
3374
  description: Optional[str] = None,
3368
3375
  force: bool = True,
3369
- progress: Union[TextIO, bool] = False,
3376
+ progress: Union[TextIO, Literal[False]] = False,
3370
3377
  ) -> "List[Optional[requests.Response]]":
3371
3378
  """Uploads multiple files to W&B.
3372
3379
 
@@ -3378,7 +3385,7 @@ class Api:
3378
3385
  description (str, optional): The description of the changes
3379
3386
  force (bool, optional): Whether to prevent push if git has uncommitted changes
3380
3387
  progress (callable, or stream): If callable, will be called with (chunk_bytes,
3381
- total_bytes) as argument else if True, renders a progress bar to stream.
3388
+ total_bytes) as argument. If TextIO, renders a progress bar to it.
3382
3389
 
3383
3390
  Returns:
3384
3391
  A list of `requests.Response` objects
@@ -3439,8 +3446,8 @@ class Api:
3439
3446
  )
3440
3447
  else:
3441
3448
  length = os.fstat(open_file.fileno()).st_size
3442
- with click.progressbar(
3443
- file=progress, # type: ignore
3449
+ with click.progressbar( # type: ignore
3450
+ file=progress,
3444
3451
  length=length,
3445
3452
  label=f"Uploading file: {file_name}",
3446
3453
  fill_char=click.style("&", fg="green"),
@@ -3464,6 +3471,7 @@ class Api:
3464
3471
  entity: str,
3465
3472
  project: str,
3466
3473
  aliases: Sequence[str],
3474
+ organization: str,
3467
3475
  ) -> Dict[str, Any]:
3468
3476
  template = """
3469
3477
  mutation LinkArtifact(
@@ -3485,6 +3493,14 @@ class Api:
3485
3493
  }
3486
3494
  """
3487
3495
 
3496
+ org_entity = ""
3497
+ if is_artifact_registry_project(project):
3498
+ try:
3499
+ org_entity = self._resolve_org_entity_name(entity, organization)
3500
+ except ValueError as e:
3501
+ wandb.termerror(str(e))
3502
+ raise
3503
+
3488
3504
  def replace(a: str, b: str) -> None:
3489
3505
  nonlocal template
3490
3506
  template = template.replace(a, b)
@@ -3500,7 +3516,7 @@ class Api:
3500
3516
  "clientID": client_id,
3501
3517
  "artifactID": server_id,
3502
3518
  "artifactPortfolioName": portfolio_name,
3503
- "entityName": entity,
3519
+ "entityName": org_entity or entity,
3504
3520
  "projectName": project,
3505
3521
  "aliases": [
3506
3522
  {"alias": alias, "artifactCollectionName": portfolio_name}
@@ -3513,6 +3529,89 @@ class Api:
3513
3529
  link_artifact: Dict[str, Any] = response["linkArtifact"]
3514
3530
  return link_artifact
3515
3531
 
3532
+ def _resolve_org_entity_name(self, entity: str, organization: str = "") -> str:
3533
+ # resolveOrgEntityName fetches the portfolio's org entity's name.
3534
+ #
3535
+ # The organization parameter may be empty, an org's display name, or an org entity name.
3536
+ #
3537
+ # If the server doesn't support fetching the org name of a portfolio, then this returns
3538
+ # the organization parameter, or an error if it is empty. Otherwise, this returns the
3539
+ # fetched value after validating that the given organization, if not empty, matches
3540
+ # either the org's display or entity name.
3541
+ org_fields = self.server_organization_type_introspection()
3542
+ can_fetch_org_entity = "orgEntity" in org_fields
3543
+ if not organization and not can_fetch_org_entity:
3544
+ raise ValueError(
3545
+ "Fetching Registry artifacts without inputting an organization "
3546
+ "is unavailable for your server version. "
3547
+ "Please upgrade your server to 0.50.0 or later."
3548
+ )
3549
+ if not can_fetch_org_entity:
3550
+ # Server doesn't support fetching org entity to validate,
3551
+ # assume org entity is correctly inputted
3552
+ return organization
3553
+
3554
+ org_entity, org_name = self.fetch_org_entity_from_entity(entity)
3555
+ if organization:
3556
+ if organization != org_name and organization != org_entity:
3557
+ raise ValueError(
3558
+ f"Artifact belongs to the organization {org_name!r} "
3559
+ f"and cannot be linked/fetched with {organization!r}. "
3560
+ "Please update the target path with the correct organization name."
3561
+ )
3562
+ wandb.termwarn(
3563
+ "Registries can be linked/fetched using a shorthand form without specifying the organization name. "
3564
+ "Try using shorthand path format: <my_registry_name>/<artifact_name>"
3565
+ )
3566
+ return org_entity
3567
+
3568
+ def fetch_org_entity_from_entity(self, entity: str) -> Tuple[str, str]:
3569
+ query = gql(
3570
+ """
3571
+ query FetchOrgEntityFromEntity(
3572
+ $entityName: String!,
3573
+ ) {
3574
+ entity(name: $entityName) {
3575
+ isTeam
3576
+ organization {
3577
+ name
3578
+ orgEntity {
3579
+ name
3580
+ }
3581
+ }
3582
+ }
3583
+ }
3584
+ """
3585
+ )
3586
+ response = self.gql(
3587
+ query,
3588
+ variable_values={
3589
+ "entityName": entity,
3590
+ },
3591
+ )
3592
+ try:
3593
+ is_team = response["entity"].get("isTeam", False)
3594
+ org = response["entity"]["organization"]
3595
+ org_name = org["name"] or ""
3596
+ org_entity_name = org["orgEntity"]["name"] or ""
3597
+ except (LookupError, TypeError) as e:
3598
+ if is_team:
3599
+ # This path should pretty much never be reached as all team entities have an organization.
3600
+ raise ValueError(
3601
+ f"Unable to find an organization under entity {entity!r}. "
3602
+ ) from e
3603
+ else:
3604
+ raise ValueError(
3605
+ f"Unable to resolve an organization associated with the entity: {entity!r} "
3606
+ "that is initialized in the API or Run settings. This could be because "
3607
+ f"{entity!r} is a personal entity or the team entity doesn't exist. "
3608
+ "Please re-initialize the API or Run with a team entity using "
3609
+ "wandb.Api(overrides={'entity': '<my_team_entity>'}) "
3610
+ "or wandb.init(entity='<my_team_entity>') "
3611
+ ) from e
3612
+ else:
3613
+ return org_entity_name, org_name
3614
+
3516
3615
  def use_artifact(
3517
3616
  self,
3518
3617
  artifact_id: str,
@@ -3580,6 +3679,28 @@ class Api:
3580
3679
  return artifact
3581
3680
  return None
3582
3681
 
3682
+ # Fetch fields available in backend of Organization type
3683
+ def server_organization_type_introspection(self) -> List[str]:
3684
+ query_string = """
3685
+ query ProbeServerOrganization {
3686
+ OrganizationInfoType: __type(name:"Organization") {
3687
+ fields {
3688
+ name
3689
+ }
3690
+ }
3691
+ }
3692
+ """
3693
+
3694
+ if self.server_organization_type_fields_info is None:
3695
+ query = gql(query_string)
3696
+ res = self.gql(query)
3697
+ input_fields = res.get("OrganizationInfoType", {}).get("fields", [{}])
3698
+ self.server_organization_type_fields_info = [
3699
+ field["name"] for field in input_fields if "name" in field
3700
+ ]
3701
+
3702
+ return self.server_organization_type_fields_info
3703
+
3583
3704
  def create_artifact_type(
3584
3705
  self,
3585
3706
  artifact_type_name: str,
@@ -1455,16 +1455,29 @@ class SendManager:
1455
1455
  entity = link.portfolio_entity
1456
1456
  project = link.portfolio_project
1457
1457
  aliases = link.portfolio_aliases
1458
+ organization = link.portfolio_organization
1458
1459
  logger.debug(
1459
- f"link_artifact params - client_id={client_id}, server_id={server_id}, pfolio={portfolio_name}, entity={entity}, project={project}"
1460
+ f"link_artifact params - client_id={client_id}, server_id={server_id}, "
1461
+ f"portfolio_name={portfolio_name}, entity={entity}, project={project}, "
1462
+ f"organization={organization}"
1460
1463
  )
1461
1464
  if (client_id or server_id) and portfolio_name and entity and project:
1462
1465
  try:
1463
1466
  self._api.link_artifact(
1464
- client_id, server_id, portfolio_name, entity, project, aliases
1467
+ client_id,
1468
+ server_id,
1469
+ portfolio_name,
1470
+ entity,
1471
+ project,
1472
+ aliases,
1473
+ organization,
1465
1474
  )
1466
1475
  except Exception as e:
1467
- result.response.log_artifact_response.error_message = f'error linking artifact to "{entity}/{project}/{portfolio_name}"; error: {e}'
1476
+ org_or_entity = organization or entity
1477
+ result.response.log_artifact_response.error_message = (
1478
+ f"error linking artifact to "
1479
+ f'"{org_or_entity}/{project}/{portfolio_name}"; error: {e}'
1480
+ )
1468
1481
  logger.warning("Failed to link artifact to portfolio: %s", e)
1469
1482
  self._respond_result(result)
1470
1483
 
@@ -143,7 +143,7 @@ def _replace_refs_and_allofs(schema: dict, defs: Optional[dict]) -> dict:
143
143
  # Reference found, replace it with its definition
144
144
  def_key = schema.pop("$ref").split("#/$defs/")[1]
145
145
  # Also run recursive replacement in case a ref contains more refs
146
- ret = _replace_refs_and_allofs(defs.pop(def_key), defs)
146
+ ret = _replace_refs_and_allofs(defs[def_key], defs)
147
147
  for key, val in schema.items():
148
148
  if isinstance(val, dict):
149
149
  # Step into dicts recursively
wandb/sdk/lib/module.py CHANGED
@@ -19,6 +19,8 @@ def set_global(
19
19
  log_model=None,
20
20
  use_model=None,
21
21
  link_model=None,
22
+ watch=None,
23
+ unwatch=None,
22
24
  ):
23
25
  if run:
24
26
  wandb.run = run
@@ -48,6 +50,10 @@ def set_global(
48
50
  wandb.use_model = use_model
49
51
  if link_model:
50
52
  wandb.link_model = link_model
53
+ if watch:
54
+ wandb.watch = watch
55
+ if unwatch:
56
+ wandb.unwatch = unwatch
51
57
 
52
58
 
53
59
  def unset_globals():
@@ -55,6 +61,12 @@ def unset_globals():
55
61
  wandb.config = preinit.PreInitObject("wandb.config")
56
62
  wandb.summary = preinit.PreInitObject("wandb.summary")
57
63
  wandb.log = preinit.PreInitCallable("wandb.log", wandb.wandb_sdk.wandb_run.Run.log)
64
+ wandb.watch = preinit.PreInitCallable(
65
+ "wandb.watch", wandb.wandb_sdk.wandb_run.Run.watch
66
+ )
67
+ wandb.unwatch = preinit.PreInitCallable(
68
+ "wandb.unwatch", wandb.wandb_sdk.wandb_run.Run.unwatch
69
+ )
58
70
  wandb.save = preinit.PreInitCallable(
59
71
  "wandb.save", wandb.wandb_sdk.wandb_run.Run.save
60
72
  )