wandb 0.18.0__py3-none-any.whl → 0.18.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. wandb/__init__.py +2 -2
  2. wandb/__init__.pyi +1 -1
  3. wandb/apis/public/runs.py +2 -0
  4. wandb/bin/nvidia_gpu_stats +0 -0
  5. wandb/cli/cli.py +0 -2
  6. wandb/data_types.py +9 -2019
  7. wandb/env.py +0 -5
  8. wandb/{sklearn → integration/sklearn}/calculate/calibration_curves.py +7 -7
  9. wandb/{sklearn → integration/sklearn}/calculate/class_proportions.py +1 -1
  10. wandb/{sklearn → integration/sklearn}/calculate/confusion_matrix.py +3 -2
  11. wandb/{sklearn → integration/sklearn}/calculate/elbow_curve.py +6 -6
  12. wandb/{sklearn → integration/sklearn}/calculate/learning_curve.py +2 -2
  13. wandb/{sklearn → integration/sklearn}/calculate/outlier_candidates.py +2 -2
  14. wandb/{sklearn → integration/sklearn}/calculate/residuals.py +8 -8
  15. wandb/{sklearn → integration/sklearn}/calculate/silhouette.py +2 -2
  16. wandb/{sklearn → integration/sklearn}/calculate/summary_metrics.py +2 -2
  17. wandb/{sklearn → integration/sklearn}/plot/classifier.py +5 -5
  18. wandb/{sklearn → integration/sklearn}/plot/clusterer.py +10 -6
  19. wandb/{sklearn → integration/sklearn}/plot/regressor.py +5 -5
  20. wandb/{sklearn → integration/sklearn}/plot/shared.py +3 -3
  21. wandb/{sklearn → integration/sklearn}/utils.py +8 -8
  22. wandb/{wandb_torch.py → integration/torch/wandb_torch.py} +36 -32
  23. wandb/proto/v3/wandb_base_pb2.py +2 -1
  24. wandb/proto/v3/wandb_internal_pb2.py +2 -1
  25. wandb/proto/v3/wandb_server_pb2.py +2 -1
  26. wandb/proto/v3/wandb_settings_pb2.py +2 -1
  27. wandb/proto/v3/wandb_telemetry_pb2.py +2 -1
  28. wandb/proto/v4/wandb_base_pb2.py +2 -1
  29. wandb/proto/v4/wandb_internal_pb2.py +2 -1
  30. wandb/proto/v4/wandb_server_pb2.py +2 -1
  31. wandb/proto/v4/wandb_settings_pb2.py +2 -1
  32. wandb/proto/v4/wandb_telemetry_pb2.py +2 -1
  33. wandb/proto/v5/wandb_base_pb2.py +3 -2
  34. wandb/proto/v5/wandb_internal_pb2.py +3 -2
  35. wandb/proto/v5/wandb_server_pb2.py +3 -2
  36. wandb/proto/v5/wandb_settings_pb2.py +3 -2
  37. wandb/proto/v5/wandb_telemetry_pb2.py +3 -2
  38. wandb/sdk/data_types/audio.py +165 -0
  39. wandb/sdk/data_types/bokeh.py +70 -0
  40. wandb/sdk/data_types/graph.py +405 -0
  41. wandb/sdk/data_types/image.py +156 -0
  42. wandb/sdk/data_types/table.py +1204 -0
  43. wandb/sdk/data_types/trace_tree.py +2 -2
  44. wandb/sdk/data_types/utils.py +49 -0
  45. wandb/sdk/service/service.py +2 -9
  46. wandb/sdk/service/streams.py +0 -7
  47. wandb/sdk/wandb_init.py +10 -3
  48. wandb/sdk/wandb_run.py +6 -152
  49. wandb/sdk/wandb_setup.py +1 -1
  50. wandb/sklearn.py +35 -0
  51. wandb/util.py +6 -2
  52. {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/METADATA +1 -1
  53. {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/RECORD +61 -57
  54. wandb/sdk/lib/console.py +0 -39
  55. /wandb/{sklearn → integration/sklearn}/__init__.py +0 -0
  56. /wandb/{sklearn → integration/sklearn}/calculate/__init__.py +0 -0
  57. /wandb/{sklearn → integration/sklearn}/calculate/decision_boundaries.py +0 -0
  58. /wandb/{sklearn → integration/sklearn}/calculate/feature_importances.py +0 -0
  59. /wandb/{sklearn → integration/sklearn}/plot/__init__.py +0 -0
  60. {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/WHEEL +0 -0
  61. {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/entry_points.txt +0 -0
  62. {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,165 @@
1
+ import hashlib
2
+ import os
3
+ from typing import Optional
4
+
5
+ from wandb import util
6
+ from wandb.sdk.lib import filesystem, runid
7
+
8
+ from . import _dtypes
9
+ from ._private import MEDIA_TMP
10
+ from .base_types.media import BatchableMedia
11
+
12
+
13
+ class Audio(BatchableMedia):
14
+ """Wandb class for audio clips.
15
+
16
+ Arguments:
17
+ data_or_path: (string or numpy array) A path to an audio file
18
+ or a numpy array of audio data.
19
+ sample_rate: (int) Sample rate, required when passing in raw
20
+ numpy array of audio data.
21
+ caption: (string) Caption to display with audio.
22
+ """
23
+
24
+ _log_type = "audio-file"
25
+
26
+ def __init__(self, data_or_path, sample_rate=None, caption=None):
27
+ """Accept a path to an audio file or a numpy array of audio data."""
28
+ super().__init__()
29
+ self._duration = None
30
+ self._sample_rate = sample_rate
31
+ self._caption = caption
32
+
33
+ if isinstance(data_or_path, str):
34
+ if self.path_is_reference(data_or_path):
35
+ self._path = data_or_path
36
+ self._sha256 = hashlib.sha256(data_or_path.encode("utf-8")).hexdigest()
37
+ self._is_tmp = False
38
+ else:
39
+ self._set_file(data_or_path, is_tmp=False)
40
+ else:
41
+ if sample_rate is None:
42
+ raise ValueError(
43
+ 'Argument "sample_rate" is required when instantiating wandb.Audio with raw data.'
44
+ )
45
+
46
+ soundfile = util.get_module(
47
+ "soundfile",
48
+ required='Raw audio requires the soundfile package. To get it, run "pip install soundfile"',
49
+ )
50
+
51
+ tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".wav")
52
+ soundfile.write(tmp_path, data_or_path, sample_rate)
53
+ self._duration = len(data_or_path) / float(sample_rate)
54
+
55
+ self._set_file(tmp_path, is_tmp=True)
56
+
57
+ @classmethod
58
+ def get_media_subdir(cls):
59
+ return os.path.join("media", "audio")
60
+
61
+ @classmethod
62
+ def from_json(cls, json_obj, source_artifact):
63
+ return cls(
64
+ source_artifact.get_entry(json_obj["path"]).download(),
65
+ caption=json_obj["caption"],
66
+ )
67
+
68
+ def bind_to_run(
69
+ self, run, key, step, id_=None, ignore_copy_err: Optional[bool] = None
70
+ ):
71
+ if self.path_is_reference(self._path):
72
+ raise ValueError(
73
+ "Audio media created by a reference to external storage cannot currently be added to a run"
74
+ )
75
+
76
+ return super().bind_to_run(run, key, step, id_, ignore_copy_err)
77
+
78
+ def to_json(self, run):
79
+ json_dict = super().to_json(run)
80
+ json_dict.update(
81
+ {
82
+ "_type": self._log_type,
83
+ "caption": self._caption,
84
+ }
85
+ )
86
+ return json_dict
87
+
88
+ @classmethod
89
+ def seq_to_json(cls, seq, run, key, step):
90
+ audio_list = list(seq)
91
+
92
+ util.get_module(
93
+ "soundfile",
94
+ required="wandb.Audio requires the soundfile package. To get it, run: pip install soundfile",
95
+ )
96
+ base_path = os.path.join(run.dir, "media", "audio")
97
+ filesystem.mkdir_exists_ok(base_path)
98
+ meta = {
99
+ "_type": "audio",
100
+ "count": len(audio_list),
101
+ "audio": [a.to_json(run) for a in audio_list],
102
+ }
103
+ sample_rates = cls.sample_rates(audio_list)
104
+ if sample_rates:
105
+ meta["sampleRates"] = sample_rates
106
+ durations = cls.durations(audio_list)
107
+ if durations:
108
+ meta["durations"] = durations
109
+ captions = cls.captions(audio_list)
110
+ if captions:
111
+ meta["captions"] = captions
112
+
113
+ return meta
114
+
115
+ @classmethod
116
+ def durations(cls, audio_list):
117
+ return [a._duration for a in audio_list]
118
+
119
+ @classmethod
120
+ def sample_rates(cls, audio_list):
121
+ return [a._sample_rate for a in audio_list]
122
+
123
+ @classmethod
124
+ def captions(cls, audio_list):
125
+ captions = [a._caption for a in audio_list]
126
+ if all(c is None for c in captions):
127
+ return False
128
+ else:
129
+ return ["" if c is None else c for c in captions]
130
+
131
+ def resolve_ref(self):
132
+ if self.path_is_reference(self._path):
133
+ # this object was already created using a ref:
134
+ return self._path
135
+ source_artifact = self._artifact_source.artifact
136
+
137
+ resolved_name = source_artifact._local_path_to_name(self._path)
138
+ if resolved_name is not None:
139
+ target_entry = source_artifact.manifest.get_entry_by_path(resolved_name)
140
+ if target_entry is not None:
141
+ return target_entry.ref
142
+
143
+ return None
144
+
145
+ def __eq__(self, other):
146
+ if self.path_is_reference(self._path) or self.path_is_reference(other._path):
147
+ # one or more of these objects is an unresolved reference -- we'll compare
148
+ # their reference paths instead of their SHAs:
149
+ return (
150
+ self.resolve_ref() == other.resolve_ref()
151
+ and self._caption == other._caption
152
+ )
153
+
154
+ return super().__eq__(other) and self._caption == other._caption
155
+
156
+ def __ne__(self, other):
157
+ return not self.__eq__(other)
158
+
159
+
160
+ class _AudioFileType(_dtypes.Type):
161
+ name = "audio-file"
162
+ types = [Audio]
163
+
164
+
165
+ _dtypes.TypeRegistry.add(_AudioFileType)
@@ -0,0 +1,70 @@
1
+ import codecs
2
+ import json
3
+ import os
4
+
5
+ from wandb import util
6
+ from wandb.sdk.lib import runid
7
+
8
+ from . import _dtypes
9
+ from ._private import MEDIA_TMP
10
+ from .base_types.media import Media
11
+
12
+
13
+ class Bokeh(Media):
14
+ """Wandb class for Bokeh plots.
15
+
16
+ Arguments:
17
+ val: Bokeh plot
18
+ """
19
+
20
+ _log_type = "bokeh-file"
21
+
22
+ def __init__(self, data_or_path):
23
+ super().__init__()
24
+ bokeh = util.get_module("bokeh", required=True)
25
+ if isinstance(data_or_path, str) and os.path.exists(data_or_path):
26
+ with open(data_or_path) as file:
27
+ b_json = json.load(file)
28
+ self.b_obj = bokeh.document.Document.from_json(b_json)
29
+ self._set_file(data_or_path, is_tmp=False, extension=".bokeh.json")
30
+ elif isinstance(data_or_path, bokeh.model.Model):
31
+ _data = bokeh.document.Document()
32
+ _data.add_root(data_or_path)
33
+ # serialize/deserialize pairing followed by sorting attributes ensures
34
+ # that the file's sha's are equivalent in subsequent calls
35
+ self.b_obj = bokeh.document.Document.from_json(_data.to_json())
36
+ b_json = self.b_obj.to_json()
37
+ if "references" in b_json["roots"]:
38
+ b_json["roots"]["references"].sort(key=lambda x: x["id"])
39
+
40
+ tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".bokeh.json")
41
+ with codecs.open(tmp_path, "w", encoding="utf-8") as fp:
42
+ util.json_dump_safer(b_json, fp)
43
+ self._set_file(tmp_path, is_tmp=True, extension=".bokeh.json")
44
+ elif not isinstance(data_or_path, bokeh.document.Document):
45
+ raise TypeError(
46
+ "Bokeh constructor accepts Bokeh document/model or path to Bokeh json file"
47
+ )
48
+
49
+ def get_media_subdir(self):
50
+ return os.path.join("media", "bokeh")
51
+
52
+ def to_json(self, run):
53
+ # TODO: (tss) this is getting redundant for all the media objects. We can probably
54
+ # pull this into Media#to_json and remove this type override for all the media types.
55
+ # There are only a few cases where the type is different between artifacts and runs.
56
+ json_dict = super().to_json(run)
57
+ json_dict["_type"] = self._log_type
58
+ return json_dict
59
+
60
+ @classmethod
61
+ def from_json(cls, json_obj, source_artifact):
62
+ return cls(source_artifact.get_entry(json_obj["path"]).download())
63
+
64
+
65
+ class _BokehFileType(_dtypes.Type):
66
+ name = "bokeh-file"
67
+ types = [Bokeh]
68
+
69
+
70
+ _dtypes.TypeRegistry.add(_BokehFileType)
@@ -0,0 +1,405 @@
1
+ import codecs
2
+ import os
3
+ import pprint
4
+
5
+ from wandb import util
6
+ from wandb.sdk.data_types._private import MEDIA_TMP
7
+ from wandb.sdk.data_types.base_types.media import Media, _numpy_arrays_to_lists
8
+ from wandb.sdk.data_types.base_types.wb_value import WBValue
9
+ from wandb.sdk.lib import runid
10
+
11
+
12
+ def _nest(thing):
13
+ # Use tensorflows nest function if available, otherwise just wrap object in an array"""
14
+
15
+ tfutil = util.get_module("tensorflow.python.util")
16
+ if tfutil:
17
+ return tfutil.nest.flatten(thing)
18
+ else:
19
+ return [thing]
20
+
21
+
22
+ class Edge(WBValue):
23
+ """Edge used in `Graph`."""
24
+
25
+ def __init__(self, from_node, to_node):
26
+ self._attributes = {}
27
+ self.from_node = from_node
28
+ self.to_node = to_node
29
+
30
+ def __repr__(self):
31
+ temp_attr = dict(self._attributes)
32
+ del temp_attr["from_node"]
33
+ del temp_attr["to_node"]
34
+ temp_attr["from_id"] = self.from_node.id
35
+ temp_attr["to_id"] = self.to_node.id
36
+ return str(temp_attr)
37
+
38
+ def to_json(self, run=None):
39
+ return [self.from_node.id, self.to_node.id]
40
+
41
+ @property
42
+ def name(self):
43
+ """Optional, not necessarily unique."""
44
+ return self._attributes.get("name")
45
+
46
+ @name.setter
47
+ def name(self, val):
48
+ self._attributes["name"] = val
49
+ return val
50
+
51
+ @property
52
+ def from_node(self):
53
+ return self._attributes.get("from_node")
54
+
55
+ @from_node.setter
56
+ def from_node(self, val):
57
+ self._attributes["from_node"] = val
58
+ return val
59
+
60
+ @property
61
+ def to_node(self):
62
+ return self._attributes.get("to_node")
63
+
64
+ @to_node.setter
65
+ def to_node(self, val):
66
+ self._attributes["to_node"] = val
67
+ return val
68
+
69
+
70
+ class Node(WBValue):
71
+ """Node used in `Graph`."""
72
+
73
+ def __init__(
74
+ self,
75
+ id=None,
76
+ name=None,
77
+ class_name=None,
78
+ size=None,
79
+ parameters=None,
80
+ output_shape=None,
81
+ is_output=None,
82
+ num_parameters=None,
83
+ node=None,
84
+ ):
85
+ self._attributes = {"name": None}
86
+ self.in_edges = {} # indexed by source node id
87
+ self.out_edges = {} # indexed by dest node id
88
+ # optional object (e.g. PyTorch Parameter or Module) that this Node represents
89
+ self.obj = None
90
+
91
+ if node is not None:
92
+ self._attributes.update(node._attributes)
93
+ del self._attributes["id"]
94
+ self.obj = node.obj
95
+
96
+ if id is not None:
97
+ self.id = id
98
+ if name is not None:
99
+ self.name = name
100
+ if class_name is not None:
101
+ self.class_name = class_name
102
+ if size is not None:
103
+ self.size = size
104
+ if parameters is not None:
105
+ self.parameters = parameters
106
+ if output_shape is not None:
107
+ self.output_shape = output_shape
108
+ if is_output is not None:
109
+ self.is_output = is_output
110
+ if num_parameters is not None:
111
+ self.num_parameters = num_parameters
112
+
113
+ def to_json(self, run=None):
114
+ return self._attributes
115
+
116
+ def __repr__(self):
117
+ return repr(self._attributes)
118
+
119
+ @property
120
+ def id(self):
121
+ """Must be unique in the graph."""
122
+ return self._attributes.get("id")
123
+
124
+ @id.setter
125
+ def id(self, val):
126
+ self._attributes["id"] = val
127
+ return val
128
+
129
+ @property
130
+ def name(self):
131
+ """Usually the type of layer or sublayer."""
132
+ return self._attributes.get("name")
133
+
134
+ @name.setter
135
+ def name(self, val):
136
+ self._attributes["name"] = val
137
+ return val
138
+
139
+ @property
140
+ def class_name(self):
141
+ """Usually the type of layer or sublayer."""
142
+ return self._attributes.get("class_name")
143
+
144
+ @class_name.setter
145
+ def class_name(self, val):
146
+ self._attributes["class_name"] = val
147
+ return val
148
+
149
+ @property
150
+ def functions(self):
151
+ return self._attributes.get("functions", [])
152
+
153
+ @functions.setter
154
+ def functions(self, val):
155
+ self._attributes["functions"] = val
156
+ return val
157
+
158
+ @property
159
+ def parameters(self):
160
+ return self._attributes.get("parameters", [])
161
+
162
+ @parameters.setter
163
+ def parameters(self, val):
164
+ self._attributes["parameters"] = val
165
+ return val
166
+
167
+ @property
168
+ def size(self):
169
+ return self._attributes.get("size")
170
+
171
+ @size.setter
172
+ def size(self, val):
173
+ """Tensor size."""
174
+ self._attributes["size"] = tuple(val)
175
+ return val
176
+
177
+ @property
178
+ def output_shape(self):
179
+ return self._attributes.get("output_shape")
180
+
181
+ @output_shape.setter
182
+ def output_shape(self, val):
183
+ """Tensor output_shape."""
184
+ self._attributes["output_shape"] = val
185
+ return val
186
+
187
+ @property
188
+ def is_output(self):
189
+ return self._attributes.get("is_output")
190
+
191
+ @is_output.setter
192
+ def is_output(self, val):
193
+ """Tensor is_output."""
194
+ self._attributes["is_output"] = val
195
+ return val
196
+
197
+ @property
198
+ def num_parameters(self):
199
+ return self._attributes.get("num_parameters")
200
+
201
+ @num_parameters.setter
202
+ def num_parameters(self, val):
203
+ """Tensor num_parameters."""
204
+ self._attributes["num_parameters"] = val
205
+ return val
206
+
207
+ @property
208
+ def child_parameters(self):
209
+ return self._attributes.get("child_parameters")
210
+
211
+ @child_parameters.setter
212
+ def child_parameters(self, val):
213
+ """Tensor child_parameters."""
214
+ self._attributes["child_parameters"] = val
215
+ return val
216
+
217
+ @property
218
+ def is_constant(self):
219
+ return self._attributes.get("is_constant")
220
+
221
+ @is_constant.setter
222
+ def is_constant(self, val):
223
+ """Tensor is_constant."""
224
+ self._attributes["is_constant"] = val
225
+ return val
226
+
227
+ @classmethod
228
+ def from_keras(cls, layer):
229
+ node = cls()
230
+
231
+ try:
232
+ output_shape = layer.output_shape
233
+ except AttributeError:
234
+ output_shape = ["multiple"]
235
+
236
+ node.id = layer.name
237
+ node.name = layer.name
238
+ node.class_name = layer.__class__.__name__
239
+ node.output_shape = output_shape
240
+ node.num_parameters = layer.count_params()
241
+
242
+ return node
243
+
244
+
245
+ class Graph(Media):
246
+ """Wandb class for graphs.
247
+
248
+ This class is typically used for saving and displaying neural net models. It
249
+ represents the graph as an array of nodes and edges. The nodes can have
250
+ labels that can be visualized by wandb.
251
+
252
+ Examples:
253
+ Import a keras model:
254
+ ```
255
+ Graph.from_keras(keras_model)
256
+ ```
257
+
258
+ Attributes:
259
+ format (string): Format to help wandb display the graph nicely.
260
+ nodes ([wandb.Node]): List of wandb.Nodes
261
+ nodes_by_id (dict): dict of ids -> nodes
262
+ edges ([(wandb.Node, wandb.Node)]): List of pairs of nodes interpreted as edges
263
+ loaded (boolean): Flag to tell whether the graph is completely loaded
264
+ root (wandb.Node): root node of the graph
265
+ """
266
+
267
+ _log_type = "graph-file"
268
+
269
+ def __init__(self, format="keras"):
270
+ super().__init__()
271
+ # LB: TODO: I think we should factor criterion and criterion_passed out
272
+ self.format = format
273
+ self.nodes = []
274
+ self.nodes_by_id = {}
275
+ self.edges = []
276
+ self.loaded = False
277
+ self.criterion = None
278
+ self.criterion_passed = False
279
+ self.root = None # optional root Node if applicable
280
+
281
+ def _to_graph_json(self, run=None):
282
+ # Needs to be its own function for tests
283
+ return {
284
+ "format": self.format,
285
+ "nodes": [node.to_json() for node in self.nodes],
286
+ "edges": [edge.to_json() for edge in self.edges],
287
+ }
288
+
289
+ def bind_to_run(self, *args, **kwargs):
290
+ data = self._to_graph_json()
291
+ tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".graph.json")
292
+ data = _numpy_arrays_to_lists(data)
293
+ with codecs.open(tmp_path, "w", encoding="utf-8") as fp:
294
+ util.json_dump_safer(data, fp)
295
+ self._set_file(tmp_path, is_tmp=True, extension=".graph.json")
296
+ if self.is_bound():
297
+ return
298
+ super().bind_to_run(*args, **kwargs)
299
+
300
+ @classmethod
301
+ def get_media_subdir(cls):
302
+ return os.path.join("media", "graph")
303
+
304
+ def to_json(self, run):
305
+ json_dict = super().to_json(run)
306
+ json_dict["_type"] = self._log_type
307
+ return json_dict
308
+
309
+ def __getitem__(self, nid):
310
+ return self.nodes_by_id[nid]
311
+
312
+ def pprint(self):
313
+ for edge in self.edges:
314
+ pprint.pprint(edge.attributes)
315
+ for node in self.nodes:
316
+ pprint.pprint(node.attributes)
317
+
318
+ def add_node(self, node=None, **node_kwargs):
319
+ if node is None:
320
+ node = Node(**node_kwargs)
321
+ elif node_kwargs:
322
+ raise ValueError(
323
+ f"Only pass one of either node ({node}) or other keyword arguments ({node_kwargs})"
324
+ )
325
+ self.nodes.append(node)
326
+ self.nodes_by_id[node.id] = node
327
+
328
+ return node
329
+
330
+ def add_edge(self, from_node, to_node):
331
+ edge = Edge(from_node, to_node)
332
+ self.edges.append(edge)
333
+
334
+ return edge
335
+
336
+ @classmethod
337
+ def from_keras(cls, model):
338
+ # TODO: his method requires a refactor to work with the keras 3.
339
+ graph = cls()
340
+ # Shamelessly copied (then modified) from keras/keras/utils/layer_utils.py
341
+ sequential_like = cls._is_sequential(model)
342
+
343
+ relevant_nodes = None
344
+ if not sequential_like:
345
+ relevant_nodes = []
346
+ for v in model._nodes_by_depth.values():
347
+ relevant_nodes += v
348
+
349
+ layers = model.layers
350
+ for i in range(len(layers)):
351
+ node = Node.from_keras(layers[i])
352
+ if hasattr(layers[i], "_inbound_nodes"):
353
+ for in_node in layers[i]._inbound_nodes:
354
+ if relevant_nodes and in_node not in relevant_nodes:
355
+ # node is not part of the current network
356
+ continue
357
+ for in_layer in _nest(in_node.inbound_layers):
358
+ inbound_keras_node = Node.from_keras(in_layer)
359
+
360
+ if inbound_keras_node.id not in graph.nodes_by_id:
361
+ graph.add_node(inbound_keras_node)
362
+ inbound_node = graph.nodes_by_id[inbound_keras_node.id]
363
+
364
+ graph.add_edge(inbound_node, node)
365
+ graph.add_node(node)
366
+ return graph
367
+
368
+ @classmethod
369
+ def _is_sequential(cls, model):
370
+ sequential_like = True
371
+
372
+ if (
373
+ model.__class__.__name__ != "Sequential"
374
+ and hasattr(model, "_is_graph_network")
375
+ and model._is_graph_network
376
+ ):
377
+ nodes_by_depth = model._nodes_by_depth.values()
378
+ nodes = []
379
+ for v in nodes_by_depth:
380
+ # TensorFlow2 doesn't insure inbound is always a list
381
+ inbound = v[0].inbound_layers
382
+ if not hasattr(inbound, "__len__"):
383
+ inbound = [inbound]
384
+ if (len(v) > 1) or (len(v) == 1 and len(inbound) > 1):
385
+ # if the model has multiple nodes
386
+ # or if the nodes have multiple inbound_layers
387
+ # the model is no longer sequential
388
+ sequential_like = False
389
+ break
390
+ nodes += v
391
+ if sequential_like:
392
+ # search for shared layers
393
+ for layer in model.layers:
394
+ flag = False
395
+ if hasattr(layer, "_inbound_nodes"):
396
+ for node in layer._inbound_nodes:
397
+ if node in nodes:
398
+ if flag:
399
+ sequential_like = False
400
+ break
401
+ else:
402
+ flag = True
403
+ if not sequential_like:
404
+ break
405
+ return sequential_like