dlinfer-ascend 0.2.3.post2__cp311-cp311-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. dlinfer/__init__.py +5 -0
  2. dlinfer/framework/__init__.py +1 -0
  3. dlinfer/framework/lmdeploy_ext/__init__.py +6 -0
  4. dlinfer/framework/lmdeploy_ext/cudagraph/__init__.py +20 -0
  5. dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py +391 -0
  6. dlinfer/framework/lmdeploy_ext/cudagraph/camb_cudagraph.py +133 -0
  7. dlinfer/framework/lmdeploy_ext/cudagraph/maca_cudagraph.py +128 -0
  8. dlinfer/framework/lmdeploy_ext/cudagraph/ppu_cudagraph.py +131 -0
  9. dlinfer/framework/lmdeploy_ext/device/__init__.py +79 -0
  10. dlinfer/framework/lmdeploy_ext/device/ascend.py +205 -0
  11. dlinfer/framework/lmdeploy_ext/device/camb.py +24 -0
  12. dlinfer/framework/lmdeploy_ext/quants/__init__.py +20 -0
  13. dlinfer/framework/lmdeploy_ext/quants/ascend_awq.py +248 -0
  14. dlinfer/framework/torch_npu_ext/__init__.py +12 -0
  15. dlinfer/framework/torch_npu_ext/aclgraph.py +59 -0
  16. dlinfer/framework/transformers_ext/__init__.py +17 -0
  17. dlinfer/framework/transformers_ext/cogvlm.py +25 -0
  18. dlinfer/framework/transformers_ext/internlm2.py +242 -0
  19. dlinfer/framework/transformers_ext/internvl.py +33 -0
  20. dlinfer/framework/transformers_ext/patch.py +33 -0
  21. dlinfer/graph/__init__.py +5 -0
  22. dlinfer/graph/custom_op.py +147 -0
  23. dlinfer/graph/dicp/__init__.py +0 -0
  24. dlinfer/graph/dicp/dynamo_bridge/__init__.py +0 -0
  25. dlinfer/graph/dicp/dynamo_bridge/compile.py +42 -0
  26. dlinfer/graph/dicp/dynamo_bridge/compile_fx.py +305 -0
  27. dlinfer/graph/dicp/dynamo_bridge/conversion.py +75 -0
  28. dlinfer/graph/dicp/dynamo_bridge/decompositions.py +38 -0
  29. dlinfer/graph/dicp/dynamo_bridge/graph.py +141 -0
  30. dlinfer/graph/dicp/dynamo_bridge/op_transformer.py +293 -0
  31. dlinfer/graph/dicp/dynamo_bridge/operator.py +87 -0
  32. dlinfer/graph/dicp/dynamo_bridge/pt_patch.py +320 -0
  33. dlinfer/graph/dicp/dynamo_bridge/torch_version.py +38 -0
  34. dlinfer/graph/dicp/dynamo_bridge/utils.py +158 -0
  35. dlinfer/graph/dicp/vendor/AtbGraph/__init__.py +13 -0
  36. dlinfer/graph/dicp/vendor/AtbGraph/atb_op.py +853 -0
  37. dlinfer/graph/dicp/vendor/AtbGraph/codegen/__init__.py +0 -0
  38. dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb.py +318 -0
  39. dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_graph.py +768 -0
  40. dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_infer_param.py +763 -0
  41. dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_op.py +1279 -0
  42. dlinfer/graph/dicp/vendor/AtbGraph/codegen/libdicp_model.so +0 -0
  43. dlinfer/graph/dicp/vendor/AtbGraph/codegen/load_and_run.py +21 -0
  44. dlinfer/graph/dicp/vendor/AtbGraph/codegen/utils.py +178 -0
  45. dlinfer/graph/dicp/vendor/AtbGraph/compile_job.py +52 -0
  46. dlinfer/graph/dicp/vendor/AtbGraph/config.py +36 -0
  47. dlinfer/graph/dicp/vendor/AtbGraph/conversion.py +908 -0
  48. dlinfer/graph/dicp/vendor/AtbGraph/ext_ops.py +95 -0
  49. dlinfer/graph/dicp/vendor/AtbGraph/infer_res_utils.py +200 -0
  50. dlinfer/graph/dicp/vendor/AtbGraph/opset_convert.py +70 -0
  51. dlinfer/graph/dicp/vendor/AtbGraph/pattern_replacement.py +152 -0
  52. dlinfer/graph/dicp/vendor/__init__.py +0 -0
  53. dlinfer/ops/__init__.py +2 -0
  54. dlinfer/ops/llm.py +879 -0
  55. dlinfer/utils/__init__.py +1 -0
  56. dlinfer/utils/config.py +18 -0
  57. dlinfer/utils/registry.py +8 -0
  58. dlinfer/utils/type_annotation.py +3 -0
  59. dlinfer/vendor/__init__.py +33 -0
  60. dlinfer/vendor/ascend/__init__.py +5 -0
  61. dlinfer/vendor/ascend/pytorch_patch.py +55 -0
  62. dlinfer/vendor/ascend/torch_npu_ops.py +601 -0
  63. dlinfer/vendor/ascend/utils.py +20 -0
  64. dlinfer/vendor/vendor.yaml +2 -0
  65. dlinfer_ascend-0.2.3.post2.dist-info/LICENSE +28 -0
  66. dlinfer_ascend-0.2.3.post2.dist-info/METADATA +213 -0
  67. dlinfer_ascend-0.2.3.post2.dist-info/RECORD +70 -0
  68. dlinfer_ascend-0.2.3.post2.dist-info/WHEEL +5 -0
  69. dlinfer_ascend-0.2.3.post2.dist-info/entry_points.txt +2 -0
  70. dlinfer_ascend-0.2.3.post2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,147 @@
