onnx-diagnostic 0.8.4__py3-none-any.whl → 0.8.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,140 +0,0 @@
1
- from typing import Any, Callable, Union
2
- import torch
3
- from torch._C import DispatchKey
4
-
5
- # from torch._higher_order_ops import BaseHOP
6
- from torch._ops import HigherOrderOperator
7
- from torch._functorch.utils import exposed_in
8
- import torch.utils._pytree as pytree
9
- from torch._higher_order_ops.utils import (
10
- check_input_alias_and_mutation_return_outputs,
11
- reenter_make_fx,
12
- unique_graph_id,
13
- validate_subgraph_args_types,
14
- )
15
- from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
16
- from torch.utils._python_dispatch import _get_current_dispatch_mode
17
- from .control_flow_onnx import _loop_for_onnx_fn
18
-
19
-
20
- class SimpleLoopForOp(HigherOrderOperator):
21
- def __init__(self):
22
- super().__init__("simple_loop_for")
23
-
24
- def __call__(self, n_iter, body_fn, operands):
25
- validate_subgraph_args_types(operands)
26
- return super().__call__(n_iter, body_fn, operands)
27
-
28
- def gen_schema(self, n_iter, body_fn, operands):
29
- from torch._higher_order_ops.schema import HopSchemaGenerator
30
- from torch._higher_order_ops.utils import materialize_as_graph
31
-
32
- body_gm: torch.fx.GraphModule = materialize_as_graph( # type: ignore[annotation-unchecked]
33
- body_fn, (torch.tensor(0, dtype=torch.int64), *operands)
34
- )
35
- (
36
- _,
37
- _,
38
- _,
39
- body_mutated_inputs,
40
- body_outputs,
41
- ) = check_input_alias_and_mutation_return_outputs(body_gm)
42
- mutated_inputs = body_mutated_inputs
43
-
44
- schema_gen = HopSchemaGenerator(self)
45
- schema_gen.add_arg("n_iter", n_iter)
46
- schema_gen.add_arg("body_fn", body_gm)
47
- for idx, arg in enumerate(operands):
48
- schema_gen.add_arg(f"operand{idx}", arg, is_mutated=idx in mutated_inputs)
49
-
50
- for out in body_outputs:
51
- schema_gen.add_output(out)
52
- schema_gen.add_schema_tree_spec(n_iter, body_fn, operands)
53
- return schema_gen.gen_schema()
54
-
55
-
56
- simple_loop_for_op = SimpleLoopForOp()
57
-
58
-
59
- @exposed_in("torch")
60
- def simple_loop_for(
61
- n_iter: Union[int, torch.Tensor],
62
- body_fn: Callable,
63
- operands: Union[tuple, list] = (),
64
- ) -> Any:
65
- if torch.compiler.is_dynamo_compiling():
66
- return simple_loop_for_op(n_iter, body_fn, (n_iter, *operands))
67
-
68
- if isinstance(n_iter, (bool, int, float)):
69
- return _loop_for_onnx_fn(body_fn, n_iter, None, *operands)
70
-
71
- def _validate_input(n_iter, body_fn, operands):
72
- assert isinstance(
73
- n_iter, (int, torch.Tensor, torch.SymInt)
74
- ), f"Expected pred to be bool or tensor, but got {n_iter}."
75
- assert (
76
- not isinstance(n_iter, torch.Tensor) or n_iter.numel() == 1
77
- ), f"Expected pred to be bool or single-element tensor, but got {n_iter}."
78
- assert callable(body_fn), "Expect both branches to be callable."
79
- assert isinstance(operands, (tuple, list)) and pytree.tree_all(
80
- lambda t: isinstance(t, torch.Tensor), operands
81
- ), (
82
- "Expect operands to be a tuple of possibly nested dict/list/tuple that only "
83
- f"consists of tensor leaves, but got {operands}."
84
- )
85
-
86
- _validate_input(n_iter, body_fn, operands)
87
-
88
- assert torch._dynamo.is_dynamo_supported(), "torch.cond requires dynamo support."
89
-
90
- def _loop_for_op_wrapper(*args, **kwargs):
91
- return simple_loop_for_op(*args, **kwargs)
92
-
93
- from torch._higher_order_ops.utils import setup_compilation_env
94
-
95
- with setup_compilation_env() as _backend:
96
- return _loop_for_op_wrapper(n_iter, body_fn, *operands)
97
- # return torch.compile(_loop_for_op_wrapper, backend=backend, fullgraph=True)(
98
- # n_iter, body_fn, operands
99
- # )
100
-
101
-
102
- def trace_loop_for(proxy_mode, func_overload, n_iter, body_fn, operands):
103
- assert isinstance(
104
- operands, (list, tuple)
105
- ), f"Cond operands must be a list or tuple of tensors and SymInts {operands}"
106
-
107
- body_graph = reenter_make_fx(body_fn)(n_iter, *operands)
108
-
109
- body_outs = []
110
- for node in body_graph.graph.nodes:
111
- if node.op == "output":
112
- body_outs.extend(node.args)
113
-
114
- # flat_body_outs = pytree.arg_tree_leaves(*body_outs)
115
- _i, body_name = unique_graph_id(proxy_mode, prefix="body_graph")
116
- proxy_mode.tracer.root.register_module(body_name, body_graph)
117
- args = (n_iter, body_graph, body_graph, operands)
118
- proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
119
- out_proxy = proxy_mode.tracer.create_proxy("call_function", func_overload, proxy_args, {})
120
- out = func_overload(n_iter, body_graph, operands)
121
- return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
122
-
123
-
124
- @simple_loop_for_op.py_impl(DispatchKey.CompositeExplicitAutograd)
125
- def loop_for_op_dense(n_iter, body_fn, operands):
126
- assert all(
127
- isinstance(o, (torch.Tensor, int)) for o in operands
128
- ), f"Dense implementation operands must be a list of tensors and ints {operands}"
129
- mode = _get_current_dispatch_mode()
130
- assert mode is None, "Mode should never be enabled for CPU/CUDA key"
131
- return _loop_for_onnx_fn(body_fn, n_iter, None, operands)
132
-
133
-
134
- @simple_loop_for_op.py_impl(ProxyTorchDispatchMode)
135
- def inner(mode, n_iter, body_fn, operands):
136
- return trace_loop_for(mode, simple_loop_for_op, n_iter, body_fn, operands)
137
-
138
-
139
- simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCPU)
140
- simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCUDA)