wandb 0.18.0rc1__py3-none-win32.whl → 0.18.1__py3-none-win32.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +2 -2
- wandb/__init__.pyi +1 -1
- wandb/apis/public/runs.py +2 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +0 -2
- wandb/data_types.py +9 -2019
- wandb/env.py +0 -5
- wandb/{sklearn → integration/sklearn}/calculate/calibration_curves.py +7 -7
- wandb/{sklearn → integration/sklearn}/calculate/class_proportions.py +1 -1
- wandb/{sklearn → integration/sklearn}/calculate/confusion_matrix.py +3 -2
- wandb/{sklearn → integration/sklearn}/calculate/elbow_curve.py +6 -6
- wandb/{sklearn → integration/sklearn}/calculate/learning_curve.py +2 -2
- wandb/{sklearn → integration/sklearn}/calculate/outlier_candidates.py +2 -2
- wandb/{sklearn → integration/sklearn}/calculate/residuals.py +8 -8
- wandb/{sklearn → integration/sklearn}/calculate/silhouette.py +2 -2
- wandb/{sklearn → integration/sklearn}/calculate/summary_metrics.py +2 -2
- wandb/{sklearn → integration/sklearn}/plot/classifier.py +5 -5
- wandb/{sklearn → integration/sklearn}/plot/clusterer.py +10 -6
- wandb/{sklearn → integration/sklearn}/plot/regressor.py +5 -5
- wandb/{sklearn → integration/sklearn}/plot/shared.py +3 -3
- wandb/{sklearn → integration/sklearn}/utils.py +8 -8
- wandb/{wandb_torch.py → integration/torch/wandb_torch.py} +36 -32
- wandb/proto/v3/wandb_base_pb2.py +2 -1
- wandb/proto/v3/wandb_internal_pb2.py +2 -1
- wandb/proto/v3/wandb_server_pb2.py +2 -1
- wandb/proto/v3/wandb_settings_pb2.py +2 -1
- wandb/proto/v3/wandb_telemetry_pb2.py +2 -1
- wandb/proto/v4/wandb_base_pb2.py +2 -1
- wandb/proto/v4/wandb_internal_pb2.py +2 -1
- wandb/proto/v4/wandb_server_pb2.py +2 -1
- wandb/proto/v4/wandb_settings_pb2.py +2 -1
- wandb/proto/v4/wandb_telemetry_pb2.py +2 -1
- wandb/proto/v5/wandb_base_pb2.py +3 -2
- wandb/proto/v5/wandb_internal_pb2.py +3 -2
- wandb/proto/v5/wandb_server_pb2.py +3 -2
- wandb/proto/v5/wandb_settings_pb2.py +3 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +3 -2
- wandb/sdk/data_types/audio.py +165 -0
- wandb/sdk/data_types/bokeh.py +70 -0
- wandb/sdk/data_types/graph.py +405 -0
- wandb/sdk/data_types/image.py +156 -0
- wandb/sdk/data_types/table.py +1204 -0
- wandb/sdk/data_types/trace_tree.py +2 -2
- wandb/sdk/data_types/utils.py +49 -0
- wandb/sdk/service/service.py +2 -9
- wandb/sdk/service/streams.py +0 -7
- wandb/sdk/wandb_init.py +10 -3
- wandb/sdk/wandb_run.py +6 -152
- wandb/sdk/wandb_setup.py +1 -1
- wandb/sklearn.py +35 -0
- wandb/util.py +6 -2
- {wandb-0.18.0rc1.dist-info → wandb-0.18.1.dist-info}/METADATA +5 -5
- {wandb-0.18.0rc1.dist-info → wandb-0.18.1.dist-info}/RECORD +61 -57
- wandb/sdk/lib/console.py +0 -39
- /wandb/{sklearn → integration/sklearn}/__init__.py +0 -0
- /wandb/{sklearn → integration/sklearn}/calculate/__init__.py +0 -0
- /wandb/{sklearn → integration/sklearn}/calculate/decision_boundaries.py +0 -0
- /wandb/{sklearn → integration/sklearn}/calculate/feature_importances.py +0 -0
- /wandb/{sklearn → integration/sklearn}/plot/__init__.py +0 -0
- {wandb-0.18.0rc1.dist-info → wandb-0.18.1.dist-info}/WHEEL +0 -0
- {wandb-0.18.0rc1.dist-info → wandb-0.18.1.dist-info}/entry_points.txt +0 -0
- {wandb-0.18.0rc1.dist-info → wandb-0.18.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,165 @@
|
|
1
|
+
import hashlib
|
2
|
+
import os
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
from wandb import util
|
6
|
+
from wandb.sdk.lib import filesystem, runid
|
7
|
+
|
8
|
+
from . import _dtypes
|
9
|
+
from ._private import MEDIA_TMP
|
10
|
+
from .base_types.media import BatchableMedia
|
11
|
+
|
12
|
+
|
13
|
+
class Audio(BatchableMedia):
|
14
|
+
"""Wandb class for audio clips.
|
15
|
+
|
16
|
+
Arguments:
|
17
|
+
data_or_path: (string or numpy array) A path to an audio file
|
18
|
+
or a numpy array of audio data.
|
19
|
+
sample_rate: (int) Sample rate, required when passing in raw
|
20
|
+
numpy array of audio data.
|
21
|
+
caption: (string) Caption to display with audio.
|
22
|
+
"""
|
23
|
+
|
24
|
+
_log_type = "audio-file"
|
25
|
+
|
26
|
+
def __init__(self, data_or_path, sample_rate=None, caption=None):
|
27
|
+
"""Accept a path to an audio file or a numpy array of audio data."""
|
28
|
+
super().__init__()
|
29
|
+
self._duration = None
|
30
|
+
self._sample_rate = sample_rate
|
31
|
+
self._caption = caption
|
32
|
+
|
33
|
+
if isinstance(data_or_path, str):
|
34
|
+
if self.path_is_reference(data_or_path):
|
35
|
+
self._path = data_or_path
|
36
|
+
self._sha256 = hashlib.sha256(data_or_path.encode("utf-8")).hexdigest()
|
37
|
+
self._is_tmp = False
|
38
|
+
else:
|
39
|
+
self._set_file(data_or_path, is_tmp=False)
|
40
|
+
else:
|
41
|
+
if sample_rate is None:
|
42
|
+
raise ValueError(
|
43
|
+
'Argument "sample_rate" is required when instantiating wandb.Audio with raw data.'
|
44
|
+
)
|
45
|
+
|
46
|
+
soundfile = util.get_module(
|
47
|
+
"soundfile",
|
48
|
+
required='Raw audio requires the soundfile package. To get it, run "pip install soundfile"',
|
49
|
+
)
|
50
|
+
|
51
|
+
tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".wav")
|
52
|
+
soundfile.write(tmp_path, data_or_path, sample_rate)
|
53
|
+
self._duration = len(data_or_path) / float(sample_rate)
|
54
|
+
|
55
|
+
self._set_file(tmp_path, is_tmp=True)
|
56
|
+
|
57
|
+
@classmethod
|
58
|
+
def get_media_subdir(cls):
|
59
|
+
return os.path.join("media", "audio")
|
60
|
+
|
61
|
+
@classmethod
|
62
|
+
def from_json(cls, json_obj, source_artifact):
|
63
|
+
return cls(
|
64
|
+
source_artifact.get_entry(json_obj["path"]).download(),
|
65
|
+
caption=json_obj["caption"],
|
66
|
+
)
|
67
|
+
|
68
|
+
def bind_to_run(
|
69
|
+
self, run, key, step, id_=None, ignore_copy_err: Optional[bool] = None
|
70
|
+
):
|
71
|
+
if self.path_is_reference(self._path):
|
72
|
+
raise ValueError(
|
73
|
+
"Audio media created by a reference to external storage cannot currently be added to a run"
|
74
|
+
)
|
75
|
+
|
76
|
+
return super().bind_to_run(run, key, step, id_, ignore_copy_err)
|
77
|
+
|
78
|
+
def to_json(self, run):
|
79
|
+
json_dict = super().to_json(run)
|
80
|
+
json_dict.update(
|
81
|
+
{
|
82
|
+
"_type": self._log_type,
|
83
|
+
"caption": self._caption,
|
84
|
+
}
|
85
|
+
)
|
86
|
+
return json_dict
|
87
|
+
|
88
|
+
@classmethod
|
89
|
+
def seq_to_json(cls, seq, run, key, step):
|
90
|
+
audio_list = list(seq)
|
91
|
+
|
92
|
+
util.get_module(
|
93
|
+
"soundfile",
|
94
|
+
required="wandb.Audio requires the soundfile package. To get it, run: pip install soundfile",
|
95
|
+
)
|
96
|
+
base_path = os.path.join(run.dir, "media", "audio")
|
97
|
+
filesystem.mkdir_exists_ok(base_path)
|
98
|
+
meta = {
|
99
|
+
"_type": "audio",
|
100
|
+
"count": len(audio_list),
|
101
|
+
"audio": [a.to_json(run) for a in audio_list],
|
102
|
+
}
|
103
|
+
sample_rates = cls.sample_rates(audio_list)
|
104
|
+
if sample_rates:
|
105
|
+
meta["sampleRates"] = sample_rates
|
106
|
+
durations = cls.durations(audio_list)
|
107
|
+
if durations:
|
108
|
+
meta["durations"] = durations
|
109
|
+
captions = cls.captions(audio_list)
|
110
|
+
if captions:
|
111
|
+
meta["captions"] = captions
|
112
|
+
|
113
|
+
return meta
|
114
|
+
|
115
|
+
@classmethod
|
116
|
+
def durations(cls, audio_list):
|
117
|
+
return [a._duration for a in audio_list]
|
118
|
+
|
119
|
+
@classmethod
|
120
|
+
def sample_rates(cls, audio_list):
|
121
|
+
return [a._sample_rate for a in audio_list]
|
122
|
+
|
123
|
+
@classmethod
|
124
|
+
def captions(cls, audio_list):
|
125
|
+
captions = [a._caption for a in audio_list]
|
126
|
+
if all(c is None for c in captions):
|
127
|
+
return False
|
128
|
+
else:
|
129
|
+
return ["" if c is None else c for c in captions]
|
130
|
+
|
131
|
+
def resolve_ref(self):
|
132
|
+
if self.path_is_reference(self._path):
|
133
|
+
# this object was already created using a ref:
|
134
|
+
return self._path
|
135
|
+
source_artifact = self._artifact_source.artifact
|
136
|
+
|
137
|
+
resolved_name = source_artifact._local_path_to_name(self._path)
|
138
|
+
if resolved_name is not None:
|
139
|
+
target_entry = source_artifact.manifest.get_entry_by_path(resolved_name)
|
140
|
+
if target_entry is not None:
|
141
|
+
return target_entry.ref
|
142
|
+
|
143
|
+
return None
|
144
|
+
|
145
|
+
def __eq__(self, other):
|
146
|
+
if self.path_is_reference(self._path) or self.path_is_reference(other._path):
|
147
|
+
# one or more of these objects is an unresolved reference -- we'll compare
|
148
|
+
# their reference paths instead of their SHAs:
|
149
|
+
return (
|
150
|
+
self.resolve_ref() == other.resolve_ref()
|
151
|
+
and self._caption == other._caption
|
152
|
+
)
|
153
|
+
|
154
|
+
return super().__eq__(other) and self._caption == other._caption
|
155
|
+
|
156
|
+
def __ne__(self, other):
|
157
|
+
return not self.__eq__(other)
|
158
|
+
|
159
|
+
|
160
|
+
class _AudioFileType(_dtypes.Type):
|
161
|
+
name = "audio-file"
|
162
|
+
types = [Audio]
|
163
|
+
|
164
|
+
|
165
|
+
_dtypes.TypeRegistry.add(_AudioFileType)
|
@@ -0,0 +1,70 @@
|
|
1
|
+
import codecs
|
2
|
+
import json
|
3
|
+
import os
|
4
|
+
|
5
|
+
from wandb import util
|
6
|
+
from wandb.sdk.lib import runid
|
7
|
+
|
8
|
+
from . import _dtypes
|
9
|
+
from ._private import MEDIA_TMP
|
10
|
+
from .base_types.media import Media
|
11
|
+
|
12
|
+
|
13
|
+
class Bokeh(Media):
|
14
|
+
"""Wandb class for Bokeh plots.
|
15
|
+
|
16
|
+
Arguments:
|
17
|
+
val: Bokeh plot
|
18
|
+
"""
|
19
|
+
|
20
|
+
_log_type = "bokeh-file"
|
21
|
+
|
22
|
+
def __init__(self, data_or_path):
|
23
|
+
super().__init__()
|
24
|
+
bokeh = util.get_module("bokeh", required=True)
|
25
|
+
if isinstance(data_or_path, str) and os.path.exists(data_or_path):
|
26
|
+
with open(data_or_path) as file:
|
27
|
+
b_json = json.load(file)
|
28
|
+
self.b_obj = bokeh.document.Document.from_json(b_json)
|
29
|
+
self._set_file(data_or_path, is_tmp=False, extension=".bokeh.json")
|
30
|
+
elif isinstance(data_or_path, bokeh.model.Model):
|
31
|
+
_data = bokeh.document.Document()
|
32
|
+
_data.add_root(data_or_path)
|
33
|
+
# serialize/deserialize pairing followed by sorting attributes ensures
|
34
|
+
# that the file's sha's are equivalent in subsequent calls
|
35
|
+
self.b_obj = bokeh.document.Document.from_json(_data.to_json())
|
36
|
+
b_json = self.b_obj.to_json()
|
37
|
+
if "references" in b_json["roots"]:
|
38
|
+
b_json["roots"]["references"].sort(key=lambda x: x["id"])
|
39
|
+
|
40
|
+
tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".bokeh.json")
|
41
|
+
with codecs.open(tmp_path, "w", encoding="utf-8") as fp:
|
42
|
+
util.json_dump_safer(b_json, fp)
|
43
|
+
self._set_file(tmp_path, is_tmp=True, extension=".bokeh.json")
|
44
|
+
elif not isinstance(data_or_path, bokeh.document.Document):
|
45
|
+
raise TypeError(
|
46
|
+
"Bokeh constructor accepts Bokeh document/model or path to Bokeh json file"
|
47
|
+
)
|
48
|
+
|
49
|
+
def get_media_subdir(self):
|
50
|
+
return os.path.join("media", "bokeh")
|
51
|
+
|
52
|
+
def to_json(self, run):
|
53
|
+
# TODO: (tss) this is getting redundant for all the media objects. We can probably
|
54
|
+
# pull this into Media#to_json and remove this type override for all the media types.
|
55
|
+
# There are only a few cases where the type is different between artifacts and runs.
|
56
|
+
json_dict = super().to_json(run)
|
57
|
+
json_dict["_type"] = self._log_type
|
58
|
+
return json_dict
|
59
|
+
|
60
|
+
@classmethod
|
61
|
+
def from_json(cls, json_obj, source_artifact):
|
62
|
+
return cls(source_artifact.get_entry(json_obj["path"]).download())
|
63
|
+
|
64
|
+
|
65
|
+
class _BokehFileType(_dtypes.Type):
|
66
|
+
name = "bokeh-file"
|
67
|
+
types = [Bokeh]
|
68
|
+
|
69
|
+
|
70
|
+
_dtypes.TypeRegistry.add(_BokehFileType)
|
@@ -0,0 +1,405 @@
|
|
1
|
+
import codecs
|
2
|
+
import os
|
3
|
+
import pprint
|
4
|
+
|
5
|
+
from wandb import util
|
6
|
+
from wandb.sdk.data_types._private import MEDIA_TMP
|
7
|
+
from wandb.sdk.data_types.base_types.media import Media, _numpy_arrays_to_lists
|
8
|
+
from wandb.sdk.data_types.base_types.wb_value import WBValue
|
9
|
+
from wandb.sdk.lib import runid
|
10
|
+
|
11
|
+
|
12
|
+
def _nest(thing):
|
13
|
+
# Use tensorflows nest function if available, otherwise just wrap object in an array"""
|
14
|
+
|
15
|
+
tfutil = util.get_module("tensorflow.python.util")
|
16
|
+
if tfutil:
|
17
|
+
return tfutil.nest.flatten(thing)
|
18
|
+
else:
|
19
|
+
return [thing]
|
20
|
+
|
21
|
+
|
22
|
+
class Edge(WBValue):
|
23
|
+
"""Edge used in `Graph`."""
|
24
|
+
|
25
|
+
def __init__(self, from_node, to_node):
|
26
|
+
self._attributes = {}
|
27
|
+
self.from_node = from_node
|
28
|
+
self.to_node = to_node
|
29
|
+
|
30
|
+
def __repr__(self):
|
31
|
+
temp_attr = dict(self._attributes)
|
32
|
+
del temp_attr["from_node"]
|
33
|
+
del temp_attr["to_node"]
|
34
|
+
temp_attr["from_id"] = self.from_node.id
|
35
|
+
temp_attr["to_id"] = self.to_node.id
|
36
|
+
return str(temp_attr)
|
37
|
+
|
38
|
+
def to_json(self, run=None):
|
39
|
+
return [self.from_node.id, self.to_node.id]
|
40
|
+
|
41
|
+
@property
|
42
|
+
def name(self):
|
43
|
+
"""Optional, not necessarily unique."""
|
44
|
+
return self._attributes.get("name")
|
45
|
+
|
46
|
+
@name.setter
|
47
|
+
def name(self, val):
|
48
|
+
self._attributes["name"] = val
|
49
|
+
return val
|
50
|
+
|
51
|
+
@property
|
52
|
+
def from_node(self):
|
53
|
+
return self._attributes.get("from_node")
|
54
|
+
|
55
|
+
@from_node.setter
|
56
|
+
def from_node(self, val):
|
57
|
+
self._attributes["from_node"] = val
|
58
|
+
return val
|
59
|
+
|
60
|
+
@property
|
61
|
+
def to_node(self):
|
62
|
+
return self._attributes.get("to_node")
|
63
|
+
|
64
|
+
@to_node.setter
|
65
|
+
def to_node(self, val):
|
66
|
+
self._attributes["to_node"] = val
|
67
|
+
return val
|
68
|
+
|
69
|
+
|
70
|
+
class Node(WBValue):
|
71
|
+
"""Node used in `Graph`."""
|
72
|
+
|
73
|
+
def __init__(
|
74
|
+
self,
|
75
|
+
id=None,
|
76
|
+
name=None,
|
77
|
+
class_name=None,
|
78
|
+
size=None,
|
79
|
+
parameters=None,
|
80
|
+
output_shape=None,
|
81
|
+
is_output=None,
|
82
|
+
num_parameters=None,
|
83
|
+
node=None,
|
84
|
+
):
|
85
|
+
self._attributes = {"name": None}
|
86
|
+
self.in_edges = {} # indexed by source node id
|
87
|
+
self.out_edges = {} # indexed by dest node id
|
88
|
+
# optional object (e.g. PyTorch Parameter or Module) that this Node represents
|
89
|
+
self.obj = None
|
90
|
+
|
91
|
+
if node is not None:
|
92
|
+
self._attributes.update(node._attributes)
|
93
|
+
del self._attributes["id"]
|
94
|
+
self.obj = node.obj
|
95
|
+
|
96
|
+
if id is not None:
|
97
|
+
self.id = id
|
98
|
+
if name is not None:
|
99
|
+
self.name = name
|
100
|
+
if class_name is not None:
|
101
|
+
self.class_name = class_name
|
102
|
+
if size is not None:
|
103
|
+
self.size = size
|
104
|
+
if parameters is not None:
|
105
|
+
self.parameters = parameters
|
106
|
+
if output_shape is not None:
|
107
|
+
self.output_shape = output_shape
|
108
|
+
if is_output is not None:
|
109
|
+
self.is_output = is_output
|
110
|
+
if num_parameters is not None:
|
111
|
+
self.num_parameters = num_parameters
|
112
|
+
|
113
|
+
def to_json(self, run=None):
|
114
|
+
return self._attributes
|
115
|
+
|
116
|
+
def __repr__(self):
|
117
|
+
return repr(self._attributes)
|
118
|
+
|
119
|
+
@property
|
120
|
+
def id(self):
|
121
|
+
"""Must be unique in the graph."""
|
122
|
+
return self._attributes.get("id")
|
123
|
+
|
124
|
+
@id.setter
|
125
|
+
def id(self, val):
|
126
|
+
self._attributes["id"] = val
|
127
|
+
return val
|
128
|
+
|
129
|
+
@property
|
130
|
+
def name(self):
|
131
|
+
"""Usually the type of layer or sublayer."""
|
132
|
+
return self._attributes.get("name")
|
133
|
+
|
134
|
+
@name.setter
|
135
|
+
def name(self, val):
|
136
|
+
self._attributes["name"] = val
|
137
|
+
return val
|
138
|
+
|
139
|
+
@property
|
140
|
+
def class_name(self):
|
141
|
+
"""Usually the type of layer or sublayer."""
|
142
|
+
return self._attributes.get("class_name")
|
143
|
+
|
144
|
+
@class_name.setter
|
145
|
+
def class_name(self, val):
|
146
|
+
self._attributes["class_name"] = val
|
147
|
+
return val
|
148
|
+
|
149
|
+
@property
|
150
|
+
def functions(self):
|
151
|
+
return self._attributes.get("functions", [])
|
152
|
+
|
153
|
+
@functions.setter
|
154
|
+
def functions(self, val):
|
155
|
+
self._attributes["functions"] = val
|
156
|
+
return val
|
157
|
+
|
158
|
+
@property
|
159
|
+
def parameters(self):
|
160
|
+
return self._attributes.get("parameters", [])
|
161
|
+
|
162
|
+
@parameters.setter
|
163
|
+
def parameters(self, val):
|
164
|
+
self._attributes["parameters"] = val
|
165
|
+
return val
|
166
|
+
|
167
|
+
@property
|
168
|
+
def size(self):
|
169
|
+
return self._attributes.get("size")
|
170
|
+
|
171
|
+
@size.setter
|
172
|
+
def size(self, val):
|
173
|
+
"""Tensor size."""
|
174
|
+
self._attributes["size"] = tuple(val)
|
175
|
+
return val
|
176
|
+
|
177
|
+
@property
|
178
|
+
def output_shape(self):
|
179
|
+
return self._attributes.get("output_shape")
|
180
|
+
|
181
|
+
@output_shape.setter
|
182
|
+
def output_shape(self, val):
|
183
|
+
"""Tensor output_shape."""
|
184
|
+
self._attributes["output_shape"] = val
|
185
|
+
return val
|
186
|
+
|
187
|
+
@property
|
188
|
+
def is_output(self):
|
189
|
+
return self._attributes.get("is_output")
|
190
|
+
|
191
|
+
@is_output.setter
|
192
|
+
def is_output(self, val):
|
193
|
+
"""Tensor is_output."""
|
194
|
+
self._attributes["is_output"] = val
|
195
|
+
return val
|
196
|
+
|
197
|
+
@property
|
198
|
+
def num_parameters(self):
|
199
|
+
return self._attributes.get("num_parameters")
|
200
|
+
|
201
|
+
@num_parameters.setter
|
202
|
+
def num_parameters(self, val):
|
203
|
+
"""Tensor num_parameters."""
|
204
|
+
self._attributes["num_parameters"] = val
|
205
|
+
return val
|
206
|
+
|
207
|
+
@property
|
208
|
+
def child_parameters(self):
|
209
|
+
return self._attributes.get("child_parameters")
|
210
|
+
|
211
|
+
@child_parameters.setter
|
212
|
+
def child_parameters(self, val):
|
213
|
+
"""Tensor child_parameters."""
|
214
|
+
self._attributes["child_parameters"] = val
|
215
|
+
return val
|
216
|
+
|
217
|
+
@property
|
218
|
+
def is_constant(self):
|
219
|
+
return self._attributes.get("is_constant")
|
220
|
+
|
221
|
+
@is_constant.setter
|
222
|
+
def is_constant(self, val):
|
223
|
+
"""Tensor is_constant."""
|
224
|
+
self._attributes["is_constant"] = val
|
225
|
+
return val
|
226
|
+
|
227
|
+
@classmethod
|
228
|
+
def from_keras(cls, layer):
|
229
|
+
node = cls()
|
230
|
+
|
231
|
+
try:
|
232
|
+
output_shape = layer.output_shape
|
233
|
+
except AttributeError:
|
234
|
+
output_shape = ["multiple"]
|
235
|
+
|
236
|
+
node.id = layer.name
|
237
|
+
node.name = layer.name
|
238
|
+
node.class_name = layer.__class__.__name__
|
239
|
+
node.output_shape = output_shape
|
240
|
+
node.num_parameters = layer.count_params()
|
241
|
+
|
242
|
+
return node
|
243
|
+
|
244
|
+
|
245
|
+
class Graph(Media):
|
246
|
+
"""Wandb class for graphs.
|
247
|
+
|
248
|
+
This class is typically used for saving and displaying neural net models. It
|
249
|
+
represents the graph as an array of nodes and edges. The nodes can have
|
250
|
+
labels that can be visualized by wandb.
|
251
|
+
|
252
|
+
Examples:
|
253
|
+
Import a keras model:
|
254
|
+
```
|
255
|
+
Graph.from_keras(keras_model)
|
256
|
+
```
|
257
|
+
|
258
|
+
Attributes:
|
259
|
+
format (string): Format to help wandb display the graph nicely.
|
260
|
+
nodes ([wandb.Node]): List of wandb.Nodes
|
261
|
+
nodes_by_id (dict): dict of ids -> nodes
|
262
|
+
edges ([(wandb.Node, wandb.Node)]): List of pairs of nodes interpreted as edges
|
263
|
+
loaded (boolean): Flag to tell whether the graph is completely loaded
|
264
|
+
root (wandb.Node): root node of the graph
|
265
|
+
"""
|
266
|
+
|
267
|
+
_log_type = "graph-file"
|
268
|
+
|
269
|
+
def __init__(self, format="keras"):
|
270
|
+
super().__init__()
|
271
|
+
# LB: TODO: I think we should factor criterion and criterion_passed out
|
272
|
+
self.format = format
|
273
|
+
self.nodes = []
|
274
|
+
self.nodes_by_id = {}
|
275
|
+
self.edges = []
|
276
|
+
self.loaded = False
|
277
|
+
self.criterion = None
|
278
|
+
self.criterion_passed = False
|
279
|
+
self.root = None # optional root Node if applicable
|
280
|
+
|
281
|
+
def _to_graph_json(self, run=None):
|
282
|
+
# Needs to be its own function for tests
|
283
|
+
return {
|
284
|
+
"format": self.format,
|
285
|
+
"nodes": [node.to_json() for node in self.nodes],
|
286
|
+
"edges": [edge.to_json() for edge in self.edges],
|
287
|
+
}
|
288
|
+
|
289
|
+
def bind_to_run(self, *args, **kwargs):
|
290
|
+
data = self._to_graph_json()
|
291
|
+
tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".graph.json")
|
292
|
+
data = _numpy_arrays_to_lists(data)
|
293
|
+
with codecs.open(tmp_path, "w", encoding="utf-8") as fp:
|
294
|
+
util.json_dump_safer(data, fp)
|
295
|
+
self._set_file(tmp_path, is_tmp=True, extension=".graph.json")
|
296
|
+
if self.is_bound():
|
297
|
+
return
|
298
|
+
super().bind_to_run(*args, **kwargs)
|
299
|
+
|
300
|
+
@classmethod
|
301
|
+
def get_media_subdir(cls):
|
302
|
+
return os.path.join("media", "graph")
|
303
|
+
|
304
|
+
def to_json(self, run):
|
305
|
+
json_dict = super().to_json(run)
|
306
|
+
json_dict["_type"] = self._log_type
|
307
|
+
return json_dict
|
308
|
+
|
309
|
+
def __getitem__(self, nid):
|
310
|
+
return self.nodes_by_id[nid]
|
311
|
+
|
312
|
+
def pprint(self):
|
313
|
+
for edge in self.edges:
|
314
|
+
pprint.pprint(edge.attributes)
|
315
|
+
for node in self.nodes:
|
316
|
+
pprint.pprint(node.attributes)
|
317
|
+
|
318
|
+
def add_node(self, node=None, **node_kwargs):
|
319
|
+
if node is None:
|
320
|
+
node = Node(**node_kwargs)
|
321
|
+
elif node_kwargs:
|
322
|
+
raise ValueError(
|
323
|
+
f"Only pass one of either node ({node}) or other keyword arguments ({node_kwargs})"
|
324
|
+
)
|
325
|
+
self.nodes.append(node)
|
326
|
+
self.nodes_by_id[node.id] = node
|
327
|
+
|
328
|
+
return node
|
329
|
+
|
330
|
+
def add_edge(self, from_node, to_node):
|
331
|
+
edge = Edge(from_node, to_node)
|
332
|
+
self.edges.append(edge)
|
333
|
+
|
334
|
+
return edge
|
335
|
+
|
336
|
+
@classmethod
|
337
|
+
def from_keras(cls, model):
|
338
|
+
# TODO: his method requires a refactor to work with the keras 3.
|
339
|
+
graph = cls()
|
340
|
+
# Shamelessly copied (then modified) from keras/keras/utils/layer_utils.py
|
341
|
+
sequential_like = cls._is_sequential(model)
|
342
|
+
|
343
|
+
relevant_nodes = None
|
344
|
+
if not sequential_like:
|
345
|
+
relevant_nodes = []
|
346
|
+
for v in model._nodes_by_depth.values():
|
347
|
+
relevant_nodes += v
|
348
|
+
|
349
|
+
layers = model.layers
|
350
|
+
for i in range(len(layers)):
|
351
|
+
node = Node.from_keras(layers[i])
|
352
|
+
if hasattr(layers[i], "_inbound_nodes"):
|
353
|
+
for in_node in layers[i]._inbound_nodes:
|
354
|
+
if relevant_nodes and in_node not in relevant_nodes:
|
355
|
+
# node is not part of the current network
|
356
|
+
continue
|
357
|
+
for in_layer in _nest(in_node.inbound_layers):
|
358
|
+
inbound_keras_node = Node.from_keras(in_layer)
|
359
|
+
|
360
|
+
if inbound_keras_node.id not in graph.nodes_by_id:
|
361
|
+
graph.add_node(inbound_keras_node)
|
362
|
+
inbound_node = graph.nodes_by_id[inbound_keras_node.id]
|
363
|
+
|
364
|
+
graph.add_edge(inbound_node, node)
|
365
|
+
graph.add_node(node)
|
366
|
+
return graph
|
367
|
+
|
368
|
+
@classmethod
|
369
|
+
def _is_sequential(cls, model):
|
370
|
+
sequential_like = True
|
371
|
+
|
372
|
+
if (
|
373
|
+
model.__class__.__name__ != "Sequential"
|
374
|
+
and hasattr(model, "_is_graph_network")
|
375
|
+
and model._is_graph_network
|
376
|
+
):
|
377
|
+
nodes_by_depth = model._nodes_by_depth.values()
|
378
|
+
nodes = []
|
379
|
+
for v in nodes_by_depth:
|
380
|
+
# TensorFlow2 doesn't insure inbound is always a list
|
381
|
+
inbound = v[0].inbound_layers
|
382
|
+
if not hasattr(inbound, "__len__"):
|
383
|
+
inbound = [inbound]
|
384
|
+
if (len(v) > 1) or (len(v) == 1 and len(inbound) > 1):
|
385
|
+
# if the model has multiple nodes
|
386
|
+
# or if the nodes have multiple inbound_layers
|
387
|
+
# the model is no longer sequential
|
388
|
+
sequential_like = False
|
389
|
+
break
|
390
|
+
nodes += v
|
391
|
+
if sequential_like:
|
392
|
+
# search for shared layers
|
393
|
+
for layer in model.layers:
|
394
|
+
flag = False
|
395
|
+
if hasattr(layer, "_inbound_nodes"):
|
396
|
+
for node in layer._inbound_nodes:
|
397
|
+
if node in nodes:
|
398
|
+
if flag:
|
399
|
+
sequential_like = False
|
400
|
+
break
|
401
|
+
else:
|
402
|
+
flag = True
|
403
|
+
if not sequential_like:
|
404
|
+
break
|
405
|
+
return sequential_like
|