1
+ # Copyright (c) 2024, DeepLink. All rights reserved.
2
+ import inspect
3
+ from functools import wraps
4
+
5
+ from torch.library import Library, impl
6
+
7
+ import dlinfer.graph
8
+ from dlinfer.utils.type_annotation import Callable, Optional, Sequence, Dict
9
+ from dlinfer.vendor import dispatch_key, vendor_name
10
+
11
+ library_impl_dict: Dict[str, Library] = dict()
12
+ graph_enabled_backends = ["ascend"]
13
+
14
+
15
+ def register_custom_op(
16
+ qualname: str,
17
+ shape_param_keys: Optional[Sequence[str]] = None,
18
+ default_value: Optional[Dict] = None,
19
+ impl_abstract_func: Optional[Callable] = None,
20
+ ) -> Callable:
21
+ disable = vendor_name not in graph_enabled_backends
22
+
23
+ def inner_func(func: Callable):
24
+ if disable:
25
+ return override_default_value_static(default_value)(func)
26
+ import torch._custom_ops
27
+
28
+ nonlocal impl_abstract_func
29
+ lib_name, func_name = qualname.split("::")
30
+ torch._custom_ops.custom_op(qualname)(func)
31
+ # using low level torch.library APIs in case of the registration
32
+ # of fallback kernels which raises error in torch._custom_ops.impl
33
+ if lib_name not in library_impl_dict:
34
+ library_impl_dict[lib_name] = Library(lib_name, "IMPL")
35
+ impl(library_impl_dict[lib_name], func_name, dispatch_key)(func)
36
+ if impl_abstract_func is None:
37
+ assert shape_param_keys is not None
38
+ params_name_list = [name for name in inspect.signature(func).parameters]
39
+
40
+ def _impl_abstract_func(*args, **kwargs):
41
+ assert len(args) + len(kwargs) == len(params_name_list)
42
+ result = []
43
+ for key in shape_param_keys:
44
+ key_index = params_name_list.index(key)
45
+ if key_index < len(args):
46
+ target = args[key_index]
47
+ else:
48
+ target = kwargs[key]
49
+ result.append(torch.empty_like(target))
50
+ if len(result) == 1:
51
+ return result[0]
52
+ return tuple(result)
53
+
54
+ impl_abstract_func = _impl_abstract_func
55
+ torch._custom_ops.impl_abstract(qualname)(impl_abstract_func)
56
+ torch_ops_namespace = getattr(torch.ops, lib_name)
57
+ torch_ops_func = getattr(torch_ops_namespace, func_name)
58
+ assert torch_ops_func is not None
59
+ # override default value
60
+ func_with_default = override_default_value_static(default_value)(func)
61
+ torch_ops_func_with_default = override_default_value_dynamic(
62
+ default_value, func
63
+ )(torch_ops_func)
64
+
65
+ # use config.enable_graph_mode to control func call
66
+ @wraps(func)
67
+ def patched_func(*args, **kwargs):
68
+ if not dlinfer.graph.config.enable_graph_mode:
69
+ return func_with_default(*args, **kwargs)
70
+ else:
71
+ return torch_ops_func_with_default(*args, **kwargs)
72
+
73
+ return patched_func
74
+
75
+ return inner_func
76
+
77
+
78
+ def override_default_value_dynamic(
79
+ default_value: Optional[Dict], origin_func: Callable
80
+ ):
81
+ def inner_func(func):
82
+ if default_value is None:
83
+ return func
84
+ sig = inspect.signature(origin_func)
85
+ sig_param_keys = sig.parameters.keys()
86
+ params_str = ", ".join(sig_param_keys)
87
+ params_with_default = []
88
+ for name in sig_param_keys:
89
+ if name in default_value:
90
+ if isinstance(default_value[name], str):
91
+ params_with_default.append(f"{name}='{default_value[name]}'")
92
+ else:
93
+ params_with_default.append(f"{name}={default_value[name]}")
94
+ else:
95
+ params_with_default.append(name)
96
+ params_str_with_default = ", ".join(params_with_default)
97
+ func_code = f"""
98
+ def {func.__name__}({params_str_with_default}):
99
+ return original_func({params_str})
100
+ """
101
+ exec_namespace = {}
102
+ # it's hard not to use exec here
103
+ exec(func_code, {"original_func": func}, exec_namespace)
104
+ dynamic_func = exec_namespace[func.__name__]
105
+
106
+ return dynamic_func
107
+
108
+ return inner_func
109
+
110
+
111
+ def override_default_value_static(default_value: Optional[Dict]):
112
+ # suitable for the function which signature isn't (*args, **kwargs)
113
+ def inner_func(func):
114
+ if default_value is None:
115
+ return func
116
+ sig = inspect.signature(func)
117
+ old_params = sig.parameters
118
+ new_params = []
119
+ default_arg = []
120
+ default_kwarg = []
121
+ func_co_argcount = func.__code__.co_argcount
122
+ param_has_default_value = False
123
+ for idx, (name, param) in enumerate(old_params.items()):
124
+ if name in default_value:
125
+ new_param = param.replace(default=default_value[name])
126
+ else:
127
+ new_param = param
128
+ new_params.append(new_param)
129
+ if new_param.default is not inspect._empty:
130
+ if not param_has_default_value:
131
+ param_has_default_value = True
132
+ if idx < func_co_argcount:
133
+ default_arg.append(new_param.default)
134
+ else:
135
+ default_kwarg.append((name, new_param.default))
136
+ else:
137
+ if param_has_default_value:
138
+ raise SyntaxError(
139
+ f"non-default argument '{name}' follows default argument"
140
+ )
141
+ new_signature = sig.replace(parameters=new_params)
142
+ func.__signature__ = new_signature
143
+ func.__defaults__ = tuple(default_arg)
144
+ func.__kwdefaults__ = dict(default_kwarg)
145
+ return func
146
+
147
+ return inner_func
File without changes
File without changes
@@ -0,0 +1,42 @@
1
+ from abc import ABCMeta, abstractmethod
2
+
3
+ from dlinfer.graph.dicp.dynamo_bridge.torch_version import is_torch_251_or_higher
4
+
5
+ if is_torch_251_or_higher:
6
+ from torch._inductor.async_compile import AsyncCompile
7
+ else:
8
+ from torch._inductor.codecache import AsyncCompile
9
+
10
+
11
+ class DeviceCompileJob:
12
+ __metaclass__ = ABCMeta
13
+
14
+ def __init__(self):
15
+ pass
16
+
17
+ @abstractmethod
18
+ def get_key():
19
+ pass
20
+
21
+ @abstractmethod
22
+ def get_compile_result():
23
+ pass
24
+
25
+
26
+ class DeviceKernelCache:
27
+ cache = dict()
28
+ clear = staticmethod(cache.clear)
29
+
30
+ @classmethod
31
+ def get_kernel(cls, device_compile_job):
32
+ key = device_compile_job.get_key()
33
+ if key not in cls.cache:
34
+ loaded = device_compile_job.get_compile_result()
35
+ cls.cache[key] = loaded
36
+ cls.cache[key].key = key
37
+ return cls.cache[key]
38
+
39
+
40
+ class AsyncCompileKernel(AsyncCompile):
41
+ def compile_kernel(self, device_compile_job):
42
+ return DeviceKernelCache.get_kernel(device_compile_job).run
@@ -0,0 +1,305 @@
1
+ from torch._dynamo.backends.common import aot_autograd
2
+ from torch._functorch.aot_autograd import make_boxed_func
3
+ from .graph import GraphTransformer
4
+ import functools
5
+ import itertools
6
+ import logging
7
+ import sys
8
+ import functorch
9
+ import torch.fx
10
+ import importlib
11
+ import os
12
+
13
+ from typing import List
14
+ from importlib import import_module
15
+
16
+ import torch
17
+
18
+ from dlinfer.graph.dicp.dynamo_bridge import pt_patch # noqa F401
19
+ from dlinfer.graph.dicp.dynamo_bridge.torch_version import (
20
+ is_torch_200,
21
+ is_torch_210_or_higher,
22
+ )
23
+
24
+
25
+ log = logging.getLogger(__name__)
26
+
27
+ dynamo_logging = import_module("torch._dynamo.logging")
28
+ dynamo_utils = import_module("torch._dynamo.utils")
29
+
30
+ count_calls = dynamo_utils.count_calls
31
+
32
+
33
+ def get_fake_mode_from_tensors(input_tensors):
34
+ if is_torch_200:
35
+ from torch._dynamo.utils import fake_mode_from_tensors
36
+
37
+ return fake_mode_from_tensors(input_tensors)
38
+ elif is_torch_210_or_higher:
39
+ from torch._dynamo.utils import detect_fake_mode
40
+
41
+ return detect_fake_mode(input_tensors)
42
+ else:
43
+ raise ValueError(f"unsupported dicp torch version: {torch.__version__}")
44
+
45
+
46
+ def used_nodes_all_symint(nodes):
47
+ for node in nodes:
48
+ if node.op == "placeholder" and len(node.users) > 0:
49
+ if hasattr(node, "meta"):
50
+ node = node.meta["val"]
51
+ if not isinstance(node, torch.SymInt):
52
+ return False
53
+ elif node.op == "output":
54
+ if hasattr(node, "meta") and "val" in node.meta:
55
+ node = node.meta["val"]
56
+ if not isinstance(node, torch.SymInt):
57
+ return False
58
+ return True
59
+
60
+
61
+ @torch.utils._python_dispatch._disable_current_modes()
62
+ def compile_fx_inner(
63
+ gm: torch.fx.GraphModule,
64
+ example_inputs: List[torch.Tensor],
65
+ num_fixed=0,
66
+ is_backward=False,
67
+ graph_id=None,
68
+ backend=None,
69
+ ):
70
+ if dynamo_utils.count_calls(gm.graph) == 0:
71
+ return make_boxed_func(gm.forward)
72
+
73
+ # all symint inputs fallback to eager mode
74
+ if used_nodes_all_symint(list(gm.graph.nodes)):
75
+ return gm
76
+
77
+ # lift the maximum depth of the Python interpreter stack
78
+ # to adapt large/deep models
79
+ sys.setrecursionlimit(max(sys.getrecursionlimit(), 2000))
80
+
81
+ gt = GraphTransformer(gm, backend)
82
+ gt.transform()
83
+ compiled_fn = gt.compile_to_fn()
84
+
85
+ # aot autograd needs to know to pass in inputs as a list
86
+ compiled_fn._boxed_call = True
87
+ return compiled_fn
88
+
89
+
90
+ _graph_counter = itertools.count(0)
91
+
92
+
93
+ def compile_fx(
94
+ model_: torch.fx.GraphModule,
95
+ example_inputs_: List[torch.Tensor],
96
+ backend: str,
97
+ inner_compile=compile_fx_inner,
98
+ ):
99
+ if is_torch_200:
100
+ return compile_fx_200(model_, example_inputs_, backend, inner_compile)
101
+ elif is_torch_210_or_higher:
102
+ return compile_fx_210(model_, example_inputs_, backend, inner_compile)
103
+ else:
104
+ raise ValueError(f"unsupported dicp torch version: {torch.__version__}")
105
+
106
+
107
+ def compile_fx_200(
108
+ model_: torch.fx.GraphModule,
109
+ example_inputs_: List[torch.Tensor],
110
+ backend: str,
111
+ inner_compile=compile_fx_inner,
112
+ ):
113
+ """Main entrypoint to a compile given FX graph"""
114
+ functorch.compile.config.use_functionalize = True
115
+ functorch.compile.config.use_fake_tensor = True
116
+
117
+ num_example_inputs = len(example_inputs_)
118
+
119
+ graph_id = next(_graph_counter)
120
+
121
+ @dynamo_utils.dynamo_timed
122
+ def fw_compiler(model: torch.fx.GraphModule, example_inputs):
123
+ fixed = len(example_inputs) - num_example_inputs
124
+ return inner_compile(
125
+ model,
126
+ example_inputs,
127
+ num_fixed=fixed,
128
+ graph_id=graph_id,
129
+ backend=backend,
130
+ )
131
+
132
+ @dynamo_utils.dynamo_timed
133
+ def bw_compiler(model: torch.fx.GraphModule, example_inputs):
134
+ fixed = count_tangents(model)
135
+ return inner_compile(
136
+ model,
137
+ example_inputs,
138
+ num_fixed=fixed,
139
+ is_backward=True,
140
+ graph_id=graph_id,
141
+ backend=backend,
142
+ )
143
+
144
+ decompositions = get_decompositions(backend=backend)
145
+ return aot_autograd(
146
+ fw_compiler=fw_compiler, bw_compiler=bw_compiler, decompositions=decompositions
147
+ )(model_, example_inputs_)
148
+
149
+
150
+ def compile_fx_210(
151
+ model_: torch.fx.GraphModule,
152
+ example_inputs_: List[torch.Tensor],
153
+ backend: str,
154
+ inner_compile=compile_fx_inner,
155
+ ):
156
+ import torch._dynamo.config as dynamo_config
157
+ from torch._inductor.compile_fx import (
158
+ flatten_graph_inputs,
159
+ graph_returns_tuple,
160
+ make_graph_return_tuple,
161
+ pre_grad_passes,
162
+ joint_graph_passes,
163
+ min_cut_rematerialization_partition,
164
+ _PyTreeCodeGen,
165
+ handle_dynamo_export_graph,
166
+ )
167
+
168
+ decompositions = get_decompositions(backend=backend)
169
+
170
+ recursive_compile_fx = functools.partial(
171
+ compile_fx,
172
+ inner_compile=inner_compile,
173
+ decompositions=decompositions,
174
+ )
175
+
176
+ if not graph_returns_tuple(model_):
177
+ return make_graph_return_tuple(
178
+ model_,
179
+ example_inputs_,
180
+ recursive_compile_fx,
181
+ )
182
+
183
+ if isinstance(model_, torch.fx.GraphModule):
184
+ if isinstance(model_.graph._codegen, _PyTreeCodeGen):
185
+ # this graph is the result of dynamo.export()
186
+ return handle_dynamo_export_graph(
187
+ model_,
188
+ example_inputs_,
189
+ recursive_compile_fx,
190
+ )
191
+
192
+ # Since handle_dynamo_export_graph will trigger compile_fx again,
193
+ # Move these passes after handle_dynamo_export_graph to avoid repeated calls.
194
+ model_ = pre_grad_passes(model_, example_inputs_)
195
+
196
+ if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_):
197
+ return flatten_graph_inputs(
198
+ model_,
199
+ example_inputs_,
200
+ recursive_compile_fx,
201
+ )
202
+
203
+ # assert not config._raise_error_for_testing
204
+ num_example_inputs = len(example_inputs_)
205
+
206
+ graph_id = next(_graph_counter)
207
+
208
+ def fw_compiler_base(
209
+ model: torch.fx.GraphModule,
210
+ example_inputs: List[torch.Tensor],
211
+ is_inference: bool,
212
+ ):
213
+ return _fw_compiler_base(model, example_inputs, is_inference)
214
+
215
+ def _fw_compiler_base(
216
+ model: torch.fx.GraphModule,
217
+ example_inputs: List[torch.Tensor],
218
+ is_inference: bool,
219
+ ):
220
+ if is_inference:
221
+ # partition_fn won't be called
222
+ # joint_graph_passes(model)
223
+ pass
224
+
225
+ fixed = len(example_inputs) - num_example_inputs
226
+ return inner_compile(
227
+ model,
228
+ example_inputs,
229
+ num_fixed=fixed,
230
+ graph_id=graph_id,
231
+ backend=backend,
232
+ )
233
+
234
+ fw_compiler = functools.partial(fw_compiler_base, is_inference=False)
235
+ inference_compiler = functools.partial(fw_compiler_base, is_inference=True)
236
+
237
+ def partition_fn(graph, joint_inputs, **kwargs):
238
+ joint_graph_passes(graph)
239
+ return min_cut_rematerialization_partition(
240
+ graph, joint_inputs, **kwargs, compiler="inductor"
241
+ )
242
+
243
+ # Save and restore dynamic shapes setting for backwards, as it is
244
+ # sometimes done as a context manager which won't be set when we
245
+ # hit backwards compile
246
+ dynamic_shapes = dynamo_config.dynamic_shapes
247
+
248
+ def bw_compiler(model: torch.fx.GraphModule, example_inputs):
249
+ with dynamo_config.patch(dynamic_shapes=dynamic_shapes):
250
+ fixed = count_tangents(model)
251
+ return inner_compile(
252
+ model,
253
+ example_inputs,
254
+ num_fixed=fixed,
255
+ is_backward=True,
256
+ graph_id=graph_id,
257
+ backend=backend,
258
+ )
259
+
260
+ # TODO: can add logging before/after the call to create_aot_dispatcher_function
261
+ # in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func
262
+ # once torchdynamo is merged into pytorch
263
+ return aot_autograd(
264
+ fw_compiler=fw_compiler,
265
+ bw_compiler=bw_compiler,
266
+ inference_compiler=inference_compiler,
267
+ decompositions=decompositions,
268
+ partition_fn=partition_fn,
269
+ keep_inference_input_mutations=True,
270
+ )(model_, example_inputs_)
271
+
272
+
273
+ def count_tangents(fx_g: torch.fx.GraphModule):
274
+ """
275
+ Infers which inputs are static for a backwards graph
276
+ """
277
+
278
+ def is_not_gradout(x):
279
+ return "tangents" not in x.name
280
+
281
+ arg_count = 0
282
+ static_arg_idxs = []
283
+ for n in fx_g.graph.nodes:
284
+ if n.op == "placeholder":
285
+ if is_not_gradout(n):
286
+ static_arg_idxs.append(arg_count)
287
+ arg_count += 1
288
+
289
+ assert static_arg_idxs == list(range(len(static_arg_idxs)))
290
+ return len(static_arg_idxs)
291
+
292
+
293
+ def get_decompositions(backend):
294
+ decompositions = {}
295
+ folder_list = os.listdir(os.path.dirname(os.path.dirname(__file__)) + "/vendor")
296
+ found_decomp = False
297
+ for folder in folder_list:
298
+ if backend.lower() == folder.lower():
299
+ config = importlib.import_module(
300
+ "dlinfer.graph.dicp.vendor." + folder + ".config"
301
+ )
302
+ decompositions = config.decomp
303
+ found_decomp = True
304
+ assert found_decomp, "Not found decomp table!"
305
+ return decompositions
@@ -0,0 +1,75 @@
1
+ import functools
2
+ import torch
3
+ from dlinfer.graph.dicp.dynamo_bridge.operator import Operator
4
+
5
+
6
+ def args_kwargs_unchange(args, kwargs):
7
+ return args, kwargs
8
+
9
+
10
+ def register_conversion_impl(
11
+ conversions: list, aten_fn, decomp_fn, process_args_kwargs_fn=None
12
+ ):
13
+ register_op_singleton_flag = isinstance(decomp_fn, type) and issubclass(
14
+ decomp_fn, Operator
15
+ )
16
+ if register_op_singleton_flag:
17
+ wrapped = (
18
+ decomp_fn.get_singleton(),
19
+ (
20
+ args_kwargs_unchange
21
+ if process_args_kwargs_fn is None
22
+ else process_args_kwargs_fn
23
+ ),
24
+ )
25
+ else:
26
+
27
+ @functools.wraps(decomp_fn)
28
+ def wrapped(*args, **kwargs):
29
+ return decomp_fn(*args, **kwargs)
30
+
31
+ if not isinstance(aten_fn, (list, tuple)):
32
+ aten_fn = [aten_fn]
33
+ else:
34
+ aten_fn = list(aten_fn)
35
+
36
+ aten_fn_for_key = []
37
+ for fn in list(aten_fn):
38
+ if isinstance(fn, str):
39
+ assert fn.startswith("torch.ops")
40
+ real_fn_name = fn.replace("torch.ops.", "")
41
+ ns, op_overload = real_fn_name.split(".", 1)
42
+ if not hasattr(torch.ops, ns):
43
+ print(
44
+ f"[dicp] can't find torch.ops.{ns}, conversion for {fn} is ignored"
45
+ )
46
+ continue
47
+ ns_obj = getattr(torch.ops, ns)
48
+ if "." in op_overload:
49
+ op, overload = op_overload.split(".", 1)
50
+ if not hasattr(ns_obj, op):
51
+ print(
52
+ f"[dicp] can't find torch.ops.{ns}.{op}, conversion for {fn} is ignored"
53
+ )
54
+ continue
55
+ op_obj = getattr(ns_obj, op)
56
+ fn = getattr(op_obj, overload)
57
+ else:
58
+ if not hasattr(ns_obj, op_overload):
59
+ print(
60
+ f"[dicp] can't find torch.ops.{ns}.{op_overload}, conversion for {fn} is ignored"
61
+ )
62
+ continue
63
+ fn = getattr(ns_obj, op_overload)
64
+ if isinstance(fn, torch._ops.OpOverloadPacket):
65
+ for overload in fn.overloads():
66
+ other_fn = getattr(fn, overload)
67
+ if other_fn not in conversions:
68
+ aten_fn_for_key.append(other_fn)
69
+ aten_fn_for_key.append(fn)
70
+
71
+ conversions.update({fn: wrapped for fn in aten_fn_for_key})
72
+ if register_op_singleton_flag:
73
+ return wrapped[0]
74
+ else:
75
+ return wrapped
@@ -0,0 +1,38 @@
1
+ from collections import defaultdict
2
+ from typing import Callable, Dict, Sequence, Union
3
+
4
+ import torch
5
+ from torch._decomp import register_decomposition
6
+ from torch._ops import OpOverload, OpOverloadPacket
7
+
8
+ dicp_decomposition_table = {}
9
+ aten = torch.ops.aten
10
+
11
+
12
+ def register_decomposition_for_dicp(fn):
13
+ return register_decomposition(fn, registry=dicp_decomposition_table)
14
+
15
+
16
+ @register_decomposition_for_dicp(aten.count_nonzero.default)
17
+ def count_nonzero_default(x, dim=None):
18
+ cond = x != 0
19
+ dim = [] if dim is None else dim
20
+ return aten.sum.dim_IntList(cond, dim=dim, keepdim=False, dtype=torch.int64)
21
+
22
+
23
+ def get_decompositions(
24
+ aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]],
25
+ target_decomposition_table: Dict[OpOverload, Callable] = None,
26
+ ) -> Dict[OpOverload, Callable]:
27
+ registry = dicp_decomposition_table
28
+ packets_to_overloads = defaultdict(list)
29
+ for opo in registry:
30
+ packets_to_overloads[opo.overloadpacket].append(opo)
31
+ decompositions = target_decomposition_table if target_decomposition_table else {}
32
+ for op in aten_ops:
33
+ if isinstance(op, OpOverloadPacket) and op in packets_to_overloads:
34
+ for op_overload in packets_to_overloads[op]:
35
+ decompositions[op_overload] = registry[op_overload]
36
+ elif isinstance(op, OpOverload) and op in registry:
37
+ decompositions[op] = registry[op]
38
+ return decompositions