onnx-diagnostic 0.8.6__py3-none-any.whl → 0.8.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +108 -3
- onnx_diagnostic/ci_models/ci_helpers.py +12 -7
- onnx_diagnostic/ci_models/export_phi4_mm.py +1062 -0
- onnx_diagnostic/ci_models/export_qwen25_vl.py +12 -4
- onnx_diagnostic/export/api.py +1 -0
- onnx_diagnostic/export/cf_simple_loop_for.py +195 -10
- onnx_diagnostic/ext_test_case.py +9 -2
- onnx_diagnostic/helpers/bench_run.py +1 -1
- onnx_diagnostic/helpers/log_helper.py +1 -3
- onnx_diagnostic/helpers/optim_helper.py +116 -0
- onnx_diagnostic/tasks/image_text_to_text.py +15 -5
- onnx_diagnostic/tasks/text2text_generation.py +84 -48
- onnx_diagnostic/tasks/text_generation.py +3 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +28 -2
- onnx_diagnostic/torch_export_patches/patch_expressions.py +4 -1
- onnx_diagnostic/torch_export_patches/patch_module.py +31 -23
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_funnel.py +80 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +12 -1
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +15 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +22 -24
- onnx_diagnostic/torch_models/hghub/hub_api.py +11 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +9 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -19
- {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.7.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.7.dist-info}/RECORD +29 -26
- {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.7.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.7.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.7.dist-info}/top_level.txt +0 -0
|
@@ -8,10 +8,10 @@ Requirements
|
|
|
8
8
|
::
|
|
9
9
|
|
|
10
10
|
git+https://github.com/sdpython/experimental-experiment.git # optional
|
|
11
|
-
huggingface_hub
|
|
11
|
+
huggingface_hub
|
|
12
12
|
onnx-diagnostic>=0.8.6
|
|
13
13
|
onnxruntime>=1.23
|
|
14
|
-
torch>=2.
|
|
14
|
+
torch>=2.10 # weekly is better
|
|
15
15
|
tqdm
|
|
16
16
|
transformers>=4.57
|
|
17
17
|
|
|
@@ -59,6 +59,7 @@ It is possible to overwrite this by by setting environment variable
|
|
|
59
59
|
import os
|
|
60
60
|
import sys
|
|
61
61
|
import time
|
|
62
|
+
import warnings
|
|
62
63
|
from typing import Any, Dict, List, Tuple
|
|
63
64
|
from .ci_helpers import (
|
|
64
65
|
check_for_discrepancies_and_log_everything_into_a_json_file,
|
|
@@ -97,7 +98,6 @@ def get_untrained_model(model_id: str, second_input: bool, verbose: int) -> Dict
|
|
|
97
98
|
},
|
|
98
99
|
# "_attn_implementation": "flash_attention_2",
|
|
99
100
|
"_attn_implementation": "sdpa",
|
|
100
|
-
"dtype": "float16",
|
|
101
101
|
}
|
|
102
102
|
|
|
103
103
|
config_reduction = _config_reduction
|
|
@@ -281,6 +281,10 @@ def main(
|
|
|
281
281
|
).eval()
|
|
282
282
|
data = dict(model=model)
|
|
283
283
|
config = model.config
|
|
284
|
+
if not hasattr(config, "bos_token_id") or not config.bos_token_id:
|
|
285
|
+
config.bos_token_id = 151643
|
|
286
|
+
if not hasattr(config, "eos_token_id") or not config.eos_token_id:
|
|
287
|
+
config.eos_token_id = 151645
|
|
284
288
|
else:
|
|
285
289
|
print("-- random model")
|
|
286
290
|
data = get_untrained_model(model_id, second_input=second_input, verbose=1)
|
|
@@ -298,7 +302,11 @@ def main(
|
|
|
298
302
|
print(f"-- config._attn_implementation={model.config._attn_implementation}")
|
|
299
303
|
print(f"-- model.dtype={model.dtype}")
|
|
300
304
|
print(f"-- model.device={model.device}")
|
|
301
|
-
|
|
305
|
+
try:
|
|
306
|
+
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
|
|
307
|
+
except OSError as e:
|
|
308
|
+
warnings.warn(f"Unable to access internet due to {e!r}", ResourceWarning, stacklevel=0)
|
|
309
|
+
return
|
|
302
310
|
print(f"-- processor={type(processor)}")
|
|
303
311
|
|
|
304
312
|
export_inputs, other_inputs = None, None
|
onnx_diagnostic/export/api.py
CHANGED
|
@@ -11,6 +11,7 @@ from torch._higher_order_ops.utils import (
|
|
|
11
11
|
unique_graph_id,
|
|
12
12
|
validate_subgraph_args_types,
|
|
13
13
|
)
|
|
14
|
+
import torch._dynamo.variables.higher_order_ops as hop
|
|
14
15
|
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
|
|
15
16
|
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
|
16
17
|
|
|
@@ -97,14 +98,18 @@ def _simple_loop_for_fn(
|
|
|
97
98
|
f"Unexpected number of results {len(r)} for function {body_fn}, "
|
|
98
99
|
f"expected {len(res[-1])}"
|
|
99
100
|
)
|
|
101
|
+
assert all(isinstance(t, torch.Tensor) for t in r), (
|
|
102
|
+
f"Unexpected type {[type(_) for _ in r]} for returned by function {body_fn}, "
|
|
103
|
+
f"it must be a tuple of Tensor or a Tensor."
|
|
104
|
+
)
|
|
100
105
|
res.append(r)
|
|
101
106
|
else:
|
|
102
107
|
assert isinstance(r, torch.Tensor), (
|
|
103
|
-
f"Unexpected type {r}
|
|
104
|
-
f"it must be a tuple or a Tensor."
|
|
108
|
+
f"Unexpected type {type(r)} coming from function {body_fn}, "
|
|
109
|
+
f"it must be a tuple of Tensor or a Tensor."
|
|
105
110
|
)
|
|
106
111
|
assert not res or len(res[-1]) == 1, (
|
|
107
|
-
f"Unexpected number of results {len(r)}
|
|
112
|
+
f"Unexpected number of results {len(r)} coming from function {body_fn}, "
|
|
108
113
|
f"expected {len(res[-1])}"
|
|
109
114
|
)
|
|
110
115
|
res.append((r,))
|
|
@@ -126,8 +131,6 @@ def _simple_loop_for_fn(
|
|
|
126
131
|
)
|
|
127
132
|
|
|
128
133
|
|
|
129
|
-
# from torch._functorch.utils import exposed_in
|
|
130
|
-
# @exposed_in("torch")
|
|
131
134
|
def _simple_loop_for(
|
|
132
135
|
n_iter: Union[int, torch.Tensor],
|
|
133
136
|
body_fn: Callable,
|
|
@@ -159,7 +162,7 @@ def _simple_loop_for(
|
|
|
159
162
|
|
|
160
163
|
if torch.compiler.is_dynamo_compiling():
|
|
161
164
|
return simple_loop_for_op(
|
|
162
|
-
n_iter, body_fn,
|
|
165
|
+
n_iter, body_fn, operands, concatenation_dims=concatenation_dims
|
|
163
166
|
)
|
|
164
167
|
|
|
165
168
|
if isinstance(n_iter, (bool, int, float)):
|
|
@@ -181,8 +184,10 @@ def _simple_loop_for(
|
|
|
181
184
|
|
|
182
185
|
with setup_compilation_env() as _backend:
|
|
183
186
|
return _loop_for_op_wrapper(n_iter, body_fn, operands, concatenation_dims)
|
|
184
|
-
#
|
|
185
|
-
#
|
|
187
|
+
# This is needed to support function body using module weights or function body
|
|
188
|
+
# defined as a class method. This is yet to be implemented.
|
|
189
|
+
# cpl = torch.compile(_loop_for_op_wrapper, backend=_backend, fullgraph=True)
|
|
190
|
+
# return cpl(n_iter, body_fn, operands, concatenation_dims)
|
|
186
191
|
|
|
187
192
|
|
|
188
193
|
def trace_simple_loop_for(
|
|
@@ -236,9 +241,15 @@ def loop_for_op_dense(n_iter, body_fn, operands, concatenation_dims=None):
|
|
|
236
241
|
)
|
|
237
242
|
mode = _get_current_dispatch_mode()
|
|
238
243
|
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
|
|
239
|
-
|
|
240
|
-
|
|
244
|
+
is_fake = isinstance(n_iter, torch._subclasses.fake_tensor.FakeTensor)
|
|
245
|
+
res = _simple_loop_for_fn(n_iter, body_fn, operands, concatenation_dims=concatenation_dims)
|
|
246
|
+
assert is_fake or not any(
|
|
247
|
+
isinstance(r, torch._subclasses.fake_tensor.FakeTensor) for r in res
|
|
248
|
+
), (
|
|
249
|
+
f"One result is a fake tensor but the inputs were not, type(n_iter)={type(n_iter)}, "
|
|
250
|
+
f"operands: {[type(_) for _ in operands]}, res: {[type(_) for _ in res]}"
|
|
241
251
|
)
|
|
252
|
+
return res
|
|
242
253
|
|
|
243
254
|
|
|
244
255
|
@simple_loop_for_op.py_impl(ProxyTorchDispatchMode)
|
|
@@ -267,6 +278,180 @@ simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCPU)
|
|
|
267
278
|
simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCUDA)
|
|
268
279
|
|
|
269
280
|
|
|
281
|
+
class SimpleLoopForHigherOrderVariable(hop.TorchHigherOrderOperatorVariable):
|
|
282
|
+
"""
|
|
283
|
+
Replicates the same pattern found for other higher order operators.
|
|
284
|
+
This enables recursive compilation and the use of modules inside a function.
|
|
285
|
+
"""
|
|
286
|
+
|
|
287
|
+
_HOP_NAME = "simple_loop_for"
|
|
288
|
+
_ALLOW_FALLBACK_TO_EAGER = False
|
|
289
|
+
supports_input_mutation = False
|
|
290
|
+
supports_aliasing = False
|
|
291
|
+
|
|
292
|
+
def _call_function(
|
|
293
|
+
self,
|
|
294
|
+
tx: torch._dynamo.symbolic_convert.InstructionTranslator,
|
|
295
|
+
args: list[hop.VariableTracker],
|
|
296
|
+
kwargs: dict[str, hop.VariableTracker],
|
|
297
|
+
) -> hop.VariableTracker:
|
|
298
|
+
"""Main function."""
|
|
299
|
+
args, kwargs = hop.LazyVariableTracker.realize_all((args, kwargs))
|
|
300
|
+
|
|
301
|
+
for i, k in enumerate(["n_iter", "body_fn", "operands", "concatenated_dims"]):
|
|
302
|
+
if v := kwargs.pop(k, None):
|
|
303
|
+
assert i == len(args), "did not provide the right number of non-keyword args"
|
|
304
|
+
args.append(v)
|
|
305
|
+
|
|
306
|
+
if len(args) != 4 or kwargs:
|
|
307
|
+
hop.unimplemented(
|
|
308
|
+
gb_type="simple_loop_for: improper args/kwargs",
|
|
309
|
+
context=f"args: {args}, kwargs: {kwargs}",
|
|
310
|
+
explanation=f"torch.cond expects 4 positional arguments (got {len(args)}) "
|
|
311
|
+
f"and no keyword arguments (got {len(kwargs)})",
|
|
312
|
+
hints=[*hop.graph_break_hints.USER_ERROR],
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
# Specialize into one of the branches since pred is constant
|
|
316
|
+
n_iter, body_fn, operands, _concatenated_dims = args
|
|
317
|
+
assert type(n_iter) is not hop.ConstantVariable, (
|
|
318
|
+
f"n_iter is a {type(n_iter)}. When used simple_loop_for, "
|
|
319
|
+
f"it unrolls the loop. A SymInt should be used."
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# predicate
|
|
323
|
+
if type(n_iter.realize()) not in (
|
|
324
|
+
hop.ConstantVariable,
|
|
325
|
+
hop.TensorVariable,
|
|
326
|
+
hop.SymNodeVariable,
|
|
327
|
+
):
|
|
328
|
+
hop.unimplemented(
|
|
329
|
+
gb_type="simple_loop_for: improper predicate",
|
|
330
|
+
context=str(n_iter),
|
|
331
|
+
explanation=(
|
|
332
|
+
f"Expected `n_iter` to be an int or a integer "
|
|
333
|
+
f"tensor with a single item "
|
|
334
|
+
f"but got {str(type(n_iter))} with original python type "
|
|
335
|
+
f"{str(n_iter.python_type())}."
|
|
336
|
+
),
|
|
337
|
+
hints=[*hop.graph_break_hints.USER_ERROR],
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# operands
|
|
341
|
+
if not isinstance(operands, (hop.ListVariable, hop.TupleVariable)):
|
|
342
|
+
hop.unimplemented(
|
|
343
|
+
gb_type="simple_loop_for: improper operands",
|
|
344
|
+
context=str(operands),
|
|
345
|
+
explanation="Expected `operands` to be a list/tuple "
|
|
346
|
+
f"but got {operands.python_type()}.",
|
|
347
|
+
hints=[*hop.graph_break_hints.USER_ERROR],
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
operands_seq = operands.unpack_var_sequence(tx)
|
|
351
|
+
if not hop.only_consist_of(
|
|
352
|
+
operands, (hop.TensorVariable, hop.ConstantVariable, hop.SymNodeVariable)
|
|
353
|
+
):
|
|
354
|
+
hop.unimplemented(
|
|
355
|
+
gb_type="simple_loop_for: improper operands contents",
|
|
356
|
+
context=str(operands),
|
|
357
|
+
explanation=(
|
|
358
|
+
"Expected `operands` to be a list/tuple of pytrees "
|
|
359
|
+
"that only consists of tensor leaves."
|
|
360
|
+
),
|
|
361
|
+
hints=[*hop.graph_break_hints.USER_ERROR],
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
# branches
|
|
365
|
+
hop._check_supported_callable_arg(tx, body_fn, "body_fn")
|
|
366
|
+
|
|
367
|
+
def speculate_body():
|
|
368
|
+
(
|
|
369
|
+
(ret_val, ret_spec),
|
|
370
|
+
ret_graph,
|
|
371
|
+
ret_lifted_freevars,
|
|
372
|
+
) = hop.speculate_subgraph(
|
|
373
|
+
tx,
|
|
374
|
+
args[1],
|
|
375
|
+
(args[0], *operands_seq),
|
|
376
|
+
{},
|
|
377
|
+
self._HOP_NAME,
|
|
378
|
+
source_target=self.value,
|
|
379
|
+
should_flatten_outputs=True,
|
|
380
|
+
# TODO - removing consts from control flow ops need more work
|
|
381
|
+
remove_consts_from_outputs=False,
|
|
382
|
+
supports_input_mutation=self.supports_input_mutation,
|
|
383
|
+
supports_aliasing=self.supports_aliasing,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# need to ensure we increase epoch so we don't memoize unbacked bindings
|
|
387
|
+
# across different subgraphs which can interfere with runtime assertion
|
|
388
|
+
# generation.
|
|
389
|
+
tx.fake_mode.epoch += 1
|
|
390
|
+
|
|
391
|
+
if not hop.only_consist_of(ret_val, (hop.TensorVariable, hop.ConstantVariable)):
|
|
392
|
+
hop.unimplemented(
|
|
393
|
+
gb_type="simple_loop_for: unsupported branch return type",
|
|
394
|
+
context=str(ret_val),
|
|
395
|
+
explanation=(
|
|
396
|
+
"Expected branches to return a possibly nested "
|
|
397
|
+
"pytree of tensors or constant ints."
|
|
398
|
+
),
|
|
399
|
+
hints=[*hop.graph_break_hints.USER_ERROR],
|
|
400
|
+
)
|
|
401
|
+
for ret in ret_val.unpack_var_sequence(tx):
|
|
402
|
+
if ret.is_python_constant() and not isinstance(ret.as_python_constant(), int):
|
|
403
|
+
hop.unimplemented(
|
|
404
|
+
gb_type=(
|
|
405
|
+
"simple_loop_for: unsupported branch return type "
|
|
406
|
+
"(constant non-int)"
|
|
407
|
+
),
|
|
408
|
+
context=str(ret_val),
|
|
409
|
+
explanation="Constants returned from branches must be ints.",
|
|
410
|
+
hints=[*hop.graph_break_hints.USER_ERROR],
|
|
411
|
+
)
|
|
412
|
+
return ret_val, ret_spec, ret_graph, ret_lifted_freevars
|
|
413
|
+
|
|
414
|
+
body_r, body_spec, body_graph, body_lifted_freevars = speculate_body()
|
|
415
|
+
body_nn_modules = dict(tx.output.nn_modules)
|
|
416
|
+
|
|
417
|
+
same_spec = body_spec.treespec.as_python_constant()
|
|
418
|
+
if same_spec is not NotImplemented and not same_spec:
|
|
419
|
+
hop.unimplemented(
|
|
420
|
+
gb_type="simple_loop_for: differing branch outputs",
|
|
421
|
+
context=(
|
|
422
|
+
f"body_spec: {body_spec.treespec}, false_spec: "
|
|
423
|
+
f"{body_spec.treespec}, same_spec: {same_spec}"
|
|
424
|
+
),
|
|
425
|
+
explanation="Expected branches to return the same pytree structure.",
|
|
426
|
+
hints=[*hop.graph_break_hints.USER_ERROR],
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
body_name = tx.output.install_subgraph(
|
|
430
|
+
"loop_body", torch.fx.GraphModule(body_nn_modules, body_graph)
|
|
431
|
+
)
|
|
432
|
+
body_node = hop.make_attr(tx, body_name)
|
|
433
|
+
p_args = (
|
|
434
|
+
n_iter.as_proxy(),
|
|
435
|
+
body_node,
|
|
436
|
+
# We pick true_shared but it shouldn't matter
|
|
437
|
+
operands.as_proxy() + tuple(body_lifted_freevars.keys()),
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
return hop._call_function_and_unflatten_output(
|
|
441
|
+
tx,
|
|
442
|
+
simple_loop_for,
|
|
443
|
+
p_args,
|
|
444
|
+
{},
|
|
445
|
+
None,
|
|
446
|
+
body_spec,
|
|
447
|
+
body_r,
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
hop._hop_name_to_variable_class["simple_loop_for"] = SimpleLoopForHigherOrderVariable
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
# @torch._functorch.utils.exposed_in("torch")
|
|
270
455
|
def simple_loop_for(
|
|
271
456
|
n_iter: Union[int, torch.Tensor],
|
|
272
457
|
body_fn: Callable,
|
onnx_diagnostic/ext_test_case.py
CHANGED
|
@@ -1267,6 +1267,7 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1267
1267
|
:class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`
|
|
1268
1268
|
"""
|
|
1269
1269
|
from .helpers import string_type, string_diff, max_diff
|
|
1270
|
+
from .helpers.torch_helper import torch_deepcopy
|
|
1270
1271
|
from .helpers.rt_helper import make_feeds
|
|
1271
1272
|
from .helpers.ort_session import InferenceSessionForTorch
|
|
1272
1273
|
|
|
@@ -1283,6 +1284,12 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1283
1284
|
model_file = proto
|
|
1284
1285
|
name = proto
|
|
1285
1286
|
proto = onnx.load(name)
|
|
1287
|
+
elif hasattr(proto, "save"):
|
|
1288
|
+
name = f"{test_name}.onnx"
|
|
1289
|
+
proto.save(name)
|
|
1290
|
+
proto = onnx.load(name)
|
|
1291
|
+
elif hasattr(proto, "model_proto"):
|
|
1292
|
+
proto = proto.model_proto
|
|
1286
1293
|
elif not self.unit_test_going():
|
|
1287
1294
|
assert isinstance(
|
|
1288
1295
|
proto, onnx.ModelProto
|
|
@@ -1341,9 +1348,9 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1341
1348
|
if copy_inputs:
|
|
1342
1349
|
expected = [
|
|
1343
1350
|
(
|
|
1344
|
-
model(*
|
|
1351
|
+
model(*torch_deepcopy(inp))
|
|
1345
1352
|
if isinstance(inp, tuple)
|
|
1346
|
-
else model(**
|
|
1353
|
+
else model(**torch_deepcopy(inp))
|
|
1347
1354
|
)
|
|
1348
1355
|
for inp in inputs
|
|
1349
1356
|
]
|
|
@@ -1921,9 +1921,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1921
1921
|
return lambdas[formula]
|
|
1922
1922
|
|
|
1923
1923
|
if formula == "onnx_n_nodes_no_cst":
|
|
1924
|
-
return lambda df: gdf(df, "onnx_n_nodes", 0) - gdf(
|
|
1925
|
-
df, "op_onnx__Constant", 0
|
|
1926
|
-
).fillna(0)
|
|
1924
|
+
return lambda df: gdf(df, "onnx_n_nodes", 0) - gdf(df, "op_onnx__Constant", 0)
|
|
1927
1925
|
if formula == "peak_gpu_torch":
|
|
1928
1926
|
return lambda df: gdf(df, "mema_gpu_5_after_export") - gdf(df, "mema_gpu_4_reset")
|
|
1929
1927
|
if formula == "peak_gpu_nvidia":
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from typing import Optional, Union
|
|
2
|
+
import pprint
|
|
3
|
+
import onnx
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def optimize_model(
|
|
7
|
+
algorithm: str,
|
|
8
|
+
model: Union[onnx.ModelProto, str],
|
|
9
|
+
output: Optional[str] = None,
|
|
10
|
+
processor: Optional[str] = None,
|
|
11
|
+
infer_shapes: bool = True,
|
|
12
|
+
remove_shape_info: bool = False,
|
|
13
|
+
verbose: int = 1,
|
|
14
|
+
):
|
|
15
|
+
"""
|
|
16
|
+
Optimizes an onnx model by fusing nodes. It looks for patterns in the graphs
|
|
17
|
+
and replaces them by the corresponding nodes. It also does basic optimization
|
|
18
|
+
such as removing identity nodes or unused nodes.
|
|
19
|
+
|
|
20
|
+
:param algorithm: algorithm to choose
|
|
21
|
+
:param model: model to optimize as a proto or a filename
|
|
22
|
+
:param output: if not empty, the optimized model is saved
|
|
23
|
+
:param processor: optimization are done for the processor
|
|
24
|
+
:param infer_shapes: infer shapes before optimizing, this might not be
|
|
25
|
+
available for all algorithm
|
|
26
|
+
:param remove_shape_info: remove shape information before saving the model
|
|
27
|
+
:param verbose: verbosity level
|
|
28
|
+
:return: optimized model
|
|
29
|
+
|
|
30
|
+
The goal is to make the model faster.
|
|
31
|
+
Argument patterns defines the patterns to apply or the set of patterns.
|
|
32
|
+
It is possible to show statistics or to remove a particular pattern.
|
|
33
|
+
Here are some environment variables which can be used to trigger
|
|
34
|
+
these displays.
|
|
35
|
+
|
|
36
|
+
Available options algorithms, default and default+runtime:
|
|
37
|
+
|
|
38
|
+
- ``DROPPATTERN=<pattern1,patterns2,...>``: do not apply
|
|
39
|
+
those patterns when optimizing a model
|
|
40
|
+
- ``DUMPPATTERNS=<folder>``: dumps all matched and applied nodes when a pattern is applied
|
|
41
|
+
- ``PATTERN=<pattern1,pattern2,...>``: increase verbosity
|
|
42
|
+
for specific patterns to understand why one pattern was not applied,
|
|
43
|
+
this shows which line is rejecting a pattern if it seems one pattern was missed
|
|
44
|
+
"""
|
|
45
|
+
if isinstance(model, str):
|
|
46
|
+
if verbose:
|
|
47
|
+
print(f"[optimize_model] load {model!r}")
|
|
48
|
+
proto = onnx.load(model)
|
|
49
|
+
if verbose:
|
|
50
|
+
print("[optimize_model] done loading.")
|
|
51
|
+
else:
|
|
52
|
+
proto = model
|
|
53
|
+
|
|
54
|
+
if verbose:
|
|
55
|
+
print(f"[optimize_model] optimize with {algorithm!r}")
|
|
56
|
+
if algorithm in {"default", "default+onnxruntime"}:
|
|
57
|
+
from experimental_experiment.xoptim import get_pattern_list
|
|
58
|
+
from experimental_experiment.xbuilder import GraphBuilder, OptimizationOptions
|
|
59
|
+
|
|
60
|
+
pats = get_pattern_list(algorithm)
|
|
61
|
+
|
|
62
|
+
gr = GraphBuilder(
|
|
63
|
+
proto,
|
|
64
|
+
infer_shapes_options=infer_shapes,
|
|
65
|
+
optimization_options=OptimizationOptions(
|
|
66
|
+
patterns=pats,
|
|
67
|
+
verbose=verbose,
|
|
68
|
+
remove_unused=True,
|
|
69
|
+
constant_folding=True,
|
|
70
|
+
remove_identity=True,
|
|
71
|
+
max_iter=max(100, len(proto.graph.node) // 2),
|
|
72
|
+
processor=processor or "CPU",
|
|
73
|
+
),
|
|
74
|
+
)
|
|
75
|
+
if verbose:
|
|
76
|
+
print(f"[optimize_model] starts optimizing with {len(pats)} patterns")
|
|
77
|
+
print(f"[optimize_model] model has {len(proto.graph.node)} nodes")
|
|
78
|
+
opt_onx, report = gr.to_onnx(optimize=True, return_optimize_report=True)
|
|
79
|
+
if verbose:
|
|
80
|
+
print("[optimize_model] optimization report")
|
|
81
|
+
pprint.pprint(report)
|
|
82
|
+
print("[optimize_model] done")
|
|
83
|
+
|
|
84
|
+
elif algorithm == "slim":
|
|
85
|
+
import onnxslim
|
|
86
|
+
|
|
87
|
+
opt_onx = onnxslim.slim(proto, no_shape_infer=not infer_shapes)
|
|
88
|
+
elif algorithm in {"ir", "os_ort"}:
|
|
89
|
+
import onnx_ir
|
|
90
|
+
import onnxscript.optimizer
|
|
91
|
+
from onnxscript.rewriter.ort_fusions import optimize_for_ort
|
|
92
|
+
|
|
93
|
+
model_ir = onnx_ir.from_proto(proto)
|
|
94
|
+
if algorithm == "ir":
|
|
95
|
+
onnxscript.optimizer.optimize(model_ir)
|
|
96
|
+
else:
|
|
97
|
+
optimize_for_ort(model_ir)
|
|
98
|
+
opt_onx = onnx_ir.serde.serialize_model(model_ir)
|
|
99
|
+
|
|
100
|
+
del proto
|
|
101
|
+
if verbose:
|
|
102
|
+
print(f"[optimize_model] done optimizing, model has {len(opt_onx.graph.node)} nodes")
|
|
103
|
+
if remove_shape_info:
|
|
104
|
+
if verbose:
|
|
105
|
+
print(f"[optimize_model] remove shape information {len(opt_onx.graph.value_info)}")
|
|
106
|
+
del opt_onx.graph.value_info[:]
|
|
107
|
+
if verbose:
|
|
108
|
+
print("[optimize_model] done removing shape info")
|
|
109
|
+
|
|
110
|
+
if output:
|
|
111
|
+
if verbose:
|
|
112
|
+
print(f"[optimize_model] save file into {output!r}")
|
|
113
|
+
onnx.save(opt_onx, output, save_as_external_data=True)
|
|
114
|
+
if verbose:
|
|
115
|
+
print("[optimize_model] done saving")
|
|
116
|
+
return opt_onx
|
|
@@ -13,6 +13,10 @@ from .data import get_data
|
|
|
13
13
|
__TASK__ = "image-text-to-text"
|
|
14
14
|
|
|
15
15
|
|
|
16
|
+
def should_have_vision_config(config):
|
|
17
|
+
return config.architectures != ["FuyuForCausalLM"]
|
|
18
|
+
|
|
19
|
+
|
|
16
20
|
def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
17
21
|
"""Reduces a model size."""
|
|
18
22
|
kwargs: Dict[str, Any] = {}
|
|
@@ -477,7 +481,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
477
481
|
"hidden_size",
|
|
478
482
|
"pad_token_id",
|
|
479
483
|
)
|
|
480
|
-
|
|
484
|
+
if should_have_vision_config(config):
|
|
485
|
+
check_hasattr(config, "vision_config", ("image_token_index", "image_token_id"))
|
|
481
486
|
text_config = True
|
|
482
487
|
else:
|
|
483
488
|
check_hasattr(
|
|
@@ -491,7 +496,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
491
496
|
"vision_config",
|
|
492
497
|
)
|
|
493
498
|
text_config = False
|
|
494
|
-
|
|
499
|
+
if should_have_vision_config(config):
|
|
500
|
+
check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels"))
|
|
495
501
|
kwargs = dict(
|
|
496
502
|
head_dim=(
|
|
497
503
|
16
|
|
@@ -552,17 +558,21 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
552
558
|
),
|
|
553
559
|
width=(
|
|
554
560
|
224
|
|
555
|
-
if config is None
|
|
561
|
+
if config is None
|
|
562
|
+
or not should_have_vision_config(config)
|
|
563
|
+
or not hasattr(config.vision_config, "image_size")
|
|
556
564
|
else config.vision_config.image_size
|
|
557
565
|
),
|
|
558
566
|
height=(
|
|
559
567
|
224
|
|
560
|
-
if config is None
|
|
568
|
+
if config is None
|
|
569
|
+
or not should_have_vision_config(config)
|
|
570
|
+
or not hasattr(config.vision_config, "image_size")
|
|
561
571
|
else config.vision_config.image_size
|
|
562
572
|
),
|
|
563
573
|
num_channels=(
|
|
564
574
|
3
|
|
565
|
-
if config is None
|
|
575
|
+
if config is None or not should_have_vision_config(config)
|
|
566
576
|
else _pick(config.vision_config, "num_channels", "in_chans", "in_channels")
|
|
567
577
|
),
|
|
568
578
|
pad_token_id=(
|