wandb 0.17.0rc2__py3-none-any.whl → 0.17.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -2
- wandb/apis/importers/internals/internal.py +0 -1
- wandb/apis/importers/wandb.py +12 -7
- wandb/apis/internal.py +0 -3
- wandb/apis/public/api.py +213 -79
- wandb/apis/public/artifacts.py +335 -100
- wandb/apis/public/files.py +9 -9
- wandb/apis/public/jobs.py +16 -4
- wandb/apis/public/projects.py +26 -28
- wandb/apis/public/query_generator.py +1 -1
- wandb/apis/public/runs.py +163 -65
- wandb/apis/public/sweeps.py +2 -2
- wandb/apis/reports/__init__.py +1 -7
- wandb/apis/reports/v1/__init__.py +5 -27
- wandb/apis/reports/v2/__init__.py +7 -19
- wandb/apis/workspaces/__init__.py +8 -0
- wandb/beta/workflows.py +8 -3
- wandb/cli/cli.py +131 -59
- wandb/docker/__init__.py +1 -1
- wandb/errors/term.py +10 -2
- wandb/filesync/step_checksum.py +1 -4
- wandb/filesync/step_prepare.py +4 -24
- wandb/filesync/step_upload.py +5 -107
- wandb/filesync/upload_job.py +0 -76
- wandb/integration/gym/__init__.py +35 -15
- wandb/integration/openai/fine_tuning.py +21 -3
- wandb/integration/prodigy/prodigy.py +1 -1
- wandb/jupyter.py +16 -17
- wandb/plot/pr_curve.py +2 -1
- wandb/plot/roc_curve.py +2 -1
- wandb/{plots → plot}/utils.py +13 -25
- wandb/proto/v3/wandb_internal_pb2.py +54 -54
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +54 -54
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v5/wandb_base_pb2.py +30 -0
- wandb/proto/v5/wandb_internal_pb2.py +355 -0
- wandb/proto/v5/wandb_server_pb2.py +63 -0
- wandb/proto/v5/wandb_settings_pb2.py +45 -0
- wandb/proto/v5/wandb_telemetry_pb2.py +41 -0
- wandb/proto/wandb_base_pb2.py +2 -0
- wandb/proto/wandb_deprecated.py +9 -1
- wandb/proto/wandb_generate_deprecated.py +34 -0
- wandb/proto/{wandb_internal_codegen.py → wandb_generate_proto.py} +1 -35
- wandb/proto/wandb_internal_pb2.py +2 -0
- wandb/proto/wandb_server_pb2.py +2 -0
- wandb/proto/wandb_settings_pb2.py +2 -0
- wandb/proto/wandb_telemetry_pb2.py +2 -0
- wandb/sdk/artifacts/artifact.py +68 -22
- wandb/sdk/artifacts/artifact_manifest.py +1 -1
- wandb/sdk/artifacts/artifact_manifest_entry.py +6 -3
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -1
- wandb/sdk/artifacts/artifact_saver.py +1 -10
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +6 -2
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +6 -4
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +2 -42
- wandb/sdk/artifacts/storage_policy.py +1 -12
- wandb/sdk/data_types/image.py +1 -1
- wandb/sdk/data_types/video.py +4 -2
- wandb/sdk/interface/interface.py +13 -0
- wandb/sdk/interface/interface_shared.py +1 -1
- wandb/sdk/internal/file_pusher.py +2 -5
- wandb/sdk/internal/file_stream.py +6 -19
- wandb/sdk/internal/internal_api.py +148 -136
- wandb/sdk/internal/job_builder.py +207 -135
- wandb/sdk/internal/progress.py +0 -28
- wandb/sdk/internal/sender.py +102 -39
- wandb/sdk/internal/settings_static.py +8 -1
- wandb/sdk/internal/system/assets/trainium.py +3 -3
- wandb/sdk/internal/system/system_info.py +4 -2
- wandb/sdk/internal/update.py +1 -1
- wandb/sdk/launch/__init__.py +9 -1
- wandb/sdk/launch/_launch.py +4 -24
- wandb/sdk/launch/_launch_add.py +1 -3
- wandb/sdk/launch/_project_spec.py +184 -224
- wandb/sdk/launch/agent/agent.py +58 -18
- wandb/sdk/launch/agent/config.py +0 -3
- wandb/sdk/launch/builder/abstract.py +67 -0
- wandb/sdk/launch/builder/build.py +165 -576
- wandb/sdk/launch/builder/context_manager.py +235 -0
- wandb/sdk/launch/builder/docker_builder.py +7 -23
- wandb/sdk/launch/builder/kaniko_builder.py +10 -23
- wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
- wandb/sdk/launch/create_job.py +51 -45
- wandb/sdk/launch/environment/aws_environment.py +26 -1
- wandb/sdk/launch/inputs/files.py +148 -0
- wandb/sdk/launch/inputs/internal.py +224 -0
- wandb/sdk/launch/inputs/manage.py +95 -0
- wandb/sdk/launch/runner/abstract.py +2 -2
- wandb/sdk/launch/runner/kubernetes_monitor.py +45 -12
- wandb/sdk/launch/runner/kubernetes_runner.py +6 -8
- wandb/sdk/launch/runner/local_container.py +2 -3
- wandb/sdk/launch/runner/local_process.py +8 -29
- wandb/sdk/launch/runner/sagemaker_runner.py +20 -14
- wandb/sdk/launch/runner/vertex_runner.py +8 -7
- wandb/sdk/launch/sweeps/scheduler.py +2 -0
- wandb/sdk/launch/sweeps/utils.py +2 -2
- wandb/sdk/launch/utils.py +16 -138
- wandb/sdk/lib/_settings_toposort_generated.py +2 -5
- wandb/sdk/lib/apikey.py +4 -2
- wandb/sdk/lib/config_util.py +3 -3
- wandb/sdk/lib/proto_util.py +22 -1
- wandb/sdk/lib/redirect.py +1 -1
- wandb/sdk/service/service.py +2 -1
- wandb/sdk/service/streams.py +5 -5
- wandb/sdk/wandb_init.py +25 -59
- wandb/sdk/wandb_login.py +28 -25
- wandb/sdk/wandb_run.py +112 -45
- wandb/sdk/wandb_settings.py +33 -64
- wandb/sdk/wandb_watch.py +1 -1
- wandb/sklearn/plot/classifier.py +4 -6
- wandb/sync/sync.py +2 -2
- wandb/testing/relay.py +32 -17
- wandb/util.py +36 -37
- wandb/wandb_agent.py +3 -3
- wandb/wandb_controller.py +3 -2
- {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/METADATA +7 -9
- {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/RECORD +124 -146
- {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/WHEEL +1 -1
- wandb/apis/reports/v1/_blocks.py +0 -1406
- wandb/apis/reports/v1/_helpers.py +0 -70
- wandb/apis/reports/v1/_panels.py +0 -1282
- wandb/apis/reports/v1/_templates.py +0 -478
- wandb/apis/reports/v1/blocks.py +0 -27
- wandb/apis/reports/v1/helpers.py +0 -2
- wandb/apis/reports/v1/mutations.py +0 -66
- wandb/apis/reports/v1/panels.py +0 -17
- wandb/apis/reports/v1/report.py +0 -268
- wandb/apis/reports/v1/runset.py +0 -144
- wandb/apis/reports/v1/templates.py +0 -7
- wandb/apis/reports/v1/util.py +0 -406
- wandb/apis/reports/v1/validators.py +0 -131
- wandb/apis/reports/v2/blocks.py +0 -25
- wandb/apis/reports/v2/expr_parsing.py +0 -257
- wandb/apis/reports/v2/gql.py +0 -68
- wandb/apis/reports/v2/interface.py +0 -1911
- wandb/apis/reports/v2/internal.py +0 -867
- wandb/apis/reports/v2/metrics.py +0 -6
- wandb/apis/reports/v2/panels.py +0 -15
- wandb/catboost/__init__.py +0 -9
- wandb/fastai/__init__.py +0 -9
- wandb/keras/__init__.py +0 -19
- wandb/lightgbm/__init__.py +0 -9
- wandb/plots/__init__.py +0 -6
- wandb/plots/explain_text.py +0 -36
- wandb/plots/heatmap.py +0 -81
- wandb/plots/named_entity.py +0 -43
- wandb/plots/part_of_speech.py +0 -50
- wandb/plots/plot_definitions.py +0 -768
- wandb/plots/precision_recall.py +0 -121
- wandb/plots/roc.py +0 -103
- wandb/sacred/__init__.py +0 -3
- wandb/xgboost/__init__.py +0 -9
- {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/entry_points.txt +0 -0
- {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/licenses/LICENSE +0 -0
wandb/apis/public/projects.py
CHANGED
@@ -22,23 +22,22 @@ class Projects(Paginator):
|
|
22
22
|
|
23
23
|
QUERY = gql(
|
24
24
|
"""
|
25
|
-
query Projects($entity: String, $cursor: String, $perPage: Int = 50) {
|
26
|
-
models(entityName: $entity, after: $cursor, first: $perPage) {
|
27
|
-
edges {
|
28
|
-
node {
|
25
|
+
query Projects($entity: String, $cursor: String, $perPage: Int = 50) {{
|
26
|
+
models(entityName: $entity, after: $cursor, first: $perPage) {{
|
27
|
+
edges {{
|
28
|
+
node {{
|
29
29
|
...ProjectFragment
|
30
|
-
}
|
30
|
+
}}
|
31
31
|
cursor
|
32
|
-
}
|
33
|
-
pageInfo {
|
32
|
+
}}
|
33
|
+
pageInfo {{
|
34
34
|
endCursor
|
35
35
|
hasNextPage
|
36
|
-
}
|
37
|
-
}
|
38
|
-
}
|
39
|
-
|
40
|
-
"""
|
41
|
-
% PROJECT_FRAGMENT
|
36
|
+
}}
|
37
|
+
}}
|
38
|
+
}}
|
39
|
+
{}
|
40
|
+
""".format(PROJECT_FRAGMENT)
|
42
41
|
)
|
43
42
|
|
44
43
|
def __init__(self, client, entity, per_page=50):
|
@@ -118,26 +117,25 @@ class Project(Attrs):
|
|
118
117
|
def sweeps(self):
|
119
118
|
query = gql(
|
120
119
|
"""
|
121
|
-
query GetSweeps($project: String!, $entity: String!) {
|
122
|
-
project(name: $project, entityName: $entity) {
|
120
|
+
query GetSweeps($project: String!, $entity: String!) {{
|
121
|
+
project(name: $project, entityName: $entity) {{
|
123
122
|
totalSweeps
|
124
|
-
sweeps {
|
125
|
-
edges {
|
126
|
-
node {
|
123
|
+
sweeps {{
|
124
|
+
edges {{
|
125
|
+
node {{
|
127
126
|
...SweepFragment
|
128
|
-
}
|
127
|
+
}}
|
129
128
|
cursor
|
130
|
-
}
|
131
|
-
pageInfo {
|
129
|
+
}}
|
130
|
+
pageInfo {{
|
132
131
|
endCursor
|
133
132
|
hasNextPage
|
134
|
-
}
|
135
|
-
}
|
136
|
-
}
|
137
|
-
}
|
138
|
-
|
139
|
-
"""
|
140
|
-
% public.SWEEP_FRAGMENT
|
133
|
+
}}
|
134
|
+
}}
|
135
|
+
}}
|
136
|
+
}}
|
137
|
+
{}
|
138
|
+
""".format(public.SWEEP_FRAGMENT)
|
141
139
|
)
|
142
140
|
variable_values = {"project": self.name, "entity": self.entity}
|
143
141
|
ret = self.client.execute(query, variable_values)
|
@@ -60,7 +60,7 @@ class QueryGenerator:
|
|
60
60
|
return key["name"]
|
61
61
|
elif key["section"] == "tags":
|
62
62
|
return "tags." + key["name"]
|
63
|
-
raise ValueError("Invalid key:
|
63
|
+
raise ValueError("Invalid key: {}".format(key))
|
64
64
|
|
65
65
|
def server_path_to_key(self, path):
|
66
66
|
if path.startswith("config."):
|
wandb/apis/public/runs.py
CHANGED
@@ -2,10 +2,16 @@
|
|
2
2
|
|
3
3
|
import json
|
4
4
|
import os
|
5
|
+
import sys
|
5
6
|
import tempfile
|
6
7
|
import time
|
7
8
|
import urllib
|
8
|
-
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional
|
9
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional
|
10
|
+
|
11
|
+
if sys.version_info >= (3, 8):
|
12
|
+
from typing import Literal
|
13
|
+
else:
|
14
|
+
from typing_extensions import Literal
|
9
15
|
|
10
16
|
from wandb_gql import gql
|
11
17
|
|
@@ -60,27 +66,26 @@ class Runs(Paginator):
|
|
60
66
|
|
61
67
|
QUERY = gql(
|
62
68
|
"""
|
63
|
-
query Runs($project: String!, $entity: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {
|
64
|
-
project(name: $project, entityName: $entity) {
|
69
|
+
query Runs($project: String!, $entity: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{
|
70
|
+
project(name: $project, entityName: $entity) {{
|
65
71
|
runCount(filters: $filters)
|
66
72
|
readOnly
|
67
|
-
runs(filters: $filters, after: $cursor, first: $perPage, order: $order) {
|
68
|
-
edges {
|
69
|
-
node {
|
73
|
+
runs(filters: $filters, after: $cursor, first: $perPage, order: $order) {{
|
74
|
+
edges {{
|
75
|
+
node {{
|
70
76
|
...RunFragment
|
71
|
-
}
|
77
|
+
}}
|
72
78
|
cursor
|
73
|
-
}
|
74
|
-
pageInfo {
|
79
|
+
}}
|
80
|
+
pageInfo {{
|
75
81
|
endCursor
|
76
82
|
hasNextPage
|
77
|
-
}
|
78
|
-
}
|
79
|
-
}
|
80
|
-
}
|
81
|
-
|
82
|
-
"""
|
83
|
-
% RUN_FRAGMENT
|
83
|
+
}}
|
84
|
+
}}
|
85
|
+
}}
|
86
|
+
}}
|
87
|
+
{}
|
88
|
+
""".format(RUN_FRAGMENT)
|
84
89
|
)
|
85
90
|
|
86
91
|
def __init__(
|
@@ -131,7 +136,7 @@ class Runs(Paginator):
|
|
131
136
|
def convert_objects(self):
|
132
137
|
objs = []
|
133
138
|
if self.last_response is None or self.last_response.get("project") is None:
|
134
|
-
raise ValueError("Could not find project
|
139
|
+
raise ValueError("Could not find project {}".format(self.project))
|
135
140
|
for run_response in self.last_response["project"]["runs"]["edges"]:
|
136
141
|
run = Run(
|
137
142
|
self.client,
|
@@ -162,6 +167,104 @@ class Runs(Paginator):
|
|
162
167
|
|
163
168
|
return objs
|
164
169
|
|
170
|
+
@normalize_exceptions
|
171
|
+
def histories(
|
172
|
+
self,
|
173
|
+
samples: int = 500,
|
174
|
+
keys: Optional[List[str]] = None,
|
175
|
+
x_axis: str = "_step",
|
176
|
+
format: Literal["default", "pandas", "polars"] = "default",
|
177
|
+
stream: Literal["default", "system"] = "default",
|
178
|
+
):
|
179
|
+
"""Return sampled history metrics for all runs that fit the filters conditions.
|
180
|
+
|
181
|
+
Arguments:
|
182
|
+
samples : (int, optional) The number of samples to return per run
|
183
|
+
keys : (list[str], optional) Only return metrics for specific keys
|
184
|
+
x_axis : (str, optional) Use this metric as the xAxis defaults to _step
|
185
|
+
format : (Literal, optional) Format to return data in, options are "default", "pandas", "polars"
|
186
|
+
stream : (Literal, optional) "default" for metrics, "system" for machine metrics
|
187
|
+
Returns:
|
188
|
+
pandas.DataFrame: If format="pandas", returns a `pandas.DataFrame` of history metrics.
|
189
|
+
polars.DataFrame: If format="polars", returns a `polars.DataFrame` of history metrics.
|
190
|
+
list of dicts: If format="default", returns a list of dicts containing history metrics with a run_id key.
|
191
|
+
"""
|
192
|
+
if format not in ("default", "pandas", "polars"):
|
193
|
+
raise ValueError(
|
194
|
+
f"Invalid format: {format}. Must be one of 'default', 'pandas', 'polars'"
|
195
|
+
)
|
196
|
+
|
197
|
+
histories = []
|
198
|
+
|
199
|
+
if format == "default":
|
200
|
+
for run in self:
|
201
|
+
history_data = run.history(
|
202
|
+
samples=samples,
|
203
|
+
keys=keys,
|
204
|
+
x_axis=x_axis,
|
205
|
+
pandas=False,
|
206
|
+
stream=stream,
|
207
|
+
)
|
208
|
+
if not history_data:
|
209
|
+
continue
|
210
|
+
for entry in history_data:
|
211
|
+
entry["run_id"] = run.id
|
212
|
+
histories.extend(history_data)
|
213
|
+
|
214
|
+
return histories
|
215
|
+
|
216
|
+
if format == "pandas":
|
217
|
+
pd = util.get_module(
|
218
|
+
"pandas", required="Exporting pandas DataFrame requires pandas"
|
219
|
+
)
|
220
|
+
for run in self:
|
221
|
+
history_data = run.history(
|
222
|
+
samples=samples,
|
223
|
+
keys=keys,
|
224
|
+
x_axis=x_axis,
|
225
|
+
pandas=False,
|
226
|
+
stream=stream,
|
227
|
+
)
|
228
|
+
if not history_data:
|
229
|
+
continue
|
230
|
+
df = pd.DataFrame.from_records(history_data)
|
231
|
+
df["run_id"] = run.id
|
232
|
+
histories.append(df)
|
233
|
+
if not histories:
|
234
|
+
return pd.DataFrame()
|
235
|
+
combined_df = pd.concat(histories)
|
236
|
+
combined_df.sort_values("run_id", inplace=True)
|
237
|
+
combined_df.reset_index(drop=True, inplace=True)
|
238
|
+
# sort columns for consistency
|
239
|
+
combined_df = combined_df[(sorted(combined_df.columns))]
|
240
|
+
|
241
|
+
return combined_df
|
242
|
+
|
243
|
+
if format == "polars":
|
244
|
+
pl = util.get_module(
|
245
|
+
"polars", required="Exporting polars DataFrame requires polars"
|
246
|
+
)
|
247
|
+
for run in self:
|
248
|
+
history_data = run.history(
|
249
|
+
samples=samples,
|
250
|
+
keys=keys,
|
251
|
+
x_axis=x_axis,
|
252
|
+
pandas=False,
|
253
|
+
stream=stream,
|
254
|
+
)
|
255
|
+
if not history_data:
|
256
|
+
continue
|
257
|
+
df = pl.from_records(history_data)
|
258
|
+
df = df.with_columns(pl.lit(run.id).alias("run_id"))
|
259
|
+
histories.append(df)
|
260
|
+
if not histories:
|
261
|
+
return pl.DataFrame()
|
262
|
+
combined_df = pl.concat(histories, how="align")
|
263
|
+
# sort columns for consistency
|
264
|
+
combined_df = combined_df.select(sorted(combined_df.columns)).sort("run_id")
|
265
|
+
|
266
|
+
return combined_df
|
267
|
+
|
165
268
|
def __repr__(self):
|
166
269
|
return f"<Runs {self.entity}/{self.project}>"
|
167
270
|
|
@@ -310,16 +413,15 @@ class Run(Attrs):
|
|
310
413
|
def load(self, force=False):
|
311
414
|
query = gql(
|
312
415
|
"""
|
313
|
-
query Run($project: String!, $entity: String!, $name: String!) {
|
314
|
-
project(name: $project, entityName: $entity) {
|
315
|
-
run(name: $name) {
|
416
|
+
query Run($project: String!, $entity: String!, $name: String!) {{
|
417
|
+
project(name: $project, entityName: $entity) {{
|
418
|
+
run(name: $name) {{
|
316
419
|
...RunFragment
|
317
|
-
}
|
318
|
-
}
|
319
|
-
}
|
320
|
-
|
321
|
-
"""
|
322
|
-
% RUN_FRAGMENT
|
420
|
+
}}
|
421
|
+
}}
|
422
|
+
}}
|
423
|
+
{}
|
424
|
+
""".format(RUN_FRAGMENT)
|
323
425
|
)
|
324
426
|
if force or not self._attrs:
|
325
427
|
response = self._exec(query)
|
@@ -328,7 +430,7 @@ class Run(Attrs):
|
|
328
430
|
or response.get("project") is None
|
329
431
|
or response["project"].get("run") is None
|
330
432
|
):
|
331
|
-
raise ValueError("Could not find run
|
433
|
+
raise ValueError("Could not find run {}".format(self))
|
332
434
|
self._attrs = response["project"]["run"]
|
333
435
|
self._state = self._attrs["state"]
|
334
436
|
|
@@ -402,16 +504,15 @@ class Run(Attrs):
|
|
402
504
|
"""Persist changes to the run object to the wandb backend."""
|
403
505
|
mutation = gql(
|
404
506
|
"""
|
405
|
-
mutation UpsertBucket($id: String!, $description: String, $display_name: String, $notes: String, $tags: [String!], $config: JSONString!, $groupName: String) {
|
406
|
-
upsertBucket(input: {id: $id, description: $description, displayName: $display_name, notes: $notes, tags: $tags, config: $config, groupName: $groupName}) {
|
407
|
-
bucket {
|
507
|
+
mutation UpsertBucket($id: String!, $description: String, $display_name: String, $notes: String, $tags: [String!], $config: JSONString!, $groupName: String) {{
|
508
|
+
upsertBucket(input: {{id: $id, description: $description, displayName: $display_name, notes: $notes, tags: $tags, config: $config, groupName: $groupName}}) {{
|
509
|
+
bucket {{
|
408
510
|
...RunFragment
|
409
|
-
}
|
410
|
-
}
|
411
|
-
}
|
412
|
-
|
413
|
-
"""
|
414
|
-
% RUN_FRAGMENT
|
511
|
+
}}
|
512
|
+
}}
|
513
|
+
}}
|
514
|
+
{}
|
515
|
+
""".format(RUN_FRAGMENT)
|
415
516
|
)
|
416
517
|
_ = self._exec(
|
417
518
|
mutation,
|
@@ -491,13 +592,12 @@ class Run(Attrs):
|
|
491
592
|
node = "history" if stream == "default" else "events"
|
492
593
|
query = gql(
|
493
594
|
"""
|
494
|
-
query RunFullHistory($project: String!, $entity: String!, $name: String!, $samples: Int) {
|
495
|
-
project(name: $project, entityName: $entity) {
|
496
|
-
run(name: $name) {
|
497
|
-
}
|
498
|
-
}
|
499
|
-
"""
|
500
|
-
% node
|
595
|
+
query RunFullHistory($project: String!, $entity: String!, $name: String!, $samples: Int) {{
|
596
|
+
project(name: $project, entityName: $entity) {{
|
597
|
+
run(name: $name) {{ {}(samples: $samples) }}
|
598
|
+
}}
|
599
|
+
}}
|
600
|
+
""".format(node)
|
501
601
|
)
|
502
602
|
|
503
603
|
response = self._exec(query, samples=samples)
|
@@ -587,9 +687,9 @@ class Run(Attrs):
|
|
587
687
|
else:
|
588
688
|
lines = self._full_history(samples=samples, stream=stream)
|
589
689
|
if pandas:
|
590
|
-
|
591
|
-
if
|
592
|
-
lines =
|
690
|
+
pd = util.get_module("pandas")
|
691
|
+
if pd:
|
692
|
+
lines = pd.DataFrame.from_records(lines)
|
593
693
|
else:
|
594
694
|
print("Unable to load pandas, call history with pandas=False")
|
595
695
|
return lines
|
@@ -607,10 +707,11 @@ class Run(Attrs):
|
|
607
707
|
losses = [row["Loss"] for row in history]
|
608
708
|
```
|
609
709
|
|
610
|
-
|
611
710
|
Arguments:
|
612
711
|
keys ([str], optional): only fetch these keys, and only fetch rows that have all of keys defined.
|
613
|
-
page_size (int, optional): size of pages to fetch from the api
|
712
|
+
page_size (int, optional): size of pages to fetch from the api.
|
713
|
+
min_step (int, optional): the minimum number of pages to scan at a time.
|
714
|
+
max_step (int, optional): the maximum number of pages to scan at a time.
|
614
715
|
|
615
716
|
Returns:
|
616
717
|
An iterable collection over history records (dict).
|
@@ -707,27 +808,24 @@ class Run(Attrs):
|
|
707
808
|
)
|
708
809
|
api.set_current_run_id(self.id)
|
709
810
|
|
710
|
-
if isinstance(artifact, wandb.Artifact)
|
711
|
-
|
712
|
-
|
713
|
-
or self.project != artifact.source_project
|
714
|
-
):
|
715
|
-
raise ValueError("A run can't log an artifact to a different project.")
|
716
|
-
artifact_collection_name = artifact.source_name.split(":")[0]
|
717
|
-
api.create_artifact(
|
718
|
-
artifact.type,
|
719
|
-
artifact_collection_name,
|
720
|
-
artifact.digest,
|
721
|
-
aliases=aliases,
|
722
|
-
)
|
723
|
-
return artifact
|
724
|
-
elif isinstance(artifact, wandb.Artifact) and artifact.is_draft():
|
811
|
+
if not isinstance(artifact, wandb.Artifact):
|
812
|
+
raise ValueError("You must pass a wandb.Api().artifact() to use_artifact")
|
813
|
+
if artifact.is_draft():
|
725
814
|
raise ValueError(
|
726
815
|
"Only existing artifacts are accepted by this api. "
|
727
816
|
"Manually create one with `wandb artifact put`"
|
728
817
|
)
|
729
|
-
|
730
|
-
|
818
|
+
if (
|
819
|
+
self.entity != artifact.source_entity
|
820
|
+
or self.project != artifact.source_project
|
821
|
+
):
|
822
|
+
raise ValueError("A run can't log an artifact to a different project.")
|
823
|
+
|
824
|
+
artifact_collection_name = artifact.source_name.split(":")[0]
|
825
|
+
api.create_artifact(
|
826
|
+
artifact.type, artifact_collection_name, artifact.digest, aliases=aliases
|
827
|
+
)
|
828
|
+
return artifact
|
731
829
|
|
732
830
|
@property
|
733
831
|
def summary(self):
|
wandb/apis/public/sweeps.py
CHANGED
@@ -107,7 +107,7 @@ class Sweep(Attrs):
|
|
107
107
|
if force or not self._attrs:
|
108
108
|
sweep = self.get(self.client, self.entity, self.project, self.id)
|
109
109
|
if sweep is None:
|
110
|
-
raise ValueError("Could not find sweep
|
110
|
+
raise ValueError("Could not find sweep {}".format(self))
|
111
111
|
self._attrs = sweep._attrs
|
112
112
|
self.runs = sweep.runs
|
113
113
|
|
@@ -133,7 +133,7 @@ class Sweep(Attrs):
|
|
133
133
|
"No order specified and couldn't find metric in sweep config, returning most recent run"
|
134
134
|
)
|
135
135
|
else:
|
136
|
-
wandb.termlog("Sorting runs by
|
136
|
+
wandb.termlog("Sorting runs by {}".format(order))
|
137
137
|
filters = {"$and": [{"sweep": self.id}]}
|
138
138
|
try:
|
139
139
|
return public.Runs(
|
wandb/apis/reports/__init__.py
CHANGED
@@ -1,30 +1,8 @@
|
|
1
|
-
import os
|
2
|
-
from inspect import cleandoc
|
3
|
-
|
4
1
|
import wandb
|
5
2
|
|
6
|
-
|
7
|
-
from .
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
from .runset import Runset
|
12
|
-
from .templates import * # noqa: F403
|
13
|
-
from .util import InlineCode, InlineLaTeX, Link
|
14
|
-
|
15
|
-
|
16
|
-
def show_welcome_message():
|
17
|
-
if os.getenv("WANDB_REPORT_API_DISABLE_MESSAGE"):
|
18
|
-
return
|
19
|
-
|
20
|
-
wandb.termwarn(
|
21
|
-
cleandoc(
|
22
|
-
"""
|
23
|
-
The v1 API is deprecated and will be removed in a future release. Please move to v2 by setting the env var WANDB_REPORT_API_ENABLE_V2=True. This will be on by default in a future release.
|
24
|
-
You can disable this message by setting the env var WANDB_REPORT_API_DISABLE_MESSAGE=True
|
25
|
-
"""
|
26
|
-
)
|
3
|
+
try:
|
4
|
+
from wandb_workspaces.reports.v1 import * # noqa: F403
|
5
|
+
except ImportError:
|
6
|
+
wandb.termerror(
|
7
|
+
"Failed to import wandb_workspaces. To edit reports programatically, please install it using `pip install wandb[workspaces]`."
|
27
8
|
)
|
28
|
-
|
29
|
-
|
30
|
-
show_welcome_message()
|
@@ -1,20 +1,8 @@
|
|
1
|
-
import
|
2
|
-
from inspect import cleandoc
|
1
|
+
import wandb
|
3
2
|
|
4
|
-
|
5
|
-
from . import
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
InlineLatex,
|
11
|
-
Layout,
|
12
|
-
Link,
|
13
|
-
ParallelCoordinatesPlotColumn,
|
14
|
-
Report,
|
15
|
-
Runset,
|
16
|
-
RunsetGroup,
|
17
|
-
RunsetGroupKey,
|
18
|
-
)
|
19
|
-
from .metrics import * # noqa
|
20
|
-
from .panels import * # noqa
|
3
|
+
try:
|
4
|
+
from wandb_workspaces.reports.v2 import * # noqa: F403
|
5
|
+
except ImportError:
|
6
|
+
wandb.termerror(
|
7
|
+
"Failed to import wandb_workspaces. To edit reports programatically, please install it using `pip install wandb[workspaces]`."
|
8
|
+
)
|
wandb/beta/workflows.py
CHANGED
@@ -183,7 +183,7 @@ def log_model(
|
|
183
183
|
return model
|
184
184
|
|
185
185
|
|
186
|
-
def use_model(aliased_path: str) -> "_SavedModel":
|
186
|
+
def use_model(aliased_path: str, unsafe: bool = False) -> "_SavedModel":
|
187
187
|
"""Fetch a saved model from an alias.
|
188
188
|
|
189
189
|
Under the hood, we use the alias to fetch the model artifact containing the
|
@@ -193,17 +193,22 @@ def use_model(aliased_path: str) -> "_SavedModel":
|
|
193
193
|
Args:
|
194
194
|
aliased_path: `str` - the following forms are valid: "name:version",
|
195
195
|
"name:alias". May be prefixed with "entity/project".
|
196
|
+
unsafe: `bool` - must be True to indicate the user understands the risks
|
197
|
+
associated with loading external models.
|
196
198
|
|
197
199
|
Returns:
|
198
200
|
_SavedModel instance
|
199
201
|
|
200
202
|
Example:
|
201
203
|
```python
|
202
|
-
# Assuming
|
203
|
-
sm = use_model("my-simple-model:latest")
|
204
|
+
# Assuming the model with the name "my-simple-model" is trusted:
|
205
|
+
sm = use_model("my-simple-model:latest", unsafe=True)
|
204
206
|
model = sm.model_obj()
|
205
207
|
```
|
206
208
|
"""
|
209
|
+
if not unsafe:
|
210
|
+
raise ValueError("The 'unsafe' parameter must be set to True to load a model.")
|
211
|
+
|
207
212
|
if ":" not in aliased_path:
|
208
213
|
raise ValueError(
|
209
214
|
"aliased_path must be of the form 'name:alias' or 'name:version'."
|