wandb 0.18.4__py3-none-any.whl → 0.18.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +2 -2
- wandb/__init__.pyi +21 -19
- wandb/agents/pyagent.py +1 -1
- wandb/apis/importers/wandb.py +1 -1
- wandb/apis/normalize.py +2 -18
- wandb/apis/public/api.py +122 -62
- wandb/apis/public/artifacts.py +8 -3
- wandb/apis/public/files.py +17 -2
- wandb/apis/public/jobs.py +2 -2
- wandb/apis/public/query_generator.py +1 -1
- wandb/apis/public/runs.py +8 -8
- wandb/apis/public/teams.py +3 -3
- wandb/apis/public/users.py +1 -1
- wandb/apis/public/utils.py +68 -0
- wandb/bin/gpu_stats +0 -0
- wandb/cli/cli.py +12 -3
- wandb/data_types.py +1 -1
- wandb/docker/__init__.py +2 -1
- wandb/docker/auth.py +2 -3
- wandb/errors/links.py +73 -0
- wandb/errors/term.py +7 -6
- wandb/filesync/step_prepare.py +1 -1
- wandb/filesync/upload_job.py +1 -1
- wandb/integration/catboost/catboost.py +2 -2
- wandb/integration/diffusers/pipeline_resolver.py +1 -1
- wandb/integration/diffusers/resolvers/multimodal.py +6 -6
- wandb/integration/diffusers/resolvers/utils.py +1 -1
- wandb/integration/fastai/__init__.py +3 -2
- wandb/integration/keras/callbacks/metrics_logger.py +1 -1
- wandb/integration/keras/callbacks/model_checkpoint.py +1 -1
- wandb/integration/keras/keras.py +1 -1
- wandb/integration/kfp/kfp_patch.py +1 -1
- wandb/integration/lightgbm/__init__.py +2 -2
- wandb/integration/magic.py +2 -2
- wandb/integration/metaflow/metaflow.py +1 -1
- wandb/integration/sacred/__init__.py +1 -1
- wandb/integration/sagemaker/auth.py +1 -1
- wandb/integration/sklearn/plot/classifier.py +7 -7
- wandb/integration/sklearn/plot/clusterer.py +3 -3
- wandb/integration/sklearn/plot/regressor.py +3 -3
- wandb/integration/sklearn/plot/shared.py +2 -2
- wandb/integration/tensorboard/log.py +2 -2
- wandb/integration/ultralytics/callback.py +2 -2
- wandb/integration/xgboost/xgboost.py +1 -1
- wandb/jupyter.py +0 -1
- wandb/plot/__init__.py +17 -8
- wandb/plot/bar.py +53 -27
- wandb/plot/confusion_matrix.py +151 -70
- wandb/plot/custom_chart.py +124 -0
- wandb/plot/histogram.py +46 -20
- wandb/plot/line.py +57 -26
- wandb/plot/line_series.py +148 -60
- wandb/plot/pr_curve.py +89 -44
- wandb/plot/roc_curve.py +82 -37
- wandb/plot/scatter.py +53 -20
- wandb/plot/viz.py +20 -102
- wandb/sdk/artifacts/artifact.py +280 -328
- wandb/sdk/artifacts/artifact_manifest.py +10 -9
- wandb/sdk/artifacts/artifact_manifest_entry.py +1 -1
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +9 -4
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +1 -3
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -1
- wandb/sdk/backend/backend.py +0 -1
- wandb/sdk/data_types/audio.py +1 -1
- wandb/sdk/data_types/base_types/media.py +66 -5
- wandb/sdk/data_types/bokeh.py +1 -1
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +1 -1
- wandb/sdk/data_types/helper_types/image_mask.py +2 -2
- wandb/sdk/data_types/histogram.py +1 -1
- wandb/sdk/data_types/html.py +1 -1
- wandb/sdk/data_types/image.py +1 -1
- wandb/sdk/data_types/molecule.py +3 -3
- wandb/sdk/data_types/object_3d.py +4 -4
- wandb/sdk/data_types/plotly.py +1 -1
- wandb/sdk/data_types/saved_model.py +0 -1
- wandb/sdk/data_types/table.py +7 -7
- wandb/sdk/data_types/trace_tree.py +1 -1
- wandb/sdk/data_types/video.py +4 -3
- wandb/sdk/interface/router.py +0 -2
- wandb/sdk/internal/datastore.py +1 -1
- wandb/sdk/internal/file_pusher.py +1 -1
- wandb/sdk/internal/file_stream.py +4 -4
- wandb/sdk/internal/handler.py +3 -2
- wandb/sdk/internal/internal.py +1 -1
- wandb/sdk/internal/internal_api.py +183 -64
- wandb/sdk/internal/job_builder.py +4 -3
- wandb/sdk/internal/system/assets/__init__.py +0 -2
- wandb/sdk/internal/tb_watcher.py +11 -10
- wandb/sdk/launch/_launch.py +4 -3
- wandb/sdk/launch/_launch_add.py +2 -2
- wandb/sdk/launch/builder/kaniko_builder.py +0 -1
- wandb/sdk/launch/create_job.py +1 -0
- wandb/sdk/launch/environment/local_environment.py +0 -1
- wandb/sdk/launch/errors.py +0 -6
- wandb/sdk/launch/registry/local_registry.py +0 -2
- wandb/sdk/launch/runner/abstract.py +0 -5
- wandb/sdk/launch/sweeps/__init__.py +0 -2
- wandb/sdk/launch/sweeps/scheduler.py +0 -2
- wandb/sdk/launch/sweeps/scheduler_sweep.py +0 -1
- wandb/sdk/lib/apikey.py +3 -3
- wandb/sdk/lib/file_stream_utils.py +1 -1
- wandb/sdk/lib/filesystem.py +1 -1
- wandb/sdk/lib/ipython.py +16 -9
- wandb/sdk/lib/mailbox.py +0 -4
- wandb/sdk/lib/printer.py +44 -8
- wandb/sdk/lib/retry.py +1 -1
- wandb/sdk/service/service.py +3 -3
- wandb/sdk/service/streams.py +2 -4
- wandb/sdk/wandb_init.py +20 -20
- wandb/sdk/wandb_login.py +1 -1
- wandb/sdk/wandb_require.py +1 -4
- wandb/sdk/wandb_run.py +57 -69
- wandb/sdk/wandb_settings.py +3 -4
- wandb/sdk/wandb_sync.py +2 -1
- wandb/util.py +46 -18
- wandb/wandb_agent.py +3 -3
- wandb/wandb_controller.py +2 -2
- {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/METADATA +1 -1
- {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/RECORD +124 -125
- wandb/sdk/internal/system/assets/gpu_apple.py +0 -177
- wandb/sdk/lib/_wburls_generate.py +0 -25
- wandb/sdk/lib/_wburls_generated.py +0 -22
- wandb/sdk/lib/wburls.py +0 -46
- {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/WHEEL +0 -0
- {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/entry_points.txt +0 -0
- {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/licenses/LICENSE +0 -0
@@ -20,9 +20,9 @@ from typing import (
|
|
20
20
|
Dict,
|
21
21
|
Iterable,
|
22
22
|
List,
|
23
|
-
Literal,
|
24
23
|
Mapping,
|
25
24
|
MutableMapping,
|
25
|
+
NamedTuple,
|
26
26
|
Optional,
|
27
27
|
Sequence,
|
28
28
|
TextIO,
|
@@ -30,6 +30,11 @@ from typing import (
|
|
30
30
|
Union,
|
31
31
|
)
|
32
32
|
|
33
|
+
if sys.version_info >= (3, 8):
|
34
|
+
from typing import Literal
|
35
|
+
else:
|
36
|
+
from typing_extensions import Literal
|
37
|
+
|
33
38
|
import click
|
34
39
|
import requests
|
35
40
|
import yaml
|
@@ -165,6 +170,50 @@ class _ThreadLocalData(threading.local):
|
|
165
170
|
self.context = None
|
166
171
|
|
167
172
|
|
173
|
+
class _OrgNames(NamedTuple):
|
174
|
+
entity_name: str
|
175
|
+
display_name: str
|
176
|
+
|
177
|
+
|
178
|
+
def _match_org_with_fetched_org_entities(
|
179
|
+
organization: str, orgs: Sequence[_OrgNames]
|
180
|
+
) -> str:
|
181
|
+
"""Match the organization provided in the path with the org entity or org name of the input entity.
|
182
|
+
|
183
|
+
Args:
|
184
|
+
organization: The organization name to match
|
185
|
+
orgs: List of tuples containing (org_entity_name, org_display_name)
|
186
|
+
|
187
|
+
Returns:
|
188
|
+
str: The matched org entity name
|
189
|
+
|
190
|
+
Raises:
|
191
|
+
ValueError: If no matching organization is found or if multiple orgs exist without a match
|
192
|
+
"""
|
193
|
+
for org_names in orgs:
|
194
|
+
if organization in org_names:
|
195
|
+
wandb.termwarn(
|
196
|
+
"Registries can be linked/fetched using a shorthand form without specifying the organization name. "
|
197
|
+
"Try using shorthand path format: <my_registry_name>/<artifact_name> or "
|
198
|
+
"just <my_registry_name> if fetching just the project."
|
199
|
+
)
|
200
|
+
return org_names.entity_name
|
201
|
+
|
202
|
+
if len(orgs) == 1:
|
203
|
+
raise ValueError(
|
204
|
+
f"Expecting the organization name or entity name to match {orgs[0].display_name!r} "
|
205
|
+
f"and cannot be linked/fetched with {organization!r}. "
|
206
|
+
"Please update the target path with the correct organization name."
|
207
|
+
)
|
208
|
+
|
209
|
+
raise ValueError(
|
210
|
+
"Personal entity belongs to multiple organizations "
|
211
|
+
f"and cannot be linked/fetched with {organization!r}. "
|
212
|
+
"Please update the target path with the correct organization name "
|
213
|
+
"or use a team entity in the entity settings."
|
214
|
+
)
|
215
|
+
|
216
|
+
|
168
217
|
class Api:
|
169
218
|
"""W&B Internal Api wrapper.
|
170
219
|
|
@@ -174,7 +223,7 @@ class Api:
|
|
174
223
|
directory. If none can be found, we look in the current user's home
|
175
224
|
directory.
|
176
225
|
|
177
|
-
|
226
|
+
Args:
|
178
227
|
default_settings(dict, optional): If you aren't using a settings
|
179
228
|
file, or you wish to override the section to use in the settings file
|
180
229
|
Override the settings here.
|
@@ -309,6 +358,7 @@ class Api:
|
|
309
358
|
self.server_create_artifact_input_info: Optional[List[str]] = None
|
310
359
|
self.server_artifact_fields_info: Optional[List[str]] = None
|
311
360
|
self.server_organization_type_fields_info: Optional[List[str]] = None
|
361
|
+
self.server_supports_enabling_artifact_usage_tracking: Optional[bool] = None
|
312
362
|
self._max_cli_version: Optional[str] = None
|
313
363
|
self._server_settings_type: Optional[List[str]] = None
|
314
364
|
self.fail_run_queue_item_input_info: Optional[List[str]] = None
|
@@ -430,7 +480,7 @@ class Api:
|
|
430
480
|
def settings(self, key: Optional[str] = None, section: Optional[str] = None) -> Any:
|
431
481
|
"""The settings overridden from the wandb/settings file.
|
432
482
|
|
433
|
-
|
483
|
+
Args:
|
434
484
|
key (str, optional): If provided only this setting is returned
|
435
485
|
section (str, optional): If provided this section of the setting file is
|
436
486
|
used, defaults to "default"
|
@@ -510,7 +560,7 @@ class Api:
|
|
510
560
|
) -> Tuple[str, str]:
|
511
561
|
"""Parse a slug into a project and run.
|
512
562
|
|
513
|
-
|
563
|
+
Args:
|
514
564
|
slug (str): The slug to parse
|
515
565
|
project (str, optional): The project to use, if not provided it will be
|
516
566
|
inferred from the slug
|
@@ -949,7 +999,7 @@ class Api:
|
|
949
999
|
def list_projects(self, entity: Optional[str] = None) -> List[Dict[str, str]]:
|
950
1000
|
"""List projects in W&B scoped by entity.
|
951
1001
|
|
952
|
-
|
1002
|
+
Args:
|
953
1003
|
entity (str, optional): The entity to scope this project to.
|
954
1004
|
|
955
1005
|
Returns:
|
@@ -981,7 +1031,7 @@ class Api:
|
|
981
1031
|
def project(self, project: str, entity: Optional[str] = None) -> "_Response":
|
982
1032
|
"""Retrieve project.
|
983
1033
|
|
984
|
-
|
1034
|
+
Args:
|
985
1035
|
project (str): The project to get details for
|
986
1036
|
entity (str, optional): The entity to scope this project to.
|
987
1037
|
|
@@ -1016,7 +1066,7 @@ class Api:
|
|
1016
1066
|
) -> Dict[str, Any]:
|
1017
1067
|
"""Retrieve sweep.
|
1018
1068
|
|
1019
|
-
|
1069
|
+
Args:
|
1020
1070
|
sweep (str): The sweep to get details for
|
1021
1071
|
specs (str): history specs
|
1022
1072
|
project (str, optional): The project to scope this sweep to.
|
@@ -1089,7 +1139,7 @@ class Api:
|
|
1089
1139
|
) -> List[Dict[str, str]]:
|
1090
1140
|
"""List runs in W&B scoped by project.
|
1091
1141
|
|
1092
|
-
|
1142
|
+
Args:
|
1093
1143
|
project (str): The project to scope the runs to
|
1094
1144
|
entity (str, optional): The entity to scope this project to. Defaults to public models
|
1095
1145
|
|
@@ -1130,7 +1180,7 @@ class Api:
|
|
1130
1180
|
) -> Tuple[str, Dict[str, Any], Optional[str], Dict[str, Any]]:
|
1131
1181
|
"""Get the relevant configs for a run.
|
1132
1182
|
|
1133
|
-
|
1183
|
+
Args:
|
1134
1184
|
project (str): The project to download, (can include bucket)
|
1135
1185
|
run (str, optional): The run to download
|
1136
1186
|
entity (str, optional): The entity to scope this project to.
|
@@ -1219,7 +1269,7 @@ class Api:
|
|
1219
1269
|
) -> Optional[Dict[str, Any]]:
|
1220
1270
|
"""Check if a run exists and get resume information.
|
1221
1271
|
|
1222
|
-
|
1272
|
+
Args:
|
1223
1273
|
entity (str): The entity to scope this project to.
|
1224
1274
|
project_name (str): The project to download, (can include bucket)
|
1225
1275
|
name (str): The run to download
|
@@ -1324,7 +1374,7 @@ class Api:
|
|
1324
1374
|
) -> Dict[str, Any]:
|
1325
1375
|
"""Create a new project.
|
1326
1376
|
|
1327
|
-
|
1377
|
+
Args:
|
1328
1378
|
project (str): The project to create
|
1329
1379
|
description (str, optional): A description of this project
|
1330
1380
|
entity (str, optional): The entity to scope this project to.
|
@@ -2152,7 +2202,7 @@ class Api:
|
|
2152
2202
|
) -> Tuple[dict, bool, Optional[List]]:
|
2153
2203
|
"""Update a run.
|
2154
2204
|
|
2155
|
-
|
2205
|
+
Args:
|
2156
2206
|
id (str, optional): The existing run to update
|
2157
2207
|
name (str, optional): The name of the run to create
|
2158
2208
|
group (str, optional): Name of the group this run is a part of
|
@@ -2339,7 +2389,7 @@ class Api:
|
|
2339
2389
|
) -> dict:
|
2340
2390
|
"""Rewinds a run to a previous state.
|
2341
2391
|
|
2342
|
-
|
2392
|
+
Args:
|
2343
2393
|
run_name (str): The name of the run to rewind
|
2344
2394
|
metric_name (str): The name of the metric to rewind to
|
2345
2395
|
metric_value (float): The value of the metric to rewind to
|
@@ -2526,7 +2576,7 @@ class Api:
|
|
2526
2576
|
) -> Tuple[str, List[str], Dict[str, Dict[str, Any]]]:
|
2527
2577
|
"""Generate temporary resumable upload urls.
|
2528
2578
|
|
2529
|
-
|
2579
|
+
Args:
|
2530
2580
|
project (str): The project to download
|
2531
2581
|
files (list or dict): The filenames to upload
|
2532
2582
|
run (str, optional): The run to upload to
|
@@ -2663,7 +2713,7 @@ class Api:
|
|
2663
2713
|
) -> Dict[str, Dict[str, str]]:
|
2664
2714
|
"""Generate download urls.
|
2665
2715
|
|
2666
|
-
|
2716
|
+
Args:
|
2667
2717
|
project (str): The project to download
|
2668
2718
|
run (str): The run to upload to
|
2669
2719
|
entity (str, optional): The entity to scope this project to. Defaults to wandb models
|
@@ -2722,7 +2772,7 @@ class Api:
|
|
2722
2772
|
) -> Optional[Dict[str, str]]:
|
2723
2773
|
"""Generate download urls.
|
2724
2774
|
|
2725
|
-
|
2775
|
+
Args:
|
2726
2776
|
project (str): The project to download
|
2727
2777
|
file_name (str): The name of the file to download
|
2728
2778
|
run (str): The run to upload to
|
@@ -2775,7 +2825,7 @@ class Api:
|
|
2775
2825
|
def download_file(self, url: str) -> Tuple[int, requests.Response]:
|
2776
2826
|
"""Initiate a streaming download.
|
2777
2827
|
|
2778
|
-
|
2828
|
+
Args:
|
2779
2829
|
url (str): The url to download
|
2780
2830
|
|
2781
2831
|
Returns:
|
@@ -2809,7 +2859,7 @@ class Api:
|
|
2809
2859
|
) -> Tuple[str, Optional[requests.Response]]:
|
2810
2860
|
"""Download a file from a run and write it to wandb/.
|
2811
2861
|
|
2812
|
-
|
2862
|
+
Args:
|
2813
2863
|
metadata (obj): The metadata object for the file to download. Comes from Api.download_urls().
|
2814
2864
|
out_dir (str, optional): The directory to write the file to. Defaults to wandb/
|
2815
2865
|
|
@@ -2873,7 +2923,7 @@ class Api:
|
|
2873
2923
|
) -> Optional[requests.Response]:
|
2874
2924
|
"""Upload a file chunk to S3 with failure resumption.
|
2875
2925
|
|
2876
|
-
|
2926
|
+
Args:
|
2877
2927
|
url: The url to download
|
2878
2928
|
upload_chunk: The path to the file you want to upload
|
2879
2929
|
extra_headers: A dictionary of extra headers to send with the request
|
@@ -2926,7 +2976,7 @@ class Api:
|
|
2926
2976
|
) -> Optional[requests.Response]:
|
2927
2977
|
"""Upload a file to W&B with failure resumption.
|
2928
2978
|
|
2929
|
-
|
2979
|
+
Args:
|
2930
2980
|
url: The url to download
|
2931
2981
|
file: The path to the file you want to upload
|
2932
2982
|
callback: A callback which is passed the number of
|
@@ -2998,7 +3048,7 @@ class Api:
|
|
2998
3048
|
) -> dict:
|
2999
3049
|
"""Register a new agent.
|
3000
3050
|
|
3001
|
-
|
3051
|
+
Args:
|
3002
3052
|
host (str): hostname
|
3003
3053
|
sweep_id (str): sweep id
|
3004
3054
|
project_name: (str): model that contains sweep
|
@@ -3048,7 +3098,7 @@ class Api:
|
|
3048
3098
|
) -> List[Dict[str, Any]]:
|
3049
3099
|
"""Notify server about agent state, receive commands.
|
3050
3100
|
|
3051
|
-
|
3101
|
+
Args:
|
3052
3102
|
agent_id (str): agent_id
|
3053
3103
|
metrics (dict): system metrics
|
3054
3104
|
run_states (dict): run_id: state mapping
|
@@ -3157,7 +3207,7 @@ class Api:
|
|
3157
3207
|
) -> Tuple[str, List[str]]:
|
3158
3208
|
"""Upsert a sweep object.
|
3159
3209
|
|
3160
|
-
|
3210
|
+
Args:
|
3161
3211
|
config (dict): sweep config (will be converted to yaml)
|
3162
3212
|
controller (str): controller to use
|
3163
3213
|
launch_scheduler (str): launch scheduler to use
|
@@ -3338,7 +3388,7 @@ class Api:
|
|
3338
3388
|
) -> "List[requests.Response]":
|
3339
3389
|
"""Download files from W&B.
|
3340
3390
|
|
3341
|
-
|
3391
|
+
Args:
|
3342
3392
|
project (str): The project to download
|
3343
3393
|
run (str, optional): The run to upload to
|
3344
3394
|
entity (str, optional): The entity to scope this project to. Defaults to wandb models
|
@@ -3373,7 +3423,7 @@ class Api:
|
|
3373
3423
|
) -> "List[Optional[requests.Response]]":
|
3374
3424
|
"""Uploads multiple files to W&B.
|
3375
3425
|
|
3376
|
-
|
3426
|
+
Args:
|
3377
3427
|
files (list or dict): The filenames to upload, when dict the values are open files
|
3378
3428
|
run (str, optional): The run to upload to
|
3379
3429
|
entity (str, optional): The entity to scope this project to. Defaults to wandb models
|
@@ -3492,7 +3542,9 @@ class Api:
|
|
3492
3542
|
org_entity = ""
|
3493
3543
|
if is_artifact_registry_project(project):
|
3494
3544
|
try:
|
3495
|
-
org_entity = self._resolve_org_entity_name(
|
3545
|
+
org_entity = self._resolve_org_entity_name(
|
3546
|
+
entity=entity, organization=organization
|
3547
|
+
)
|
3496
3548
|
except ValueError as e:
|
3497
3549
|
wandb.termerror(str(e))
|
3498
3550
|
raise
|
@@ -3534,47 +3586,67 @@ class Api:
|
|
3534
3586
|
# the organization parameter, or an error if it is empty. Otherwise, this returns the
|
3535
3587
|
# fetched value after validating that the given organization, if not empty, matches
|
3536
3588
|
# either the org's display or entity name.
|
3589
|
+
|
3590
|
+
if not entity:
|
3591
|
+
raise ValueError("Entity name is required to resolve org entity name.")
|
3592
|
+
|
3537
3593
|
org_fields = self.server_organization_type_introspection()
|
3538
|
-
|
3539
|
-
if not organization and not
|
3594
|
+
can_shorthand_org_entity = "orgEntity" in org_fields
|
3595
|
+
if not organization and not can_shorthand_org_entity:
|
3540
3596
|
raise ValueError(
|
3541
3597
|
"Fetching Registry artifacts without inputting an organization "
|
3542
3598
|
"is unavailable for your server version. "
|
3543
3599
|
"Please upgrade your server to 0.50.0 or later."
|
3544
3600
|
)
|
3545
|
-
if not
|
3601
|
+
if not can_shorthand_org_entity:
|
3546
3602
|
# Server doesn't support fetching org entity to validate,
|
3547
3603
|
# assume org entity is correctly inputted
|
3548
3604
|
return organization
|
3549
3605
|
|
3550
|
-
|
3606
|
+
orgs_from_entity = self._fetch_orgs_and_org_entities_from_entity(entity)
|
3551
3607
|
if organization:
|
3552
|
-
|
3553
|
-
|
3554
|
-
|
3555
|
-
|
3556
|
-
|
3557
|
-
|
3558
|
-
|
3559
|
-
"
|
3560
|
-
"
|
3608
|
+
return _match_org_with_fetched_org_entities(organization, orgs_from_entity)
|
3609
|
+
|
3610
|
+
# If no input organization provided, error if entity belongs to multiple orgs because we
|
3611
|
+
# cannot determine which one to use.
|
3612
|
+
if len(orgs_from_entity) > 1:
|
3613
|
+
raise ValueError(
|
3614
|
+
f"Personal entity {entity!r} belongs to multiple organizations "
|
3615
|
+
"and cannot be used without specifying the organization name. "
|
3616
|
+
"Please specify the organization in the Registry path or use a team entity in the entity settings."
|
3561
3617
|
)
|
3562
|
-
return
|
3618
|
+
return orgs_from_entity[0].entity_name
|
3563
3619
|
|
3564
|
-
def
|
3620
|
+
def _fetch_orgs_and_org_entities_from_entity(self, entity: str) -> List[_OrgNames]:
|
3621
|
+
"""Fetches organization entity names and display names for a given entity.
|
3622
|
+
|
3623
|
+
Args:
|
3624
|
+
entity (str): Entity name to lookup. Can be either a personal or team entity.
|
3625
|
+
|
3626
|
+
Returns:
|
3627
|
+
List[_OrgNames]: List of _OrgNames tuples. (_OrgNames(entity_name, display_name))
|
3628
|
+
|
3629
|
+
Raises:
|
3630
|
+
ValueError: If entity is not found, has no organizations, or other validation errors.
|
3631
|
+
"""
|
3565
3632
|
query = gql(
|
3566
3633
|
"""
|
3567
|
-
query FetchOrgEntityFromEntity(
|
3568
|
-
$entityName: String!,
|
3569
|
-
) {
|
3634
|
+
query FetchOrgEntityFromEntity($entityName: String!) {
|
3570
3635
|
entity(name: $entityName) {
|
3571
|
-
isTeam
|
3572
3636
|
organization {
|
3573
3637
|
name
|
3574
3638
|
orgEntity {
|
3575
3639
|
name
|
3576
3640
|
}
|
3577
3641
|
}
|
3642
|
+
user {
|
3643
|
+
organizations {
|
3644
|
+
name
|
3645
|
+
orgEntity {
|
3646
|
+
name
|
3647
|
+
}
|
3648
|
+
}
|
3649
|
+
}
|
3578
3650
|
}
|
3579
3651
|
}
|
3580
3652
|
"""
|
@@ -3585,28 +3657,40 @@ class Api:
|
|
3585
3657
|
"entityName": entity,
|
3586
3658
|
},
|
3587
3659
|
)
|
3588
|
-
|
3589
|
-
|
3590
|
-
|
3591
|
-
|
3592
|
-
|
3593
|
-
|
3594
|
-
|
3595
|
-
|
3660
|
+
|
3661
|
+
# Parse organization from response
|
3662
|
+
entity_resp = response["entity"]["organization"]
|
3663
|
+
user_resp = response["entity"]["user"]
|
3664
|
+
# Check for organization under team/org entity type
|
3665
|
+
if entity_resp:
|
3666
|
+
org_name = entity_resp.get("name")
|
3667
|
+
org_entity_name = entity_resp.get("orgEntity") and entity_resp[
|
3668
|
+
"orgEntity"
|
3669
|
+
].get("name")
|
3670
|
+
if not org_name or not org_entity_name:
|
3596
3671
|
raise ValueError(
|
3597
|
-
f"Unable to find an organization under entity {entity!r}.
|
3598
|
-
)
|
3599
|
-
|
3672
|
+
f"Unable to find an organization under entity {entity!r}."
|
3673
|
+
)
|
3674
|
+
return [_OrgNames(entity_name=org_entity_name, display_name=org_name)]
|
3675
|
+
# Check for organization under personal entity type, where a user can belong to multiple orgs
|
3676
|
+
elif user_resp:
|
3677
|
+
orgs = user_resp.get("organizations", [])
|
3678
|
+
org_entities_return = [
|
3679
|
+
_OrgNames(
|
3680
|
+
entity_name=org["orgEntity"]["name"], display_name=org["name"]
|
3681
|
+
)
|
3682
|
+
for org in orgs
|
3683
|
+
if org.get("orgEntity") and org.get("name")
|
3684
|
+
]
|
3685
|
+
if not org_entities_return:
|
3600
3686
|
raise ValueError(
|
3601
|
-
f"Unable to resolve an organization associated with
|
3602
|
-
"
|
3603
|
-
|
3604
|
-
|
3605
|
-
|
3606
|
-
"or wandb.init(entity='<my_team_entity>') "
|
3607
|
-
) from e
|
3687
|
+
f"Unable to resolve an organization associated with personal entity: {entity!r}. "
|
3688
|
+
"This could be because its a personal entity that doesn't belong to any organizations. "
|
3689
|
+
"Please specify the organization in the Registry path or use a team entity in the entity settings."
|
3690
|
+
)
|
3691
|
+
return org_entities_return
|
3608
3692
|
else:
|
3609
|
-
|
3693
|
+
raise ValueError(f"Unable to find an organization under entity {entity!r}.")
|
3610
3694
|
|
3611
3695
|
def use_artifact(
|
3612
3696
|
self,
|
@@ -3697,6 +3781,41 @@ class Api:
|
|
3697
3781
|
|
3698
3782
|
return self.server_organization_type_fields_info
|
3699
3783
|
|
3784
|
+
# Fetch input arguments for the "artifact" endpoint on the "Project" type
|
3785
|
+
def server_project_type_introspection(self) -> bool:
|
3786
|
+
if self.server_supports_enabling_artifact_usage_tracking is not None:
|
3787
|
+
return self.server_supports_enabling_artifact_usage_tracking
|
3788
|
+
|
3789
|
+
query_string = """
|
3790
|
+
query ProbeServerProjectInfo {
|
3791
|
+
ProjectInfoType: __type(name:"Project") {
|
3792
|
+
fields {
|
3793
|
+
name
|
3794
|
+
args {
|
3795
|
+
name
|
3796
|
+
}
|
3797
|
+
}
|
3798
|
+
}
|
3799
|
+
}
|
3800
|
+
"""
|
3801
|
+
|
3802
|
+
query = gql(query_string)
|
3803
|
+
res = self.gql(query)
|
3804
|
+
input_fields = res.get("ProjectInfoType", {}).get("fields", [{}])
|
3805
|
+
artifact_args: List[Dict[str, str]] = next(
|
3806
|
+
(
|
3807
|
+
field.get("args", [])
|
3808
|
+
for field in input_fields
|
3809
|
+
if field.get("name") == "artifact"
|
3810
|
+
),
|
3811
|
+
[],
|
3812
|
+
)
|
3813
|
+
self.server_supports_enabling_artifact_usage_tracking = any(
|
3814
|
+
arg.get("name") == "enableTracking" for arg in artifact_args
|
3815
|
+
)
|
3816
|
+
|
3817
|
+
return self.server_supports_enabling_artifact_usage_tracking
|
3818
|
+
|
3700
3819
|
def create_artifact_type(
|
3701
3820
|
self,
|
3702
3821
|
artifact_type_name: str,
|
@@ -313,7 +313,8 @@ class JobBuilder:
|
|
313
313
|
"build_context": metadata.get("build_context"),
|
314
314
|
"dockerfile": metadata.get("dockerfile"),
|
315
315
|
}
|
316
|
-
|
316
|
+
artifact_basename, *_ = self._logged_code_artifact["name"].split(":")
|
317
|
+
name = self._make_job_name(artifact_basename)
|
317
318
|
|
318
319
|
return source, name
|
319
320
|
|
@@ -380,7 +381,7 @@ class JobBuilder:
|
|
380
381
|
]:
|
381
382
|
"""Construct a job source dict and name from the current run.
|
382
383
|
|
383
|
-
|
384
|
+
Args:
|
384
385
|
source_type (str): The type of source to build the job from. One of
|
385
386
|
"repo", "artifact", or "image".
|
386
387
|
"""
|
@@ -427,7 +428,7 @@ class JobBuilder:
|
|
427
428
|
) -> Optional[Artifact]:
|
428
429
|
"""Build a job artifact from the current run.
|
429
430
|
|
430
|
-
|
431
|
+
Args:
|
431
432
|
api (Api): The API object to use to create the job artifact.
|
432
433
|
build_context (Optional[str]): Path within the job source code to
|
433
434
|
the image build context. Saved as part of the job for future
|
@@ -4,7 +4,6 @@ __all__ = (
|
|
4
4
|
"Disk",
|
5
5
|
"GPU",
|
6
6
|
"GPUAMD",
|
7
|
-
"GPUApple",
|
8
7
|
"IPU",
|
9
8
|
"Memory",
|
10
9
|
"Network",
|
@@ -18,7 +17,6 @@ from .cpu import CPU
|
|
18
17
|
from .disk import Disk
|
19
18
|
from .gpu import GPU
|
20
19
|
from .gpu_amd import GPUAMD
|
21
|
-
from .gpu_apple import GPUApple
|
22
20
|
from .ipu import IPU
|
23
21
|
from .memory import Memory
|
24
22
|
from .network import Network
|
wandb/sdk/internal/tb_watcher.py
CHANGED
@@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
12
12
|
|
13
13
|
import wandb
|
14
14
|
from wandb import util
|
15
|
-
from wandb.plot
|
15
|
+
from wandb.plot import CustomChart
|
16
16
|
from wandb.sdk.interface.interface import GlobStr
|
17
17
|
from wandb.sdk.lib import filesystem
|
18
18
|
|
@@ -73,7 +73,7 @@ def is_tfevents_file_created_by(
|
|
73
73
|
if not path:
|
74
74
|
raise ValueError("Path must be a nonempty string")
|
75
75
|
basename = os.path.basename(path)
|
76
|
-
if basename.endswith(".profile-empty"
|
76
|
+
if basename.endswith((".profile-empty", ".sagemaker-uploaded")):
|
77
77
|
return False
|
78
78
|
fname_components = basename.split(".")
|
79
79
|
try:
|
@@ -439,18 +439,19 @@ class TBEventConsumer:
|
|
439
439
|
|
440
440
|
def _save_row(self, row: "HistoryDict") -> None:
|
441
441
|
chart_keys = set()
|
442
|
-
for k in row:
|
443
|
-
if isinstance(
|
442
|
+
for k, v in row.items():
|
443
|
+
if isinstance(v, CustomChart):
|
444
444
|
chart_keys.add(k)
|
445
|
-
|
446
|
-
|
447
|
-
|
445
|
+
v.set_key(k)
|
446
|
+
self._tbwatcher._interface.publish_config(
|
447
|
+
key=v.spec.config_key,
|
448
|
+
val=v.spec.config_value,
|
448
449
|
)
|
449
|
-
row[k] = row[k]._data
|
450
|
-
self._tbwatcher._interface.publish_config(val=value, key=key)
|
451
450
|
|
452
451
|
for k in chart_keys:
|
453
|
-
|
452
|
+
chart = row.pop(k)
|
453
|
+
if isinstance(chart, CustomChart):
|
454
|
+
row[chart.spec.table_key] = chart.table
|
454
455
|
|
455
456
|
self._tbwatcher._interface.publish_history(
|
456
457
|
row, run=self._internal_run, publish_step=False
|
wandb/sdk/launch/_launch.py
CHANGED
@@ -120,10 +120,11 @@ def resolve_agent_config( # noqa: C901
|
|
120
120
|
if isinstance(resolved_config.get("queue"), str):
|
121
121
|
resolved_config["queues"].append(resolved_config["queue"])
|
122
122
|
else:
|
123
|
-
|
124
|
-
|
125
|
-
|
123
|
+
msg = (
|
124
|
+
"Invalid launch agent config for key 'queue' with type: {type(resolved_config.get('queue'))} "
|
125
|
+
"(expected str). Specify multiple queues with the 'queues' key"
|
126
126
|
)
|
127
|
+
raise LaunchError(msg)
|
127
128
|
|
128
129
|
keys = ["entity"]
|
129
130
|
settings = {
|
wandb/sdk/launch/_launch_add.py
CHANGED
@@ -61,7 +61,7 @@ def launch_add(
|
|
61
61
|
config: A dictionary containing the configuration for the run. May also contain
|
62
62
|
resource specific arguments under the key "resource_args"
|
63
63
|
template_variables: A dictionary containing values of template variables for a run queue.
|
64
|
-
Expected format of {"VAR_NAME": VAR_VALUE}
|
64
|
+
Expected format of `{"VAR_NAME": VAR_VALUE}`
|
65
65
|
project: Target project to send launched run to
|
66
66
|
entity: Target entity to send launched run to
|
67
67
|
queue: the name of the queue to enqueue the run to
|
@@ -240,7 +240,7 @@ def _launch_add(
|
|
240
240
|
public_api = public.Api()
|
241
241
|
if job is not None:
|
242
242
|
try:
|
243
|
-
public_api.
|
243
|
+
public_api._artifact(job, type="job")
|
244
244
|
except (ValueError, CommError) as e:
|
245
245
|
raise LaunchError(f"Unable to fetch job with name {job}: {e}")
|
246
246
|
|
wandb/sdk/launch/create_job.py
CHANGED
@@ -168,6 +168,7 @@ def _create_job(
|
|
168
168
|
return None, "", []
|
169
169
|
|
170
170
|
job_builder = _configure_job_builder_for_partial(tempdir.name, job_source=job_type)
|
171
|
+
job_builder._settings.update(job_name=name)
|
171
172
|
if job_type == "code":
|
172
173
|
assert entrypoint is not None
|
173
174
|
job_name = _make_code_artifact(
|
wandb/sdk/launch/errors.py
CHANGED
@@ -4,16 +4,10 @@ from wandb.errors import Error
|
|
4
4
|
class LaunchError(Error):
|
5
5
|
"""Raised when a known error occurs in wandb launch."""
|
6
6
|
|
7
|
-
pass
|
8
|
-
|
9
7
|
|
10
8
|
class LaunchDockerError(Error):
|
11
9
|
"""Raised when Docker daemon is not running."""
|
12
10
|
|
13
|
-
pass
|
14
|
-
|
15
11
|
|
16
12
|
class ExecutionError(Error):
|
17
13
|
"""Generic execution exception."""
|
18
|
-
|
19
|
-
pass
|
@@ -20,7 +20,6 @@ class LocalRegistry(AbstractRegistry):
|
|
20
20
|
|
21
21
|
def __init__(self) -> None:
|
22
22
|
"""Initialize a local registry."""
|
23
|
-
pass
|
24
23
|
|
25
24
|
@classmethod
|
26
25
|
def from_config(
|
@@ -40,7 +39,6 @@ class LocalRegistry(AbstractRegistry):
|
|
40
39
|
|
41
40
|
async def verify(self) -> None:
|
42
41
|
"""Verify the local registry by doing nothing."""
|
43
|
-
pass
|
44
42
|
|
45
43
|
async def get_username_password(self) -> Tuple[str, str]:
|
46
44
|
"""Get the username and password of the local registry."""
|