wandb 0.13.10__py3-none-any.whl → 0.14.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +2 -3
- wandb/apis/__init__.py +1 -3
- wandb/apis/importers/__init__.py +4 -0
- wandb/apis/importers/base.py +312 -0
- wandb/apis/importers/mlflow.py +113 -0
- wandb/apis/internal.py +29 -2
- wandb/apis/normalize.py +6 -5
- wandb/apis/public.py +163 -180
- wandb/apis/reports/_templates.py +6 -12
- wandb/apis/reports/report.py +1 -1
- wandb/apis/reports/runset.py +1 -3
- wandb/apis/reports/util.py +12 -10
- wandb/beta/workflows.py +57 -34
- wandb/catboost/__init__.py +1 -2
- wandb/cli/cli.py +215 -133
- wandb/data_types.py +63 -56
- wandb/docker/__init__.py +78 -16
- wandb/docker/auth.py +21 -22
- wandb/env.py +0 -1
- wandb/errors/__init__.py +8 -116
- wandb/errors/term.py +1 -1
- wandb/fastai/__init__.py +1 -2
- wandb/filesync/dir_watcher.py +8 -5
- wandb/filesync/step_prepare.py +76 -75
- wandb/filesync/step_upload.py +1 -2
- wandb/integration/catboost/__init__.py +1 -3
- wandb/integration/catboost/catboost.py +8 -14
- wandb/integration/fastai/__init__.py +7 -13
- wandb/integration/gym/__init__.py +35 -4
- wandb/integration/keras/__init__.py +3 -3
- wandb/integration/keras/callbacks/metrics_logger.py +9 -8
- wandb/integration/keras/callbacks/model_checkpoint.py +9 -9
- wandb/integration/keras/callbacks/tables_builder.py +31 -19
- wandb/integration/kfp/kfp_patch.py +20 -17
- wandb/integration/kfp/wandb_logging.py +1 -2
- wandb/integration/lightgbm/__init__.py +21 -19
- wandb/integration/prodigy/prodigy.py +6 -7
- wandb/integration/sacred/__init__.py +9 -12
- wandb/integration/sagemaker/__init__.py +1 -3
- wandb/integration/sagemaker/auth.py +0 -1
- wandb/integration/sagemaker/config.py +1 -1
- wandb/integration/sagemaker/resources.py +1 -1
- wandb/integration/sb3/sb3.py +8 -4
- wandb/integration/tensorboard/__init__.py +1 -3
- wandb/integration/tensorboard/log.py +8 -8
- wandb/integration/tensorboard/monkeypatch.py +11 -9
- wandb/integration/tensorflow/__init__.py +1 -3
- wandb/integration/xgboost/__init__.py +4 -6
- wandb/integration/yolov8/__init__.py +7 -0
- wandb/integration/yolov8/yolov8.py +250 -0
- wandb/jupyter.py +31 -35
- wandb/lightgbm/__init__.py +1 -2
- wandb/old/settings.py +2 -2
- wandb/plot/bar.py +1 -2
- wandb/plot/confusion_matrix.py +1 -3
- wandb/plot/histogram.py +1 -2
- wandb/plot/line.py +1 -2
- wandb/plot/line_series.py +4 -4
- wandb/plot/pr_curve.py +17 -20
- wandb/plot/roc_curve.py +1 -3
- wandb/plot/scatter.py +1 -2
- wandb/proto/v3/wandb_server_pb2.py +85 -39
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_server_pb2.py +51 -39
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/__init__.py +1 -3
- wandb/sdk/backend/backend.py +1 -1
- wandb/sdk/data_types/_dtypes.py +38 -30
- wandb/sdk/data_types/base_types/json_metadata.py +1 -3
- wandb/sdk/data_types/base_types/media.py +17 -17
- wandb/sdk/data_types/base_types/wb_value.py +33 -26
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +91 -125
- wandb/sdk/data_types/helper_types/classes.py +1 -1
- wandb/sdk/data_types/helper_types/image_mask.py +12 -12
- wandb/sdk/data_types/histogram.py +5 -4
- wandb/sdk/data_types/html.py +1 -2
- wandb/sdk/data_types/image.py +11 -11
- wandb/sdk/data_types/molecule.py +3 -6
- wandb/sdk/data_types/object_3d.py +1 -2
- wandb/sdk/data_types/plotly.py +1 -2
- wandb/sdk/data_types/saved_model.py +10 -8
- wandb/sdk/data_types/video.py +1 -1
- wandb/sdk/integration_utils/data_logging.py +5 -5
- wandb/sdk/interface/artifacts.py +288 -266
- wandb/sdk/interface/interface.py +2 -3
- wandb/sdk/interface/interface_grpc.py +1 -1
- wandb/sdk/interface/interface_queue.py +1 -1
- wandb/sdk/interface/interface_relay.py +1 -1
- wandb/sdk/interface/interface_shared.py +1 -2
- wandb/sdk/interface/interface_sock.py +1 -1
- wandb/sdk/interface/message_future.py +1 -1
- wandb/sdk/interface/message_future_poll.py +1 -1
- wandb/sdk/interface/router.py +1 -1
- wandb/sdk/interface/router_queue.py +1 -1
- wandb/sdk/interface/router_relay.py +1 -1
- wandb/sdk/interface/router_sock.py +1 -1
- wandb/sdk/interface/summary_record.py +1 -1
- wandb/sdk/internal/artifacts.py +1 -1
- wandb/sdk/internal/datastore.py +2 -3
- wandb/sdk/internal/file_pusher.py +5 -3
- wandb/sdk/internal/file_stream.py +22 -19
- wandb/sdk/internal/handler.py +5 -4
- wandb/sdk/internal/internal.py +1 -1
- wandb/sdk/internal/internal_api.py +115 -55
- wandb/sdk/internal/job_builder.py +1 -3
- wandb/sdk/internal/profiler.py +1 -1
- wandb/sdk/internal/progress.py +4 -6
- wandb/sdk/internal/sample.py +1 -3
- wandb/sdk/internal/sender.py +28 -16
- wandb/sdk/internal/settings_static.py +5 -5
- wandb/sdk/internal/system/assets/__init__.py +1 -0
- wandb/sdk/internal/system/assets/cpu.py +3 -9
- wandb/sdk/internal/system/assets/disk.py +2 -4
- wandb/sdk/internal/system/assets/gpu.py +6 -18
- wandb/sdk/internal/system/assets/gpu_apple.py +2 -4
- wandb/sdk/internal/system/assets/interfaces.py +50 -22
- wandb/sdk/internal/system/assets/ipu.py +1 -3
- wandb/sdk/internal/system/assets/memory.py +7 -13
- wandb/sdk/internal/system/assets/network.py +4 -8
- wandb/sdk/internal/system/assets/open_metrics.py +283 -0
- wandb/sdk/internal/system/assets/tpu.py +1 -4
- wandb/sdk/internal/system/assets/trainium.py +26 -14
- wandb/sdk/internal/system/system_info.py +2 -3
- wandb/sdk/internal/system/system_monitor.py +52 -20
- wandb/sdk/internal/tb_watcher.py +12 -13
- wandb/sdk/launch/_project_spec.py +54 -65
- wandb/sdk/launch/agent/agent.py +374 -90
- wandb/sdk/launch/builder/abstract.py +61 -7
- wandb/sdk/launch/builder/build.py +81 -110
- wandb/sdk/launch/builder/docker_builder.py +181 -0
- wandb/sdk/launch/builder/kaniko_builder.py +419 -0
- wandb/sdk/launch/builder/noop.py +31 -12
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +70 -20
- wandb/sdk/launch/environment/abstract.py +28 -0
- wandb/sdk/launch/environment/aws_environment.py +276 -0
- wandb/sdk/launch/environment/gcp_environment.py +271 -0
- wandb/sdk/launch/environment/local_environment.py +65 -0
- wandb/sdk/launch/github_reference.py +3 -8
- wandb/sdk/launch/launch.py +38 -29
- wandb/sdk/launch/launch_add.py +6 -8
- wandb/sdk/launch/loader.py +230 -0
- wandb/sdk/launch/registry/abstract.py +54 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +163 -0
- wandb/sdk/launch/registry/google_artifact_registry.py +203 -0
- wandb/sdk/launch/registry/local_registry.py +62 -0
- wandb/sdk/launch/runner/abstract.py +1 -16
- wandb/sdk/launch/runner/{kubernetes.py → kubernetes_runner.py} +83 -95
- wandb/sdk/launch/runner/local_container.py +46 -22
- wandb/sdk/launch/runner/local_process.py +1 -4
- wandb/sdk/launch/runner/{aws.py → sagemaker_runner.py} +53 -212
- wandb/sdk/launch/runner/{gcp_vertex.py → vertex_runner.py} +38 -55
- wandb/sdk/launch/sweeps/__init__.py +3 -2
- wandb/sdk/launch/sweeps/scheduler.py +132 -39
- wandb/sdk/launch/sweeps/scheduler_sweep.py +80 -89
- wandb/sdk/launch/utils.py +101 -30
- wandb/sdk/launch/wandb_reference.py +2 -7
- wandb/sdk/lib/_settings_toposort_generate.py +166 -0
- wandb/sdk/lib/_settings_toposort_generated.py +201 -0
- wandb/sdk/lib/apikey.py +2 -4
- wandb/sdk/lib/config_util.py +4 -1
- wandb/sdk/lib/console.py +1 -3
- wandb/sdk/lib/deprecate.py +3 -3
- wandb/sdk/lib/file_stream_utils.py +7 -5
- wandb/sdk/lib/filenames.py +1 -1
- wandb/sdk/lib/filesystem.py +61 -5
- wandb/sdk/lib/git.py +1 -3
- wandb/sdk/lib/import_hooks.py +4 -7
- wandb/sdk/lib/ipython.py +8 -5
- wandb/sdk/lib/lazyloader.py +1 -3
- wandb/sdk/lib/mailbox.py +14 -4
- wandb/sdk/lib/proto_util.py +10 -5
- wandb/sdk/lib/redirect.py +15 -22
- wandb/sdk/lib/reporting.py +1 -3
- wandb/sdk/lib/retry.py +4 -5
- wandb/sdk/lib/runid.py +1 -3
- wandb/sdk/lib/server.py +15 -9
- wandb/sdk/lib/sock_client.py +1 -1
- wandb/sdk/lib/sparkline.py +1 -1
- wandb/sdk/lib/wburls.py +1 -1
- wandb/sdk/service/port_file.py +1 -2
- wandb/sdk/service/service.py +36 -13
- wandb/sdk/service/service_base.py +12 -1
- wandb/sdk/verify/verify.py +5 -7
- wandb/sdk/wandb_artifacts.py +142 -177
- wandb/sdk/wandb_config.py +5 -8
- wandb/sdk/wandb_helper.py +1 -1
- wandb/sdk/wandb_init.py +24 -13
- wandb/sdk/wandb_login.py +9 -9
- wandb/sdk/wandb_manager.py +39 -4
- wandb/sdk/wandb_metric.py +2 -6
- wandb/sdk/wandb_require.py +4 -15
- wandb/sdk/wandb_require_helpers.py +1 -9
- wandb/sdk/wandb_run.py +95 -141
- wandb/sdk/wandb_save.py +1 -3
- wandb/sdk/wandb_settings.py +149 -54
- wandb/sdk/wandb_setup.py +66 -46
- wandb/sdk/wandb_summary.py +13 -10
- wandb/sdk/wandb_sweep.py +6 -7
- wandb/sdk/wandb_watch.py +1 -1
- wandb/sklearn/calculate/confusion_matrix.py +1 -1
- wandb/sklearn/calculate/learning_curve.py +1 -1
- wandb/sklearn/calculate/summary_metrics.py +1 -3
- wandb/sklearn/plot/__init__.py +1 -1
- wandb/sklearn/plot/classifier.py +27 -18
- wandb/sklearn/plot/clusterer.py +4 -5
- wandb/sklearn/plot/regressor.py +4 -4
- wandb/sklearn/plot/shared.py +2 -2
- wandb/sync/__init__.py +1 -3
- wandb/sync/sync.py +4 -5
- wandb/testing/relay.py +11 -10
- wandb/trigger.py +1 -1
- wandb/util.py +106 -81
- wandb/viz.py +4 -4
- wandb/wandb_agent.py +50 -50
- wandb/wandb_controller.py +2 -3
- wandb/wandb_run.py +1 -2
- wandb/wandb_torch.py +1 -1
- wandb/xgboost/__init__.py +1 -2
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/METADATA +6 -2
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/RECORD +224 -209
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/WHEEL +1 -1
- wandb/sdk/launch/builder/docker.py +0 -80
- wandb/sdk/launch/builder/kaniko.py +0 -393
- wandb/sdk/launch/builder/loader.py +0 -32
- wandb/sdk/launch/runner/loader.py +0 -50
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/LICENSE +0 -0
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/top_level.txt +0 -0
@@ -58,7 +58,7 @@ if TYPE_CHECKING:
|
|
58
58
|
from .progress import ProgressFn
|
59
59
|
|
60
60
|
class CreateArtifactFileSpecInput(TypedDict, total=False):
|
61
|
-
"""Corresponds to `type CreateArtifactFileSpecInput` in schema.graphql"""
|
61
|
+
"""Corresponds to `type CreateArtifactFileSpecInput` in schema.graphql."""
|
62
62
|
|
63
63
|
artifactID: str # noqa: N815
|
64
64
|
name: str
|
@@ -106,7 +106,7 @@ class _ThreadLocalData(threading.local):
|
|
106
106
|
|
107
107
|
|
108
108
|
class Api:
|
109
|
-
"""W&B Internal Api wrapper
|
109
|
+
"""W&B Internal Api wrapper.
|
110
110
|
|
111
111
|
Note:
|
112
112
|
Settings are automatically overridden by looking for
|
@@ -205,6 +205,7 @@ class Api:
|
|
205
205
|
self._azure_blob_module = util.get_module("azure.storage.blob")
|
206
206
|
|
207
207
|
self.query_types: Optional[List[str]] = None
|
208
|
+
self.mutation_types: Optional[List[str]] = None
|
208
209
|
self.server_info_types: Optional[List[str]] = None
|
209
210
|
self.server_use_artifact_input_info: Optional[List[str]] = None
|
210
211
|
self._max_cli_version: Optional[str] = None
|
@@ -229,11 +230,11 @@ class Api:
|
|
229
230
|
return self._local_data.context or self._global_context
|
230
231
|
|
231
232
|
def reauth(self) -> None:
|
232
|
-
"""
|
233
|
+
"""Ensure the current api key is set in the transport."""
|
233
234
|
self.client.transport.auth = ("api", self.api_key or "")
|
234
235
|
|
235
236
|
def relocate(self) -> None:
|
236
|
-
"""
|
237
|
+
"""Ensure the current api points to the right server."""
|
237
238
|
self.client.transport.url = "%s/graphql" % self.settings("base_url")
|
238
239
|
|
239
240
|
def execute(self, *args: Any, **kwargs: Any) -> "_Response":
|
@@ -384,7 +385,7 @@ class Api:
|
|
384
385
|
def parse_slug(
|
385
386
|
self, slug: str, project: Optional[str] = None, run: Optional[str] = None
|
386
387
|
) -> Tuple[str, str]:
|
387
|
-
"""
|
388
|
+
"""Parse a slug into a project and run.
|
388
389
|
|
389
390
|
Arguments:
|
390
391
|
slug (str): The slug to parse
|
@@ -409,12 +410,15 @@ class Api:
|
|
409
410
|
return project, run
|
410
411
|
|
411
412
|
@normalize_exceptions
|
412
|
-
def server_info_introspection(self) -> Tuple[List[str], List[str]]:
|
413
|
+
def server_info_introspection(self) -> Tuple[List[str], List[str], List[str]]:
|
413
414
|
query_string = """
|
414
415
|
query ProbeServerCapabilities {
|
415
416
|
QueryType: __type(name: "Query") {
|
416
417
|
...fieldData
|
417
418
|
}
|
419
|
+
MutationType: __type(name: "Mutation") {
|
420
|
+
...fieldData
|
421
|
+
}
|
418
422
|
ServerInfoType: __type(name: "ServerInfo") {
|
419
423
|
...fieldData
|
420
424
|
}
|
@@ -426,7 +430,11 @@ class Api:
|
|
426
430
|
}
|
427
431
|
}
|
428
432
|
"""
|
429
|
-
if
|
433
|
+
if (
|
434
|
+
self.query_types is None
|
435
|
+
or self.mutation_types is None
|
436
|
+
or self.server_info_types is None
|
437
|
+
):
|
430
438
|
query = gql(query_string)
|
431
439
|
res = self.gql(query)
|
432
440
|
|
@@ -434,11 +442,15 @@ class Api:
|
|
434
442
|
field.get("name", "")
|
435
443
|
for field in res.get("QueryType", {}).get("fields", [{}])
|
436
444
|
]
|
445
|
+
self.mutation_types = [
|
446
|
+
field.get("name", "")
|
447
|
+
for field in res.get("MutationType", {}).get("fields", [{}])
|
448
|
+
]
|
437
449
|
self.server_info_types = [
|
438
450
|
field.get("name", "")
|
439
451
|
for field in res.get("ServerInfoType", {}).get("fields", [{}])
|
440
452
|
]
|
441
|
-
return self.query_types, self.server_info_types
|
453
|
+
return self.query_types, self.server_info_types, self.mutation_types
|
442
454
|
|
443
455
|
@normalize_exceptions
|
444
456
|
def server_settings_introspection(self) -> None:
|
@@ -505,6 +517,35 @@ class Api:
|
|
505
517
|
res = self.gql(query)
|
506
518
|
return res.get("LaunchAgentType") or None
|
507
519
|
|
520
|
+
@normalize_exceptions
|
521
|
+
def fail_run_queue_item_introspection(self) -> bool:
|
522
|
+
_, _, mutations = self.server_info_introspection()
|
523
|
+
return "failRunQueueItem" in mutations
|
524
|
+
|
525
|
+
@normalize_exceptions
|
526
|
+
def fail_run_queue_item(self, run_queue_item_id: str) -> bool:
|
527
|
+
mutation = gql(
|
528
|
+
"""
|
529
|
+
mutation failRunQueueItem($runQueueItemId: ID!) {
|
530
|
+
failRunQueueItem(
|
531
|
+
input: {
|
532
|
+
runQueueItemId: $runQueueItemId
|
533
|
+
}
|
534
|
+
) {
|
535
|
+
success
|
536
|
+
}
|
537
|
+
}
|
538
|
+
"""
|
539
|
+
)
|
540
|
+
response = self.gql(
|
541
|
+
mutation,
|
542
|
+
variable_values={
|
543
|
+
"runQueueItemId": run_queue_item_id,
|
544
|
+
},
|
545
|
+
)
|
546
|
+
result: bool = response["failRunQueueItem"]["success"]
|
547
|
+
return result
|
548
|
+
|
508
549
|
@normalize_exceptions
|
509
550
|
def viewer(self) -> Dict[str, Any]:
|
510
551
|
query = gql(
|
@@ -530,11 +571,10 @@ class Api:
|
|
530
571
|
|
531
572
|
@normalize_exceptions
|
532
573
|
def max_cli_version(self) -> Optional[str]:
|
533
|
-
|
534
574
|
if self._max_cli_version is not None:
|
535
575
|
return self._max_cli_version
|
536
576
|
|
537
|
-
query_types, server_info_types = self.server_info_introspection()
|
577
|
+
query_types, server_info_types, _ = self.server_info_introspection()
|
538
578
|
cli_version_exists = (
|
539
579
|
"serverInfo" in query_types and "cliVersionInfo" in server_info_types
|
540
580
|
)
|
@@ -580,7 +620,7 @@ class Api:
|
|
580
620
|
_CLI_QUERY_
|
581
621
|
}
|
582
622
|
"""
|
583
|
-
query_types, server_info_types = self.server_info_introspection()
|
623
|
+
query_types, server_info_types, _ = self.server_info_introspection()
|
584
624
|
|
585
625
|
cli_version_exists = (
|
586
626
|
"serverInfo" in query_types and "cliVersionInfo" in server_info_types
|
@@ -603,7 +643,7 @@ class Api:
|
|
603
643
|
|
604
644
|
@normalize_exceptions
|
605
645
|
def list_projects(self, entity: Optional[str] = None) -> List[Dict[str, str]]:
|
606
|
-
"""
|
646
|
+
"""List projects in W&B scoped by entity.
|
607
647
|
|
608
648
|
Arguments:
|
609
649
|
entity (str, optional): The entity to scope this project to.
|
@@ -635,7 +675,7 @@ class Api:
|
|
635
675
|
|
636
676
|
@normalize_exceptions
|
637
677
|
def project(self, project: str, entity: Optional[str] = None) -> "_Response":
|
638
|
-
"""Retrieve project
|
678
|
+
"""Retrieve project.
|
639
679
|
|
640
680
|
Arguments:
|
641
681
|
project (str): The project to get details for
|
@@ -743,7 +783,7 @@ class Api:
|
|
743
783
|
def list_runs(
|
744
784
|
self, project: str, entity: Optional[str] = None
|
745
785
|
) -> List[Dict[str, str]]:
|
746
|
-
"""
|
786
|
+
"""List runs in W&B scoped by project.
|
747
787
|
|
748
788
|
Arguments:
|
749
789
|
project (str): The project to scope the runs to
|
@@ -784,7 +824,7 @@ class Api:
|
|
784
824
|
def run_config(
|
785
825
|
self, project: str, run: Optional[str] = None, entity: Optional[str] = None
|
786
826
|
) -> Tuple[str, Dict[str, Any], Optional[str], Dict[str, Any]]:
|
787
|
-
"""Get the relevant configs for a run
|
827
|
+
"""Get the relevant configs for a run.
|
788
828
|
|
789
829
|
Arguments:
|
790
830
|
project (str): The project to download, (can include bucket)
|
@@ -973,7 +1013,7 @@ class Api:
|
|
973
1013
|
description: Optional[str] = None,
|
974
1014
|
entity: Optional[str] = None,
|
975
1015
|
) -> Dict[str, Any]:
|
976
|
-
"""Create a new project
|
1016
|
+
"""Create a new project.
|
977
1017
|
|
978
1018
|
Arguments:
|
979
1019
|
project (str): The project to create
|
@@ -1006,6 +1046,32 @@ class Api:
|
|
1006
1046
|
result: Dict[str, Any] = response["upsertModel"]["model"]
|
1007
1047
|
return result
|
1008
1048
|
|
1049
|
+
@normalize_exceptions
|
1050
|
+
def entity_is_team(self, entity: str) -> bool:
|
1051
|
+
query = gql(
|
1052
|
+
"""
|
1053
|
+
query EntityIsTeam($entity: String!) {
|
1054
|
+
entity(name: $entity) {
|
1055
|
+
id
|
1056
|
+
isTeam
|
1057
|
+
}
|
1058
|
+
}
|
1059
|
+
"""
|
1060
|
+
)
|
1061
|
+
variable_values = {
|
1062
|
+
"entity": entity,
|
1063
|
+
}
|
1064
|
+
|
1065
|
+
res = self.gql(query, variable_values)
|
1066
|
+
if res.get("entity") is None:
|
1067
|
+
raise Exception(
|
1068
|
+
f"Error fetching entity {entity} "
|
1069
|
+
"check that you have access to this entity"
|
1070
|
+
)
|
1071
|
+
|
1072
|
+
is_team: bool = res["entity"]["isTeam"]
|
1073
|
+
return is_team
|
1074
|
+
|
1009
1075
|
@normalize_exceptions
|
1010
1076
|
def get_project_run_queues(self, entity: str, project: str) -> List[Dict[str, str]]:
|
1011
1077
|
query = gql(
|
@@ -1029,10 +1095,19 @@ class Api:
|
|
1029
1095
|
|
1030
1096
|
res = self.gql(query, variable_values)
|
1031
1097
|
if res.get("project") is None:
|
1032
|
-
|
1033
|
-
|
1034
|
-
|
1035
|
-
|
1098
|
+
# circular dependency: (LAUNCH_DEFAULT_PROJECT = model-registry)
|
1099
|
+
if project == "model-registry":
|
1100
|
+
msg = (
|
1101
|
+
f"Error fetching run queues for {entity} "
|
1102
|
+
"check that you have access to this entity and project"
|
1103
|
+
)
|
1104
|
+
else:
|
1105
|
+
msg = (
|
1106
|
+
f"Error fetching run queues for {entity}/{project} "
|
1107
|
+
"check that you have access to this entity and project"
|
1108
|
+
)
|
1109
|
+
|
1110
|
+
raise Exception(msg)
|
1036
1111
|
|
1037
1112
|
project_run_queues: List[Dict[str, str]] = res["project"]["runQueues"]
|
1038
1113
|
return project_run_queues
|
@@ -1073,10 +1148,7 @@ class Api:
|
|
1073
1148
|
def push_to_run_queue_by_name(
|
1074
1149
|
self, entity: str, project: str, queue_name: str, run_spec: str
|
1075
1150
|
) -> Optional[Dict[str, Any]]:
|
1076
|
-
"""
|
1077
|
-
Queryless mutation, should be used before legacy fallback method
|
1078
|
-
"""
|
1079
|
-
|
1151
|
+
"""Queryless mutation, should be used before legacy fallback method."""
|
1080
1152
|
mutation = gql(
|
1081
1153
|
"""
|
1082
1154
|
mutation pushToRunQueueByName(
|
@@ -1445,7 +1517,7 @@ class Api:
|
|
1445
1517
|
summary_metrics: Optional[str] = None,
|
1446
1518
|
num_retries: Optional[int] = None,
|
1447
1519
|
) -> Tuple[dict, bool, Optional[List]]:
|
1448
|
-
"""Update a run
|
1520
|
+
"""Update a run.
|
1449
1521
|
|
1450
1522
|
Arguments:
|
1451
1523
|
id (str, optional): The existing run to update
|
@@ -1468,7 +1540,6 @@ class Api:
|
|
1468
1540
|
summary_metrics (str, optional): The JSON summary metrics
|
1469
1541
|
num_retries (int, optional): Number of retries
|
1470
1542
|
"""
|
1471
|
-
|
1472
1543
|
query_string = """
|
1473
1544
|
mutation UpsertBucket(
|
1474
1545
|
$id: String,
|
@@ -1706,7 +1777,7 @@ class Api:
|
|
1706
1777
|
entity: Optional[str] = None,
|
1707
1778
|
description: Optional[str] = None,
|
1708
1779
|
) -> Tuple[str, List[str], Dict[str, Dict[str, Any]]]:
|
1709
|
-
"""Generate temporary resumable upload urls
|
1780
|
+
"""Generate temporary resumable upload urls.
|
1710
1781
|
|
1711
1782
|
Arguments:
|
1712
1783
|
project (str): The project to download
|
@@ -1775,7 +1846,7 @@ class Api:
|
|
1775
1846
|
run: Optional[str] = None,
|
1776
1847
|
entity: Optional[str] = None,
|
1777
1848
|
) -> Dict[str, Dict[str, str]]:
|
1778
|
-
"""Generate download urls
|
1849
|
+
"""Generate download urls.
|
1779
1850
|
|
1780
1851
|
Arguments:
|
1781
1852
|
project (str): The project to download
|
@@ -1834,7 +1905,7 @@ class Api:
|
|
1834
1905
|
run: Optional[str] = None,
|
1835
1906
|
entity: Optional[str] = None,
|
1836
1907
|
) -> Optional[Dict[str, str]]:
|
1837
|
-
"""Generate download urls
|
1908
|
+
"""Generate download urls.
|
1838
1909
|
|
1839
1910
|
Arguments:
|
1840
1911
|
project (str): The project to download
|
@@ -1887,7 +1958,7 @@ class Api:
|
|
1887
1958
|
|
1888
1959
|
@normalize_exceptions
|
1889
1960
|
def download_file(self, url: str) -> Tuple[int, requests.Response]:
|
1890
|
-
"""Initiate a streaming download
|
1961
|
+
"""Initiate a streaming download.
|
1891
1962
|
|
1892
1963
|
Arguments:
|
1893
1964
|
url (str): The url to download
|
@@ -1905,7 +1976,7 @@ class Api:
|
|
1905
1976
|
metadata: Dict[str, str],
|
1906
1977
|
out_dir: Optional[str] = None,
|
1907
1978
|
) -> Tuple[str, Optional[requests.Response]]:
|
1908
|
-
"""Download a file from a run and write it to wandb
|
1979
|
+
"""Download a file from a run and write it to wandb/.
|
1909
1980
|
|
1910
1981
|
Arguments:
|
1911
1982
|
metadata (obj): The metadata object for the file to download. Comes from Api.download_urls().
|
@@ -1931,9 +2002,7 @@ class Api:
|
|
1931
2002
|
def upload_file_azure(
|
1932
2003
|
self, url: str, file: Any, extra_headers: Dict[str, str]
|
1933
2004
|
) -> None:
|
1934
|
-
"""
|
1935
|
-
Upload a file to azure
|
1936
|
-
"""
|
2005
|
+
"""Upload a file to azure."""
|
1937
2006
|
from azure.core.exceptions import AzureError # type: ignore
|
1938
2007
|
|
1939
2008
|
# Configure the client without retries so our existing logic can handle them
|
@@ -1972,7 +2041,7 @@ class Api:
|
|
1972
2041
|
callback: Optional["ProgressFn"] = None,
|
1973
2042
|
extra_headers: Optional[Dict[str, str]] = None,
|
1974
2043
|
) -> Optional[requests.Response]:
|
1975
|
-
"""
|
2044
|
+
"""Upload a file to W&B with failure resumption.
|
1976
2045
|
|
1977
2046
|
Arguments:
|
1978
2047
|
url: The url to download
|
@@ -2039,7 +2108,7 @@ class Api:
|
|
2039
2108
|
project_name: Optional[str] = None,
|
2040
2109
|
entity: Optional[str] = None,
|
2041
2110
|
) -> dict:
|
2042
|
-
"""Register a new agent
|
2111
|
+
"""Register a new agent.
|
2043
2112
|
|
2044
2113
|
Arguments:
|
2045
2114
|
host (str): hostname
|
@@ -2318,7 +2387,7 @@ class Api:
|
|
2318
2387
|
|
2319
2388
|
@normalize_exceptions
|
2320
2389
|
def create_anonymous_api_key(self) -> str:
|
2321
|
-
"""
|
2390
|
+
"""Create a new API key belonging to a new anonymous user."""
|
2322
2391
|
mutation = gql(
|
2323
2392
|
"""
|
2324
2393
|
mutation CreateAnonymousApiKey {
|
@@ -2337,14 +2406,14 @@ class Api:
|
|
2337
2406
|
|
2338
2407
|
@staticmethod
|
2339
2408
|
def file_current(fname: str, md5: B64MD5) -> bool:
|
2340
|
-
"""Checksum a file and compare the md5 with the known md5"""
|
2409
|
+
"""Checksum a file and compare the md5 with the known md5."""
|
2341
2410
|
return os.path.isfile(fname) and md5_file_b64(fname) == md5
|
2342
2411
|
|
2343
2412
|
@normalize_exceptions
|
2344
2413
|
def pull(
|
2345
2414
|
self, project: str, run: Optional[str] = None, entity: Optional[str] = None
|
2346
2415
|
) -> "List[requests.Response]":
|
2347
|
-
"""Download files from W&B
|
2416
|
+
"""Download files from W&B.
|
2348
2417
|
|
2349
2418
|
Arguments:
|
2350
2419
|
project (str): The project to download
|
@@ -2379,7 +2448,7 @@ class Api:
|
|
2379
2448
|
force: bool = True,
|
2380
2449
|
progress: Union[TextIO, bool] = False,
|
2381
2450
|
) -> "List[Optional[requests.Response]]":
|
2382
|
-
"""Uploads multiple files to W&B
|
2451
|
+
"""Uploads multiple files to W&B.
|
2383
2452
|
|
2384
2453
|
Arguments:
|
2385
2454
|
files (list or dict): The filenames to upload, when dict the values are open files
|
@@ -2949,7 +3018,6 @@ class Api:
|
|
2949
3018
|
self,
|
2950
3019
|
client_id: str,
|
2951
3020
|
) -> Optional[str]:
|
2952
|
-
|
2953
3021
|
if client_id in self._client_id_mapping:
|
2954
3022
|
return self._client_id_mapping[client_id]
|
2955
3023
|
|
@@ -3142,9 +3210,7 @@ class Api:
|
|
3142
3210
|
entity: Optional[str] = None,
|
3143
3211
|
project: Optional[str] = None,
|
3144
3212
|
) -> None:
|
3145
|
-
"""
|
3146
|
-
Finish the sweep to stop running new runs and let currently running runs finish.
|
3147
|
-
"""
|
3213
|
+
"""Finish the sweep to stop running new runs and let currently running runs finish."""
|
3148
3214
|
self.set_sweep_state(
|
3149
3215
|
sweep=sweep, state="FINISHED", entity=entity, project=project
|
3150
3216
|
)
|
@@ -3155,9 +3221,7 @@ class Api:
|
|
3155
3221
|
entity: Optional[str] = None,
|
3156
3222
|
project: Optional[str] = None,
|
3157
3223
|
) -> None:
|
3158
|
-
"""
|
3159
|
-
Cancel the sweep to kill all running runs and stop running new runs.
|
3160
|
-
"""
|
3224
|
+
"""Cancel the sweep to kill all running runs and stop running new runs."""
|
3161
3225
|
self.set_sweep_state(
|
3162
3226
|
sweep=sweep, state="CANCELED", entity=entity, project=project
|
3163
3227
|
)
|
@@ -3168,9 +3232,7 @@ class Api:
|
|
3168
3232
|
entity: Optional[str] = None,
|
3169
3233
|
project: Optional[str] = None,
|
3170
3234
|
) -> None:
|
3171
|
-
"""
|
3172
|
-
Pause the sweep to temporarily stop running new runs.
|
3173
|
-
"""
|
3235
|
+
"""Pause the sweep to temporarily stop running new runs."""
|
3174
3236
|
self.set_sweep_state(
|
3175
3237
|
sweep=sweep, state="PAUSED", entity=entity, project=project
|
3176
3238
|
)
|
@@ -3181,20 +3243,18 @@ class Api:
|
|
3181
3243
|
entity: Optional[str] = None,
|
3182
3244
|
project: Optional[str] = None,
|
3183
3245
|
) -> None:
|
3184
|
-
"""
|
3185
|
-
Resume the sweep to continue running new runs.
|
3186
|
-
"""
|
3246
|
+
"""Resume the sweep to continue running new runs."""
|
3187
3247
|
self.set_sweep_state(
|
3188
3248
|
sweep=sweep, state="RUNNING", entity=entity, project=project
|
3189
3249
|
)
|
3190
3250
|
|
3191
3251
|
def _status_request(self, url: str, length: int) -> requests.Response:
|
3192
|
-
"""Ask google how much we've uploaded"""
|
3252
|
+
"""Ask google how much we've uploaded."""
|
3193
3253
|
return requests.put(
|
3194
3254
|
url=url,
|
3195
3255
|
headers={"Content-Length": "0", "Content-Range": "bytes */%i" % length},
|
3196
3256
|
)
|
3197
3257
|
|
3198
3258
|
def _flatten_edges(self, response: "_Response") -> List[Dict]:
|
3199
|
-
"""Return an array from the nested graphql relay structure"""
|
3259
|
+
"""Return an array from the nested graphql relay structure."""
|
3200
3260
|
return [node["node"] for node in response["edges"]]
|
wandb/sdk/internal/profiler.py
CHANGED
@@ -10,7 +10,7 @@ PYTORCH_PROFILER_MODULE = "torch.profiler"
|
|
10
10
|
|
11
11
|
|
12
12
|
def torch_trace_handler():
|
13
|
-
"""
|
13
|
+
"""Create a trace handler for traces generated by the profiler.
|
14
14
|
|
15
15
|
Provide as an argument to `torch.profiler.profile`:
|
16
16
|
```python
|
wandb/sdk/internal/progress.py
CHANGED
@@ -1,6 +1,4 @@
|
|
1
|
-
"""
|
2
|
-
progress.
|
3
|
-
"""
|
1
|
+
"""progress."""
|
4
2
|
|
5
3
|
import os
|
6
4
|
import sys
|
@@ -20,7 +18,7 @@ if TYPE_CHECKING:
|
|
20
18
|
|
21
19
|
|
22
20
|
class Progress:
|
23
|
-
"""A helper class for displaying progress"""
|
21
|
+
"""A helper class for displaying progress."""
|
24
22
|
|
25
23
|
ITER_BYTES = 1024 * 1024
|
26
24
|
|
@@ -40,7 +38,7 @@ class Progress:
|
|
40
38
|
self.len = os.fstat(file.fileno()).st_size
|
41
39
|
|
42
40
|
def read(self, size=-1):
|
43
|
-
"""Read bytes and call the callback"""
|
41
|
+
"""Read bytes and call the callback."""
|
44
42
|
bites = self.file.read(size)
|
45
43
|
self.bytes_read += len(bites)
|
46
44
|
if not bites and self.bytes_read < self.len:
|
@@ -64,7 +62,7 @@ class Progress:
|
|
64
62
|
self.file.seek(0)
|
65
63
|
|
66
64
|
def __getattr__(self, name):
|
67
|
-
"""Fallback to the file object for attrs not defined here"""
|
65
|
+
"""Fallback to the file object for attrs not defined here."""
|
68
66
|
if hasattr(self.file, name):
|
69
67
|
return getattr(self.file, name)
|
70
68
|
else:
|
wandb/sdk/internal/sample.py
CHANGED
wandb/sdk/internal/sender.py
CHANGED
@@ -1,6 +1,4 @@
|
|
1
|
-
"""
|
2
|
-
sender.
|
3
|
-
"""
|
1
|
+
"""sender."""
|
4
2
|
|
5
3
|
|
6
4
|
import json
|
@@ -29,10 +27,11 @@ import requests
|
|
29
27
|
|
30
28
|
import wandb
|
31
29
|
from wandb import util
|
32
|
-
from wandb.errors import
|
30
|
+
from wandb.errors import CommError
|
33
31
|
from wandb.filesync.dir_watcher import DirWatcher
|
34
32
|
from wandb.proto import wandb_internal_pb2
|
35
33
|
from wandb.sdk.lib import redirect
|
34
|
+
from wandb.sdk.lib.mailbox import ContextCancelledError
|
36
35
|
|
37
36
|
from ..interface import interface
|
38
37
|
from ..interface.interface_queue import InterfaceQueue
|
@@ -271,7 +270,8 @@ class SendManager:
|
|
271
270
|
|
272
271
|
@classmethod
|
273
272
|
def setup(cls, root_dir: str, resume: Union[None, bool, str]) -> "SendManager":
|
274
|
-
"""
|
273
|
+
"""Set up a standalone SendManager.
|
274
|
+
|
275
275
|
Currently, we're using this primarily for `sync.py`.
|
276
276
|
"""
|
277
277
|
files_dir = os.path.join(root_dir, "files")
|
@@ -710,8 +710,7 @@ class SendManager:
|
|
710
710
|
def _maybe_setup_resume(
|
711
711
|
self, run: "RunRecord"
|
712
712
|
) -> Optional["wandb_internal_pb2.ErrorInfo"]:
|
713
|
-
"""
|
714
|
-
incompatible."""
|
713
|
+
"""Queries the backend for a run; fail if the settings are incompatible."""
|
715
714
|
if not self._settings.resume:
|
716
715
|
return None
|
717
716
|
|
@@ -850,7 +849,7 @@ class SendManager:
|
|
850
849
|
config_util.save_config_file_from_dict(config_path, config_value_dict)
|
851
850
|
|
852
851
|
def _sync_spell(self) -> None:
|
853
|
-
"""
|
852
|
+
"""Sync this run with spell."""
|
854
853
|
if not self._run:
|
855
854
|
return
|
856
855
|
try:
|
@@ -926,7 +925,19 @@ class SendManager:
|
|
926
925
|
config_value_dict = self._config_format(None)
|
927
926
|
self._config_save(config_value_dict)
|
928
927
|
|
929
|
-
|
928
|
+
try:
|
929
|
+
self._init_run(run, config_value_dict)
|
930
|
+
except CommError as e:
|
931
|
+
logger.error(e, exc_info=True)
|
932
|
+
if record.control.req_resp or record.control.mailbox_slot:
|
933
|
+
result = proto_util._result_from_record(record)
|
934
|
+
result.run_result.run.CopyFrom(run)
|
935
|
+
error = wandb_internal_pb2.ErrorInfo()
|
936
|
+
error.message = str(e)
|
937
|
+
result.run_result.error.CopyFrom(error)
|
938
|
+
self._respond_result(result)
|
939
|
+
return
|
940
|
+
|
930
941
|
assert self._run # self._run is configured in _init_run()
|
931
942
|
|
932
943
|
if record.control.req_resp or record.control.mailbox_slot:
|
@@ -1360,9 +1371,9 @@ class SendManager:
|
|
1360
1371
|
logger.warning("Failed to link artifact to portfolio: %s", e)
|
1361
1372
|
|
1362
1373
|
def send_use_artifact(self, record: "Record") -> None:
|
1363
|
-
"""
|
1364
|
-
|
1365
|
-
internally
|
1374
|
+
"""Pretend to send a used artifact.
|
1375
|
+
|
1376
|
+
This function doesn't actually send anything, it is just used internally.
|
1366
1377
|
"""
|
1367
1378
|
use = record.use_artifact
|
1368
1379
|
if use.type == "job":
|
@@ -1530,10 +1541,11 @@ class SendManager:
|
|
1530
1541
|
return self._cached_server_info
|
1531
1542
|
|
1532
1543
|
def get_local_info(self) -> "LocalInfo":
|
1533
|
-
"""
|
1534
|
-
|
1535
|
-
First, we perform an introspection, if it returns empty we deduce that the
|
1536
|
-
out-of-date. Otherwise, we use the returned values to deduce the
|
1544
|
+
"""Queries the server to get the local version information.
|
1545
|
+
|
1546
|
+
First, we perform an introspection, if it returns empty we deduce that the
|
1547
|
+
docker image is out-of-date. Otherwise, we use the returned values to deduce the
|
1548
|
+
state of the local server.
|
1537
1549
|
"""
|
1538
1550
|
local_info = wandb_internal_pb2.LocalInfo()
|
1539
1551
|
if self._settings._offline:
|
@@ -1,7 +1,5 @@
|
|
1
|
-
"""
|
2
|
-
|
3
|
-
"""
|
4
|
-
from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union
|
1
|
+
"""static settings."""
|
2
|
+
from typing import Any, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union
|
5
3
|
|
6
4
|
SettingsDict = Dict[str, Union[str, float, Tuple, None]]
|
7
5
|
|
@@ -20,6 +18,8 @@ class SettingsStatic:
|
|
20
18
|
_stats_samples_to_average: int
|
21
19
|
_stats_join_assets: bool
|
22
20
|
_stats_neuron_monitor_config_path: Optional[str]
|
21
|
+
_stats_open_metrics_endpoints: Mapping[str, str]
|
22
|
+
_stats_open_metrics_filters: Mapping[str, Mapping[str, str]]
|
23
23
|
files_dir: str
|
24
24
|
program_relpath: Optional[str]
|
25
25
|
log_internal: str
|
@@ -32,7 +32,7 @@ class SettingsStatic:
|
|
32
32
|
_jupyter_name: Optional[str]
|
33
33
|
_jupyter_root: Optional[str]
|
34
34
|
_network_buffer: Optional[int]
|
35
|
-
|
35
|
+
_disable_service: Optional[bool]
|
36
36
|
_live_policy_rate_limit: Optional[int]
|
37
37
|
resume: Optional[str]
|
38
38
|
program: Optional[str]
|
@@ -6,5 +6,6 @@ from .gpu_apple import GPUApple # noqa: F401
|
|
6
6
|
from .ipu import IPU # noqa: F401
|
7
7
|
from .memory import Memory # noqa: F401
|
8
8
|
from .network import Network # noqa: F401
|
9
|
+
from .open_metrics import OpenMetrics # noqa: F401
|
9
10
|
from .tpu import TPU # noqa: F401
|
10
11
|
from .trainium import Trainium # noqa: F401
|
@@ -20,9 +20,7 @@ if TYPE_CHECKING:
|
|
20
20
|
|
21
21
|
|
22
22
|
class ProcessCpuPercent:
|
23
|
-
"""
|
24
|
-
CPU usage of the process in percent normalized by the number of CPUs.
|
25
|
-
"""
|
23
|
+
"""CPU usage of the process in percent normalized by the number of CPUs."""
|
26
24
|
|
27
25
|
# name = "process_cpu_percent"
|
28
26
|
name = "cpu"
|
@@ -58,9 +56,7 @@ class ProcessCpuPercent:
|
|
58
56
|
|
59
57
|
|
60
58
|
class CpuPercent:
|
61
|
-
"""
|
62
|
-
CPU usage of the system in percent per core.
|
63
|
-
"""
|
59
|
+
"""CPU usage of the system in percent per core."""
|
64
60
|
|
65
61
|
name = "cpu.{i}.cpu_percent"
|
66
62
|
|
@@ -87,9 +83,7 @@ class CpuPercent:
|
|
87
83
|
|
88
84
|
|
89
85
|
class ProcessCpuThreads:
|
90
|
-
"""
|
91
|
-
Number of threads used by the process.
|
92
|
-
"""
|
86
|
+
"""Number of threads used by the process."""
|
93
87
|
|
94
88
|
name = "proc.cpu.threads"
|
95
89
|
|