wandb 0.18.3__py3-none-any.whl → 0.18.4__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (45) hide show
  1. wandb/__init__.py +16 -7
  2. wandb/__init__.pyi +96 -63
  3. wandb/analytics/sentry.py +91 -88
  4. wandb/apis/public/api.py +18 -4
  5. wandb/apis/public/runs.py +53 -2
  6. wandb/bin/gpu_stats +0 -0
  7. wandb/cli/beta.py +178 -0
  8. wandb/cli/cli.py +5 -171
  9. wandb/data_types.py +3 -0
  10. wandb/env.py +74 -73
  11. wandb/errors/term.py +300 -43
  12. wandb/proto/v3/wandb_internal_pb2.py +263 -223
  13. wandb/proto/v3/wandb_server_pb2.py +57 -37
  14. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  15. wandb/proto/v4/wandb_internal_pb2.py +226 -218
  16. wandb/proto/v4/wandb_server_pb2.py +41 -37
  17. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  18. wandb/proto/v5/wandb_internal_pb2.py +226 -218
  19. wandb/proto/v5/wandb_server_pb2.py +41 -37
  20. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  21. wandb/sdk/__init__.py +3 -3
  22. wandb/sdk/artifacts/_validators.py +41 -8
  23. wandb/sdk/artifacts/artifact.py +32 -1
  24. wandb/sdk/artifacts/artifact_file_cache.py +1 -2
  25. wandb/sdk/data_types/_dtypes.py +7 -3
  26. wandb/sdk/data_types/video.py +15 -6
  27. wandb/sdk/interface/interface.py +2 -0
  28. wandb/sdk/internal/internal_api.py +122 -5
  29. wandb/sdk/internal/sender.py +16 -3
  30. wandb/sdk/launch/inputs/internal.py +1 -1
  31. wandb/sdk/lib/module.py +12 -0
  32. wandb/sdk/lib/printer.py +291 -105
  33. wandb/sdk/lib/progress.py +274 -0
  34. wandb/sdk/service/streams.py +21 -11
  35. wandb/sdk/wandb_init.py +58 -54
  36. wandb/sdk/wandb_run.py +380 -454
  37. wandb/sdk/wandb_settings.py +2 -0
  38. wandb/sdk/wandb_watch.py +17 -11
  39. wandb/util.py +6 -2
  40. {wandb-0.18.3.dist-info → wandb-0.18.4.dist-info}/METADATA +4 -3
  41. {wandb-0.18.3.dist-info → wandb-0.18.4.dist-info}/RECORD +44 -42
  42. wandb/bin/nvidia_gpu_stats +0 -0
  43. {wandb-0.18.3.dist-info → wandb-0.18.4.dist-info}/WHEEL +0 -0
  44. {wandb-0.18.3.dist-info → wandb-0.18.4.dist-info}/entry_points.txt +0 -0
  45. {wandb-0.18.3.dist-info → wandb-0.18.4.dist-info}/licenses/LICENSE +0 -0
@@ -20,6 +20,7 @@ from typing import (
20
20
  Dict,
21
21
  Iterable,
22
22
  List,
23
+ Literal,
23
24
  Mapping,
24
25
  MutableMapping,
25
26
  Optional,
@@ -41,6 +42,7 @@ from wandb.apis.normalize import normalize_exceptions, parse_backend_error_messa
41
42
  from wandb.errors import AuthenticationError, CommError, UnsupportedError, UsageError
42
43
  from wandb.integration.sagemaker import parse_sm_secrets
43
44
  from wandb.old.settings import Settings
45
+ from wandb.sdk.artifacts._validators import is_artifact_registry_project
44
46
  from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
45
47
  from wandb.sdk.lib.gql_request import GraphQLSession
46
48
  from wandb.sdk.lib.hashutil import B64MD5, md5_file_b64
@@ -306,6 +308,7 @@ class Api:
306
308
  self.server_use_artifact_input_info: Optional[List[str]] = None
307
309
  self.server_create_artifact_input_info: Optional[List[str]] = None
308
310
  self.server_artifact_fields_info: Optional[List[str]] = None
311
+ self.server_organization_type_fields_info: Optional[List[str]] = None
309
312
  self._max_cli_version: Optional[str] = None
310
313
  self._server_settings_type: Optional[List[str]] = None
311
314
  self.fail_run_queue_item_input_info: Optional[List[str]] = None
@@ -3366,7 +3369,7 @@ class Api:
3366
3369
  project: Optional[str] = None,
3367
3370
  description: Optional[str] = None,
3368
3371
  force: bool = True,
3369
- progress: Union[TextIO, bool] = False,
3372
+ progress: Union[TextIO, Literal[False]] = False,
3370
3373
  ) -> "List[Optional[requests.Response]]":
