wandb 0.18.3__py3-none-win32.whl → 0.18.4__py3-none-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. wandb/__init__.py +16 -7
  2. wandb/__init__.pyi +96 -63
  3. wandb/analytics/sentry.py +91 -88
  4. wandb/apis/public/api.py +18 -4
  5. wandb/apis/public/runs.py +53 -2
  6. wandb/bin/gpu_stats.exe +0 -0
  7. wandb/bin/wandb-core +0 -0
  8. wandb/cli/beta.py +178 -0
  9. wandb/cli/cli.py +5 -171
  10. wandb/data_types.py +3 -0
  11. wandb/env.py +74 -73
  12. wandb/errors/term.py +300 -43
  13. wandb/proto/v3/wandb_internal_pb2.py +263 -223
  14. wandb/proto/v3/wandb_server_pb2.py +57 -37
  15. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  16. wandb/proto/v4/wandb_internal_pb2.py +226 -218
  17. wandb/proto/v4/wandb_server_pb2.py +41 -37
  18. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  19. wandb/proto/v5/wandb_internal_pb2.py +226 -218
  20. wandb/proto/v5/wandb_server_pb2.py +41 -37
  21. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  22. wandb/sdk/__init__.py +3 -3
  23. wandb/sdk/artifacts/_validators.py +41 -8
  24. wandb/sdk/artifacts/artifact.py +32 -1
  25. wandb/sdk/artifacts/artifact_file_cache.py +1 -2
  26. wandb/sdk/data_types/_dtypes.py +7 -3
  27. wandb/sdk/data_types/video.py +15 -6
  28. wandb/sdk/interface/interface.py +2 -0
  29. wandb/sdk/internal/internal_api.py +122 -5
  30. wandb/sdk/internal/sender.py +16 -3
  31. wandb/sdk/launch/inputs/internal.py +1 -1
  32. wandb/sdk/lib/module.py +12 -0
  33. wandb/sdk/lib/printer.py +291 -105
  34. wandb/sdk/lib/progress.py +274 -0
  35. wandb/sdk/service/streams.py +21 -11
  36. wandb/sdk/wandb_init.py +58 -54
  37. wandb/sdk/wandb_run.py +380 -454
  38. wandb/sdk/wandb_settings.py +2 -0
  39. wandb/sdk/wandb_watch.py +17 -11
  40. wandb/util.py +6 -2
  41. {wandb-0.18.3.dist-info → wandb-0.18.4.dist-info}/METADATA +4 -3
  42. {wandb-0.18.3.dist-info → wandb-0.18.4.dist-info}/RECORD +45 -43
  43. wandb/bin/nvidia_gpu_stats.exe +0 -0
  44. {wandb-0.18.3.dist-info → wandb-0.18.4.dist-info}/WHEEL +0 -0
  45. {wandb-0.18.3.dist-info → wandb-0.18.4.dist-info}/entry_points.txt +0 -0
  46. {wandb-0.18.3.dist-info → wandb-0.18.4.dist-info}/licenses/LICENSE +0 -0
@@ -20,6 +20,7 @@ from typing import (
20
20
  Dict,
21
21
  Iterable,
22
22
  List,
23
+ Literal,
23
24
  Mapping,
24
25
  MutableMapping,
25
26
  Optional,
@@ -41,6 +42,7 @@ from wandb.apis.normalize import normalize_exceptions, parse_backend_error_messa
41
42
  from wandb.errors import AuthenticationError, CommError, UnsupportedError, UsageError
42
43
  from wandb.integration.sagemaker import parse_sm_secrets
43
44
  from wandb.old.settings import Settings
45
+ from wandb.sdk.artifacts._validators import is_artifact_registry_project
44
46
  from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
45
47
  from wandb.sdk.lib.gql_request import GraphQLSession
46
48
  from wandb.sdk.lib.hashutil import B64MD5, md5_file_b64
@@ -306,6 +308,7 @@ class Api:
306
308
  self.server_use_artifact_input_info: Optional[List[str]] = None
307
309
  self.server_create_artifact_input_info: Optional[List[str]] = None
308
310
  self.server_artifact_fields_info: Optional[List[str]] = None
311
+ self.server_organization_type_fields_info: Optional[List[str]] = None
309
312
  self._max_cli_version: Optional[str] = None
310
313
  self._server_settings_type: Optional[List[str]] = None
311
314
  self.fail_run_queue_item_input_info: Optional[List[str]] = None
@@ -3366,7 +3369,7 @@ class Api:
3366
3369
  project: Optional[str] = None,
3367
3370
  description: Optional[str] = None,
3368
3371
  force: bool = True,
3369
- progress: Union[TextIO, bool] = False,
3372
+ progress: Union[TextIO, Literal[False]] = False,
3370
3373
  ) -> "List[Optional[requests.Response]]":
