wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -1
- wandb/analytics/sentry.py +1 -0
- wandb/apis/importers/base.py +20 -5
- wandb/apis/importers/mlflow.py +7 -1
- wandb/apis/internal.py +12 -0
- wandb/apis/public.py +247 -1387
- wandb/apis/reports/_panels.py +58 -35
- wandb/beta/workflows.py +6 -7
- wandb/cli/cli.py +130 -60
- wandb/data_types.py +3 -1
- wandb/filesync/dir_watcher.py +21 -27
- wandb/filesync/step_checksum.py +8 -8
- wandb/filesync/step_prepare.py +23 -10
- wandb/filesync/step_upload.py +13 -13
- wandb/filesync/upload_job.py +4 -8
- wandb/integration/cohere/__init__.py +3 -0
- wandb/integration/cohere/cohere.py +21 -0
- wandb/integration/cohere/resolver.py +347 -0
- wandb/integration/gym/__init__.py +4 -6
- wandb/integration/huggingface/__init__.py +3 -0
- wandb/integration/huggingface/huggingface.py +18 -0
- wandb/integration/huggingface/resolver.py +213 -0
- wandb/integration/langchain/wandb_tracer.py +16 -179
- wandb/integration/openai/__init__.py +1 -3
- wandb/integration/openai/openai.py +11 -143
- wandb/integration/openai/resolver.py +111 -38
- wandb/integration/sagemaker/config.py +2 -2
- wandb/integration/tensorboard/log.py +4 -4
- wandb/old/settings.py +24 -7
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/wandb_deprecated.py +3 -1
- wandb/sdk/__init__.py +1 -1
- wandb/sdk/artifacts/__init__.py +0 -0
- wandb/sdk/artifacts/artifact.py +2101 -0
- wandb/sdk/artifacts/artifact_download_logger.py +42 -0
- wandb/sdk/artifacts/artifact_manifest.py +67 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
- wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
- wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
- wandb/sdk/artifacts/artifact_state.py +10 -0
- wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
- wandb/sdk/artifacts/exceptions.py +55 -0
- wandb/sdk/artifacts/storage_handler.py +59 -0
- wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
- wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
- wandb/sdk/artifacts/storage_layout.py +6 -0
- wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
- wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
- wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
- wandb/sdk/data_types/_dtypes.py +7 -12
- wandb/sdk/data_types/base_types/json_metadata.py +3 -2
- wandb/sdk/data_types/base_types/media.py +8 -8
- wandb/sdk/data_types/base_types/wb_value.py +12 -13
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
- wandb/sdk/data_types/helper_types/classes.py +6 -8
- wandb/sdk/data_types/helper_types/image_mask.py +5 -6
- wandb/sdk/data_types/histogram.py +4 -3
- wandb/sdk/data_types/html.py +3 -4
- wandb/sdk/data_types/image.py +11 -9
- wandb/sdk/data_types/molecule.py +5 -3
- wandb/sdk/data_types/object_3d.py +7 -5
- wandb/sdk/data_types/plotly.py +3 -2
- wandb/sdk/data_types/saved_model.py +11 -11
- wandb/sdk/data_types/trace_tree.py +5 -4
- wandb/sdk/data_types/utils.py +3 -5
- wandb/sdk/data_types/video.py +5 -4
- wandb/sdk/integration_utils/auto_logging.py +215 -0
- wandb/sdk/interface/interface.py +15 -15
- wandb/sdk/internal/file_pusher.py +8 -16
- wandb/sdk/internal/file_stream.py +5 -11
- wandb/sdk/internal/handler.py +13 -1
- wandb/sdk/internal/internal_api.py +287 -13
- wandb/sdk/internal/job_builder.py +119 -30
- wandb/sdk/internal/sender.py +6 -26
- wandb/sdk/internal/settings_static.py +2 -0
- wandb/sdk/internal/system/assets/__init__.py +2 -0
- wandb/sdk/internal/system/assets/gpu.py +42 -0
- wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
- wandb/sdk/internal/system/env_probe_helpers.py +13 -0
- wandb/sdk/internal/system/system_info.py +3 -3
- wandb/sdk/internal/tb_watcher.py +32 -22
- wandb/sdk/internal/thread_local_settings.py +18 -0
- wandb/sdk/launch/_project_spec.py +57 -11
- wandb/sdk/launch/agent/agent.py +147 -65
- wandb/sdk/launch/agent/job_status_tracker.py +34 -0
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
- wandb/sdk/launch/builder/abstract.py +5 -1
- wandb/sdk/launch/builder/build.py +21 -18
- wandb/sdk/launch/builder/docker_builder.py +10 -4
- wandb/sdk/launch/builder/kaniko_builder.py +113 -23
- wandb/sdk/launch/builder/noop.py +6 -3
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
- wandb/sdk/launch/environment/aws_environment.py +3 -2
- wandb/sdk/launch/environment/azure_environment.py +124 -0
- wandb/sdk/launch/environment/gcp_environment.py +2 -4
- wandb/sdk/launch/environment/local_environment.py +1 -1
- wandb/sdk/launch/errors.py +19 -0
- wandb/sdk/launch/github_reference.py +32 -19
- wandb/sdk/launch/launch.py +3 -8
- wandb/sdk/launch/launch_add.py +6 -2
- wandb/sdk/launch/loader.py +21 -2
- wandb/sdk/launch/registry/azure_container_registry.py +132 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
- wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
- wandb/sdk/launch/registry/local_registry.py +2 -1
- wandb/sdk/launch/runner/abstract.py +24 -3
- wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
- wandb/sdk/launch/runner/local_container.py +103 -51
- wandb/sdk/launch/runner/local_process.py +1 -1
- wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
- wandb/sdk/launch/runner/vertex_runner.py +10 -5
- wandb/sdk/launch/sweeps/__init__.py +7 -9
- wandb/sdk/launch/sweeps/scheduler.py +307 -77
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +82 -35
- wandb/sdk/launch/utils.py +89 -75
- wandb/sdk/lib/_settings_toposort_generated.py +7 -0
- wandb/sdk/lib/capped_dict.py +26 -0
- wandb/sdk/lib/{git.py → gitlib.py} +76 -59
- wandb/sdk/lib/hashutil.py +12 -4
- wandb/sdk/lib/paths.py +96 -8
- wandb/sdk/lib/sock_client.py +2 -2
- wandb/sdk/lib/timer.py +1 -0
- wandb/sdk/service/server.py +22 -9
- wandb/sdk/service/server_sock.py +1 -1
- wandb/sdk/service/service.py +27 -8
- wandb/sdk/verify/verify.py +4 -7
- wandb/sdk/wandb_config.py +2 -6
- wandb/sdk/wandb_init.py +57 -53
- wandb/sdk/wandb_require.py +7 -0
- wandb/sdk/wandb_run.py +61 -223
- wandb/sdk/wandb_settings.py +28 -4
- wandb/testing/relay.py +15 -2
- wandb/util.py +74 -36
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
- wandb/integration/langchain/util.py +0 -191
- wandb/sdk/interface/artifacts/__init__.py +0 -33
- wandb/sdk/interface/artifacts/artifact.py +0 -615
- wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
- wandb/sdk/wandb_artifacts.py +0 -2226
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,215 @@
|
|
1
|
+
import functools
|
2
|
+
import logging
|
3
|
+
import sys
|
4
|
+
from typing import Any, Dict, Optional, Sequence, TypeVar
|
5
|
+
|
6
|
+
import wandb.sdk
|
7
|
+
import wandb.util
|
8
|
+
from wandb.sdk.lib import telemetry as wb_telemetry
|
9
|
+
from wandb.sdk.lib.timer import Timer
|
10
|
+
|
11
|
+
if sys.version_info >= (3, 8):
|
12
|
+
from typing import Protocol
|
13
|
+
else:
|
14
|
+
from typing_extensions import Protocol
|
15
|
+
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
AutologInitArgs = Optional[Dict[str, Any]]
|
21
|
+
|
22
|
+
|
23
|
+
K = TypeVar("K", bound=str)
|
24
|
+
V = TypeVar("V")
|
25
|
+
|
26
|
+
|
27
|
+
class Response(Protocol[K, V]):
|
28
|
+
def __getitem__(self, key: K) -> V:
|
29
|
+
... # pragma: no cover
|
30
|
+
|
31
|
+
def get(self, key: K, default: Optional[V] = None) -> Optional[V]:
|
32
|
+
... # pragma: no cover
|
33
|
+
|
34
|
+
|
35
|
+
class ArgumentResponseResolver(Protocol):
|
36
|
+
def __call__(
|
37
|
+
self,
|
38
|
+
args: Sequence[Any],
|
39
|
+
kwargs: Dict[str, Any],
|
40
|
+
response: Response,
|
41
|
+
start_time: float,
|
42
|
+
time_elapsed: float,
|
43
|
+
) -> Optional[Dict[str, Any]]:
|
44
|
+
... # pragma: no cover
|
45
|
+
|
46
|
+
|
47
|
+
class PatchAPI:
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
name: str,
|
51
|
+
symbols: Sequence[str],
|
52
|
+
resolver: ArgumentResponseResolver,
|
53
|
+
) -> None:
|
54
|
+
"""Patches the API to log wandb Media or metrics."""
|
55
|
+
# name of the LLM provider, e.g. "Cohere" or "OpenAI" or package name like "Transformers"
|
56
|
+
self.name = name
|
57
|
+
# api library name, e.g. "cohere" or "openai" or "transformers"
|
58
|
+
self._api = None
|
59
|
+
# dictionary of original methods
|
60
|
+
self.original_methods: Dict[str, Any] = {}
|
61
|
+
# list of symbols to patch, e.g. ["Client.generate", "Edit.create"] or ["Pipeline.__call__"]
|
62
|
+
self.symbols = symbols
|
63
|
+
# resolver callable to convert args/response into a dictionary of wandb media objects or metrics
|
64
|
+
self.resolver = resolver
|
65
|
+
|
66
|
+
@property
|
67
|
+
def set_api(self) -> Any:
|
68
|
+
"""Returns the API module."""
|
69
|
+
lib_name = self.name.lower()
|
70
|
+
if self._api is None:
|
71
|
+
self._api = wandb.util.get_module(
|
72
|
+
name=lib_name,
|
73
|
+
required=f"To use the W&B {self.name} Autolog, "
|
74
|
+
f"you need to have the `{lib_name}` python "
|
75
|
+
f"package installed. Please install it with `pip install {lib_name}`.",
|
76
|
+
lazy=False,
|
77
|
+
)
|
78
|
+
return self._api
|
79
|
+
|
80
|
+
def patch(self, run: "wandb.sdk.wandb_run.Run") -> None:
|
81
|
+
"""Patches the API to log media or metrics to W&B."""
|
82
|
+
for symbol in self.symbols:
|
83
|
+
# split on dots, e.g. "Client.generate" -> ["Client", "generate"]
|
84
|
+
symbol_parts = symbol.split(".")
|
85
|
+
# and get the attribute from the module
|
86
|
+
original = functools.reduce(getattr, symbol_parts, self.set_api)
|
87
|
+
|
88
|
+
def method_factory(original_method: Any):
|
89
|
+
@functools.wraps(original_method)
|
90
|
+
def method(*args, **kwargs):
|
91
|
+
with Timer() as timer:
|
92
|
+
result = original_method(*args, **kwargs)
|
93
|
+
try:
|
94
|
+
loggable_dict = self.resolver(
|
95
|
+
args, kwargs, result, timer.start_time, timer.elapsed
|
96
|
+
)
|
97
|
+
if loggable_dict is not None:
|
98
|
+
run.log(loggable_dict)
|
99
|
+
except Exception as e:
|
100
|
+
logger.warning(e)
|
101
|
+
return result
|
102
|
+
|
103
|
+
return method
|
104
|
+
|
105
|
+
# save original method
|
106
|
+
self.original_methods[symbol] = original
|
107
|
+
# monkey patch the method
|
108
|
+
if len(symbol_parts) == 1:
|
109
|
+
setattr(self.set_api, symbol_parts[0], method_factory(original))
|
110
|
+
else:
|
111
|
+
setattr(
|
112
|
+
functools.reduce(getattr, symbol_parts[:-1], self.set_api),
|
113
|
+
symbol_parts[-1],
|
114
|
+
method_factory(original),
|
115
|
+
)
|
116
|
+
|
117
|
+
def unpatch(self) -> None:
|
118
|
+
"""Unpatches the API."""
|
119
|
+
for symbol, original in self.original_methods.items():
|
120
|
+
# split on dots, e.g. "Client.generate" -> ["Client", "generate"]
|
121
|
+
symbol_parts = symbol.split(".")
|
122
|
+
# unpatch the method
|
123
|
+
if len(symbol_parts) == 1:
|
124
|
+
setattr(self.set_api, symbol_parts[0], original)
|
125
|
+
else:
|
126
|
+
setattr(
|
127
|
+
functools.reduce(getattr, symbol_parts[:-1], self.set_api),
|
128
|
+
symbol_parts[-1],
|
129
|
+
original,
|
130
|
+
)
|
131
|
+
|
132
|
+
|
133
|
+
class AutologAPI:
|
134
|
+
def __init__(
|
135
|
+
self,
|
136
|
+
name: str,
|
137
|
+
symbols: Sequence[str],
|
138
|
+
resolver: ArgumentResponseResolver,
|
139
|
+
telemetry_feature: Optional[str] = None,
|
140
|
+
) -> None:
|
141
|
+
"""Autolog API calls to W&B."""
|
142
|
+
self._telemetry_feature = telemetry_feature
|
143
|
+
self._patch_api = PatchAPI(
|
144
|
+
name=name,
|
145
|
+
symbols=symbols,
|
146
|
+
resolver=resolver,
|
147
|
+
)
|
148
|
+
self._name = self._patch_api.name
|
149
|
+
self._run: Optional["wandb.sdk.wandb_run.Run"] = None
|
150
|
+
self.__run_created_by_autolog: bool = False
|
151
|
+
|
152
|
+
@property
|
153
|
+
def _is_enabled(self) -> bool:
|
154
|
+
"""Returns whether autologging is enabled."""
|
155
|
+
return self._run is not None
|
156
|
+
|
157
|
+
def __call__(self, init: AutologInitArgs = None) -> None:
|
158
|
+
"""Enable autologging."""
|
159
|
+
self.enable(init=init)
|
160
|
+
|
161
|
+
def _run_init(self, init: AutologInitArgs = None) -> None:
|
162
|
+
"""Handle wandb run initialization."""
|
163
|
+
# - autolog(init: dict = {...}) calls wandb.init(**{...})
|
164
|
+
# regardless of whether there is a wandb.run or not,
|
165
|
+
# we only track if the run was created by autolog
|
166
|
+
# - todo: autolog(init: dict | run = run) would use the user-provided run
|
167
|
+
# - autolog() uses the wandb.run if there is one, otherwise it calls wandb.init()
|
168
|
+
if init:
|
169
|
+
_wandb_run = wandb.run
|
170
|
+
# we delegate dealing with the init dict to wandb.init()
|
171
|
+
self._run = wandb.init(**init)
|
172
|
+
if _wandb_run != self._run:
|
173
|
+
self.__run_created_by_autolog = True
|
174
|
+
elif wandb.run is None:
|
175
|
+
self._run = wandb.init()
|
176
|
+
self.__run_created_by_autolog = True
|
177
|
+
else:
|
178
|
+
self._run = wandb.run
|
179
|
+
|
180
|
+
def enable(self, init: AutologInitArgs = None) -> None:
|
181
|
+
"""Enable autologging.
|
182
|
+
|
183
|
+
Args:
|
184
|
+
init: Optional dictionary of arguments to pass to wandb.init().
|
185
|
+
|
186
|
+
"""
|
187
|
+
if self._is_enabled:
|
188
|
+
logger.info(
|
189
|
+
f"{self._name} autologging is already enabled, disabling and re-enabling."
|
190
|
+
)
|
191
|
+
self.disable()
|
192
|
+
|
193
|
+
logger.info(f"Enabling {self._name} autologging.")
|
194
|
+
self._run_init(init=init)
|
195
|
+
|
196
|
+
self._patch_api.patch(self._run)
|
197
|
+
|
198
|
+
if self._telemetry_feature:
|
199
|
+
with wb_telemetry.context(self._run) as tel:
|
200
|
+
setattr(tel.feature, self._telemetry_feature, True)
|
201
|
+
|
202
|
+
def disable(self) -> None:
|
203
|
+
"""Disable autologging."""
|
204
|
+
if self._run is None:
|
205
|
+
return
|
206
|
+
|
207
|
+
logger.info(f"Disabling {self._name} autologging.")
|
208
|
+
|
209
|
+
if self.__run_created_by_autolog:
|
210
|
+
self._run.finish()
|
211
|
+
self.__run_created_by_autolog = False
|
212
|
+
|
213
|
+
self._run = None
|
214
|
+
|
215
|
+
self._patch_api.unpatch()
|
wandb/sdk/interface/interface.py
CHANGED
@@ -17,9 +17,9 @@ import time
|
|
17
17
|
from abc import abstractmethod
|
18
18
|
from typing import TYPE_CHECKING, Any, Iterable, NewType, Optional, Tuple, Union
|
19
19
|
|
20
|
-
from wandb.apis.public import Artifact as PublicArtifact
|
21
20
|
from wandb.proto import wandb_internal_pb2 as pb
|
22
21
|
from wandb.proto import wandb_telemetry_pb2 as tpb
|
22
|
+
from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
|
23
23
|
from wandb.util import (
|
24
24
|
WandBJSONEncoderOld,
|
25
25
|
get_h5_typename,
|
@@ -32,14 +32,14 @@ from wandb.util import (
|
|
32
32
|
|
33
33
|
from ..data_types.utils import history_dict_to_json, val_to_json
|
34
34
|
from ..lib.mailbox import MailboxHandle
|
35
|
-
from ..wandb_artifacts import Artifact
|
36
35
|
from . import summary_record as sr
|
37
|
-
from .artifacts import ArtifactManifest
|
38
36
|
from .message_future import MessageFuture
|
39
37
|
|
40
38
|
GlobStr = NewType("GlobStr", str)
|
41
39
|
|
42
40
|
if TYPE_CHECKING:
|
41
|
+
from wandb.sdk.artifacts.artifact import Artifact
|
42
|
+
|
43
43
|
from ..wandb_run import Run
|
44
44
|
|
45
45
|
if sys.version_info >= (3, 8):
|
@@ -180,8 +180,9 @@ class InterfaceBase:
|
|
180
180
|
proto_run.telemetry.MergeFrom(run._telemetry_obj)
|
181
181
|
return proto_run
|
182
182
|
|
183
|
-
def publish_run(self, run: "
|
184
|
-
self.
|
183
|
+
def publish_run(self, run: "Run") -> None:
|
184
|
+
run_record = self._make_run(run)
|
185
|
+
self._publish_run(run_record)
|
185
186
|
|
186
187
|
@abstractmethod
|
187
188
|
def _publish_run(self, run: pb.RunRecord) -> None:
|
@@ -372,7 +373,7 @@ class InterfaceBase:
|
|
372
373
|
def _publish_files(self, files: pb.FilesRecord) -> None:
|
373
374
|
raise NotImplementedError
|
374
375
|
|
375
|
-
def _make_artifact(self, artifact: Artifact) -> pb.ArtifactRecord:
|
376
|
+
def _make_artifact(self, artifact: "Artifact") -> pb.ArtifactRecord:
|
376
377
|
proto_artifact = pb.ArtifactRecord()
|
377
378
|
proto_artifact.type = artifact.type
|
378
379
|
proto_artifact.name = artifact.name
|
@@ -424,14 +425,14 @@ class InterfaceBase:
|
|
424
425
|
def publish_link_artifact(
|
425
426
|
self,
|
426
427
|
run: "Run",
|
427
|
-
artifact:
|
428
|
+
artifact: "Artifact",
|
428
429
|
portfolio_name: str,
|
429
430
|
aliases: Iterable[str],
|
430
431
|
entity: Optional[str] = None,
|
431
432
|
project: Optional[str] = None,
|
432
433
|
) -> None:
|
433
434
|
link_artifact = pb.LinkArtifactRecord()
|
434
|
-
if
|
435
|
+
if artifact.is_draft():
|
435
436
|
link_artifact.client_id = artifact._client_id
|
436
437
|
else:
|
437
438
|
link_artifact.server_id = artifact.id if artifact.id else ""
|
@@ -448,10 +449,8 @@ class InterfaceBase:
|
|
448
449
|
|
449
450
|
def publish_use_artifact(
|
450
451
|
self,
|
451
|
-
artifact: Artifact,
|
452
|
+
artifact: "Artifact",
|
452
453
|
) -> None:
|
453
|
-
# use_artifact is either a public.Artifact or a wandb.Artifact that has been
|
454
|
-
# waited on and has an id
|
455
454
|
assert artifact.id is not None, "Artifact must have an id"
|
456
455
|
use_artifact = pb.UseArtifactRecord(
|
457
456
|
id=artifact.id, type=artifact.type, name=artifact.name
|
@@ -466,7 +465,7 @@ class InterfaceBase:
|
|
466
465
|
def communicate_artifact(
|
467
466
|
self,
|
468
467
|
run: "Run",
|
469
|
-
artifact: Artifact,
|
468
|
+
artifact: "Artifact",
|
470
469
|
aliases: Iterable[str],
|
471
470
|
history_step: Optional[int] = None,
|
472
471
|
is_user_created: bool = False,
|
@@ -516,7 +515,7 @@ class InterfaceBase:
|
|
516
515
|
def publish_artifact(
|
517
516
|
self,
|
518
517
|
run: "Run",
|
519
|
-
artifact: Artifact,
|
518
|
+
artifact: "Artifact",
|
520
519
|
aliases: Iterable[str],
|
521
520
|
is_user_created: bool = False,
|
522
521
|
use_after_commit: bool = False,
|
@@ -744,8 +743,9 @@ class InterfaceBase:
|
|
744
743
|
def _communicate_shutdown(self) -> None:
|
745
744
|
raise NotImplementedError
|
746
745
|
|
747
|
-
def deliver_run(self, run: "
|
748
|
-
|
746
|
+
def deliver_run(self, run: "Run") -> MailboxHandle:
|
747
|
+
run_record = self._make_run(run)
|
748
|
+
return self._deliver_run(run_record)
|
749
749
|
|
750
750
|
@abstractmethod
|
751
751
|
def _deliver_run(self, run: pb.RunRecord) -> MailboxHandle:
|
@@ -8,11 +8,13 @@ from typing import TYPE_CHECKING, Optional, Tuple
|
|
8
8
|
|
9
9
|
import wandb
|
10
10
|
import wandb.util
|
11
|
-
from wandb.filesync import
|
11
|
+
from wandb.filesync import stats, step_checksum, step_upload
|
12
|
+
from wandb.sdk.lib.paths import LogicalPath
|
12
13
|
|
13
14
|
if TYPE_CHECKING:
|
14
|
-
from wandb.sdk.
|
15
|
-
from wandb.sdk.
|
15
|
+
from wandb.sdk.artifacts import artifact_saver
|
16
|
+
from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
|
17
|
+
from wandb.sdk.internal import file_stream, internal_api
|
16
18
|
from wandb.sdk.internal.settings_static import SettingsStatic
|
17
19
|
|
18
20
|
|
@@ -111,12 +113,7 @@ class FilePusher:
|
|
111
113
|
def file_counts_by_category(self) -> stats.FileCountsByCategory:
|
112
114
|
return self._stats.file_counts_by_category()
|
113
115
|
|
114
|
-
def file_changed(
|
115
|
-
self,
|
116
|
-
save_name: dir_watcher.SaveName,
|
117
|
-
path: str,
|
118
|
-
copy: bool = True,
|
119
|
-
):
|
116
|
+
def file_changed(self, save_name: LogicalPath, path: str, copy: bool = True):
|
120
117
|
"""Tell the file pusher that a file's changed and should be uploaded.
|
121
118
|
|
122
119
|
Arguments:
|
@@ -130,17 +127,12 @@ class FilePusher:
|
|
130
127
|
if os.path.getsize(path) == 0:
|
131
128
|
return
|
132
129
|
|
133
|
-
|
134
|
-
event = step_checksum.RequestUpload(
|
135
|
-
path,
|
136
|
-
dir_watcher.SaveName(save_name),
|
137
|
-
copy,
|
138
|
-
)
|
130
|
+
event = step_checksum.RequestUpload(path, save_name, copy)
|
139
131
|
self._incoming_queue.put(event)
|
140
132
|
|
141
133
|
def store_manifest_files(
|
142
134
|
self,
|
143
|
-
manifest: "
|
135
|
+
manifest: "ArtifactManifest",
|
144
136
|
artifact_id: str,
|
145
137
|
save_fn: "artifact_saver.SaveFn",
|
146
138
|
save_fn_async: "artifact_saver.SaveFnAsync",
|
@@ -336,14 +336,9 @@ class FileStreamApi:
|
|
336
336
|
self._client = requests.Session()
|
337
337
|
# todo: actually use the timeout once more thorough error injection in testing covers it
|
338
338
|
# self._client.post = functools.partial(self._client.post, timeout=self.HTTP_TIMEOUT)
|
339
|
-
self._client.auth =
|
340
|
-
self._client.headers.update(
|
341
|
-
|
342
|
-
"User-Agent": api.user_agent,
|
343
|
-
"X-WANDB-USERNAME": env.get_username() or "",
|
344
|
-
"X-WANDB-USER-EMAIL": env.get_user_email() or "",
|
345
|
-
}
|
346
|
-
)
|
339
|
+
self._client.auth = api.client.transport.session.auth
|
340
|
+
self._client.headers.update(api.client.transport.headers or {})
|
341
|
+
self._client.cookies.update(api.client.transport.cookies or {}) # type: ignore[no-untyped-call]
|
347
342
|
self._file_policies: Dict[str, "DefaultFilePolicy"] = {}
|
348
343
|
self._dropped_chunks: int = 0
|
349
344
|
self._queue: queue.Queue = queue.Queue()
|
@@ -658,9 +653,8 @@ def request_with_retry(
|
|
658
653
|
e.response is not None and e.response.status_code == 429
|
659
654
|
):
|
660
655
|
err_str = (
|
661
|
-
"Filestream rate limit exceeded,
|
662
|
-
|
663
|
-
)
|
656
|
+
"Filestream rate limit exceeded, "
|
657
|
+
f"retrying in {delay:.1f} seconds. "
|
664
658
|
)
|
665
659
|
if retry_callback:
|
666
660
|
retry_callback(e.response.status_code, err_str)
|
wandb/sdk/internal/handler.py
CHANGED
@@ -191,6 +191,11 @@ class HandleManager:
|
|
191
191
|
self._dispatch_record(record)
|
192
192
|
|
193
193
|
def handle_run(self, record: Record) -> None:
|
194
|
+
if self._settings._offline:
|
195
|
+
self._run_proto = record.run
|
196
|
+
result = proto_util._result_from_record(record)
|
197
|
+
result.run_result.run.CopyFrom(record.run)
|
198
|
+
self._respond_result(result)
|
194
199
|
self._dispatch_record(record)
|
195
200
|
|
196
201
|
def handle_stats(self, record: Record) -> None:
|
@@ -624,7 +629,12 @@ class HandleManager:
|
|
624
629
|
self._dispatch_record(record)
|
625
630
|
|
626
631
|
def handle_request_attach(self, record: Record) -> None:
|
627
|
-
|
632
|
+
result = proto_util._result_from_record(record)
|
633
|
+
attach_id = record.request.attach.attach_id
|
634
|
+
assert attach_id
|
635
|
+
assert self._run_proto
|
636
|
+
result.response.attach_response.run.CopyFrom(self._run_proto)
|
637
|
+
self._respond_result(result)
|
628
638
|
|
629
639
|
def handle_request_log_artifact(self, record: Record) -> None:
|
630
640
|
self._dispatch_record(record)
|
@@ -675,6 +685,8 @@ class HandleManager:
|
|
675
685
|
assert run_start
|
676
686
|
assert run_start.run
|
677
687
|
|
688
|
+
self._run_proto = run_start.run
|
689
|
+
|
678
690
|
self._run_start_time = run_start.run.start_time.ToMicroseconds() / 1e6
|
679
691
|
|
680
692
|
self._track_time = time.time()
|