3371
3374
  """Uploads multiple files to W&B.
3372
3375
 
@@ -3378,7 +3381,7 @@ class Api:
3378
3381
  description (str, optional): The description of the changes
3379
3382
  force (bool, optional): Whether to prevent push if git has uncommitted changes
3380
3383
  progress (callable, or stream): If callable, will be called with (chunk_bytes,
3381
- total_bytes) as argument else if True, renders a progress bar to stream.
3384
+ total_bytes) as argument. If TextIO, renders a progress bar to it.
3382
3385
 
3383
3386
  Returns:
3384
3387
  A list of `requests.Response` objects
@@ -3439,8 +3442,8 @@ class Api:
3439
3442
  )
3440
3443
  else:
3441
3444
  length = os.fstat(open_file.fileno()).st_size
3442
- with click.progressbar(
3443
- file=progress, # type: ignore
3445
+ with click.progressbar( # type: ignore
3446
+ file=progress,
3444
3447
  length=length,
3445
3448
  label=f"Uploading file: {file_name}",
3446
3449
  fill_char=click.style("&", fg="green"),
@@ -3464,6 +3467,7 @@ class Api:
3464
3467
  entity: str,
3465
3468
  project: str,
3466
3469
  aliases: Sequence[str],
3470
+ organization: str,
3467
3471
  ) -> Dict[str, Any]:
3468
3472
  template = """
3469
3473
  mutation LinkArtifact(
@@ -3485,6 +3489,14 @@ class Api:
3485
3489
  }
