wandb 0.19.0rc1__py3-none-win_amd64.whl → 0.19.1rc1__py3-none-win_amd64.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +211 -209
- wandb/apis/attrs.py +15 -4
- wandb/apis/public/api.py +8 -4
- wandb/apis/public/files.py +65 -12
- wandb/apis/public/runs.py +52 -7
- wandb/apis/public/sweeps.py +1 -1
- wandb/bin/gpu_stats.exe +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +2 -1
- wandb/env.py +1 -1
- wandb/errors/term.py +60 -1
- wandb/integration/keras/callbacks/tables_builder.py +3 -1
- wandb/integration/kfp/kfp_patch.py +25 -15
- wandb/integration/lightning/fabric/logger.py +3 -1
- wandb/integration/tensorboard/monkeypatch.py +3 -2
- wandb/jupyter.py +4 -5
- wandb/plot/bar.py +5 -6
- wandb/plot/histogram.py +1 -1
- wandb/plot/line_series.py +3 -3
- wandb/plot/pr_curve.py +7 -3
- wandb/plot/scatter.py +2 -1
- wandb/proto/v3/wandb_settings_pb2.py +25 -15
- wandb/proto/v4/wandb_settings_pb2.py +17 -15
- wandb/proto/v5/wandb_settings_pb2.py +17 -15
- wandb/sdk/artifacts/_validators.py +1 -3
- wandb/sdk/artifacts/artifact_manifest_entry.py +1 -1
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +12 -2
- wandb/sdk/data_types/helper_types/image_mask.py +8 -2
- wandb/sdk/data_types/histogram.py +3 -3
- wandb/sdk/data_types/image.py +3 -1
- wandb/sdk/interface/interface.py +34 -5
- wandb/sdk/interface/interface_sock.py +2 -2
- wandb/sdk/internal/file_stream.py +4 -1
- wandb/sdk/internal/sender.py +4 -1
- wandb/sdk/internal/settings_static.py +17 -4
- wandb/sdk/launch/utils.py +1 -0
- wandb/sdk/lib/ipython.py +5 -27
- wandb/sdk/lib/printer.py +33 -20
- wandb/sdk/lib/sparkline.py +1 -2
- wandb/sdk/wandb_config.py +2 -2
- wandb/sdk/wandb_init.py +236 -243
- wandb/sdk/wandb_run.py +172 -231
- wandb/sdk/wandb_settings.py +103 -14
- {wandb-0.19.0rc1.dist-info → wandb-0.19.1rc1.dist-info}/METADATA +1 -1
- {wandb-0.19.0rc1.dist-info → wandb-0.19.1rc1.dist-info}/RECORD +49 -49
- {wandb-0.19.0rc1.dist-info → wandb-0.19.1rc1.dist-info}/WHEEL +0 -0
- {wandb-0.19.0rc1.dist-info → wandb-0.19.1rc1.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.0rc1.dist-info → wandb-0.19.1rc1.dist-info}/licenses/LICENSE +0 -0
wandb/apis/public/files.py
CHANGED
@@ -49,6 +49,7 @@ class Files(Paginator):
|
|
49
49
|
query RunFiles($project: String!, $entity: String!, $name: String!, $fileCursor: String,
|
50
50
|
$fileLimit: Int = 50, $fileNames: [String] = [], $upload: Boolean = false) {{
|
51
51
|
project(name: $project, entityName: $entity) {{
|
52
|
+
internalId
|
52
53
|
run(name: $name) {{
|
53
54
|
fileCount
|
54
55
|
...RunFilesFragment
|
@@ -98,7 +99,7 @@ class Files(Paginator):
|
|
98
99
|
|
99
100
|
def convert_objects(self):
|
100
101
|
return [
|
101
|
-
File(self.client, r["node"])
|
102
|
+
File(self.client, r["node"], self.run)
|
102
103
|
for r in self.last_response["project"]["run"]["files"]["edges"]
|
103
104
|
]
|
104
105
|
|
@@ -120,9 +121,11 @@ class File(Attrs):
|
|
120
121
|
path_uri (str): path to file in the bucket, currently only available for files stored in S3
|
121
122
|
"""
|
122
123
|
|
123
|
-
def __init__(self, client, attrs):
|
124
|
+
def __init__(self, client, attrs, run=None):
|
124
125
|
self.client = client
|
125
126
|
self._attrs = attrs
|
127
|
+
self.run = run
|
128
|
+
self.server_supports_delete_file_with_project_id: Optional[bool] = None
|
126
129
|
super().__init__(dict(attrs))
|
127
130
|
|
128
131
|
@property
|
@@ -189,18 +192,35 @@ class File(Attrs):
|
|
189
192
|
|
190
193
|
@normalize_exceptions
|
191
194
|
def delete(self):
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
files: $files
|
197
|
-
}) {
|
198
|
-
success
|
199
|
-
}
|
195
|
+
project_id_mutation_fragment = ""
|
196
|
+
project_id_variable_fragment = ""
|
197
|
+
variable_values = {
|
198
|
+
"files": [self.id],
|
200
199
|
}
|
201
|
-
|
200
|
+
|
201
|
+
# Add projectId to mutation and variables if the server supports it.
|
202
|
+
# Otherwise, do not include projectId in mutation for older server versions which do not support it.
|
203
|
+
if self._server_accepts_project_id_for_delete_file():
|
204
|
+
variable_values["projectId"] = self.run._project_internal_id
|
205
|
+
project_id_variable_fragment = ", $projectId: Int"
|
206
|
+
project_id_mutation_fragment = "projectId: $projectId"
|
207
|
+
|
208
|
+
mutation_string = """
|
209
|
+
mutation deleteFiles($files: [ID!]!{}) {{
|
210
|
+
deleteFiles(input: {{
|
211
|
+
files: $files
|
212
|
+
{}
|
213
|
+
}}) {{
|
214
|
+
success
|
215
|
+
}}
|
216
|
+
}}
|
217
|
+
""".format(project_id_variable_fragment, project_id_mutation_fragment)
|
218
|
+
mutation = gql(mutation_string)
|
219
|
+
|
220
|
+
self.client.execute(
|
221
|
+
mutation,
|
222
|
+
variable_values=variable_values,
|
202
223
|
)
|
203
|
-
self.client.execute(mutation, variable_values={"files": [self.id]})
|
204
224
|
|
205
225
|
def __repr__(self):
|
206
226
|
return "<File {} ({}) {}>".format(
|
@@ -208,3 +228,36 @@ class File(Attrs):
|
|
208
228
|
self.mimetype,
|
209
229
|
util.to_human_size(self.size, units=util.POW_2_BYTES),
|
210
230
|
)
|
231
|
+
|
232
|
+
@normalize_exceptions
|
233
|
+
def _server_accepts_project_id_for_delete_file(self) -> bool:
|
234
|
+
"""Returns True if the server supports deleting files with a projectId.
|
235
|
+
|
236
|
+
This check is done by utilizing GraphQL introspection in the avaiable fields on the DeleteFiles API.
|
237
|
+
"""
|
238
|
+
query_string = """
|
239
|
+
query ProbeDeleteFilesProjectIdInput {
|
240
|
+
DeleteFilesProjectIdInputType: __type(name:"DeleteFilesInput") {
|
241
|
+
inputFields{
|
242
|
+
name
|
243
|
+
}
|
244
|
+
}
|
245
|
+
}
|
246
|
+
"""
|
247
|
+
|
248
|
+
# Only perform the query once to avoid extra network calls
|
249
|
+
if self.server_supports_delete_file_with_project_id is None:
|
250
|
+
query = gql(query_string)
|
251
|
+
res = self.client.execute(query)
|
252
|
+
|
253
|
+
# If projectId is in the inputFields, the server supports deleting files with a projectId
|
254
|
+
self.server_supports_delete_file_with_project_id = "projectId" in [
|
255
|
+
x["name"]
|
256
|
+
for x in (
|
257
|
+
res.get("DeleteFilesProjectIdInputType", {}).get(
|
258
|
+
"inputFields", [{}]
|
259
|
+
)
|
260
|
+
)
|
261
|
+
]
|
262
|
+
|
263
|
+
return self.server_supports_delete_file_with_project_id
|
wandb/apis/public/runs.py
CHANGED
@@ -71,6 +71,7 @@ class Runs(Paginator):
|
|
71
71
|
"""
|
72
72
|
query Runs($project: String!, $entity: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{
|
73
73
|
project(name: $project, entityName: $entity) {{
|
74
|
+
internalId
|
74
75
|
runCount(filters: $filters)
|
75
76
|
readOnly
|
76
77
|
runs(filters: $filters, after: $cursor, first: $perPage, order: $order) {{
|
@@ -103,6 +104,7 @@ class Runs(Paginator):
|
|
103
104
|
):
|
104
105
|
self.entity = entity
|
105
106
|
self.project = project
|
107
|
+
self._project_internal_id = None
|
106
108
|
self.filters = filters or {}
|
107
109
|
self.order = order
|
108
110
|
self._sweeps = {}
|
@@ -287,6 +289,7 @@ class Run(Attrs):
|
|
287
289
|
Calling update will persist any changes.
|
288
290
|
project (str): the project associated with the run
|
289
291
|
entity (str): the name of the entity associated with the run
|
292
|
+
project_internal_id (int): the internal id of the project
|
290
293
|
user (str): the name of the user who created the run
|
291
294
|
path (str): Unique identifier [entity]/[project]/[run_id]
|
292
295
|
notes (str): Notes about the run
|
@@ -328,6 +331,7 @@ class Run(Attrs):
|
|
328
331
|
self._summary = None
|
329
332
|
self._metadata: Optional[Dict[str, Any]] = None
|
330
333
|
self._state = _attrs.get("state", "not found")
|
334
|
+
self.server_provides_internal_id_field: Optional[bool] = None
|
331
335
|
|
332
336
|
self.load(force=not _attrs)
|
333
337
|
|
@@ -417,13 +421,18 @@ class Run(Attrs):
|
|
417
421
|
"""
|
418
422
|
query Run($project: String!, $entity: String!, $name: String!) {{
|
419
423
|
project(name: $project, entityName: $entity) {{
|
424
|
+
{}
|
420
425
|
run(name: $name) {{
|
421
426
|
...RunFragment
|
422
427
|
}}
|
423
428
|
}}
|
424
429
|
}}
|
425
430
|
{}
|
426
|
-
""".format(
|
431
|
+
""".format(
|
432
|
+
# Only query internalId if the server supports it
|
433
|
+
"internalId" if self._server_provides_internal_id_for_project() else "",
|
434
|
+
RUN_FRAGMENT,
|
435
|
+
)
|
427
436
|
)
|
428
437
|
if force or not self._attrs:
|
429
438
|
response = self._exec(query)
|
@@ -435,7 +444,7 @@ class Run(Attrs):
|
|
435
444
|
raise ValueError("Could not find run {}".format(self))
|
436
445
|
self._attrs = response["project"]["run"]
|
437
446
|
self._state = self._attrs["state"]
|
438
|
-
|
447
|
+
self._project_internal_id = response["project"].get("internalId", None)
|
439
448
|
if self._include_sweeps and self.sweep_name and not self.sweep:
|
440
449
|
# There may be a lot of runs. Don't bother pulling them all
|
441
450
|
# just for the sake of this one.
|
@@ -495,7 +504,6 @@ class Run(Attrs):
|
|
495
504
|
res = self._exec(query)
|
496
505
|
state = res["project"]["run"]["state"]
|
497
506
|
if state in ["finished", "crashed", "failed"]:
|
498
|
-
print(f"Run finished with status: {state}")
|
499
507
|
self._attrs["state"] = state
|
500
508
|
self._state = state
|
501
509
|
return
|
@@ -506,8 +514,8 @@ class Run(Attrs):
|
|
506
514
|
"""Persist changes to the run object to the wandb backend."""
|
507
515
|
mutation = gql(
|
508
516
|
"""
|
509
|
-
mutation UpsertBucket($id: String!, $description: String, $display_name: String, $notes: String, $tags: [String!], $config: JSONString!, $groupName: String) {{
|
510
|
-
upsertBucket(input: {{id: $id, description: $description, displayName: $display_name, notes: $notes, tags: $tags, config: $config, groupName: $groupName}}) {{
|
517
|
+
mutation UpsertBucket($id: String!, $description: String, $display_name: String, $notes: String, $tags: [String!], $config: JSONString!, $groupName: String, $jobType: String) {{
|
518
|
+
upsertBucket(input: {{id: $id, description: $description, displayName: $display_name, notes: $notes, tags: $tags, config: $config, groupName: $groupName, jobType: $jobType}}) {{
|
511
519
|
bucket {{
|
512
520
|
...RunFragment
|
513
521
|
}}
|
@@ -525,6 +533,7 @@ class Run(Attrs):
|
|
525
533
|
display_name=self.display_name,
|
526
534
|
config=self.json_config,
|
527
535
|
groupName=self.group,
|
536
|
+
jobType=self.job_type,
|
528
537
|
)
|
529
538
|
self.summary.update()
|
530
539
|
|
@@ -770,7 +779,9 @@ class Run(Attrs):
|
|
770
779
|
Example:
|
771
780
|
>>> import wandb
|
772
781
|
>>> import tempfile
|
773
|
-
>>> with tempfile.NamedTemporaryFile(
|
782
|
+
>>> with tempfile.NamedTemporaryFile(
|
783
|
+
... mode="w", delete=False, suffix=".txt"
|
784
|
+
... ) as tmp:
|
774
785
|
... tmp.write("This is a test artifact")
|
775
786
|
... tmp_path = tmp.name
|
776
787
|
>>> run = wandb.init(project="artifact-example")
|
@@ -893,6 +904,37 @@ class Run(Attrs):
|
|
893
904
|
)
|
894
905
|
return artifact
|
895
906
|
|
907
|
+
@normalize_exceptions
|
908
|
+
def _server_provides_internal_id_for_project(self) -> bool:
|
909
|
+
"""Returns True if the server allows us to query the internalId field for a project.
|
910
|
+
|
911
|
+
This check is done by utilizing GraphQL introspection in the avaiable fields on the Project type.
|
912
|
+
"""
|
913
|
+
query_string = """
|
914
|
+
query ProbeProjectInput {
|
915
|
+
ProjectType: __type(name:"Project") {
|
916
|
+
fields {
|
917
|
+
name
|
918
|
+
}
|
919
|
+
}
|
920
|
+
}
|
921
|
+
"""
|
922
|
+
|
923
|
+
# Only perform the query once to avoid extra network calls
|
924
|
+
if self.server_provides_internal_id_field is None:
|
925
|
+
query = gql(query_string)
|
926
|
+
res = self.client.execute(query)
|
927
|
+
print(
|
928
|
+
"internalId"
|
929
|
+
in [x["name"] for x in (res.get("ProjectType", {}).get("fields", [{}]))]
|
930
|
+
)
|
931
|
+
|
932
|
+
self.server_provides_internal_id_field = "internalId" in [
|
933
|
+
x["name"] for x in (res.get("ProjectType", {}).get("fields", [{}]))
|
934
|
+
]
|
935
|
+
|
936
|
+
return self.server_provides_internal_id_field
|
937
|
+
|
896
938
|
@property
|
897
939
|
def summary(self):
|
898
940
|
if self._summary is None:
|
@@ -921,7 +963,10 @@ class Run(Attrs):
|
|
921
963
|
if self._metadata is None:
|
922
964
|
try:
|
923
965
|
f = self.file("wandb-metadata.json")
|
924
|
-
|
966
|
+
session = self.client._client.transport.session
|
967
|
+
response = session.get(f.url, timeout=5)
|
968
|
+
response.raise_for_status()
|
969
|
+
contents = response.content
|
925
970
|
self._metadata = json_util.loads(contents)
|
926
971
|
except: # noqa: E722
|
927
972
|
# file doesn't exist, or can't be downloaded, or can't be parsed
|
wandb/apis/public/sweeps.py
CHANGED
wandb/bin/gpu_stats.exe
CHANGED
Binary file
|
wandb/bin/wandb-core
CHANGED
Binary file
|
wandb/cli/cli.py
CHANGED
@@ -33,6 +33,7 @@ from wandb.apis.public import RunQueue
|
|
33
33
|
from wandb.errors.links import url_registry
|
34
34
|
from wandb.sdk.artifacts._validators import is_artifact_registry_project
|
35
35
|
from wandb.sdk.artifacts.artifact_file_cache import get_artifact_file_cache
|
36
|
+
from wandb.sdk.internal.internal_api import Api as SDKInternalApi
|
36
37
|
from wandb.sdk.launch import utils as launch_utils
|
37
38
|
from wandb.sdk.launch._launch_add import _launch_add
|
38
39
|
from wandb.sdk.launch.errors import ExecutionError, LaunchError
|
@@ -2399,7 +2400,7 @@ def get(path, root, type):
|
|
2399
2400
|
settings_entity = public_api.settings["entity"] or public_api.default_entity
|
2400
2401
|
# Registry artifacts are under the org entity. Because we offer a shorthand and alias for this path,
|
2401
2402
|
# we need to fetch the org entity to for the user behind the scenes.
|
2402
|
-
entity =
|
2403
|
+
entity = SDKInternalApi()._resolve_org_entity_name(
|
2403
2404
|
entity=settings_entity, organization=organization
|
2404
2405
|
)
|
2405
2406
|
full_path = f"{entity}/{project}/{artifact_name}:{version}"
|
wandb/env.py
CHANGED
@@ -169,7 +169,7 @@ def error_reporting_enabled() -> bool:
|
|
169
169
|
|
170
170
|
|
171
171
|
def core_debug(default: str | None = None) -> bool:
|
172
|
-
return _env_as_bool(CORE_DEBUG, default=default)
|
172
|
+
return _env_as_bool(CORE_DEBUG, default=default) or is_debug()
|
173
173
|
|
174
174
|
|
175
175
|
def ssl_disabled() -> bool:
|
wandb/errors/term.py
CHANGED
@@ -5,6 +5,8 @@ from __future__ import annotations
|
|
5
5
|
import contextlib
|
6
6
|
import logging
|
7
7
|
import os
|
8
|
+
import re
|
9
|
+
import shutil
|
8
10
|
import sys
|
9
11
|
import threading
|
10
12
|
from typing import TYPE_CHECKING, Iterator, Protocol
|
@@ -271,10 +273,67 @@ class DynamicBlock:
|
|
271
273
|
The lock must be held.
|
272
274
|
"""
|
273
275
|
if self._lines_to_print:
|
274
|
-
|
276
|
+
# Trim lines before printing. This is crucial because the \x1b[Am
|
277
|
+
# (cursor up) sequence used when clearing the text moves up by one
|
278
|
+
# visual line, and the terminal may be wrapping long lines onto
|
279
|
+
# multiple visual lines.
|
280
|
+
#
|
281
|
+
# There is no ANSI escape sequence that moves the cursor up by one
|
282
|
+
# "physical" line instead. Note that the user may resize their
|
283
|
+
# terminal.
|
284
|
+
term_width = _shutil_get_terminal_width()
|
285
|
+
click.echo(
|
286
|
+
"\n".join(
|
287
|
+
_ansi_shorten(line, term_width) #
|
288
|
+
for line in self._lines_to_print
|
289
|
+
),
|
290
|
+
file=sys.stderr,
|
291
|
+
)
|
292
|
+
|
275
293
|
self._num_lines_printed += len(self._lines_to_print)
|
276
294
|
|
277
295
|
|
296
|
+
def _shutil_get_terminal_width() -> int:
|
297
|
+
"""Returns the width of the terminal.
|
298
|
+
|
299
|
+
Defined here for patching in tests.
|
300
|
+
"""
|
301
|
+
columns, _ = shutil.get_terminal_size()
|
302
|
+
return columns
|
303
|
+
|
304
|
+
|
305
|
+
_ANSI_RE = re.compile("\x1b\\[(K|.*?m)")
|
306
|
+
|
307
|
+
|
308
|
+
def _ansi_shorten(text: str, width: int) -> str:
|
309
|
+
"""Shorten text potentially containing ANSI sequences to fit a width."""
|
310
|
+
first_ansi = _ANSI_RE.search(text)
|
311
|
+
|
312
|
+
if not first_ansi:
|
313
|
+
return _raw_shorten(text, width)
|
314
|
+
|
315
|
+
if first_ansi.start() > width - 3:
|
316
|
+
return _raw_shorten(text[: first_ansi.start()], width)
|
317
|
+
|
318
|
+
return text[: first_ansi.end()] + _ansi_shorten(
|
319
|
+
text[first_ansi.end() :],
|
320
|
+
# Key part: the ANSI sequence doesn't reduce the remaining width.
|
321
|
+
width - first_ansi.start(),
|
322
|
+
)
|
323
|
+
|
324
|
+
|
325
|
+
def _raw_shorten(text: str, width: int) -> str:
|
326
|
+
"""Shorten text to fit a width, replacing the end with "...".
|
327
|
+
|
328
|
+
Unlike textwrap.shorten(), this does not drop whitespace or do anything
|
329
|
+
smart.
|
330
|
+
"""
|
331
|
+
if len(text) <= width:
|
332
|
+
return text
|
333
|
+
|
334
|
+
return text[: width - 3] + "..."
|
335
|
+
|
336
|
+
|
278
337
|
def _log(
|
279
338
|
string="",
|
280
339
|
newline=True,
|
@@ -145,7 +145,9 @@ class WandbEvalCallback(Callback, abc.ABC):
|
|
145
145
|
for idx, data in enumerate(dataloader):
|
146
146
|
preds = model.predict(data)
|
147
147
|
self.pred_table.add_data(
|
148
|
-
self.data_table_ref.data[idx][0],
|
148
|
+
self.data_table_ref.data[idx][0],
|
149
|
+
self.data_table_ref.data[idx][1],
|
150
|
+
preds,
|
149
151
|
)
|
150
152
|
```
|
151
153
|
This method is called `on_epoch_end` or equivalent hook.
|
@@ -208,23 +208,24 @@ def create_component_from_func(
|
|
208
208
|
"""Return sum of two arguments"""
|
209
209
|
return a + b
|
210
210
|
|
211
|
+
|
211
212
|
# add_op is a task factory function that creates a task object when given arguments
|
212
213
|
add_op = create_component_from_func(
|
213
214
|
func=add,
|
214
|
-
base_image=
|
215
|
-
output_component_file=
|
216
|
-
packages_to_install=[
|
215
|
+
base_image="python:3.7", # Optional
|
216
|
+
output_component_file="add.component.yaml", # Optional
|
217
|
+
packages_to_install=["pandas==0.24"], # Optional
|
217
218
|
)
|
218
219
|
|
219
220
|
# The component spec can be accessed through the .component_spec attribute:
|
220
|
-
add_op.component_spec.save(
|
221
|
+
add_op.component_spec.save("add.component.yaml")
|
221
222
|
|
222
223
|
# The component function can be called with arguments to create a task:
|
223
224
|
add_task = add_op(1, 3)
|
224
225
|
|
225
226
|
# The resulting task has output references, corresponding to the component outputs.
|
226
227
|
# When the function only has a single anonymous return value, the output name is "Output":
|
227
|
-
sum_output_ref = add_task.outputs[
|
228
|
+
sum_output_ref = add_task.outputs["Output"]
|
228
229
|
|
229
230
|
# These task output references can be passed to other component functions, constructing a computation graph:
|
230
231
|
task2 = add_op(sum_output_ref, 5)
|
@@ -241,17 +242,21 @@ def create_component_from_func(
|
|
241
242
|
|
242
243
|
from typing import NamedTuple
|
243
244
|
|
244
|
-
|
245
|
+
|
246
|
+
def add_multiply_two_numbers(a: float, b: float) -> NamedTuple(
|
247
|
+
"Outputs", [("sum", float), ("product", float)]
|
248
|
+
):
|
245
249
|
"""Return sum and product of two arguments"""
|
246
250
|
return (a + b, a * b)
|
247
251
|
|
252
|
+
|
248
253
|
add_multiply_op = create_component_from_func(add_multiply_two_numbers)
|
249
254
|
|
250
255
|
# The component function can be called with arguments to create a task:
|
251
256
|
add_multiply_task = add_multiply_op(1, 3)
|
252
257
|
|
253
258
|
# The resulting task has output references, corresponding to the component outputs:
|
254
|
-
sum_output_ref = add_multiply_task.outputs[
|
259
|
+
sum_output_ref = add_multiply_task.outputs["sum"]
|
255
260
|
|
256
261
|
# These task output references can be passed to other component functions, constructing a computation graph:
|
257
262
|
task2 = add_multiply_op(sum_output_ref, 5)
|
@@ -266,14 +271,19 @@ def create_component_from_func(
|
|
266
271
|
Example of a component function declaring file input and output::
|
267
272
|
|
268
273
|
def catboost_train_classifier(
|
269
|
-
training_data_path: InputPath(
|
270
|
-
trained_model_path: OutputPath(
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
274
|
+
training_data_path: InputPath("CSV"), # Path to input data file of type "CSV"
|
275
|
+
trained_model_path: OutputPath(
|
276
|
+
"CatBoostModel"
|
277
|
+
), # Path to output data file of type "CatBoostModel"
|
278
|
+
number_of_trees: int = 100, # Small output of type "Integer"
|
279
|
+
) -> NamedTuple(
|
280
|
+
"Outputs",
|
281
|
+
[
|
282
|
+
("Accuracy", float), # Small output of type "Float"
|
283
|
+
("Precision", float), # Small output of type "Float"
|
284
|
+
("JobUri", "URI"), # Small output of type "URI"
|
285
|
+
],
|
286
|
+
):
|
277
287
|
"""Train CatBoost classification model"""
|
278
288
|
...
|
279
289
|
|
@@ -192,7 +192,9 @@ class WandbLogger(Logger):
|
|
192
192
|
wandb_logger.log_image(key="samples", images=[img1, img2])
|
193
193
|
|
194
194
|
# adding captions
|
195
|
-
wandb_logger.log_image(
|
195
|
+
wandb_logger.log_image(
|
196
|
+
key="samples", images=[img1, img2], caption=["tree", "person"]
|
197
|
+
)
|
196
198
|
|
197
199
|
# using file path
|
198
200
|
wandb_logger.log_image(key="samples", images=["img_1.jpg", "img_2.jpg"])
|
@@ -30,8 +30,9 @@ def patch(
|
|
30
30
|
) -> None:
|
31
31
|
if len(wandb.patched["tensorboard"]) > 0:
|
32
32
|
raise ValueError(
|
33
|
-
"Tensorboard already patched
|
34
|
-
"
|
33
|
+
"Tensorboard already patched. Call `wandb.tensorboard.unpatch()` first; "
|
34
|
+
"remove `sync_tensorboard=True` from `wandb.init`; "
|
35
|
+
"or only call `wandb.tensorboard.patch` once."
|
35
36
|
)
|
36
37
|
|
37
38
|
# TODO: Some older versions of tensorflow don't require tensorboard to be present.
|
wandb/jupyter.py
CHANGED
@@ -15,10 +15,9 @@ import wandb.util
|
|
15
15
|
from wandb.sdk.lib import filesystem
|
16
16
|
|
17
17
|
try:
|
18
|
-
|
18
|
+
import IPython
|
19
19
|
from IPython.core.magic import Magics, line_cell_magic, magics_class
|
20
20
|
from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring
|
21
|
-
from IPython.display import display
|
22
21
|
except ImportError:
|
23
22
|
wandb.termwarn("ipython is not supported in python 2.7, upgrade to 3.x")
|
24
23
|
|
@@ -66,7 +65,7 @@ class IFrame:
|
|
66
65
|
|
67
66
|
def maybe_display(self) -> bool:
|
68
67
|
if not self.displayed and (self.path or wandb.run):
|
69
|
-
display(self)
|
68
|
+
IPython.display.display(self)
|
70
69
|
return self.displayed
|
71
70
|
|
72
71
|
def _repr_html_(self):
|
@@ -155,7 +154,7 @@ class WandBMagics(Magics):
|
|
155
154
|
+ cell
|
156
155
|
+ "\nwandb.jupyter.__IFrame = None"
|
157
156
|
)
|
158
|
-
get_ipython().run_cell(cell)
|
157
|
+
IPython.get_ipython().run_cell(cell)
|
159
158
|
|
160
159
|
|
161
160
|
def notebook_metadata_from_jupyter_servers_and_kernel_id():
|
@@ -343,7 +342,7 @@ class Notebook:
|
|
343
342
|
def __init__(self, settings):
|
344
343
|
self.outputs = {}
|
345
344
|
self.settings = settings
|
346
|
-
self.shell = get_ipython()
|
345
|
+
self.shell = IPython.get_ipython()
|
347
346
|
|
348
347
|
def save_display(self, exc_count, data_with_metadata):
|
349
348
|
self.outputs[exc_count] = self.outputs.get(exc_count, [])
|
wandb/plot/bar.py
CHANGED
@@ -38,10 +38,10 @@ def bar(
|
|
38
38
|
|
39
39
|
# Generate random data for the table
|
40
40
|
data = [
|
41
|
-
[
|
42
|
-
[
|
43
|
-
[
|
44
|
-
[
|
41
|
+
["car", random.uniform(0, 1)],
|
42
|
+
["bus", random.uniform(0, 1)],
|
43
|
+
["road", random.uniform(0, 1)],
|
44
|
+
["person", random.uniform(0, 1)],
|
45
45
|
]
|
46
46
|
|
47
47
|
# Create a table with the data
|
@@ -49,7 +49,6 @@ def bar(
|
|
49
49
|
|
50
50
|
# Initialize a W&B run and log the bar plot
|
51
51
|
with wandb.init(project="bar_chart") as run:
|
52
|
-
|
53
52
|
# Create a bar plot from the table
|
54
53
|
bar_plot = wandb.plot.bar(
|
55
54
|
table=table,
|
@@ -59,7 +58,7 @@ def bar(
|
|
59
58
|
)
|
60
59
|
|
61
60
|
# Log the bar chart to W&B
|
62
|
-
run.log({
|
61
|
+
run.log({"bar_plot": bar_plot})
|
63
62
|
```
|
64
63
|
"""
|
65
64
|
return plot_table(
|
wandb/plot/histogram.py
CHANGED
wandb/plot/line_series.py
CHANGED
@@ -53,9 +53,9 @@ def line_series(
|
|
53
53
|
|
54
54
|
# Multiple y series to plot
|
55
55
|
ys = [
|
56
|
-
[i for i in range(10)],
|
57
|
-
[i**2 for i in range(10)],
|
58
|
-
[i**3 for i in range(10)],
|
56
|
+
[i for i in range(10)], # y = x
|
57
|
+
[i**2 for i in range(10)], # y = x^2
|
58
|
+
[i**3 for i in range(10)], # y = x^3
|
59
59
|
]
|
60
60
|
|
61
61
|
# Generate and log the line series chart
|
wandb/plot/pr_curve.py
CHANGED
@@ -74,10 +74,10 @@ def pr_curve(
|
|
74
74
|
[0.2, 0.8], # Second sample (spam), and so on
|
75
75
|
[0.1, 0.9],
|
76
76
|
[0.8, 0.2],
|
77
|
-
[0.3, 0.7]
|
77
|
+
[0.3, 0.7],
|
78
78
|
]
|
79
79
|
|
80
|
-
labels = [
|
80
|
+
labels = ["not spam", "spam"] # Optional class names for readability
|
81
81
|
|
82
82
|
with wandb.init(project="spam-detection") as run:
|
83
83
|
pr_curve = wandb.plot.pr_curve(
|
@@ -176,6 +176,10 @@ def pr_curve(
|
|
176
176
|
"y": "precision",
|
177
177
|
"class": "class",
|
178
178
|
},
|
179
|
-
string_fields={
|
179
|
+
string_fields={
|
180
|
+
"title": title,
|
181
|
+
"x-axis-title": "Recall",
|
182
|
+
"y-axis-title": "Precision",
|
183
|
+
},
|
180
184
|
split_table=split_table,
|
181
185
|
)
|
wandb/plot/scatter.py
CHANGED
@@ -38,7 +38,8 @@ def scatter(
|
|
38
38
|
|
39
39
|
# Simulate temperature variations at different altitudes over time
|
40
40
|
data = [
|
41
|
-
[i, random.uniform(-10, 20) - 0.005 * i + 5 * math.sin(i / 50)]
|
41
|
+
[i, random.uniform(-10, 20) - 0.005 * i + 5 * math.sin(i / 50)]
|
42
|
+
for i in range(300)
|
42
43
|
]
|
43
44
|
|
44
45
|
# Create W&B table with altitude (m) and temperature (°C) columns
|