leap-model-parser 0.1.184__tar.gz → 0.1.185.dev2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/PKG-INFO +2 -1
- leap_model_parser-0.1.185.dev2/leap_model_parser/leap_graph_editor.py +365 -0
- {leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/leap_model_parser/model_parser.py +12 -2
- {leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/pyproject.toml +2 -1
- {leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/LICENSE +0 -0
- {leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/README.md +0 -0
- {leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/leap_model_parser/__init__.py +0 -0
- {leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/leap_model_parser/contract/__init__.py +0 -0
- {leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/leap_model_parser/contract/graph.py +0 -0
- {leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/leap_model_parser/contract/importmodelresponse.py +0 -0
- {leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/leap_model_parser/contract/nodedata.py +0 -0
- {leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/leap_model_parser/contract/ui_components.json +0 -0
- {leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/leap_model_parser/keras_json_model_import.py +0 -0
- {leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/leap_model_parser/utils/__init__.py +0 -0
- {leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/leap_model_parser/utils/layerpedia/__init__.py +0 -0
- {leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/leap_model_parser/utils/layerpedia/layerpedia.py +0 -0
- {leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/leap_model_parser/utils/tlinspection/__init__.py +0 -0
- {leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/leap_model_parser/utils/tlinspection/leapinspection.py +0 -0
- {leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/leap_model_parser/utils/uicomponents/__init__.py +0 -0
- {leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/leap_model_parser/utils/uicomponents/generatenodedata.py +0 -0
- {leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/leap_model_parser/utils/uicomponents/tensorflowinscpection.py +0 -0
- {leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/leap_model_parser/utils/uicomponents/ui_components.json +0 -0
- {leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/leap_model_parser/utils/uicomponents/ui_components_config.yaml +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: leap-model-parser
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.185.dev2
|
4
4
|
Summary:
|
5
5
|
Home-page: https://github.com/tensorleap/leap-model-parser
|
6
6
|
License: MIT
|
@@ -12,6 +12,7 @@ Classifier: Programming Language :: Python :: 3
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.8
|
13
13
|
Classifier: Programming Language :: Python :: 3.9
|
14
14
|
Classifier: Programming Language :: Python :: 3.10
|
15
|
+
Requires-Dist: code-loader (==1.0.87.dev2)
|
15
16
|
Requires-Dist: keras-data-format-converter (==0.1.22)
|
16
17
|
Requires-Dist: leap-model-rebuilder (==0.1.7)
|
17
18
|
Requires-Dist: numpy (>=1.22.3,<2.0.0)
|
@@ -0,0 +1,365 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from typing import Optional, Dict, Any, List
|
3
|
+
|
4
|
+
from code_loader.contract.mapping import NodeConnection, NodeMappingType, NodeMapping
|
5
|
+
from keras import Model
|
6
|
+
|
7
|
+
from leap_model_parser.contract.graph import Node as Node, OutputData, ConnectionOutput, ConnectionInput, InputData
|
8
|
+
|
9
|
+
|
10
|
+
# class NodeMappingType(Enum):
|
11
|
+
# Visualizer = 'Visualizer'
|
12
|
+
# Metric = 'Metric'
|
13
|
+
# GroundTruth = 'GroundTruth'
|
14
|
+
# Input = 'Input'
|
15
|
+
# Layer = 'Layer'
|
16
|
+
# Loss = 'Loss'
|
17
|
+
# CustomLoss = 'CustomLoss'
|
18
|
+
# Optimizer = 'Optimizer'
|
19
|
+
# Prediction0 = 'Prediction0'
|
20
|
+
# Prediction1 = 'Prediction1'
|
21
|
+
# Prediction2 = 'Prediction2'
|
22
|
+
# Prediction3 = 'Prediction3'
|
23
|
+
# Input0 = 'Input0'
|
24
|
+
# Input1 = 'Input1'
|
25
|
+
# Input2 = 'Input2'
|
26
|
+
# Input3 = 'Input3'
|
27
|
+
# Input4 = 'Input4'
|
28
|
+
# Input5 = 'Input5'
|
29
|
+
#
|
30
|
+
#
|
31
|
+
#
|
32
|
+
#
|
33
|
+
#
|
34
|
+
# @dataclass
|
35
|
+
# class NodeMapping:
|
36
|
+
# name: str
|
37
|
+
# type: NodeMappingType
|
38
|
+
# user_unique_name: Optional[str] = None
|
39
|
+
# sub_type: Optional[str] = None
|
40
|
+
# arg_names: Optional[List[str]] = None
|
41
|
+
#
|
42
|
+
#
|
43
|
+
# @dataclass
|
44
|
+
# class NodeConnection:
|
45
|
+
# node: NodeMapping
|
46
|
+
# node_inputs: Optional[Dict[str, NodeMapping]]
|
47
|
+
# prediction_type_name: Optional[str] = None
|
48
|
+
|
49
|
+
|
50
|
+
class LeapGraphEditor:
|
51
|
+
def __init__(self, model_graph: Dict[str, Node], keras_model: Model):
|
52
|
+
self.model_graph = model_graph
|
53
|
+
self.keras_model = keras_model
|
54
|
+
|
55
|
+
node_ids_as_int = [int(node_id) for node_id in model_graph.keys()]
|
56
|
+
self._next_node_id_index = max(node_ids_as_int) + 1
|
57
|
+
|
58
|
+
# def add_dataset(self, dataset_name: str, raw_dataset_version: Dict[str, Any],
|
59
|
+
# dataset_parse_result: DatasetIntegParseResult):
|
60
|
+
#
|
61
|
+
# LeapGraphEditor._add_setup_to_metadata(raw_dataset_version['metadata'], dataset_parse_result)
|
62
|
+
# raw_dataset_version['name'] = dataset_name
|
63
|
+
#
|
64
|
+
# dataset_node = self._get_dataset_node()
|
65
|
+
# dataset_node.data['datasetVersion'] = raw_dataset_version
|
66
|
+
# dataset_node.data['selected_dataset'] = dataset_name
|
67
|
+
# self._add_arg_names_to_visualizers(dataset_parse_result)
|
68
|
+
|
69
|
+
def add_connections_to_graph(self, connections: List[NodeConnection]):
|
70
|
+
connections = self._validate_and_reorder_connections_list(connections)
|
71
|
+
for connection in connections:
|
72
|
+
self._add_node_connection_to_graph(connection)
|
73
|
+
|
74
|
+
def _add_node_connection_to_graph(self, node_connection: NodeConnection):
|
75
|
+
# if node_connection.node.type.value.startswith('Input'):
|
76
|
+
# input_index = int(node_connection.node.type.value.replace('Input', ''))
|
77
|
+
#
|
78
|
+
# origin_name = self.keras_model.inputs[input_index].node.layer.name
|
79
|
+
#
|
80
|
+
# _find_node_by_origin_name
|
81
|
+
# # elif node_connection.node.type == NodeMappingType.Input:
|
82
|
+
# # self._find_or_add_input_node()
|
83
|
+
|
84
|
+
|
85
|
+
if node_connection.node.type == NodeMappingType.Visualizer:
|
86
|
+
new_visualizer_node_id = self._add_visualizer_node(
|
87
|
+
node_connection.node.name, node_connection.node.sub_type,
|
88
|
+
node_connection.node.user_unique_name, node_connection.node.arg_names)
|
89
|
+
for input_name, node in node_connection.node_inputs.items():
|
90
|
+
input_node_id = self._find_or_add_input_node(node)
|
91
|
+
self._add_connection_to_node(new_visualizer_node_id, input_name, input_node_id)
|
92
|
+
elif node_connection.node.type == NodeMappingType.Metric:
|
93
|
+
new_metric_node_id = self._add_metric_node(
|
94
|
+
node_connection.node.name,
|
95
|
+
node_connection.node.user_unique_name, node_connection.node.arg_names)
|
96
|
+
for input_name, node in node_connection.node_inputs.items():
|
97
|
+
input_node_id = self._find_or_add_input_node(node)
|
98
|
+
self._add_connection_to_node(new_metric_node_id, input_name, input_node_id)
|
99
|
+
elif node_connection.node.type in (NodeMappingType.Loss, NodeMappingType.CustomLoss):
|
100
|
+
prediction_type_name = node_connection.prediction_type_name
|
101
|
+
# if prediction_type_name is None:
|
102
|
+
# raise Exception("prediction_type_name is required for loss connection")
|
103
|
+
|
104
|
+
new_loss_node_id = self._add_loss_node(node_connection.node.name,
|
105
|
+
node_connection.node.type == NodeMappingType.CustomLoss)
|
106
|
+
for input_name, node in node_connection.node_inputs.items():
|
107
|
+
input_node_id = self._find_or_add_input_node(node)
|
108
|
+
# if node.type == NodeMappingType.Layer:
|
109
|
+
# self.model_graph[input_node_id].data['prediction_type'] = prediction_type_name
|
110
|
+
self._add_connection_to_node(new_loss_node_id, input_name, input_node_id)
|
111
|
+
# elif node_connection.node.type == NodeMappingType.Optimizer:
|
112
|
+
# new_optimizer_node_id = self._add_optimizer_node(node_connection.node.name)
|
113
|
+
# loss_node_ids = self._get_all_loss_node_ids()
|
114
|
+
# assert len(loss_node_ids) > 0
|
115
|
+
# for i, loss_node_id in enumerate(loss_node_ids):
|
116
|
+
# self._add_connection_to_node(new_optimizer_node_id, str(i), loss_node_id)
|
117
|
+
# self.model_graph[new_optimizer_node_id].data['custom_input_keys'] = list(
|
118
|
+
# self.model_graph[new_optimizer_node_id].inputs.keys())
|
119
|
+
else:
|
120
|
+
raise Exception(f"Can't add node of type {node_connection.node.type.name}")
|
121
|
+
|
122
|
+
def model_graph_dict(self) -> Dict[str, Any]:
|
123
|
+
json_model_graph = {}
|
124
|
+
for node_id, node in self.model_graph.items():
|
125
|
+
json_model_graph[node_id] = node.__dict__
|
126
|
+
|
127
|
+
return json_model_graph
|
128
|
+
|
129
|
+
|
130
|
+
def _find_node_by_origin_name(self, origin_name: str) -> Optional[Node]:
|
131
|
+
for node in self.model_graph.values():
|
132
|
+
if node.data.get('origin_name') == origin_name:
|
133
|
+
return node
|
134
|
+
return None
|
135
|
+
|
136
|
+
def _find_input_node_by_origin_name(self, origin_name: str) -> Optional[Node]:
|
137
|
+
for node in self.model_graph.values():
|
138
|
+
if node.data.get('output_name') == origin_name:
|
139
|
+
return node
|
140
|
+
return None
|
141
|
+
|
142
|
+
def _validate_and_reorder_connections_list(self, connections: List[NodeConnection]) -> List[NodeConnection]:
|
143
|
+
# optimizers = [connection for connection in connections if connection.node.type == NodeType.Optimizer]
|
144
|
+
for connection in connections:
|
145
|
+
if connection.node_inputs is None:
|
146
|
+
continue
|
147
|
+
for input_name, input_node in connection.node_inputs.items():
|
148
|
+
if 'Prediction' in input_node.type.value:
|
149
|
+
prediction_index= int(input_node.type.value.replace('Prediction', ''))
|
150
|
+
origin_name = self.keras_model.outputs[prediction_index].node.layer.name
|
151
|
+
input_node.name = origin_name
|
152
|
+
|
153
|
+
return connections
|
154
|
+
losses = [connection for connection in connections
|
155
|
+
if connection.node.type in (NodeMappingType.Loss, NodeMappingType.CustomLoss)]
|
156
|
+
visualizers = [connection for connection in connections if connection.node.type == NodeMappingType.Visualizer]
|
157
|
+
|
158
|
+
# if len(optimizers) == 0:
|
159
|
+
# raise Exception('At least one optimizer needed')
|
160
|
+
# if len(losses) == 0:
|
161
|
+
# raise Exception('At least one loss needed')
|
162
|
+
# if len(optimizers) + len(losses) + len(visualizers) < len(connections):
|
163
|
+
# raise Exception('Unsupported node type')
|
164
|
+
|
165
|
+
return visualizers + losses
|
166
|
+
|
167
|
+
def _find_encoder_node_id(self, encoder_name: str) -> Optional[str]:
|
168
|
+
for node_id, node_response in self.model_graph.items():
|
169
|
+
if 'type' in node_response.data and (node_response.data['type'] in ('Input', 'GroundTruth')):
|
170
|
+
if f'{node_id}-{encoder_name}' in node_response.outputs:
|
171
|
+
return node_id
|
172
|
+
return None
|
173
|
+
|
174
|
+
def _find_layer_node_id(self, layer_name: str) -> str:
|
175
|
+
for node_id, node_response in self.model_graph.items():
|
176
|
+
if 'type' in node_response.data and node_response.data['type'] == 'Layer':
|
177
|
+
if node_response.data['origin_name'] == layer_name:
|
178
|
+
return node_id
|
179
|
+
raise Exception(f"Couldn't find node for layer {layer_name}")
|
180
|
+
|
181
|
+
def _generate_new_node_id(self) -> str:
|
182
|
+
self._next_node_id_index += 1
|
183
|
+
return str(self._next_node_id_index - 1)
|
184
|
+
|
185
|
+
def _add_ground_truth_node(self, ground_truth_name: str) -> str:
|
186
|
+
new_node_id = self._generate_new_node_id()
|
187
|
+
ground_truth_node = Node(
|
188
|
+
new_node_id,
|
189
|
+
'GroundTruth',
|
190
|
+
position=[0, 0],
|
191
|
+
data={'name': ground_truth_name, 'output_name': ground_truth_name,
|
192
|
+
'type': 'GroundTruth', "selected": ground_truth_name},
|
193
|
+
inputs={},
|
194
|
+
outputs={
|
195
|
+
f'{new_node_id}-{ground_truth_name}': ConnectionOutput([])
|
196
|
+
}
|
197
|
+
)
|
198
|
+
self.model_graph[new_node_id] = ground_truth_node
|
199
|
+
return new_node_id
|
200
|
+
|
201
|
+
def _add_visualizer_node(self, visualizer_name: str, visualizer_type: str,
|
202
|
+
user_unique_name: str, arg_names: List[str]) -> str:
|
203
|
+
new_node_id = self._generate_new_node_id()
|
204
|
+
|
205
|
+
visualizer_node = Node(
|
206
|
+
new_node_id,
|
207
|
+
'Visualizer',
|
208
|
+
position=[0, 0],
|
209
|
+
data={'visualizer_name': visualizer_name, 'type': 'Visualizer',
|
210
|
+
'selected': visualizer_name, 'name': visualizer_name, 'visualizer_type': visualizer_type,
|
211
|
+
'arg_names': arg_names, "user_unique_name": user_unique_name},
|
212
|
+
inputs={},
|
213
|
+
outputs={})
|
214
|
+
|
215
|
+
self.model_graph[new_node_id] = visualizer_node
|
216
|
+
return new_node_id
|
217
|
+
|
218
|
+
def _add_metric_node(self, metric_name: str,
|
219
|
+
user_unique_name: str, arg_names: List[str]) -> str:
|
220
|
+
new_node_id = self._generate_new_node_id()
|
221
|
+
|
222
|
+
metric_node = Node(
|
223
|
+
new_node_id,
|
224
|
+
'Metric',
|
225
|
+
position=[0, 0],
|
226
|
+
data={'metric_name': metric_name, 'type': 'Metric', 'name': metric_name,
|
227
|
+
'arg_names': arg_names, "user_unique_name": user_unique_name},
|
228
|
+
inputs={},
|
229
|
+
outputs={})
|
230
|
+
|
231
|
+
self.model_graph[new_node_id] = metric_node
|
232
|
+
return new_node_id
|
233
|
+
|
234
|
+
def _add_loss_node(self, loss_name: str, is_custom_loss: bool) -> str:
|
235
|
+
new_node_id = self._generate_new_node_id()
|
236
|
+
|
237
|
+
loss_type = 'CustomLoss' if is_custom_loss else 'Loss'
|
238
|
+
loss_node_name = 'CustomLoss' if is_custom_loss else loss_name
|
239
|
+
|
240
|
+
loss_node = Node(
|
241
|
+
new_node_id,
|
242
|
+
loss_node_name,
|
243
|
+
position=[0, 0],
|
244
|
+
data={'type': loss_type, 'selected': loss_name, 'name': loss_name},
|
245
|
+
inputs={},
|
246
|
+
outputs={
|
247
|
+
f'{new_node_id}-loss': ConnectionOutput([])
|
248
|
+
}
|
249
|
+
# outputs={
|
250
|
+
# f'{new_node_id}-loss': {'connections': []}
|
251
|
+
# }
|
252
|
+
)
|
253
|
+
|
254
|
+
self.model_graph[new_node_id] = loss_node
|
255
|
+
return new_node_id
|
256
|
+
|
257
|
+
# def _add_optimizer_node(self, optimizer_name: str) -> str:
|
258
|
+
# new_node_id = self._generate_new_node_id()
|
259
|
+
#
|
260
|
+
# optimizer_node = NodeResponse(
|
261
|
+
# new_node_id,
|
262
|
+
# optimizer_name,
|
263
|
+
# data={'type': 'Optimizer', 'selected': optimizer_name},
|
264
|
+
# inputs={},
|
265
|
+
# outputs={})
|
266
|
+
#
|
267
|
+
# self.model_graph[new_node_id] = optimizer_node
|
268
|
+
# return new_node_id
|
269
|
+
|
270
|
+
def _get_output_name_from_node_id(self, input_node_id: str, input_name: Optional[str] = None) -> str:
|
271
|
+
input_node_outputs_len = len(self.model_graph[input_node_id].outputs)
|
272
|
+
if input_node_outputs_len == 0:
|
273
|
+
output_name_to_add = f'{input_node_id}-feature_map'
|
274
|
+
self.model_graph[input_node_id].outputs[output_name_to_add] = ConnectionOutput([])
|
275
|
+
|
276
|
+
# self.model_graph[input_node_id].outputs[output_name_to_add] = {
|
277
|
+
# 'connections': []
|
278
|
+
# }
|
279
|
+
return output_name_to_add
|
280
|
+
if input_node_outputs_len == 1:
|
281
|
+
return list(self.model_graph[input_node_id].outputs.keys())[0]
|
282
|
+
if input_name is not None:
|
283
|
+
guessed_output_name = f'{input_node_id}-{input_name}'
|
284
|
+
if guessed_output_name in self.model_graph[input_node_id].outputs:
|
285
|
+
return guessed_output_name
|
286
|
+
|
287
|
+
# todo: layers with multiple outputs
|
288
|
+
raise Exception("Can't decide on output name")
|
289
|
+
|
290
|
+
def _add_connection_to_node(self, node_id: str, input_name: str, input_node_id: str):
|
291
|
+
# todo: layers with multiple outputs
|
292
|
+
output_name = self._get_output_name_from_node_id(input_node_id, input_name)
|
293
|
+
input_name = f'{node_id}-{input_name}'
|
294
|
+
self.model_graph[node_id].inputs[input_name] = ConnectionInput([InputData(input_node_id, output_name)])
|
295
|
+
# self.model_graph[node_id].inputs[input_name] = {
|
296
|
+
# 'connections': [{'data': {}, 'node': input_node_id, 'output': output_name}]
|
297
|
+
# }
|
298
|
+
|
299
|
+
# if 'connections' not in self.model_graph[input_node_id].outputs[output_name]:
|
300
|
+
# self.model_graph[input_node_id].outputs[output_name]['connections'] = []
|
301
|
+
output_connection = OutputData(node_id, input_name)
|
302
|
+
# output_connection = {'input': input_name, 'node': node_id, 'data': {}}
|
303
|
+
self.model_graph[input_node_id].outputs[output_name].connections.append(output_connection)
|
304
|
+
|
305
|
+
def _find_or_add_input_node(self, input_node: NodeMapping) -> str:
|
306
|
+
if input_node.type in (NodeMappingType.Input, NodeMappingType.GroundTruth):
|
307
|
+
input_node_id = self._find_encoder_node_id(input_node.name)
|
308
|
+
if input_node_id is None:
|
309
|
+
input_node_id = self._add_ground_truth_node(input_node.name)
|
310
|
+
elif input_node.type.value.startswith('Prediction'):
|
311
|
+
input_node_id = self._find_node_by_origin_name(input_node.name).id
|
312
|
+
else:
|
313
|
+
input_node_id = self._find_layer_node_id(input_node.name)
|
314
|
+
|
315
|
+
return input_node_id
|
316
|
+
|
317
|
+
def _find_prediction_node(self, prediction_index):
|
318
|
+
pass
|
319
|
+
|
320
|
+
def _get_all_loss_node_ids(self):
|
321
|
+
loss_node_ids = []
|
322
|
+
for node_id, node_response in self.model_graph.items():
|
323
|
+
if 'type' in node_response.data and node_response.data['type'] in ('CustomLoss', 'Loss'):
|
324
|
+
loss_node_ids.append(node_id)
|
325
|
+
return loss_node_ids
|
326
|
+
|
327
|
+
# def _get_dataset_node(self) -> NodeResponse:
|
328
|
+
# for node_response in self.model_graph.values():
|
329
|
+
# if 'type' in node_response.data and node_response.data['type'] == 'dataset':
|
330
|
+
# return node_response
|
331
|
+
#
|
332
|
+
# raise Exception("Didn't find dataset node")
|
333
|
+
|
334
|
+
@staticmethod
|
335
|
+
def _convert_dataclass_to_json_dict(_dataclass):
|
336
|
+
if isinstance(_dataclass, Enum):
|
337
|
+
return _dataclass.name
|
338
|
+
if hasattr(_dataclass, '__dict__'):
|
339
|
+
return {
|
340
|
+
key: LeapGraphEditor._convert_dataclass_to_json_dict(_dataclass.__dict__[key])
|
341
|
+
for key in _dataclass.__dict__
|
342
|
+
}
|
343
|
+
if isinstance(_dataclass, list):
|
344
|
+
return [
|
345
|
+
LeapGraphEditor._convert_dataclass_to_json_dict(element)
|
346
|
+
for element in _dataclass
|
347
|
+
]
|
348
|
+
return _dataclass
|
349
|
+
|
350
|
+
# @staticmethod
|
351
|
+
# def _add_setup_to_metadata(dataset_version_metadata: Dict[str, Any],
|
352
|
+
# dataset_parse_result: DatasetIntegParseResult):
|
353
|
+
# setup_json = LeapGraphEditor._convert_dataclass_to_json_dict(dataset_parse_result.setup)
|
354
|
+
#
|
355
|
+
# dataset_version_metadata['setup'] = setup_json
|
356
|
+
|
357
|
+
# def _add_arg_names_to_visualizers(self, dataset_parse_result: DatasetIntegParseResult):
|
358
|
+
# visualizer_instance_by_name: Dict[str, VisualizerInstance] = {
|
359
|
+
# visualizer_instance.name: visualizer_instance
|
360
|
+
# for visualizer_instance in dataset_parse_result.setup.visualizers
|
361
|
+
# }
|
362
|
+
#
|
363
|
+
# for _, node_response in self.model_graph.items():
|
364
|
+
# if node_response.data['type'] == 'Visualizer':
|
365
|
+
# node_response.data['arg_names'] = visualizer_instance_by_name[node_response.data['selected']].arg_names
|
{leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/leap_model_parser/model_parser.py
RENAMED
@@ -8,6 +8,7 @@ from pathlib import Path
|
|
8
8
|
from typing import Callable, Optional, List, Dict, Tuple, Type
|
9
9
|
|
10
10
|
import tensorflow as tf # type: ignore
|
11
|
+
from code_loader.contract.mapping import NodeConnection
|
11
12
|
from keras import Model # type: ignore
|
12
13
|
from keras_data_format_converter import convert_channels_first_to_last # type: ignore
|
13
14
|
from leap_model_rebuilder import rebuild_model # type: ignore
|
@@ -20,6 +21,7 @@ from tensorflow.keras.models import load_model # type: ignore
|
|
20
21
|
from leap_model_parser.contract.graph import Node, InputInfo
|
21
22
|
from leap_model_parser.contract.importmodelresponse import ImportModelTypeEnum
|
22
23
|
from leap_model_parser.keras_json_model_import import KerasJsonModelImport
|
24
|
+
from leap_model_parser.leap_graph_editor import LeapGraphEditor
|
23
25
|
|
24
26
|
onnx_imported = False
|
25
27
|
package_name = 'onnx'
|
@@ -32,13 +34,16 @@ if spec is not None:
|
|
32
34
|
|
33
35
|
class ModelParser:
|
34
36
|
def __init__(self, should_transform_inputs_and_outputs=False,
|
35
|
-
custom_layers: Optional[Dict[str, Type[tf.keras.layers.Layer]]] = None
|
37
|
+
custom_layers: Optional[Dict[str, Type[tf.keras.layers.Layer]]] = None,
|
38
|
+
mapping_connections: Optional[List[NodeConnection]] = None):
|
36
39
|
self._should_transform_inputs_and_outputs = should_transform_inputs_and_outputs
|
37
40
|
self.custom_layers = custom_layers
|
38
41
|
if custom_layers is None:
|
39
42
|
self.custom_layers = {}
|
40
43
|
|
41
|
-
self.custom_layers = {**self.custom_layers, **onnx_custom_layers}
|
44
|
+
self.custom_layers = {**self.custom_layers, **onnx_custom_layers}
|
45
|
+
|
46
|
+
self.mapping_connections = mapping_connections
|
42
47
|
|
43
48
|
self._model_types_converter = {
|
44
49
|
ImportModelTypeEnum.JSON_TF2.value: self.convert_json_model,
|
@@ -70,6 +75,11 @@ class ModelParser:
|
|
70
75
|
|
71
76
|
graph, connected_inputs = model_generator.generate_graph(
|
72
77
|
model_schema, layer_name_to_inbound_nodes)
|
78
|
+
|
79
|
+
if self.mapping_connections is not None:
|
80
|
+
leap_graph_editor = LeapGraphEditor(graph, keras_model_with_weights)
|
81
|
+
leap_graph_editor.add_connections_to_graph(self.mapping_connections)
|
82
|
+
|
73
83
|
return graph, connected_inputs, keras_model_with_weights, error_info
|
74
84
|
except Exception as e:
|
75
85
|
if model_type.value in [ImportModelTypeEnum.H5_TF2.value, ImportModelTypeEnum.PB_TF2.value]:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[tool.poetry]
|
2
2
|
name = "leap-model-parser"
|
3
|
-
version = "0.1.
|
3
|
+
version = "0.1.185.dev2"
|
4
4
|
description = ""
|
5
5
|
authors = ["idan <idan.yogev@tensorleap.ai>"]
|
6
6
|
license = "MIT"
|
@@ -22,6 +22,7 @@ onnx2kerastl = "0.0.174"
|
|
22
22
|
keras-data-format-converter = "0.1.22"
|
23
23
|
leap-model-rebuilder = "0.1.7"
|
24
24
|
tensorflow-io-gcs-filesystem = "0.34.0"
|
25
|
+
code-loader = "1.0.87.dev2"
|
25
26
|
|
26
27
|
[tool.poetry.dev-dependencies]
|
27
28
|
pytest = "^7.1.1"
|
File without changes
|
File without changes
|
File without changes
|
{leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/leap_model_parser/contract/__init__.py
RENAMED
File without changes
|
{leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/leap_model_parser/contract/graph.py
RENAMED
File without changes
|
File without changes
|
{leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/leap_model_parser/contract/nodedata.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
{leap_model_parser-0.1.184 → leap_model_parser-0.1.185.dev2}/leap_model_parser/utils/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|