wandb 0.18.3__py3-none-macosx_11_0_arm64.whl → 0.18.5__py3-none-macosx_11_0_arm64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (46) hide show
  1. wandb/__init__.py +16 -7
  2. wandb/__init__.pyi +96 -63
  3. wandb/analytics/sentry.py +91 -88
  4. wandb/apis/public/api.py +18 -4
  5. wandb/apis/public/runs.py +53 -2
  6. wandb/bin/gpu_stats +0 -0
  7. wandb/bin/wandb-core +0 -0
  8. wandb/cli/beta.py +178 -0
  9. wandb/cli/cli.py +5 -171
  10. wandb/data_types.py +3 -0
  11. wandb/env.py +74 -73
  12. wandb/errors/term.py +300 -43
  13. wandb/proto/v3/wandb_internal_pb2.py +263 -223
  14. wandb/proto/v3/wandb_server_pb2.py +57 -37
  15. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  16. wandb/proto/v4/wandb_internal_pb2.py +226 -218
  17. wandb/proto/v4/wandb_server_pb2.py +41 -37
  18. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  19. wandb/proto/v5/wandb_internal_pb2.py +226 -218
  20. wandb/proto/v5/wandb_server_pb2.py +41 -37
  21. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  22. wandb/sdk/__init__.py +3 -3
  23. wandb/sdk/artifacts/_validators.py +41 -8
  24. wandb/sdk/artifacts/artifact.py +32 -1
  25. wandb/sdk/artifacts/artifact_file_cache.py +1 -2
  26. wandb/sdk/data_types/_dtypes.py +7 -3
  27. wandb/sdk/data_types/video.py +15 -6
  28. wandb/sdk/interface/interface.py +2 -0
  29. wandb/sdk/internal/internal_api.py +126 -5
  30. wandb/sdk/internal/sender.py +16 -3
  31. wandb/sdk/launch/inputs/internal.py +1 -1
  32. wandb/sdk/lib/module.py +12 -0
  33. wandb/sdk/lib/printer.py +291 -105
  34. wandb/sdk/lib/progress.py +274 -0
  35. wandb/sdk/service/streams.py +21 -11
  36. wandb/sdk/wandb_init.py +58 -54
  37. wandb/sdk/wandb_run.py +380 -454
  38. wandb/sdk/wandb_settings.py +2 -0
  39. wandb/sdk/wandb_watch.py +17 -11
  40. wandb/util.py +6 -2
  41. {wandb-0.18.3.dist-info → wandb-0.18.5.dist-info}/METADATA +4 -3
  42. {wandb-0.18.3.dist-info → wandb-0.18.5.dist-info}/RECORD +45 -43
  43. wandb/bin/apple_gpu_stats +0 -0
  44. {wandb-0.18.3.dist-info → wandb-0.18.5.dist-info}/WHEEL +0 -0
  45. {wandb-0.18.3.dist-info → wandb-0.18.5.dist-info}/entry_points.txt +0 -0
  46. {wandb-0.18.3.dist-info → wandb-0.18.5.dist-info}/licenses/LICENSE +0 -0
@@ -29,6 +29,11 @@ from typing import (
29
29
  Union,
30
30
  )
31
31
 
32
+ if sys.version_info >= (3, 8):
33
+ from typing import Literal
34
+ else:
35
+ from typing_extensions import Literal
36
+
32
37
  import click
33
38
  import requests
34
39
  import yaml
@@ -41,6 +46,7 @@ from wandb.apis.normalize import normalize_exceptions, parse_backend_error_messa
41
46
  from wandb.errors import AuthenticationError, CommError, UnsupportedError, UsageError
42
47
  from wandb.integration.sagemaker import parse_sm_secrets
43
48
  from wandb.old.settings import Settings
49
+ from wandb.sdk.artifacts._validators import is_artifact_registry_project
44
50
  from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
45
51
  from wandb.sdk.lib.gql_request import GraphQLSession
46
52
  from wandb.sdk.lib.hashutil import B64MD5, md5_file_b64
