enncode 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. enncode/__init__.py +0 -0
  2. enncode/compatibility.py +460 -0
  3. enncode/gurobiModelBuilder.py +141 -0
  4. enncode/internalOnnx.py +17 -0
  5. enncode/networkBuilder.py +241 -0
  6. enncode/operators/__init__.py +0 -0
  7. enncode/operators/add.py +107 -0
  8. enncode/operators/averagepool.py +134 -0
  9. enncode/operators/base_operator.py +36 -0
  10. enncode/operators/batch_normalization.py +103 -0
  11. enncode/operators/concat.py +83 -0
  12. enncode/operators/conv.py +124 -0
  13. enncode/operators/div.py +104 -0
  14. enncode/operators/dropout.py +75 -0
  15. enncode/operators/flatten.py +87 -0
  16. enncode/operators/gemm.py +121 -0
  17. enncode/operators/identity.py +54 -0
  18. enncode/operators/matmul.py +122 -0
  19. enncode/operators/maxpool.py +122 -0
  20. enncode/operators/mul.py +109 -0
  21. enncode/operators/operator_factory.py +74 -0
  22. enncode/operators/relu.py +105 -0
  23. enncode/operators/reshape.py +90 -0
  24. enncode/operators/sub.py +105 -0
  25. enncode/operators/unsqueeze.py +90 -0
  26. enncode/parser.py +105 -0
  27. enncode/parsers/__init__.py +0 -0
  28. enncode/parsers/add_parser.py +54 -0
  29. enncode/parsers/averagepool_parser.py +99 -0
  30. enncode/parsers/base_parser.py +23 -0
  31. enncode/parsers/batch_normalization_parser.py +66 -0
  32. enncode/parsers/concat_parser.py +92 -0
  33. enncode/parsers/constant_parser.py +77 -0
  34. enncode/parsers/conv_parser.py +98 -0
  35. enncode/parsers/div_parser.py +82 -0
  36. enncode/parsers/dropout_parser.py +69 -0
  37. enncode/parsers/flatten_parser.py +80 -0
  38. enncode/parsers/gemm_parser.py +62 -0
  39. enncode/parsers/identity_parser.py +65 -0
  40. enncode/parsers/matmul_parser.py +85 -0
  41. enncode/parsers/maxpool_parser.py +99 -0
  42. enncode/parsers/mul_parser.py +80 -0
  43. enncode/parsers/parser_factory.py +69 -0
  44. enncode/parsers/relu_parser.py +53 -0
  45. enncode/parsers/reshape_parser.py +75 -0
  46. enncode/parsers/shape_parser.py +62 -0
  47. enncode/parsers/sub_parser.py +64 -0
  48. enncode/parsers/unsqueeze_parser.py +62 -0
  49. enncode/utils.py +118 -0
  50. enncode-0.1.5.dist-info/METADATA +371 -0
  51. enncode-0.1.5.dist-info/RECORD +54 -0
  52. enncode-0.1.5.dist-info/WHEEL +5 -0
  53. enncode-0.1.5.dist-info/licenses/LICENSE.txt +21 -0
  54. enncode-0.1.5.dist-info/top_level.txt +1 -0