3486
3490
  """
3487
3491
 
3492
+ org_entity = ""
3493
+ if is_artifact_registry_project(project):
3494
+ try:
3495
+ org_entity = self._resolve_org_entity_name(entity, organization)
3496
+ except ValueError as e:
3497
+ wandb.termerror(str(e))
3498
+ raise
3499
+
3488
3500
  def replace(a: str, b: str) -> None:
3489
3501
  nonlocal template
3490
3502
  template = template.replace(a, b)
@@ -3500,7 +3512,7 @@ class Api:
3500
3512
  "clientID": client_id,
3501
3513
  "artifactID": server_id,
3502
3514
  "artifactPortfolioName": portfolio_name,
3503
- "entityName": entity,
3515
+ "entityName": org_entity or entity,
3504
3516
  "projectName": project,
3505
3517
  "aliases": [
3506
3518
  {"alias": alias, "artifactCollectionName": portfolio_name}
@@ -3513,6 +3525,89 @@ class Api:
3513
3525
  link_artifact: Dict[str, Any] = response["linkArtifact"]
3514
3526
  return link_artifact
3515
3527
 
3528
+ def _resolve_org_entity_name(self, entity: str, organization: str = "") -> str:
3529
+ # resolveOrgEntityName fetches the portfolio's org entity's name.
3530
+ #
3531
+ # The organization parameter may be empty, an org's display name, or an org entity name.
3532
+ #
3533
+ # If the server doesn't support fetching the org name of a portfolio, then this returns
3534
+ # the organization parameter, or an error if it is empty. Otherwise, this returns the
3535
+ # fetched value after validating that the given organization, if not empty, matches
3536
+ # either the org's display or entity name.
3537
+ org_fields = self.server_organization_type_introspection()
3538
+ can_fetch_org_entity = "orgEntity" in org_fields
3539
+ if not organization and not can_fetch_org_entity:
3540
+ raise ValueError(
3541
+ "Fetching Registry artifacts without inputting an organization "
3542
+ "is unavailable for your server version. "
3543
+ "Please upgrade your server to 0.50.0 or later."
3544
+ )
3545
+ if not can_fetch_org_entity:
3546
+ # Server doesn't support fetching org entity to validate,
3547
+ # assume org entity is correctly inputted
3548
+ return organization
3549
+
3550
+ org_entity, org_name = self.fetch_org_entity_from_entity(entity)
3551
+ if organization:
3552
+ if organization != org_name and organization != org_entity:
3553
+ raise ValueError(
3554
+ f"Artifact belongs to the organization {org_name!r} "
3555
+ f"and cannot be linked/fetched with {organization!r}. "
3556
+ "Please update the target path with the correct organization name."
3557
+ )
3558
+ wandb.termwarn(
3559
+ "Registries can be linked/fetched using a shorthand form without specifying the organization name. "
3560
+ "Try using shorthand path format: <my_registry_name>/<artifact_name>"
3561
+ )
3562
+ return org_entity
3563
+
3564
+ def fetch_org_entity_from_entity(self, entity: str) -> Tuple[str, str]:
3565
+ query = gql(
3566
+ """
3567
+ query FetchOrgEntityFromEntity(
3568
+ $entityName: String!,
3569
+ ) {
3570
+ entity(name: $entityName) {
3571
+ isTeam
3572
+ organization {
3573
+ name
3574
+ orgEntity {
3575
+ name
3576
+ }
3577
+ }
3578
+ }
3579
+ }
3580
+ """
3581
+ )
3582
+ response = self.gql(
3583
+ query,
3584
+ variable_values={
3585
+ "entityName": entity,
3586
+ },
3587
+ )
3588
+ try:
3589
+ is_team = response["entity"].get("isTeam", False)
3590
+ org = response["entity"]["organization"]
3591
+ org_name = org["name"] or ""
3592
+ org_entity_name = org["orgEntity"]["name"] or ""
3593
+ except (LookupError, TypeError) as e:
3594
+ if is_team:
3595
+ # This path should pretty much never be reached as all team entities have an organization.
3596
+ raise ValueError(
3597
+ f"Unable to find an organization under entity {entity!r}. "
3598
+ ) from e
3599
+ else:
3600
+ raise ValueError(
3601
+ f"Unable to resolve an organization associated with the entity: {entity!r} "
3602
+ "that is initialized in the API or Run settings. This could be because "
3603
+ f"{entity!r} is a personal entity or the team entity doesn't exist. "
3604
+ "Please re-initialize the API or Run with a team entity using "
3605
+ "wandb.Api(overrides={'entity': '<my_team_entity>'}) "
3606
+ "or wandb.init(entity='<my_team_entity>') "
3607
+ ) from e
3608
+ else:
3609
+ return org_entity_name, org_name
3610
+
3516
3611
  def use_artifact(
3517
3612
  self,
3518
3613
  artifact_id: str,
@@ -3580,6 +3675,28 @@ class Api:
3580
3675
  return artifact
3581
3676
  return None
3582
3677
 
3678
+ # Fetch fields available in backend of Organization type
3679
+ def server_organization_type_introspection(self) -> List[str]:
3680
+ query_string = """
3681
+ query ProbeServerOrganization {
3682
+ OrganizationInfoType: __type(name:"Organization") {
3683
+ fields {
3684
+ name
3685
+ }
3686
+ }
3687
+ }
3688
+ """
3689
+
3690
+ if self.server_organization_type_fields_info is None:
3691
+ query = gql(query_string)
3692
+ res = self.gql(query)
3693
+ input_fields = res.get("OrganizationInfoType", {}).get("fields", [{}])
3694
+ self.server_organization_type_fields_info = [
3695
+ field["name"] for field in input_fields if "name" in field
3696
+ ]
3697
+
3698
+ return self.server_organization_type_fields_info
3699
+
3583
3700
  def create_artifact_type(
3584
3701
  self,
3585
3702
  artifact_type_name: str,
@@ -1455,16 +1455,29 @@ class SendManager:
1455
1455
  entity = link.portfolio_entity
1456
1456
  project = link.portfolio_project
1457
1457
  aliases = link.portfolio_aliases
1458
+ organization = link.portfolio_organization
1458
1459
  logger.debug(
1459
- f"link_artifact params - client_id={client_id}, server_id={server_id}, pfolio={portfolio_name}, entity={entity}, project={project}"
1460
+ f"link_artifact params - client_id={client_id}, server_id={server_id}, "
1461
+ f"portfolio_name={portfolio_name}, entity={entity}, project={project}, "
1462
+ f"organization={organization}"
1460
1463
  )
1461
1464
  if (client_id or server_id) and portfolio_name and entity and project:
1462
1465
  try:
1463
1466
  self._api.link_artifact(
1464
- client_id, server_id, portfolio_name, entity, project, aliases
1467
+ client_id,
1468
+ server_id,
1469
+ portfolio_name,
1470
+ entity,
1471
+ project,
1472
+ aliases,
1473
+ organization,
1465
1474
  )
1466
1475
  except Exception as e:
1467
- result.response.log_artifact_response.error_message = f'error linking artifact to "{entity}/{project}/{portfolio_name}"; error: {e}'
1476
+ org_or_entity = organization or entity
1477
+ result.response.log_artifact_response.error_message = (
1478
+ f"error linking artifact to "
1479
+ f'"{org_or_entity}/{project}/{portfolio_name}"; error: {e}'
1480
+ )
1468
1481
  logger.warning("Failed to link artifact to portfolio: %s", e)
1469
1482
  self._respond_result(result)
1470
1483
 
@@ -143,7 +143,7 @@ def _replace_refs_and_allofs(schema: dict, defs: Optional[dict]) -> dict:
143
143
  # Reference found, replace it with its definition
144
144
  def_key = schema.pop("$ref").split("#/$defs/")[1]
145
145
  # Also run recursive replacement in case a ref contains more refs
146
- ret = _replace_refs_and_allofs(defs.pop(def_key), defs)
146
+ ret = _replace_refs_and_allofs(defs[def_key], defs)
147
147
  for key, val in schema.items():
148
148
  if isinstance(val, dict):
149
149
  # Step into dicts recursively
wandb/sdk/lib/module.py CHANGED
@@ -19,6 +19,8 @@ def set_global(
19
19
  log_model=None,
20
20
  use_model=None,
21
21
  link_model=None,
22
+ watch=None,
23
+ unwatch=None,
22
24
  ):
23
25
  if run:
24
26
  wandb.run = run
@@ -48,6 +50,10 @@ def set_global(
48
50
  wandb.use_model = use_model
49
51
  if link_model:
50
52
  wandb.link_model = link_model
53
+ if watch:
54
+ wandb.watch = watch
55
+ if unwatch:
56
+ wandb.unwatch = unwatch
51
57
 
52
58
 
53
59
  def unset_globals():
@@ -55,6 +61,12 @@ def unset_globals():
55
61
  wandb.config = preinit.PreInitObject("wandb.config")
56
62
  wandb.summary = preinit.PreInitObject("wandb.summary")
57
63
  wandb.log = preinit.PreInitCallable("wandb.log", wandb.wandb_sdk.wandb_run.Run.log)
64
+ wandb.watch = preinit.PreInitCallable(
65
+ "wandb.watch", wandb.wandb_sdk.wandb_run.Run.watch
66
+ )
67
+ wandb.unwatch = preinit.PreInitCallable(
68
+ "wandb.unwatch", wandb.wandb_sdk.wandb_run.Run.unwatch
69
+ )
58
70
  wandb.save = preinit.PreInitCallable(
59
71
  "wandb.save", wandb.wandb_sdk.wandb_run.Run.save
60
72
  )