@@ -306,6 +312,7 @@ class Api:
306
312
  self.server_use_artifact_input_info: Optional[List[str]] = None
307
313
  self.server_create_artifact_input_info: Optional[List[str]] = None
308
314
  self.server_artifact_fields_info: Optional[List[str]] = None
315
+ self.server_organization_type_fields_info: Optional[List[str]] = None
309
316
  self._max_cli_version: Optional[str] = None
310
317
  self._server_settings_type: Optional[List[str]] = None
311
318
  self.fail_run_queue_item_input_info: Optional[List[str]] = None
@@ -3366,7 +3373,7 @@ class Api:
3366
3373
  project: Optional[str] = None,
3367
3374
  description: Optional[str] = None,
3368
3375
  force: bool = True,
3369
- progress: Union[TextIO, bool] = False,
3376
+ progress: Union[TextIO, Literal[False]] = False,
3370
3377
  ) -> "List[Optional[requests.Response]]":
3371
3378
  """Uploads multiple files to W&B.
3372
3379
 
@@ -3378,7 +3385,7 @@ class Api:
3378
3385
  description (str, optional): The description of the changes
3379
3386
  force (bool, optional): Whether to prevent push if git has uncommitted changes
3380
3387
  progress (callable, or stream): If callable, will be called with (chunk_bytes,
3381
- total_bytes) as argument else if True, renders a progress bar to stream.
3388
+ total_bytes) as argument. If TextIO, renders a progress bar to it.
3382
3389
 
3383
3390
  Returns:
3384
3391
  A list of `requests.Response` objects
@@ -3439,8 +3446,8 @@ class Api:
3439
3446
  )
3440
3447
  else:
3441
3448
  length = os.fstat(open_file.fileno()).st_size
3442
- with click.progressbar(
3443
- file=progress, # type: ignore
3449
+ with click.progressbar( # type: ignore
3450
+ file=progress,
3444
3451
  length=length,
3445
3452
  label=f"Uploading file: {file_name}",
3446
3453
  fill_char=click.style("&", fg="green"),
@@ -3464,6 +3471,7 @@ class Api:
3464
3471
  entity: str,
3465
3472
  project: str,
3466
3473
  aliases: Sequence[str],
3474
+ organization: str,
3467
3475
  ) -> Dict[str, Any]:
3468
3476
  template = """
3469
3477
  mutation LinkArtifact(
@@ -3485,6 +3493,14 @@ class Api:
3485
3493
  }