enncode/__init__.py ADDED
File without changes
@@ -0,0 +1,460 @@
1
+ from enncode.parsers.parser_factory import ParserFactory
2
+ from onnx import utils
3
+ from vnnlib.compat import read_vnnlib_simple
4
+ from gurobipy import GRB
5
+ from enncode import gurobiModelBuilder
6
+ import gzip
7
+ import onnxruntime as ort
8
+ import os.path
9
+ import sys
10
+ import os
11
+ import numpy as np
12
+ import onnx
13
+ import glob
14
+
15
+ def load_gzip_onnx_model(gzipped_path):
16
+ """
17
+ Loads a specified model in onnx.gz format and return its decompressed 'standard' onnx model.
18
+
19
+ Args:
20
+ gzipped_path: path to the specified compressed model *.onnx.gz
21
+
22
+ Returns:
23
+ model: returns the loaded onnx model
24
+
25
+ Raises:
26
+ ValueError: if none onnx model could be decompressed from given path
27
+ """
28
+ print(f"\n Loading compressed model from: {gzipped_path} ")
29
+
30
+ with gzip.open(gzipped_path, 'rb') as f:
31
+ model = onnx.load(f)
32
+
33
+ if model is None:
34
+ raise ValueError(f"Decompressing model from given path {gzipped_path} couldn't be done successfully.")
35
+ print(f"Loading compressed model from {gzipped_path} was successful!")
36
+
37
+ return model
38
+
39
+
40
+ def load_vnnlib_conditions(vnnlib_path, onnx_model):
41
+ """
42
+ Initializes one parser for a *.vnnlib.gz by with given path. The parser is a dictionary which stores defined
43
+ preconditions at parser[0][0] and postconditions at parser[0][1], usually stored at the named indices.
44
+ The onnx model is necessary for determining the number of input and outputs.
45
+
46
+ Args:
47
+ vnnlib_path: path to the specified compressed model *.vnnlib.gz specifications
48
+ onnx_model: the onnx model which the specs should be analyzed
49
+
50
+ Returns:
51
+ model: returns a vnnlib parser initialized with given path
52
+
53
+ Raises:
54
+ ValueError (or internal vnnlib error): If loading the spec parser object failed cause spec/model mismatch or path issues.
55
+ """
56
+ graph = onnx_model.graph
57
+ input = graph.input
58
+ input_tensor = input[0]
59
+ input_shape = [dim.dim_value for dim in input_tensor.type.tensor_type.shape.dim]
60
+ num_inputs = int(np.prod(input_shape))
61
+ print(f"Registered number of inputs from given onnx model is: {num_inputs}")
62
+
63
+ output = graph.output
64
+ output_tensor = output[0]
65
+ output_shape = [dim.dim_value for dim in output_tensor.type.tensor_type.shape.dim]
66
+ num_outputs = int(np.prod(output_shape))
67
+ print(f"Registered number of outputs from given onnx model is: {num_outputs}")
68
+
69
+ spec_parser = read_vnnlib_simple(vnnlib_path, num_inputs, num_outputs)
70
+ if spec_parser is None:
71
+ raise ValueError(
72
+ f"Loading specs from {vnnlib_path} of vnnlib.data couldn't be done successfully.")
73
+ print(f"Loading specifications from {vnnlib_path} was successful! \n")
74
+
75
+ return spec_parser
76
+
77
+ def run_onnx_model(model_path, input_data, input_tensor_name='input', output_tensor_name='output'):
78
+ """
79
+ Runs an ONNX model with the given input and returns its first output.
80
+ """
81
+ session = ort.InferenceSession(model_path)
82
+ onnx_outputs = session.run(None, {input_tensor_name: input_data})
83
+ return onnx_outputs[0]
84
+
85
+
86
+ def solve_gurobi_model(model_path, input_data, input_tensor_name='input', output_tensor_name='output',
87
+ expand_batch_dim=False):
88
+ """
89
+ Converts an ONNX model to a Gurobi model, assigns input values, optimizes, and returns the output.
90
+ """
91
+ converter = gurobiModelBuilder.GurobiModelBuilder(model_path)
92
+ converter.build_model()
93
+
94
+ dummy_input = input_data
95
+ input_shape = dummy_input.shape
96
+
97
+ # Expand with batch dimension if not included in onnx-file
98
+ if expand_batch_dim:
99
+ dummy_input = np.expand_dims(dummy_input, axis=0)
100
+ input_shape = dummy_input.shape
101
+
102
+ # Set dummy input values in the Gurobi model.
103
+ input_vars = converter.variables.get(input_tensor_name)
104
+ if input_vars is None:
105
+ raise ValueError(f"No input variables found for '{input_tensor_name}'.")
106
+
107
+ for idx, var in input_vars.items():
108
+ if isinstance(idx, int):
109
+ md_idx = np.unravel_index(idx, input_shape[1:]) # Exclude batch dimension
110
+ elif isinstance(idx, tuple):
111
+ if len(idx) < len(input_shape) - 1:
112
+ idx = (0,) * (len(input_shape) - 1 - len(idx)) + idx
113
+ md_idx = idx
114
+ else:
115
+ raise ValueError(f"Unexpected index type: {type(idx)}")
116
+ value = float(dummy_input[0, *md_idx])
117
+ var.lb = value
118
+ var.ub = value
119
+
120
+ gurobi_model = converter.get_gurobi_model()
121
+ gurobi_model.optimize()
122
+ if gurobi_model.status != GRB.OPTIMAL:
123
+ raise ValueError(f"Optimization ended with status {gurobi_model.status}.")
124
+
125
+ # Extract the output from the Gurobi model.
126
+ output_vars = converter.variables.get(output_tensor_name)
127
+ if output_vars is None:
128
+ raise ValueError(f"No output variables found for '{output_tensor_name}'.")
129
+ output_shape = converter.in_out_tensors_shapes[output_tensor_name]
130
+ gurobi_outputs = np.zeros([1] + output_shape, dtype=np.float32)
131
+
132
+ output_shape = np.empty(output_shape).shape
133
+ flat_vars = [output_vars[k] for k in sorted(output_vars.keys())]
134
+ vars_array = np.array(flat_vars, dtype=object).reshape(output_shape)
135
+ for idx in np.ndindex(vars_array.shape):
136
+ if isinstance(idx, int):
137
+ md_idx = np.unravel_index(idx, output_shape)
138
+ elif isinstance(idx, tuple):
139
+ md_idx = idx
140
+ else:
141
+ raise ValueError(f"Unexpected index type in output: {type(idx)}")
142
+ gurobi_outputs[(0,) + md_idx] = vars_array[idx].x
143
+
144
+ # If an artificial batch dim. was added (cause not included in onnx file), it has to be removed
145
+ if expand_batch_dim:
146
+ if gurobi_outputs.shape[0] != 1:
147
+ raise ValueError(f"Something went wrong handling the batch-dimension expansion.")
148
+ gurobi_outputs = np.reshape(gurobi_outputs, gurobi_outputs.shape[1:])
149
+
150
+ return gurobi_outputs
151
+
152
+ def get_unsupported_node_types(onnx_path):
153
+ """
154
+ Helper function to extract existing node types in onnx.file that are currently supported by GurobiModelBuilder.
155
+
156
+ Args:
157
+ onnx_path: Path to the (original) onnx file which is analyzed
158
+ Returns:
159
+ unsupported: List of nodes in given onnx file (path), that are not supported
160
+ """
161
+ dummy_factory = ParserFactory()
162
+ model = onnx.load(onnx_path)
163
+ graph = model.graph
164
+ used_ops_in_model = set([node.op_type for node in graph.node])
165
+ supported_ops = set(dummy_factory.parsers.keys())
166
+ unsupported_ops = []
167
+ for op in used_ops_in_model:
168
+ if op not in supported_ops:
169
+ unsupported_ops.append(op)
170
+
171
+ return unsupported_ops
172
+
173
+
174
+ def get_nodes_of_graph(onnx_path):
175
+ """
176
+ Helper function for extracting node-names, primary for writing log file.
177
+
178
+ Args:
179
+ onnx_path: path to the current onnx model
180
+ """
181
+ subgraph = onnx.load(onnx_path)
182
+ subgraph = subgraph.graph
183
+ inputs = [inp.name for inp in subgraph.input]
184
+ outputs = [out.name for out in subgraph.output]
185
+ nodes = []
186
+ for i, node in enumerate(subgraph.node):
187
+ name = node.name if node.name else f"Unnamed_{node.op_type}_{i}"
188
+ nodes.append(name)
189
+ return inputs, outputs, nodes
190
+
191
+
192
+ def extract_subgraph(onnx_path, subgraph_filename, model_input_names, target_output_names, log_file_path):
193
+ """
194
+ Helper function for extraction subgraphs and handling possible exceptions.
195
+
196
+ Args:
197
+ onnx_path: path to the current onnx model
198
+ subgraph_filename: name/path of the new subgraph, stored at given subgraph_filename
199
+ model_input_names: input names of the subgraph to be extracted (determined by caller method)
200
+ target_output_names: output names of the subgraph to be extracted (determined by caller method)
201
+ log_file_path: path to the log file of currently analyzed (base) onnx file
202
+ """
203
+ ind = " " * 5
204
+ try:
205
+ utils.extract_model(
206
+ onnx_path,
207
+ subgraph_filename,
208
+ input_names=model_input_names,
209
+ output_names=target_output_names
210
+ )
211
+ with open(log_file_path, 'a', encoding="utf-8") as f:
212
+ f.write(f"{ind}[PASSED]: Extracting current subgraph was successful, stored at {subgraph_filename}. \n")
213
+
214
+ except Exception as e:
215
+ print(f"Something went wrong extracting subgraph {subgraph_filename}. \n")
216
+ print(e)
217
+ with open(log_file_path, 'a', encoding="utf-8") as f:
218
+ f.write(f"{ind}[FAILED]: Current subgraph couldn't be extracted {subgraph_filename}. \n")
219
+ f.write(f"{ind}{ind}{e}")
220
+ sys.exit(1)
221
+
222
+
223
+ def add_dynamic_batch_dim(onnx_model, dynamic_name="batch_size"):
224
+ """
225
+ Expects onnx_model to be a loaded onnx model, where a dynamic input/output dimension for batch size is declared at
226
+ the first dimension or is otherwise added.
227
+
228
+ Args:
229
+ onnx_model: onnx model where dynamic batch dim has to be determined (inplace), given by caller method.
230
+ dynamic_name: name of the new dynamic axis
231
+ """
232
+ graph = onnx_model.graph
233
+ for input_tensor in graph.input:
234
+ dims = input_tensor.type.tensor_type.shape.dim
235
+ if len(dims) > 1:
236
+ dims[0].dim_param = dynamic_name
237
+ for output_tensor in graph.output:
238
+ dims = output_tensor.type.tensor_type.shape.dim
239
+ if len(dims) > 1:
240
+ dims[0].dim_param = dynamic_name
241
+
242
+
243
+ def check_equivalence(onnx_path, model_input, model_input_names, target_output_names, log_file_path, rtol=1e-05, atol=1e-08):
244
+ """
245
+ Runs inference on onnx model and checks compatibility with GurobiModelBuilder.
246
+
247
+ Args:
248
+ onnx_path: path to the current onnx model
249
+ model_input: input from caller method for compatibility check
250
+ model_input_names: name of model input
251
+ target_output_names: determined output names from caller method
252
+ log_file_path: path to the log file of currently analyzed (base) onnx file
253
+
254
+ Returns:
255
+ True: if model is compatible and shows equivalence for given random input
256
+ False: if model is compatible but didn't show equivalence for given random input
257
+ None: if model isn't compatible or gurobi output has different shape as onnx output for given random input
258
+ """
259
+ if not os.path.isfile(log_file_path):
260
+ with open(log_file_path, "w", encoding="utf-8") as f:
261
+ f.write(f"Log-file from analysing of {onnx_path}. \n \n")
262
+
263
+ onnx_output = run_onnx_model(onnx_path, model_input, input_tensor_name=model_input_names[0])[0]
264
+ try:
265
+ gurobi_output = solve_gurobi_model(
266
+ onnx_path,
267
+ model_input,
268
+ input_tensor_name=model_input_names[0],
269
+ output_tensor_name=target_output_names[0],
270
+ expand_batch_dim=True
271
+ )
272
+ if onnx_output.shape != gurobi_output.shape:
273
+ # In that case, the shape mismatch might be caused by expanding batch dimension.
274
+ if len(gurobi_output.shape) - 1 == len(onnx_output.shape) and gurobi_output.shape[0] == 1:
275
+ gurobi_output = gurobi_output[0]
276
+ else:
277
+ raise ValueError(f"Shape mismatch: ONNX {onnx_output.shape} vs Gurobi {gurobi_output.shape}")
278
+
279
+ equivalence = np.allclose(onnx_output, gurobi_output, rtol=rtol, atol=atol)
280
+
281
+ return equivalence
282
+
283
+ except NotImplementedError as e:
284
+ print("\n Current subgraph has unsupported node types.\n")
285
+ print("GurobiModelBuilder misses support for following node type:")
286
+ print(e)
287
+ ind=" " * 5
288
+ with open(log_file_path, 'a', encoding="utf-8") as f:
289
+ f.write(f"{ind}[FAILED]: Compatibility is missing. Subgraph has unsupported node type! \n")
290
+ f.write(f"{ind}{ind}{e} \n")
291
+ return None
292
+ except ValueError as v:
293
+ ind = " " * 5
294
+ with open(log_file_path, 'a', encoding="utf-8") as f:
295
+ f.write(f"{ind}[PASSED]: Compatibility. \n")
296
+ f.write(f"{ind}[FAILED]: Equivalence check results in different shapes for ONNX and Gurobi output. \n")
297
+ f.write(f"{ind}{ind} {v} \n")
298
+ sys.exit(1)
299
+
300
+
301
+ def iterative_analyze_subgraphs(onnx_path, output_dir, model_input, model_input_names, rtol, atol, save_subgraphs=True):
302
+ """
303
+ This method is used for analyzing GurobiModelBuilder parsing errors. It iteratively extracts subgraphs from given onnx
304
+ model and tries to identify nodes, responsible for misconduct while parsing.
305
+
306
+ Args:
307
+ onnx_path: path to the initial onnx model
308
+ model_input: input from caller method for compatibility check
309
+ model_input_names: name of model input
310
+ """
311
+ model = onnx.load(onnx_path)
312
+ graph = model.graph
313
+
314
+ # Checks for validity of output dir, otherwise, directory of onnx path is taken as output dir
315
+ if not os.path.isdir(output_dir):
316
+ output_dir = os.path.dirname(onnx_path)
317
+ output_dir += "subgraphs"
318
+ if not os.path.isdir(output_dir):
319
+ os.makedirs(output_dir)
320
+
321
+ # Create log file for clearer overview
322
+ log_file_path = output_dir + "/subgraphs_log.txt"
323
+ with open(log_file_path, "w", encoding="utf-8") as f:
324
+ f.write(f"Log-file from analysing of {onnx_path}. \n \n")
325
+
326
+ for i, node in enumerate(graph.node):
327
+ # For constants nodes, no subgraph is evaluated
328
+ if node.op_type == "Constant":
329
+ continue
330
+
331
+ node_name = node.name if node.name else f"Node_{i}_{node.op_type}"
332
+ target_output_names = list(node.output)
333
+ subgraph_filename = os.path.join(output_dir, f"node_{i:03d}_{node_name.replace('/', '_')}.onnx")
334
+
335
+ ind = " " * 5
336
+ with open(log_file_path, "a", encoding="utf-8") as f:
337
+ f.write(f"[NODE {i}]: {node_name}\n")
338
+
339
+ # If subgraphs should not be stored permanently while analysing, they are removed in next iteration
340
+ if not save_subgraphs:
341
+ for file in glob.glob(os.path.join(output_dir, "*.onnx")):
342
+ os.remove(file)
343
+
344
+ extract_subgraph(onnx_path, subgraph_filename, model_input_names, target_output_names, log_file_path)
345
+
346
+ equivalence = check_equivalence(
347
+ subgraph_filename, model_input, model_input_names, target_output_names, log_file_path, rtol, atol
348
+ )
349
+ # Update log file
350
+ with open(log_file_path, "a") as f:
351
+ subgraph_inputs, subgraph_outputs, subgraph_nodes = get_nodes_of_graph(subgraph_filename)
352
+
353
+ if equivalence:
354
+ f.write(f"{ind}[PASSED]: Compatibility. \n")
355
+ f.write(f"{ind}[PASSED]: Equivalence check for ({subgraph_filename}) has been successful. \n \n")
356
+ else:
357
+ # In that case, check_equivalence must have failed, caused by unsupported nodes
358
+ if equivalence is None:
359
+ unsupported_types = get_unsupported_node_types(onnx_path)
360
+ f.write("\n Note: - Furthermore, the original model contains "
361
+ "following unsupported operation types, which are likely to cause further incompatibility: \n")
362
+ f.write(str(unsupported_types))
363
+ f.write("\n (Please see documentation for currently supported node-operations)")
364
+ # If equivalence is neither true nor none, it is false, indicating compatibility but missing equivalence
365
+ else:
366
+ f.write(f"{ind}[PASSED]: Compatibility. \n")
367
+ f.write(f"{ind}[FAILED]: Equivalence check for ({subgraph_filename}) failed. \n")
368
+ f.write(f"{ind} Inputs: {subgraph_inputs} \n")
369
+ f.write(f"{ind} Outputs: {subgraph_outputs} \n")
370
+ f.write(f"{ind} Included nodes: {subgraph_nodes} \n \n")
371
+
372
+ sys.exit(1)
373
+
374
+
375
+ def compatibility_check(onnx_path, iterative_analysis=True, output_dir=None, save_subgraphs=True, rtol=1e-03, atol=1e-05):
376
+ """
377
+ This method implements an automated compatibility check for given onnx file. First, it is adjusted to have dynamic
378
+ batch dimension and is stored as a new onnx file. For the adjusted onnx file is compatibility checked with a random
379
+ input. If successful, equivalence to the corresponding onnx run is check with given rtol/atol deviation.
380
+
381
+ If not successful, the user is asked if an iterative analysis of subgraphs should be done. If so, every subgraph
382
+ in topological order is evaluated to be extracted and then tested for compatibility and equivalence with onnx run.
383
+
384
+ Results of iterative analysis are written in a log file, stored in the specified output directory.
385
+
386
+ Args:
387
+ onnx_path: path to the onnx file to be checked.
388
+ iterative_analysis: boolean flag, if iterative analysis of subgraphs should be done
389
+ output_dir: path to the directory where log files and subgraphs are stored.
390
+ save_subgraphs: boolean flag, if all subgraphs should be stored permanently or removed if checked
391
+ rtol: the relative tolerance parameter for np.allclose
392
+ atol: the absolute tolerance parameter for np.allclose
393
+ """
394
+ path = onnx_path
395
+
396
+ # 1) First the suffix of given path is checked for the type of onnx format
397
+ if path.endswith(".gz"):
398
+ model = load_gzip_onnx_model(path)
399
+ path = path.removesuffix(".gz")
400
+ onnx.save(model, path)
401
+ onnx_model = onnx.load(path)
402
+
403
+ # Then we ensure that first dimension is always a dynamic batch dimension
404
+ add_dynamic_batch_dim(onnx_model)
405
+ # onnx_model = shape_inference.infer_shapes(onnx_model)
406
+ path = path.removesuffix(".onnx") + "_modified.onnx"
407
+ onnx.save(onnx_model, path)
408
+
409
+ # 2) For given network, input and output names has to be filtered
410
+ graph = onnx_model.graph
411
+ input_names = [node.name for node in graph.input]
412
+ initializer_names = {x.name for x in graph.initializer}
413
+ real_inputs = [name for name in input_names if name not in initializer_names]
414
+
415
+ input_nodes = [node for node in onnx_model.graph.input if node.name in real_inputs]
416
+ input_tensor = input_nodes[0]
417
+ input_shape = [max(dim.dim_value, 1) for dim in input_tensor.type.tensor_type.shape.dim]
418
+ output_names = [node.name for node in onnx_model.graph.output]
419
+
420
+ # 3) Basic compatibility check by random input
421
+ dummy_input = np.random.rand(*input_shape).astype(np.float32)
422
+ onnx_output = run_onnx_model(path, dummy_input, input_tensor_name=real_inputs[0])[0]
423
+ try:
424
+ gurobi_output = solve_gurobi_model(
425
+ path,
426
+ dummy_input,
427
+ input_tensor_name=real_inputs[0],
428
+ output_tensor_name=output_names[0],
429
+ expand_batch_dim=True
430
+ )
431
+ except NotImplementedError as e:
432
+ print(f"\n An error has occurred by solving gurobi model for given onnx file.")
433
+ if iterative_analysis:
434
+ if output_dir is None:
435
+ output_dir = os.path.dirname(onnx_path) + "/"
436
+ # The subgraphs are iteratively checked for compatibility
437
+ iterative_analyze_subgraphs(path, output_dir, dummy_input, real_inputs, rtol, atol, save_subgraphs)
438
+ else:
439
+ print("\n No iterative check for compatibility is performed. \n")
440
+ print("GurobiModelBuilder misses support for following node types:")
441
+ unsupported_types = get_unsupported_node_types(path)
442
+ print(str(unsupported_types))
443
+ sys.exit(1)
444
+
445
+ # 4) Check for equivalent outputs
446
+ rtol = rtol
447
+ atol = atol
448
+ equivalence = np.allclose(onnx_output, gurobi_output, rtol=rtol, atol=atol)
449
+ if equivalence:
450
+ print("\n Given network has been compatible with GurobiModelBuilder parsing. \n")
451
+ print(f"It has shown equivalence for rtol={rtol} and atol={atol}.")
452
+ print(f"ONNX-output: {onnx_output}")
453
+ print(f"Gurobi-output: {gurobi_output}")
454
+ sys.exit(0)
455
+ else:
456
+ print("\n Given network has been compatible with GurobiModelBuilder parsing. \n")
457
+ print(f"Unfortunately there is a deviation between ONNX and Gurobi output for rtol={rtol} and atol={atol}.")
458
+ print(f"ONNX-output: {onnx_output}")
459
+ print(f"Gurobi-output: {gurobi_output}")
460
+ sys.exit(1)
@@ -0,0 +1,141 @@
1
+ from gurobipy import Model, GRB
2
+ from .operators.operator_factory import OperatorFactory
3
+ from .parser import ONNXParser
4
+ from .utils import _generate_indices
5
+
6
+ class GurobiModelBuilder:
7
+ """
8
+ Converts an ONNX model to a Gurobi optimization model by transforming the ONNX
9
+ representation into an internal representation and then constructing the corresponding
10
+ constraints for each operator.
11
+
12
+ Attributes:
13
+ model (gurobipy.Model): The Gurobi model being constructed.
14
+ internal_onnx (InternalONNX): The internal representation of the parsed ONNX model,
15
+ containing initializers, nodes, and input/output tensor shapes.
16
+ initializers (dict): A dictionary containing the initial values extracted from the ONNX model.
17
+ nodes (list): A list of dictionaries, each representing an ONNX node with its associated data.
18
+ in_out_tensors_shapes (dict): A mapping of input and output tensor names to their shapes.
19
+ operator_factory (OperatorFactory): Factory for creating operator instances based on node types.
20
+ variables (dict): A mapping of tensor names to either Gurobi decision variables or constant values.
21
+ """
22
+ def __init__(self, onnx_model_path: str):
23
+ """
24
+ Initializes the ONNXToGurobi converter with the given ONNX model file path.
25
+
26
+ This constructor loads the ONNX model, converts it into an internal representation,
27
+ and initializes the attributes required for building the Gurobi model.
28
+
29
+ Args:
30
+ onnx_model_path (str): The file path to the ONNX model to be converted.
31
+ """
32
+ self.model = Model("NeuralNetwork")
33
+ self.internal_onnx = ONNXParser(onnx_model_path)._parse_model()
34
+ self.initializers = self.internal_onnx.initializers
35
+ self.nodes = self.internal_onnx.nodes
36
+ self.in_out_tensors_shapes = self.internal_onnx.in_out_tensors_shapes
37
+ self.operator_factory = OperatorFactory()
38
+ self.variables = {}
39
+
40
+ def create_variables(self):
41
+ """
42
+ Creates Gurobi variables for the input/output tensors and intermediate nodes.
43
+
44
+ """
45
+ # Create variables for inputs and outputs
46
+ for tensor_name, shape in self.in_out_tensors_shapes.items():
47
+ indices = _generate_indices(shape)
48
+ self.variables[tensor_name] = self.model.addVars(
49
+ indices,
50
+ vtype=GRB.CONTINUOUS,
51
+ lb=-GRB.INFINITY,
52
+ name=tensor_name
53
+ )
54
+
55
+ # Create variables for intermediate nodes
56
+ for node in self.nodes:
57
+ output_name = node['output'][0]['name']
58
+
59
+ if node['type'] == "Constant":
60
+ # Constants are not model variables
61
+ if 'attributes' in node and node['attributes']:
62
+ self.variables[output_name] = node['attributes']['value']
63
+ else:
64
+ self.variables[output_name] = 0
65
+
66
+ elif node['type'] == "Relu":
67
+ shape = node['output'][0]['shape']
68
+ indices = _generate_indices(shape)
69
+ """
70
+ var_input = self.variables[node["input"][0]["name"]]
71
+
72
+ # Create binary variables for ReLU indicator
73
+ self.variables[f"relu_binary_{output_name}"] = self.model.addVars(
74
+ var_input.keys(),
75
+ vtype=GRB.BINARY,
76
+ name=f"relu_binary_{output_name}"
77
+ )
78
+
79
+ # Create output variables
80
+ self.variables[output_name] = self.model.addVars(
81
+ indices,
82
+ vtype=GRB.CONTINUOUS,
83
+ lb=-GRB.INFINITY,
84
+ name=output_name
85
+ )
86
+ """
87
+ self.variables[output_name] = self.model.addVars(
88
+ indices,
89
+ vtype=GRB.CONTINUOUS,
90
+ lb=0.0,
91
+ name=output_name
92
+ )
93
+
94
+ else:
95
+ shape = node['output'][0]['shape']
96
+ indices = _generate_indices(shape)
97
+ self.variables[output_name] = self.model.addVars(
98
+ indices,
99
+ vtype=GRB.CONTINUOUS,
100
+ lb=-GRB.INFINITY,
101
+ name=output_name
102
+ )
103
+
104
+ def build_model(self):
105
+ """
106
+ Constructs the Gurobi model by creating variables and applying operator constraints.
107
+
108
+ """
109
+ self.create_variables()
110
+ for node in self.nodes:
111
+ if node['type'] != "Constant":
112
+ operator = self.operator_factory.create_operator(node, self.initializers)
113
+ operator.apply_constraints(self.model, self.variables)
114
+
115
+ def get_gurobi_model(self):
116
+ """
117
+ Retrieves the Gurobi model after all constraints have been added.
118
+
119
+ Returns:
120
+ gurobipy.Model: The constructed Gurobi model reflecting the ONNX graph.
121
+ """
122
+ return self.model
123
+
124
+ def get_input_vars(self):
125
+ if len(self.internal_onnx.input_node_name) != 1:
126
+ raise ValueError(f"The current model seems to have more than one input node, which isn't supported by this function.")
127
+ input_name = self.internal_onnx.input_node_name[0]
128
+ input_vars = self.variables.get(input_name)
129
+ if input_vars is None:
130
+ raise ValueError(f"Input variables couldn't be accessed.")
131
+ return input_vars
132
+
133
+ def get_output_vars(self):
134
+ if len(self.internal_onnx.output_node_name) != 1:
135
+ raise ValueError(f"The current model seems to have more than one output node, which isn't supported by this function.")
136
+ output_name = self.internal_onnx.output_node_name[0]
137
+ output_vars = self.variables.get(output_name)
138
+ if output_vars is None:
139
+ raise ValueError(f"Output variables couldn't be accessed.")
140
+ return output_vars
141
+
@@ -0,0 +1,17 @@
1
+ class InternalONNXRepresentation:
2
+ """
3
+ Represents the internal ONNX model after parsing.
4
+ This representation supplies the ONNXToGurobi
5
+ with the required attributes for building the Gurobi model.
6
+
7
+ Attributes:
8
+ initializers (dict): Contains the initial values from the parsed ONNX model.
9
+ nodes (list): A list of dictionaries, each representing an ONNX node extracted by the parser.
10
+ in_out_tensors_shapes (dict): Stores shapes for all input and output tensors.
11
+ """
12
+ def __init__(self, parser):
13
+ self.initializers = parser.initializer_values
14
+ self.nodes = parser.nodes
15
+ self.in_out_tensors_shapes = parser.input_output_tensors_shapes
16
+ self.input_node_name = parser.real_inputs_names
17
+ self.output_node_name = parser.real_output_names