onnx-diagnostic 0.8.4__py3-none-any.whl → 0.8.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +67 -9
- onnx_diagnostic/ci_models/__init__.py +0 -0
- onnx_diagnostic/ci_models/ci_helpers.py +430 -0
- onnx_diagnostic/ci_models/export_qwen25_vl.py +560 -0
- onnx_diagnostic/export/api.py +15 -4
- onnx_diagnostic/export/cf_simple_loop_for.py +352 -0
- onnx_diagnostic/export/control_flow_onnx.py +23 -17
- onnx_diagnostic/export/onnx_plug.py +60 -6
- onnx_diagnostic/ext_test_case.py +14 -0
- onnx_diagnostic/helpers/helper.py +26 -27
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +16 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +10 -1
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +103 -31
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +29 -8
- onnx_diagnostic/torch_onnx/compare.py +357 -0
- {onnx_diagnostic-0.8.4.dist-info → onnx_diagnostic-0.8.6.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.4.dist-info → onnx_diagnostic-0.8.6.dist-info}/RECORD +22 -19
- onnx_diagnostic/export/control_flow.py +0 -214
- onnx_diagnostic/export/control_flow_research.py +0 -140
- {onnx_diagnostic-0.8.4.dist-info → onnx_diagnostic-0.8.6.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.4.dist-info → onnx_diagnostic-0.8.6.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.4.dist-info → onnx_diagnostic-0.8.6.dist-info}/top_level.txt +0 -0
|
@@ -1,140 +0,0 @@
|
|
|
1
|
-
from typing import Any, Callable, Union
|
|
2
|
-
import torch
|
|
3
|
-
from torch._C import DispatchKey
|
|
4
|
-
|
|
5
|
-
# from torch._higher_order_ops import BaseHOP
|
|
6
|
-
from torch._ops import HigherOrderOperator
|
|
7
|
-
from torch._functorch.utils import exposed_in
|
|
8
|
-
import torch.utils._pytree as pytree
|
|
9
|
-
from torch._higher_order_ops.utils import (
|
|
10
|
-
check_input_alias_and_mutation_return_outputs,
|
|
11
|
-
reenter_make_fx,
|
|
12
|
-
unique_graph_id,
|
|
13
|
-
validate_subgraph_args_types,
|
|
14
|
-
)
|
|
15
|
-
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
|
|
16
|
-
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
|
17
|
-
from .control_flow_onnx import _loop_for_onnx_fn
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class SimpleLoopForOp(HigherOrderOperator):
|
|
21
|
-
def __init__(self):
|
|
22
|
-
super().__init__("simple_loop_for")
|
|
23
|
-
|
|
24
|
-
def __call__(self, n_iter, body_fn, operands):
|
|
25
|
-
validate_subgraph_args_types(operands)
|
|
26
|
-
return super().__call__(n_iter, body_fn, operands)
|
|
27
|
-
|
|
28
|
-
def gen_schema(self, n_iter, body_fn, operands):
|
|
29
|
-
from torch._higher_order_ops.schema import HopSchemaGenerator
|
|
30
|
-
from torch._higher_order_ops.utils import materialize_as_graph
|
|
31
|
-
|
|
32
|
-
body_gm: torch.fx.GraphModule = materialize_as_graph( # type: ignore[annotation-unchecked]
|
|
33
|
-
body_fn, (torch.tensor(0, dtype=torch.int64), *operands)
|
|
34
|
-
)
|
|
35
|
-
(
|
|
36
|
-
_,
|
|
37
|
-
_,
|
|
38
|
-
_,
|
|
39
|
-
body_mutated_inputs,
|
|
40
|
-
body_outputs,
|
|
41
|
-
) = check_input_alias_and_mutation_return_outputs(body_gm)
|
|
42
|
-
mutated_inputs = body_mutated_inputs
|
|
43
|
-
|
|
44
|
-
schema_gen = HopSchemaGenerator(self)
|
|
45
|
-
schema_gen.add_arg("n_iter", n_iter)
|
|
46
|
-
schema_gen.add_arg("body_fn", body_gm)
|
|
47
|
-
for idx, arg in enumerate(operands):
|
|
48
|
-
schema_gen.add_arg(f"operand{idx}", arg, is_mutated=idx in mutated_inputs)
|
|
49
|
-
|
|
50
|
-
for out in body_outputs:
|
|
51
|
-
schema_gen.add_output(out)
|
|
52
|
-
schema_gen.add_schema_tree_spec(n_iter, body_fn, operands)
|
|
53
|
-
return schema_gen.gen_schema()
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
simple_loop_for_op = SimpleLoopForOp()
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
@exposed_in("torch")
|
|
60
|
-
def simple_loop_for(
|
|
61
|
-
n_iter: Union[int, torch.Tensor],
|
|
62
|
-
body_fn: Callable,
|
|
63
|
-
operands: Union[tuple, list] = (),
|
|
64
|
-
) -> Any:
|
|
65
|
-
if torch.compiler.is_dynamo_compiling():
|
|
66
|
-
return simple_loop_for_op(n_iter, body_fn, (n_iter, *operands))
|
|
67
|
-
|
|
68
|
-
if isinstance(n_iter, (bool, int, float)):
|
|
69
|
-
return _loop_for_onnx_fn(body_fn, n_iter, None, *operands)
|
|
70
|
-
|
|
71
|
-
def _validate_input(n_iter, body_fn, operands):
|
|
72
|
-
assert isinstance(
|
|
73
|
-
n_iter, (int, torch.Tensor, torch.SymInt)
|
|
74
|
-
), f"Expected pred to be bool or tensor, but got {n_iter}."
|
|
75
|
-
assert (
|
|
76
|
-
not isinstance(n_iter, torch.Tensor) or n_iter.numel() == 1
|
|
77
|
-
), f"Expected pred to be bool or single-element tensor, but got {n_iter}."
|
|
78
|
-
assert callable(body_fn), "Expect both branches to be callable."
|
|
79
|
-
assert isinstance(operands, (tuple, list)) and pytree.tree_all(
|
|
80
|
-
lambda t: isinstance(t, torch.Tensor), operands
|
|
81
|
-
), (
|
|
82
|
-
"Expect operands to be a tuple of possibly nested dict/list/tuple that only "
|
|
83
|
-
f"consists of tensor leaves, but got {operands}."
|
|
84
|
-
)
|
|
85
|
-
|
|
86
|
-
_validate_input(n_iter, body_fn, operands)
|
|
87
|
-
|
|
88
|
-
assert torch._dynamo.is_dynamo_supported(), "torch.cond requires dynamo support."
|
|
89
|
-
|
|
90
|
-
def _loop_for_op_wrapper(*args, **kwargs):
|
|
91
|
-
return simple_loop_for_op(*args, **kwargs)
|
|
92
|
-
|
|
93
|
-
from torch._higher_order_ops.utils import setup_compilation_env
|
|
94
|
-
|
|
95
|
-
with setup_compilation_env() as _backend:
|
|
96
|
-
return _loop_for_op_wrapper(n_iter, body_fn, *operands)
|
|
97
|
-
# return torch.compile(_loop_for_op_wrapper, backend=backend, fullgraph=True)(
|
|
98
|
-
# n_iter, body_fn, operands
|
|
99
|
-
# )
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
def trace_loop_for(proxy_mode, func_overload, n_iter, body_fn, operands):
|
|
103
|
-
assert isinstance(
|
|
104
|
-
operands, (list, tuple)
|
|
105
|
-
), f"Cond operands must be a list or tuple of tensors and SymInts {operands}"
|
|
106
|
-
|
|
107
|
-
body_graph = reenter_make_fx(body_fn)(n_iter, *operands)
|
|
108
|
-
|
|
109
|
-
body_outs = []
|
|
110
|
-
for node in body_graph.graph.nodes:
|
|
111
|
-
if node.op == "output":
|
|
112
|
-
body_outs.extend(node.args)
|
|
113
|
-
|
|
114
|
-
# flat_body_outs = pytree.arg_tree_leaves(*body_outs)
|
|
115
|
-
_i, body_name = unique_graph_id(proxy_mode, prefix="body_graph")
|
|
116
|
-
proxy_mode.tracer.root.register_module(body_name, body_graph)
|
|
117
|
-
args = (n_iter, body_graph, body_graph, operands)
|
|
118
|
-
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
|
|
119
|
-
out_proxy = proxy_mode.tracer.create_proxy("call_function", func_overload, proxy_args, {})
|
|
120
|
-
out = func_overload(n_iter, body_graph, operands)
|
|
121
|
-
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
@simple_loop_for_op.py_impl(DispatchKey.CompositeExplicitAutograd)
|
|
125
|
-
def loop_for_op_dense(n_iter, body_fn, operands):
|
|
126
|
-
assert all(
|
|
127
|
-
isinstance(o, (torch.Tensor, int)) for o in operands
|
|
128
|
-
), f"Dense implementation operands must be a list of tensors and ints {operands}"
|
|
129
|
-
mode = _get_current_dispatch_mode()
|
|
130
|
-
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
|
|
131
|
-
return _loop_for_onnx_fn(body_fn, n_iter, None, operands)
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
@simple_loop_for_op.py_impl(ProxyTorchDispatchMode)
|
|
135
|
-
def inner(mode, n_iter, body_fn, operands):
|
|
136
|
-
return trace_loop_for(mode, simple_loop_for_op, n_iter, body_fn, operands)
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCPU)
|
|
140
|
-
simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCUDA)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|