wandb 0.19.0__py3-none-win32.whl → 0.19.1__py3-none-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -7
- wandb/__init__.pyi +211 -209
- wandb/apis/attrs.py +15 -4
- wandb/apis/public/api.py +8 -4
- wandb/apis/public/files.py +65 -12
- wandb/apis/public/runs.py +52 -7
- wandb/apis/public/sweeps.py +1 -1
- wandb/bin/gpu_stats.exe +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +2 -1
- wandb/env.py +1 -1
- wandb/errors/term.py +60 -1
- wandb/integration/keras/callbacks/tables_builder.py +3 -1
- wandb/integration/kfp/kfp_patch.py +25 -15
- wandb/integration/lightning/fabric/logger.py +3 -1
- wandb/integration/tensorboard/monkeypatch.py +3 -2
- wandb/jupyter.py +4 -5
- wandb/plot/bar.py +5 -6
- wandb/plot/histogram.py +1 -1
- wandb/plot/line_series.py +3 -3
- wandb/plot/pr_curve.py +7 -3
- wandb/plot/scatter.py +2 -1
- wandb/proto/v3/wandb_settings_pb2.py +25 -15
- wandb/proto/v4/wandb_settings_pb2.py +17 -15
- wandb/proto/v5/wandb_settings_pb2.py +17 -15
- wandb/sdk/artifacts/_validators.py +1 -3
- wandb/sdk/artifacts/artifact_manifest_entry.py +1 -1
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +12 -2
- wandb/sdk/data_types/helper_types/image_mask.py +8 -2
- wandb/sdk/data_types/histogram.py +3 -3
- wandb/sdk/data_types/image.py +3 -1
- wandb/sdk/interface/interface.py +34 -5
- wandb/sdk/interface/interface_sock.py +2 -2
- wandb/sdk/internal/file_stream.py +4 -1
- wandb/sdk/internal/sender.py +4 -1
- wandb/sdk/internal/settings_static.py +17 -4
- wandb/sdk/launch/utils.py +1 -0
- wandb/sdk/lib/ipython.py +5 -27
- wandb/sdk/lib/printer.py +33 -20
- wandb/sdk/lib/progress.py +7 -1
- wandb/sdk/lib/sparkline.py +1 -2
- wandb/sdk/wandb_config.py +2 -2
- wandb/sdk/wandb_init.py +236 -243
- wandb/sdk/wandb_run.py +172 -231
- wandb/sdk/wandb_settings.py +104 -15
- {wandb-0.19.0.dist-info → wandb-0.19.1.dist-info}/METADATA +1 -1
- {wandb-0.19.0.dist-info → wandb-0.19.1.dist-info}/RECORD +50 -50
- {wandb-0.19.0.dist-info → wandb-0.19.1.dist-info}/WHEEL +0 -0
- {wandb-0.19.0.dist-info → wandb-0.19.1.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.0.dist-info → wandb-0.19.1.dist-info}/licenses/LICENSE +0 -0
wandb/apis/public/files.py
CHANGED
@@ -49,6 +49,7 @@ class Files(Paginator):
|
|
49
49
|
query RunFiles($project: String!, $entity: String!, $name: String!, $fileCursor: String,
|
50
50
|
$fileLimit: Int = 50, $fileNames: [String] = [], $upload: Boolean = false) {{
|
51
51
|
project(name: $project, entityName: $entity) {{
|
52
|
+
internalId
|
52
53
|
run(name: $name) {{
|
53
54
|
fileCount
|
54
55
|
...RunFilesFragment
|
@@ -98,7 +99,7 @@ class Files(Paginator):
|
|
98
99
|
|
99
100
|
def convert_objects(self):
|
100
101
|
return [
|
101
|
-
File(self.client, r["node"])
|
102
|
+
File(self.client, r["node"], self.run)
|
102
103
|
for r in self.last_response["project"]["run"]["files"]["edges"]
|
103
104
|
]
|
104
105
|
|
@@ -120,9 +121,11 @@ class File(Attrs):
|
|
120
121
|
path_uri (str): path to file in the bucket, currently only available for files stored in S3
|
121
122
|
"""
|
122
123
|
|
123
|
-
def __init__(self, client, attrs):
|
124
|
+
def __init__(self, client, attrs, run=None):
|
124
125
|
self.client = client
|
125
126
|
self._attrs = attrs
|
127
|
+
self.run = run
|
128
|
+
self.server_supports_delete_file_with_project_id: Optional[bool] = None
|
126
129
|
super().__init__(dict(attrs))
|
127
130
|
|
128
131
|
@property
|
@@ -189,18 +192,35 @@ class File(Attrs):
|
|
189
192
|
|
190
193
|
@normalize_exceptions
|
191
194
|
def delete(self):
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
files: $files
|
197
|
-
}) {
|
198
|
-
success
|
199
|
-
}
|
195
|
+
project_id_mutation_fragment = ""
|
196
|
+
project_id_variable_fragment = ""
|
197
|
+
variable_values = {
|
198
|
+
"files": [self.id],
|
200
199
|
}
|
201
|
-
|
200
|
+
|
201
|
+
# Add projectId to mutation and variables if the server supports it.
|
202
|
+
# Otherwise, do not include projectId in mutation for older server versions which do not support it.
|
203
|
+
if self._server_accepts_project_id_for_delete_file():
|
204
|
+
variable_values["projectId"] = self.run._project_internal_id
|
205
|
+
project_id_variable_fragment = ", $projectId: Int"
|
206
|
+
project_id_mutation_fragment = "projectId: $projectId"
|
207
|
+
|
208
|
+
mutation_string = """
|
209
|
+
mutation deleteFiles($files: [ID!]!{}) {{
|
210
|
+
deleteFiles(input: {{
|
211
|
+
files: $files
|
212
|
+
{}
|
213
|
+
}}) {{
|
214
|
+
success
|
215
|
+
}}
|
216
|
+
}}
|
217
|
+
""".format(project_id_variable_fragment, project_id_mutation_fragment)
|
218
|
+
mutation = gql(mutation_string)
|
219
|
+
|
220
|
+
self.client.execute(
|
221
|
+
mutation,
|
222
|
+
variable_values=variable_values,
|
202
223
|
)
|
203
|
-
self.client.execute(mutation, variable_values={"files": [self.id]})
|
204
224
|
|
205
225
|
def __repr__(self):
|
206
226
|
return "<File {} ({}) {}>".format(
|
@@ -208,3 +228,36 @@ class File(Attrs):
|
|
208
228
|
self.mimetype,
|
209
229
|
util.to_human_size(self.size, units=util.POW_2_BYTES),
|
210
230
|
)
|
231
|
+
|
232
|
+
@normalize_exceptions
|
233
|
+
def _server_accepts_project_id_for_delete_file(self) -> bool:
|
234
|
+
"""Returns True if the server supports deleting files with a projectId.
|
235
|
+
|
236
|
+
This check is done by utilizing GraphQL introspection in the avaiable fields on the DeleteFiles API.
|
237
|
+
"""
|
238
|
+
query_string = """
|
239
|
+
query ProbeDeleteFilesProjectIdInput {
|
240
|
+
DeleteFilesProjectIdInputType: __type(name:"DeleteFilesInput") {
|
241
|
+
inputFields{
|
242
|
+
name
|
243
|
+
}
|
244
|
+
}
|
245
|
+
}
|
246
|
+
"""
|
247
|
+
|
248
|
+
# Only perform the query once to avoid extra network calls
|
249
|
+
if self.server_supports_delete_file_with_project_id is None:
|
250
|
+
query = gql(query_string)
|
251
|
+
res = self.client.execute(query)
|
252
|
+
|
253
|
+
# If projectId is in the inputFields, the server supports deleting files with a projectId
|
254
|
+
self.server_supports_delete_file_with_project_id = "projectId" in [
|
255
|
+
x["name"]
|
256
|
+
for x in (
|
257
|
+
res.get("DeleteFilesProjectIdInputType", {}).get(
|
258
|
+
"inputFields", [{}]
|
259
|
+
)
|
260
|
+
)
|
261
|
+
]
|
262
|
+
|
263
|
+
return self.server_supports_delete_file_with_project_id
|
wandb/apis/public/runs.py
CHANGED
@@ -71,6 +71,7 @@ class Runs(Paginator):
|
|
71
71
|
"""
|
72
72
|
query Runs($project: String!, $entity: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{
|
73
73
|
project(name: $project, entityName: $entity) {{
|
74
|
+
internalId
|
74
75
|
runCount(filters: $filters)
|
75
76
|
readOnly
|
76
77
|
runs(filters: $filters, after: $cursor, first: $perPage, order: $order) {{
|
@@ -103,6 +104,7 @@ class Runs(Paginator):
|
|
103
104
|
):
|
104
105
|
self.entity = entity
|
105
106
|
self.project = project
|
107
|
+
self._project_internal_id = None
|
106
108
|
self.filters = filters or {}
|
107
109
|
self.order = order
|
108
110
|
self._sweeps = {}
|
@@ -287,6 +289,7 @@ class Run(Attrs):
|
|
287
289
|
Calling update will persist any changes.
|
288
290
|
project (str): the project associated with the run
|
289
291
|
entity (str): the name of the entity associated with the run
|
292
|
+
project_internal_id (int): the internal id of the project
|
290
293
|
user (str): the name of the user who created the run
|
291
294
|
path (str): Unique identifier [entity]/[project]/[run_id]
|
292
295
|
notes (str): Notes about the run
|
@@ -328,6 +331,7 @@ class Run(Attrs):
|
|
328
331
|
self._summary = None
|
329
332
|
self._metadata: Optional[Dict[str, Any]] = None
|
330
333
|
self._state = _attrs.get("state", "not found")
|
334
|
+
self.server_provides_internal_id_field: Optional[bool] = None
|
331
335
|
|
332
336
|
self.load(force=not _attrs)
|
333
337
|
|
@@ -417,13 +421,18 @@ class Run(Attrs):
|
|
417
421
|
"""
|
418
422
|
query Run($project: String!, $entity: String!, $name: String!) {{
|
419
423
|
project(name: $project, entityName: $entity) {{
|
424
|
+
{}
|
420
425
|
run(name: $name) {{
|
421
426
|
...RunFragment
|
422
427
|
}}
|
423
428
|
}}
|
424
429
|
}}
|
425
430
|
{}
|
426
|
-
""".format(
|
431
|
+
""".format(
|
432
|
+
# Only query internalId if the server supports it
|
433
|
+
"internalId" if self._server_provides_internal_id_for_project() else "",
|
434
|
+
RUN_FRAGMENT,
|
435
|
+
)
|
427
436
|
)
|
428
437
|
if force or not self._attrs:
|
429
438
|
response = self._exec(query)
|
@@ -435,7 +444,7 @@ class Run(Attrs):
|
|
435
444
|
raise ValueError("Could not find run {}".format(self))
|
436
445
|
self._attrs = response["project"]["run"]
|
437
446
|
self._state = self._attrs["state"]
|
438
|
-
|
447
|
+
self._project_internal_id = response["project"].get("internalId", None)
|
439
448
|
if self._include_sweeps and self.sweep_name and not self.sweep:
|
440
449
|
# There may be a lot of runs. Don't bother pulling them all
|
441
450
|
# just for the sake of this one.
|
@@ -495,7 +504,6 @@ class Run(Attrs):
|
|
495
504
|
res = self._exec(query)
|
496
505
|
state = res["project"]["run"]["state"]
|
497
506
|
if state in ["finished", "crashed", "failed"]:
|
498
|
-
print(f"Run finished with status: {state}")
|
499
507
|
self._attrs["state"] = state
|
500
508
|
self._state = state
|
501
509
|
return
|
@@ -506,8 +514,8 @@ class Run(Attrs):
|
|
506
514
|
"""Persist changes to the run object to the wandb backend."""
|
507
515
|
mutation = gql(
|
508
516
|
"""
|
509
|
-
mutation UpsertBucket($id: String!, $description: String, $display_name: String, $notes: String, $tags: [String!], $config: JSONString!, $groupName: String) {{
|
510
|
-
upsertBucket(input: {{id: $id, description: $description, displayName: $display_name, notes: $notes, tags: $tags, config: $config, groupName: $groupName}}) {{
|
517
|
+
mutation UpsertBucket($id: String!, $description: String, $display_name: String, $notes: String, $tags: [String!], $config: JSONString!, $groupName: String, $jobType: String) {{
|
518
|
+
upsertBucket(input: {{id: $id, description: $description, displayName: $display_name, notes: $notes, tags: $tags, config: $config, groupName: $groupName, jobType: $jobType}}) {{
|
511
519
|
bucket {{
|
512
520
|
...RunFragment
|
513
521
|
}}
|
@@ -525,6 +533,7 @@ class Run(Attrs):
|
|
525
533
|
display_name=self.display_name,
|
526
534
|
config=self.json_config,
|
527
535
|
groupName=self.group,
|
536
|
+
jobType=self.job_type,
|
528
537
|
)
|
529
538
|
self.summary.update()
|
530
539
|
|
@@ -770,7 +779,9 @@ class Run(Attrs):
|
|
770
779
|
Example:
|
771
780
|
>>> import wandb
|
772
781
|
>>> import tempfile
|
773
|
-
>>> with tempfile.NamedTemporaryFile(
|
782
|
+
>>> with tempfile.NamedTemporaryFile(
|
783
|
+
... mode="w", delete=False, suffix=".txt"
|
784
|
+
... ) as tmp:
|
774
785
|
... tmp.write("This is a test artifact")
|
775
786
|
... tmp_path = tmp.name
|
776
787
|
>>> run = wandb.init(project="artifact-example")
|
@@ -893,6 +904,37 @@ class Run(Attrs):
|
|
893
904
|
)
|
894
905
|
return artifact
|
895
906
|
|
907
|
+
@normalize_exceptions
|
908
|
+
def _server_provides_internal_id_for_project(self) -> bool:
|
909
|
+
"""Returns True if the server allows us to query the internalId field for a project.
|
910
|
+
|
911
|
+
This check is done by utilizing GraphQL introspection in the avaiable fields on the Project type.
|
912
|
+
"""
|
913
|
+
query_string = """
|
914
|
+
query ProbeProjectInput {
|
915
|
+
ProjectType: __type(name:"Project") {
|
916
|
+
fields {
|
917
|
+
name
|
918
|
+
}
|
919
|
+
}
|
920
|
+
}
|
921
|
+
"""
|
922
|
+
|
923
|
+
# Only perform the query once to avoid extra network calls
|
924
|
+
if self.server_provides_internal_id_field is None:
|
925
|
+
query = gql(query_string)
|
926
|
+
res = self.client.execute(query)
|
927
|
+
print(
|
928
|
+
"internalId"
|
929
|
+
in [x["name"] for x in (res.get("ProjectType", {}).get("fields", [{}]))]
|
930
|
+
)
|
931
|
+
|
932
|
+
self.server_provides_internal_id_field = "internalId" in [
|
933
|
+
x["name"] for x in (res.get("ProjectType", {}).get("fields", [{}]))
|
934
|
+
]
|
935
|
+
|
936
|
+
return self.server_provides_internal_id_field
|
937
|
+
|
896
938
|
@property
|
897
939
|
def summary(self):
|
898
940
|
if self._summary is None:
|
@@ -921,7 +963,10 @@ class Run(Attrs):
|
|
921
963
|
if self._metadata is None:
|
922
964
|
try:
|
923
965
|
f = self.file("wandb-metadata.json")
|
924
|
-
|
966
|
+
session = self.client._client.transport.session
|
967
|
+
response = session.get(f.url, timeout=5)
|
968
|
+
response.raise_for_status()
|
969
|
+
contents = response.content
|
925
970
|
self._metadata = json_util.loads(contents)
|
926
971
|
except: # noqa: E722
|
927
972
|
# file doesn't exist, or can't be downloaded, or can't be parsed
|
wandb/apis/public/sweeps.py
CHANGED
wandb/bin/gpu_stats.exe
CHANGED
Binary file
|
wandb/bin/wandb-core
CHANGED
Binary file
|
wandb/cli/cli.py
CHANGED
@@ -33,6 +33,7 @@ from wandb.apis.public import RunQueue
|
|
33
33
|
from wandb.errors.links import url_registry
|
34
34
|
from wandb.sdk.artifacts._validators import is_artifact_registry_project
|
35
35
|
from wandb.sdk.artifacts.artifact_file_cache import get_artifact_file_cache
|
36
|
+
from wandb.sdk.internal.internal_api import Api as SDKInternalApi
|
36
37
|
from wandb.sdk.launch import utils as launch_utils
|
37
38
|
from wandb.sdk.launch._launch_add import _launch_add
|
38
39
|
from wandb.sdk.launch.errors import ExecutionError, LaunchError
|
@@ -2399,7 +2400,7 @@ def get(path, root, type):
|
|
2399
2400
|
settings_entity = public_api.settings["entity"] or public_api.default_entity
|
2400
2401
|
# Registry artifacts are under the org entity. Because we offer a shorthand and alias for this path,
|
2401
2402
|
# we need to fetch the org entity to for the user behind the scenes.
|
2402
|
-
entity =
|
2403
|
+
entity = SDKInternalApi()._resolve_org_entity_name(
|
2403
2404
|
entity=settings_entity, organization=organization
|
2404
2405
|
)
|
2405
2406
|
full_path = f"{entity}/{project}/{artifact_name}:{version}"
|
wandb/env.py
CHANGED
@@ -169,7 +169,7 @@ def error_reporting_enabled() -> bool:
|
|
169
169
|
|
170
170
|
|
171
171
|
def core_debug(default: str | None = None) -> bool:
|
172
|
-
return _env_as_bool(CORE_DEBUG, default=default)
|
172
|
+
return _env_as_bool(CORE_DEBUG, default=default) or is_debug()
|
173
173
|
|
174
174
|
|
175
175
|
def ssl_disabled() -> bool:
|
wandb/errors/term.py
CHANGED
@@ -5,6 +5,8 @@ from __future__ import annotations
|
|
5
5
|
import contextlib
|
6
6
|
import logging
|
7
7
|
import os
|
8
|
+
import re
|
9
|
+
import shutil
|
8
10
|
import sys
|
9
11
|
import threading
|
10
12
|
from typing import TYPE_CHECKING, Iterator, Protocol
|
@@ -271,10 +273,67 @@ class DynamicBlock:
|
|
271
273
|
The lock must be held.
|
272
274
|
"""
|
273
275
|
if self._lines_to_print:
|
274
|
-
|
276
|
+
# Trim lines before printing. This is crucial because the \x1b[Am
|
277
|
+
# (cursor up) sequence used when clearing the text moves up by one
|
278
|
+
# visual line, and the terminal may be wrapping long lines onto
|
279
|
+
# multiple visual lines.
|
280
|
+
#
|
281
|
+
# There is no ANSI escape sequence that moves the cursor up by one
|
282
|
+
# "physical" line instead. Note that the user may resize their
|
283
|
+
# terminal.
|
284
|
+
term_width = _shutil_get_terminal_width()
|
285
|
+
click.echo(
|
286
|
+
"\n".join(
|
287
|
+
_ansi_shorten(line, term_width) #
|
288
|
+
for line in self._lines_to_print
|
289
|
+
),
|
290
|
+
file=sys.stderr,
|
291
|
+
)
|
292
|
+
|
275
293
|
self._num_lines_printed += len(self._lines_to_print)
|
276
294
|
|
277
295
|
|
296
|
+
def _shutil_get_terminal_width() -> int:
|
297
|
+
"""Returns the width of the terminal.
|
298
|
+
|
299
|
+
Defined here for patching in tests.
|
300
|
+
"""
|
301
|
+
columns, _ = shutil.get_terminal_size()
|
302
|
+
return columns
|
303
|
+
|
304
|
+
|
305
|
+
_ANSI_RE = re.compile("\x1b\\[(K|.*?m)")
|
306
|
+
|
307
|
+
|
308
|
+
def _ansi_shorten(text: str, width: int) -> str:
|
309
|
+
"""Shorten text potentially containing ANSI sequences to fit a width."""
|
310
|
+
first_ansi = _ANSI_RE.search(text)
|
311
|
+
|
312
|
+
if not first_ansi:
|
313
|
+
return _raw_shorten(text, width)
|
314
|
+
|
315
|
+
if first_ansi.start() > width - 3:
|
316
|
+
return _raw_shorten(text[: first_ansi.start()], width)
|
317
|
+
|
318
|
+
return text[: first_ansi.end()] + _ansi_shorten(
|
319
|
+
text[first_ansi.end() :],
|
320
|
+
# Key part: the ANSI sequence doesn't reduce the remaining width.
|
321
|
+
width - first_ansi.start(),
|
322
|
+
)
|
323
|
+
|
324
|
+
|
325
|
+
def _raw_shorten(text: str, width: int) -> str:
|
326
|
+
"""Shorten text to fit a width, replacing the end with "...".
|
327
|
+
|
328
|
+
Unlike textwrap.shorten(), this does not drop whitespace or do anything
|
329
|
+
smart.
|
330
|
+
"""
|
331
|
+
if len(text) <= width:
|
332
|
+
return text
|
333
|
+
|
334
|
+
return text[: width - 3] + "..."
|
335
|
+
|
336
|
+
|
278
337
|
def _log(
|
279
338
|
string="",
|
280
339
|
newline=True,
|
@@ -145,7 +145,9 @@ class WandbEvalCallback(Callback, abc.ABC):
|
|
145
145
|
for idx, data in enumerate(dataloader):
|
146
146
|
preds = model.predict(data)
|
147
147
|
self.pred_table.add_data(
|
148
|
-
self.data_table_ref.data[idx][0],
|
148
|
+
self.data_table_ref.data[idx][0],
|
149
|
+
self.data_table_ref.data[idx][1],
|
150
|
+
preds,
|
149
151
|
)
|
150
152
|
```
|
151
153
|
This method is called `on_epoch_end` or equivalent hook.
|
@@ -208,23 +208,24 @@ def create_component_from_func(
|
|
208
208
|
"""Return sum of two arguments"""
|
209
209
|
return a + b
|
210
210
|
|
211
|
+
|
211
212
|
# add_op is a task factory function that creates a task object when given arguments
|
212
213
|
add_op = create_component_from_func(
|
213
214
|
func=add,
|
214
|
-
base_image=
|
215
|
-
output_component_file=
|
216
|
-
packages_to_install=[
|
215
|
+
base_image="python:3.7", # Optional
|
216
|
+
output_component_file="add.component.yaml", # Optional
|
217
|
+
packages_to_install=["pandas==0.24"], # Optional
|
217
218
|
)
|
218
219
|
|
219
220
|
# The component spec can be accessed through the .component_spec attribute:
|
220
|
-
add_op.component_spec.save(
|
221
|
+
add_op.component_spec.save("add.component.yaml")
|
221
222
|
|
222
223
|
# The component function can be called with arguments to create a task:
|
223
224
|
add_task = add_op(1, 3)
|
224
225
|
|
225
226
|
# The resulting task has output references, corresponding to the component outputs.
|
226
227
|
# When the function only has a single anonymous return value, the output name is "Output":
|
227
|
-
sum_output_ref = add_task.outputs[
|
228
|
+
sum_output_ref = add_task.outputs["Output"]
|
228
229
|
|
229
230
|
# These task output references can be passed to other component functions, constructing a computation graph:
|
230
231
|
task2 = add_op(sum_output_ref, 5)
|
@@ -241,17 +242,21 @@ def create_component_from_func(
|
|
241
242
|
|
242
243
|
from typing import NamedTuple
|
243
244
|
|
244
|
-
|
245
|
+
|
246
|
+
def add_multiply_two_numbers(a: float, b: float) -> NamedTuple(
|
247
|
+
"Outputs", [("sum", float), ("product", float)]
|
248
|
+
):
|
245
249
|
"""Return sum and product of two arguments"""
|
246
250
|
return (a + b, a * b)
|
247
251
|
|
252
|
+
|
248
253
|
add_multiply_op = create_component_from_func(add_multiply_two_numbers)
|
249
254
|
|
250
255
|
# The component function can be called with arguments to create a task:
|
251
256
|
add_multiply_task = add_multiply_op(1, 3)
|
252
257
|
|
253
258
|
# The resulting task has output references, corresponding to the component outputs:
|
254
|
-
sum_output_ref = add_multiply_task.outputs[
|
259
|
+
sum_output_ref = add_multiply_task.outputs["sum"]
|
255
260
|
|
256
261
|
# These task output references can be passed to other component functions, constructing a computation graph:
|
257
262
|
task2 = add_multiply_op(sum_output_ref, 5)
|
@@ -266,14 +271,19 @@ def create_component_from_func(
|
|
266
271
|
Example of a component function declaring file input and output::
|
267
272
|
|
268
273
|
def catboost_train_classifier(
|
269
|
-
training_data_path: InputPath(
|
270
|
-
trained_model_path: OutputPath(
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
274
|
+
training_data_path: InputPath("CSV"), # Path to input data file of type "CSV"
|
275
|
+
trained_model_path: OutputPath(
|
276
|
+
"CatBoostModel"
|
277
|
+
), # Path to output data file of type "CatBoostModel"
|
278
|
+
number_of_trees: int = 100, # Small output of type "Integer"
|
279
|
+
) -> NamedTuple(
|
280
|
+
"Outputs",
|
281
|
+
[
|
282
|
+
("Accuracy", float), # Small output of type "Float"
|
283
|
+
("Precision", float), # Small output of type "Float"
|
284
|
+
("JobUri", "URI"), # Small output of type "URI"
|
285
|
+
],
|
286
|
+
):
|
277
287
|
"""Train CatBoost classification model"""
|
278
288
|
...
|
279
289
|
|
@@ -192,7 +192,9 @@ class WandbLogger(Logger):
|
|
192
192
|
wandb_logger.log_image(key="samples", images=[img1, img2])
|
193
193
|
|
194
194
|
# adding captions
|
195
|
-
wandb_logger.log_image(
|
195
|
+
wandb_logger.log_image(
|
196
|
+
key="samples", images=[img1, img2], caption=["tree", "person"]
|
197
|
+
)
|
196
198
|
|
197
199
|
# using file path
|
198
200
|
wandb_logger.log_image(key="samples", images=["img_1.jpg", "img_2.jpg"])
|
@@ -30,8 +30,9 @@ def patch(
|
|
30
30
|
) -> None:
|
31
31
|
if len(wandb.patched["tensorboard"]) > 0:
|
32
32
|
raise ValueError(
|
33
|
-
"Tensorboard already patched
|
34
|
-
"
|
33
|
+
"Tensorboard already patched. Call `wandb.tensorboard.unpatch()` first; "
|
34
|
+
"remove `sync_tensorboard=True` from `wandb.init`; "
|
35
|
+
"or only call `wandb.tensorboard.patch` once."
|
35
36
|
)
|
36
37
|
|
37
38
|
# TODO: Some older versions of tensorflow don't require tensorboard to be present.
|
wandb/jupyter.py
CHANGED
@@ -15,10 +15,9 @@ import wandb.util
|
|
15
15
|
from wandb.sdk.lib import filesystem
|
16
16
|
|
17
17
|
try:
|
18
|
-
|
18
|
+
import IPython
|
19
19
|
from IPython.core.magic import Magics, line_cell_magic, magics_class
|
20
20
|
from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring
|
21
|
-
from IPython.display import display
|
22
21
|
except ImportError:
|
23
22
|
wandb.termwarn("ipython is not supported in python 2.7, upgrade to 3.x")
|
24
23
|
|
@@ -66,7 +65,7 @@ class IFrame:
|
|
66
65
|
|
67
66
|
def maybe_display(self) -> bool:
|
68
67
|
if not self.displayed and (self.path or wandb.run):
|
69
|
-
display(self)
|
68
|
+
IPython.display.display(self)
|
70
69
|
return self.displayed
|
71
70
|
|
72
71
|
def _repr_html_(self):
|
@@ -155,7 +154,7 @@ class WandBMagics(Magics):
|
|
155
154
|
+ cell
|
156
155
|
+ "\nwandb.jupyter.__IFrame = None"
|
157
156
|
)
|
158
|
-
get_ipython().run_cell(cell)
|
157
|
+
IPython.get_ipython().run_cell(cell)
|
159
158
|
|
160
159
|
|
161
160
|
def notebook_metadata_from_jupyter_servers_and_kernel_id():
|
@@ -343,7 +342,7 @@ class Notebook:
|
|
343
342
|
def __init__(self, settings):
|
344
343
|
self.outputs = {}
|
345
344
|
self.settings = settings
|
346
|
-
self.shell = get_ipython()
|
345
|
+
self.shell = IPython.get_ipython()
|
347
346
|
|
348
347
|
def save_display(self, exc_count, data_with_metadata):
|
349
348
|
self.outputs[exc_count] = self.outputs.get(exc_count, [])
|
wandb/plot/bar.py
CHANGED
@@ -38,10 +38,10 @@ def bar(
|
|
38
38
|
|
39
39
|
# Generate random data for the table
|
40
40
|
data = [
|
41
|
-
[
|
42
|
-
[
|
43
|
-
[
|
44
|
-
[
|
41
|
+
["car", random.uniform(0, 1)],
|
42
|
+
["bus", random.uniform(0, 1)],
|
43
|
+
["road", random.uniform(0, 1)],
|
44
|
+
["person", random.uniform(0, 1)],
|
45
45
|
]
|
46
46
|
|
47
47
|
# Create a table with the data
|
@@ -49,7 +49,6 @@ def bar(
|
|
49
49
|
|
50
50
|
# Initialize a W&B run and log the bar plot
|
51
51
|
with wandb.init(project="bar_chart") as run:
|
52
|
-
|
53
52
|
# Create a bar plot from the table
|
54
53
|
bar_plot = wandb.plot.bar(
|
55
54
|
table=table,
|
@@ -59,7 +58,7 @@ def bar(
|
|
59
58
|
)
|
60
59
|
|
61
60
|
# Log the bar chart to W&B
|
62
|
-
run.log({
|
61
|
+
run.log({"bar_plot": bar_plot})
|
63
62
|
```
|
64
63
|
"""
|
65
64
|
return plot_table(
|
wandb/plot/histogram.py
CHANGED
wandb/plot/line_series.py
CHANGED
@@ -53,9 +53,9 @@ def line_series(
|
|
53
53
|
|
54
54
|
# Multiple y series to plot
|
55
55
|
ys = [
|
56
|
-
[i for i in range(10)],
|
57
|
-
[i**2 for i in range(10)],
|
58
|
-
[i**3 for i in range(10)],
|
56
|
+
[i for i in range(10)], # y = x
|
57
|
+
[i**2 for i in range(10)], # y = x^2
|
58
|
+
[i**3 for i in range(10)], # y = x^3
|
59
59
|
]
|
60
60
|
|
61
61
|
# Generate and log the line series chart
|
wandb/plot/pr_curve.py
CHANGED
@@ -74,10 +74,10 @@ def pr_curve(
|
|
74
74
|
[0.2, 0.8], # Second sample (spam), and so on
|
75
75
|
[0.1, 0.9],
|
76
76
|
[0.8, 0.2],
|
77
|
-
[0.3, 0.7]
|
77
|
+
[0.3, 0.7],
|
78
78
|
]
|
79
79
|
|
80
|
-
labels = [
|
80
|
+
labels = ["not spam", "spam"] # Optional class names for readability
|
81
81
|
|
82
82
|
with wandb.init(project="spam-detection") as run:
|
83
83
|
pr_curve = wandb.plot.pr_curve(
|
@@ -176,6 +176,10 @@ def pr_curve(
|
|
176
176
|
"y": "precision",
|
177
177
|
"class": "class",
|
178
178
|
},
|
179
|
-
string_fields={
|
179
|
+
string_fields={
|
180
|
+
"title": title,
|
181
|
+
"x-axis-title": "Recall",
|
182
|
+
"y-axis-title": "Precision",
|
183
|
+
},
|
180
184
|
split_table=split_table,
|
181
185
|
)
|
wandb/plot/scatter.py
CHANGED
@@ -38,7 +38,8 @@ def scatter(
|
|
38
38
|
|
39
39
|
# Simulate temperature variations at different altitudes over time
|
40
40
|
data = [
|
41
|
-
[i, random.uniform(-10, 20) - 0.005 * i + 5 * math.sin(i / 50)]
|
41
|
+
[i, random.uniform(-10, 20) - 0.005 * i + 5 * math.sin(i / 50)]
|
42
|
+
for i in range(300)
|
42
43
|
]
|
43
44
|
|
44
45
|
# Create W&B table with altitude (m) and temperature (°C) columns
|