wandb 0.18.3__py3-none-any.whl → 0.18.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +16 -7
- wandb/__init__.pyi +96 -63
- wandb/analytics/sentry.py +91 -88
- wandb/apis/public/api.py +18 -4
- wandb/apis/public/runs.py +53 -2
- wandb/bin/gpu_stats +0 -0
- wandb/cli/beta.py +178 -0
- wandb/cli/cli.py +5 -171
- wandb/data_types.py +3 -0
- wandb/env.py +74 -73
- wandb/errors/term.py +300 -43
- wandb/proto/v3/wandb_internal_pb2.py +263 -223
- wandb/proto/v3/wandb_server_pb2.py +57 -37
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_internal_pb2.py +226 -218
- wandb/proto/v4/wandb_server_pb2.py +41 -37
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_internal_pb2.py +226 -218
- wandb/proto/v5/wandb_server_pb2.py +41 -37
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/sdk/__init__.py +3 -3
- wandb/sdk/artifacts/_validators.py +41 -8
- wandb/sdk/artifacts/artifact.py +32 -1
- wandb/sdk/artifacts/artifact_file_cache.py +1 -2
- wandb/sdk/data_types/_dtypes.py +7 -3
- wandb/sdk/data_types/video.py +15 -6
- wandb/sdk/interface/interface.py +2 -0
- wandb/sdk/internal/internal_api.py +126 -5
- wandb/sdk/internal/sender.py +16 -3
- wandb/sdk/launch/inputs/internal.py +1 -1
- wandb/sdk/lib/module.py +12 -0
- wandb/sdk/lib/printer.py +291 -105
- wandb/sdk/lib/progress.py +274 -0
- wandb/sdk/service/streams.py +21 -11
- wandb/sdk/wandb_init.py +58 -54
- wandb/sdk/wandb_run.py +380 -454
- wandb/sdk/wandb_settings.py +2 -0
- wandb/sdk/wandb_watch.py +17 -11
- wandb/util.py +6 -2
- {wandb-0.18.3.dist-info → wandb-0.18.5.dist-info}/METADATA +4 -3
- {wandb-0.18.3.dist-info → wandb-0.18.5.dist-info}/RECORD +44 -42
- wandb/bin/nvidia_gpu_stats +0 -0
- {wandb-0.18.3.dist-info → wandb-0.18.5.dist-info}/WHEEL +0 -0
- {wandb-0.18.3.dist-info → wandb-0.18.5.dist-info}/entry_points.txt +0 -0
- {wandb-0.18.3.dist-info → wandb-0.18.5.dist-info}/licenses/LICENSE +0 -0
@@ -29,6 +29,11 @@ from typing import (
|
|
29
29
|
Union,
|
30
30
|
)
|
31
31
|
|
32
|
+
if sys.version_info >= (3, 8):
|
33
|
+
from typing import Literal
|
34
|
+
else:
|
35
|
+
from typing_extensions import Literal
|
36
|
+
|
32
37
|
import click
|
33
38
|
import requests
|
34
39
|
import yaml
|
@@ -41,6 +46,7 @@ from wandb.apis.normalize import normalize_exceptions, parse_backend_error_messa
|
|
41
46
|
from wandb.errors import AuthenticationError, CommError, UnsupportedError, UsageError
|
42
47
|
from wandb.integration.sagemaker import parse_sm_secrets
|
43
48
|
from wandb.old.settings import Settings
|
49
|
+
from wandb.sdk.artifacts._validators import is_artifact_registry_project
|
44
50
|
from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
|
45
51
|
from wandb.sdk.lib.gql_request import GraphQLSession
|
46
52
|
from wandb.sdk.lib.hashutil import B64MD5, md5_file_b64
|
@@ -306,6 +312,7 @@ class Api:
|
|
306
312
|
self.server_use_artifact_input_info: Optional[List[str]] = None
|
307
313
|
self.server_create_artifact_input_info: Optional[List[str]] = None
|
308
314
|
self.server_artifact_fields_info: Optional[List[str]] = None
|
315
|
+
self.server_organization_type_fields_info: Optional[List[str]] = None
|
309
316
|
self._max_cli_version: Optional[str] = None
|
310
317
|
self._server_settings_type: Optional[List[str]] = None
|
311
318
|
self.fail_run_queue_item_input_info: Optional[List[str]] = None
|
@@ -3366,7 +3373,7 @@ class Api:
|
|
3366
3373
|
project: Optional[str] = None,
|
3367
3374
|
description: Optional[str] = None,
|
3368
3375
|
force: bool = True,
|
3369
|
-
progress: Union[TextIO,
|
3376
|
+
progress: Union[TextIO, Literal[False]] = False,
|
3370
3377
|
) -> "List[Optional[requests.Response]]":
|
3371
3378
|
"""Uploads multiple files to W&B.
|
3372
3379
|
|
@@ -3378,7 +3385,7 @@ class Api:
|
|
3378
3385
|
description (str, optional): The description of the changes
|
3379
3386
|
force (bool, optional): Whether to prevent push if git has uncommitted changes
|
3380
3387
|
progress (callable, or stream): If callable, will be called with (chunk_bytes,
|
3381
|
-
total_bytes) as argument
|
3388
|
+
total_bytes) as argument. If TextIO, renders a progress bar to it.
|
3382
3389
|
|
3383
3390
|
Returns:
|
3384
3391
|
A list of `requests.Response` objects
|
@@ -3439,8 +3446,8 @@ class Api:
|
|
3439
3446
|
)
|
3440
3447
|
else:
|
3441
3448
|
length = os.fstat(open_file.fileno()).st_size
|
3442
|
-
with click.progressbar(
|
3443
|
-
file=progress,
|
3449
|
+
with click.progressbar( # type: ignore
|
3450
|
+
file=progress,
|
3444
3451
|
length=length,
|
3445
3452
|
label=f"Uploading file: {file_name}",
|
3446
3453
|
fill_char=click.style("&", fg="green"),
|
@@ -3464,6 +3471,7 @@ class Api:
|
|
3464
3471
|
entity: str,
|
3465
3472
|
project: str,
|
3466
3473
|
aliases: Sequence[str],
|
3474
|
+
organization: str,
|
3467
3475
|
) -> Dict[str, Any]:
|
3468
3476
|
template = """
|
3469
3477
|
mutation LinkArtifact(
|
@@ -3485,6 +3493,14 @@ class Api:
|
|
3485
3493
|
}
|
3486
3494
|
"""
|
3487
3495
|
|
3496
|
+
org_entity = ""
|
3497
|
+
if is_artifact_registry_project(project):
|
3498
|
+
try:
|
3499
|
+
org_entity = self._resolve_org_entity_name(entity, organization)
|
3500
|
+
except ValueError as e:
|
3501
|
+
wandb.termerror(str(e))
|
3502
|
+
raise
|
3503
|
+
|
3488
3504
|
def replace(a: str, b: str) -> None:
|
3489
3505
|
nonlocal template
|
3490
3506
|
template = template.replace(a, b)
|
@@ -3500,7 +3516,7 @@ class Api:
|
|
3500
3516
|
"clientID": client_id,
|
3501
3517
|
"artifactID": server_id,
|
3502
3518
|
"artifactPortfolioName": portfolio_name,
|
3503
|
-
"entityName": entity,
|
3519
|
+
"entityName": org_entity or entity,
|
3504
3520
|
"projectName": project,
|
3505
3521
|
"aliases": [
|
3506
3522
|
{"alias": alias, "artifactCollectionName": portfolio_name}
|
@@ -3513,6 +3529,89 @@ class Api:
|
|
3513
3529
|
link_artifact: Dict[str, Any] = response["linkArtifact"]
|
3514
3530
|
return link_artifact
|
3515
3531
|
|
3532
|
+
def _resolve_org_entity_name(self, entity: str, organization: str = "") -> str:
|
3533
|
+
# resolveOrgEntityName fetches the portfolio's org entity's name.
|
3534
|
+
#
|
3535
|
+
# The organization parameter may be empty, an org's display name, or an org entity name.
|
3536
|
+
#
|
3537
|
+
# If the server doesn't support fetching the org name of a portfolio, then this returns
|
3538
|
+
# the organization parameter, or an error if it is empty. Otherwise, this returns the
|
3539
|
+
# fetched value after validating that the given organization, if not empty, matches
|
3540
|
+
# either the org's display or entity name.
|
3541
|
+
org_fields = self.server_organization_type_introspection()
|
3542
|
+
can_fetch_org_entity = "orgEntity" in org_fields
|
3543
|
+
if not organization and not can_fetch_org_entity:
|
3544
|
+
raise ValueError(
|
3545
|
+
"Fetching Registry artifacts without inputting an organization "
|
3546
|
+
"is unavailable for your server version. "
|
3547
|
+
"Please upgrade your server to 0.50.0 or later."
|
3548
|
+
)
|
3549
|
+
if not can_fetch_org_entity:
|
3550
|
+
# Server doesn't support fetching org entity to validate,
|
3551
|
+
# assume org entity is correctly inputted
|
3552
|
+
return organization
|
3553
|
+
|
3554
|
+
org_entity, org_name = self.fetch_org_entity_from_entity(entity)
|
3555
|
+
if organization:
|
3556
|
+
if organization != org_name and organization != org_entity:
|
3557
|
+
raise ValueError(
|
3558
|
+
f"Artifact belongs to the organization {org_name!r} "
|
3559
|
+
f"and cannot be linked/fetched with {organization!r}. "
|
3560
|
+
"Please update the target path with the correct organization name."
|
3561
|
+
)
|
3562
|
+
wandb.termwarn(
|
3563
|
+
"Registries can be linked/fetched using a shorthand form without specifying the organization name. "
|
3564
|
+
"Try using shorthand path format: <my_registry_name>/<artifact_name>"
|
3565
|
+
)
|
3566
|
+
return org_entity
|
3567
|
+
|
3568
|
+
def fetch_org_entity_from_entity(self, entity: str) -> Tuple[str, str]:
|
3569
|
+
query = gql(
|
3570
|
+
"""
|
3571
|
+
query FetchOrgEntityFromEntity(
|
3572
|
+
$entityName: String!,
|
3573
|
+
) {
|
3574
|
+
entity(name: $entityName) {
|
3575
|
+
isTeam
|
3576
|
+
organization {
|
3577
|
+
name
|
3578
|
+
orgEntity {
|
3579
|
+
name
|
3580
|
+
}
|
3581
|
+
}
|
3582
|
+
}
|
3583
|
+
}
|
3584
|
+
"""
|
3585
|
+
)
|
3586
|
+
response = self.gql(
|
3587
|
+
query,
|
3588
|
+
variable_values={
|
3589
|
+
"entityName": entity,
|
3590
|
+
},
|
3591
|
+
)
|
3592
|
+
try:
|
3593
|
+
is_team = response["entity"].get("isTeam", False)
|
3594
|
+
org = response["entity"]["organization"]
|
3595
|
+
org_name = org["name"] or ""
|
3596
|
+
org_entity_name = org["orgEntity"]["name"] or ""
|
3597
|
+
except (LookupError, TypeError) as e:
|
3598
|
+
if is_team:
|
3599
|
+
# This path should pretty much never be reached as all team entities have an organization.
|
3600
|
+
raise ValueError(
|
3601
|
+
f"Unable to find an organization under entity {entity!r}. "
|
3602
|
+
) from e
|
3603
|
+
else:
|
3604
|
+
raise ValueError(
|
3605
|
+
f"Unable to resolve an organization associated with the entity: {entity!r} "
|
3606
|
+
"that is initialized in the API or Run settings. This could be because "
|
3607
|
+
f"{entity!r} is a personal entity or the team entity doesn't exist. "
|
3608
|
+
"Please re-initialize the API or Run with a team entity using "
|
3609
|
+
"wandb.Api(overrides={'entity': '<my_team_entity>'}) "
|
3610
|
+
"or wandb.init(entity='<my_team_entity>') "
|
3611
|
+
) from e
|
3612
|
+
else:
|
3613
|
+
return org_entity_name, org_name
|
3614
|
+
|
3516
3615
|
def use_artifact(
|
3517
3616
|
self,
|
3518
3617
|
artifact_id: str,
|
@@ -3580,6 +3679,28 @@ class Api:
|
|
3580
3679
|
return artifact
|
3581
3680
|
return None
|
3582
3681
|
|
3682
|
+
# Fetch fields available in backend of Organization type
|
3683
|
+
def server_organization_type_introspection(self) -> List[str]:
|
3684
|
+
query_string = """
|
3685
|
+
query ProbeServerOrganization {
|
3686
|
+
OrganizationInfoType: __type(name:"Organization") {
|
3687
|
+
fields {
|
3688
|
+
name
|
3689
|
+
}
|
3690
|
+
}
|
3691
|
+
}
|
3692
|
+
"""
|
3693
|
+
|
3694
|
+
if self.server_organization_type_fields_info is None:
|
3695
|
+
query = gql(query_string)
|
3696
|
+
res = self.gql(query)
|
3697
|
+
input_fields = res.get("OrganizationInfoType", {}).get("fields", [{}])
|
3698
|
+
self.server_organization_type_fields_info = [
|
3699
|
+
field["name"] for field in input_fields if "name" in field
|
3700
|
+
]
|
3701
|
+
|
3702
|
+
return self.server_organization_type_fields_info
|
3703
|
+
|
3583
3704
|
def create_artifact_type(
|
3584
3705
|
self,
|
3585
3706
|
artifact_type_name: str,
|
wandb/sdk/internal/sender.py
CHANGED
@@ -1455,16 +1455,29 @@ class SendManager:
|
|
1455
1455
|
entity = link.portfolio_entity
|
1456
1456
|
project = link.portfolio_project
|
1457
1457
|
aliases = link.portfolio_aliases
|
1458
|
+
organization = link.portfolio_organization
|
1458
1459
|
logger.debug(
|
1459
|
-
f"link_artifact params - client_id={client_id}, server_id={server_id},
|
1460
|
+
f"link_artifact params - client_id={client_id}, server_id={server_id}, "
|
1461
|
+
f"portfolio_name={portfolio_name}, entity={entity}, project={project}, "
|
1462
|
+
f"organization={organization}"
|
1460
1463
|
)
|
1461
1464
|
if (client_id or server_id) and portfolio_name and entity and project:
|
1462
1465
|
try:
|
1463
1466
|
self._api.link_artifact(
|
1464
|
-
client_id,
|
1467
|
+
client_id,
|
1468
|
+
server_id,
|
1469
|
+
portfolio_name,
|
1470
|
+
entity,
|
1471
|
+
project,
|
1472
|
+
aliases,
|
1473
|
+
organization,
|
1465
1474
|
)
|
1466
1475
|
except Exception as e:
|
1467
|
-
|
1476
|
+
org_or_entity = organization or entity
|
1477
|
+
result.response.log_artifact_response.error_message = (
|
1478
|
+
f"error linking artifact to "
|
1479
|
+
f'"{org_or_entity}/{project}/{portfolio_name}"; error: {e}'
|
1480
|
+
)
|
1468
1481
|
logger.warning("Failed to link artifact to portfolio: %s", e)
|
1469
1482
|
self._respond_result(result)
|
1470
1483
|
|
@@ -143,7 +143,7 @@ def _replace_refs_and_allofs(schema: dict, defs: Optional[dict]) -> dict:
|
|
143
143
|
# Reference found, replace it with its definition
|
144
144
|
def_key = schema.pop("$ref").split("#/$defs/")[1]
|
145
145
|
# Also run recursive replacement in case a ref contains more refs
|
146
|
-
ret = _replace_refs_and_allofs(defs
|
146
|
+
ret = _replace_refs_and_allofs(defs[def_key], defs)
|
147
147
|
for key, val in schema.items():
|
148
148
|
if isinstance(val, dict):
|
149
149
|
# Step into dicts recursively
|
wandb/sdk/lib/module.py
CHANGED
@@ -19,6 +19,8 @@ def set_global(
|
|
19
19
|
log_model=None,
|
20
20
|
use_model=None,
|
21
21
|
link_model=None,
|
22
|
+
watch=None,
|
23
|
+
unwatch=None,
|
22
24
|
):
|
23
25
|
if run:
|
24
26
|
wandb.run = run
|
@@ -48,6 +50,10 @@ def set_global(
|
|
48
50
|
wandb.use_model = use_model
|
49
51
|
if link_model:
|
50
52
|
wandb.link_model = link_model
|
53
|
+
if watch:
|
54
|
+
wandb.watch = watch
|
55
|
+
if unwatch:
|
56
|
+
wandb.unwatch = unwatch
|
51
57
|
|
52
58
|
|
53
59
|
def unset_globals():
|
@@ -55,6 +61,12 @@ def unset_globals():
|
|
55
61
|
wandb.config = preinit.PreInitObject("wandb.config")
|
56
62
|
wandb.summary = preinit.PreInitObject("wandb.summary")
|
57
63
|
wandb.log = preinit.PreInitCallable("wandb.log", wandb.wandb_sdk.wandb_run.Run.log)
|
64
|
+
wandb.watch = preinit.PreInitCallable(
|
65
|
+
"wandb.watch", wandb.wandb_sdk.wandb_run.Run.watch
|
66
|
+
)
|
67
|
+
wandb.unwatch = preinit.PreInitCallable(
|
68
|
+
"wandb.unwatch", wandb.wandb_sdk.wandb_run.Run.unwatch
|
69
|
+
)
|
58
70
|
wandb.save = preinit.PreInitCallable(
|
59
71
|
"wandb.save", wandb.wandb_sdk.wandb_run.Run.save
|
60
72
|
)
|