3371
3374
  """Uploads multiple files to W&B.
3372
3375
 
@@ -3378,7 +3381,7 @@ class Api:
3378
3381
  description (str, optional): The description of the changes
3379
3382
  force (bool, optional): Whether to prevent push if git has uncommitted changes
3380
3383
  progress (callable, or stream): If callable, will be called with (chunk_bytes,
3381
- total_bytes) as argument else if True, renders a progress bar to stream.
3384
+ total_bytes) as argument. If TextIO, renders a progress bar to it.
3382
3385
 
3383
3386
  Returns:
3384
3387
  A list of `requests.Response` objects
@@ -3439,8 +3442,8 @@ class Api:
3439
3442
  )
3440
3443
  else:
3441
3444
  length = os.fstat(open_file.fileno()).st_size
3442
- with click.progressbar(
3443
- file=progress, # type: ignore
3445
+ with click.progressbar( # type: ignore
3446
+ file=progress,
3444
3447
  length=length,
3445
3448
  label=f"Uploading file: {file_name}",
3446
3449
  fill_char=click.style("&", fg="green"),
@@ -3464,6 +3467,7 @@ class Api:
3464
3467
  entity: str,
3465
3468
  project: str,
3466
3469
  aliases: Sequence[str],
3470
+ organization: str,
3467
3471
  ) -> Dict[str, Any]:
3468
3472
  template = """
3469
3473
  mutation LinkArtifact(
@@ -3485,6 +3489,14 @@ class Api:
3485
3489
  }
