wandb 0.17.0rc1__py3-none-macosx_11_0_arm64.whl → 0.17.1__py3-none-macosx_11_0_arm64.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -2
- wandb/apis/importers/internals/internal.py +0 -1
- wandb/apis/importers/wandb.py +12 -7
- wandb/apis/internal.py +0 -3
- wandb/apis/public/api.py +213 -79
- wandb/apis/public/artifacts.py +335 -100
- wandb/apis/public/files.py +9 -9
- wandb/apis/public/jobs.py +16 -4
- wandb/apis/public/projects.py +26 -28
- wandb/apis/public/query_generator.py +1 -1
- wandb/apis/public/runs.py +163 -65
- wandb/apis/public/sweeps.py +2 -2
- wandb/apis/reports/__init__.py +1 -7
- wandb/apis/reports/v1/__init__.py +5 -27
- wandb/apis/reports/v2/__init__.py +7 -19
- wandb/apis/workspaces/__init__.py +8 -0
- wandb/beta/workflows.py +8 -3
- wandb/bin/apple_gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +131 -59
- wandb/data_types.py +6 -3
- wandb/docker/__init__.py +2 -2
- wandb/env.py +3 -3
- wandb/errors/term.py +10 -2
- wandb/filesync/step_checksum.py +1 -4
- wandb/filesync/step_prepare.py +4 -24
- wandb/filesync/step_upload.py +5 -107
- wandb/filesync/upload_job.py +0 -76
- wandb/integration/gym/__init__.py +35 -15
- wandb/integration/huggingface/resolver.py +2 -2
- wandb/integration/keras/callbacks/metrics_logger.py +1 -1
- wandb/integration/keras/keras.py +1 -1
- wandb/integration/openai/fine_tuning.py +21 -3
- wandb/integration/prodigy/prodigy.py +1 -1
- wandb/jupyter.py +16 -17
- wandb/old/summary.py +1 -1
- wandb/plot/confusion_matrix.py +1 -1
- wandb/plot/pr_curve.py +2 -1
- wandb/plot/roc_curve.py +2 -1
- wandb/{plots → plot}/utils.py +13 -25
- wandb/proto/v3/wandb_internal_pb2.py +54 -54
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +54 -54
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v5/wandb_base_pb2.py +30 -0
- wandb/proto/v5/wandb_internal_pb2.py +355 -0
- wandb/proto/v5/wandb_server_pb2.py +63 -0
- wandb/proto/v5/wandb_settings_pb2.py +45 -0
- wandb/proto/v5/wandb_telemetry_pb2.py +41 -0
- wandb/proto/wandb_base_pb2.py +2 -0
- wandb/proto/wandb_deprecated.py +9 -1
- wandb/proto/wandb_generate_deprecated.py +34 -0
- wandb/proto/{wandb_internal_codegen.py → wandb_generate_proto.py} +1 -35
- wandb/proto/wandb_internal_pb2.py +2 -0
- wandb/proto/wandb_server_pb2.py +2 -0
- wandb/proto/wandb_settings_pb2.py +2 -0
- wandb/proto/wandb_telemetry_pb2.py +2 -0
- wandb/sdk/artifacts/artifact.py +68 -22
- wandb/sdk/artifacts/artifact_manifest.py +1 -1
- wandb/sdk/artifacts/artifact_manifest_entry.py +6 -3
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -1
- wandb/sdk/artifacts/artifact_saver.py +1 -10
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +6 -2
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +6 -4
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +2 -42
- wandb/sdk/artifacts/storage_policy.py +1 -12
- wandb/sdk/data_types/_dtypes.py +8 -8
- wandb/sdk/data_types/image.py +2 -2
- wandb/sdk/data_types/video.py +5 -3
- wandb/sdk/integration_utils/data_logging.py +5 -5
- wandb/sdk/interface/interface.py +14 -1
- wandb/sdk/interface/interface_shared.py +1 -1
- wandb/sdk/internal/file_pusher.py +2 -5
- wandb/sdk/internal/file_stream.py +6 -19
- wandb/sdk/internal/internal_api.py +148 -136
- wandb/sdk/internal/job_builder.py +208 -136
- wandb/sdk/internal/progress.py +0 -28
- wandb/sdk/internal/sender.py +102 -39
- wandb/sdk/internal/settings_static.py +8 -1
- wandb/sdk/internal/system/assets/trainium.py +3 -3
- wandb/sdk/internal/system/system_info.py +4 -2
- wandb/sdk/internal/update.py +1 -1
- wandb/sdk/launch/__init__.py +9 -1
- wandb/sdk/launch/_launch.py +4 -24
- wandb/sdk/launch/_launch_add.py +1 -3
- wandb/sdk/launch/_project_spec.py +187 -225
- wandb/sdk/launch/agent/agent.py +59 -19
- wandb/sdk/launch/agent/config.py +0 -3
- wandb/sdk/launch/builder/abstract.py +68 -1
- wandb/sdk/launch/builder/build.py +165 -576
- wandb/sdk/launch/builder/context_manager.py +235 -0
- wandb/sdk/launch/builder/docker_builder.py +7 -23
- wandb/sdk/launch/builder/kaniko_builder.py +12 -25
- wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
- wandb/sdk/launch/create_job.py +51 -45
- wandb/sdk/launch/environment/aws_environment.py +26 -1
- wandb/sdk/launch/inputs/files.py +148 -0
- wandb/sdk/launch/inputs/internal.py +224 -0
- wandb/sdk/launch/inputs/manage.py +95 -0
- wandb/sdk/launch/registry/google_artifact_registry.py +1 -1
- wandb/sdk/launch/runner/abstract.py +2 -2
- wandb/sdk/launch/runner/kubernetes_monitor.py +45 -12
- wandb/sdk/launch/runner/kubernetes_runner.py +6 -8
- wandb/sdk/launch/runner/local_container.py +2 -3
- wandb/sdk/launch/runner/local_process.py +8 -29
- wandb/sdk/launch/runner/sagemaker_runner.py +20 -14
- wandb/sdk/launch/runner/vertex_runner.py +8 -7
- wandb/sdk/launch/sweeps/scheduler.py +5 -3
- wandb/sdk/launch/sweeps/scheduler_sweep.py +1 -1
- wandb/sdk/launch/sweeps/utils.py +4 -4
- wandb/sdk/launch/utils.py +16 -138
- wandb/sdk/lib/_settings_toposort_generated.py +2 -5
- wandb/sdk/lib/apikey.py +4 -2
- wandb/sdk/lib/config_util.py +3 -3
- wandb/sdk/lib/import_hooks.py +1 -1
- wandb/sdk/lib/proto_util.py +22 -1
- wandb/sdk/lib/redirect.py +20 -15
- wandb/sdk/lib/tracelog.py +1 -1
- wandb/sdk/service/service.py +2 -1
- wandb/sdk/service/streams.py +5 -5
- wandb/sdk/wandb_init.py +25 -59
- wandb/sdk/wandb_login.py +28 -25
- wandb/sdk/wandb_run.py +123 -53
- wandb/sdk/wandb_settings.py +33 -64
- wandb/sdk/wandb_setup.py +1 -1
- wandb/sdk/wandb_watch.py +1 -1
- wandb/sklearn/plot/classifier.py +10 -12
- wandb/sklearn/plot/clusterer.py +1 -1
- wandb/sync/sync.py +2 -2
- wandb/testing/relay.py +32 -17
- wandb/util.py +36 -37
- wandb/wandb_agent.py +3 -3
- wandb/wandb_controller.py +5 -4
- {wandb-0.17.0rc1.dist-info → wandb-0.17.1.dist-info}/METADATA +8 -10
- {wandb-0.17.0rc1.dist-info → wandb-0.17.1.dist-info}/RECORD +141 -163
- {wandb-0.17.0rc1.dist-info → wandb-0.17.1.dist-info}/WHEEL +1 -1
- wandb/apis/reports/v1/_blocks.py +0 -1406
- wandb/apis/reports/v1/_helpers.py +0 -70
- wandb/apis/reports/v1/_panels.py +0 -1282
- wandb/apis/reports/v1/_templates.py +0 -478
- wandb/apis/reports/v1/blocks.py +0 -27
- wandb/apis/reports/v1/helpers.py +0 -2
- wandb/apis/reports/v1/mutations.py +0 -66
- wandb/apis/reports/v1/panels.py +0 -17
- wandb/apis/reports/v1/report.py +0 -268
- wandb/apis/reports/v1/runset.py +0 -144
- wandb/apis/reports/v1/templates.py +0 -7
- wandb/apis/reports/v1/util.py +0 -406
- wandb/apis/reports/v1/validators.py +0 -131
- wandb/apis/reports/v2/blocks.py +0 -25
- wandb/apis/reports/v2/expr_parsing.py +0 -257
- wandb/apis/reports/v2/gql.py +0 -68
- wandb/apis/reports/v2/interface.py +0 -1911
- wandb/apis/reports/v2/internal.py +0 -867
- wandb/apis/reports/v2/metrics.py +0 -6
- wandb/apis/reports/v2/panels.py +0 -15
- wandb/catboost/__init__.py +0 -9
- wandb/fastai/__init__.py +0 -9
- wandb/keras/__init__.py +0 -19
- wandb/lightgbm/__init__.py +0 -9
- wandb/plots/__init__.py +0 -6
- wandb/plots/explain_text.py +0 -36
- wandb/plots/heatmap.py +0 -81
- wandb/plots/named_entity.py +0 -43
- wandb/plots/part_of_speech.py +0 -50
- wandb/plots/plot_definitions.py +0 -768
- wandb/plots/precision_recall.py +0 -121
- wandb/plots/roc.py +0 -103
- wandb/sacred/__init__.py +0 -3
- wandb/xgboost/__init__.py +0 -9
- {wandb-0.17.0rc1.dist-info → wandb-0.17.1.dist-info}/entry_points.txt +0 -0
- {wandb-0.17.0rc1.dist-info → wandb-0.17.1.dist-info}/licenses/LICENSE +0 -0
@@ -78,7 +78,7 @@ class ValidationDataLogger:
|
|
78
78
|
Defaults to `"wb_validation_data"`.
|
79
79
|
artifact_type: The artifact type to use for the validation data.
|
80
80
|
Defaults to `"validation_dataset"`.
|
81
|
-
class_labels: Optional list of
|
81
|
+
class_labels: Optional list of labels to use in the inferred
|
82
82
|
processors. If the model's `target` or `output` is inferred to be a class,
|
83
83
|
we will attempt to map the class to these labels. Defaults to `None`.
|
84
84
|
infer_missing_processors: Determines if processors are inferred if
|
@@ -262,7 +262,7 @@ def _infer_single_example_keyed_processor(
|
|
262
262
|
):
|
263
263
|
np = wandb.util.get_module(
|
264
264
|
"numpy",
|
265
|
-
required="
|
265
|
+
required="Inferring processors require numpy",
|
266
266
|
)
|
267
267
|
# Assume these are logits
|
268
268
|
class_names = class_labels_table.get_column("label")
|
@@ -301,7 +301,7 @@ def _infer_single_example_keyed_processor(
|
|
301
301
|
elif len(shape) == 1:
|
302
302
|
np = wandb.util.get_module(
|
303
303
|
"numpy",
|
304
|
-
required="
|
304
|
+
required="Inferring processors require numpy",
|
305
305
|
)
|
306
306
|
# This could be anything
|
307
307
|
if shape[0] <= 10:
|
@@ -354,7 +354,7 @@ def _infer_validation_row_processor(
|
|
354
354
|
input_col_name: str = "input",
|
355
355
|
target_col_name: str = "target",
|
356
356
|
) -> Callable:
|
357
|
-
"""Infers the
|
357
|
+
"""Infers the composite processor for the validation data."""
|
358
358
|
single_processors = {}
|
359
359
|
if isinstance(example_input, dict):
|
360
360
|
for key in example_input:
|
@@ -431,7 +431,7 @@ def _infer_prediction_row_processor(
|
|
431
431
|
input_col_name: str = "input",
|
432
432
|
output_col_name: str = "output",
|
433
433
|
) -> Callable:
|
434
|
-
"""Infers the
|
434
|
+
"""Infers the composite processor for the prediction output data."""
|
435
435
|
single_processors = {}
|
436
436
|
|
437
437
|
if isinstance(example_prediction, dict):
|
wandb/sdk/interface/interface.py
CHANGED
@@ -387,18 +387,30 @@ class InterfaceBase:
|
|
387
387
|
def _make_partial_source_str(
|
388
388
|
source: Any, job_info: Dict[str, Any], metadata: Dict[str, Any]
|
389
389
|
) -> str:
|
390
|
-
"""Construct use_artifact.partial.source_info.
|
390
|
+
"""Construct use_artifact.partial.source_info.source as str."""
|
391
391
|
source_type = job_info.get("source_type", "").strip()
|
392
392
|
if source_type == "artifact":
|
393
393
|
info_source = job_info.get("source", {})
|
394
394
|
source.artifact.artifact = info_source.get("artifact", "")
|
395
395
|
source.artifact.entrypoint.extend(info_source.get("entrypoint", []))
|
396
396
|
source.artifact.notebook = info_source.get("notebook", False)
|
397
|
+
build_context = info_source.get("build_context")
|
398
|
+
if build_context:
|
399
|
+
source.artifact.build_context = build_context
|
400
|
+
dockerfile = info_source.get("dockerfile")
|
401
|
+
if dockerfile:
|
402
|
+
source.artifact.dockerfile = dockerfile
|
397
403
|
elif source_type == "repo":
|
398
404
|
source.git.git_info.remote = metadata.get("git", {}).get("remote", "")
|
399
405
|
source.git.git_info.commit = metadata.get("git", {}).get("commit", "")
|
400
406
|
source.git.entrypoint.extend(metadata.get("entrypoint", []))
|
401
407
|
source.git.notebook = metadata.get("notebook", False)
|
408
|
+
build_context = metadata.get("build_context")
|
409
|
+
if build_context:
|
410
|
+
source.git.build_context = build_context
|
411
|
+
dockerfile = metadata.get("dockerfile")
|
412
|
+
if dockerfile:
|
413
|
+
source.git.dockerfile = dockerfile
|
402
414
|
elif source_type == "image":
|
403
415
|
source.image.image = metadata.get("docker", "")
|
404
416
|
else:
|
@@ -775,6 +787,7 @@ class InterfaceBase:
|
|
775
787
|
source.file.CopyFrom(
|
776
788
|
pb.JobInputSource.ConfigFileSource(path=file_path),
|
777
789
|
)
|
790
|
+
request.input_source.CopyFrom(source)
|
778
791
|
|
779
792
|
return self._publish_job_input(request)
|
780
793
|
|
@@ -321,7 +321,7 @@ class InterfaceShared(InterfaceBase):
|
|
321
321
|
if result is None:
|
322
322
|
# TODO: friendlier error message here
|
323
323
|
raise wandb.Error(
|
324
|
-
"Couldn't communicate with backend after
|
324
|
+
"Couldn't communicate with backend after {} seconds".format(timeout)
|
325
325
|
)
|
326
326
|
login_response = result.response.login_response
|
327
327
|
assert login_response
|
@@ -14,7 +14,7 @@ from wandb.sdk.lib.paths import LogicalPath
|
|
14
14
|
|
15
15
|
if TYPE_CHECKING:
|
16
16
|
from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
|
17
|
-
from wandb.sdk.artifacts.artifact_saver import SaveFn
|
17
|
+
from wandb.sdk.artifacts.artifact_saver import SaveFn
|
18
18
|
from wandb.sdk.internal import file_stream, internal_api
|
19
19
|
from wandb.sdk.internal.settings_static import SettingsStatic
|
20
20
|
|
@@ -148,11 +148,8 @@ class FilePusher:
|
|
148
148
|
manifest: "ArtifactManifest",
|
149
149
|
artifact_id: str,
|
150
150
|
save_fn: "SaveFn",
|
151
|
-
save_fn_async: "SaveFnAsync",
|
152
151
|
) -> None:
|
153
|
-
event = step_checksum.RequestStoreManifestFiles(
|
154
|
-
manifest, artifact_id, save_fn, save_fn_async
|
155
|
-
)
|
152
|
+
event = step_checksum.RequestStoreManifestFiles(manifest, artifact_id, save_fn)
|
156
153
|
self._incoming_queue.put(event)
|
157
154
|
|
158
155
|
def commit_artifact(
|
@@ -1,4 +1,3 @@
|
|
1
|
-
import base64
|
2
1
|
import functools
|
3
2
|
import itertools
|
4
3
|
import json
|
@@ -53,7 +52,7 @@ logger = logging.getLogger(__name__)
|
|
53
52
|
|
54
53
|
class Chunk(NamedTuple):
|
55
54
|
filename: str
|
56
|
-
data:
|
55
|
+
data: str
|
57
56
|
|
58
57
|
|
59
58
|
class DefaultFilePolicy:
|
@@ -227,7 +226,7 @@ class CRDedupeFilePolicy(DefaultFilePolicy):
|
|
227
226
|
prefix += token + " "
|
228
227
|
return prefix, rest
|
229
228
|
|
230
|
-
def process_chunks(self, chunks: List) -> List["ProcessedChunk"]:
|
229
|
+
def process_chunks(self, chunks: List[Chunk]) -> List["ProcessedChunk"]:
|
231
230
|
r"""Process chunks.
|
232
231
|
|
233
232
|
Args:
|
@@ -300,18 +299,6 @@ class CRDedupeFilePolicy(DefaultFilePolicy):
|
|
300
299
|
return ret
|
301
300
|
|
302
301
|
|
303
|
-
class BinaryFilePolicy(DefaultFilePolicy):
|
304
|
-
def __init__(self) -> None:
|
305
|
-
super().__init__()
|
306
|
-
self._offset: int = 0
|
307
|
-
|
308
|
-
def process_chunks(self, chunks: List[Chunk]) -> "ProcessedBinaryChunk":
|
309
|
-
data = b"".join([c.data for c in chunks])
|
310
|
-
enc = base64.b64encode(data).decode("ascii")
|
311
|
-
self._offset += len(data)
|
312
|
-
return {"offset": self._offset, "content": enc, "encoding": "base64"}
|
313
|
-
|
314
|
-
|
315
302
|
class FileStreamApi:
|
316
303
|
"""Pushes chunks of files to our streaming endpoint.
|
317
304
|
|
@@ -520,7 +507,7 @@ class FileStreamApi:
|
|
520
507
|
wandb.termerror(
|
521
508
|
"Dropped streaming file chunk (see wandb/debug-internal.log)"
|
522
509
|
)
|
523
|
-
logger.exception("dropped chunk
|
510
|
+
logger.exception("dropped chunk {}".format(response))
|
524
511
|
self._dropped_chunks += 1
|
525
512
|
else:
|
526
513
|
parsed: Optional[dict] = None
|
@@ -585,12 +572,12 @@ class FileStreamApi:
|
|
585
572
|
def enqueue_preempting(self) -> None:
|
586
573
|
self._queue.put(self.Preempting())
|
587
574
|
|
588
|
-
def push(self, filename: str, data:
|
575
|
+
def push(self, filename: str, data: str) -> None:
|
589
576
|
"""Push a chunk of a file to the streaming endpoint.
|
590
577
|
|
591
578
|
Arguments:
|
592
|
-
filename: Name of file
|
593
|
-
data:
|
579
|
+
filename: Name of file to append to.
|
580
|
+
data: Text to append to the file.
|
594
581
|
"""
|
595
582
|
self._queue.put(Chunk(filename, data))
|
596
583
|
|
@@ -1,5 +1,4 @@
|
|
1
1
|
import ast
|
2
|
-
import asyncio
|
3
2
|
import base64
|
4
3
|
import datetime
|
5
4
|
import functools
|
@@ -49,7 +48,7 @@ from ..lib import retry
|
|
49
48
|
from ..lib.filenames import DIFF_FNAME, METADATA_FNAME
|
50
49
|
from ..lib.gitlib import GitRepo
|
51
50
|
from . import context
|
52
|
-
from .progress import
|
51
|
+
from .progress import Progress
|
53
52
|
|
54
53
|
logger = logging.getLogger(__name__)
|
55
54
|
|
@@ -121,13 +120,6 @@ if TYPE_CHECKING:
|
|
121
120
|
SweepState = Literal["RUNNING", "PAUSED", "CANCELED", "FINISHED"]
|
122
121
|
Number = Union[int, float]
|
123
122
|
|
124
|
-
# This funny if/else construction is the simplest thing I've found that
|
125
|
-
# works at runtime, satisfies Mypy, and gives autocomplete in VSCode:
|
126
|
-
if TYPE_CHECKING:
|
127
|
-
import httpx
|
128
|
-
else:
|
129
|
-
httpx = util.get_module("httpx")
|
130
|
-
|
131
123
|
# class _MappingSupportsCopy(Protocol):
|
132
124
|
# def copy(self) -> "_MappingSupportsCopy": ...
|
133
125
|
# def keys(self) -> Iterable: ...
|
@@ -161,23 +153,6 @@ def check_httpclient_logger_handler() -> None:
|
|
161
153
|
httpclient_logger.addHandler(root_logger.handlers[0])
|
162
154
|
|
163
155
|
|
164
|
-
def check_httpx_exc_retriable(exc: Exception) -> bool:
|
165
|
-
retriable_codes = (308, 408, 409, 429, 500, 502, 503, 504)
|
166
|
-
return (
|
167
|
-
isinstance(exc, (httpx.TimeoutException, httpx.NetworkError))
|
168
|
-
or (
|
169
|
-
isinstance(exc, httpx.HTTPStatusError)
|
170
|
-
and exc.response.status_code in retriable_codes
|
171
|
-
)
|
172
|
-
or (
|
173
|
-
isinstance(exc, httpx.HTTPStatusError)
|
174
|
-
and exc.response.status_code == 400
|
175
|
-
and "x-amz-meta-md5" in exc.request.headers
|
176
|
-
and "RequestTimeout" in str(exc.response.content)
|
177
|
-
)
|
178
|
-
)
|
179
|
-
|
180
|
-
|
181
156
|
class _ThreadLocalData(threading.local):
|
182
157
|
context: Optional[context.Context]
|
183
158
|
|
@@ -286,10 +261,6 @@ class Api:
|
|
286
261
|
)
|
287
262
|
)
|
288
263
|
|
289
|
-
# httpx is an optional dependency, so we lazily instantiate the client
|
290
|
-
# only when we need it
|
291
|
-
self._async_httpx_client: Optional[httpx.AsyncClient] = None
|
292
|
-
|
293
264
|
self.retry_callback = retry_callback
|
294
265
|
self._retry_gql = retry.Retry(
|
295
266
|
self.execute,
|
@@ -361,7 +332,7 @@ class Api:
|
|
361
332
|
|
362
333
|
def relocate(self) -> None:
|
363
334
|
"""Ensure the current api points to the right server."""
|
364
|
-
self.client.transport.url = "
|
335
|
+
self.client.transport.url = "{}/graphql".format(self.settings("base_url"))
|
365
336
|
|
366
337
|
def execute(self, *args: Any, **kwargs: Any) -> "_Response":
|
367
338
|
"""Wrapper around execute that logs in cases of failure."""
|
@@ -2245,6 +2216,113 @@ class Api:
|
|
2245
2216
|
server_messages,
|
2246
2217
|
)
|
2247
2218
|
|
2219
|
+
@normalize_exceptions
|
2220
|
+
def rewind_run(
|
2221
|
+
self,
|
2222
|
+
run_name: str,
|
2223
|
+
metric_name: str,
|
2224
|
+
metric_value: float,
|
2225
|
+
program_path: Optional[str] = None,
|
2226
|
+
entity: Optional[str] = None,
|
2227
|
+
project: Optional[str] = None,
|
2228
|
+
num_retries: Optional[int] = None,
|
2229
|
+
) -> dict:
|
2230
|
+
"""Rewinds a run to a previous state.
|
2231
|
+
|
2232
|
+
Arguments:
|
2233
|
+
run_name (str): The name of the run to rewind
|
2234
|
+
metric_name (str): The name of the metric to rewind to
|
2235
|
+
metric_value (float): The value of the metric to rewind to
|
2236
|
+
program_path (str, optional): Path to the program
|
2237
|
+
entity (str, optional): The entity to scope this project to
|
2238
|
+
project (str, optional): The name of the project
|
2239
|
+
num_retries (int, optional): Number of retries
|
2240
|
+
|
2241
|
+
Returns:
|
2242
|
+
A dict with the rewound run
|
2243
|
+
|
2244
|
+
{
|
2245
|
+
"id": "run_id",
|
2246
|
+
"name": "run_name",
|
2247
|
+
"displayName": "run_display_name",
|
2248
|
+
"description": "run_description",
|
2249
|
+
"config": "stringified_run_config_json",
|
2250
|
+
"sweepName": "run_sweep_name",
|
2251
|
+
"project": {
|
2252
|
+
"id": "project_id",
|
2253
|
+
"name": "project_name",
|
2254
|
+
"entity": {
|
2255
|
+
"id": "entity_id",
|
2256
|
+
"name": "entity_name"
|
2257
|
+
}
|
2258
|
+
},
|
2259
|
+
"historyLineCount": 100,
|
2260
|
+
}
|
2261
|
+
"""
|
2262
|
+
query_string = """
|
2263
|
+
mutation RewindRun($runName: String!, $entity: String, $project: String, $metricName: String!, $metricValue: Float!) {
|
2264
|
+
rewindRun(input: {runName: $runName, entityName: $entity, projectName: $project, metricName: $metricName, metricValue: $metricValue}) {
|
2265
|
+
rewoundRun {
|
2266
|
+
id
|
2267
|
+
name
|
2268
|
+
displayName
|
2269
|
+
description
|
2270
|
+
config
|
2271
|
+
sweepName
|
2272
|
+
project {
|
2273
|
+
id
|
2274
|
+
name
|
2275
|
+
entity {
|
2276
|
+
id
|
2277
|
+
name
|
2278
|
+
}
|
2279
|
+
}
|
2280
|
+
historyLineCount
|
2281
|
+
}
|
2282
|
+
}
|
2283
|
+
}
|
2284
|
+
"""
|
2285
|
+
|
2286
|
+
mutation = gql(query_string)
|
2287
|
+
|
2288
|
+
kwargs = {}
|
2289
|
+
if num_retries is not None:
|
2290
|
+
kwargs["num_retries"] = num_retries
|
2291
|
+
|
2292
|
+
variable_values = {
|
2293
|
+
"runName": run_name,
|
2294
|
+
"entity": entity or self.settings("entity"),
|
2295
|
+
"project": project or util.auto_project_name(program_path),
|
2296
|
+
"metricName": metric_name,
|
2297
|
+
"metricValue": metric_value,
|
2298
|
+
}
|
2299
|
+
|
2300
|
+
# retry conflict errors for 2 minutes, default to no_auth_retry
|
2301
|
+
check_retry_fn = util.make_check_retry_fn(
|
2302
|
+
check_fn=util.check_retry_conflict_or_gone,
|
2303
|
+
check_timedelta=datetime.timedelta(minutes=2),
|
2304
|
+
fallback_retry_fn=util.no_retry_auth,
|
2305
|
+
)
|
2306
|
+
|
2307
|
+
response = self.gql(
|
2308
|
+
mutation,
|
2309
|
+
variable_values=variable_values,
|
2310
|
+
check_retry_fn=check_retry_fn,
|
2311
|
+
**kwargs,
|
2312
|
+
)
|
2313
|
+
|
2314
|
+
run_obj: Dict[str, Dict[str, Dict[str, str]]] = response.get(
|
2315
|
+
"rewindRun", {}
|
2316
|
+
).get("rewoundRun", {})
|
2317
|
+
project_obj: Dict[str, Dict[str, str]] = run_obj.get("project", {})
|
2318
|
+
if project_obj:
|
2319
|
+
self.set_setting("project", project_obj["name"])
|
2320
|
+
entity_obj = project_obj.get("entity", {})
|
2321
|
+
if entity_obj:
|
2322
|
+
self.set_setting("entity", entity_obj["name"])
|
2323
|
+
|
2324
|
+
return run_obj
|
2325
|
+
|
2248
2326
|
@normalize_exceptions
|
2249
2327
|
def get_run_info(
|
2250
2328
|
self,
|
@@ -2794,105 +2872,6 @@ class Api:
|
|
2794
2872
|
|
2795
2873
|
return response
|
2796
2874
|
|
2797
|
-
async def upload_file_async(
|
2798
|
-
self,
|
2799
|
-
url: str,
|
2800
|
-
file: IO[bytes],
|
2801
|
-
callback: Optional["ProgressFn"] = None,
|
2802
|
-
extra_headers: Optional[Dict[str, str]] = None,
|
2803
|
-
) -> None:
|
2804
|
-
"""An async not-quite-equivalent version of `upload_file`.
|
2805
|
-
|
2806
|
-
Differences from `upload_file`:
|
2807
|
-
- This method doesn't implement Azure uploads. (The Azure SDK supports
|
2808
|
-
async, but it's nontrivial to use it here.) If the upload looks like
|
2809
|
-
it's destined for Azure, this method will delegate to the sync impl.
|
2810
|
-
- Consequently, this method doesn't return the response object.
|
2811
|
-
(Because it might fall back to the sync impl, it would sometimes
|
2812
|
-
return a `requests.Response` and sometimes an `httpx.Response`.)
|
2813
|
-
- This method doesn't wrap retryable errors in `TransientError`.
|
2814
|
-
It leaves that determination to the caller.
|
2815
|
-
"""
|
2816
|
-
check_httpclient_logger_handler()
|
2817
|
-
must_delegate = False
|
2818
|
-
|
2819
|
-
if httpx is None:
|
2820
|
-
wandb.termwarn( # type: ignore[unreachable]
|
2821
|
-
"async file-uploads require `pip install wandb[async]`; falling back to sync implementation",
|
2822
|
-
repeat=False,
|
2823
|
-
)
|
2824
|
-
must_delegate = True
|
2825
|
-
|
2826
|
-
if extra_headers is not None and "x-ms-blob-type" in extra_headers:
|
2827
|
-
wandb.termwarn(
|
2828
|
-
"async file-uploads don't support Azure; falling back to sync implementation",
|
2829
|
-
repeat=False,
|
2830
|
-
)
|
2831
|
-
must_delegate = True
|
2832
|
-
|
2833
|
-
if must_delegate:
|
2834
|
-
await asyncio.get_event_loop().run_in_executor(
|
2835
|
-
None,
|
2836
|
-
lambda: self.upload_file_retry(
|
2837
|
-
url=url,
|
2838
|
-
file=file,
|
2839
|
-
callback=callback,
|
2840
|
-
extra_headers=extra_headers,
|
2841
|
-
),
|
2842
|
-
)
|
2843
|
-
return
|
2844
|
-
|
2845
|
-
if self._async_httpx_client is None:
|
2846
|
-
self._async_httpx_client = httpx.AsyncClient()
|
2847
|
-
|
2848
|
-
progress = AsyncProgress(Progress(file, callback=callback))
|
2849
|
-
|
2850
|
-
try:
|
2851
|
-
response = await self._async_httpx_client.put(
|
2852
|
-
url=url,
|
2853
|
-
content=progress,
|
2854
|
-
headers={
|
2855
|
-
"Content-Length": str(len(progress)),
|
2856
|
-
**(extra_headers if extra_headers is not None else {}),
|
2857
|
-
},
|
2858
|
-
)
|
2859
|
-
response.raise_for_status()
|
2860
|
-
except Exception as e:
|
2861
|
-
progress.rewind()
|
2862
|
-
logger.error(f"upload_file_async exception {url}: {e}")
|
2863
|
-
if isinstance(e, httpx.RequestError):
|
2864
|
-
logger.error(f"upload_file_async request headers: {e.request.headers}")
|
2865
|
-
if isinstance(e, httpx.HTTPStatusError):
|
2866
|
-
logger.error(f"upload_file_async response body: {e.response.content!r}")
|
2867
|
-
raise
|
2868
|
-
|
2869
|
-
async def upload_file_retry_async(
|
2870
|
-
self,
|
2871
|
-
url: str,
|
2872
|
-
file: IO[bytes],
|
2873
|
-
callback: Optional["ProgressFn"] = None,
|
2874
|
-
extra_headers: Optional[Dict[str, str]] = None,
|
2875
|
-
num_retries: int = 100,
|
2876
|
-
) -> None:
|
2877
|
-
backoff = retry.FilteredBackoff(
|
2878
|
-
filter=check_httpx_exc_retriable,
|
2879
|
-
wrapped=retry.ExponentialBackoff(
|
2880
|
-
initial_sleep=datetime.timedelta(seconds=1),
|
2881
|
-
max_sleep=datetime.timedelta(seconds=60),
|
2882
|
-
max_retries=num_retries,
|
2883
|
-
timeout_at=datetime.datetime.now() + datetime.timedelta(days=7),
|
2884
|
-
),
|
2885
|
-
)
|
2886
|
-
|
2887
|
-
await retry.retry_async(
|
2888
|
-
backoff=backoff,
|
2889
|
-
fn=self.upload_file_async,
|
2890
|
-
url=url,
|
2891
|
-
file=file,
|
2892
|
-
callback=callback,
|
2893
|
-
extra_headers=extra_headers,
|
2894
|
-
)
|
2895
|
-
|
2896
2875
|
@normalize_exceptions
|
2897
2876
|
def register_agent(
|
2898
2877
|
self,
|
@@ -3039,9 +3018,10 @@ class Api:
|
|
3039
3018
|
parameter["distribution"] = "uniform"
|
3040
3019
|
else:
|
3041
3020
|
raise ValueError(
|
3042
|
-
"Parameter
|
3043
|
-
"uniform distribution) or ints (for an int_uniform distribution)."
|
3044
|
-
|
3021
|
+
"Parameter {} is ambiguous, please specify bounds as both floats (for a float_"
|
3022
|
+
"uniform distribution) or ints (for an int_uniform distribution).".format(
|
3023
|
+
parameter_name
|
3024
|
+
)
|
3045
3025
|
)
|
3046
3026
|
return config
|
3047
3027
|
|
@@ -3144,7 +3124,9 @@ class Api:
|
|
3144
3124
|
|
3145
3125
|
# Silly, but attr-dicts like EasyDicts don't serialize correctly to yaml.
|
3146
3126
|
# This sanitizes them with a round trip pass through json to get a regular dict.
|
3147
|
-
config_str = yaml.dump(
|
3127
|
+
config_str = yaml.dump(
|
3128
|
+
json.loads(json.dumps(config)), Dumper=util.NonOctalStringDumper
|
3129
|
+
)
|
3148
3130
|
|
3149
3131
|
err: Optional[Exception] = None
|
3150
3132
|
for mutation in mutations:
|
@@ -3887,6 +3869,36 @@ class Api:
|
|
3887
3869
|
response["updateArtifactManifest"]["artifactManifest"]["file"],
|
3888
3870
|
)
|
3889
3871
|
|
3872
|
+
def update_artifact_metadata(
|
3873
|
+
self, artifact_id: str, metadata: Dict[str, Any]
|
3874
|
+
) -> Dict[str, Any]:
|
3875
|
+
"""Set the metadata of the given artifact version."""
|
3876
|
+
mutation = gql(
|
3877
|
+
"""
|
3878
|
+
mutation UpdateArtifact(
|
3879
|
+
$artifactID: ID!,
|
3880
|
+
$metadata: JSONString,
|
3881
|
+
) {
|
3882
|
+
updateArtifact(input: {
|
3883
|
+
artifactID: $artifactID,
|
3884
|
+
metadata: $metadata,
|
3885
|
+
}) {
|
3886
|
+
artifact {
|
3887
|
+
id
|
3888
|
+
}
|
3889
|
+
}
|
3890
|
+
}
|
3891
|
+
"""
|
3892
|
+
)
|
3893
|
+
response = self.gql(
|
3894
|
+
mutation,
|
3895
|
+
variable_values={
|
3896
|
+
"artifactID": artifact_id,
|
3897
|
+
"metadata": json.dumps(metadata),
|
3898
|
+
},
|
3899
|
+
)
|
3900
|
+
return response["updateArtifact"]["artifact"]
|
3901
|
+
|
3890
3902
|
def _resolve_client_id(
|
3891
3903
|
self,
|
3892
3904
|
client_id: str,
|
@@ -4081,9 +4093,9 @@ class Api:
|
|
4081
4093
|
s = self.sweep(sweep=sweep, entity=entity, project=project, specs="{}")
|
4082
4094
|
curr_state = s["state"].upper()
|
4083
4095
|
if state == "PAUSED" and curr_state not in ("PAUSED", "RUNNING"):
|
4084
|
-
raise Exception("Cannot pause
|
4096
|
+
raise Exception("Cannot pause {} sweep.".format(curr_state.lower()))
|
4085
4097
|
elif state != "RUNNING" and curr_state not in ("RUNNING", "PAUSED", "PENDING"):
|
4086
|
-
raise Exception("Sweep already
|
4098
|
+
raise Exception("Sweep already {}.".format(curr_state.lower()))
|
4087
4099
|
sweep_id = s["id"]
|
4088
4100
|
mutation = gql(
|
4089
4101
|
"""
|