dlinfer-ascend 0.2.3.post2__cp311-cp311-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dlinfer/__init__.py +5 -0
- dlinfer/framework/__init__.py +1 -0
- dlinfer/framework/lmdeploy_ext/__init__.py +6 -0
- dlinfer/framework/lmdeploy_ext/cudagraph/__init__.py +20 -0
- dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py +391 -0
- dlinfer/framework/lmdeploy_ext/cudagraph/camb_cudagraph.py +133 -0
- dlinfer/framework/lmdeploy_ext/cudagraph/maca_cudagraph.py +128 -0
- dlinfer/framework/lmdeploy_ext/cudagraph/ppu_cudagraph.py +131 -0
- dlinfer/framework/lmdeploy_ext/device/__init__.py +79 -0
- dlinfer/framework/lmdeploy_ext/device/ascend.py +205 -0
- dlinfer/framework/lmdeploy_ext/device/camb.py +24 -0
- dlinfer/framework/lmdeploy_ext/quants/__init__.py +20 -0
- dlinfer/framework/lmdeploy_ext/quants/ascend_awq.py +248 -0
- dlinfer/framework/torch_npu_ext/__init__.py +12 -0
- dlinfer/framework/torch_npu_ext/aclgraph.py +59 -0
- dlinfer/framework/transformers_ext/__init__.py +17 -0
- dlinfer/framework/transformers_ext/cogvlm.py +25 -0
- dlinfer/framework/transformers_ext/internlm2.py +242 -0
- dlinfer/framework/transformers_ext/internvl.py +33 -0
- dlinfer/framework/transformers_ext/patch.py +33 -0
- dlinfer/graph/__init__.py +5 -0
- dlinfer/graph/custom_op.py +147 -0
- dlinfer/graph/dicp/__init__.py +0 -0
- dlinfer/graph/dicp/dynamo_bridge/__init__.py +0 -0
- dlinfer/graph/dicp/dynamo_bridge/compile.py +42 -0
- dlinfer/graph/dicp/dynamo_bridge/compile_fx.py +305 -0
- dlinfer/graph/dicp/dynamo_bridge/conversion.py +75 -0
- dlinfer/graph/dicp/dynamo_bridge/decompositions.py +38 -0
- dlinfer/graph/dicp/dynamo_bridge/graph.py +141 -0
- dlinfer/graph/dicp/dynamo_bridge/op_transformer.py +293 -0
- dlinfer/graph/dicp/dynamo_bridge/operator.py +87 -0
- dlinfer/graph/dicp/dynamo_bridge/pt_patch.py +320 -0
- dlinfer/graph/dicp/dynamo_bridge/torch_version.py +38 -0
- dlinfer/graph/dicp/dynamo_bridge/utils.py +158 -0
- dlinfer/graph/dicp/vendor/AtbGraph/__init__.py +13 -0
- dlinfer/graph/dicp/vendor/AtbGraph/atb_op.py +853 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/__init__.py +0 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb.py +318 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_graph.py +768 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_infer_param.py +763 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_op.py +1279 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/libdicp_model.so +0 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/load_and_run.py +21 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/utils.py +178 -0
- dlinfer/graph/dicp/vendor/AtbGraph/compile_job.py +52 -0
- dlinfer/graph/dicp/vendor/AtbGraph/config.py +36 -0
- dlinfer/graph/dicp/vendor/AtbGraph/conversion.py +908 -0
- dlinfer/graph/dicp/vendor/AtbGraph/ext_ops.py +95 -0
- dlinfer/graph/dicp/vendor/AtbGraph/infer_res_utils.py +200 -0
- dlinfer/graph/dicp/vendor/AtbGraph/opset_convert.py +70 -0
- dlinfer/graph/dicp/vendor/AtbGraph/pattern_replacement.py +152 -0
- dlinfer/graph/dicp/vendor/__init__.py +0 -0
- dlinfer/ops/__init__.py +2 -0
- dlinfer/ops/llm.py +879 -0
- dlinfer/utils/__init__.py +1 -0
- dlinfer/utils/config.py +18 -0
- dlinfer/utils/registry.py +8 -0
- dlinfer/utils/type_annotation.py +3 -0
- dlinfer/vendor/__init__.py +33 -0
- dlinfer/vendor/ascend/__init__.py +5 -0
- dlinfer/vendor/ascend/pytorch_patch.py +55 -0
- dlinfer/vendor/ascend/torch_npu_ops.py +601 -0
- dlinfer/vendor/ascend/utils.py +20 -0
- dlinfer/vendor/vendor.yaml +2 -0
- dlinfer_ascend-0.2.3.post2.dist-info/LICENSE +28 -0
- dlinfer_ascend-0.2.3.post2.dist-info/METADATA +213 -0
- dlinfer_ascend-0.2.3.post2.dist-info/RECORD +70 -0
- dlinfer_ascend-0.2.3.post2.dist-info/WHEEL +5 -0
- dlinfer_ascend-0.2.3.post2.dist-info/entry_points.txt +2 -0
- dlinfer_ascend-0.2.3.post2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
# Copyright (c) 2024, DeepLink. All rights reserved.
|
|
2
|
+
import inspect
|
|
3
|
+
from functools import wraps
|
|
4
|
+
|
|
5
|
+
from torch.library import Library, impl
|
|
6
|
+
|
|
7
|
+
import dlinfer.graph
|
|
8
|
+
from dlinfer.utils.type_annotation import Callable, Optional, Sequence, Dict
|
|
9
|
+
from dlinfer.vendor import dispatch_key, vendor_name
|
|
10
|
+
|
|
11
|
+
library_impl_dict: Dict[str, Library] = dict()
|
|
12
|
+
graph_enabled_backends = ["ascend"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def register_custom_op(
|
|
16
|
+
qualname: str,
|
|
17
|
+
shape_param_keys: Optional[Sequence[str]] = None,
|
|
18
|
+
default_value: Optional[Dict] = None,
|
|
19
|
+
impl_abstract_func: Optional[Callable] = None,
|
|
20
|
+
) -> Callable:
|
|
21
|
+
disable = vendor_name not in graph_enabled_backends
|
|
22
|
+
|
|
23
|
+
def inner_func(func: Callable):
|
|
24
|
+
if disable:
|
|
25
|
+
return override_default_value_static(default_value)(func)
|
|
26
|
+
import torch._custom_ops
|
|
27
|
+
|
|
28
|
+
nonlocal impl_abstract_func
|
|
29
|
+
lib_name, func_name = qualname.split("::")
|
|
30
|
+
torch._custom_ops.custom_op(qualname)(func)
|
|
31
|
+
# using low level torch.library APIs in case of the registration
|
|
32
|
+
# of fallback kernels which raises error in torch._custom_ops.impl
|
|
33
|
+
if lib_name not in library_impl_dict:
|
|
34
|
+
library_impl_dict[lib_name] = Library(lib_name, "IMPL")
|
|
35
|
+
impl(library_impl_dict[lib_name], func_name, dispatch_key)(func)
|
|
36
|
+
if impl_abstract_func is None:
|
|
37
|
+
assert shape_param_keys is not None
|
|
38
|
+
params_name_list = [name for name in inspect.signature(func).parameters]
|
|
39
|
+
|
|
40
|
+
def _impl_abstract_func(*args, **kwargs):
|
|
41
|
+
assert len(args) + len(kwargs) == len(params_name_list)
|
|
42
|
+
result = []
|
|
43
|
+
for key in shape_param_keys:
|
|
44
|
+
key_index = params_name_list.index(key)
|
|
45
|
+
if key_index < len(args):
|
|
46
|
+
target = args[key_index]
|
|
47
|
+
else:
|
|
48
|
+
target = kwargs[key]
|
|
49
|
+
result.append(torch.empty_like(target))
|
|
50
|
+
if len(result) == 1:
|
|
51
|
+
return result[0]
|
|
52
|
+
return tuple(result)
|
|
53
|
+
|
|
54
|
+
impl_abstract_func = _impl_abstract_func
|
|
55
|
+
torch._custom_ops.impl_abstract(qualname)(impl_abstract_func)
|
|
56
|
+
torch_ops_namespace = getattr(torch.ops, lib_name)
|
|
57
|
+
torch_ops_func = getattr(torch_ops_namespace, func_name)
|
|
58
|
+
assert torch_ops_func is not None
|
|
59
|
+
# override default value
|
|
60
|
+
func_with_default = override_default_value_static(default_value)(func)
|
|
61
|
+
torch_ops_func_with_default = override_default_value_dynamic(
|
|
62
|
+
default_value, func
|
|
63
|
+
)(torch_ops_func)
|
|
64
|
+
|
|
65
|
+
# use config.enable_graph_mode to control func call
|
|
66
|
+
@wraps(func)
|
|
67
|
+
def patched_func(*args, **kwargs):
|
|
68
|
+
if not dlinfer.graph.config.enable_graph_mode:
|
|
69
|
+
return func_with_default(*args, **kwargs)
|
|
70
|
+
else:
|
|
71
|
+
return torch_ops_func_with_default(*args, **kwargs)
|
|
72
|
+
|
|
73
|
+
return patched_func
|
|
74
|
+
|
|
75
|
+
return inner_func
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def override_default_value_dynamic(
|
|
79
|
+
default_value: Optional[Dict], origin_func: Callable
|
|
80
|
+
):
|
|
81
|
+
def inner_func(func):
|
|
82
|
+
if default_value is None:
|
|
83
|
+
return func
|
|
84
|
+
sig = inspect.signature(origin_func)
|
|
85
|
+
sig_param_keys = sig.parameters.keys()
|
|
86
|
+
params_str = ", ".join(sig_param_keys)
|
|
87
|
+
params_with_default = []
|
|
88
|
+
for name in sig_param_keys:
|
|
89
|
+
if name in default_value:
|
|
90
|
+
if isinstance(default_value[name], str):
|
|
91
|
+
params_with_default.append(f"{name}='{default_value[name]}'")
|
|
92
|
+
else:
|
|
93
|
+
params_with_default.append(f"{name}={default_value[name]}")
|
|
94
|
+
else:
|
|
95
|
+
params_with_default.append(name)
|
|
96
|
+
params_str_with_default = ", ".join(params_with_default)
|
|
97
|
+
func_code = f"""
|
|
98
|
+
def {func.__name__}({params_str_with_default}):
|
|
99
|
+
return original_func({params_str})
|
|
100
|
+
"""
|
|
101
|
+
exec_namespace = {}
|
|
102
|
+
# it's hard not to use exec here
|
|
103
|
+
exec(func_code, {"original_func": func}, exec_namespace)
|
|
104
|
+
dynamic_func = exec_namespace[func.__name__]
|
|
105
|
+
|
|
106
|
+
return dynamic_func
|
|
107
|
+
|
|
108
|
+
return inner_func
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def override_default_value_static(default_value: Optional[Dict]):
|
|
112
|
+
# suitable for the function which signature isn't (*args, **kwargs)
|
|
113
|
+
def inner_func(func):
|
|
114
|
+
if default_value is None:
|
|
115
|
+
return func
|
|
116
|
+
sig = inspect.signature(func)
|
|
117
|
+
old_params = sig.parameters
|
|
118
|
+
new_params = []
|
|
119
|
+
default_arg = []
|
|
120
|
+
default_kwarg = []
|
|
121
|
+
func_co_argcount = func.__code__.co_argcount
|
|
122
|
+
param_has_default_value = False
|
|
123
|
+
for idx, (name, param) in enumerate(old_params.items()):
|
|
124
|
+
if name in default_value:
|
|
125
|
+
new_param = param.replace(default=default_value[name])
|
|
126
|
+
else:
|
|
127
|
+
new_param = param
|
|
128
|
+
new_params.append(new_param)
|
|
129
|
+
if new_param.default is not inspect._empty:
|
|
130
|
+
if not param_has_default_value:
|
|
131
|
+
param_has_default_value = True
|
|
132
|
+
if idx < func_co_argcount:
|
|
133
|
+
default_arg.append(new_param.default)
|
|
134
|
+
else:
|
|
135
|
+
default_kwarg.append((name, new_param.default))
|
|
136
|
+
else:
|
|
137
|
+
if param_has_default_value:
|
|
138
|
+
raise SyntaxError(
|
|
139
|
+
f"non-default argument '{name}' follows default argument"
|
|
140
|
+
)
|
|
141
|
+
new_signature = sig.replace(parameters=new_params)
|
|
142
|
+
func.__signature__ = new_signature
|
|
143
|
+
func.__defaults__ = tuple(default_arg)
|
|
144
|
+
func.__kwdefaults__ = dict(default_kwarg)
|
|
145
|
+
return func
|
|
146
|
+
|
|
147
|
+
return inner_func
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from abc import ABCMeta, abstractmethod
|
|
2
|
+
|
|
3
|
+
from dlinfer.graph.dicp.dynamo_bridge.torch_version import is_torch_251_or_higher
|
|
4
|
+
|
|
5
|
+
if is_torch_251_or_higher:
|
|
6
|
+
from torch._inductor.async_compile import AsyncCompile
|
|
7
|
+
else:
|
|
8
|
+
from torch._inductor.codecache import AsyncCompile
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DeviceCompileJob:
|
|
12
|
+
__metaclass__ = ABCMeta
|
|
13
|
+
|
|
14
|
+
def __init__(self):
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def get_key():
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def get_compile_result():
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DeviceKernelCache:
|
|
27
|
+
cache = dict()
|
|
28
|
+
clear = staticmethod(cache.clear)
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def get_kernel(cls, device_compile_job):
|
|
32
|
+
key = device_compile_job.get_key()
|
|
33
|
+
if key not in cls.cache:
|
|
34
|
+
loaded = device_compile_job.get_compile_result()
|
|
35
|
+
cls.cache[key] = loaded
|
|
36
|
+
cls.cache[key].key = key
|
|
37
|
+
return cls.cache[key]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class AsyncCompileKernel(AsyncCompile):
|
|
41
|
+
def compile_kernel(self, device_compile_job):
|
|
42
|
+
return DeviceKernelCache.get_kernel(device_compile_job).run
|
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
from torch._dynamo.backends.common import aot_autograd
|
|
2
|
+
from torch._functorch.aot_autograd import make_boxed_func
|
|
3
|
+
from .graph import GraphTransformer
|
|
4
|
+
import functools
|
|
5
|
+
import itertools
|
|
6
|
+
import logging
|
|
7
|
+
import sys
|
|
8
|
+
import functorch
|
|
9
|
+
import torch.fx
|
|
10
|
+
import importlib
|
|
11
|
+
import os
|
|
12
|
+
|
|
13
|
+
from typing import List
|
|
14
|
+
from importlib import import_module
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
from dlinfer.graph.dicp.dynamo_bridge import pt_patch # noqa F401
|
|
19
|
+
from dlinfer.graph.dicp.dynamo_bridge.torch_version import (
|
|
20
|
+
is_torch_200,
|
|
21
|
+
is_torch_210_or_higher,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
log = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
dynamo_logging = import_module("torch._dynamo.logging")
|
|
28
|
+
dynamo_utils = import_module("torch._dynamo.utils")
|
|
29
|
+
|
|
30
|
+
count_calls = dynamo_utils.count_calls
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_fake_mode_from_tensors(input_tensors):
|
|
34
|
+
if is_torch_200:
|
|
35
|
+
from torch._dynamo.utils import fake_mode_from_tensors
|
|
36
|
+
|
|
37
|
+
return fake_mode_from_tensors(input_tensors)
|
|
38
|
+
elif is_torch_210_or_higher:
|
|
39
|
+
from torch._dynamo.utils import detect_fake_mode
|
|
40
|
+
|
|
41
|
+
return detect_fake_mode(input_tensors)
|
|
42
|
+
else:
|
|
43
|
+
raise ValueError(f"unsupported dicp torch version: {torch.__version__}")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def used_nodes_all_symint(nodes):
|
|
47
|
+
for node in nodes:
|
|
48
|
+
if node.op == "placeholder" and len(node.users) > 0:
|
|
49
|
+
if hasattr(node, "meta"):
|
|
50
|
+
node = node.meta["val"]
|
|
51
|
+
if not isinstance(node, torch.SymInt):
|
|
52
|
+
return False
|
|
53
|
+
elif node.op == "output":
|
|
54
|
+
if hasattr(node, "meta") and "val" in node.meta:
|
|
55
|
+
node = node.meta["val"]
|
|
56
|
+
if not isinstance(node, torch.SymInt):
|
|
57
|
+
return False
|
|
58
|
+
return True
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@torch.utils._python_dispatch._disable_current_modes()
|
|
62
|
+
def compile_fx_inner(
|
|
63
|
+
gm: torch.fx.GraphModule,
|
|
64
|
+
example_inputs: List[torch.Tensor],
|
|
65
|
+
num_fixed=0,
|
|
66
|
+
is_backward=False,
|
|
67
|
+
graph_id=None,
|
|
68
|
+
backend=None,
|
|
69
|
+
):
|
|
70
|
+
if dynamo_utils.count_calls(gm.graph) == 0:
|
|
71
|
+
return make_boxed_func(gm.forward)
|
|
72
|
+
|
|
73
|
+
# all symint inputs fallback to eager mode
|
|
74
|
+
if used_nodes_all_symint(list(gm.graph.nodes)):
|
|
75
|
+
return gm
|
|
76
|
+
|
|
77
|
+
# lift the maximum depth of the Python interpreter stack
|
|
78
|
+
# to adapt large/deep models
|
|
79
|
+
sys.setrecursionlimit(max(sys.getrecursionlimit(), 2000))
|
|
80
|
+
|
|
81
|
+
gt = GraphTransformer(gm, backend)
|
|
82
|
+
gt.transform()
|
|
83
|
+
compiled_fn = gt.compile_to_fn()
|
|
84
|
+
|
|
85
|
+
# aot autograd needs to know to pass in inputs as a list
|
|
86
|
+
compiled_fn._boxed_call = True
|
|
87
|
+
return compiled_fn
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
_graph_counter = itertools.count(0)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def compile_fx(
|
|
94
|
+
model_: torch.fx.GraphModule,
|
|
95
|
+
example_inputs_: List[torch.Tensor],
|
|
96
|
+
backend: str,
|
|
97
|
+
inner_compile=compile_fx_inner,
|
|
98
|
+
):
|
|
99
|
+
if is_torch_200:
|
|
100
|
+
return compile_fx_200(model_, example_inputs_, backend, inner_compile)
|
|
101
|
+
elif is_torch_210_or_higher:
|
|
102
|
+
return compile_fx_210(model_, example_inputs_, backend, inner_compile)
|
|
103
|
+
else:
|
|
104
|
+
raise ValueError(f"unsupported dicp torch version: {torch.__version__}")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def compile_fx_200(
|
|
108
|
+
model_: torch.fx.GraphModule,
|
|
109
|
+
example_inputs_: List[torch.Tensor],
|
|
110
|
+
backend: str,
|
|
111
|
+
inner_compile=compile_fx_inner,
|
|
112
|
+
):
|
|
113
|
+
"""Main entrypoint to a compile given FX graph"""
|
|
114
|
+
functorch.compile.config.use_functionalize = True
|
|
115
|
+
functorch.compile.config.use_fake_tensor = True
|
|
116
|
+
|
|
117
|
+
num_example_inputs = len(example_inputs_)
|
|
118
|
+
|
|
119
|
+
graph_id = next(_graph_counter)
|
|
120
|
+
|
|
121
|
+
@dynamo_utils.dynamo_timed
|
|
122
|
+
def fw_compiler(model: torch.fx.GraphModule, example_inputs):
|
|
123
|
+
fixed = len(example_inputs) - num_example_inputs
|
|
124
|
+
return inner_compile(
|
|
125
|
+
model,
|
|
126
|
+
example_inputs,
|
|
127
|
+
num_fixed=fixed,
|
|
128
|
+
graph_id=graph_id,
|
|
129
|
+
backend=backend,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
@dynamo_utils.dynamo_timed
|
|
133
|
+
def bw_compiler(model: torch.fx.GraphModule, example_inputs):
|
|
134
|
+
fixed = count_tangents(model)
|
|
135
|
+
return inner_compile(
|
|
136
|
+
model,
|
|
137
|
+
example_inputs,
|
|
138
|
+
num_fixed=fixed,
|
|
139
|
+
is_backward=True,
|
|
140
|
+
graph_id=graph_id,
|
|
141
|
+
backend=backend,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
decompositions = get_decompositions(backend=backend)
|
|
145
|
+
return aot_autograd(
|
|
146
|
+
fw_compiler=fw_compiler, bw_compiler=bw_compiler, decompositions=decompositions
|
|
147
|
+
)(model_, example_inputs_)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def compile_fx_210(
|
|
151
|
+
model_: torch.fx.GraphModule,
|
|
152
|
+
example_inputs_: List[torch.Tensor],
|
|
153
|
+
backend: str,
|
|
154
|
+
inner_compile=compile_fx_inner,
|
|
155
|
+
):
|
|
156
|
+
import torch._dynamo.config as dynamo_config
|
|
157
|
+
from torch._inductor.compile_fx import (
|
|
158
|
+
flatten_graph_inputs,
|
|
159
|
+
graph_returns_tuple,
|
|
160
|
+
make_graph_return_tuple,
|
|
161
|
+
pre_grad_passes,
|
|
162
|
+
joint_graph_passes,
|
|
163
|
+
min_cut_rematerialization_partition,
|
|
164
|
+
_PyTreeCodeGen,
|
|
165
|
+
handle_dynamo_export_graph,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
decompositions = get_decompositions(backend=backend)
|
|
169
|
+
|
|
170
|
+
recursive_compile_fx = functools.partial(
|
|
171
|
+
compile_fx,
|
|
172
|
+
inner_compile=inner_compile,
|
|
173
|
+
decompositions=decompositions,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
if not graph_returns_tuple(model_):
|
|
177
|
+
return make_graph_return_tuple(
|
|
178
|
+
model_,
|
|
179
|
+
example_inputs_,
|
|
180
|
+
recursive_compile_fx,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
if isinstance(model_, torch.fx.GraphModule):
|
|
184
|
+
if isinstance(model_.graph._codegen, _PyTreeCodeGen):
|
|
185
|
+
# this graph is the result of dynamo.export()
|
|
186
|
+
return handle_dynamo_export_graph(
|
|
187
|
+
model_,
|
|
188
|
+
example_inputs_,
|
|
189
|
+
recursive_compile_fx,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# Since handle_dynamo_export_graph will trigger compile_fx again,
|
|
193
|
+
# Move these passes after handle_dynamo_export_graph to avoid repeated calls.
|
|
194
|
+
model_ = pre_grad_passes(model_, example_inputs_)
|
|
195
|
+
|
|
196
|
+
if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_):
|
|
197
|
+
return flatten_graph_inputs(
|
|
198
|
+
model_,
|
|
199
|
+
example_inputs_,
|
|
200
|
+
recursive_compile_fx,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# assert not config._raise_error_for_testing
|
|
204
|
+
num_example_inputs = len(example_inputs_)
|
|
205
|
+
|
|
206
|
+
graph_id = next(_graph_counter)
|
|
207
|
+
|
|
208
|
+
def fw_compiler_base(
|
|
209
|
+
model: torch.fx.GraphModule,
|
|
210
|
+
example_inputs: List[torch.Tensor],
|
|
211
|
+
is_inference: bool,
|
|
212
|
+
):
|
|
213
|
+
return _fw_compiler_base(model, example_inputs, is_inference)
|
|
214
|
+
|
|
215
|
+
def _fw_compiler_base(
|
|
216
|
+
model: torch.fx.GraphModule,
|
|
217
|
+
example_inputs: List[torch.Tensor],
|
|
218
|
+
is_inference: bool,
|
|
219
|
+
):
|
|
220
|
+
if is_inference:
|
|
221
|
+
# partition_fn won't be called
|
|
222
|
+
# joint_graph_passes(model)
|
|
223
|
+
pass
|
|
224
|
+
|
|
225
|
+
fixed = len(example_inputs) - num_example_inputs
|
|
226
|
+
return inner_compile(
|
|
227
|
+
model,
|
|
228
|
+
example_inputs,
|
|
229
|
+
num_fixed=fixed,
|
|
230
|
+
graph_id=graph_id,
|
|
231
|
+
backend=backend,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
fw_compiler = functools.partial(fw_compiler_base, is_inference=False)
|
|
235
|
+
inference_compiler = functools.partial(fw_compiler_base, is_inference=True)
|
|
236
|
+
|
|
237
|
+
def partition_fn(graph, joint_inputs, **kwargs):
|
|
238
|
+
joint_graph_passes(graph)
|
|
239
|
+
return min_cut_rematerialization_partition(
|
|
240
|
+
graph, joint_inputs, **kwargs, compiler="inductor"
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# Save and restore dynamic shapes setting for backwards, as it is
|
|
244
|
+
# sometimes done as a context manager which won't be set when we
|
|
245
|
+
# hit backwards compile
|
|
246
|
+
dynamic_shapes = dynamo_config.dynamic_shapes
|
|
247
|
+
|
|
248
|
+
def bw_compiler(model: torch.fx.GraphModule, example_inputs):
|
|
249
|
+
with dynamo_config.patch(dynamic_shapes=dynamic_shapes):
|
|
250
|
+
fixed = count_tangents(model)
|
|
251
|
+
return inner_compile(
|
|
252
|
+
model,
|
|
253
|
+
example_inputs,
|
|
254
|
+
num_fixed=fixed,
|
|
255
|
+
is_backward=True,
|
|
256
|
+
graph_id=graph_id,
|
|
257
|
+
backend=backend,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# TODO: can add logging before/after the call to create_aot_dispatcher_function
|
|
261
|
+
# in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func
|
|
262
|
+
# once torchdynamo is merged into pytorch
|
|
263
|
+
return aot_autograd(
|
|
264
|
+
fw_compiler=fw_compiler,
|
|
265
|
+
bw_compiler=bw_compiler,
|
|
266
|
+
inference_compiler=inference_compiler,
|
|
267
|
+
decompositions=decompositions,
|
|
268
|
+
partition_fn=partition_fn,
|
|
269
|
+
keep_inference_input_mutations=True,
|
|
270
|
+
)(model_, example_inputs_)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def count_tangents(fx_g: torch.fx.GraphModule):
|
|
274
|
+
"""
|
|
275
|
+
Infers which inputs are static for a backwards graph
|
|
276
|
+
"""
|
|
277
|
+
|
|
278
|
+
def is_not_gradout(x):
|
|
279
|
+
return "tangents" not in x.name
|
|
280
|
+
|
|
281
|
+
arg_count = 0
|
|
282
|
+
static_arg_idxs = []
|
|
283
|
+
for n in fx_g.graph.nodes:
|
|
284
|
+
if n.op == "placeholder":
|
|
285
|
+
if is_not_gradout(n):
|
|
286
|
+
static_arg_idxs.append(arg_count)
|
|
287
|
+
arg_count += 1
|
|
288
|
+
|
|
289
|
+
assert static_arg_idxs == list(range(len(static_arg_idxs)))
|
|
290
|
+
return len(static_arg_idxs)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def get_decompositions(backend):
|
|
294
|
+
decompositions = {}
|
|
295
|
+
folder_list = os.listdir(os.path.dirname(os.path.dirname(__file__)) + "/vendor")
|
|
296
|
+
found_decomp = False
|
|
297
|
+
for folder in folder_list:
|
|
298
|
+
if backend.lower() == folder.lower():
|
|
299
|
+
config = importlib.import_module(
|
|
300
|
+
"dlinfer.graph.dicp.vendor." + folder + ".config"
|
|
301
|
+
)
|
|
302
|
+
decompositions = config.decomp
|
|
303
|
+
found_decomp = True
|
|
304
|
+
assert found_decomp, "Not found decomp table!"
|
|
305
|
+
return decompositions
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import torch
|
|
3
|
+
from dlinfer.graph.dicp.dynamo_bridge.operator import Operator
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def args_kwargs_unchange(args, kwargs):
|
|
7
|
+
return args, kwargs
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def register_conversion_impl(
|
|
11
|
+
conversions: list, aten_fn, decomp_fn, process_args_kwargs_fn=None
|
|
12
|
+
):
|
|
13
|
+
register_op_singleton_flag = isinstance(decomp_fn, type) and issubclass(
|
|
14
|
+
decomp_fn, Operator
|
|
15
|
+
)
|
|
16
|
+
if register_op_singleton_flag:
|
|
17
|
+
wrapped = (
|
|
18
|
+
decomp_fn.get_singleton(),
|
|
19
|
+
(
|
|
20
|
+
args_kwargs_unchange
|
|
21
|
+
if process_args_kwargs_fn is None
|
|
22
|
+
else process_args_kwargs_fn
|
|
23
|
+
),
|
|
24
|
+
)
|
|
25
|
+
else:
|
|
26
|
+
|
|
27
|
+
@functools.wraps(decomp_fn)
|
|
28
|
+
def wrapped(*args, **kwargs):
|
|
29
|
+
return decomp_fn(*args, **kwargs)
|
|
30
|
+
|
|
31
|
+
if not isinstance(aten_fn, (list, tuple)):
|
|
32
|
+
aten_fn = [aten_fn]
|
|
33
|
+
else:
|
|
34
|
+
aten_fn = list(aten_fn)
|
|
35
|
+
|
|
36
|
+
aten_fn_for_key = []
|
|
37
|
+
for fn in list(aten_fn):
|
|
38
|
+
if isinstance(fn, str):
|
|
39
|
+
assert fn.startswith("torch.ops")
|
|
40
|
+
real_fn_name = fn.replace("torch.ops.", "")
|
|
41
|
+
ns, op_overload = real_fn_name.split(".", 1)
|
|
42
|
+
if not hasattr(torch.ops, ns):
|
|
43
|
+
print(
|
|
44
|
+
f"[dicp] can't find torch.ops.{ns}, conversion for {fn} is ignored"
|
|
45
|
+
)
|
|
46
|
+
continue
|
|
47
|
+
ns_obj = getattr(torch.ops, ns)
|
|
48
|
+
if "." in op_overload:
|
|
49
|
+
op, overload = op_overload.split(".", 1)
|
|
50
|
+
if not hasattr(ns_obj, op):
|
|
51
|
+
print(
|
|
52
|
+
f"[dicp] can't find torch.ops.{ns}.{op}, conversion for {fn} is ignored"
|
|
53
|
+
)
|
|
54
|
+
continue
|
|
55
|
+
op_obj = getattr(ns_obj, op)
|
|
56
|
+
fn = getattr(op_obj, overload)
|
|
57
|
+
else:
|
|
58
|
+
if not hasattr(ns_obj, op_overload):
|
|
59
|
+
print(
|
|
60
|
+
f"[dicp] can't find torch.ops.{ns}.{op_overload}, conversion for {fn} is ignored"
|
|
61
|
+
)
|
|
62
|
+
continue
|
|
63
|
+
fn = getattr(ns_obj, op_overload)
|
|
64
|
+
if isinstance(fn, torch._ops.OpOverloadPacket):
|
|
65
|
+
for overload in fn.overloads():
|
|
66
|
+
other_fn = getattr(fn, overload)
|
|
67
|
+
if other_fn not in conversions:
|
|
68
|
+
aten_fn_for_key.append(other_fn)
|
|
69
|
+
aten_fn_for_key.append(fn)
|
|
70
|
+
|
|
71
|
+
conversions.update({fn: wrapped for fn in aten_fn_for_key})
|
|
72
|
+
if register_op_singleton_flag:
|
|
73
|
+
return wrapped[0]
|
|
74
|
+
else:
|
|
75
|
+
return wrapped
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import Callable, Dict, Sequence, Union
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch._decomp import register_decomposition
|
|
6
|
+
from torch._ops import OpOverload, OpOverloadPacket
|
|
7
|
+
|
|
8
|
+
dicp_decomposition_table = {}
|
|
9
|
+
aten = torch.ops.aten
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def register_decomposition_for_dicp(fn):
|
|
13
|
+
return register_decomposition(fn, registry=dicp_decomposition_table)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@register_decomposition_for_dicp(aten.count_nonzero.default)
|
|
17
|
+
def count_nonzero_default(x, dim=None):
|
|
18
|
+
cond = x != 0
|
|
19
|
+
dim = [] if dim is None else dim
|
|
20
|
+
return aten.sum.dim_IntList(cond, dim=dim, keepdim=False, dtype=torch.int64)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_decompositions(
|
|
24
|
+
aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]],
|
|
25
|
+
target_decomposition_table: Dict[OpOverload, Callable] = None,
|
|
26
|
+
) -> Dict[OpOverload, Callable]:
|
|
27
|
+
registry = dicp_decomposition_table
|
|
28
|
+
packets_to_overloads = defaultdict(list)
|
|
29
|
+
for opo in registry:
|
|
30
|
+
packets_to_overloads[opo.overloadpacket].append(opo)
|
|
31
|
+
decompositions = target_decomposition_table if target_decomposition_table else {}
|
|
32
|
+
for op in aten_ops:
|
|
33
|
+
if isinstance(op, OpOverloadPacket) and op in packets_to_overloads:
|
|
34
|
+
for op_overload in packets_to_overloads[op]:
|
|
35
|
+
decompositions[op_overload] = registry[op_overload]
|
|
36
|
+
elif isinstance(op, OpOverload) and op in registry:
|
|
37
|
+
decompositions[op] = registry[op]
|
|
38
|
+
return decompositions
|