wandb 0.18.0rc1__py3-none-win_amd64.whl → 0.18.1__py3-none-win_amd64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (62) hide show
  1. wandb/__init__.py +2 -2
  2. wandb/__init__.pyi +1 -1
  3. wandb/apis/public/runs.py +2 -0
  4. wandb/bin/wandb-core +0 -0
  5. wandb/cli/cli.py +0 -2
  6. wandb/data_types.py +9 -2019
  7. wandb/env.py +0 -5
  8. wandb/{sklearn → integration/sklearn}/calculate/calibration_curves.py +7 -7
  9. wandb/{sklearn → integration/sklearn}/calculate/class_proportions.py +1 -1
  10. wandb/{sklearn → integration/sklearn}/calculate/confusion_matrix.py +3 -2
  11. wandb/{sklearn → integration/sklearn}/calculate/elbow_curve.py +6 -6
  12. wandb/{sklearn → integration/sklearn}/calculate/learning_curve.py +2 -2
  13. wandb/{sklearn → integration/sklearn}/calculate/outlier_candidates.py +2 -2
  14. wandb/{sklearn → integration/sklearn}/calculate/residuals.py +8 -8
  15. wandb/{sklearn → integration/sklearn}/calculate/silhouette.py +2 -2
  16. wandb/{sklearn → integration/sklearn}/calculate/summary_metrics.py +2 -2
  17. wandb/{sklearn → integration/sklearn}/plot/classifier.py +5 -5
  18. wandb/{sklearn → integration/sklearn}/plot/clusterer.py +10 -6
  19. wandb/{sklearn → integration/sklearn}/plot/regressor.py +5 -5
  20. wandb/{sklearn → integration/sklearn}/plot/shared.py +3 -3
  21. wandb/{sklearn → integration/sklearn}/utils.py +8 -8
  22. wandb/{wandb_torch.py → integration/torch/wandb_torch.py} +36 -32
  23. wandb/proto/v3/wandb_base_pb2.py +2 -1
  24. wandb/proto/v3/wandb_internal_pb2.py +2 -1
  25. wandb/proto/v3/wandb_server_pb2.py +2 -1
  26. wandb/proto/v3/wandb_settings_pb2.py +2 -1
  27. wandb/proto/v3/wandb_telemetry_pb2.py +2 -1
  28. wandb/proto/v4/wandb_base_pb2.py +2 -1
  29. wandb/proto/v4/wandb_internal_pb2.py +2 -1
  30. wandb/proto/v4/wandb_server_pb2.py +2 -1
  31. wandb/proto/v4/wandb_settings_pb2.py +2 -1
  32. wandb/proto/v4/wandb_telemetry_pb2.py +2 -1
  33. wandb/proto/v5/wandb_base_pb2.py +3 -2
  34. wandb/proto/v5/wandb_internal_pb2.py +3 -2
  35. wandb/proto/v5/wandb_server_pb2.py +3 -2
  36. wandb/proto/v5/wandb_settings_pb2.py +3 -2
  37. wandb/proto/v5/wandb_telemetry_pb2.py +3 -2
  38. wandb/sdk/data_types/audio.py +165 -0
  39. wandb/sdk/data_types/bokeh.py +70 -0
  40. wandb/sdk/data_types/graph.py +405 -0
  41. wandb/sdk/data_types/image.py +156 -0
  42. wandb/sdk/data_types/table.py +1204 -0
  43. wandb/sdk/data_types/trace_tree.py +2 -2
  44. wandb/sdk/data_types/utils.py +49 -0
  45. wandb/sdk/service/service.py +2 -9
  46. wandb/sdk/service/streams.py +0 -7
  47. wandb/sdk/wandb_init.py +10 -3
  48. wandb/sdk/wandb_run.py +6 -152
  49. wandb/sdk/wandb_setup.py +1 -1
  50. wandb/sklearn.py +35 -0
  51. wandb/util.py +6 -2
  52. {wandb-0.18.0rc1.dist-info → wandb-0.18.1.dist-info}/METADATA +5 -5
  53. {wandb-0.18.0rc1.dist-info → wandb-0.18.1.dist-info}/RECORD +61 -57
  54. wandb/sdk/lib/console.py +0 -39
  55. /wandb/{sklearn → integration/sklearn}/__init__.py +0 -0
  56. /wandb/{sklearn → integration/sklearn}/calculate/__init__.py +0 -0
  57. /wandb/{sklearn → integration/sklearn}/calculate/decision_boundaries.py +0 -0
  58. /wandb/{sklearn → integration/sklearn}/calculate/feature_importances.py +0 -0
  59. /wandb/{sklearn → integration/sklearn}/plot/__init__.py +0 -0
  60. {wandb-0.18.0rc1.dist-info → wandb-0.18.1.dist-info}/WHEEL +0 -0
  61. {wandb-0.18.0rc1.dist-info → wandb-0.18.1.dist-info}/entry_points.txt +0 -0
  62. {wandb-0.18.0rc1.dist-info → wandb-0.18.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,165 @@
1
+ import hashlib
2
+ import os
3
+ from typing import Optional
4
+
5
+ from wandb import util
6
+ from wandb.sdk.lib import filesystem, runid
7
+
8
+ from . import _dtypes
9
+ from ._private import MEDIA_TMP
10
+ from .base_types.media import BatchableMedia
11
+
12
+
13
+ class Audio(BatchableMedia):
14
+ """Wandb class for audio clips.
15
+
16
+ Arguments:
17
+ data_or_path: (string or numpy array) A path to an audio file
18
+ or a numpy array of audio data.
19
+ sample_rate: (int) Sample rate, required when passing in raw
20
+ numpy array of audio data.
21
+ caption: (string) Caption to display with audio.
22
+ """
23
+
24
+ _log_type = "audio-file"
25
+
26
+ def __init__(self, data_or_path, sample_rate=None, caption=None):
27
+ """Accept a path to an audio file or a numpy array of audio data."""
28
+ super().__init__()
29
+ self._duration = None
30
+ self._sample_rate = sample_rate
31
+ self._caption = caption
32
+
33
+ if isinstance(data_or_path, str):
34
+ if self.path_is_reference(data_or_path):
35
+ self._path = data_or_path
36
+ self._sha256 = hashlib.sha256(data_or_path.encode("utf-8")).hexdigest()
37
+ self._is_tmp = False
38
+ else:
39
+ self._set_file(data_or_path, is_tmp=False)
40
+ else:
41
+ if sample_rate is None:
42
+ raise ValueError(
43
+ 'Argument "sample_rate" is required when instantiating wandb.Audio with raw data.'
44
+ )
45
+
46
+ soundfile = util.get_module(
47
+ "soundfile",
48
+ required='Raw audio requires the soundfile package. To get it, run "pip install soundfile"',
49
+ )
50
+
51
+ tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".wav")
52
+ soundfile.write(tmp_path, data_or_path, sample_rate)
53
+ self._duration = len(data_or_path) / float(sample_rate)
54
+
55
+ self._set_file(tmp_path, is_tmp=True)
56
+
57
+ @classmethod
58
+ def get_media_subdir(cls):
59
+ return os.path.join("media", "audio")
60
+
61
+ @classmethod
62
+ def from_json(cls, json_obj, source_artifact):
63
+ return cls(
64
+ source_artifact.get_entry(json_obj["path"]).download(),
65
+ caption=json_obj["caption"],
66
+ )
67
+
68
+ def bind_to_run(
69
+ self, run, key, step, id_=None, ignore_copy_err: Optional[bool] = None
70
+ ):
71
+ if self.path_is_reference(self._path):
72
+ raise ValueError(
73
+ "Audio media created by a reference to external storage cannot currently be added to a run"
74
+ )
75
+
76
+ return super().bind_to_run(run, key, step, id_, ignore_copy_err)
77
+
78
+ def to_json(self, run):
79
+ json_dict = super().to_json(run)
80
+ json_dict.update(
81
+ {
82
+ "_type": self._log_type,
83
+ "caption": self._caption,
84
+ }
85
+ )
86
+ return json_dict
87
+
88
+ @classmethod
89
+ def seq_to_json(cls, seq, run, key, step):
90
+ audio_list = list(seq)
91
+
92
+ util.get_module(
93
+ "soundfile",
94
+ required="wandb.Audio requires the soundfile package. To get it, run: pip install soundfile",
95
+ )
96
+ base_path = os.path.join(run.dir, "media", "audio")
97
+ filesystem.mkdir_exists_ok(base_path)
98
+ meta = {
99
+ "_type": "audio",
100
+ "count": len(audio_list),
101
+ "audio": [a.to_json(run) for a in audio_list],
102
+ }
103
+ sample_rates = cls.sample_rates(audio_list)
104
+ if sample_rates:
105
+ meta["sampleRates"] = sample_rates
106
+ durations = cls.durations(audio_list)
107
+ if durations:
108
+ meta["durations"] = durations
109
+ captions = cls.captions(audio_list)
110
+ if captions:
111
+ meta["captions"] = captions
112
+
113
+ return meta
114
+
115
+ @classmethod
116
+ def durations(cls, audio_list):
117
+ return [a._duration for a in audio_list]
118
+
119
+ @classmethod
120
+ def sample_rates(cls, audio_list):
121
+ return [a._sample_rate for a in audio_list]
122
+
123
+ @classmethod
124
+ def captions(cls, audio_list):
125
+ captions = [a._caption for a in audio_list]
126
+ if all(c is None for c in captions):
127
+ return False
128
+ else:
129
+ return ["" if c is None else c for c in captions]
130
+
131
+ def resolve_ref(self):
132
+ if self.path_is_reference(self._path):
133
+ # this object was already created using a ref:
134
+ return self._path
135
+ source_artifact = self._artifact_source.artifact
136
+
137
+ resolved_name = source_artifact._local_path_to_name(self._path)
138
+ if resolved_name is not None:
139
+ target_entry = source_artifact.manifest.get_entry_by_path(resolved_name)
140
+ if target_entry is not None:
141
+ return target_entry.ref
142
+
143
+ return None
144
+
145
+ def __eq__(self, other):
146
+ if self.path_is_reference(self._path) or self.path_is_reference(other._path):
147
+ # one or more of these objects is an unresolved reference -- we'll compare
148
+ # their reference paths instead of their SHAs:
149
+ return (
150
+ self.resolve_ref() == other.resolve_ref()
151
+ and self._caption == other._caption
152
+ )
153
+
154
+ return super().__eq__(other) and self._caption == other._caption
155
+
156
+ def __ne__(self, other):
157
+ return not self.__eq__(other)
158
+
159
+
160
+ class _AudioFileType(_dtypes.Type):
161
+ name = "audio-file"
162
+ types = [Audio]
163
+
164
+
165
+ _dtypes.TypeRegistry.add(_AudioFileType)
@@ -0,0 +1,70 @@
1
+ import codecs
2
+ import json
3
+ import os
4
+
5
+ from wandb import util
6
+ from wandb.sdk.lib import runid
7
+
8
+ from . import _dtypes
9
+ from ._private import MEDIA_TMP
10
+ from .base_types.media import Media
11
+
12
+
13
+ class Bokeh(Media):
14
+ """Wandb class for Bokeh plots.
15
+
16
+ Arguments:
17
+ val: Bokeh plot
18
+ """
19
+
20
+ _log_type = "bokeh-file"
21
+
22
+ def __init__(self, data_or_path):
23
+ super().__init__()
24
+ bokeh = util.get_module("bokeh", required=True)
25
+ if isinstance(data_or_path, str) and os.path.exists(data_or_path):
26
+ with open(data_or_path) as file:
27
+ b_json = json.load(file)
28
+ self.b_obj = bokeh.document.Document.from_json(b_json)
29
+ self._set_file(data_or_path, is_tmp=False, extension=".bokeh.json")
30
+ elif isinstance(data_or_path, bokeh.model.Model):
31
+ _data = bokeh.document.Document()
32
+ _data.add_root(data_or_path)
33
+ # serialize/deserialize pairing followed by sorting attributes ensures
34
+ # that the file's sha's are equivalent in subsequent calls
35
+ self.b_obj = bokeh.document.Document.from_json(_data.to_json())
36
+ b_json = self.b_obj.to_json()
37
+ if "references" in b_json["roots"]:
38
+ b_json["roots"]["references"].sort(key=lambda x: x["id"])
39
+
40
+ tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".bokeh.json")
41
+ with codecs.open(tmp_path, "w", encoding="utf-8") as fp:
42
+ util.json_dump_safer(b_json, fp)
43
+ self._set_file(tmp_path, is_tmp=True, extension=".bokeh.json")
44
+ elif not isinstance(data_or_path, bokeh.document.Document):
45
+ raise TypeError(
46
+ "Bokeh constructor accepts Bokeh document/model or path to Bokeh json file"
47
+ )
48
+
49
+ def get_media_subdir(self):
50
+ return os.path.join("media", "bokeh")
51
+
52
+ def to_json(self, run):
53
+ # TODO: (tss) this is getting redundant for all the media objects. We can probably
54
+ # pull this into Media#to_json and remove this type override for all the media types.
55
+ # There are only a few cases where the type is different between artifacts and runs.
56
+ json_dict = super().to_json(run)
57
+ json_dict["_type"] = self._log_type
58
+ return json_dict
59
+
60
+ @classmethod
61
+ def from_json(cls, json_obj, source_artifact):
62
+ return cls(source_artifact.get_entry(json_obj["path"]).download())
63
+
64
+
65
+ class _BokehFileType(_dtypes.Type):
66
+ name = "bokeh-file"
67
+ types = [Bokeh]
68
+
69
+
70
+ _dtypes.TypeRegistry.add(_BokehFileType)
@@ -0,0 +1,405 @@
1
+ import codecs
2
+ import os
3
+ import pprint
4
+
5
+ from wandb import util
6
+ from wandb.sdk.data_types._private import MEDIA_TMP
7
+ from wandb.sdk.data_types.base_types.media import Media, _numpy_arrays_to_lists
8
+ from wandb.sdk.data_types.base_types.wb_value import WBValue
9
+ from wandb.sdk.lib import runid
10
+
11
+
12
+ def _nest(thing):
13
+ # Use tensorflows nest function if available, otherwise just wrap object in an array"""
14
+
15
+ tfutil = util.get_module("tensorflow.python.util")
16
+ if tfutil:
17
+ return tfutil.nest.flatten(thing)
18
+ else:
19
+ return [thing]
20
+
21
+
22
+ class Edge(WBValue):
23
+ """Edge used in `Graph`."""
24
+
25
+ def __init__(self, from_node, to_node):
26
+ self._attributes = {}
27
+ self.from_node = from_node
28
+ self.to_node = to_node
29
+
30
+ def __repr__(self):
31
+ temp_attr = dict(self._attributes)
32
+ del temp_attr["from_node"]
33
+ del temp_attr["to_node"]
34
+ temp_attr["from_id"] = self.from_node.id
35
+ temp_attr["to_id"] = self.to_node.id
36
+ return str(temp_attr)
37
+
38
+ def to_json(self, run=None):
39
+ return [self.from_node.id, self.to_node.id]
40
+
41
+ @property
42
+ def name(self):
43
+ """Optional, not necessarily unique."""
44
+ return self._attributes.get("name")
45
+
46
+ @name.setter
47
+ def name(self, val):
48
+ self._attributes["name"] = val
49
+ return val
50
+
51
+ @property
52
+ def from_node(self):
53
+ return self._attributes.get("from_node")
54
+
55
+ @from_node.setter
56
+ def from_node(self, val):
57
+ self._attributes["from_node"] = val
58
+ return val
59
+
60
+ @property
61
+ def to_node(self):
62
+ return self._attributes.get("to_node")
63
+
64
+ @to_node.setter
65
+ def to_node(self, val):
66
+ self._attributes["to_node"] = val
67
+ return val
68
+
69
+
70
+ class Node(WBValue):
71
+ """Node used in `Graph`."""
72
+
73
+ def __init__(
74
+ self,
75
+ id=None,
76
+ name=None,
77
+ class_name=None,
78
+ size=None,
79
+ parameters=None,
80
+ output_shape=None,
81
+ is_output=None,
82
+ num_parameters=None,
83
+ node=None,
84
+ ):
85
+ self._attributes = {"name": None}
86
+ self.in_edges = {} # indexed by source node id
87
+ self.out_edges = {} # indexed by dest node id
88
+ # optional object (e.g. PyTorch Parameter or Module) that this Node represents
89
+ self.obj = None
90
+
91
+ if node is not None:
92
+ self._attributes.update(node._attributes)
93
+ del self._attributes["id"]
94
+ self.obj = node.obj
95
+
96
+ if id is not None:
97
+ self.id = id
98
+ if name is not None:
99
+ self.name = name
100
+ if class_name is not None:
101
+ self.class_name = class_name
102
+ if size is not None:
103
+ self.size = size
104
+ if parameters is not None:
105
+ self.parameters = parameters
106
+ if output_shape is not None:
107
+ self.output_shape = output_shape
108
+ if is_output is not None:
109
+ self.is_output = is_output
110
+ if num_parameters is not None:
111
+ self.num_parameters = num_parameters
112
+
113
+ def to_json(self, run=None):
114
+ return self._attributes
115
+
116
+ def __repr__(self):
117
+ return repr(self._attributes)
118
+
119
+ @property
120
+ def id(self):
121
+ """Must be unique in the graph."""
122
+ return self._attributes.get("id")
123
+
124
+ @id.setter
125
+ def id(self, val):
126
+ self._attributes["id"] = val
127
+ return val
128
+
129
+ @property
130
+ def name(self):
131
+ """Usually the type of layer or sublayer."""
132
+ return self._attributes.get("name")
133
+
134
+ @name.setter
135
+ def name(self, val):
136
+ self._attributes["name"] = val
137
+ return val
138
+
139
+ @property
140
+ def class_name(self):
141
+ """Usually the type of layer or sublayer."""
142
+ return self._attributes.get("class_name")
143
+
144
+ @class_name.setter
145
+ def class_name(self, val):
146
+ self._attributes["class_name"] = val
147
+ return val
148
+
149
+ @property
150
+ def functions(self):
151
+ return self._attributes.get("functions", [])
152
+
153
+ @functions.setter
154
+ def functions(self, val):
155
+ self._attributes["functions"] = val
156
+ return val
157
+
158
+ @property
159
+ def parameters(self):
160
+ return self._attributes.get("parameters", [])
161
+
162
+ @parameters.setter
163
+ def parameters(self, val):
164
+ self._attributes["parameters"] = val
165
+ return val
166
+
167
+ @property
168
+ def size(self):
169
+ return self._attributes.get("size")
170
+
171
+ @size.setter
172
+ def size(self, val):
173
+ """Tensor size."""
174
+ self._attributes["size"] = tuple(val)
175
+ return val
176
+
177
+ @property
178
+ def output_shape(self):
179
+ return self._attributes.get("output_shape")
180
+
181
+ @output_shape.setter
182
+ def output_shape(self, val):
183
+ """Tensor output_shape."""
184
+ self._attributes["output_shape"] = val
185
+ return val
186
+
187
+ @property
188
+ def is_output(self):
189
+ return self._attributes.get("is_output")
190
+
191
+ @is_output.setter
192
+ def is_output(self, val):
193
+ """Tensor is_output."""
194
+ self._attributes["is_output"] = val
195
+ return val
196
+
197
+ @property
198
+ def num_parameters(self):
199
+ return self._attributes.get("num_parameters")
200
+
201
+ @num_parameters.setter
202
+ def num_parameters(self, val):
203
+ """Tensor num_parameters."""
204
+ self._attributes["num_parameters"] = val
205
+ return val
206
+
207
+ @property
208
+ def child_parameters(self):
209
+ return self._attributes.get("child_parameters")
210
+
211
+ @child_parameters.setter
212
+ def child_parameters(self, val):
213
+ """Tensor child_parameters."""
214
+ self._attributes["child_parameters"] = val
215
+ return val
216
+
217
+ @property
218
+ def is_constant(self):
219
+ return self._attributes.get("is_constant")
220
+
221
+ @is_constant.setter
222
+ def is_constant(self, val):
223
+ """Tensor is_constant."""
224
+ self._attributes["is_constant"] = val
225
+ return val
226
+
227
+ @classmethod
228
+ def from_keras(cls, layer):
229
+ node = cls()
230
+
231
+ try:
232
+ output_shape = layer.output_shape
233
+ except AttributeError:
234
+ output_shape = ["multiple"]
235
+
236
+ node.id = layer.name
237
+ node.name = layer.name
238
+ node.class_name = layer.__class__.__name__
239
+ node.output_shape = output_shape
240
+ node.num_parameters = layer.count_params()
241
+
242
+ return node
243
+
244
+
245
+ class Graph(Media):
246
+ """Wandb class for graphs.
247
+
248
+ This class is typically used for saving and displaying neural net models. It
249
+ represents the graph as an array of nodes and edges. The nodes can have
250
+ labels that can be visualized by wandb.
251
+
252
+ Examples:
253
+ Import a keras model:
254
+ ```
255
+ Graph.from_keras(keras_model)
256
+ ```
257
+
258
+ Attributes:
259
+ format (string): Format to help wandb display the graph nicely.
260
+ nodes ([wandb.Node]): List of wandb.Nodes
261
+ nodes_by_id (dict): dict of ids -> nodes
262
+ edges ([(wandb.Node, wandb.Node)]): List of pairs of nodes interpreted as edges
263
+ loaded (boolean): Flag to tell whether the graph is completely loaded
264
+ root (wandb.Node): root node of the graph
265
+ """
266
+
267
+ _log_type = "graph-file"
268
+
269
+ def __init__(self, format="keras"):
270
+ super().__init__()
271
+ # LB: TODO: I think we should factor criterion and criterion_passed out
272
+ self.format = format
273
+ self.nodes = []
274
+ self.nodes_by_id = {}
275
+ self.edges = []
276
+ self.loaded = False
277
+ self.criterion = None
278
+ self.criterion_passed = False
279
+ self.root = None # optional root Node if applicable
280
+
281
+ def _to_graph_json(self, run=None):
282
+ # Needs to be its own function for tests
283
+ return {
284
+ "format": self.format,
285
+ "nodes": [node.to_json() for node in self.nodes],
286
+ "edges": [edge.to_json() for edge in self.edges],
287
+ }
288
+
289
+ def bind_to_run(self, *args, **kwargs):
290
+ data = self._to_graph_json()
291
+ tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".graph.json")
292
+ data = _numpy_arrays_to_lists(data)
293
+ with codecs.open(tmp_path, "w", encoding="utf-8") as fp:
294
+ util.json_dump_safer(data, fp)
295
+ self._set_file(tmp_path, is_tmp=True, extension=".graph.json")
296
+ if self.is_bound():
297
+ return
298
+ super().bind_to_run(*args, **kwargs)
299
+
300
+ @classmethod
301
+ def get_media_subdir(cls):
302
+ return os.path.join("media", "graph")
303
+
304
+ def to_json(self, run):
305
+ json_dict = super().to_json(run)
306
+ json_dict["_type"] = self._log_type
307
+ return json_dict
308
+
309
+ def __getitem__(self, nid):
310
+ return self.nodes_by_id[nid]
311
+
312
+ def pprint(self):
313
+ for edge in self.edges:
314
+ pprint.pprint(edge.attributes)
315
+ for node in self.nodes:
316
+ pprint.pprint(node.attributes)
317
+
318
+ def add_node(self, node=None, **node_kwargs):
319
+ if node is None:
320
+ node = Node(**node_kwargs)
321
+ elif node_kwargs:
322
+ raise ValueError(
323
+ f"Only pass one of either node ({node}) or other keyword arguments ({node_kwargs})"
324
+ )
325
+ self.nodes.append(node)
326
+ self.nodes_by_id[node.id] = node
327
+
328
+ return node
329
+
330
+ def add_edge(self, from_node, to_node):
331
+ edge = Edge(from_node, to_node)
332
+ self.edges.append(edge)
333
+
334
+ return edge
335
+
336
+ @classmethod
337
+ def from_keras(cls, model):
338
+ # TODO: his method requires a refactor to work with the keras 3.
339
+ graph = cls()
340
+ # Shamelessly copied (then modified) from keras/keras/utils/layer_utils.py
341
+ sequential_like = cls._is_sequential(model)
342
+
343
+ relevant_nodes = None
344
+ if not sequential_like:
345
+ relevant_nodes = []
346
+ for v in model._nodes_by_depth.values():
347
+ relevant_nodes += v
348
+
349
+ layers = model.layers
350
+ for i in range(len(layers)):
351
+ node = Node.from_keras(layers[i])
352
+ if hasattr(layers[i], "_inbound_nodes"):
353
+ for in_node in layers[i]._inbound_nodes:
354
+ if relevant_nodes and in_node not in relevant_nodes:
355
+ # node is not part of the current network
356
+ continue
357
+ for in_layer in _nest(in_node.inbound_layers):
358
+ inbound_keras_node = Node.from_keras(in_layer)
359
+
360
+ if inbound_keras_node.id not in graph.nodes_by_id:
361
+ graph.add_node(inbound_keras_node)
362
+ inbound_node = graph.nodes_by_id[inbound_keras_node.id]
363
+
364
+ graph.add_edge(inbound_node, node)
365
+ graph.add_node(node)
366
+ return graph
367
+
368
+ @classmethod
369
+ def _is_sequential(cls, model):
370
+ sequential_like = True
371
+
372
+ if (
373
+ model.__class__.__name__ != "Sequential"
374
+ and hasattr(model, "_is_graph_network")
375
+ and model._is_graph_network
376
+ ):
377
+ nodes_by_depth = model._nodes_by_depth.values()
378
+ nodes = []
379
+ for v in nodes_by_depth:
380
+ # TensorFlow2 doesn't insure inbound is always a list
381
+ inbound = v[0].inbound_layers
382
+ if not hasattr(inbound, "__len__"):
383
+ inbound = [inbound]
384
+ if (len(v) > 1) or (len(v) == 1 and len(inbound) > 1):
385
+ # if the model has multiple nodes
386
+ # or if the nodes have multiple inbound_layers
387
+ # the model is no longer sequential
388
+ sequential_like = False
389
+ break
390
+ nodes += v
391
+ if sequential_like:
392
+ # search for shared layers
393
+ for layer in model.layers:
394
+ flag = False
395
+ if hasattr(layer, "_inbound_nodes"):
396
+ for node in layer._inbound_nodes:
397
+ if node in nodes:
398
+ if flag:
399
+ sequential_like = False
400
+ break
401
+ else:
402
+ flag = True
403
+ if not sequential_like:
404
+ break
405
+ return sequential_like