wandb 0.20.2rc20250616__py3-none-macosx_11_0_arm64.whl → 0.21.1__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +16 -14
- wandb/__init__.pyi +450 -472
- wandb/agents/pyagent.py +41 -12
- wandb/analytics/sentry.py +7 -2
- wandb/apis/importers/mlflow.py +1 -1
- wandb/apis/internal.py +3 -0
- wandb/apis/paginator.py +17 -4
- wandb/apis/public/__init__.py +1 -1
- wandb/apis/public/api.py +606 -359
- wandb/apis/public/artifacts.py +214 -16
- wandb/apis/public/automations.py +19 -3
- wandb/apis/public/files.py +177 -38
- wandb/apis/public/history.py +67 -15
- wandb/apis/public/integrations.py +25 -2
- wandb/apis/public/jobs.py +90 -2
- wandb/apis/public/projects.py +161 -69
- wandb/apis/public/query_generator.py +11 -1
- wandb/apis/public/registries/registries_search.py +7 -15
- wandb/apis/public/reports.py +147 -13
- wandb/apis/public/runs.py +315 -128
- wandb/apis/public/sweeps.py +222 -22
- wandb/apis/public/teams.py +41 -4
- wandb/apis/public/users.py +45 -4
- wandb/automations/__init__.py +10 -10
- wandb/automations/_filters/run_metrics.py +0 -2
- wandb/automations/_utils.py +0 -2
- wandb/automations/actions.py +0 -2
- wandb/automations/automations.py +0 -2
- wandb/automations/events.py +0 -2
- wandb/beta/workflows.py +66 -30
- wandb/bin/gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +80 -1
- wandb/env.py +8 -0
- wandb/errors/errors.py +4 -1
- wandb/integration/catboost/catboost.py +6 -2
- wandb/integration/kfp/kfp_patch.py +3 -1
- wandb/integration/lightning/fabric/logger.py +3 -4
- wandb/integration/metaflow/__init__.py +6 -0
- wandb/integration/metaflow/data_pandas.py +74 -0
- wandb/integration/metaflow/errors.py +13 -0
- wandb/integration/metaflow/metaflow.py +205 -190
- wandb/integration/openai/fine_tuning.py +1 -2
- wandb/integration/sb3/sb3.py +3 -3
- wandb/integration/ultralytics/callback.py +6 -2
- wandb/jupyter.py +5 -5
- wandb/plot/__init__.py +2 -0
- wandb/plot/bar.py +30 -29
- wandb/plot/confusion_matrix.py +75 -71
- wandb/plot/custom_chart.py +30 -7
- wandb/plot/histogram.py +26 -25
- wandb/plot/line.py +33 -32
- wandb/plot/line_series.py +100 -103
- wandb/plot/pr_curve.py +33 -32
- wandb/plot/roc_curve.py +38 -38
- wandb/plot/scatter.py +27 -27
- wandb/proto/v3/wandb_internal_pb2.py +366 -385
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +4 -4
- wandb/proto/v4/wandb_internal_pb2.py +352 -356
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +4 -4
- wandb/proto/v5/wandb_internal_pb2.py +352 -356
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +4 -4
- wandb/proto/v6/wandb_internal_pb2.py +352 -356
- wandb/proto/v6/wandb_settings_pb2.py +2 -2
- wandb/proto/v6/wandb_telemetry_pb2.py +4 -4
- wandb/proto/wandb_deprecated.py +6 -0
- wandb/sdk/artifacts/_generated/__init__.py +12 -1
- wandb/sdk/artifacts/_generated/input_types.py +20 -2
- wandb/sdk/artifacts/_generated/link_artifact.py +21 -0
- wandb/sdk/artifacts/_generated/operations.py +9 -0
- wandb/sdk/artifacts/_internal_artifact.py +19 -8
- wandb/sdk/artifacts/_validators.py +48 -2
- wandb/sdk/artifacts/artifact.py +269 -96
- wandb/sdk/data_types/audio.py +38 -10
- wandb/sdk/data_types/base_types/media.py +15 -63
- wandb/sdk/data_types/base_types/wb_value.py +6 -6
- wandb/sdk/data_types/graph.py +48 -14
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +1 -3
- wandb/sdk/data_types/helper_types/image_mask.py +1 -3
- wandb/sdk/data_types/histogram.py +34 -21
- wandb/sdk/data_types/html.py +35 -12
- wandb/sdk/data_types/image.py +104 -68
- wandb/sdk/data_types/molecule.py +32 -19
- wandb/sdk/data_types/object_3d.py +36 -17
- wandb/sdk/data_types/plotly.py +18 -5
- wandb/sdk/data_types/saved_model.py +7 -9
- wandb/sdk/data_types/table.py +99 -70
- wandb/sdk/data_types/trace_tree.py +12 -12
- wandb/sdk/data_types/video.py +53 -26
- wandb/sdk/integration_utils/auto_logging.py +2 -2
- wandb/sdk/interface/interface.py +8 -19
- wandb/sdk/interface/interface_shared.py +7 -16
- wandb/sdk/internal/datastore.py +18 -18
- wandb/sdk/internal/handler.py +3 -5
- wandb/sdk/internal/internal_api.py +60 -0
- wandb/sdk/internal/job_builder.py +6 -0
- wandb/sdk/internal/sender.py +23 -3
- wandb/sdk/internal/sender_config.py +9 -0
- wandb/sdk/launch/_project_spec.py +3 -3
- wandb/sdk/launch/agent/agent.py +11 -4
- wandb/sdk/launch/agent/job_status_tracker.py +3 -1
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +2 -2
- wandb/sdk/launch/create_job.py +3 -1
- wandb/sdk/launch/inputs/internal.py +3 -4
- wandb/sdk/launch/inputs/schema.py +1 -0
- wandb/sdk/launch/runner/kubernetes_monitor.py +1 -0
- wandb/sdk/launch/runner/kubernetes_runner.py +328 -1
- wandb/sdk/launch/sweeps/scheduler.py +2 -3
- wandb/sdk/launch/utils.py +3 -3
- wandb/sdk/lib/asyncio_compat.py +3 -0
- wandb/sdk/lib/console_capture.py +66 -19
- wandb/sdk/lib/deprecate.py +1 -7
- wandb/sdk/lib/disabled.py +1 -1
- wandb/sdk/lib/hashutil.py +14 -1
- wandb/sdk/lib/module.py +7 -13
- wandb/sdk/lib/progress.py +0 -19
- wandb/sdk/lib/sock_client.py +0 -4
- wandb/sdk/wandb_init.py +67 -93
- wandb/sdk/wandb_login.py +18 -14
- wandb/sdk/wandb_metric.py +2 -0
- wandb/sdk/wandb_require.py +0 -1
- wandb/sdk/wandb_run.py +429 -527
- wandb/sdk/wandb_settings.py +364 -74
- wandb/sdk/wandb_setup.py +28 -28
- wandb/sdk/wandb_sweep.py +14 -13
- wandb/sdk/wandb_watch.py +4 -6
- wandb/sync/sync.py +10 -0
- wandb/util.py +57 -0
- wandb/wandb_run.py +1 -2
- {wandb-0.20.2rc20250616.dist-info → wandb-0.21.1.dist-info}/METADATA +1 -1
- {wandb-0.20.2rc20250616.dist-info → wandb-0.21.1.dist-info}/RECORD +137 -137
- wandb/sdk/wandb_metadata.py +0 -623
- wandb/vendor/pynvml/__init__.py +0 -0
- wandb/vendor/pynvml/pynvml.py +0 -4779
- {wandb-0.20.2rc20250616.dist-info → wandb-0.21.1.dist-info}/WHEEL +0 -0
- {wandb-0.20.2rc20250616.dist-info → wandb-0.21.1.dist-info}/entry_points.txt +0 -0
- {wandb-0.20.2rc20250616.dist-info → wandb-0.21.1.dist-info}/licenses/LICENSE +0 -0
@@ -59,6 +59,10 @@ class InterfaceShared(InterfaceBase):
|
|
59
59
|
rec = self._make_record(telemetry=telem)
|
60
60
|
self._publish(rec)
|
61
61
|
|
62
|
+
def _publish_environment(self, environment: pb.EnvironmentRecord) -> None:
|
63
|
+
rec = self._make_record(environment=environment)
|
64
|
+
self._publish(rec)
|
65
|
+
|
62
66
|
def _publish_job_input(
|
63
67
|
self, job_input: pb.JobInputRequest
|
64
68
|
) -> MailboxHandle[pb.Result]:
|
@@ -106,11 +110,9 @@ class InterfaceShared(InterfaceBase):
|
|
106
110
|
summary_record: Optional[pb.SummaryRecordRequest] = None,
|
107
111
|
telemetry_record: Optional[pb.TelemetryRecordRequest] = None,
|
108
112
|
get_system_metrics: Optional[pb.GetSystemMetricsRequest] = None,
|
109
|
-
get_system_metadata: Optional[pb.GetSystemMetadataRequest] = None,
|
110
113
|
python_packages: Optional[pb.PythonPackagesRequest] = None,
|
111
114
|
job_input: Optional[pb.JobInputRequest] = None,
|
112
115
|
run_finish_without_exit: Optional[pb.RunFinishWithoutExitRequest] = None,
|
113
|
-
metadata: Optional[pb.MetadataRequest] = None,
|
114
116
|
) -> pb.Record:
|
115
117
|
request = pb.Request()
|
116
118
|
if get_summary:
|
@@ -169,8 +171,6 @@ class InterfaceShared(InterfaceBase):
|
|
169
171
|
request.telemetry_record.CopyFrom(telemetry_record)
|
170
172
|
elif get_system_metrics:
|
171
173
|
request.get_system_metrics.CopyFrom(get_system_metrics)
|
172
|
-
elif get_system_metadata:
|
173
|
-
request.get_system_metadata.CopyFrom(get_system_metadata)
|
174
174
|
elif sync_finish:
|
175
175
|
request.sync_finish.CopyFrom(sync_finish)
|
176
176
|
elif python_packages:
|
@@ -179,8 +179,6 @@ class InterfaceShared(InterfaceBase):
|
|
179
179
|
request.job_input.CopyFrom(job_input)
|
180
180
|
elif run_finish_without_exit:
|
181
181
|
request.run_finish_without_exit.CopyFrom(run_finish_without_exit)
|
182
|
-
elif metadata:
|
183
|
-
request.metadata.CopyFrom(metadata)
|
184
182
|
else:
|
185
183
|
raise Exception("Invalid request")
|
186
184
|
record = self._make_record(request=request)
|
@@ -212,6 +210,7 @@ class InterfaceShared(InterfaceBase):
|
|
212
210
|
use_artifact: Optional[pb.UseArtifactRecord] = None,
|
213
211
|
output: Optional[pb.OutputRecord] = None,
|
214
212
|
output_raw: Optional[pb.OutputRawRecord] = None,
|
213
|
+
environment: Optional[pb.EnvironmentRecord] = None,
|
215
214
|
) -> pb.Record:
|
216
215
|
record = pb.Record()
|
217
216
|
if run:
|
@@ -254,6 +253,8 @@ class InterfaceShared(InterfaceBase):
|
|
254
253
|
record.output.CopyFrom(output)
|
255
254
|
elif output_raw:
|
256
255
|
record.output_raw.CopyFrom(output_raw)
|
256
|
+
elif environment:
|
257
|
+
record.environment.CopyFrom(environment)
|
257
258
|
else:
|
258
259
|
raise Exception("Invalid record")
|
259
260
|
return record
|
@@ -304,10 +305,6 @@ class InterfaceShared(InterfaceBase):
|
|
304
305
|
rec = self._make_record(summary=summary)
|
305
306
|
self._publish(rec)
|
306
307
|
|
307
|
-
def _publish_metadata(self, metadata: pb.MetadataRequest) -> None:
|
308
|
-
rec = self._make_request(metadata=metadata)
|
309
|
-
self._publish(rec)
|
310
|
-
|
311
308
|
def _publish_metric(self, metric: pb.MetricRecord) -> None:
|
312
309
|
rec = self._make_record(metric=metric)
|
313
310
|
self._publish(rec)
|
@@ -422,12 +419,6 @@ class InterfaceShared(InterfaceBase):
|
|
422
419
|
record = self._make_request(get_system_metrics=get_system_metrics)
|
423
420
|
return self._deliver_record(record)
|
424
421
|
|
425
|
-
def _deliver_get_system_metadata(
|
426
|
-
self, get_system_metadata: pb.GetSystemMetadataRequest
|
427
|
-
) -> MailboxHandle[pb.Result]:
|
428
|
-
record = self._make_request(get_system_metadata=get_system_metadata)
|
429
|
-
return self._deliver_record(record)
|
430
|
-
|
431
422
|
def _deliver_exit(
|
432
423
|
self,
|
433
424
|
exit_data: pb.RunExitRecord,
|
wandb/sdk/internal/datastore.py
CHANGED
@@ -124,18 +124,18 @@ class DataStore:
|
|
124
124
|
header = self._fp.read(LEVELDBLOG_HEADER_LEN)
|
125
125
|
if len(header) == 0:
|
126
126
|
return None
|
127
|
-
assert (
|
128
|
-
len(header)
|
129
|
-
)
|
127
|
+
assert len(header) == LEVELDBLOG_HEADER_LEN, (
|
128
|
+
f"record header is {len(header)} bytes instead of the expected {LEVELDBLOG_HEADER_LEN}"
|
129
|
+
)
|
130
130
|
fields = struct.unpack("<IHB", header)
|
131
131
|
checksum, dlength, dtype = fields
|
132
132
|
# check len, better fit in the block
|
133
133
|
self._index += LEVELDBLOG_HEADER_LEN
|
134
134
|
data = self._fp.read(dlength)
|
135
135
|
checksum_computed = zlib.crc32(data, self._crc[dtype]) & 0xFFFFFFFF
|
136
|
-
assert (
|
137
|
-
checksum
|
138
|
-
)
|
136
|
+
assert checksum == checksum_computed, (
|
137
|
+
"record checksum is invalid, data may be corrupt"
|
138
|
+
)
|
139
139
|
self._index += dlength
|
140
140
|
return dtype, data
|
141
141
|
|
@@ -158,9 +158,9 @@ class DataStore:
|
|
158
158
|
if dtype == LEVELDBLOG_FULL:
|
159
159
|
return data
|
160
160
|
|
161
|
-
assert (
|
162
|
-
|
163
|
-
)
|
161
|
+
assert dtype == LEVELDBLOG_FIRST, (
|
162
|
+
f"expected record to be type {LEVELDBLOG_FIRST} but found {dtype}"
|
163
|
+
)
|
164
164
|
while True:
|
165
165
|
offset = self._index % LEVELDBLOG_BLOCK_LEN
|
166
166
|
record = self.scan_record()
|
@@ -170,9 +170,9 @@ class DataStore:
|
|
170
170
|
if dtype == LEVELDBLOG_LAST:
|
171
171
|
data += new_data
|
172
172
|
break
|
173
|
-
assert (
|
174
|
-
|
175
|
-
)
|
173
|
+
assert dtype == LEVELDBLOG_MIDDLE, (
|
174
|
+
f"expected record to be type {LEVELDBLOG_MIDDLE} but found {dtype}"
|
175
|
+
)
|
176
176
|
data += new_data
|
177
177
|
return data
|
178
178
|
|
@@ -183,17 +183,17 @@ class DataStore:
|
|
183
183
|
LEVELDBLOG_HEADER_MAGIC,
|
184
184
|
LEVELDBLOG_HEADER_VERSION,
|
185
185
|
)
|
186
|
-
assert (
|
187
|
-
len(data)
|
188
|
-
)
|
186
|
+
assert len(data) == LEVELDBLOG_HEADER_LEN, (
|
187
|
+
f"header size is {len(data)} bytes, expected {LEVELDBLOG_HEADER_LEN}"
|
188
|
+
)
|
189
189
|
self._fp.write(data)
|
190
190
|
self._index += len(data)
|
191
191
|
|
192
192
|
def _read_header(self):
|
193
193
|
header = self._fp.read(LEVELDBLOG_HEADER_LEN)
|
194
|
-
assert (
|
195
|
-
len(header)
|
196
|
-
)
|
194
|
+
assert len(header) == LEVELDBLOG_HEADER_LEN, (
|
195
|
+
f"header is {len(header)} bytes instead of the expected {LEVELDBLOG_HEADER_LEN}"
|
196
|
+
)
|
197
197
|
ident, magic, version = struct.unpack("<4sHB", header)
|
198
198
|
if ident != strtobytes(LEVELDBLOG_HEADER_IDENT):
|
199
199
|
raise Exception("Invalid header")
|
wandb/sdk/internal/handler.py
CHANGED
@@ -37,7 +37,6 @@ from wandb.proto.wandb_internal_pb2 import (
|
|
37
37
|
|
38
38
|
from ..interface.interface_queue import InterfaceQueue
|
39
39
|
from ..lib import handler_util, proto_util
|
40
|
-
from ..wandb_metadata import Metadata
|
41
40
|
from . import context, sample, tb_watcher
|
42
41
|
from .settings_static import SettingsStatic
|
43
42
|
|
@@ -115,7 +114,6 @@ class HandleManager:
|
|
115
114
|
self._context_keeper = context_keeper
|
116
115
|
|
117
116
|
self._tb_watcher = None
|
118
|
-
self._metadata: Optional[Metadata] = None
|
119
117
|
self._step = 0
|
120
118
|
|
121
119
|
self._track_time = None
|
@@ -173,9 +171,6 @@ class HandleManager:
|
|
173
171
|
def handle_request_cancel(self, record: Record) -> None:
|
174
172
|
self._dispatch_record(record)
|
175
173
|
|
176
|
-
def handle_request_metadata(self, record: Record) -> None:
|
177
|
-
logger.warning("Metadata updates are ignored when using the legacy service.")
|
178
|
-
|
179
174
|
def handle_request_defer(self, record: Record) -> None:
|
180
175
|
defer = record.request.defer
|
181
176
|
state = defer.state
|
@@ -649,6 +644,9 @@ class HandleManager:
|
|
649
644
|
def handle_footer(self, record: Record) -> None:
|
650
645
|
self._dispatch_record(record)
|
651
646
|
|
647
|
+
def handle_metadata(self, record: Record) -> None:
|
648
|
+
self._dispatch_record(record)
|
649
|
+
|
652
650
|
def handle_request_attach(self, record: Record) -> None:
|
653
651
|
result = proto_util._result_from_record(record)
|
654
652
|
attach_id = record.request.attach.attach_id
|
@@ -284,6 +284,7 @@ class Api:
|
|
284
284
|
self._extra_http_headers.update(_thread_local_api_settings.headers or {})
|
285
285
|
|
286
286
|
auth = None
|
287
|
+
api_key = api_key or self.default_settings.get("api_key")
|
287
288
|
if api_key:
|
288
289
|
auth = ("api", api_key)
|
289
290
|
elif self.access_token is not None:
|
@@ -362,6 +363,7 @@ class Api:
|
|
362
363
|
self.server_create_run_queue_supports_priority: Optional[bool] = None
|
363
364
|
self.server_supports_template_variables: Optional[bool] = None
|
364
365
|
self.server_push_to_run_queue_supports_priority: Optional[bool] = None
|
366
|
+
|
365
367
|
self._server_features_cache: Optional[Dict[str, bool]] = None
|
366
368
|
|
367
369
|
def gql(self, *args: Any, **kwargs: Any) -> Any:
|
@@ -3233,6 +3235,7 @@ class Api:
|
|
3233
3235
|
entity: Optional[str] = None,
|
3234
3236
|
state: Optional[str] = None,
|
3235
3237
|
prior_runs: Optional[List[str]] = None,
|
3238
|
+
display_name: Optional[str] = None,
|
3236
3239
|
template_variable_values: Optional[Dict[str, Any]] = None,
|
3237
3240
|
) -> Tuple[str, List[str]]:
|
3238
3241
|
"""Upsert a sweep object.
|
@@ -3247,6 +3250,7 @@ class Api:
|
|
3247
3250
|
entity (str): entity to use
|
3248
3251
|
state (str): state
|
3249
3252
|
prior_runs (list): IDs of existing runs to add to the sweep
|
3253
|
+
display_name (str): display name for the sweep
|
3250
3254
|
template_variable_values (dict): template variable values
|
3251
3255
|
"""
|
3252
3256
|
project_query = """
|
@@ -3270,6 +3274,7 @@ class Api:
|
|
3270
3274
|
$scheduler: JSONString,
|
3271
3275
|
$state: String,
|
3272
3276
|
$priorRunsFilters: JSONString,
|
3277
|
+
$displayName: String,
|
3273
3278
|
) {
|
3274
3279
|
upsertSweep(input: {
|
3275
3280
|
id: $id,
|
@@ -3281,6 +3286,7 @@ class Api:
|
|
3281
3286
|
scheduler: $scheduler,
|
3282
3287
|
state: $state,
|
3283
3288
|
priorRunsFilters: $priorRunsFilters,
|
3289
|
+
displayName: $displayName,
|
3284
3290
|
}) {
|
3285
3291
|
sweep {
|
3286
3292
|
name
|
@@ -3357,6 +3363,7 @@ class Api:
|
|
3357
3363
|
"templateVariableValues": json.dumps(template_variable_values),
|
3358
3364
|
"scheduler": scheduler,
|
3359
3365
|
"priorRunsFilters": filters,
|
3366
|
+
"displayName": display_name,
|
3360
3367
|
}
|
3361
3368
|
if state:
|
3362
3369
|
variables["state"] = state
|
@@ -4661,3 +4668,56 @@ class Api:
|
|
4661
4668
|
success: bool = response["stopRun"].get("success")
|
4662
4669
|
|
4663
4670
|
return success
|
4671
|
+
|
4672
|
+
@normalize_exceptions
|
4673
|
+
def create_custom_chart(
|
4674
|
+
self,
|
4675
|
+
entity: str,
|
4676
|
+
name: str,
|
4677
|
+
display_name: str,
|
4678
|
+
spec_type: str,
|
4679
|
+
access: str,
|
4680
|
+
spec: Union[str, Mapping[str, Any]],
|
4681
|
+
) -> Optional[Dict[str, Any]]:
|
4682
|
+
if not isinstance(spec, str):
|
4683
|
+
spec = json.dumps(spec)
|
4684
|
+
|
4685
|
+
mutation = gql(
|
4686
|
+
"""
|
4687
|
+
mutation CreateCustomChart(
|
4688
|
+
$entity: String!
|
4689
|
+
$name: String!
|
4690
|
+
$displayName: String!
|
4691
|
+
$type: String!
|
4692
|
+
$access: String!
|
4693
|
+
$spec: JSONString!
|
4694
|
+
) {
|
4695
|
+
createCustomChart(
|
4696
|
+
input: {
|
4697
|
+
entity: $entity
|
4698
|
+
name: $name
|
4699
|
+
displayName: $displayName
|
4700
|
+
type: $type
|
4701
|
+
access: $access
|
4702
|
+
spec: $spec
|
4703
|
+
}
|
4704
|
+
) {
|
4705
|
+
chart { id }
|
4706
|
+
}
|
4707
|
+
}
|
4708
|
+
"""
|
4709
|
+
)
|
4710
|
+
|
4711
|
+
variable_values = {
|
4712
|
+
"entity": entity,
|
4713
|
+
"name": name,
|
4714
|
+
"displayName": display_name,
|
4715
|
+
"type": spec_type,
|
4716
|
+
"access": access,
|
4717
|
+
"spec": spec,
|
4718
|
+
}
|
4719
|
+
|
4720
|
+
result: Optional[Dict[str, Any]] = self.gql(mutation, variable_values)[
|
4721
|
+
"createCustomChart"
|
4722
|
+
]
|
4723
|
+
return result
|
@@ -109,6 +109,7 @@ class JobSourceDict(TypedDict, total=False):
|
|
109
109
|
input_types: Dict[str, Any]
|
110
110
|
output_types: Dict[str, Any]
|
111
111
|
runtime: Optional[str]
|
112
|
+
services: Dict[str, str]
|
112
113
|
|
113
114
|
|
114
115
|
class ArtifactInfoForJob(TypedDict):
|
@@ -143,6 +144,7 @@ class JobBuilder:
|
|
143
144
|
_job_version_alias: Optional[str]
|
144
145
|
_is_notebook_run: bool
|
145
146
|
_verbose: bool
|
147
|
+
_services: Dict[str, str]
|
146
148
|
|
147
149
|
def __init__(self, settings: SettingsStatic, verbose: bool = False):
|
148
150
|
self._settings = settings
|
@@ -162,6 +164,7 @@ class JobBuilder:
|
|
162
164
|
self._is_notebook_run = self._get_is_notebook_run()
|
163
165
|
self._verbose = verbose
|
164
166
|
self._partial = False
|
167
|
+
self._services = {}
|
165
168
|
|
166
169
|
def set_config(self, config: Dict[str, Any]) -> None:
|
167
170
|
self._config = config
|
@@ -544,6 +547,9 @@ class JobBuilder:
|
|
544
547
|
"runtime": runtime,
|
545
548
|
}
|
546
549
|
|
550
|
+
if self._services:
|
551
|
+
source_info["services"] = self._services
|
552
|
+
|
547
553
|
assert source_info is not None
|
548
554
|
assert name is not None
|
549
555
|
|
wandb/sdk/internal/sender.py
CHANGED
@@ -63,6 +63,7 @@ if TYPE_CHECKING:
|
|
63
63
|
ArtifactManifest,
|
64
64
|
ArtifactManifestEntry,
|
65
65
|
ArtifactRecord,
|
66
|
+
EnvironmentRecord,
|
66
67
|
HttpResponse,
|
67
68
|
LocalInfo,
|
68
69
|
Record,
|
@@ -212,6 +213,7 @@ class SendManager:
|
|
212
213
|
_context_keeper: context.ContextKeeper
|
213
214
|
|
214
215
|
_telemetry_obj: telemetry.TelemetryRecord
|
216
|
+
_environment_obj: "EnvironmentRecord"
|
215
217
|
_fs: Optional["file_stream.FileStreamApi"]
|
216
218
|
_run: Optional["RunRecord"]
|
217
219
|
_entity: Optional[str]
|
@@ -268,6 +270,7 @@ class SendManager:
|
|
268
270
|
|
269
271
|
self._start_time: int = 0
|
270
272
|
self._telemetry_obj = telemetry.TelemetryRecord()
|
273
|
+
self._environment_obj = wandb_internal_pb2.EnvironmentRecord()
|
271
274
|
self._config_metric_pbdict_list: List[Dict[int, Any]] = []
|
272
275
|
self._metadata_summary: Dict[str, Any] = defaultdict()
|
273
276
|
self._cached_summary: Dict[str, Any] = dict()
|
@@ -790,12 +793,12 @@ class SendManager:
|
|
790
793
|
|
791
794
|
def _config_backend_dict(self) -> sender_config.BackendConfigDict:
|
792
795
|
config = self._consolidated_config or sender_config.ConfigState()
|
793
|
-
|
794
796
|
return config.to_backend_dict(
|
795
797
|
telemetry_record=self._telemetry_obj,
|
796
798
|
framework=self._telemetry_get_framework(),
|
797
799
|
start_time_millis=self._start_time,
|
798
800
|
metric_pbdicts=self._config_metric_pbdict_list,
|
801
|
+
environment_record=self._environment_obj,
|
799
802
|
)
|
800
803
|
|
801
804
|
def _config_save(
|
@@ -1379,11 +1382,11 @@ class SendManager:
|
|
1379
1382
|
next_idx = len(self._config_metric_pbdict_list)
|
1380
1383
|
self._config_metric_pbdict_list.append(md)
|
1381
1384
|
self._config_metric_index_dict[metric.name] = next_idx
|
1382
|
-
self.
|
1385
|
+
self._debounce_config()
|
1383
1386
|
|
1384
1387
|
def _update_telemetry_record(self, telemetry: telemetry.TelemetryRecord) -> None:
|
1385
1388
|
self._telemetry_obj.MergeFrom(telemetry)
|
1386
|
-
self.
|
1389
|
+
self._debounce_config()
|
1387
1390
|
|
1388
1391
|
def send_telemetry(self, record: "Record") -> None:
|
1389
1392
|
self._update_telemetry_record(record.telemetry)
|
@@ -1417,6 +1420,23 @@ class SendManager:
|
|
1417
1420
|
# tbrecord watching threads are handled by handler.py
|
1418
1421
|
pass
|
1419
1422
|
|
1423
|
+
def _update_environment_record(self, environment: "EnvironmentRecord") -> None:
|
1424
|
+
self._environment_obj.MergeFrom(environment)
|
1425
|
+
self._debounce_config()
|
1426
|
+
|
1427
|
+
def send_environment(self, record: "Record") -> None:
|
1428
|
+
"""Inject environment info into config and upload as a JSON file."""
|
1429
|
+
self._update_environment_record(record.environment)
|
1430
|
+
|
1431
|
+
environment_json = json.dumps(proto_util.message_to_dict(self._environment_obj))
|
1432
|
+
|
1433
|
+
with open(
|
1434
|
+
os.path.join(self._settings.files_dir, filenames.METADATA_FNAME), "w"
|
1435
|
+
) as f:
|
1436
|
+
f.write(environment_json)
|
1437
|
+
|
1438
|
+
self._save_file(interface.GlobStr(filenames.METADATA_FNAME), policy="now")
|
1439
|
+
|
1420
1440
|
def send_request_link_artifact(self, record: "Record") -> None:
|
1421
1441
|
if not (record.control.req_resp or record.control.mailbox_slot):
|
1422
1442
|
raise ValueError(
|
@@ -79,6 +79,7 @@ class ConfigState:
|
|
79
79
|
framework: Optional[str],
|
80
80
|
start_time_millis: int,
|
81
81
|
metric_pbdicts: Sequence[Dict[int, Any]],
|
82
|
+
environment_record: wandb_internal_pb2.EnvironmentRecord,
|
82
83
|
) -> BackendConfigDict:
|
83
84
|
"""Returns a dictionary representation expected by the backend.
|
84
85
|
|
@@ -125,6 +126,14 @@ class ConfigState:
|
|
125
126
|
if metric_pbdicts:
|
126
127
|
wandb_internal["m"] = metric_pbdicts
|
127
128
|
|
129
|
+
###################################################
|
130
|
+
# Environment
|
131
|
+
###################################################
|
132
|
+
writer_id = environment_record.writer_id
|
133
|
+
if writer_id:
|
134
|
+
environment_dict = proto_util.message_to_dict(environment_record)
|
135
|
+
wandb_internal["e"] = {writer_id: environment_dict}
|
136
|
+
|
128
137
|
return BackendConfigDict(
|
129
138
|
{
|
130
139
|
key: {
|
@@ -370,9 +370,9 @@ class LaunchProject:
|
|
370
370
|
|
371
371
|
def set_job_entry_point(self, command: List[str]) -> "EntryPoint":
|
372
372
|
"""Set job entrypoint for the project."""
|
373
|
-
assert (
|
374
|
-
|
375
|
-
)
|
373
|
+
assert self._entry_point is None, (
|
374
|
+
"Cannot set entry point twice. Use LaunchProject.override_entrypoint"
|
375
|
+
)
|
376
376
|
new_entrypoint = EntryPoint(name=command[-1], command=command)
|
377
377
|
self._entry_point = new_entrypoint
|
378
378
|
return new_entrypoint
|
wandb/sdk/launch/agent/agent.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
"""Implementation of launch agent."""
|
2
2
|
|
3
3
|
import asyncio
|
4
|
+
import copy
|
4
5
|
import logging
|
5
6
|
import os
|
6
7
|
import pprint
|
@@ -76,9 +77,9 @@ class JobSpecAndQueue:
|
|
76
77
|
def _convert_access(access: str) -> str:
|
77
78
|
"""Convert access string to a value accepted by wandb."""
|
78
79
|
access = access.upper()
|
79
|
-
assert (
|
80
|
-
access
|
81
|
-
)
|
80
|
+
assert access == "PROJECT" or access == "USER", (
|
81
|
+
"Queue access must be either project or user"
|
82
|
+
)
|
82
83
|
return access
|
83
84
|
|
84
85
|
|
@@ -421,6 +422,7 @@ class LaunchAgent:
|
|
421
422
|
"""Removes the job from our list for now."""
|
422
423
|
with self._jobs_lock:
|
423
424
|
job_and_run_status = self._jobs[thread_id]
|
425
|
+
|
424
426
|
if (
|
425
427
|
job_and_run_status.entity is not None
|
426
428
|
and job_and_run_status.entity != self._entity
|
@@ -516,7 +518,11 @@ class LaunchAgent:
|
|
516
518
|
Arguments:
|
517
519
|
job: Job to run.
|
518
520
|
"""
|
519
|
-
|
521
|
+
job_copy = copy.deepcopy(job)
|
522
|
+
if "runSpec" in job_copy and "_wandb_api_key" in job_copy["runSpec"]:
|
523
|
+
job_copy["runSpec"]["_wandb_api_key"] = "<redacted>"
|
524
|
+
|
525
|
+
_msg = f"{LOG_PREFIX}Launch agent received job:\n{pprint.pformat(job_copy)}\n"
|
520
526
|
wandb.termlog(_msg)
|
521
527
|
_logger.info(_msg)
|
522
528
|
# update agent status
|
@@ -727,6 +733,7 @@ class LaunchAgent:
|
|
727
733
|
backend = loader.runner_from_config(
|
728
734
|
resource, api, backend_config, environment, registry
|
729
735
|
)
|
736
|
+
|
730
737
|
if not (
|
731
738
|
project.docker_image
|
732
739
|
or project.job_base_image
|
@@ -44,7 +44,9 @@ class JobAndRunStatusTracker:
|
|
44
44
|
self.run_id is not None
|
45
45
|
and self.project is not None
|
46
46
|
and self.entity is not None
|
47
|
-
),
|
47
|
+
), (
|
48
|
+
"Job tracker does not contain run info. Update with run info before checking if run stopped"
|
49
|
+
)
|
48
50
|
check_stop = event_loop_thread_exec(api.api.check_stop_requested)
|
49
51
|
try:
|
50
52
|
return bool(await check_stop(self.project, self.entity, self.run_id))
|
@@ -11,7 +11,7 @@ FileSubtypes = Literal["warning", "error"]
|
|
11
11
|
class RunQueueItemFileSaver:
|
12
12
|
def __init__(
|
13
13
|
self,
|
14
|
-
agent_run: Optional["wandb.
|
14
|
+
agent_run: Optional["wandb.Run"],
|
15
15
|
run_queue_item_id: str,
|
16
16
|
):
|
17
17
|
self.run_queue_item_id = run_queue_item_id
|
@@ -20,7 +20,7 @@ class RunQueueItemFileSaver:
|
|
20
20
|
def save_contents(
|
21
21
|
self, contents: str, fname: str, file_sub_type: FileSubtypes
|
22
22
|
) -> Optional[List[str]]:
|
23
|
-
if not isinstance(self.run, wandb.
|
23
|
+
if not isinstance(self.run, wandb.Run):
|
24
24
|
wandb.termwarn("Not saving file contents because agent has no run")
|
25
25
|
return None
|
26
26
|
root_dir = self.run._settings.files_dir
|
wandb/sdk/launch/create_job.py
CHANGED
@@ -115,6 +115,7 @@ def _create_job(
|
|
115
115
|
build_context: Optional[str] = None,
|
116
116
|
dockerfile: Optional[str] = None,
|
117
117
|
base_image: Optional[str] = None,
|
118
|
+
services: Optional[Dict[str, str]] = None,
|
118
119
|
) -> Tuple[Optional[Artifact], str, List[str]]:
|
119
120
|
wandb.termlog(f"Creating launch job of type: {job_type}...")
|
120
121
|
|
@@ -169,6 +170,7 @@ def _create_job(
|
|
169
170
|
|
170
171
|
job_builder = _configure_job_builder_for_partial(tempdir.name, job_source=job_type)
|
171
172
|
job_builder._settings.job_name = name
|
173
|
+
job_builder._services = services or {}
|
172
174
|
if job_type == "code":
|
173
175
|
assert entrypoint is not None
|
174
176
|
job_name = _make_code_artifact(
|
@@ -421,7 +423,7 @@ def _configure_job_builder_for_partial(tmpdir: str, job_source: str) -> JobBuild
|
|
421
423
|
def _make_code_artifact(
|
422
424
|
api: Api,
|
423
425
|
job_builder: JobBuilder,
|
424
|
-
run: "wandb.
|
426
|
+
run: "wandb.Run",
|
425
427
|
path: str,
|
426
428
|
entrypoint: str,
|
427
429
|
entity: Optional[str],
|
@@ -17,7 +17,6 @@ import wandb
|
|
17
17
|
import wandb.data_types
|
18
18
|
from wandb.sdk.launch.errors import LaunchError
|
19
19
|
from wandb.sdk.launch.inputs.schema import META_SCHEMA
|
20
|
-
from wandb.sdk.wandb_run import Run
|
21
20
|
from wandb.util import get_module
|
22
21
|
|
23
22
|
from .files import config_path_is_valid, override_file
|
@@ -93,7 +92,7 @@ class StagedLaunchInputs:
|
|
93
92
|
):
|
94
93
|
self._staged_inputs.append(input_arguments)
|
95
94
|
|
96
|
-
def apply(self, run: Run):
|
95
|
+
def apply(self, run: wandb.Run):
|
97
96
|
"""Apply the staged inputs to the given run."""
|
98
97
|
for input in self._staged_inputs:
|
99
98
|
_publish_job_input(input, run)
|
@@ -101,13 +100,13 @@ class StagedLaunchInputs:
|
|
101
100
|
|
102
101
|
def _publish_job_input(
|
103
102
|
input: JobInputArguments,
|
104
|
-
run: Run,
|
103
|
+
run: wandb.Run,
|
105
104
|
) -> None:
|
106
105
|
"""Publish a job input to the backend interface of the given run.
|
107
106
|
|
108
107
|
Arguments:
|
109
108
|
input (JobInputArguments): The arguments for the job input.
|
110
|
-
run (Run): The run to publish the job input to.
|
109
|
+
run (wandb.Run): The run to publish the job input to.
|
111
110
|
"""
|
112
111
|
assert run._backend is not None
|
113
112
|
assert run._backend.interface is not None
|
@@ -7,6 +7,7 @@ META_SCHEMA = {
|
|
7
7
|
},
|
8
8
|
"title": {"type": "string"},
|
9
9
|
"description": {"type": "string"},
|
10
|
+
"format": {"type": "string"},
|
10
11
|
"enum": {"type": "array", "items": {"type": ["integer", "number", "string"]}},
|
11
12
|
"properties": {"type": "object", "patternProperties": {".*": {"$ref": "#"}}},
|
12
13
|
"allOf": {"type": "array", "items": {"$ref": "#"}},
|
@@ -27,6 +27,7 @@ WANDB_K8S_LABEL_NAMESPACE = "wandb.ai"
|
|
27
27
|
WANDB_K8S_RUN_ID = f"{WANDB_K8S_LABEL_NAMESPACE}/run-id"
|
28
28
|
WANDB_K8S_LABEL_AGENT = f"{WANDB_K8S_LABEL_NAMESPACE}/agent"
|
29
29
|
WANDB_K8S_LABEL_MONITOR = f"{WANDB_K8S_LABEL_NAMESPACE}/monitor"
|
30
|
+
WANDB_K8S_LABEL_AUXILIARY_RESOURCE = f"{WANDB_K8S_LABEL_NAMESPACE}/auxiliary-resource"
|
30
31
|
|
31
32
|
|
32
33
|
class Resources:
|