wandb 0.21.1__py3-none-win_amd64.whl → 0.21.2__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +1 -1
- wandb/apis/public/api.py +1 -2
- wandb/apis/public/artifacts.py +3 -5
- wandb/apis/public/registries/_utils.py +14 -16
- wandb/apis/public/registries/registries_search.py +176 -289
- wandb/apis/public/reports.py +13 -10
- wandb/automations/_generated/delete_automation.py +1 -3
- wandb/automations/_generated/enums.py +13 -11
- wandb/bin/gpu_stats.exe +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +47 -2
- wandb/integration/metaflow/data_pandas.py +2 -2
- wandb/integration/metaflow/data_pytorch.py +75 -0
- wandb/integration/metaflow/data_sklearn.py +76 -0
- wandb/integration/metaflow/metaflow.py +16 -87
- wandb/integration/weave/__init__.py +6 -0
- wandb/integration/weave/interface.py +49 -0
- wandb/integration/weave/weave.py +63 -0
- wandb/proto/v3/wandb_internal_pb2.py +3 -2
- wandb/proto/v4/wandb_internal_pb2.py +2 -2
- wandb/proto/v5/wandb_internal_pb2.py +2 -2
- wandb/proto/v6/wandb_internal_pb2.py +2 -2
- wandb/sdk/artifacts/_factories.py +17 -0
- wandb/sdk/artifacts/_generated/__init__.py +221 -13
- wandb/sdk/artifacts/_generated/artifact_by_id.py +17 -0
- wandb/sdk/artifacts/_generated/artifact_by_name.py +22 -0
- wandb/sdk/artifacts/_generated/artifact_collection_membership_file_urls.py +43 -0
- wandb/sdk/artifacts/_generated/artifact_created_by.py +47 -0
- wandb/sdk/artifacts/_generated/artifact_file_urls.py +22 -0
- wandb/sdk/artifacts/_generated/artifact_type.py +31 -0
- wandb/sdk/artifacts/_generated/artifact_used_by.py +43 -0
- wandb/sdk/artifacts/_generated/artifact_via_membership_by_name.py +26 -0
- wandb/sdk/artifacts/_generated/delete_artifact.py +28 -0
- wandb/sdk/artifacts/_generated/enums.py +5 -0
- wandb/sdk/artifacts/_generated/fetch_artifact_manifest.py +38 -0
- wandb/sdk/artifacts/_generated/fetch_registries.py +32 -0
- wandb/sdk/artifacts/_generated/fragments.py +279 -41
- wandb/sdk/artifacts/_generated/link_artifact.py +6 -0
- wandb/sdk/artifacts/_generated/operations.py +654 -51
- wandb/sdk/artifacts/_generated/registry_collections.py +34 -0
- wandb/sdk/artifacts/_generated/registry_versions.py +34 -0
- wandb/sdk/artifacts/_generated/unlink_artifact.py +25 -0
- wandb/sdk/artifacts/_graphql_fragments.py +3 -86
- wandb/sdk/artifacts/_validators.py +6 -4
- wandb/sdk/artifacts/artifact.py +406 -543
- wandb/sdk/artifacts/artifact_file_cache.py +10 -6
- wandb/sdk/artifacts/artifact_manifest.py +10 -9
- wandb/sdk/artifacts/artifact_manifest_entry.py +9 -10
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +5 -3
- wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +1 -1
- wandb/sdk/data_types/video.py +2 -2
- wandb/sdk/interface/interface_queue.py +1 -4
- wandb/sdk/interface/interface_shared.py +26 -37
- wandb/sdk/interface/interface_sock.py +24 -14
- wandb/sdk/internal/settings_static.py +2 -3
- wandb/sdk/launch/create_job.py +12 -1
- wandb/sdk/launch/runner/kubernetes_runner.py +24 -29
- wandb/sdk/lib/asyncio_compat.py +16 -16
- wandb/sdk/lib/asyncio_manager.py +252 -0
- wandb/sdk/lib/hashutil.py +13 -4
- wandb/sdk/lib/printer.py +2 -2
- wandb/sdk/lib/printer_asyncio.py +3 -1
- wandb/sdk/lib/retry.py +185 -78
- wandb/sdk/lib/service/service_client.py +106 -0
- wandb/sdk/lib/service/service_connection.py +20 -26
- wandb/sdk/lib/service/service_token.py +30 -13
- wandb/sdk/mailbox/mailbox.py +13 -5
- wandb/sdk/mailbox/mailbox_handle.py +22 -13
- wandb/sdk/mailbox/response_handle.py +42 -106
- wandb/sdk/mailbox/wait_with_progress.py +7 -42
- wandb/sdk/wandb_init.py +11 -25
- wandb/sdk/wandb_login.py +1 -1
- wandb/sdk/wandb_run.py +91 -55
- wandb/sdk/wandb_settings.py +45 -32
- wandb/sdk/wandb_setup.py +176 -96
- wandb/util.py +1 -1
- {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/METADATA +1 -1
- {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/RECORD +84 -68
- wandb/sdk/interface/interface_relay.py +0 -38
- wandb/sdk/interface/router.py +0 -89
- wandb/sdk/interface/router_queue.py +0 -43
- wandb/sdk/interface/router_relay.py +0 -50
- wandb/sdk/interface/router_sock.py +0 -32
- wandb/sdk/lib/sock_client.py +0 -232
- {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/WHEEL +0 -0
- {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/entry_points.txt +0 -0
- {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/licenses/LICENSE +0 -0
@@ -6,13 +6,10 @@ from __future__ import annotations
|
|
6
6
|
from enum import Enum
|
7
7
|
|
8
8
|
|
9
|
-
class
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
LINK_MODEL = "LINK_MODEL"
|
14
|
-
RUN_METRIC = "RUN_METRIC"
|
15
|
-
RUN_METRIC_CHANGE = "RUN_METRIC_CHANGE"
|
9
|
+
class AlertSeverity(str, Enum):
|
10
|
+
INFO = "INFO"
|
11
|
+
WARN = "WARN"
|
12
|
+
ERROR = "ERROR"
|
16
13
|
|
17
14
|
|
18
15
|
class TriggerScopeType(str, Enum):
|
@@ -20,10 +17,15 @@ class TriggerScopeType(str, Enum):
|
|
20
17
|
ARTIFACT_COLLECTION = "ARTIFACT_COLLECTION"
|
21
18
|
|
22
19
|
|
23
|
-
class
|
24
|
-
|
25
|
-
|
26
|
-
|
20
|
+
class EventTriggeringConditionType(str, Enum):
|
21
|
+
CREATE_ARTIFACT = "CREATE_ARTIFACT"
|
22
|
+
UPDATE_ARTIFACT_ALIAS = "UPDATE_ARTIFACT_ALIAS"
|
23
|
+
ADD_ARTIFACT_ALIAS = "ADD_ARTIFACT_ALIAS"
|
24
|
+
ADD_ARTIFACT_TAG = "ADD_ARTIFACT_TAG"
|
25
|
+
LINK_MODEL = "LINK_MODEL"
|
26
|
+
RUN_METRIC = "RUN_METRIC"
|
27
|
+
RUN_METRIC_CHANGE = "RUN_METRIC_CHANGE"
|
28
|
+
RUN_STATE = "RUN_STATE"
|
27
29
|
|
28
30
|
|
29
31
|
class TriggeredActionType(str, Enum):
|
wandb/bin/gpu_stats.exe
CHANGED
Binary file
|
wandb/bin/wandb-core
CHANGED
Binary file
|
wandb/cli/cli.py
CHANGED
@@ -1906,6 +1906,12 @@ def describe(job):
|
|
1906
1906
|
help="Service configurations in format serviceName=policy. Valid policies: always, never",
|
1907
1907
|
hidden=True,
|
1908
1908
|
)
|
1909
|
+
@click.option(
|
1910
|
+
"--schema",
|
1911
|
+
type=str,
|
1912
|
+
help="Path to the schema file for the job.",
|
1913
|
+
hidden=True,
|
1914
|
+
)
|
1909
1915
|
@click.argument("path")
|
1910
1916
|
def create(
|
1911
1917
|
path,
|
@@ -1922,6 +1928,7 @@ def create(
|
|
1922
1928
|
base_image,
|
1923
1929
|
dockerfile,
|
1924
1930
|
services,
|
1931
|
+
schema,
|
1925
1932
|
):
|
1926
1933
|
"""Create a job from a source, without a wandb run.
|
1927
1934
|
|
@@ -1956,6 +1963,12 @@ def create(
|
|
1956
1963
|
wandb.termerror("Cannot provide --base-image/-B for an `image` job")
|
1957
1964
|
return
|
1958
1965
|
|
1966
|
+
if schema:
|
1967
|
+
schema_dict = util.load_json_yaml_dict(schema)
|
1968
|
+
if schema_dict is None:
|
1969
|
+
wandb.termerror(f"Invalid format for schema file: {schema}")
|
1970
|
+
return
|
1971
|
+
|
1959
1972
|
artifact, action, aliases = _create_job(
|
1960
1973
|
api=api,
|
1961
1974
|
path=path,
|
@@ -1972,6 +1985,7 @@ def create(
|
|
1972
1985
|
base_image=base_image,
|
1973
1986
|
dockerfile=dockerfile,
|
1974
1987
|
services=services,
|
1988
|
+
schema=schema_dict,
|
1975
1989
|
)
|
1976
1990
|
if not artifact:
|
1977
1991
|
wandb.termerror("Job creation failed")
|
@@ -2539,7 +2553,8 @@ def pull(run, project, entity):
|
|
2539
2553
|
|
2540
2554
|
|
2541
2555
|
@cli.command(
|
2542
|
-
context_settings=CONTEXT,
|
2556
|
+
context_settings=CONTEXT,
|
2557
|
+
help="Restore code, config and docker state for a run. Retrieves code from latest commit if code was not saved with `wandb.save()` or `wandb.init(save_code=True)`.",
|
2543
2558
|
)
|
2544
2559
|
@click.pass_context
|
2545
2560
|
@click.argument("run", envvar=env.RUN_ID)
|
@@ -2779,7 +2794,37 @@ def enabled(service):
|
|
2779
2794
|
)
|
2780
2795
|
|
2781
2796
|
|
2782
|
-
@cli.command(
|
2797
|
+
@cli.command(
|
2798
|
+
context_settings=CONTEXT,
|
2799
|
+
help="""Checks and verifies local instance of W&B. W&B checks for:
|
2800
|
+
|
2801
|
+
Checks that the host is not `api.wandb.ai` (host check).
|
2802
|
+
|
2803
|
+
Verifies if the user is logged in correctly using the provided API key (login check).
|
2804
|
+
|
2805
|
+
Checks that requests are made over HTTPS (secure requests).
|
2806
|
+
|
2807
|
+
Validates the CORS (Cross-Origin Resource Sharing) configuration of the
|
2808
|
+
object store (CORS configuration).
|
2809
|
+
|
2810
|
+
Logs metrics, saves, and downloads files to check if runs are correctly
|
2811
|
+
recorded and accessible (run check).
|
2812
|
+
|
2813
|
+
Saves and downloads artifacts to verify that the artifact storage and
|
2814
|
+
retrieval system is working as expected (artifact check).
|
2815
|
+
|
2816
|
+
Tests the GraphQL endpoint by uploading a file to ensure it can handle
|
2817
|
+
signed URL uploads (GraphQL PUT check).
|
2818
|
+
|
2819
|
+
Checks the ability to send large payloads through the proxy (large payload check).
|
2820
|
+
|
2821
|
+
Verifies that the installed version of the W&B package is up-to-date and
|
2822
|
+
compatible with the server (W&B version check).
|
2823
|
+
|
2824
|
+
Creates and executes a sweep to ensure that sweep functionality is
|
2825
|
+
working correctly (sweeps check).
|
2826
|
+
""",
|
2827
|
+
)
|
2783
2828
|
@click.option("--host", default=None, help="Test a specific instance of W&B")
|
2784
2829
|
def verify(host):
|
2785
2830
|
# TODO: (kdg) Build this all into a WandbVerify object, and clean this up.
|
@@ -5,7 +5,7 @@ May raise MissingDependencyError on import.
|
|
5
5
|
|
6
6
|
from __future__ import annotations
|
7
7
|
|
8
|
-
from typing_extensions import Any,
|
8
|
+
from typing_extensions import Any, TypeIs
|
9
9
|
|
10
10
|
import wandb
|
11
11
|
|
@@ -21,7 +21,7 @@ except ImportError as e:
|
|
21
21
|
raise errors.MissingDependencyError(warning=warning) from e
|
22
22
|
|
23
23
|
|
24
|
-
def is_dataframe(data: Any) ->
|
24
|
+
def is_dataframe(data: Any) -> TypeIs[pd.DataFrame]:
|
25
25
|
"""Returns whether the data is a Pandas DataFrame."""
|
26
26
|
return isinstance(data, pd.DataFrame)
|
27
27
|
|
@@ -0,0 +1,75 @@
|
|
1
|
+
"""Support for PyTorch datatypes.
|
2
|
+
|
3
|
+
May raise MissingDependencyError on import.
|
4
|
+
"""
|
5
|
+
|
6
|
+
from __future__ import annotations
|
7
|
+
|
8
|
+
from typing_extensions import Any, TypeIs
|
9
|
+
|
10
|
+
import wandb
|
11
|
+
|
12
|
+
from . import errors
|
13
|
+
|
14
|
+
try:
|
15
|
+
import torch
|
16
|
+
import torch.nn as nn
|
17
|
+
except ImportError as e:
|
18
|
+
warning = (
|
19
|
+
"`torch` (PyTorch) not installed >>"
|
20
|
+
" @wandb_log(models=True) may not auto log your model!"
|
21
|
+
)
|
22
|
+
raise errors.MissingDependencyError(warning=warning) from e
|
23
|
+
|
24
|
+
|
25
|
+
def is_nn_module(data: Any) -> TypeIs[nn.Module]:
|
26
|
+
"""Returns whether the data is a PyTorch nn.Module."""
|
27
|
+
return isinstance(data, nn.Module)
|
28
|
+
|
29
|
+
|
30
|
+
def use_nn_module(
|
31
|
+
name: str,
|
32
|
+
run: wandb.Run | None,
|
33
|
+
testing: bool = False,
|
34
|
+
) -> str | None:
|
35
|
+
"""Log a dependency on a PyTorch model input.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
name: Name of the input.
|
39
|
+
run: The run to update.
|
40
|
+
testing: True in unit tests.
|
41
|
+
"""
|
42
|
+
if testing:
|
43
|
+
return "models"
|
44
|
+
assert run
|
45
|
+
|
46
|
+
wandb.termlog(f"Using artifact: {name} (PyTorch nn.Module)")
|
47
|
+
run.use_artifact(f"{name}:latest")
|
48
|
+
return None
|
49
|
+
|
50
|
+
|
51
|
+
def track_nn_module(
|
52
|
+
name: str,
|
53
|
+
data: nn.Module,
|
54
|
+
run: wandb.Run | None,
|
55
|
+
testing: bool = False,
|
56
|
+
) -> str | None:
|
57
|
+
"""Log a PyTorch model output as an artifact.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
name: The output's name.
|
61
|
+
data: The output's value.
|
62
|
+
run: The run to update.
|
63
|
+
testing: True in unit tests.
|
64
|
+
"""
|
65
|
+
if testing:
|
66
|
+
return "nn.Module"
|
67
|
+
assert run
|
68
|
+
|
69
|
+
artifact = wandb.Artifact(name, type="model")
|
70
|
+
with artifact.new_file(f"{name}.pkl", "wb") as f:
|
71
|
+
torch.save(data, f)
|
72
|
+
|
73
|
+
wandb.termlog(f"Logging artifact: {name} (PyTorch nn.Module)")
|
74
|
+
run.log_artifact(artifact)
|
75
|
+
return None
|
@@ -0,0 +1,76 @@
|
|
1
|
+
"""Support for sklearn datatypes.
|
2
|
+
|
3
|
+
May raise MissingDependencyError on import.
|
4
|
+
"""
|
5
|
+
|
6
|
+
from __future__ import annotations
|
7
|
+
|
8
|
+
import pickle
|
9
|
+
|
10
|
+
from typing_extensions import Any, TypeIs
|
11
|
+
|
12
|
+
import wandb
|
13
|
+
|
14
|
+
from . import errors
|
15
|
+
|
16
|
+
try:
|
17
|
+
from sklearn.base import BaseEstimator
|
18
|
+
except ImportError as e:
|
19
|
+
warning = (
|
20
|
+
"`sklearn` not installed >>"
|
21
|
+
" @wandb_log(models=True) may not auto log your model!"
|
22
|
+
)
|
23
|
+
raise errors.MissingDependencyError(warning=warning) from e
|
24
|
+
|
25
|
+
|
26
|
+
def is_estimator(data: Any) -> TypeIs[BaseEstimator]:
|
27
|
+
"""Returns whether the data is an sklearn BaseEstimator."""
|
28
|
+
return isinstance(data, BaseEstimator)
|
29
|
+
|
30
|
+
|
31
|
+
def use_estimator(
|
32
|
+
name: str,
|
33
|
+
run: wandb.Run | None,
|
34
|
+
testing: bool = False,
|
35
|
+
) -> str | None:
|
36
|
+
"""Log a dependency on an sklearn estimator.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
name: Name of the input.
|
40
|
+
run: The run to update.
|
41
|
+
testing: True in unit tests.
|
42
|
+
"""
|
43
|
+
if testing:
|
44
|
+
return "models"
|
45
|
+
assert run
|
46
|
+
|
47
|
+
wandb.termlog(f"Using artifact: {name} (sklearn BaseEstimator)")
|
48
|
+
run.use_artifact(f"{name}:latest")
|
49
|
+
return None
|
50
|
+
|
51
|
+
|
52
|
+
def track_estimator(
|
53
|
+
name: str,
|
54
|
+
data: BaseEstimator,
|
55
|
+
run: wandb.Run | None,
|
56
|
+
testing: bool = False,
|
57
|
+
) -> str | None:
|
58
|
+
"""Log an sklearn estimator output as an artifact.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
name: The output's name.
|
62
|
+
data: The output's value.
|
63
|
+
run: The run to update.
|
64
|
+
testing: True in unit tests.
|
65
|
+
"""
|
66
|
+
if testing:
|
67
|
+
return "BaseEstimator"
|
68
|
+
assert run
|
69
|
+
|
70
|
+
artifact = wandb.Artifact(name, type="model")
|
71
|
+
with artifact.new_file(f"{name}.pkl", "wb") as f:
|
72
|
+
pickle.dump(data, f)
|
73
|
+
|
74
|
+
wandb.termlog(f"Logging artifact: {name} (sklearn BaseEstimator)")
|
75
|
+
run.log_artifact(artifact)
|
76
|
+
return None
|
@@ -18,10 +18,6 @@ except ImportError as e:
|
|
18
18
|
) from e
|
19
19
|
|
20
20
|
|
21
|
-
# Classes for isinstance() checks.
|
22
|
-
_NN_MODULE = None
|
23
|
-
_BASE_ESTIMATOR = None
|
24
|
-
|
25
21
|
try:
|
26
22
|
from . import data_pandas
|
27
23
|
except errors.MissingDependencyError as e:
|
@@ -29,83 +25,16 @@ except errors.MissingDependencyError as e:
|
|
29
25
|
data_pandas = None
|
30
26
|
|
31
27
|
try:
|
32
|
-
import
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
def _use_torch_module(
|
38
|
-
name: str,
|
39
|
-
data: nn.Module,
|
40
|
-
run,
|
41
|
-
testing: bool = False,
|
42
|
-
) -> Optional[str]:
|
43
|
-
if testing:
|
44
|
-
return "models"
|
45
|
-
|
46
|
-
run.use_artifact(f"{name}:latest")
|
47
|
-
wandb.termlog(f"Using artifact: {name} ({type(data)})")
|
48
|
-
return None
|
49
|
-
|
50
|
-
def _track_torch_module(
|
51
|
-
name: str,
|
52
|
-
data: nn.Module,
|
53
|
-
run,
|
54
|
-
testing: bool = False,
|
55
|
-
) -> Optional[str]:
|
56
|
-
if testing:
|
57
|
-
return "nn.Module"
|
58
|
-
|
59
|
-
artifact = wandb.Artifact(name, type="model")
|
60
|
-
with artifact.new_file(f"{name}.pkl", "wb") as f:
|
61
|
-
torch.save(data, f)
|
62
|
-
run.log_artifact(artifact)
|
63
|
-
wandb.termlog(f"Logging artifact: {name} ({type(data)})")
|
64
|
-
return None
|
65
|
-
|
66
|
-
except ImportError:
|
67
|
-
wandb.termwarn(
|
68
|
-
"`pytorch` not installed >> @wandb_log(models=True) may not auto log your model!"
|
69
|
-
)
|
28
|
+
from . import data_pytorch
|
29
|
+
except errors.MissingDependencyError as e:
|
30
|
+
e.warn()
|
31
|
+
data_pytorch = None
|
70
32
|
|
71
33
|
try:
|
72
|
-
from
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
def _use_sklearn_estimator(
|
77
|
-
name: str,
|
78
|
-
data: BaseEstimator,
|
79
|
-
run,
|
80
|
-
testing: bool = False,
|
81
|
-
) -> Optional[str]:
|
82
|
-
if testing:
|
83
|
-
return "models"
|
84
|
-
|
85
|
-
run.use_artifact(f"{name}:latest")
|
86
|
-
wandb.termlog(f"Using artifact: {name} ({type(data)})")
|
87
|
-
return None
|
88
|
-
|
89
|
-
def _track_sklearn_estimator(
|
90
|
-
name: str,
|
91
|
-
data: BaseEstimator,
|
92
|
-
run,
|
93
|
-
testing: bool = False,
|
94
|
-
) -> Optional[str]:
|
95
|
-
if testing:
|
96
|
-
return "BaseEstimator"
|
97
|
-
|
98
|
-
artifact = wandb.Artifact(name, type="model")
|
99
|
-
with artifact.new_file(f"{name}.pkl", "wb") as f:
|
100
|
-
pickle.dump(data, f)
|
101
|
-
run.log_artifact(artifact)
|
102
|
-
wandb.termlog(f"Logging artifact: {name} ({type(data)})")
|
103
|
-
return None
|
104
|
-
|
105
|
-
except ImportError:
|
106
|
-
wandb.termwarn(
|
107
|
-
"`sklearn` not installed >> @wandb_log(models=True) may not auto log your model!"
|
108
|
-
)
|
34
|
+
from . import data_sklearn
|
35
|
+
except errors.MissingDependencyError as e:
|
36
|
+
e.warn()
|
37
|
+
data_sklearn = None
|
109
38
|
|
110
39
|
|
111
40
|
class ArtifactProxy:
|
@@ -195,12 +124,12 @@ def wandb_track(
|
|
195
124
|
return data_pandas.track_dataframe(name, data, run, testing)
|
196
125
|
|
197
126
|
# Check for PyTorch Module
|
198
|
-
if
|
199
|
-
return
|
127
|
+
if data_pytorch and data_pytorch.is_nn_module(data) and models:
|
128
|
+
return data_pytorch.track_nn_module(name, data, run, testing)
|
200
129
|
|
201
130
|
# Check for scikit-learn BaseEstimator
|
202
|
-
if
|
203
|
-
return
|
131
|
+
if data_sklearn and data_sklearn.is_estimator(data) and models:
|
132
|
+
return data_sklearn.track_estimator(name, data, run, testing)
|
204
133
|
|
205
134
|
# Check for Path objects
|
206
135
|
if isinstance(data, Path) and datasets:
|
@@ -238,12 +167,12 @@ def wandb_use(
|
|
238
167
|
return data_pandas.use_dataframe(name, run, testing)
|
239
168
|
|
240
169
|
# Check for PyTorch Module
|
241
|
-
elif
|
242
|
-
return
|
170
|
+
elif data_pytorch and data_pytorch.is_nn_module(data) and models:
|
171
|
+
return data_pytorch.use_nn_module(name, run, testing)
|
243
172
|
|
244
173
|
# Check for scikit-learn BaseEstimator
|
245
|
-
elif
|
246
|
-
return
|
174
|
+
elif data_sklearn and data_sklearn.is_estimator(data) and models:
|
175
|
+
return data_sklearn.use_estimator(name, run, testing)
|
247
176
|
|
248
177
|
# Check for Path objects
|
249
178
|
elif isinstance(data, Path) and datasets:
|
@@ -0,0 +1,49 @@
|
|
1
|
+
"""Internal APIs for integrating with weave.
|
2
|
+
|
3
|
+
The public functions here are intended to be called by weave and care should
|
4
|
+
be taken to maintain backward compatibility.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from __future__ import annotations
|
8
|
+
|
9
|
+
import dataclasses
|
10
|
+
|
11
|
+
from wandb.sdk import wandb_setup
|
12
|
+
|
13
|
+
|
14
|
+
@dataclasses.dataclass(frozen=True)
|
15
|
+
class RunPath:
|
16
|
+
entity: str
|
17
|
+
"""The entity to which the run is logging. Never empty."""
|
18
|
+
|
19
|
+
project: str
|
20
|
+
"""The project to which the run is logging. Never empty."""
|
21
|
+
|
22
|
+
run_id: str
|
23
|
+
"""The run's ID. Never empty."""
|
24
|
+
|
25
|
+
|
26
|
+
def active_run_path() -> RunPath | None:
|
27
|
+
"""Returns the path of an initialized, unfinished run.
|
28
|
+
|
29
|
+
Returns None if all initialized runs are finished. If there is
|
30
|
+
more than one active run, an arbitrary path is returned.
|
31
|
+
The run may be finished by the time its path is returned.
|
32
|
+
|
33
|
+
Thread-safe.
|
34
|
+
"""
|
35
|
+
singleton = wandb_setup.singleton()
|
36
|
+
|
37
|
+
if (
|
38
|
+
(run := singleton.most_recent_active_run)
|
39
|
+
and run.entity
|
40
|
+
and run.project
|
41
|
+
and run.id
|
42
|
+
):
|
43
|
+
return RunPath(
|
44
|
+
entity=run.entity,
|
45
|
+
project=run.project,
|
46
|
+
run_id=run.id,
|
47
|
+
)
|
48
|
+
|
49
|
+
return None
|
@@ -0,0 +1,63 @@
|
|
1
|
+
"""Integration module for automatic Weave initialization with W&B.
|
2
|
+
|
3
|
+
This module provides automatic initialization of Weave when:
|
4
|
+
1. Weave is installed
|
5
|
+
2. A W&B run is active with a project
|
6
|
+
3. Weave is imported (init-on-import)
|
7
|
+
|
8
|
+
The integration can be disabled by setting the WANDB_DISABLE_WEAVE environment variable.
|
9
|
+
"""
|
10
|
+
|
11
|
+
from __future__ import annotations
|
12
|
+
|
13
|
+
import os
|
14
|
+
import sys
|
15
|
+
import threading
|
16
|
+
|
17
|
+
import wandb
|
18
|
+
|
19
|
+
_weave_init_lock = threading.Lock()
|
20
|
+
|
21
|
+
_DISABLE_WEAVE = "WANDB_DISABLE_WEAVE"
|
22
|
+
_WEAVE_PACKAGE_NAME = "weave"
|
23
|
+
|
24
|
+
|
25
|
+
def setup(entity: str | None, project: str | None) -> None:
|
26
|
+
"""Set up automatic Weave initialization for the current W&B run.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
project: The W&B project name to use for Weave initialization.
|
30
|
+
"""
|
31
|
+
# We can't or shouldn't init weave; return
|
32
|
+
if os.getenv(_DISABLE_WEAVE):
|
33
|
+
return
|
34
|
+
if not project:
|
35
|
+
return
|
36
|
+
|
37
|
+
# Use entity/project when available; otherwise fall back to project only
|
38
|
+
if entity:
|
39
|
+
project_path = f"{entity}/{project}"
|
40
|
+
else:
|
41
|
+
project_path = project
|
42
|
+
|
43
|
+
# If weave is not yet imported, we can't init it from here. Instead, we'll
|
44
|
+
# rely on the weave library itself to detect a run and init itself.
|
45
|
+
if _WEAVE_PACKAGE_NAME not in sys.modules:
|
46
|
+
return
|
47
|
+
|
48
|
+
# If weave has already been imported, initialize immediately
|
49
|
+
with _weave_init_lock:
|
50
|
+
try:
|
51
|
+
# This import should have already happened, so it's effectively a no-op.
|
52
|
+
# We just import to keep the symbol for the init that follows
|
53
|
+
import weave
|
54
|
+
except ImportError:
|
55
|
+
# This should never happen; but we don't raise here to avoid
|
56
|
+
# breaking the wandb run init flow just in case
|
57
|
+
return
|
58
|
+
|
59
|
+
wandb.termlog("Initializing weave.")
|
60
|
+
try:
|
61
|
+
weave.init(project_path)
|
62
|
+
except Exception as e:
|
63
|
+
wandb.termwarn(f"Failed to automatically initialize Weave: {e}")
|