wandb 0.18.2__py3-none-any.whl → 0.18.4__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +16 -7
- wandb/__init__.pyi +96 -63
- wandb/analytics/sentry.py +91 -88
- wandb/apis/public/api.py +18 -4
- wandb/apis/public/runs.py +53 -2
- wandb/bin/gpu_stats +0 -0
- wandb/cli/beta.py +178 -0
- wandb/cli/cli.py +5 -171
- wandb/data_types.py +3 -0
- wandb/env.py +74 -73
- wandb/errors/term.py +300 -43
- wandb/proto/v3/wandb_internal_pb2.py +271 -221
- wandb/proto/v3/wandb_server_pb2.py +57 -37
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_internal_pb2.py +226 -216
- wandb/proto/v4/wandb_server_pb2.py +41 -37
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_internal_pb2.py +226 -216
- wandb/proto/v5/wandb_server_pb2.py +41 -37
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/sdk/__init__.py +3 -3
- wandb/sdk/artifacts/_validators.py +41 -8
- wandb/sdk/artifacts/artifact.py +35 -4
- wandb/sdk/artifacts/artifact_file_cache.py +1 -2
- wandb/sdk/data_types/_dtypes.py +7 -3
- wandb/sdk/data_types/video.py +15 -6
- wandb/sdk/interface/interface.py +2 -0
- wandb/sdk/internal/internal_api.py +122 -5
- wandb/sdk/internal/sender.py +16 -3
- wandb/sdk/launch/inputs/internal.py +1 -1
- wandb/sdk/lib/module.py +12 -0
- wandb/sdk/lib/printer.py +291 -105
- wandb/sdk/lib/progress.py +274 -0
- wandb/sdk/service/streams.py +21 -11
- wandb/sdk/wandb_init.py +59 -54
- wandb/sdk/wandb_run.py +413 -480
- wandb/sdk/wandb_settings.py +2 -0
- wandb/sdk/wandb_watch.py +17 -11
- wandb/util.py +6 -2
- {wandb-0.18.2.dist-info → wandb-0.18.4.dist-info}/METADATA +5 -4
- {wandb-0.18.2.dist-info → wandb-0.18.4.dist-info}/RECORD +44 -42
- wandb/bin/nvidia_gpu_stats +0 -0
- {wandb-0.18.2.dist-info → wandb-0.18.4.dist-info}/WHEEL +0 -0
- {wandb-0.18.2.dist-info → wandb-0.18.4.dist-info}/entry_points.txt +0 -0
- {wandb-0.18.2.dist-info → wandb-0.18.4.dist-info}/licenses/LICENSE +0 -0
wandb/__init__.py
CHANGED
@@ -8,9 +8,10 @@ For scripts and interactive notebooks, see https://github.com/wandb/examples.
|
|
8
8
|
|
9
9
|
For reference documentation, see https://docs.wandb.com/ref/python.
|
10
10
|
"""
|
11
|
-
|
11
|
+
from __future__ import annotations
|
12
|
+
|
13
|
+
__version__ = "0.18.4"
|
12
14
|
|
13
|
-
from typing import Optional
|
14
15
|
|
15
16
|
from wandb.errors import Error
|
16
17
|
|
@@ -28,8 +29,6 @@ setup = wandb_sdk.setup
|
|
28
29
|
_attach = wandb_sdk._attach
|
29
30
|
_sync = wandb_sdk._sync
|
30
31
|
_teardown = wandb_sdk.teardown
|
31
|
-
watch = wandb_sdk.watch
|
32
|
-
unwatch = wandb_sdk.unwatch
|
33
32
|
finish = wandb_sdk.finish
|
34
33
|
join = finish
|
35
34
|
login = wandb_sdk.login
|
@@ -112,10 +111,12 @@ def _assert_is_user_process():
|
|
112
111
|
# globals
|
113
112
|
Api = PublicApi
|
114
113
|
api = InternalApi()
|
115
|
-
run:
|
114
|
+
run: wandb_sdk.wandb_run.Run | None = None
|
116
115
|
config = _preinit.PreInitObject("wandb.config", wandb_sdk.wandb_config.Config)
|
117
116
|
summary = _preinit.PreInitObject("wandb.summary", wandb_sdk.wandb_summary.Summary)
|
118
117
|
log = _preinit.PreInitCallable("wandb.log", wandb_sdk.wandb_run.Run.log) # type: ignore
|
118
|
+
watch = _preinit.PreInitCallable("wandb.watch", wandb_sdk.wandb_run.Run.watch) # type: ignore
|
119
|
+
unwatch = _preinit.PreInitCallable("wandb.unwatch", wandb_sdk.wandb_run.Run.unwatch) # type: ignore
|
119
120
|
save = _preinit.PreInitCallable("wandb.save", wandb_sdk.wandb_run.Run.save) # type: ignore
|
120
121
|
restore = wandb_sdk.wandb_run.restore
|
121
122
|
use_artifact = _preinit.PreInitCallable(
|
@@ -200,9 +201,16 @@ if "dev" in __version__:
|
|
200
201
|
import wandb.env
|
201
202
|
import os
|
202
203
|
|
203
|
-
#
|
204
|
+
# Disable error reporting in dev versions.
|
204
205
|
os.environ[wandb.env.ERROR_REPORTING] = os.environ.get(
|
205
|
-
wandb.env.ERROR_REPORTING,
|
206
|
+
wandb.env.ERROR_REPORTING,
|
207
|
+
"false",
|
208
|
+
)
|
209
|
+
|
210
|
+
# Enable new features in dev versions.
|
211
|
+
os.environ["WANDB__SHOW_OPERATION_STATS"] = os.environ.get(
|
212
|
+
"WANDB__SHOW_OPERATION_STATS",
|
213
|
+
"true",
|
206
214
|
)
|
207
215
|
|
208
216
|
_sentry = _Sentry()
|
@@ -242,4 +250,5 @@ __all__ = (
|
|
242
250
|
"link_model",
|
243
251
|
"define_metric",
|
244
252
|
"watch",
|
253
|
+
"unwatch",
|
245
254
|
)
|
wandb/__init__.pyi
CHANGED
@@ -52,6 +52,7 @@ __all__ = (
|
|
52
52
|
"Settings",
|
53
53
|
"teardown",
|
54
54
|
"watch",
|
55
|
+
"unwatch",
|
55
56
|
)
|
56
57
|
|
57
58
|
import os
|
@@ -95,12 +96,14 @@ from wandb.wandb_controller import _WandbController
|
|
95
96
|
if TYPE_CHECKING:
|
96
97
|
import torch # type: ignore [import-not-found]
|
97
98
|
|
98
|
-
|
99
|
+
from wandb.plot.viz import CustomChart
|
100
|
+
|
101
|
+
__version__: str = "0.18.4"
|
99
102
|
|
100
103
|
run: Run | None
|
101
104
|
config: wandb_config.Config
|
102
105
|
summary: wandb_summary.Summary
|
103
|
-
Api: PublicApi
|
106
|
+
Api: type[PublicApi]
|
104
107
|
|
105
108
|
# private attributes
|
106
109
|
_sentry: Sentry
|
@@ -181,32 +184,32 @@ def teardown(exit_code: Optional[int] = None) -> None:
|
|
181
184
|
...
|
182
185
|
|
183
186
|
def init(
|
184
|
-
job_type:
|
185
|
-
dir:
|
186
|
-
config:
|
187
|
-
project:
|
188
|
-
entity:
|
189
|
-
reinit:
|
190
|
-
tags:
|
191
|
-
group:
|
192
|
-
name:
|
193
|
-
notes:
|
194
|
-
magic:
|
195
|
-
config_exclude_keys:
|
196
|
-
config_include_keys:
|
197
|
-
anonymous:
|
198
|
-
mode:
|
199
|
-
allow_val_change:
|
200
|
-
resume:
|
201
|
-
force:
|
202
|
-
tensorboard:
|
203
|
-
sync_tensorboard:
|
204
|
-
monitor_gym:
|
205
|
-
save_code:
|
206
|
-
id:
|
207
|
-
fork_from:
|
208
|
-
resume_from:
|
209
|
-
settings:
|
187
|
+
job_type: str | None = None,
|
188
|
+
dir: StrPath | None = None,
|
189
|
+
config: dict | str | None = None,
|
190
|
+
project: str | None = None,
|
191
|
+
entity: str | None = None,
|
192
|
+
reinit: bool | None = None,
|
193
|
+
tags: Sequence | None = None,
|
194
|
+
group: str | None = None,
|
195
|
+
name: str | None = None,
|
196
|
+
notes: str | None = None,
|
197
|
+
magic: dict | str | bool | None = None,
|
198
|
+
config_exclude_keys: list[str] | None = None,
|
199
|
+
config_include_keys: list[str] | None = None,
|
200
|
+
anonymous: str | None = None,
|
201
|
+
mode: str | None = None,
|
202
|
+
allow_val_change: bool | None = None,
|
203
|
+
resume: bool | str | None = None,
|
204
|
+
force: bool | None = None,
|
205
|
+
tensorboard: bool | None = None, # alias for sync_tensorboard
|
206
|
+
sync_tensorboard: bool | None = None,
|
207
|
+
monitor_gym: bool | None = None,
|
208
|
+
save_code: bool | None = None,
|
209
|
+
id: str | None = None,
|
210
|
+
fork_from: str | None = None,
|
211
|
+
resume_from: str | None = None,
|
212
|
+
settings: Settings | dict[str, Any] | None = None,
|
210
213
|
) -> Run:
|
211
214
|
r"""Start a new run to track and log to W&B.
|
212
215
|
|
@@ -420,7 +423,7 @@ def init(
|
|
420
423
|
"""
|
421
424
|
...
|
422
425
|
|
423
|
-
def finish(exit_code:
|
426
|
+
def finish(exit_code: int | None = None, quiet: bool | None = None) -> None:
|
424
427
|
"""Mark a run as finished, and finish uploading all data.
|
425
428
|
|
426
429
|
This is used when creating multiple runs in the same process.
|
@@ -470,10 +473,10 @@ def login(
|
|
470
473
|
...
|
471
474
|
|
472
475
|
def log(
|
473
|
-
data:
|
474
|
-
step:
|
475
|
-
commit:
|
476
|
-
sync:
|
476
|
+
data: dict[str, Any],
|
477
|
+
step: int | None = None,
|
478
|
+
commit: bool | None = None,
|
479
|
+
sync: bool | None = None,
|
477
480
|
) -> None:
|
478
481
|
"""Upload run data.
|
479
482
|
|
@@ -704,10 +707,10 @@ def log(
|
|
704
707
|
...
|
705
708
|
|
706
709
|
def save(
|
707
|
-
glob_str:
|
708
|
-
base_path:
|
710
|
+
glob_str: str | os.PathLike | None = None,
|
711
|
+
base_path: str | os.PathLike | None = None,
|
709
712
|
policy: PolicyName = "live",
|
710
|
-
) ->
|
713
|
+
) -> bool | list[str]:
|
711
714
|
"""Sync one or more files to W&B.
|
712
715
|
|
713
716
|
Relative paths are relative to the current working directory.
|
@@ -846,12 +849,12 @@ def agent(
|
|
846
849
|
|
847
850
|
def define_metric(
|
848
851
|
name: str,
|
849
|
-
step_metric:
|
850
|
-
step_sync:
|
851
|
-
hidden:
|
852
|
-
summary:
|
853
|
-
goal:
|
854
|
-
overwrite:
|
852
|
+
step_metric: str | wandb_metric.Metric | None = None,
|
853
|
+
step_sync: bool | None = None,
|
854
|
+
hidden: bool | None = None,
|
855
|
+
summary: str | None = None,
|
856
|
+
goal: str | None = None,
|
857
|
+
overwrite: bool | None = None,
|
855
858
|
) -> wandb_metric.Metric:
|
856
859
|
"""Customize metrics logged with `wandb.log()`.
|
857
860
|
|
@@ -882,11 +885,11 @@ def define_metric(
|
|
882
885
|
...
|
883
886
|
|
884
887
|
def log_artifact(
|
885
|
-
artifact_or_path:
|
886
|
-
name:
|
887
|
-
type:
|
888
|
-
aliases:
|
889
|
-
tags:
|
888
|
+
artifact_or_path: Artifact | StrPath,
|
889
|
+
name: str | None = None,
|
890
|
+
type: str | None = None,
|
891
|
+
aliases: list[str] | None = None,
|
892
|
+
tags: list[str] | None = None,
|
890
893
|
) -> Artifact:
|
891
894
|
"""Declare an artifact as an output of a run.
|
892
895
|
|
@@ -915,10 +918,10 @@ def log_artifact(
|
|
915
918
|
...
|
916
919
|
|
917
920
|
def use_artifact(
|
918
|
-
artifact_or_name:
|
919
|
-
type:
|
920
|
-
aliases:
|
921
|
-
use_as:
|
921
|
+
artifact_or_name: str | Artifact,
|
922
|
+
type: str | None = None,
|
923
|
+
aliases: list[str] | None = None,
|
924
|
+
use_as: str | None = None,
|
922
925
|
) -> Artifact:
|
923
926
|
"""Declare an artifact as an input to a run.
|
924
927
|
|
@@ -926,8 +929,9 @@ def use_artifact(
|
|
926
929
|
|
927
930
|
Arguments:
|
928
931
|
artifact_or_name: (str or Artifact) An artifact name.
|
929
|
-
May be prefixed with entity/project/.
|
930
|
-
|
932
|
+
May be prefixed with project/ or entity/project/.
|
933
|
+
If no entity is specified in the name, the Run or API setting's entity is used.
|
934
|
+
Valid names can be in the following forms:
|
931
935
|
- name:version
|
932
936
|
- name:alias
|
933
937
|
You can also pass an Artifact object created by calling `wandb.Artifact`
|
@@ -943,8 +947,8 @@ def use_artifact(
|
|
943
947
|
|
944
948
|
def log_model(
|
945
949
|
path: StrPath,
|
946
|
-
name:
|
947
|
-
aliases:
|
950
|
+
name: str | None = None,
|
951
|
+
aliases: list[str] | None = None,
|
948
952
|
) -> None:
|
949
953
|
"""Logs a model artifact containing the contents inside the 'path' to a run and marks it as an output to this run.
|
950
954
|
|
@@ -1031,8 +1035,8 @@ def use_model(name: str) -> FilePathStr:
|
|
1031
1035
|
def link_model(
|
1032
1036
|
path: StrPath,
|
1033
1037
|
registered_model_name: str,
|
1034
|
-
name:
|
1035
|
-
aliases:
|
1038
|
+
name: str | None = None,
|
1039
|
+
aliases: list[str] | None = None,
|
1036
1040
|
) -> None:
|
1037
1041
|
"""Log a model artifact version and link it to a registered model in the model registry.
|
1038
1042
|
|
@@ -1099,6 +1103,28 @@ def link_model(
|
|
1099
1103
|
"""
|
1100
1104
|
...
|
1101
1105
|
|
1106
|
+
def plot_table(
|
1107
|
+
vega_spec_name: str,
|
1108
|
+
data_table: Table,
|
1109
|
+
fields: dict[str, Any],
|
1110
|
+
string_fields: dict[str, Any] | None = None,
|
1111
|
+
split_table: bool | None = False,
|
1112
|
+
) -> CustomChart:
|
1113
|
+
"""Create a custom plot on a table.
|
1114
|
+
|
1115
|
+
Arguments:
|
1116
|
+
vega_spec_name: the name of the spec for the plot
|
1117
|
+
data_table: a wandb.Table object containing the data to
|
1118
|
+
be used on the visualization
|
1119
|
+
fields: a dict mapping from table keys to fields that the custom
|
1120
|
+
visualization needs
|
1121
|
+
string_fields: a dict that provides values for any string constants
|
1122
|
+
the custom visualization needs
|
1123
|
+
split_table: a boolean that indicates whether the table should be in
|
1124
|
+
a separate section in the UI
|
1125
|
+
"""
|
1126
|
+
...
|
1127
|
+
|
1102
1128
|
def watch(
|
1103
1129
|
models: torch.nn.Module | Sequence[torch.nn.Module],
|
1104
1130
|
criterion: torch.F | None = None,
|
@@ -1106,7 +1132,7 @@ def watch(
|
|
1106
1132
|
log_freq: int = 1000,
|
1107
1133
|
idx: int | None = None,
|
1108
1134
|
log_graph: bool = False,
|
1109
|
-
) ->
|
1135
|
+
) -> None:
|
1110
1136
|
"""Hooks into the given PyTorch model(s) to monitor gradients and the model's computational graph.
|
1111
1137
|
|
1112
1138
|
This function can track parameters, gradients, or both during training. It should be
|
@@ -1124,16 +1150,23 @@ def watch(
|
|
1124
1150
|
Frequency (in batches) to log gradients and parameters. (default=1000)
|
1125
1151
|
idx (Optional[int]):
|
1126
1152
|
Index used when tracking multiple models with `wandb.watch`. (default=None)
|
1127
|
-
|
1153
|
+
log_graph (bool):
|
1128
1154
|
Whether to log the model's computational graph. (default=False)
|
1129
1155
|
|
1130
|
-
Returns:
|
1131
|
-
wandb.Graph:
|
1132
|
-
The graph object, which will be populated after the first backward pass.
|
1133
|
-
|
1134
1156
|
Raises:
|
1135
1157
|
ValueError:
|
1136
1158
|
If `wandb.init` has not been called or if any of the models are not instances
|
1137
1159
|
of `torch.nn.Module`.
|
1138
1160
|
"""
|
1139
1161
|
...
|
1162
|
+
|
1163
|
+
def unwatch(
|
1164
|
+
models: torch.nn.Module | Sequence[torch.nn.Module] | None = None,
|
1165
|
+
) -> None:
|
1166
|
+
"""Remove pytorch model topology, gradient and parameter hooks.
|
1167
|
+
|
1168
|
+
Args:
|
1169
|
+
models (torch.nn.Module | Sequence[torch.nn.Module]):
|
1170
|
+
Optional list of pytorch models that have had watch called on them
|
1171
|
+
"""
|
1172
|
+
...
|
wandb/analytics/sentry.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
__all__ = ("Sentry",)
|
2
4
|
|
3
5
|
|
@@ -7,7 +9,7 @@ import os
|
|
7
9
|
import pathlib
|
8
10
|
import sys
|
9
11
|
from types import TracebackType
|
10
|
-
from typing import TYPE_CHECKING, Any, Callable
|
12
|
+
from typing import TYPE_CHECKING, Any, Callable
|
11
13
|
from urllib.parse import quote
|
12
14
|
|
13
15
|
if sys.version_info >= (3, 8):
|
@@ -16,6 +18,7 @@ else:
|
|
16
18
|
from typing_extensions import Literal
|
17
19
|
|
18
20
|
import sentry_sdk # type: ignore
|
21
|
+
import sentry_sdk.scope # type: ignore
|
19
22
|
import sentry_sdk.utils # type: ignore
|
20
23
|
|
21
24
|
import wandb
|
@@ -36,7 +39,7 @@ def _safe_noop(func: Callable) -> Callable:
|
|
36
39
|
"""Decorator to ensure that Sentry methods do nothing if disabled and don't raise."""
|
37
40
|
|
38
41
|
@functools.wraps(func)
|
39
|
-
def wrapper(self:
|
42
|
+
def wrapper(self: type[Sentry], *args: Any, **kwargs: Any) -> Any:
|
40
43
|
if self._disabled:
|
41
44
|
return None
|
42
45
|
try:
|
@@ -59,7 +62,7 @@ class Sentry:
|
|
59
62
|
|
60
63
|
self.dsn = os.environ.get(wandb.env.SENTRY_DSN, SENTRY_DEFAULT_DSN)
|
61
64
|
|
62
|
-
self.
|
65
|
+
self.scope: sentry_sdk.scope.Scope | None = None
|
63
66
|
|
64
67
|
# ensure we always end the Sentry session
|
65
68
|
atexit.register(self.end_session)
|
@@ -87,47 +90,50 @@ class Sentry:
|
|
87
90
|
environment=self.environment,
|
88
91
|
release=wandb.__version__,
|
89
92
|
)
|
90
|
-
self.
|
93
|
+
self.scope = sentry_sdk.get_global_scope().fork()
|
94
|
+
self.scope.clear()
|
95
|
+
self.scope.set_client(client)
|
91
96
|
|
92
97
|
@_safe_noop
|
93
|
-
def message(self, message: str, repeat: bool = True) -> None:
|
98
|
+
def message(self, message: str, repeat: bool = True) -> str | None:
|
94
99
|
"""Send a message to Sentry."""
|
95
100
|
if not repeat and message in self._sent_messages:
|
96
|
-
return
|
101
|
+
return None
|
97
102
|
self._sent_messages.add(message)
|
98
|
-
|
103
|
+
with sentry_sdk.scope.use_isolation_scope(self.scope): # type: ignore
|
104
|
+
return sentry_sdk.capture_message(message) # type: ignore
|
99
105
|
|
100
106
|
@_safe_noop
|
101
107
|
def exception(
|
102
108
|
self,
|
103
|
-
exc:
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
None,
|
112
|
-
],
|
109
|
+
exc: str
|
110
|
+
| BaseException
|
111
|
+
| tuple[
|
112
|
+
type[BaseException] | None,
|
113
|
+
BaseException | None,
|
114
|
+
TracebackType | None,
|
115
|
+
]
|
116
|
+
| None,
|
113
117
|
handled: bool = False,
|
114
|
-
status:
|
115
|
-
) -> None:
|
118
|
+
status: SessionStatus | None = None,
|
119
|
+
) -> str | None:
|
116
120
|
"""Log an exception to Sentry."""
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
exc_info = sentry_sdk.utils.exc_info_from_error(
|
121
|
+
if isinstance(exc, str):
|
122
|
+
exc_info = sentry_sdk.utils.exc_info_from_error(Exception(exc))
|
123
|
+
elif isinstance(exc, BaseException):
|
124
|
+
exc_info = sentry_sdk.utils.exc_info_from_error(exc)
|
121
125
|
else:
|
122
126
|
exc_info = sys.exc_info()
|
123
127
|
|
124
|
-
event,
|
128
|
+
event, _ = sentry_sdk.utils.event_from_exception(
|
125
129
|
exc_info,
|
126
|
-
client_options=self.
|
130
|
+
client_options=self.scope.get_client().options, # type: ignore
|
127
131
|
mechanism={"type": "generic", "handled": handled},
|
128
132
|
)
|
133
|
+
event_id = None
|
129
134
|
try:
|
130
|
-
|
135
|
+
with sentry_sdk.scope.use_isolation_scope(self.scope): # type: ignore
|
136
|
+
event_id = sentry_sdk.capture_event(event) # type: ignore
|
131
137
|
except Exception:
|
132
138
|
pass
|
133
139
|
|
@@ -136,11 +142,11 @@ class Sentry:
|
|
136
142
|
status = status or ("crashed" if not handled else "errored") # type: ignore
|
137
143
|
self.mark_session(status=status)
|
138
144
|
|
139
|
-
client
|
145
|
+
client = self.scope.get_client() # type: ignore
|
140
146
|
if client is not None:
|
141
147
|
client.flush()
|
142
148
|
|
143
|
-
return
|
149
|
+
return event_id
|
144
150
|
|
145
151
|
def reraise(self, exc: Any) -> None:
|
146
152
|
"""Re-raise an exception after logging it to Sentry.
|
@@ -157,33 +163,31 @@ class Sentry:
|
|
157
163
|
@_safe_noop
|
158
164
|
def start_session(self) -> None:
|
159
165
|
"""Start a new session."""
|
160
|
-
assert self.
|
166
|
+
assert self.scope is not None
|
161
167
|
# get the current client and scope
|
162
|
-
|
163
|
-
session = scope._session
|
168
|
+
session = self.scope._session
|
164
169
|
|
165
170
|
# if there's no session, start one
|
166
171
|
if session is None:
|
167
|
-
self.
|
172
|
+
self.scope.start_session()
|
168
173
|
|
169
174
|
@_safe_noop
|
170
175
|
def end_session(self) -> None:
|
171
176
|
"""End the current session."""
|
172
|
-
assert self.
|
177
|
+
assert self.scope is not None
|
173
178
|
# get the current client and scope
|
174
|
-
client
|
175
|
-
session = scope._session
|
179
|
+
client = self.scope.get_client()
|
180
|
+
session = self.scope._session
|
176
181
|
|
177
182
|
if session is not None and client is not None:
|
178
|
-
self.
|
183
|
+
self.scope.end_session()
|
179
184
|
client.flush()
|
180
185
|
|
181
186
|
@_safe_noop
|
182
|
-
def mark_session(self, status:
|
187
|
+
def mark_session(self, status: SessionStatus | None = None) -> None:
|
183
188
|
"""Mark the current session with a status."""
|
184
|
-
assert self.
|
185
|
-
|
186
|
-
session = scope._session
|
189
|
+
assert self.scope is not None
|
190
|
+
session = self.scope._session
|
187
191
|
|
188
192
|
if session is not None:
|
189
193
|
session.update(status=status)
|
@@ -191,8 +195,8 @@ class Sentry:
|
|
191
195
|
@_safe_noop
|
192
196
|
def configure_scope(
|
193
197
|
self,
|
194
|
-
tags:
|
195
|
-
process_context:
|
198
|
+
tags: dict[str, Any] | None = None,
|
199
|
+
process_context: str | None = None,
|
196
200
|
) -> None:
|
197
201
|
"""Configure the Sentry scope for the current thread.
|
198
202
|
|
@@ -201,7 +205,7 @@ class Sentry:
|
|
201
205
|
all events sent from this thread. It also tries to start a session
|
202
206
|
if one doesn't already exist for this thread.
|
203
207
|
"""
|
204
|
-
assert self.
|
208
|
+
assert self.scope is not None
|
205
209
|
settings_tags = (
|
206
210
|
"entity",
|
207
211
|
"project",
|
@@ -215,51 +219,50 @@ class Sentry:
|
|
215
219
|
"launch",
|
216
220
|
)
|
217
221
|
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
scope.user = {"email": email} # noqa
|
222
|
+
self.scope.set_tag("platform", wandb.util.get_platform_name())
|
223
|
+
|
224
|
+
# set context
|
225
|
+
if process_context:
|
226
|
+
self.scope.set_tag("process_context", process_context)
|
227
|
+
|
228
|
+
# apply settings tags
|
229
|
+
if tags is None:
|
230
|
+
return None
|
231
|
+
|
232
|
+
for tag in settings_tags:
|
233
|
+
val = tags.get(tag, None)
|
234
|
+
if val not in (None, ""):
|
235
|
+
self.scope.set_tag(tag, val)
|
236
|
+
|
237
|
+
if tags.get("_colab", None):
|
238
|
+
python_runtime = "colab"
|
239
|
+
elif tags.get("_jupyter", None):
|
240
|
+
python_runtime = "jupyter"
|
241
|
+
elif tags.get("_ipython", None):
|
242
|
+
python_runtime = "ipython"
|
243
|
+
else:
|
244
|
+
python_runtime = "python"
|
245
|
+
self.scope.set_tag("python_runtime", python_runtime)
|
246
|
+
|
247
|
+
# Construct run_url and sweep_url given run_id and sweep_id
|
248
|
+
for obj in ("run", "sweep"):
|
249
|
+
obj_id, obj_url = f"{obj}_id", f"{obj}_url"
|
250
|
+
if tags.get(obj_url, None):
|
251
|
+
continue
|
252
|
+
|
253
|
+
try:
|
254
|
+
app_url = wandb.util.app_url(tags["base_url"]) # type: ignore
|
255
|
+
entity, project = (quote(tags[k]) for k in ("entity", "project")) # type: ignore
|
256
|
+
self.scope.set_tag(
|
257
|
+
obj_url,
|
258
|
+
f"{app_url}/{entity}/{project}/{obj}s/{tags[obj_id]}",
|
259
|
+
)
|
260
|
+
except Exception:
|
261
|
+
pass
|
262
|
+
|
263
|
+
email = tags.get("email")
|
264
|
+
if email:
|
265
|
+
self.scope.user = {"email": email} # noqa
|
263
266
|
|
264
267
|
# todo: add back the option to pass general tags see: c645f625d1c1a3db4a6b0e2aa8e924fee101904c (wandb/util.py)
|
265
268
|
|
wandb/apis/public/api.py
CHANGED
@@ -27,6 +27,7 @@ from wandb.apis import public
|
|
27
27
|
from wandb.apis.internal import Api as InternalApi
|
28
28
|
from wandb.apis.normalize import normalize_exceptions
|
29
29
|
from wandb.apis.public.const import RETRY_TIMEDELTA
|
30
|
+
from wandb.sdk.artifacts._validators import is_artifact_registry_project
|
30
31
|
from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
|
31
32
|
from wandb.sdk.launch.utils import LAUNCH_DEFAULT_PROJECT
|
32
33
|
from wandb.sdk.lib import retry, runid
|
@@ -1142,11 +1143,12 @@ class Api:
|
|
1142
1143
|
|
1143
1144
|
@normalize_exceptions
|
1144
1145
|
def artifact(self, name, type=None):
|
1145
|
-
"""Return a single artifact by parsing path in the form `entity/project/name`.
|
1146
|
+
"""Return a single artifact by parsing path in the form `project/name` or `entity/project/name`.
|
1146
1147
|
|
1147
1148
|
Arguments:
|
1148
|
-
name: (str) An artifact name. May be prefixed with entity/project
|
1149
|
-
|
1149
|
+
name: (str) An artifact name. May be prefixed with project/ or entity/project/.
|
1150
|
+
If no entity is specified in the name, the Run or API setting's entity is used.
|
1151
|
+
Valid names can be in the following forms:
|
1150
1152
|
name:version
|
1151
1153
|
name:alias
|
1152
1154
|
type: (str, optional) The type of artifact to fetch.
|
@@ -1157,8 +1159,20 @@ class Api:
|
|
1157
1159
|
if name is None:
|
1158
1160
|
raise ValueError("You must specify name= to fetch an artifact.")
|
1159
1161
|
entity, project, artifact_name = self._parse_artifact_path(name)
|
1162
|
+
|
1163
|
+
organization = ""
|
1164
|
+
# If its an Registry artifact, the entity is an org instead
|
1165
|
+
if is_artifact_registry_project(project):
|
1166
|
+
# Update `organization` only if an organization name was provided,
|
1167
|
+
# otherwise use the default that you already set above.
|
1168
|
+
try:
|
1169
|
+
organization, _, _ = name.split("/")
|
1170
|
+
except ValueError:
|
1171
|
+
organization = ""
|
1172
|
+
# set entity to match the settings since in above code it was potentially set to an org
|
1173
|
+
entity = self.settings["entity"] or self.default_entity
|
1160
1174
|
artifact = wandb.Artifact._from_name(
|
1161
|
-
entity, project, artifact_name, self.client
|
1175
|
+
entity, project, artifact_name, self.client, organization
|
1162
1176
|
)
|
1163
1177
|
if type is not None and artifact.type != type:
|
1164
1178
|
raise ValueError(
|