onnx-diagnostic 0.7.1__py3-none-any.whl → 0.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +22 -5
- onnx_diagnostic/ext_test_case.py +31 -0
- onnx_diagnostic/helpers/cache_helper.py +23 -12
- onnx_diagnostic/helpers/config_helper.py +16 -1
- onnx_diagnostic/helpers/log_helper.py +308 -83
- onnx_diagnostic/helpers/torch_helper.py +6 -2
- onnx_diagnostic/tasks/__init__.py +2 -0
- onnx_diagnostic/tasks/text_generation.py +17 -8
- onnx_diagnostic/tasks/text_to_image.py +91 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +24 -7
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +144 -349
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +87 -7
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +259 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +15 -4
- onnx_diagnostic/torch_models/hghub/hub_data.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +28 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -5
- onnx_diagnostic/torch_models/validate.py +36 -12
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.2.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.2.dist-info}/RECORD +26 -22
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.2.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.2.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import Any, Callable, List, Set, Tuple
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _lower_name_with_(name):
|
|
7
|
+
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
|
8
|
+
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def make_serialization_function_for_dataclass(
|
|
12
|
+
cls: type, supported_classes: Set[type]
|
|
13
|
+
) -> Tuple[Callable, Callable, Callable]:
|
|
14
|
+
"""
|
|
15
|
+
Automatically creates serialization function for a class decorated with
|
|
16
|
+
``dataclasses.dataclass``.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def flatten_cls(obj: cls) -> Tuple[List[Any], torch.utils._pytree.Context]: # type: ignore[valid-type]
|
|
20
|
+
"""Serializes a ``%s`` with python objects."""
|
|
21
|
+
return list(obj.values()), list(obj.keys())
|
|
22
|
+
|
|
23
|
+
def flatten_with_keys_cls(
|
|
24
|
+
obj: cls, # type: ignore[valid-type]
|
|
25
|
+
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
26
|
+
"""Serializes a ``%s`` with python objects with keys."""
|
|
27
|
+
values, context = list(obj.values()), list(obj.keys())
|
|
28
|
+
return [
|
|
29
|
+
(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)
|
|
30
|
+
], context
|
|
31
|
+
|
|
32
|
+
def unflatten_cls(
|
|
33
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
34
|
+
) -> cls: # type: ignore[valid-type]
|
|
35
|
+
"""Restores an instance of ``%s`` from python objects."""
|
|
36
|
+
return cls(**dict(zip(context, values)))
|
|
37
|
+
|
|
38
|
+
name = _lower_name_with_(cls.__name__)
|
|
39
|
+
flatten_cls.__name__ = f"flatten_{name}"
|
|
40
|
+
flatten_with_keys_cls.__name__ = f"flatten_with_keys_{name}"
|
|
41
|
+
unflatten_cls.__name__ = f"unflatten_{name}"
|
|
42
|
+
flatten_cls.__doc__ = flatten_cls.__doc__ % cls.__name__
|
|
43
|
+
flatten_with_keys_cls.__doc__ = flatten_with_keys_cls.__doc__ % cls.__name__
|
|
44
|
+
unflatten_cls.__doc__ = unflatten_cls.__doc__ % cls.__name__
|
|
45
|
+
supported_classes.add(cls)
|
|
46
|
+
return flatten_cls, flatten_with_keys_cls, unflatten_cls
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from typing import Dict, Optional, Set
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
|
|
5
|
+
except ImportError as e:
|
|
6
|
+
try:
|
|
7
|
+
import diffusers
|
|
8
|
+
except ImportError:
|
|
9
|
+
diffusers = None
|
|
10
|
+
UNet2DConditionOutput = None
|
|
11
|
+
if diffusers:
|
|
12
|
+
raise e
|
|
13
|
+
|
|
14
|
+
from . import make_serialization_function_for_dataclass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _make_wrong_registrations() -> Dict[type, Optional[str]]:
|
|
18
|
+
res: Dict[type, Optional[str]] = {}
|
|
19
|
+
for c in [UNet2DConditionOutput]:
|
|
20
|
+
if c is not None:
|
|
21
|
+
res[c] = None
|
|
22
|
+
return res
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
SUPPORTED_DATACLASSES: Set[type] = set()
|
|
26
|
+
WRONG_REGISTRATIONS = _make_wrong_registrations()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
if UNet2DConditionOutput is not None:
|
|
30
|
+
(
|
|
31
|
+
flatten_u_net2_d_condition_output,
|
|
32
|
+
flatten_with_keys_u_net2_d_condition_output,
|
|
33
|
+
unflatten_u_net2_d_condition_output,
|
|
34
|
+
) = make_serialization_function_for_dataclass(UNet2DConditionOutput, SUPPORTED_DATACLASSES)
|
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
from typing import Any, List, Set, Tuple
|
|
2
|
+
import torch
|
|
3
|
+
import transformers
|
|
4
|
+
from transformers.cache_utils import (
|
|
5
|
+
DynamicCache,
|
|
6
|
+
MambaCache,
|
|
7
|
+
EncoderDecoderCache,
|
|
8
|
+
SlidingWindowCache,
|
|
9
|
+
StaticCache,
|
|
10
|
+
)
|
|
11
|
+
from transformers.modeling_outputs import BaseModelOutput
|
|
12
|
+
from ...helpers.cache_helper import make_static_cache
|
|
13
|
+
from . import make_serialization_function_for_dataclass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
SUPPORTED_DATACLASSES: Set[type] = set()
|
|
17
|
+
WRONG_REGISTRATIONS = {
|
|
18
|
+
DynamicCache: "4.50",
|
|
19
|
+
BaseModelOutput: None,
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
############
|
|
24
|
+
# MambaCache
|
|
25
|
+
############
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def flatten_mamba_cache(
|
|
29
|
+
mamba_cache: MambaCache,
|
|
30
|
+
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
31
|
+
"""Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
|
|
32
|
+
flat = [
|
|
33
|
+
("conv_states", mamba_cache.conv_states),
|
|
34
|
+
("ssm_states", mamba_cache.ssm_states),
|
|
35
|
+
]
|
|
36
|
+
return [f[1] for f in flat], [f[0] for f in flat]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def unflatten_mamba_cache(
|
|
40
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
41
|
+
) -> MambaCache:
|
|
42
|
+
"""Restores a :class:`transformers.cache_utils.MambaCache` from python objects."""
|
|
43
|
+
conv_states, ssm_states = values
|
|
44
|
+
|
|
45
|
+
class _config:
|
|
46
|
+
def __init__(self):
|
|
47
|
+
if isinstance(conv_states, list):
|
|
48
|
+
self.intermediate_size = conv_states[0].shape[1]
|
|
49
|
+
self.state_size = ssm_states[0].shape[2]
|
|
50
|
+
self.conv_kernel = conv_states[0].shape[2]
|
|
51
|
+
self.num_hidden_layers = len(conv_states)
|
|
52
|
+
else:
|
|
53
|
+
self.intermediate_size = conv_states.shape[2]
|
|
54
|
+
self.state_size = ssm_states.shape[3]
|
|
55
|
+
self.conv_kernel = conv_states.shape[3]
|
|
56
|
+
self.num_hidden_layers = conv_states.shape[0]
|
|
57
|
+
|
|
58
|
+
cache = MambaCache(
|
|
59
|
+
_config(),
|
|
60
|
+
max_batch_size=1,
|
|
61
|
+
dtype=values[-1][0].dtype,
|
|
62
|
+
device="cpu" if values[-1][0].get_device() < 0 else "cuda",
|
|
63
|
+
)
|
|
64
|
+
values = dict(zip(context, values))
|
|
65
|
+
for k, v in values.items():
|
|
66
|
+
setattr(cache, k, v)
|
|
67
|
+
return cache
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def flatten_with_keys_mamba_cache(cache: MambaCache) -> Tuple[
|
|
71
|
+
List[Tuple[torch.utils._pytree.KeyEntry, Any]],
|
|
72
|
+
torch.utils._pytree.Context,
|
|
73
|
+
]:
|
|
74
|
+
"""Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
|
|
75
|
+
values, context = flatten_mamba_cache(cache)
|
|
76
|
+
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
##############
|
|
80
|
+
# DynamicCache
|
|
81
|
+
##############
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def flatten_dynamic_cache(
|
|
85
|
+
dynamic_cache: DynamicCache,
|
|
86
|
+
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
87
|
+
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
|
|
88
|
+
if hasattr(transformers.cache_utils, "_flatten_dynamic_cache"):
|
|
89
|
+
return transformers.cache_utils._flatten_dynamic_cache(dynamic_cache)
|
|
90
|
+
flat = [("key_cache", dynamic_cache.key_cache), ("value_cache", dynamic_cache.value_cache)]
|
|
91
|
+
return [f[1] for f in flat], [f[0] for f in flat]
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def flatten_with_keys_dynamic_cache(
|
|
95
|
+
dynamic_cache: DynamicCache,
|
|
96
|
+
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
97
|
+
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
|
|
98
|
+
if hasattr(transformers.cache_utils, "_flatten_with_keys_dynamic_cache"):
|
|
99
|
+
return transformers.cache_utils._flatten_with_keys_dynamic_cache(dynamic_cache)
|
|
100
|
+
values, context = flatten_dynamic_cache(dynamic_cache)
|
|
101
|
+
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def unflatten_dynamic_cache(
|
|
105
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
106
|
+
) -> DynamicCache:
|
|
107
|
+
"""Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
|
|
108
|
+
if hasattr(transformers.cache_utils, "_unflatten_dynamic_cache"):
|
|
109
|
+
assert output_type is None, f"output_type={output_type} not supported"
|
|
110
|
+
return transformers.cache_utils._unflatten_dynamic_cache(values, context)
|
|
111
|
+
|
|
112
|
+
cache = transformers.cache_utils.DynamicCache()
|
|
113
|
+
values = dict(zip(context, values))
|
|
114
|
+
for k, v in values.items():
|
|
115
|
+
setattr(cache, k, v)
|
|
116
|
+
return cache
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
#############
|
|
120
|
+
# StaticCache
|
|
121
|
+
#############
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def flatten_static_cache(
|
|
125
|
+
cache: StaticCache,
|
|
126
|
+
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
127
|
+
"""Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
|
|
128
|
+
assert not cache.key_cache or cache.max_cache_len == cache.key_cache[0].shape[2], (
|
|
129
|
+
f"Serialization doet not work when "
|
|
130
|
+
f"cache.max_cache_len={cache.max_cache_len} != "
|
|
131
|
+
f"cache.key_cache[0].shape[2]={cache.key_cache[0].shape[2]}"
|
|
132
|
+
)
|
|
133
|
+
flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)]
|
|
134
|
+
return [f[1] for f in flat], [f[0] for f in flat]
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def flatten_with_keys_static_cache(
|
|
138
|
+
cache: StaticCache,
|
|
139
|
+
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
140
|
+
"""Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
|
|
141
|
+
values, context = flatten_static_cache(cache)
|
|
142
|
+
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def unflatten_static_cache(
|
|
146
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
147
|
+
) -> StaticCache:
|
|
148
|
+
"""Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
|
|
149
|
+
return make_static_cache(
|
|
150
|
+
list(zip(values[0], values[1])), max_cache_len=values[0][0].shape[2]
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
####################
|
|
155
|
+
# SlidingWindowCache
|
|
156
|
+
####################
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def flatten_sliding_window_cache(
|
|
160
|
+
cache: SlidingWindowCache,
|
|
161
|
+
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
162
|
+
"""
|
|
163
|
+
Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
|
|
164
|
+
with python objects.
|
|
165
|
+
"""
|
|
166
|
+
flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)]
|
|
167
|
+
return [f[1] for f in flat], [f[0] for f in flat]
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def flatten_with_keys_sliding_window_cache(
|
|
171
|
+
cache: SlidingWindowCache,
|
|
172
|
+
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
173
|
+
"""
|
|
174
|
+
Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
|
|
175
|
+
with python objects.
|
|
176
|
+
"""
|
|
177
|
+
values, context = flatten_sliding_window_cache(cache)
|
|
178
|
+
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def unflatten_sliding_window_cache(
|
|
182
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
183
|
+
) -> SlidingWindowCache:
|
|
184
|
+
"""Restores a :class:`transformers.cache_utils.SlidingWindowCache` from python objects."""
|
|
185
|
+
key_cache, value_cache = values
|
|
186
|
+
|
|
187
|
+
class _config:
|
|
188
|
+
def __init__(self):
|
|
189
|
+
self.head_dim = key_cache[0].shape[-1]
|
|
190
|
+
self.num_attention_heads = key_cache[0].shape[1]
|
|
191
|
+
self.num_hidden_layers = len(key_cache)
|
|
192
|
+
self.sliding_window = key_cache[0].shape[2]
|
|
193
|
+
|
|
194
|
+
cache = SlidingWindowCache(
|
|
195
|
+
_config(),
|
|
196
|
+
max_batch_size=key_cache[0].shape[0],
|
|
197
|
+
max_cache_len=key_cache[0].shape[2], # sligding window
|
|
198
|
+
device=key_cache[0].device,
|
|
199
|
+
dtype=key_cache[0].dtype,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
values = dict(zip(context, values))
|
|
203
|
+
for k, v in values.items():
|
|
204
|
+
setattr(cache, k, v)
|
|
205
|
+
return cache
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
#####################
|
|
209
|
+
# EncoderDecoderCache
|
|
210
|
+
#####################
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def flatten_encoder_decoder_cache(
|
|
214
|
+
ec_cache: EncoderDecoderCache,
|
|
215
|
+
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
216
|
+
"""
|
|
217
|
+
Serializes a :class:`transformers.cache_utils.EncoderDecoderCache`
|
|
218
|
+
with python objects.
|
|
219
|
+
"""
|
|
220
|
+
dictionary = {
|
|
221
|
+
"self_attention_cache": ec_cache.self_attention_cache,
|
|
222
|
+
"cross_attention_cache": ec_cache.cross_attention_cache,
|
|
223
|
+
}
|
|
224
|
+
return torch.utils._pytree._dict_flatten(dictionary)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def flatten_with_keys_encoder_decoder_cache(ec_cache: EncoderDecoderCache) -> Tuple[
|
|
228
|
+
List[Tuple[torch.utils._pytree.KeyEntry, Any]],
|
|
229
|
+
torch.utils._pytree.Context,
|
|
230
|
+
]:
|
|
231
|
+
"""
|
|
232
|
+
Serializes a :class:`transformers.cache_utils.EncoderDecoderCache`
|
|
233
|
+
with python objects.
|
|
234
|
+
"""
|
|
235
|
+
dictionary = {
|
|
236
|
+
"self_attention_cache": ec_cache.self_attention_cache,
|
|
237
|
+
"cross_attention_cache": ec_cache.cross_attention_cache,
|
|
238
|
+
}
|
|
239
|
+
return torch.utils._pytree._dict_flatten_with_keys(dictionary)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def unflatten_encoder_decoder_cache(
|
|
243
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
244
|
+
) -> EncoderDecoderCache:
|
|
245
|
+
"""Restores a :class:`transformers.cache_utils.EncoderDecoderCache` from python objects."""
|
|
246
|
+
dictionary = torch.utils._pytree._dict_unflatten(values, context)
|
|
247
|
+
return EncoderDecoderCache(**dictionary)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
#############
|
|
251
|
+
# dataclasses
|
|
252
|
+
#############
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
(
|
|
256
|
+
flatten_base_model_output,
|
|
257
|
+
flatten_with_keys_base_model_output,
|
|
258
|
+
unflatten_base_model_output,
|
|
259
|
+
) = make_serialization_function_for_dataclass(BaseModelOutput, SUPPORTED_DATACLASSES)
|
|
@@ -140,7 +140,10 @@ def _guess_task_from_config(config: Any) -> Optional[str]:
|
|
|
140
140
|
|
|
141
141
|
@functools.cache
|
|
142
142
|
def task_from_arch(
|
|
143
|
-
arch: str,
|
|
143
|
+
arch: str,
|
|
144
|
+
default_value: Optional[str] = None,
|
|
145
|
+
model_id: Optional[str] = None,
|
|
146
|
+
subfolder: Optional[str] = None,
|
|
144
147
|
) -> str:
|
|
145
148
|
"""
|
|
146
149
|
This function relies on stored information. That information needs to be refresh.
|
|
@@ -148,6 +151,7 @@ def task_from_arch(
|
|
|
148
151
|
:param arch: architecture name
|
|
149
152
|
:param default_value: default value in case the task cannot be determined
|
|
150
153
|
:param model_id: unused unless the architecture does not help.
|
|
154
|
+
:param subfolder: subfolder
|
|
151
155
|
:return: task
|
|
152
156
|
|
|
153
157
|
.. runpython::
|
|
@@ -162,7 +166,7 @@ def task_from_arch(
|
|
|
162
166
|
data = load_architecture_task()
|
|
163
167
|
if arch not in data and model_id:
|
|
164
168
|
# Let's try with the model id.
|
|
165
|
-
return task_from_id(model_id)
|
|
169
|
+
return task_from_id(model_id, subfolder=subfolder)
|
|
166
170
|
if default_value is not None:
|
|
167
171
|
return data.get(arch, default_value)
|
|
168
172
|
assert arch in data, (
|
|
@@ -178,6 +182,7 @@ def task_from_id(
|
|
|
178
182
|
default_value: Optional[str] = None,
|
|
179
183
|
pretrained: bool = False,
|
|
180
184
|
fall_back_to_pretrained: bool = True,
|
|
185
|
+
subfolder: Optional[str] = None,
|
|
181
186
|
) -> str:
|
|
182
187
|
"""
|
|
183
188
|
Returns the task attached to a model id.
|
|
@@ -187,7 +192,7 @@ def task_from_id(
|
|
|
187
192
|
if the task cannot be determined
|
|
188
193
|
:param pretrained: uses the config
|
|
189
194
|
:param fall_back_to_pretrained: falls back to pretrained config
|
|
190
|
-
:param
|
|
195
|
+
:param subfolder: subfolder
|
|
191
196
|
:return: task
|
|
192
197
|
"""
|
|
193
198
|
if not pretrained:
|
|
@@ -196,7 +201,7 @@ def task_from_id(
|
|
|
196
201
|
except RuntimeError:
|
|
197
202
|
if not fall_back_to_pretrained:
|
|
198
203
|
raise
|
|
199
|
-
config = get_pretrained_config(model_id)
|
|
204
|
+
config = get_pretrained_config(model_id, subfolder=subfolder)
|
|
200
205
|
try:
|
|
201
206
|
return config.pipeline_tag
|
|
202
207
|
except AttributeError:
|
|
@@ -206,6 +211,12 @@ def task_from_id(
|
|
|
206
211
|
data = load_architecture_task()
|
|
207
212
|
if model_id in data:
|
|
208
213
|
return data[model_id]
|
|
214
|
+
if type(config) is dict and "_class_name" in config:
|
|
215
|
+
return task_from_arch(config["_class_name"], default_value=default_value)
|
|
216
|
+
if not config.architectures or not config.architectures:
|
|
217
|
+
# Some hardcoded values until a better solution is found.
|
|
218
|
+
if model_id.startswith("google/bert_"):
|
|
219
|
+
return "fill-mask"
|
|
209
220
|
assert config.architectures is not None and len(config.architectures) == 1, (
|
|
210
221
|
f"Cannot return the task of {model_id!r}, pipeline_tag is not setup, "
|
|
211
222
|
f"architectures={config.architectures} in config={config}. "
|
|
@@ -22,6 +22,7 @@ __data_arch__ = textwrap.dedent(
|
|
|
22
22
|
BlenderbotModel,feature-extraction
|
|
23
23
|
BloomModel,feature-extraction
|
|
24
24
|
CLIPModel,zero-shot-image-classification
|
|
25
|
+
CLIPTextModel,feature-extraction
|
|
25
26
|
CLIPVisionModel,feature-extraction
|
|
26
27
|
CamembertModel,feature-extraction
|
|
27
28
|
CodeGenModel,feature-extraction
|
|
@@ -4302,3 +4302,31 @@ def _ccached_microsoft_phi_35_mini_instruct():
|
|
|
4302
4302
|
"vocab_size": 32064,
|
|
4303
4303
|
}
|
|
4304
4304
|
)
|
|
4305
|
+
|
|
4306
|
+
|
|
4307
|
+
def _ccached_diffusers_tiny_torch_full_checker_unet():
|
|
4308
|
+
"diffusers/tiny-torch-full-checker/unet"
|
|
4309
|
+
return {
|
|
4310
|
+
"_class_name": "UNet2DConditionModel",
|
|
4311
|
+
"_diffusers_version": "0.8.0",
|
|
4312
|
+
"_name_or_path": "https://huggingface.co/diffusers/tiny-torch-full-checker/blob/main/unet/config.json",
|
|
4313
|
+
"act_fn": "silu",
|
|
4314
|
+
"attention_head_dim": 8,
|
|
4315
|
+
"block_out_channels": [32, 64],
|
|
4316
|
+
"center_input_sample": false,
|
|
4317
|
+
"cross_attention_dim": 32,
|
|
4318
|
+
"down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D"],
|
|
4319
|
+
"downsample_padding": 1,
|
|
4320
|
+
"dual_cross_attention": false,
|
|
4321
|
+
"flip_sin_to_cos": true,
|
|
4322
|
+
"freq_shift": 0,
|
|
4323
|
+
"in_channels": 4,
|
|
4324
|
+
"layers_per_block": 2,
|
|
4325
|
+
"mid_block_scale_factor": 1,
|
|
4326
|
+
"norm_eps": 1e-05,
|
|
4327
|
+
"norm_num_groups": 32,
|
|
4328
|
+
"out_channels": 4,
|
|
4329
|
+
"sample_size": 32,
|
|
4330
|
+
"up_block_types": ["CrossAttnUpBlock2D", "UpBlock2D"],
|
|
4331
|
+
"use_linear_projection": false,
|
|
4332
|
+
}
|
|
@@ -106,7 +106,7 @@ def get_untrained_model_with_inputs(
|
|
|
106
106
|
print(f"[get_untrained_model_with_inputs] architectures={archs!r}")
|
|
107
107
|
print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}")
|
|
108
108
|
if task is None:
|
|
109
|
-
task = task_from_arch(archs[0], model_id=model_id)
|
|
109
|
+
task = task_from_arch(archs[0], model_id=model_id, subfolder=subfolder)
|
|
110
110
|
if verbose:
|
|
111
111
|
print(f"[get_untrained_model_with_inputs] task={task!r}")
|
|
112
112
|
|
|
@@ -145,12 +145,19 @@ def get_untrained_model_with_inputs(
|
|
|
145
145
|
f"{config._attn_implementation!r}" # type: ignore[union-attr]
|
|
146
146
|
)
|
|
147
147
|
|
|
148
|
+
if type(config) is dict and "_diffusers_version" in config:
|
|
149
|
+
import diffusers
|
|
150
|
+
|
|
151
|
+
package_source = diffusers
|
|
152
|
+
else:
|
|
153
|
+
package_source = transformers
|
|
154
|
+
|
|
148
155
|
if use_pretrained:
|
|
149
156
|
model = transformers.AutoModel.from_pretrained(model_id, **mkwargs)
|
|
150
157
|
else:
|
|
151
158
|
if archs is not None:
|
|
152
159
|
try:
|
|
153
|
-
|
|
160
|
+
cls_model = getattr(package_source, archs[0])
|
|
154
161
|
except AttributeError as e:
|
|
155
162
|
# The code of the models is not in transformers but in the
|
|
156
163
|
# repository of the model. We need to download it.
|
|
@@ -174,10 +181,12 @@ def get_untrained_model_with_inputs(
|
|
|
174
181
|
f"[get_untrained_model_with_inputs] from folder "
|
|
175
182
|
f"{os.path.split(pyfiles[0])[0]!r}"
|
|
176
183
|
)
|
|
177
|
-
|
|
178
|
-
|
|
184
|
+
cls_model = (
|
|
185
|
+
transformers.dynamic_module_utils.get_class_from_dynamic_module(
|
|
186
|
+
cls_name,
|
|
187
|
+
pretrained_model_name_or_path=os.path.split(pyfiles[0])[0],
|
|
188
|
+
)
|
|
179
189
|
)
|
|
180
|
-
model = cls(config)
|
|
181
190
|
else:
|
|
182
191
|
raise AttributeError(
|
|
183
192
|
f"Unable to find class 'tranformers.{archs[0]}'. "
|
|
@@ -191,6 +200,16 @@ def get_untrained_model_with_inputs(
|
|
|
191
200
|
f"and use_pretrained=True."
|
|
192
201
|
)
|
|
193
202
|
|
|
203
|
+
try:
|
|
204
|
+
if type(config) is dict:
|
|
205
|
+
model = cls_model(**config)
|
|
206
|
+
else:
|
|
207
|
+
model = cls_model(config)
|
|
208
|
+
except RuntimeError as e:
|
|
209
|
+
raise RuntimeError(
|
|
210
|
+
f"Unable to instantiate class {cls_model.__name__} with\n{config}"
|
|
211
|
+
) from e
|
|
212
|
+
|
|
194
213
|
# input kwargs
|
|
195
214
|
kwargs, fct = random_input_kwargs(config, task)
|
|
196
215
|
if verbose:
|
|
@@ -263,7 +263,7 @@ def validate_model(
|
|
|
263
263
|
use_pretrained: bool = False,
|
|
264
264
|
optimization: Optional[str] = None,
|
|
265
265
|
quiet: bool = False,
|
|
266
|
-
patch: bool = False,
|
|
266
|
+
patch: Union[bool, str, Dict[str, bool]] = False,
|
|
267
267
|
rewrite: bool = False,
|
|
268
268
|
stop_if_static: int = 1,
|
|
269
269
|
dump_folder: Optional[str] = None,
|
|
@@ -301,8 +301,10 @@ def validate_model(
|
|
|
301
301
|
:param optimization: optimization to apply to the exported model,
|
|
302
302
|
depend on the the exporter
|
|
303
303
|
:param quiet: if quiet, catches exception if any issue
|
|
304
|
-
:param patch: applies patches (``patch_transformers=True``)
|
|
305
|
-
|
|
304
|
+
:param patch: applies patches (``patch_transformers=True, path_diffusers=True``)
|
|
305
|
+
if True before exporting
|
|
306
|
+
see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`,
|
|
307
|
+
a string can be used to specify only one of them
|
|
306
308
|
:param rewrite: applies known rewriting (``patch_transformers=True``) before exporting,
|
|
307
309
|
see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
|
|
308
310
|
:param stop_if_static: stops if a dynamic dimension becomes static,
|
|
@@ -346,9 +348,26 @@ def validate_model(
|
|
|
346
348
|
exported model returns the same outputs as the original one, otherwise,
|
|
347
349
|
:class:`onnx_diagnostic.reference.TorchOnnxEvaluator` is used.
|
|
348
350
|
"""
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
351
|
+
if isinstance(patch, bool):
|
|
352
|
+
patch_kwargs = (
|
|
353
|
+
dict(patch_transformers=True, patch_diffusers=True, patch=True)
|
|
354
|
+
if patch
|
|
355
|
+
else dict(patch=False)
|
|
356
|
+
)
|
|
357
|
+
elif isinstance(patch, str):
|
|
358
|
+
patch_kwargs = {"patch": True, **{p: True for p in patch.split(",")}} # noqa: C420
|
|
359
|
+
else:
|
|
360
|
+
assert isinstance(patch, dict), f"Unable to interpret patch={patch!r}"
|
|
361
|
+
patch_kwargs = patch.copy()
|
|
362
|
+
if "patch" not in patch_kwargs:
|
|
363
|
+
if any(patch_kwargs.values()):
|
|
364
|
+
patch_kwargs["patch"] = True
|
|
365
|
+
|
|
366
|
+
assert not rewrite or patch_kwargs.get("patch", False), (
|
|
367
|
+
f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} "
|
|
368
|
+
f"patch must be True to enable rewriting, "
|
|
369
|
+
f"if --no-patch was specified on the command line, --no-rewrite must be added."
|
|
370
|
+
)
|
|
352
371
|
summary = version_summary()
|
|
353
372
|
summary.update(
|
|
354
373
|
dict(
|
|
@@ -361,6 +380,7 @@ def validate_model(
|
|
|
361
380
|
version_optimization=optimization or "",
|
|
362
381
|
version_quiet=str(quiet),
|
|
363
382
|
version_patch=str(patch),
|
|
383
|
+
version_patch_kwargs=str(patch_kwargs).replace(" ", ""),
|
|
364
384
|
version_rewrite=str(rewrite),
|
|
365
385
|
version_dump_folder=dump_folder or "",
|
|
366
386
|
version_drop_inputs=str(list(drop_inputs or "")),
|
|
@@ -396,7 +416,7 @@ def validate_model(
|
|
|
396
416
|
print(f"[validate_model] model_options={model_options!r}")
|
|
397
417
|
print(f"[validate_model] get dummy inputs with input_options={input_options}...")
|
|
398
418
|
print(
|
|
399
|
-
f"[validate_model] rewrite={rewrite},
|
|
419
|
+
f"[validate_model] rewrite={rewrite}, patch_kwargs={patch_kwargs}, "
|
|
400
420
|
f"stop_if_static={stop_if_static}"
|
|
401
421
|
)
|
|
402
422
|
print(f"[validate_model] exporter={exporter!r}, optimization={optimization!r}")
|
|
@@ -538,9 +558,13 @@ def validate_model(
|
|
|
538
558
|
if summary["model_module"] in sys.modules:
|
|
539
559
|
summary["model_file"] = str(sys.modules[summary["model_module"]].__file__) # type: ignore[index]
|
|
540
560
|
summary["model_config_class"] = data["configuration"].__class__.__name__
|
|
541
|
-
summary["model_config"] = str(
|
|
542
|
-
|
|
543
|
-
|
|
561
|
+
summary["model_config"] = str(
|
|
562
|
+
shrink_config(
|
|
563
|
+
data["configuration"]
|
|
564
|
+
if type(data["configuration"]) is dict
|
|
565
|
+
else data["configuration"].to_dict()
|
|
566
|
+
)
|
|
567
|
+
).replace(" ", "")
|
|
544
568
|
summary["model_id"] = model_id
|
|
545
569
|
|
|
546
570
|
if verbose:
|
|
@@ -568,18 +592,18 @@ def validate_model(
|
|
|
568
592
|
f"[validate_model] -- export the model with {exporter!r}, "
|
|
569
593
|
f"optimization={optimization!r}"
|
|
570
594
|
)
|
|
571
|
-
if
|
|
595
|
+
if patch_kwargs:
|
|
572
596
|
if verbose:
|
|
573
597
|
print(
|
|
574
598
|
f"[validate_model] applies patches before exporting "
|
|
575
599
|
f"stop_if_static={stop_if_static}"
|
|
576
600
|
)
|
|
577
601
|
with torch_export_patches( # type: ignore
|
|
578
|
-
patch_transformers=True,
|
|
579
602
|
stop_if_static=stop_if_static,
|
|
580
603
|
verbose=max(0, verbose - 1),
|
|
581
604
|
rewrite=data.get("rewrite", None),
|
|
582
605
|
dump_rewriting=(os.path.join(dump_folder, "rewrite") if dump_folder else None),
|
|
606
|
+
**patch_kwargs, # type: ignore[arg-type]
|
|
583
607
|
) as modificator:
|
|
584
608
|
data["inputs_export"] = modificator(data["inputs"]) # type: ignore
|
|
585
609
|
|