3486
3490
  """
3487
3491
 
3492
+ org_entity = ""
3493
+ if is_artifact_registry_project(project):
3494
+ try:
3495
+ org_entity = self._resolve_org_entity_name(entity, organization)
3496
+ except ValueError as e:
3497
+ wandb.termerror(str(e))
3498
+ raise
3499
+
3488
3500
  def replace(a: str, b: str) -> None:
3489
3501
  nonlocal template
3490
3502
  template = template.replace(a, b)
@@ -3500,7 +3512,7 @@ class Api:
3500
3512
  "clientID": client_id,
3501
3513
  "artifactID": server_id,
3502
3514
  "artifactPortfolioName": portfolio_name,
3503
- "entityName": entity,
3515
+ "entityName": org_entity or entity,
3504
3516
  "projectName": project,
3505
3517
  "aliases": [
3506
3518
  {"alias": alias, "artifactCollectionName": portfolio_name}
@@ -3513,6 +3525,89 @@ class Api:
3513
3525
  link_artifact: Dict[str, Any] = response["linkArtifact"]
3514
3526
  return link_artifact
3515
3527
 
3528
+ def _resolve_org_entity_name(self, entity: str, organization: str = "") -> str:
3529
+ # resolveOrgEntityName fetches the portfolio's org entity's name.
3530
+ #
3531
+ # The organization parameter may be empty, an org's display name, or an org entity name.
3532
+ #
3533
+ # If the server doesn't support fetching the org name of a portfolio, then this returns
3534
+ # the organization parameter, or an error if it is empty. Otherwise, this returns the
3535
+ # fetched value after validating that the given organization, if not empty, matches
3536
+ # either the org's display or entity name.
3537
+ org_fields = self.server_organization_type_introspection()
3538
+ can_fetch_org_entity = "orgEntity" in org_fields
3539
+ if not organization and not can_fetch_org_entity:
3540
+ raise ValueError(
3541
+ "Fetching Registry artifacts without inputting an organization "
3542
+ "is unavailable for your server version. "
3543
+ "Please upgrade your server to 0.50.0 or later."
3544
+ )
3545
+ if not can_fetch_org_entity:
3546
+ # Server doesn't support fetching org entity to validate,
3547
+ # assume org entity is correctly inputted
3548
+ return organization
3549
+
3550
+ org_entity, org_name = self.fetch_org_entity_from_entity(entity)
3551
+ if organization:
3552
+ if organization != org_name and organization != org_entity:
3553
+ raise ValueError(
3554
+ f"Artifact belongs to the organization {org_name!r} "
3555
+ f"and cannot be linked/fetched with {organization!r}. "
3556
+ "Please update the target path with the correct organization name."
3557
+ )
3558
+ wandb.termwarn(
3559
+ "Registries can be linked/fetched using a shorthand form without specifying the organization name. "
3560
+ "Try using shorthand path format: <my_registry_name>/<artifact_name>"
3561
+ )
3562
+ return org_entity
3563
+
3564
+ def fetch_org_entity_from_entity(self, entity: str) -> Tuple[str, str]:
3565
+ query = gql(
3566
+ """
3567
+ query FetchOrgEntityFromEntity(
3568
+ $entityName: String!,
3569
+ ) {
3570
+ entity(name: $entityName) {
3571
+ isTeam
3572
+ organization {
3573
+ name
3574
+ orgEntity {
3575
+ name
3576
+ }
3577
+ }
3578
+ }
3579
+ }
3580
+ """
3581
+ )
3582
+ response = self.gql(
3583
+ query,
3584
+ variable_values={
3585
+ "entityName": entity,
3586
+ },
3587
+ )
3588
+ try:
3589
+ is_team = response["entity"].get("isTeam", False)
3590
+ org = response["entity"]["organization"]
3591
+ org_name = org["name"] or ""
3592
+ org_entity_name = org["orgEntity"]["name"] or ""
3593
+ except (LookupError, TypeError) as e:
3594
+ if is_team:
3595
+ # This path should pretty much never be reached as all team entities have an organization.
3596
+ raise ValueError(
3597
+ f"Unable to find an organization under entity {entity!r}. "
3598
+ ) from e
3599
+ else:
3600
+ raise ValueError(
3601
+ f"Unable to resolve an organization associated with the entity: {entity!r} "
3602
+ "that is initialized in the API or Run settings. This could be because "
3603
+ f"{entity!r} is a personal entity or the team entity doesn't exist. "
3604
+ "Please re-initialize the API or Run with a team entity using "
3605
+ "wandb.Api(overrides={'entity': '<my_team_entity>'}) "
3606
+ "or wandb.init(entity='<my_team_entity>') "
3607
+ ) from e
3608
+ else:
3609
+ return org_entity_name, org_name
3610
+
3516
3611
  def use_artifact(
3517
3612
  self,
3518
3613
  artifact_id: str,
@@ -3580,6 +3675,28 @@ class Api:
3580
3675
  return artifact
3581
3676
  return None
3582
3677
 
3678
+ # Fetch fields available in backend of Organization type
3679
+ def server_organization_type_introspection(self) -> List[str]:
3680
+ query_string = """
3681
+ query ProbeServerOrganization {
3682
+ OrganizationInfoType: __type(name:"Organization") {
3683
+ fields {
3684
+ name
3685
+ }
3686
+ }
3687
+ }
3688
+ """
3689
+
3690
+ if self.server_organization_type_fields_info is None:
3691
+ query = gql(query_string)
3692
+ res = self.gql(query)
3693
+ input_fields = res.get("OrganizationInfoType", {}).get("fields", [{}])
3694
+ self.server_organization_type_fields_info = [
3695
+ field["name"] for field in input_fields if "name" in field
3696
+ ]
3697
+
3698
+ return self.server_organization_type_fields_info
3699
+
3583
3700
  def create_artifact_type(
3584
3701
  self,
3585
3702
  artifact_type_name: str,
@@ -1455,16 +1455,29 @@ class SendManager:
1455
1455
  entity = link.portfolio_entity
1456
1456
  project = link.portfolio_project
1457
1457
  aliases = link.portfolio_aliases
1458
+ organization = link.portfolio_organization
1458
1459
  logger.debug(
1459
- f"link_artifact params - client_id={client_id}, server_id={server_id}, pfolio={portfolio_name}, entity={entity}, project={project}"
1460
+ f"link_artifact params - client_id={client_id}, server_id={server_id}, "
1461
+ f"portfolio_name={portfolio_name}, entity={entity}, project={project}, "
1462
+ f"organization={organization}"
1460
1463
  )
1461
1464
  if (client_id or server_id) and portfolio_name and entity and project:
1462
1465
  try:
1463
1466
  self._api.link_artifact(
1464
- client_id, server_id, portfolio_name, entity, project, aliases
1467
+ client_id,
1468
+ server_id,
1469
+ portfolio_name,
1470
+ entity,
1471
+ project,
1472
+ aliases,
1473
+ organization,
1465
1474
  )
1466
1475
  except Exception as e:
1467
- result.response.log_artifact_response.error_message = f'error linking artifact to "{entity}/{project}/{portfolio_name}"; error: {e}'
1476
+ org_or_entity = organization or entity
1477
+ result.response.log_artifact_response.error_message = (
1478
+ f"error linking artifact to "
1479
+ f'"{org_or_entity}/{project}/{portfolio_name}"; error: {e}'
1480
+ )
1468
1481
  logger.warning("Failed to link artifact to portfolio: %s", e)
1469
1482
  self._respond_result(result)
1470
1483
 
@@ -143,7 +143,7 @@ def _replace_refs_and_allofs(schema: dict, defs: Optional[dict]) -> dict:
143
143
  # Reference found, replace it with its definition
144
144
  def_key = schema.pop("$ref").split("#/$defs/")[1]
145
145
  # Also run recursive replacement in case a ref contains more refs
146
- ret = _replace_refs_and_allofs(defs.pop(def_key), defs)
146
+ ret = _replace_refs_and_allofs(defs[def_key], defs)
147
147
  for key, val in schema.items():
148
148
  if isinstance(val, dict):
149
149
  # Step into dicts recursively
wandb/sdk/lib/module.py CHANGED
@@ -19,6 +19,8 @@ def set_global(
19
19
  log_model=None,
20
20
  use_model=None,
21
21
  link_model=None,
22
+ watch=None,
23
+ unwatch=None,
22
24
  ):
23
25
  if run:
24
26
  wandb.run = run
@@ -48,6 +50,10 @@ def set_global(
48
50
  wandb.use_model = use_model
49
51
  if link_model:
50
52
  wandb.link_model = link_model
53
+ if watch:
54
+ wandb.watch = watch
55
+ if unwatch:
56
+ wandb.unwatch = unwatch
51
57
 
52
58
 
53
59
  def unset_globals():
@@ -55,6 +61,12 @@ def unset_globals():
55
61
  wandb.config = preinit.PreInitObject("wandb.config")
56
62
  wandb.summary = preinit.PreInitObject("wandb.summary")
57
63
  wandb.log = preinit.PreInitCallable("wandb.log", wandb.wandb_sdk.wandb_run.Run.log)
64
+ wandb.watch = preinit.PreInitCallable(
65
+ "wandb.watch", wandb.wandb_sdk.wandb_run.Run.watch
66
+ )
67
+ wandb.unwatch = preinit.PreInitCallable(
68
+ "wandb.unwatch", wandb.wandb_sdk.wandb_run.Run.unwatch
69
+ )
58
70
  wandb.save = preinit.PreInitCallable(
59
71
  "wandb.save", wandb.wandb_sdk.wandb_run.Run.save
60
72
  )