3486
3494
  """
3487
3495
 
3496
+ org_entity = ""
3497
+ if is_artifact_registry_project(project):
3498
+ try:
3499
+ org_entity = self._resolve_org_entity_name(entity, organization)
3500
+ except ValueError as e:
3501
+ wandb.termerror(str(e))
3502
+ raise
3503
+
3488
3504
  def replace(a: str, b: str) -> None:
3489
3505
  nonlocal template
3490
3506
  template = template.replace(a, b)
@@ -3500,7 +3516,7 @@ class Api:
3500
3516
  "clientID": client_id,
3501
3517
  "artifactID": server_id,
3502
3518
  "artifactPortfolioName": portfolio_name,
3503
- "entityName": entity,
3519
+ "entityName": org_entity or entity,
3504
3520
  "projectName": project,
3505
3521
  "aliases": [
3506
3522
  {"alias": alias, "artifactCollectionName": portfolio_name}
@@ -3513,6 +3529,89 @@ class Api:
3513
3529
  link_artifact: Dict[str, Any] = response["linkArtifact"]
3514
3530
  return link_artifact
3515
3531
 
3532
+ def _resolve_org_entity_name(self, entity: str, organization: str = "") -> str:
3533
+ # resolveOrgEntityName fetches the portfolio's org entity's name.
3534
+ #
3535
+ # The organization parameter may be empty, an org's display name, or an org entity name.
3536
+ #
3537
+ # If the server doesn't support fetching the org name of a portfolio, then this returns
3538
+ # the organization parameter, or an error if it is empty. Otherwise, this returns the
3539
+ # fetched value after validating that the given organization, if not empty, matches
3540
+ # either the org's display or entity name.
3541
+ org_fields = self.server_organization_type_introspection()
3542
+ can_fetch_org_entity = "orgEntity" in org_fields
3543
+ if not organization and not can_fetch_org_entity:
3544
+ raise ValueError(
3545
+ "Fetching Registry artifacts without inputting an organization "
3546
+ "is unavailable for your server version. "
3547
+ "Please upgrade your server to 0.50.0 or later."
3548
+ )
3549
+ if not can_fetch_org_entity:
3550
+ # Server doesn't support fetching org entity to validate,
3551
+ # assume org entity is correctly inputted
3552
+ return organization
3553
+
3554
+ org_entity, org_name = self.fetch_org_entity_from_entity(entity)
3555
+ if organization:
3556
+ if organization != org_name and organization != org_entity:
3557
+ raise ValueError(
3558
+ f"Artifact belongs to the organization {org_name!r} "
3559
+ f"and cannot be linked/fetched with {organization!r}. "
3560
+ "Please update the target path with the correct organization name."
3561
+ )
3562
+ wandb.termwarn(
3563
+ "Registries can be linked/fetched using a shorthand form without specifying the organization name. "
3564
+ "Try using shorthand path format: <my_registry_name>/<artifact_name>"
3565
+ )
3566
+ return org_entity
3567
+
3568
+ def fetch_org_entity_from_entity(self, entity: str) -> Tuple[str, str]:
3569
+ query = gql(
3570
+ """
3571
+ query FetchOrgEntityFromEntity(
3572
+ $entityName: String!,
3573
+ ) {
3574
+ entity(name: $entityName) {
3575
+ isTeam
3576
+ organization {
3577
+ name
3578
+ orgEntity {
3579
+ name
3580
+ }
3581
+ }
3582
+ }
3583
+ }
3584
+ """
3585
+ )
3586
+ response = self.gql(
3587
+ query,
3588
+ variable_values={
3589
+ "entityName": entity,
3590
+ },
3591
+ )
3592
+ try:
3593
+ is_team = response["entity"].get("isTeam", False)
3594
+ org = response["entity"]["organization"]
3595
+ org_name = org["name"] or ""
3596
+ org_entity_name = org["orgEntity"]["name"] or ""
3597
+ except (LookupError, TypeError) as e:
3598
+ if is_team:
3599
+ # This path should pretty much never be reached as all team entities have an organization.
3600
+ raise ValueError(
3601
+ f"Unable to find an organization under entity {entity!r}. "
3602
+ ) from e
3603
+ else:
3604
+ raise ValueError(
3605
+ f"Unable to resolve an organization associated with the entity: {entity!r} "
3606
+ "that is initialized in the API or Run settings. This could be because "
3607
+ f"{entity!r} is a personal entity or the team entity doesn't exist. "
3608
+ "Please re-initialize the API or Run with a team entity using "
3609
+ "wandb.Api(overrides={'entity': '<my_team_entity>'}) "
3610
+ "or wandb.init(entity='<my_team_entity>') "
3611
+ ) from e
3612
+ else:
3613
+ return org_entity_name, org_name
3614
+
3516
3615
  def use_artifact(
3517
3616
  self,
3518
3617
  artifact_id: str,
@@ -3580,6 +3679,28 @@ class Api:
3580
3679
  return artifact
3581
3680
  return None
3582
3681
 
3682
+ # Fetch fields available in backend of Organization type
3683
+ def server_organization_type_introspection(self) -> List[str]:
3684
+ query_string = """
3685
+ query ProbeServerOrganization {
3686
+ OrganizationInfoType: __type(name:"Organization") {
3687
+ fields {
3688
+ name
3689
+ }
3690
+ }
3691
+ }
3692
+ """
3693
+
3694
+ if self.server_organization_type_fields_info is None:
3695
+ query = gql(query_string)
3696
+ res = self.gql(query)
3697
+ input_fields = res.get("OrganizationInfoType", {}).get("fields", [{}])
3698
+ self.server_organization_type_fields_info = [
3699
+ field["name"] for field in input_fields if "name" in field
3700
+ ]
3701
+
3702
+ return self.server_organization_type_fields_info
3703
+
3583
3704
  def create_artifact_type(
3584
3705
  self,
3585
3706
  artifact_type_name: str,
@@ -1455,16 +1455,29 @@ class SendManager:
1455
1455
  entity = link.portfolio_entity
1456
1456
  project = link.portfolio_project
1457
1457
  aliases = link.portfolio_aliases
1458
+ organization = link.portfolio_organization
1458
1459
  logger.debug(
1459
- f"link_artifact params - client_id={client_id}, server_id={server_id}, pfolio={portfolio_name}, entity={entity}, project={project}"
1460
+ f"link_artifact params - client_id={client_id}, server_id={server_id}, "
1461
+ f"portfolio_name={portfolio_name}, entity={entity}, project={project}, "
1462
+ f"organization={organization}"
1460
1463
  )
1461
1464
  if (client_id or server_id) and portfolio_name and entity and project:
1462
1465
  try:
1463
1466
  self._api.link_artifact(
1464
- client_id, server_id, portfolio_name, entity, project, aliases
1467
+ client_id,
1468
+ server_id,
1469
+ portfolio_name,
1470
+ entity,
1471
+ project,
1472
+ aliases,
1473
+ organization,
1465
1474
  )
1466
1475
  except Exception as e:
1467
- result.response.log_artifact_response.error_message = f'error linking artifact to "{entity}/{project}/{portfolio_name}"; error: {e}'
1476
+ org_or_entity = organization or entity
1477
+ result.response.log_artifact_response.error_message = (
1478
+ f"error linking artifact to "
1479
+ f'"{org_or_entity}/{project}/{portfolio_name}"; error: {e}'
1480
+ )
1468
1481
  logger.warning("Failed to link artifact to portfolio: %s", e)
1469
1482
  self._respond_result(result)
1470
1483
 
@@ -143,7 +143,7 @@ def _replace_refs_and_allofs(schema: dict, defs: Optional[dict]) -> dict:
143
143
  # Reference found, replace it with its definition
144
144
  def_key = schema.pop("$ref").split("#/$defs/")[1]
145
145
  # Also run recursive replacement in case a ref contains more refs
146
- ret = _replace_refs_and_allofs(defs.pop(def_key), defs)
146
+ ret = _replace_refs_and_allofs(defs[def_key], defs)
147
147
  for key, val in schema.items():
148
148
  if isinstance(val, dict):
149
149
  # Step into dicts recursively
wandb/sdk/lib/module.py CHANGED
@@ -19,6 +19,8 @@ def set_global(
19
19
  log_model=None,
20
20
  use_model=None,
21
21
  link_model=None,
22
+ watch=None,
23
+ unwatch=None,
22
24
  ):
23
25
  if run:
24
26
  wandb.run = run
@@ -48,6 +50,10 @@ def set_global(
48
50
  wandb.use_model = use_model
49
51
  if link_model:
50
52
  wandb.link_model = link_model
53
+ if watch:
54
+ wandb.watch = watch
55
+ if unwatch:
56
+ wandb.unwatch = unwatch
51
57
 
52
58
 
53
59
  def unset_globals():
@@ -55,6 +61,12 @@ def unset_globals():
55
61
  wandb.config = preinit.PreInitObject("wandb.config")
56
62
  wandb.summary = preinit.PreInitObject("wandb.summary")
57
63
  wandb.log = preinit.PreInitCallable("wandb.log", wandb.wandb_sdk.wandb_run.Run.log)
64
+ wandb.watch = preinit.PreInitCallable(
65
+ "wandb.watch", wandb.wandb_sdk.wandb_run.Run.watch
66
+ )
67
+ wandb.unwatch = preinit.PreInitCallable(
68
+ "wandb.unwatch", wandb.wandb_sdk.wandb_run.Run.unwatch
69
+ )
58
70
  wandb.save = preinit.PreInitCallable(
59
71
  "wandb.save", wandb.wandb_sdk.wandb_run.Run.save
60
72
  )