wandb 0.17.4__py3-none-win32.whl → 0.17.5__py3-none-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/bin/wandb-core +0 -0
- wandb/filesync/upload_job.py +1 -1
- wandb/proto/v3/wandb_internal_pb2.py +339 -328
- wandb/proto/v4/wandb_internal_pb2.py +326 -323
- wandb/proto/v5/wandb_internal_pb2.py +326 -323
- wandb/sdk/artifacts/artifact.py +11 -24
- wandb/sdk/interface/interface.py +12 -5
- wandb/sdk/interface/interface_shared.py +9 -7
- wandb/sdk/internal/handler.py +1 -1
- wandb/sdk/internal/internal_api.py +4 -4
- wandb/sdk/internal/sender.py +9 -2
- wandb/sdk/launch/builder/kaniko_builder.py +30 -9
- wandb/sdk/launch/inputs/internal.py +79 -2
- wandb/sdk/launch/inputs/manage.py +21 -3
- wandb/sdk/lib/tracelog.py +2 -2
- wandb/sdk/wandb_manager.py +9 -5
- wandb/sdk/wandb_run.py +100 -75
- wandb/util.py +29 -11
- {wandb-0.17.4.dist-info → wandb-0.17.5.dist-info}/METADATA +1 -1
- {wandb-0.17.4.dist-info → wandb-0.17.5.dist-info}/RECORD +24 -24
- {wandb-0.17.4.dist-info → wandb-0.17.5.dist-info}/WHEEL +0 -0
- {wandb-0.17.4.dist-info → wandb-0.17.5.dist-info}/entry_points.txt +0 -0
- {wandb-0.17.4.dist-info → wandb-0.17.5.dist-info}/licenses/LICENSE +0 -0
wandb/sdk/artifacts/artifact.py
CHANGED
@@ -1300,7 +1300,8 @@ class Artifact:
|
|
1300
1300
|
automatic integrity validation. Disabling checksumming will speed up
|
1301
1301
|
artifact creation but reference directories will not iterated through so the
|
1302
1302
|
objects in the directory will not be saved to the artifact. We recommend
|
1303
|
-
adding reference objects in
|
1303
|
+
setting `checksum=False` when adding reference objects, in which case
|
1304
|
+
a new version will only be created if the reference URI changes.
|
1304
1305
|
max_objects: The maximum number of objects to consider when adding a
|
1305
1306
|
reference that points to directory or bucket store prefix. By default,
|
1306
1307
|
the maximum number of objects allowed for Amazon S3,
|
@@ -1627,16 +1628,19 @@ class Artifact:
|
|
1627
1628
|
) -> FilePathStr:
|
1628
1629
|
"""Download the contents of the artifact to the specified root directory.
|
1629
1630
|
|
1630
|
-
Existing files located within `root` are not modified. Explicitly delete
|
1631
|
-
|
1632
|
-
|
1631
|
+
Existing files located within `root` are not modified. Explicitly delete `root`
|
1632
|
+
before you call `download` if you want the contents of `root` to exactly match
|
1633
|
+
the artifact.
|
1633
1634
|
|
1634
1635
|
Arguments:
|
1635
1636
|
root: The directory W&B stores the artifact's files.
|
1636
1637
|
allow_missing_references: If set to `True`, any invalid reference paths
|
1637
1638
|
will be ignored while downloading referenced files.
|
1638
|
-
skip_cache: If set to `True`, the artifact cache will be skipped when
|
1639
|
-
and W&B will download each file into the default root or
|
1639
|
+
skip_cache: If set to `True`, the artifact cache will be skipped when
|
1640
|
+
downloading and W&B will download each file into the default root or
|
1641
|
+
specified download directory.
|
1642
|
+
path_prefix: If specified, only files with a path that starts with the given
|
1643
|
+
prefix will be downloaded. Uses unix format (forward slashes).
|
1640
1644
|
|
1641
1645
|
Returns:
|
1642
1646
|
The path to the downloaded contents.
|
@@ -1663,23 +1667,6 @@ class Artifact:
|
|
1663
1667
|
path_prefix=path_prefix,
|
1664
1668
|
)
|
1665
1669
|
|
1666
|
-
@classmethod
|
1667
|
-
def path_contains_dir_prefix(cls, path: StrPath, dir_path: StrPath) -> bool:
|
1668
|
-
"""Returns true if `path` contains `dir_path` as a prefix."""
|
1669
|
-
if not dir_path:
|
1670
|
-
return True
|
1671
|
-
path_parts = PurePosixPath(path).parts
|
1672
|
-
dir_parts = PurePosixPath(dir_path).parts
|
1673
|
-
return path_parts[: len(dir_parts)] == dir_parts
|
1674
|
-
|
1675
|
-
@classmethod
|
1676
|
-
def should_download_entry(
|
1677
|
-
cls, entry: ArtifactManifestEntry, prefix: Optional[StrPath]
|
1678
|
-
) -> bool:
|
1679
|
-
if prefix is None:
|
1680
|
-
return True
|
1681
|
-
return cls.path_contains_dir_prefix(entry.path, prefix)
|
1682
|
-
|
1683
1670
|
def _download_using_core(
|
1684
1671
|
self,
|
1685
1672
|
root: str,
|
@@ -1816,7 +1803,7 @@ class Artifact:
|
|
1816
1803
|
# Handled by core
|
1817
1804
|
continue
|
1818
1805
|
entry._download_url = edge["node"]["directUrl"]
|
1819
|
-
if
|
1806
|
+
if (not path_prefix) or entry.path.startswith(str(path_prefix)):
|
1820
1807
|
active_futures.add(executor.submit(download_entry, entry))
|
1821
1808
|
# Wait for download threads to catch up.
|
1822
1809
|
max_backlog = fetch_url_batch_size
|
wandb/sdk/interface/interface.py
CHANGED
@@ -358,7 +358,7 @@ class InterfaceBase:
|
|
358
358
|
proto_extra.value_json = json.dumps(v)
|
359
359
|
return proto_manifest
|
360
360
|
|
361
|
-
def
|
361
|
+
def deliver_link_artifact(
|
362
362
|
self,
|
363
363
|
run: "Run",
|
364
364
|
artifact: "Artifact",
|
@@ -366,8 +366,8 @@ class InterfaceBase:
|
|
366
366
|
aliases: Iterable[str],
|
367
367
|
entity: Optional[str] = None,
|
368
368
|
project: Optional[str] = None,
|
369
|
-
) ->
|
370
|
-
link_artifact = pb.
|
369
|
+
) -> MailboxHandle:
|
370
|
+
link_artifact = pb.LinkArtifactRequest()
|
371
371
|
if artifact.is_draft():
|
372
372
|
link_artifact.client_id = artifact._client_id
|
373
373
|
else:
|
@@ -377,10 +377,12 @@ class InterfaceBase:
|
|
377
377
|
link_artifact.portfolio_project = project or run.project
|
378
378
|
link_artifact.portfolio_aliases.extend(aliases)
|
379
379
|
|
380
|
-
self.
|
380
|
+
return self._deliver_link_artifact(link_artifact)
|
381
381
|
|
382
382
|
@abstractmethod
|
383
|
-
def
|
383
|
+
def _deliver_link_artifact(
|
384
|
+
self, link_artifact: pb.LinkArtifactRequest
|
385
|
+
) -> MailboxHandle:
|
384
386
|
raise NotImplementedError
|
385
387
|
|
386
388
|
@staticmethod
|
@@ -749,6 +751,7 @@ class InterfaceBase:
|
|
749
751
|
self,
|
750
752
|
include_paths: List[List[str]],
|
751
753
|
exclude_paths: List[List[str]],
|
754
|
+
input_schema: Optional[dict],
|
752
755
|
run_config: bool = False,
|
753
756
|
file_path: str = "",
|
754
757
|
):
|
@@ -766,6 +769,8 @@ class InterfaceBase:
|
|
766
769
|
Args:
|
767
770
|
include_paths: paths within config to include as job inputs.
|
768
771
|
exclude_paths: paths within config to exclude as job inputs.
|
772
|
+
input_schema: A JSON Schema describing which attributes will be
|
773
|
+
editable from the Launch drawer.
|
769
774
|
run_config: bool indicating whether wandb.config is the input source.
|
770
775
|
file_path: path to file to include as a job input.
|
771
776
|
"""
|
@@ -788,6 +793,8 @@ class InterfaceBase:
|
|
788
793
|
pb.JobInputSource.ConfigFileSource(path=file_path),
|
789
794
|
)
|
790
795
|
request.input_source.CopyFrom(source)
|
796
|
+
if input_schema:
|
797
|
+
request.input_schema = json_dumps_safer(input_schema)
|
791
798
|
|
792
799
|
return self._publish_job_input(request)
|
793
800
|
|
@@ -137,6 +137,7 @@ class InterfaceShared(InterfaceBase):
|
|
137
137
|
check_version: Optional[pb.CheckVersionRequest] = None,
|
138
138
|
log_artifact: Optional[pb.LogArtifactRequest] = None,
|
139
139
|
download_artifact: Optional[pb.DownloadArtifactRequest] = None,
|
140
|
+
link_artifact: Optional[pb.LinkArtifactRequest] = None,
|
140
141
|
defer: Optional[pb.DeferRequest] = None,
|
141
142
|
attach: Optional[pb.AttachRequest] = None,
|
142
143
|
server_info: Optional[pb.ServerInfoRequest] = None,
|
@@ -184,6 +185,8 @@ class InterfaceShared(InterfaceBase):
|
|
184
185
|
request.log_artifact.CopyFrom(log_artifact)
|
185
186
|
elif download_artifact:
|
186
187
|
request.download_artifact.CopyFrom(download_artifact)
|
188
|
+
elif link_artifact:
|
189
|
+
request.link_artifact.CopyFrom(link_artifact)
|
187
190
|
elif defer:
|
188
191
|
request.defer.CopyFrom(defer)
|
189
192
|
elif attach:
|
@@ -242,7 +245,6 @@ class InterfaceShared(InterfaceBase):
|
|
242
245
|
request: Optional[pb.Request] = None,
|
243
246
|
telemetry: Optional[tpb.TelemetryRecord] = None,
|
244
247
|
preempting: Optional[pb.RunPreemptingRecord] = None,
|
245
|
-
link_artifact: Optional[pb.LinkArtifactRecord] = None,
|
246
248
|
use_artifact: Optional[pb.UseArtifactRecord] = None,
|
247
249
|
output: Optional[pb.OutputRecord] = None,
|
248
250
|
output_raw: Optional[pb.OutputRawRecord] = None,
|
@@ -282,8 +284,6 @@ class InterfaceShared(InterfaceBase):
|
|
282
284
|
record.metric.CopyFrom(metric)
|
283
285
|
elif preempting:
|
284
286
|
record.preempting.CopyFrom(preempting)
|
285
|
-
elif link_artifact:
|
286
|
-
record.link_artifact.CopyFrom(link_artifact)
|
287
287
|
elif use_artifact:
|
288
288
|
record.use_artifact.CopyFrom(use_artifact)
|
289
289
|
elif output:
|
@@ -393,10 +393,6 @@ class InterfaceShared(InterfaceBase):
|
|
393
393
|
rec = self._make_record(files=files)
|
394
394
|
self._publish(rec)
|
395
395
|
|
396
|
-
def _publish_link_artifact(self, link_artifact: pb.LinkArtifactRecord) -> Any:
|
397
|
-
rec = self._make_record(link_artifact=link_artifact)
|
398
|
-
self._publish(rec)
|
399
|
-
|
400
396
|
def _publish_use_artifact(self, use_artifact: pb.UseArtifactRecord) -> Any:
|
401
397
|
rec = self._make_record(use_artifact=use_artifact)
|
402
398
|
self._publish(rec)
|
@@ -411,6 +407,12 @@ class InterfaceShared(InterfaceBase):
|
|
411
407
|
rec = self._make_request(download_artifact=download_artifact)
|
412
408
|
return self._deliver_record(rec)
|
413
409
|
|
410
|
+
def _deliver_link_artifact(
|
411
|
+
self, link_artifact: pb.LinkArtifactRequest
|
412
|
+
) -> MailboxHandle:
|
413
|
+
rec = self._make_request(link_artifact=link_artifact)
|
414
|
+
return self._deliver_record(rec)
|
415
|
+
|
414
416
|
def _publish_artifact(self, proto_artifact: pb.ArtifactRecord) -> None:
|
415
417
|
rec = self._make_record(artifact=proto_artifact)
|
416
418
|
self._publish(rec)
|
wandb/sdk/internal/handler.py
CHANGED
@@ -230,7 +230,7 @@ class HandleManager:
|
|
230
230
|
def handle_files(self, record: Record) -> None:
|
231
231
|
self._dispatch_record(record)
|
232
232
|
|
233
|
-
def
|
233
|
+
def handle_request_link_artifact(self, record: Record) -> None:
|
234
234
|
self._dispatch_record(record)
|
235
235
|
|
236
236
|
def handle_use_artifact(self, record: Record) -> None:
|
@@ -232,14 +232,14 @@ class Api:
|
|
232
232
|
|
233
233
|
# todo: remove these hacky hacks after settings refactor is complete
|
234
234
|
# keeping this code here to limit scope and so that it is easy to remove later
|
235
|
-
|
235
|
+
self._extra_http_headers = self.settings("_extra_http_headers") or json.loads(
|
236
236
|
self._environ.get("WANDB__EXTRA_HTTP_HEADERS", "{}")
|
237
237
|
)
|
238
|
-
|
238
|
+
self._extra_http_headers.update(_thread_local_api_settings.headers or {})
|
239
239
|
|
240
240
|
auth = None
|
241
241
|
if self.access_token is not None:
|
242
|
-
|
242
|
+
self._extra_http_headers["Authorization"] = f"Bearer {self.access_token}"
|
243
243
|
elif _thread_local_api_settings.cookies is None:
|
244
244
|
auth = ("api", self.api_key or "")
|
245
245
|
|
@@ -253,7 +253,7 @@ class Api:
|
|
253
253
|
"User-Agent": self.user_agent,
|
254
254
|
"X-WANDB-USERNAME": env.get_username(env=self._environ),
|
255
255
|
"X-WANDB-USER-EMAIL": env.get_user_email(env=self._environ),
|
256
|
-
**
|
256
|
+
**self._extra_http_headers,
|
257
257
|
},
|
258
258
|
use_json=True,
|
259
259
|
# this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
|
wandb/sdk/internal/sender.py
CHANGED
@@ -1473,8 +1473,13 @@ class SendManager:
|
|
1473
1473
|
# tbrecord watching threads are handled by handler.py
|
1474
1474
|
pass
|
1475
1475
|
|
1476
|
-
def
|
1477
|
-
|
1476
|
+
def send_request_link_artifact(self, record: "Record") -> None:
|
1477
|
+
if not (record.control.req_resp or record.control.mailbox_slot):
|
1478
|
+
raise ValueError(
|
1479
|
+
f"Expected either `req_resp` or `mailbox_slot`, got: {record.control!r}"
|
1480
|
+
)
|
1481
|
+
result = proto_util._result_from_record(record)
|
1482
|
+
link = record.request.link_artifact
|
1478
1483
|
client_id = link.client_id
|
1479
1484
|
server_id = link.server_id
|
1480
1485
|
portfolio_name = link.portfolio_name
|
@@ -1490,7 +1495,9 @@ class SendManager:
|
|
1490
1495
|
client_id, server_id, portfolio_name, entity, project, aliases
|
1491
1496
|
)
|
1492
1497
|
except Exception as e:
|
1498
|
+
result.response.log_artifact_response.error_message = f'error linking artifact to "{entity}/{project}/{portfolio_name}"; error: {e}'
|
1493
1499
|
logger.warning("Failed to link artifact to portfolio: %s", e)
|
1500
|
+
self._respond_result(result)
|
1494
1501
|
|
1495
1502
|
def send_use_artifact(self, record: "Record") -> None:
|
1496
1503
|
"""Pretend to send a used artifact.
|
@@ -263,11 +263,17 @@ class KanikoBuilder(AbstractBuilder):
|
|
263
263
|
repo_uri = await self.registry.get_repo_uri()
|
264
264
|
image_uri = repo_uri + ":" + image_tag
|
265
265
|
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
):
|
270
|
-
|
266
|
+
# The DOCKER_CONFIG_SECRET option is mutually exclusive with the
|
267
|
+
# registry classes, so we must skip the check for image existence in
|
268
|
+
# that case.
|
269
|
+
if not launch_project.build_required():
|
270
|
+
if DOCKER_CONFIG_SECRET:
|
271
|
+
wandb.termlog(
|
272
|
+
f"Skipping check for existing image {image_uri} due to custom dockerconfig."
|
273
|
+
)
|
274
|
+
else:
|
275
|
+
if await self.registry.check_image_exists(image_uri):
|
276
|
+
return image_uri
|
271
277
|
|
272
278
|
_logger.info(f"Building image {image_uri}...")
|
273
279
|
_, api_client = await get_kube_context_and_api_client(
|
@@ -286,7 +292,12 @@ class KanikoBuilder(AbstractBuilder):
|
|
286
292
|
wandb.termlog(f"{LOG_PREFIX}Created kaniko job {build_job_name}")
|
287
293
|
|
288
294
|
try:
|
289
|
-
|
295
|
+
# DOCKER_CONFIG_SECRET is a user provided dockerconfigjson. Skip our
|
296
|
+
# dockerconfig handling if it's set.
|
297
|
+
if (
|
298
|
+
isinstance(self.registry, AzureContainerRegistry)
|
299
|
+
and not DOCKER_CONFIG_SECRET
|
300
|
+
):
|
290
301
|
dockerfile_config_map = client.V1ConfigMap(
|
291
302
|
metadata=client.V1ObjectMeta(
|
292
303
|
name=f"docker-config-{build_job_name}"
|
@@ -344,7 +355,10 @@ class KanikoBuilder(AbstractBuilder):
|
|
344
355
|
finally:
|
345
356
|
wandb.termlog(f"{LOG_PREFIX}Cleaning up resources")
|
346
357
|
try:
|
347
|
-
if
|
358
|
+
if (
|
359
|
+
isinstance(self.registry, AzureContainerRegistry)
|
360
|
+
and not DOCKER_CONFIG_SECRET
|
361
|
+
):
|
348
362
|
await core_v1.delete_namespaced_config_map(
|
349
363
|
f"docker-config-{build_job_name}", "wandb"
|
350
364
|
)
|
@@ -498,7 +512,10 @@ class KanikoBuilder(AbstractBuilder):
|
|
498
512
|
"readOnly": True,
|
499
513
|
}
|
500
514
|
)
|
501
|
-
if
|
515
|
+
if (
|
516
|
+
isinstance(self.registry, AzureContainerRegistry)
|
517
|
+
and not DOCKER_CONFIG_SECRET
|
518
|
+
):
|
502
519
|
# Add the docker config map
|
503
520
|
volumes.append(
|
504
521
|
{
|
@@ -533,7 +550,11 @@ class KanikoBuilder(AbstractBuilder):
|
|
533
550
|
# Apply the rest of our defaults
|
534
551
|
pod_labels["wandb"] = "launch"
|
535
552
|
# This annotation is required to enable azure workload identity.
|
536
|
-
if
|
553
|
+
# Don't add this label if using a docker config secret for auth.
|
554
|
+
if (
|
555
|
+
isinstance(self.registry, AzureContainerRegistry)
|
556
|
+
and not DOCKER_CONFIG_SECRET
|
557
|
+
):
|
537
558
|
pod_labels["azure.workload.identity/use"] = "true"
|
538
559
|
pod_spec["restartPolicy"] = pod_spec.get("restartPolicy", "Never")
|
539
560
|
pod_spec["activeDeadlineSeconds"] = pod_spec.get(
|
@@ -11,7 +11,7 @@ import os
|
|
11
11
|
import pathlib
|
12
12
|
import shutil
|
13
13
|
import tempfile
|
14
|
-
from typing import List, Optional
|
14
|
+
from typing import Any, Dict, List, Optional
|
15
15
|
|
16
16
|
import wandb
|
17
17
|
import wandb.data_types
|
@@ -62,11 +62,13 @@ class JobInputArguments:
|
|
62
62
|
self,
|
63
63
|
include: Optional[List[str]] = None,
|
64
64
|
exclude: Optional[List[str]] = None,
|
65
|
+
schema: Optional[dict] = None,
|
65
66
|
file_path: Optional[str] = None,
|
66
67
|
run_config: Optional[bool] = None,
|
67
68
|
):
|
68
69
|
self.include = include
|
69
70
|
self.exclude = exclude
|
71
|
+
self.schema = schema
|
70
72
|
self.file_path = file_path
|
71
73
|
self.run_config = run_config
|
72
74
|
|
@@ -121,15 +123,66 @@ def _publish_job_input(
|
|
121
123
|
exclude_paths=[_split_on_unesc_dot(path) for path in input.exclude]
|
122
124
|
if input.exclude
|
123
125
|
else [],
|
126
|
+
input_schema=input.schema,
|
124
127
|
run_config=input.run_config,
|
125
128
|
file_path=input.file_path or "",
|
126
129
|
)
|
127
130
|
|
128
131
|
|
132
|
+
def _replace_refs_and_allofs(schema: dict, defs: dict) -> dict:
|
133
|
+
"""Recursively fix JSON schemas with common issues.
|
134
|
+
|
135
|
+
1. Replaces any instances of $ref with their associated definition in defs
|
136
|
+
2. Removes any "allOf" lists that only have one item, "lifting" the item up
|
137
|
+
See test_internal.py for examples
|
138
|
+
"""
|
139
|
+
ret: Dict[str, Any] = {}
|
140
|
+
if "$ref" in schema:
|
141
|
+
# Reference found, replace it with its definition
|
142
|
+
def_key = schema["$ref"].split("#/$defs/")[1]
|
143
|
+
# Also run recursive replacement in case a ref contains more refs
|
144
|
+
return _replace_refs_and_allofs(defs.pop(def_key), defs)
|
145
|
+
for key, val in schema.items():
|
146
|
+
if isinstance(val, dict):
|
147
|
+
# Step into dicts recursively
|
148
|
+
new_val_dict = _replace_refs_and_allofs(val, defs)
|
149
|
+
ret[key] = new_val_dict
|
150
|
+
elif isinstance(val, list):
|
151
|
+
# Step into each item in the list
|
152
|
+
new_val_list = []
|
153
|
+
for item in val:
|
154
|
+
if isinstance(item, dict):
|
155
|
+
new_val_list.append(_replace_refs_and_allofs(item, defs))
|
156
|
+
else:
|
157
|
+
new_val_list.append(item)
|
158
|
+
# Lift up allOf blocks with only one item
|
159
|
+
if (
|
160
|
+
key == "allOf"
|
161
|
+
and len(new_val_list) == 1
|
162
|
+
and isinstance(new_val_list[0], dict)
|
163
|
+
):
|
164
|
+
ret.update(new_val_list[0])
|
165
|
+
else:
|
166
|
+
ret[key] = new_val_list
|
167
|
+
else:
|
168
|
+
# For anything else (str, int, etc) keep it as-is
|
169
|
+
ret[key] = val
|
170
|
+
return ret
|
171
|
+
|
172
|
+
|
173
|
+
def _convert_pydantic_model_to_jsonschema(model: Any) -> dict:
|
174
|
+
schema = model.model_json_schema()
|
175
|
+
defs = schema.pop("$defs")
|
176
|
+
if not defs:
|
177
|
+
return schema
|
178
|
+
return _replace_refs_and_allofs(schema, defs)
|
179
|
+
|
180
|
+
|
129
181
|
def handle_config_file_input(
|
130
182
|
path: str,
|
131
183
|
include: Optional[List[str]] = None,
|
132
184
|
exclude: Optional[List[str]] = None,
|
185
|
+
schema: Optional[Any] = None,
|
133
186
|
):
|
134
187
|
"""Declare an overridable configuration file for a launch job.
|
135
188
|
|
@@ -151,9 +204,20 @@ def handle_config_file_input(
|
|
151
204
|
path,
|
152
205
|
dest,
|
153
206
|
)
|
207
|
+
# This supports both an instance of a pydantic BaseModel class (e.g. schema=MySchema(...))
|
208
|
+
# or the BaseModel class itself (e.g. schema=MySchema)
|
209
|
+
if hasattr(schema, "model_json_schema") and callable(
|
210
|
+
schema.model_json_schema # type: ignore
|
211
|
+
):
|
212
|
+
schema = _convert_pydantic_model_to_jsonschema(schema)
|
213
|
+
if schema and not isinstance(schema, dict):
|
214
|
+
raise LaunchError(
|
215
|
+
"schema must be a dict, Pydantic model instance, or Pydantic model class."
|
216
|
+
)
|
154
217
|
arguments = JobInputArguments(
|
155
218
|
include=include,
|
156
219
|
exclude=exclude,
|
220
|
+
schema=schema,
|
157
221
|
file_path=path,
|
158
222
|
run_config=False,
|
159
223
|
)
|
@@ -165,7 +229,9 @@ def handle_config_file_input(
|
|
165
229
|
|
166
230
|
|
167
231
|
def handle_run_config_input(
|
168
|
-
include: Optional[List[str]] = None,
|
232
|
+
include: Optional[List[str]] = None,
|
233
|
+
exclude: Optional[List[str]] = None,
|
234
|
+
schema: Optional[Any] = None,
|
169
235
|
):
|
170
236
|
"""Declare wandb.config as an overridable configuration for a launch job.
|
171
237
|
|
@@ -175,9 +241,20 @@ def handle_run_config_input(
|
|
175
241
|
If there is no active run, the include and exclude paths are staged and sent
|
176
242
|
when a run is created.
|
177
243
|
"""
|
244
|
+
# This supports both an instance of a pydantic BaseModel class (e.g. schema=MySchema(...))
|
245
|
+
# or the BaseModel class itself (e.g. schema=MySchema)
|
246
|
+
if hasattr(schema, "model_json_schema") and callable(
|
247
|
+
schema.model_json_schema # type: ignore
|
248
|
+
):
|
249
|
+
schema = _convert_pydantic_model_to_jsonschema(schema)
|
250
|
+
if schema and not isinstance(schema, dict):
|
251
|
+
raise LaunchError(
|
252
|
+
"schema must be a dict, Pydantic model instance, or Pydantic model class."
|
253
|
+
)
|
178
254
|
arguments = JobInputArguments(
|
179
255
|
include=include,
|
180
256
|
exclude=exclude,
|
257
|
+
schema=schema,
|
181
258
|
run_config=True,
|
182
259
|
file_path=None,
|
183
260
|
)
|
@@ -1,12 +1,13 @@
|
|
1
1
|
"""Functions for declaring overridable configuration for launch jobs."""
|
2
2
|
|
3
|
-
from typing import List, Optional
|
3
|
+
from typing import Any, List, Optional
|
4
4
|
|
5
5
|
|
6
6
|
def manage_config_file(
|
7
7
|
path: str,
|
8
8
|
include: Optional[List[str]] = None,
|
9
9
|
exclude: Optional[List[str]] = None,
|
10
|
+
schema: Optional[Any] = None,
|
10
11
|
):
|
11
12
|
r"""Declare an overridable configuration file for a launch job.
|
12
13
|
|
@@ -43,18 +44,27 @@ def manage_config_file(
|
|
43
44
|
relative and must not contain backwards traversal, i.e. `..`.
|
44
45
|
include (List[str]): A list of keys to include in the configuration file.
|
45
46
|
exclude (List[str]): A list of keys to exclude from the configuration file.
|
47
|
+
schema (dict | Pydantic model): A JSON Schema or Pydantic model describing
|
48
|
+
describing which attributes will be editable from the Launch drawer.
|
49
|
+
Accepts both an instance of a Pydantic BaseModel class or the BaseModel
|
50
|
+
class itself.
|
46
51
|
|
47
52
|
Raises:
|
48
53
|
LaunchError: If the path is not valid, or if there is no active run.
|
49
54
|
"""
|
55
|
+
# note: schema's Any type is because in the case where a BaseModel class is
|
56
|
+
# provided, its type is a pydantic internal type that we don't want our typing
|
57
|
+
# to depend on. schema's type should be considered
|
58
|
+
# "Optional[dict | <something with a .model_json_schema() method>]"
|
50
59
|
from .internal import handle_config_file_input
|
51
60
|
|
52
|
-
return handle_config_file_input(path, include, exclude)
|
61
|
+
return handle_config_file_input(path, include, exclude, schema)
|
53
62
|
|
54
63
|
|
55
64
|
def manage_wandb_config(
|
56
65
|
include: Optional[List[str]] = None,
|
57
66
|
exclude: Optional[List[str]] = None,
|
67
|
+
schema: Optional[Any] = None,
|
58
68
|
):
|
59
69
|
r"""Declare wandb.config as an overridable configuration for a launch job.
|
60
70
|
|
@@ -86,10 +96,18 @@ def manage_wandb_config(
|
|
86
96
|
Args:
|
87
97
|
include (List[str]): A list of subtrees to include in the configuration.
|
88
98
|
exclude (List[str]): A list of subtrees to exclude from the configuration.
|
99
|
+
schema (dict | Pydantic model): A JSON Schema or Pydantic model describing
|
100
|
+
describing which attributes will be editable from the Launch drawer.
|
101
|
+
Accepts both an instance of a Pydantic BaseModel class or the BaseModel
|
102
|
+
class itself.
|
89
103
|
|
90
104
|
Raises:
|
91
105
|
LaunchError: If there is no active run.
|
92
106
|
"""
|
107
|
+
# note: schema's Any type is because in the case where a BaseModel class is
|
108
|
+
# provided, its type is a pydantic internal type that we don't want our typing
|
109
|
+
# to depend on. schema's type should be considered
|
110
|
+
# "Optional[dict | <something with a .model_json_schema() method>]"
|
93
111
|
from .internal import handle_run_config_input
|
94
112
|
|
95
|
-
handle_run_config_input(include, exclude)
|
113
|
+
handle_run_config_input(include, exclude, schema)
|
wandb/sdk/lib/tracelog.py
CHANGED
@@ -45,8 +45,8 @@ logger = logging.getLogger(__name__)
|
|
45
45
|
ANNOTATE_QUEUE_NAME = "_DEBUGLOG_QUEUE_NAME"
|
46
46
|
|
47
47
|
# capture stdout and stderr before anyone messes with them
|
48
|
-
stdout_write = sys.__stdout__.write
|
49
|
-
stderr_write = sys.__stderr__.write
|
48
|
+
stdout_write = sys.__stdout__.write # type: ignore
|
49
|
+
stderr_write = sys.__stderr__.write # type: ignore
|
50
50
|
|
51
51
|
|
52
52
|
def _log(
|
wandb/sdk/wandb_manager.py
CHANGED
@@ -114,17 +114,21 @@ class _Manager:
|
|
114
114
|
|
115
115
|
try:
|
116
116
|
svc_iface._svc_connect(port=port)
|
117
|
+
|
117
118
|
except ConnectionRefusedError as e:
|
118
119
|
if not psutil.pid_exists(self._token.pid):
|
119
120
|
message = (
|
120
|
-
"Connection to wandb service failed
|
121
|
-
"
|
121
|
+
"Connection to wandb service failed"
|
122
|
+
" because the process is not available."
|
122
123
|
)
|
123
124
|
else:
|
124
|
-
message =
|
125
|
-
raise ManagerConnectionRefusedError(message)
|
125
|
+
message = "Connection to wandb service failed."
|
126
|
+
raise ManagerConnectionRefusedError(message) from e
|
127
|
+
|
126
128
|
except Exception as e:
|
127
|
-
raise ManagerConnectionError(
|
129
|
+
raise ManagerConnectionError(
|
130
|
+
"Connection to wandb service failed.",
|
131
|
+
) from e
|
128
132
|
|
129
133
|
def __init__(self, settings: "Settings") -> None:
|
130
134
|
"""Connects to the internal service, starting it if necessary."""
|