wandb 0.17.4__py3-none-any.whl → 0.17.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -1
- wandb/filesync/upload_job.py +1 -1
- wandb/proto/v3/wandb_internal_pb2.py +339 -328
- wandb/proto/v4/wandb_internal_pb2.py +326 -323
- wandb/proto/v5/wandb_internal_pb2.py +326 -323
- wandb/sdk/artifacts/artifact.py +11 -24
- wandb/sdk/interface/interface.py +12 -5
- wandb/sdk/interface/interface_shared.py +9 -7
- wandb/sdk/internal/handler.py +1 -1
- wandb/sdk/internal/internal_api.py +4 -4
- wandb/sdk/internal/sender.py +9 -2
- wandb/sdk/launch/builder/kaniko_builder.py +30 -9
- wandb/sdk/launch/inputs/internal.py +79 -2
- wandb/sdk/launch/inputs/manage.py +21 -3
- wandb/sdk/lib/tracelog.py +2 -2
- wandb/sdk/wandb_manager.py +9 -5
- wandb/sdk/wandb_run.py +100 -75
- wandb/util.py +29 -11
- {wandb-0.17.4.dist-info → wandb-0.17.5.dist-info}/METADATA +1 -1
- {wandb-0.17.4.dist-info → wandb-0.17.5.dist-info}/RECORD +23 -23
- {wandb-0.17.4.dist-info → wandb-0.17.5.dist-info}/WHEEL +0 -0
- {wandb-0.17.4.dist-info → wandb-0.17.5.dist-info}/entry_points.txt +0 -0
- {wandb-0.17.4.dist-info → wandb-0.17.5.dist-info}/licenses/LICENSE +0 -0
wandb/sdk/artifacts/artifact.py
CHANGED
@@ -1300,7 +1300,8 @@ class Artifact:
|
|
1300
1300
|
automatic integrity validation. Disabling checksumming will speed up
|
1301
1301
|
artifact creation but reference directories will not iterated through so the
|
1302
1302
|
objects in the directory will not be saved to the artifact. We recommend
|
1303
|
-
adding reference objects in
|
1303
|
+
setting `checksum=False` when adding reference objects, in which case
|
1304
|
+
a new version will only be created if the reference URI changes.
|
1304
1305
|
max_objects: The maximum number of objects to consider when adding a
|
1305
1306
|
reference that points to directory or bucket store prefix. By default,
|
1306
1307
|
the maximum number of objects allowed for Amazon S3,
|
@@ -1627,16 +1628,19 @@ class Artifact:
|
|
1627
1628
|
) -> FilePathStr:
|
1628
1629
|
"""Download the contents of the artifact to the specified root directory.
|
1629
1630
|
|
1630
|
-
Existing files located within `root` are not modified. Explicitly delete
|
1631
|
-
|
1632
|
-
|
1631
|
+
Existing files located within `root` are not modified. Explicitly delete `root`
|
1632
|
+
before you call `download` if you want the contents of `root` to exactly match
|
1633
|
+
the artifact.
|
1633
1634
|
|
1634
1635
|
Arguments:
|
1635
1636
|
root: The directory W&B stores the artifact's files.
|
1636
1637
|
allow_missing_references: If set to `True`, any invalid reference paths
|
1637
1638
|
will be ignored while downloading referenced files.
|
1638
|
-
skip_cache: If set to `True`, the artifact cache will be skipped when
|
1639
|
-
and W&B will download each file into the default root or
|
1639
|
+
skip_cache: If set to `True`, the artifact cache will be skipped when
|
1640
|
+
downloading and W&B will download each file into the default root or
|
1641
|
+
specified download directory.
|
1642
|
+
path_prefix: If specified, only files with a path that starts with the given
|
1643
|
+
prefix will be downloaded. Uses unix format (forward slashes).
|
1640
1644
|
|
1641
1645
|
Returns:
|
1642
1646
|
The path to the downloaded contents.
|
@@ -1663,23 +1667,6 @@ class Artifact:
|
|
1663
1667
|
path_prefix=path_prefix,
|
1664
1668
|
)
|
1665
1669
|
|
1666
|
-
@classmethod
|
1667
|
-
def path_contains_dir_prefix(cls, path: StrPath, dir_path: StrPath) -> bool:
|
1668
|
-
"""Returns true if `path` contains `dir_path` as a prefix."""
|
1669
|
-
if not dir_path:
|
1670
|
-
return True
|
1671
|
-
path_parts = PurePosixPath(path).parts
|
1672
|
-
dir_parts = PurePosixPath(dir_path).parts
|
1673
|
-
return path_parts[: len(dir_parts)] == dir_parts
|
1674
|
-
|
1675
|
-
@classmethod
|
1676
|
-
def should_download_entry(
|
1677
|
-
cls, entry: ArtifactManifestEntry, prefix: Optional[StrPath]
|
1678
|
-
) -> bool:
|
1679
|
-
if prefix is None:
|
1680
|
-
return True
|
1681
|
-
return cls.path_contains_dir_prefix(entry.path, prefix)
|
1682
|
-
|
1683
1670
|
def _download_using_core(
|
1684
1671
|
self,
|
1685
1672
|
root: str,
|
@@ -1816,7 +1803,7 @@ class Artifact:
|
|
1816
1803
|
# Handled by core
|
1817
1804
|
continue
|
1818
1805
|
entry._download_url = edge["node"]["directUrl"]
|
1819
|
-
if
|
1806
|
+
if (not path_prefix) or entry.path.startswith(str(path_prefix)):
|
1820
1807
|
active_futures.add(executor.submit(download_entry, entry))
|
1821
1808
|
# Wait for download threads to catch up.
|
1822
1809
|
max_backlog = fetch_url_batch_size
|
wandb/sdk/interface/interface.py
CHANGED
@@ -358,7 +358,7 @@ class InterfaceBase:
|
|
358
358
|
proto_extra.value_json = json.dumps(v)
|
359
359
|
return proto_manifest
|
360
360
|
|
361
|
-
def
|
361
|
+
def deliver_link_artifact(
|
362
362
|
self,
|
363
363
|
run: "Run",
|
364
364
|
artifact: "Artifact",
|
@@ -366,8 +366,8 @@ class InterfaceBase:
|
|
366
366
|
aliases: Iterable[str],
|
367
367
|
entity: Optional[str] = None,
|
368
368
|
project: Optional[str] = None,
|
369
|
-
) ->
|
370
|
-
link_artifact = pb.
|
369
|
+
) -> MailboxHandle:
|
370
|
+
link_artifact = pb.LinkArtifactRequest()
|
371
371
|
if artifact.is_draft():
|
372
372
|
link_artifact.client_id = artifact._client_id
|
373
373
|
else:
|
@@ -377,10 +377,12 @@ class InterfaceBase:
|
|
377
377
|
link_artifact.portfolio_project = project or run.project
|
378
378
|
link_artifact.portfolio_aliases.extend(aliases)
|
379
379
|
|
380
|
-
self.
|
380
|
+
return self._deliver_link_artifact(link_artifact)
|
381
381
|
|
382
382
|
@abstractmethod
|
383
|
-
def
|
383
|
+
def _deliver_link_artifact(
|
384
|
+
self, link_artifact: pb.LinkArtifactRequest
|
385
|
+
) -> MailboxHandle:
|
384
386
|
raise NotImplementedError
|
385
387
|
|
386
388
|
@staticmethod
|
@@ -749,6 +751,7 @@ class InterfaceBase:
|
|
749
751
|
self,
|
750
752
|
include_paths: List[List[str]],
|
751
753
|
exclude_paths: List[List[str]],
|
754
|
+
input_schema: Optional[dict],
|
752
755
|
run_config: bool = False,
|
753
756
|
file_path: str = "",
|
754
757
|
):
|
@@ -766,6 +769,8 @@ class InterfaceBase:
|
|
766
769
|
Args:
|
767
770
|
include_paths: paths within config to include as job inputs.
|
768
771
|
exclude_paths: paths within config to exclude as job inputs.
|
772
|
+
input_schema: A JSON Schema describing which attributes will be
|
773
|
+
editable from the Launch drawer.
|
769
774
|
run_config: bool indicating whether wandb.config is the input source.
|
770
775
|
file_path: path to file to include as a job input.
|
771
776
|
"""
|
@@ -788,6 +793,8 @@ class InterfaceBase:
|
|
788
793
|
pb.JobInputSource.ConfigFileSource(path=file_path),
|
789
794
|
)
|
790
795
|
request.input_source.CopyFrom(source)
|
796
|
+
if input_schema:
|
797
|
+
request.input_schema = json_dumps_safer(input_schema)
|
791
798
|
|
792
799
|
return self._publish_job_input(request)
|
793
800
|
|
@@ -137,6 +137,7 @@ class InterfaceShared(InterfaceBase):
|
|
137
137
|
check_version: Optional[pb.CheckVersionRequest] = None,
|
138
138
|
log_artifact: Optional[pb.LogArtifactRequest] = None,
|
139
139
|
download_artifact: Optional[pb.DownloadArtifactRequest] = None,
|
140
|
+
link_artifact: Optional[pb.LinkArtifactRequest] = None,
|
140
141
|
defer: Optional[pb.DeferRequest] = None,
|
141
142
|
attach: Optional[pb.AttachRequest] = None,
|
142
143
|
server_info: Optional[pb.ServerInfoRequest] = None,
|
@@ -184,6 +185,8 @@ class InterfaceShared(InterfaceBase):
|
|
184
185
|
request.log_artifact.CopyFrom(log_artifact)
|
185
186
|
elif download_artifact:
|
186
187
|
request.download_artifact.CopyFrom(download_artifact)
|
188
|
+
elif link_artifact:
|
189
|
+
request.link_artifact.CopyFrom(link_artifact)
|
187
190
|
elif defer:
|
188
191
|
request.defer.CopyFrom(defer)
|
189
192
|
elif attach:
|
@@ -242,7 +245,6 @@ class InterfaceShared(InterfaceBase):
|
|
242
245
|
request: Optional[pb.Request] = None,
|
243
246
|
telemetry: Optional[tpb.TelemetryRecord] = None,
|
244
247
|
preempting: Optional[pb.RunPreemptingRecord] = None,
|
245
|
-
link_artifact: Optional[pb.LinkArtifactRecord] = None,
|
246
248
|
use_artifact: Optional[pb.UseArtifactRecord] = None,
|
247
249
|
output: Optional[pb.OutputRecord] = None,
|
248
250
|
output_raw: Optional[pb.OutputRawRecord] = None,
|
@@ -282,8 +284,6 @@ class InterfaceShared(InterfaceBase):
|
|
282
284
|
record.metric.CopyFrom(metric)
|
283
285
|
elif preempting:
|
284
286
|
record.preempting.CopyFrom(preempting)
|
285
|
-
elif link_artifact:
|
286
|
-
record.link_artifact.CopyFrom(link_artifact)
|
287
287
|
elif use_artifact:
|
288
288
|
record.use_artifact.CopyFrom(use_artifact)
|
289
289
|
elif output:
|
@@ -393,10 +393,6 @@ class InterfaceShared(InterfaceBase):
|
|
393
393
|
rec = self._make_record(files=files)
|
394
394
|
self._publish(rec)
|
395
395
|
|
396
|
-
def _publish_link_artifact(self, link_artifact: pb.LinkArtifactRecord) -> Any:
|
397
|
-
rec = self._make_record(link_artifact=link_artifact)
|
398
|
-
self._publish(rec)
|
399
|
-
|
400
396
|
def _publish_use_artifact(self, use_artifact: pb.UseArtifactRecord) -> Any:
|
401
397
|
rec = self._make_record(use_artifact=use_artifact)
|
402
398
|
self._publish(rec)
|
@@ -411,6 +407,12 @@ class InterfaceShared(InterfaceBase):
|
|
411
407
|
rec = self._make_request(download_artifact=download_artifact)
|
412
408
|
return self._deliver_record(rec)
|
413
409
|
|
410
|
+
def _deliver_link_artifact(
|
411
|
+
self, link_artifact: pb.LinkArtifactRequest
|
412
|
+
) -> MailboxHandle:
|
413
|
+
rec = self._make_request(link_artifact=link_artifact)
|
414
|
+
return self._deliver_record(rec)
|
415
|
+
|
414
416
|
def _publish_artifact(self, proto_artifact: pb.ArtifactRecord) -> None:
|
415
417
|
rec = self._make_record(artifact=proto_artifact)
|
416
418
|
self._publish(rec)
|
wandb/sdk/internal/handler.py
CHANGED
@@ -230,7 +230,7 @@ class HandleManager:
|
|
230
230
|
def handle_files(self, record: Record) -> None:
|
231
231
|
self._dispatch_record(record)
|
232
232
|
|
233
|
-
def
|
233
|
+
def handle_request_link_artifact(self, record: Record) -> None:
|
234
234
|
self._dispatch_record(record)
|
235
235
|
|
236
236
|
def handle_use_artifact(self, record: Record) -> None:
|
@@ -232,14 +232,14 @@ class Api:
|
|
232
232
|
|
233
233
|
# todo: remove these hacky hacks after settings refactor is complete
|
234
234
|
# keeping this code here to limit scope and so that it is easy to remove later
|
235
|
-
|
235
|
+
self._extra_http_headers = self.settings("_extra_http_headers") or json.loads(
|
236
236
|
self._environ.get("WANDB__EXTRA_HTTP_HEADERS", "{}")
|
237
237
|
)
|
238
|
-
|
238
|
+
self._extra_http_headers.update(_thread_local_api_settings.headers or {})
|
239
239
|
|
240
240
|
auth = None
|
241
241
|
if self.access_token is not None:
|
242
|
-
|
242
|
+
self._extra_http_headers["Authorization"] = f"Bearer {self.access_token}"
|
243
243
|
elif _thread_local_api_settings.cookies is None:
|
244
244
|
auth = ("api", self.api_key or "")
|
245
245
|
|
@@ -253,7 +253,7 @@ class Api:
|
|
253
253
|
"User-Agent": self.user_agent,
|
254
254
|
"X-WANDB-USERNAME": env.get_username(env=self._environ),
|
255
255
|
"X-WANDB-USER-EMAIL": env.get_user_email(env=self._environ),
|
256
|
-
**
|
256
|
+
**self._extra_http_headers,
|
257
257
|
},
|
258
258
|
use_json=True,
|
259
259
|
# this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
|
wandb/sdk/internal/sender.py
CHANGED
@@ -1473,8 +1473,13 @@ class SendManager:
|
|
1473
1473
|
# tbrecord watching threads are handled by handler.py
|
1474
1474
|
pass
|
1475
1475
|
|
1476
|
-
def
|
1477
|
-
|
1476
|
+
def send_request_link_artifact(self, record: "Record") -> None:
|
1477
|
+
if not (record.control.req_resp or record.control.mailbox_slot):
|
1478
|
+
raise ValueError(
|
1479
|
+
f"Expected either `req_resp` or `mailbox_slot`, got: {record.control!r}"
|
1480
|
+
)
|
1481
|
+
result = proto_util._result_from_record(record)
|
1482
|
+
link = record.request.link_artifact
|
1478
1483
|
client_id = link.client_id
|
1479
1484
|
server_id = link.server_id
|
1480
1485
|
portfolio_name = link.portfolio_name
|
@@ -1490,7 +1495,9 @@ class SendManager:
|
|
1490
1495
|
client_id, server_id, portfolio_name, entity, project, aliases
|
1491
1496
|
)
|
1492
1497
|
except Exception as e:
|
1498
|
+
result.response.log_artifact_response.error_message = f'error linking artifact to "{entity}/{project}/{portfolio_name}"; error: {e}'
|
1493
1499
|
logger.warning("Failed to link artifact to portfolio: %s", e)
|
1500
|
+
self._respond_result(result)
|
1494
1501
|
|
1495
1502
|
def send_use_artifact(self, record: "Record") -> None:
|
1496
1503
|
"""Pretend to send a used artifact.
|
@@ -263,11 +263,17 @@ class KanikoBuilder(AbstractBuilder):
|
|
263
263
|
repo_uri = await self.registry.get_repo_uri()
|
264
264
|
image_uri = repo_uri + ":" + image_tag
|
265
265
|
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
):
|
270
|
-
|
266
|
+
# The DOCKER_CONFIG_SECRET option is mutually exclusive with the
|
267
|
+
# registry classes, so we must skip the check for image existence in
|
268
|
+
# that case.
|
269
|
+
if not launch_project.build_required():
|
270
|
+
if DOCKER_CONFIG_SECRET:
|
271
|
+
wandb.termlog(
|
272
|
+
f"Skipping check for existing image {image_uri} due to custom dockerconfig."
|
273
|
+
)
|
274
|
+
else:
|
275
|
+
if await self.registry.check_image_exists(image_uri):
|
276
|
+
return image_uri
|
271
277
|
|
272
278
|
_logger.info(f"Building image {image_uri}...")
|
273
279
|
_, api_client = await get_kube_context_and_api_client(
|
@@ -286,7 +292,12 @@ class KanikoBuilder(AbstractBuilder):
|
|
286
292
|
wandb.termlog(f"{LOG_PREFIX}Created kaniko job {build_job_name}")
|
287
293
|
|
288
294
|
try:
|
289
|
-
|
295
|
+
# DOCKER_CONFIG_SECRET is a user provided dockerconfigjson. Skip our
|
296
|
+
# dockerconfig handling if it's set.
|
297
|
+
if (
|
298
|
+
isinstance(self.registry, AzureContainerRegistry)
|
299
|
+
and not DOCKER_CONFIG_SECRET
|
300
|
+
):
|
290
301
|
dockerfile_config_map = client.V1ConfigMap(
|
291
302
|
metadata=client.V1ObjectMeta(
|
292
303
|
name=f"docker-config-{build_job_name}"
|
@@ -344,7 +355,10 @@ class KanikoBuilder(AbstractBuilder):
|
|
344
355
|
finally:
|
345
356
|
wandb.termlog(f"{LOG_PREFIX}Cleaning up resources")
|
346
357
|
try:
|
347
|
-
if
|
358
|
+
if (
|
359
|
+
isinstance(self.registry, AzureContainerRegistry)
|
360
|
+
and not DOCKER_CONFIG_SECRET
|
361
|
+
):
|
348
362
|
await core_v1.delete_namespaced_config_map(
|
349
363
|
f"docker-config-{build_job_name}", "wandb"
|
350
364
|
)
|
@@ -498,7 +512,10 @@ class KanikoBuilder(AbstractBuilder):
|
|
498
512
|
"readOnly": True,
|
499
513
|
}
|
500
514
|
)
|
501
|
-
if
|
515
|
+
if (
|
516
|
+
isinstance(self.registry, AzureContainerRegistry)
|
517
|
+
and not DOCKER_CONFIG_SECRET
|
518
|
+
):
|
502
519
|
# Add the docker config map
|
503
520
|
volumes.append(
|
504
521
|
{
|
@@ -533,7 +550,11 @@ class KanikoBuilder(AbstractBuilder):
|
|
533
550
|
# Apply the rest of our defaults
|
534
551
|
pod_labels["wandb"] = "launch"
|
535
552
|
# This annotation is required to enable azure workload identity.
|
536
|
-
if
|
553
|
+
# Don't add this label if using a docker config secret for auth.
|
554
|
+
if (
|
555
|
+
isinstance(self.registry, AzureContainerRegistry)
|
556
|
+
and not DOCKER_CONFIG_SECRET
|
557
|
+
):
|
537
558
|
pod_labels["azure.workload.identity/use"] = "true"
|
538
559
|
pod_spec["restartPolicy"] = pod_spec.get("restartPolicy", "Never")
|
539
560
|
pod_spec["activeDeadlineSeconds"] = pod_spec.get(
|
@@ -11,7 +11,7 @@ import os
|
|
11
11
|
import pathlib
|
12
12
|
import shutil
|
13
13
|
import tempfile
|
14
|
-
from typing import List, Optional
|
14
|
+
from typing import Any, Dict, List, Optional
|
15
15
|
|
16
16
|
import wandb
|
17
17
|
import wandb.data_types
|
@@ -62,11 +62,13 @@ class JobInputArguments:
|
|
62
62
|
self,
|
63
63
|
include: Optional[List[str]] = None,
|
64
64
|
exclude: Optional[List[str]] = None,
|
65
|
+
schema: Optional[dict] = None,
|
65
66
|
file_path: Optional[str] = None,
|
66
67
|
run_config: Optional[bool] = None,
|
67
68
|
):
|
68
69
|
self.include = include
|
69
70
|
self.exclude = exclude
|
71
|
+
self.schema = schema
|
70
72
|
self.file_path = file_path
|
71
73
|
self.run_config = run_config
|
72
74
|
|
@@ -121,15 +123,66 @@ def _publish_job_input(
|
|
121
123
|
exclude_paths=[_split_on_unesc_dot(path) for path in input.exclude]
|
122
124
|
if input.exclude
|
123
125
|
else [],
|
126
|
+
input_schema=input.schema,
|
124
127
|
run_config=input.run_config,
|
125
128
|
file_path=input.file_path or "",
|
126
129
|
)
|
127
130
|
|
128
131
|
|
132
|
+
def _replace_refs_and_allofs(schema: dict, defs: dict) -> dict:
|
133
|
+
"""Recursively fix JSON schemas with common issues.
|
134
|
+
|
135
|
+
1. Replaces any instances of $ref with their associated definition in defs
|
136
|
+
2. Removes any "allOf" lists that only have one item, "lifting" the item up
|
137
|
+
See test_internal.py for examples
|
138
|
+
"""
|
139
|
+
ret: Dict[str, Any] = {}
|
140
|
+
if "$ref" in schema:
|
141
|
+
# Reference found, replace it with its definition
|
142
|
+
def_key = schema["$ref"].split("#/$defs/")[1]
|
143
|
+
# Also run recursive replacement in case a ref contains more refs
|
144
|
+
return _replace_refs_and_allofs(defs.pop(def_key), defs)
|
145
|
+
for key, val in schema.items():
|
146
|
+
if isinstance(val, dict):
|
147
|
+
# Step into dicts recursively
|
148
|
+
new_val_dict = _replace_refs_and_allofs(val, defs)
|
149
|
+
ret[key] = new_val_dict
|
150
|
+
elif isinstance(val, list):
|
151
|
+
# Step into each item in the list
|
152
|
+
new_val_list = []
|
153
|
+
for item in val:
|
154
|
+
if isinstance(item, dict):
|
155
|
+
new_val_list.append(_replace_refs_and_allofs(item, defs))
|
156
|
+
else:
|
157
|
+
new_val_list.append(item)
|
158
|
+
# Lift up allOf blocks with only one item
|
159
|
+
if (
|
160
|
+
key == "allOf"
|
161
|
+
and len(new_val_list) == 1
|
162
|
+
and isinstance(new_val_list[0], dict)
|
163
|
+
):
|
164
|
+
ret.update(new_val_list[0])
|
165
|
+
else:
|
166
|
+
ret[key] = new_val_list
|
167
|
+
else:
|
168
|
+
# For anything else (str, int, etc) keep it as-is
|
169
|
+
ret[key] = val
|
170
|
+
return ret
|
171
|
+
|
172
|
+
|
173
|
+
def _convert_pydantic_model_to_jsonschema(model: Any) -> dict:
|
174
|
+
schema = model.model_json_schema()
|
175
|
+
defs = schema.pop("$defs")
|
176
|
+
if not defs:
|
177
|
+
return schema
|
178
|
+
return _replace_refs_and_allofs(schema, defs)
|
179
|
+
|
180
|
+
|
129
181
|
def handle_config_file_input(
|
130
182
|
path: str,
|
131
183
|
include: Optional[List[str]] = None,
|
132
184
|
exclude: Optional[List[str]] = None,
|
185
|
+
schema: Optional[Any] = None,
|
133
186
|
):
|
134
187
|
"""Declare an overridable configuration file for a launch job.
|
135
188
|
|
@@ -151,9 +204,20 @@ def handle_config_file_input(
|
|
151
204
|
path,
|
152
205
|
dest,
|
153
206
|
)
|
207
|
+
# This supports both an instance of a pydantic BaseModel class (e.g. schema=MySchema(...))
|
208
|
+
# or the BaseModel class itself (e.g. schema=MySchema)
|
209
|
+
if hasattr(schema, "model_json_schema") and callable(
|
210
|
+
schema.model_json_schema # type: ignore
|
211
|
+
):
|
212
|
+
schema = _convert_pydantic_model_to_jsonschema(schema)
|
213
|
+
if schema and not isinstance(schema, dict):
|
214
|
+
raise LaunchError(
|
215
|
+
"schema must be a dict, Pydantic model instance, or Pydantic model class."
|
216
|
+
)
|
154
217
|
arguments = JobInputArguments(
|
155
218
|
include=include,
|
156
219
|
exclude=exclude,
|
220
|
+
schema=schema,
|
157
221
|
file_path=path,
|
158
222
|
run_config=False,
|
159
223
|
)
|
@@ -165,7 +229,9 @@ def handle_config_file_input(
|
|
165
229
|
|
166
230
|
|
167
231
|
def handle_run_config_input(
|
168
|
-
include: Optional[List[str]] = None,
|
232
|
+
include: Optional[List[str]] = None,
|
233
|
+
exclude: Optional[List[str]] = None,
|
234
|
+
schema: Optional[Any] = None,
|
169
235
|
):
|
170
236
|
"""Declare wandb.config as an overridable configuration for a launch job.
|
171
237
|
|
@@ -175,9 +241,20 @@ def handle_run_config_input(
|
|
175
241
|
If there is no active run, the include and exclude paths are staged and sent
|
176
242
|
when a run is created.
|
177
243
|
"""
|
244
|
+
# This supports both an instance of a pydantic BaseModel class (e.g. schema=MySchema(...))
|
245
|
+
# or the BaseModel class itself (e.g. schema=MySchema)
|
246
|
+
if hasattr(schema, "model_json_schema") and callable(
|
247
|
+
schema.model_json_schema # type: ignore
|
248
|
+
):
|
249
|
+
schema = _convert_pydantic_model_to_jsonschema(schema)
|
250
|
+
if schema and not isinstance(schema, dict):
|
251
|
+
raise LaunchError(
|
252
|
+
"schema must be a dict, Pydantic model instance, or Pydantic model class."
|
253
|
+
)
|
178
254
|
arguments = JobInputArguments(
|
179
255
|
include=include,
|
180
256
|
exclude=exclude,
|
257
|
+
schema=schema,
|
181
258
|
run_config=True,
|
182
259
|
file_path=None,
|
183
260
|
)
|
@@ -1,12 +1,13 @@
|
|
1
1
|
"""Functions for declaring overridable configuration for launch jobs."""
|
2
2
|
|
3
|
-
from typing import List, Optional
|
3
|
+
from typing import Any, List, Optional
|
4
4
|
|
5
5
|
|
6
6
|
def manage_config_file(
|
7
7
|
path: str,
|
8
8
|
include: Optional[List[str]] = None,
|
9
9
|
exclude: Optional[List[str]] = None,
|
10
|
+
schema: Optional[Any] = None,
|
10
11
|
):
|
11
12
|
r"""Declare an overridable configuration file for a launch job.
|
12
13
|
|
@@ -43,18 +44,27 @@ def manage_config_file(
|
|
43
44
|
relative and must not contain backwards traversal, i.e. `..`.
|
44
45
|
include (List[str]): A list of keys to include in the configuration file.
|
45
46
|
exclude (List[str]): A list of keys to exclude from the configuration file.
|
47
|
+
schema (dict | Pydantic model): A JSON Schema or Pydantic model describing
|
48
|
+
describing which attributes will be editable from the Launch drawer.
|
49
|
+
Accepts both an instance of a Pydantic BaseModel class or the BaseModel
|
50
|
+
class itself.
|
46
51
|
|
47
52
|
Raises:
|
48
53
|
LaunchError: If the path is not valid, or if there is no active run.
|
49
54
|
"""
|
55
|
+
# note: schema's Any type is because in the case where a BaseModel class is
|
56
|
+
# provided, its type is a pydantic internal type that we don't want our typing
|
57
|
+
# to depend on. schema's type should be considered
|
58
|
+
# "Optional[dict | <something with a .model_json_schema() method>]"
|
50
59
|
from .internal import handle_config_file_input
|
51
60
|
|
52
|
-
return handle_config_file_input(path, include, exclude)
|
61
|
+
return handle_config_file_input(path, include, exclude, schema)
|
53
62
|
|
54
63
|
|
55
64
|
def manage_wandb_config(
|
56
65
|
include: Optional[List[str]] = None,
|
57
66
|
exclude: Optional[List[str]] = None,
|
67
|
+
schema: Optional[Any] = None,
|
58
68
|
):
|
59
69
|
r"""Declare wandb.config as an overridable configuration for a launch job.
|
60
70
|
|
@@ -86,10 +96,18 @@ def manage_wandb_config(
|
|
86
96
|
Args:
|
87
97
|
include (List[str]): A list of subtrees to include in the configuration.
|
88
98
|
exclude (List[str]): A list of subtrees to exclude from the configuration.
|
99
|
+
schema (dict | Pydantic model): A JSON Schema or Pydantic model describing
|
100
|
+
describing which attributes will be editable from the Launch drawer.
|
101
|
+
Accepts both an instance of a Pydantic BaseModel class or the BaseModel
|
102
|
+
class itself.
|
89
103
|
|
90
104
|
Raises:
|
91
105
|
LaunchError: If there is no active run.
|
92
106
|
"""
|
107
|
+
# note: schema's Any type is because in the case where a BaseModel class is
|
108
|
+
# provided, its type is a pydantic internal type that we don't want our typing
|
109
|
+
# to depend on. schema's type should be considered
|
110
|
+
# "Optional[dict | <something with a .model_json_schema() method>]"
|
93
111
|
from .internal import handle_run_config_input
|
94
112
|
|
95
|
-
handle_run_config_input(include, exclude)
|
113
|
+
handle_run_config_input(include, exclude, schema)
|
wandb/sdk/lib/tracelog.py
CHANGED
@@ -45,8 +45,8 @@ logger = logging.getLogger(__name__)
|
|
45
45
|
ANNOTATE_QUEUE_NAME = "_DEBUGLOG_QUEUE_NAME"
|
46
46
|
|
47
47
|
# capture stdout and stderr before anyone messes with them
|
48
|
-
stdout_write = sys.__stdout__.write
|
49
|
-
stderr_write = sys.__stderr__.write
|
48
|
+
stdout_write = sys.__stdout__.write # type: ignore
|
49
|
+
stderr_write = sys.__stderr__.write # type: ignore
|
50
50
|
|
51
51
|
|
52
52
|
def _log(
|
wandb/sdk/wandb_manager.py
CHANGED
@@ -114,17 +114,21 @@ class _Manager:
|
|
114
114
|
|
115
115
|
try:
|
116
116
|
svc_iface._svc_connect(port=port)
|
117
|
+
|
117
118
|
except ConnectionRefusedError as e:
|
118
119
|
if not psutil.pid_exists(self._token.pid):
|
119
120
|
message = (
|
120
|
-
"Connection to wandb service failed
|
121
|
-
"
|
121
|
+
"Connection to wandb service failed"
|
122
|
+
" because the process is not available."
|
122
123
|
)
|
123
124
|
else:
|
124
|
-
message =
|
125
|
-
raise ManagerConnectionRefusedError(message)
|
125
|
+
message = "Connection to wandb service failed."
|
126
|
+
raise ManagerConnectionRefusedError(message) from e
|
127
|
+
|
126
128
|
except Exception as e:
|
127
|
-
raise ManagerConnectionError(
|
129
|
+
raise ManagerConnectionError(
|
130
|
+
"Connection to wandb service failed.",
|
131
|
+
) from e
|
128
132
|
|
129
133
|
def __init__(self, settings: "Settings") -> None:
|
130
134
|
"""Connects to the internal service, starting it if necessary."""
|