onnx-diagnostic 0.7.1__py3-none-any.whl → 0.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +22 -5
- onnx_diagnostic/ext_test_case.py +31 -0
- onnx_diagnostic/helpers/cache_helper.py +23 -12
- onnx_diagnostic/helpers/config_helper.py +16 -1
- onnx_diagnostic/helpers/log_helper.py +308 -83
- onnx_diagnostic/helpers/torch_helper.py +6 -2
- onnx_diagnostic/tasks/__init__.py +2 -0
- onnx_diagnostic/tasks/text_generation.py +17 -8
- onnx_diagnostic/tasks/text_to_image.py +91 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +24 -7
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +144 -349
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +87 -7
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +259 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +15 -4
- onnx_diagnostic/torch_models/hghub/hub_data.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +28 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -5
- onnx_diagnostic/torch_models/validate.py +36 -12
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.2.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.2.dist-info}/RECORD +26 -22
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.2.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.2.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.2.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -718,13 +718,13 @@ def get_parser_agg() -> ArgumentParser:
|
|
|
718
718
|
"peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,"
|
|
719
719
|
"n_node_constant,n_node_shape,n_node_expand,"
|
|
720
720
|
"n_node_function,n_node_initializer,n_node_scatter,"
|
|
721
|
-
"time_export_unbiased",
|
|
721
|
+
"time_export_unbiased,onnx_n_nodes_no_cst,n_node_initializer_small",
|
|
722
722
|
help="Columns to compute after the aggregation was done.",
|
|
723
723
|
)
|
|
724
724
|
parser.add_argument(
|
|
725
725
|
"--views",
|
|
726
726
|
default="agg-suite,agg-all,disc,speedup,time,time_export,err,cmd,"
|
|
727
|
-
"bucket-speedup,raw-short,counts,peak-gpu",
|
|
727
|
+
"bucket-speedup,raw-short,counts,peak-gpu,onnx",
|
|
728
728
|
help="Views to add to the output files.",
|
|
729
729
|
)
|
|
730
730
|
parser.add_argument(
|
|
@@ -733,11 +733,28 @@ def get_parser_agg() -> ArgumentParser:
|
|
|
733
733
|
help="Views to dump as csv files.",
|
|
734
734
|
)
|
|
735
735
|
parser.add_argument("-v", "--verbose", type=int, default=0, help="verbosity")
|
|
736
|
+
parser.add_argument(
|
|
737
|
+
"--filter-in",
|
|
738
|
+
default="",
|
|
739
|
+
help="adds a filter to filter in data, syntax is\n"
|
|
740
|
+
'``"<column1>:<value1>;<value2>/<column2>:<value3>"`` ...',
|
|
741
|
+
)
|
|
742
|
+
parser.add_argument(
|
|
743
|
+
"--filter-out",
|
|
744
|
+
default="",
|
|
745
|
+
help="adds a filter to filter out data, syntax is\n"
|
|
746
|
+
'``"<column1>:<value1>;<value2>/<column2>:<value3>"`` ...',
|
|
747
|
+
)
|
|
736
748
|
return parser
|
|
737
749
|
|
|
738
750
|
|
|
739
751
|
def _cmd_agg(argv: List[Any]):
|
|
740
|
-
from .helpers.log_helper import
|
|
752
|
+
from .helpers.log_helper import (
|
|
753
|
+
CubeLogsPerformance,
|
|
754
|
+
open_dataframe,
|
|
755
|
+
enumerate_csv_files,
|
|
756
|
+
filter_data,
|
|
757
|
+
)
|
|
741
758
|
|
|
742
759
|
parser = get_parser_agg()
|
|
743
760
|
args = parser.parse_args(argv[1:])
|
|
@@ -748,7 +765,7 @@ def _cmd_agg(argv: List[Any]):
|
|
|
748
765
|
args.inputs, verbose=args.verbose, filtering=lambda name: bool(reg.search(name))
|
|
749
766
|
)
|
|
750
767
|
)
|
|
751
|
-
assert csv, f"No csv files in {args.inputs}, csv={csv}"
|
|
768
|
+
assert csv, f"No csv files in {args.inputs}, args.filter={args.filter!r}, csv={csv}"
|
|
752
769
|
if args.verbose:
|
|
753
770
|
from tqdm import tqdm
|
|
754
771
|
|
|
@@ -761,7 +778,7 @@ def _cmd_agg(argv: List[Any]):
|
|
|
761
778
|
assert (
|
|
762
779
|
args.time in df.columns
|
|
763
780
|
), f"Missing time column {args.time!r} in {c!r}\n{df.head()}\n{sorted(df.columns)}"
|
|
764
|
-
dfs.append(df)
|
|
781
|
+
dfs.append(filter_data(df, filter_in=args.filter_in, filter_out=args.filter_out))
|
|
765
782
|
|
|
766
783
|
drop_keys = set(args.drop_keys.split(","))
|
|
767
784
|
cube = CubeLogsPerformance(
|
onnx_diagnostic/ext_test_case.py
CHANGED
|
@@ -756,6 +756,18 @@ class ExtTestCase(unittest.TestCase):
|
|
|
756
756
|
"Adds a todo printed when all test are run."
|
|
757
757
|
cls._todos.append((f, msg))
|
|
758
758
|
|
|
759
|
+
@classmethod
|
|
760
|
+
def ort(cls):
|
|
761
|
+
import onnxruntime
|
|
762
|
+
|
|
763
|
+
return onnxruntime
|
|
764
|
+
|
|
765
|
+
@classmethod
|
|
766
|
+
def to_onnx(self, *args, **kwargs):
|
|
767
|
+
from experimental_experiment.torch_interpreter import to_onnx
|
|
768
|
+
|
|
769
|
+
return to_onnx(*args, **kwargs)
|
|
770
|
+
|
|
759
771
|
def print_model(self, model: "ModelProto"): # noqa: F821
|
|
760
772
|
"Prints a ModelProto"
|
|
761
773
|
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
|
|
@@ -917,6 +929,15 @@ class ExtTestCase(unittest.TestCase):
|
|
|
917
929
|
]
|
|
918
930
|
raise AssertionError("\n".join(rows)) # noqa: B904
|
|
919
931
|
|
|
932
|
+
def assertEqualDataFrame(self, d1, d2, **kwargs):
|
|
933
|
+
"""
|
|
934
|
+
Checks that two dataframes are equal.
|
|
935
|
+
Calls :func:`pandas.testing.assert_frame_equal`.
|
|
936
|
+
"""
|
|
937
|
+
from pandas.testing import assert_frame_equal
|
|
938
|
+
|
|
939
|
+
assert_frame_equal(d1, d2, **kwargs)
|
|
940
|
+
|
|
920
941
|
def assertEqualTrue(self, value: Any, msg: str = ""):
|
|
921
942
|
if value is True:
|
|
922
943
|
return
|
|
@@ -967,6 +988,16 @@ class ExtTestCase(unittest.TestCase):
|
|
|
967
988
|
atol=atol,
|
|
968
989
|
rtol=rtol,
|
|
969
990
|
)
|
|
991
|
+
elif expected.__class__.__name__ == "StaticCache":
|
|
992
|
+
self.assertEqual(type(expected), type(value), msg=msg)
|
|
993
|
+
self.assertEqual(expected.max_cache_len, value.max_cache_len)
|
|
994
|
+
atts = ["key_cache", "value_cache"]
|
|
995
|
+
self.assertEqualAny(
|
|
996
|
+
{k: expected.__dict__.get(k, None) for k in atts},
|
|
997
|
+
{k: value.__dict__.get(k, None) for k in atts},
|
|
998
|
+
atol=atol,
|
|
999
|
+
rtol=rtol,
|
|
1000
|
+
)
|
|
970
1001
|
elif expected.__class__.__name__ == "EncoderDecoderCache":
|
|
971
1002
|
self.assertEqual(type(expected), type(value), msg=msg)
|
|
972
1003
|
atts = ["self_attention_cache", "cross_attention_cache"]
|
|
@@ -154,10 +154,12 @@ else:
|
|
|
154
154
|
|
|
155
155
|
def make_static_cache(
|
|
156
156
|
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
157
|
+
max_cache_len: Optional[int] = None,
|
|
157
158
|
) -> transformers.cache_utils.DynamicCache:
|
|
158
159
|
"""
|
|
159
160
|
Creates an instance of :class:`transformers.cache_utils.StaticCache`.
|
|
160
161
|
:param key_value_pairs: list of pairs of (key, values)
|
|
162
|
+
:param max_cache_len: max_cache_length or something inferred from the vector
|
|
161
163
|
:return: :class:`transformers.cache_utils.StaticCache`
|
|
162
164
|
|
|
163
165
|
Example:
|
|
@@ -179,7 +181,8 @@ def make_static_cache(
|
|
|
179
181
|
torch.randn(bsize, nheads, slen, dim),
|
|
180
182
|
)
|
|
181
183
|
for i in range(n_layers)
|
|
182
|
-
]
|
|
184
|
+
],
|
|
185
|
+
max_cache_len=10,
|
|
183
186
|
)
|
|
184
187
|
print(string_type(past_key_values, with_shape=True))
|
|
185
188
|
"""
|
|
@@ -190,24 +193,32 @@ def make_static_cache(
|
|
|
190
193
|
self.num_attention_heads = key_value_pairs[0][0].shape[1]
|
|
191
194
|
self.num_hidden_layers = len(key_value_pairs)
|
|
192
195
|
|
|
196
|
+
assert max_cache_len is not None, (
|
|
197
|
+
f"max_cache_len={max_cache_len} cannot be setup "
|
|
198
|
+
f"automatically yet from shape {key_value_pairs[0][0].shape}"
|
|
199
|
+
)
|
|
200
|
+
torch._check(
|
|
201
|
+
max_cache_len >= key_value_pairs[0][0].shape[2],
|
|
202
|
+
(
|
|
203
|
+
f"max_cache_len={max_cache_len} cannot be smaller "
|
|
204
|
+
f"shape[2]={key_value_pairs[0][0].shape[2]} in shape "
|
|
205
|
+
f"{key_value_pairs[0][0].shape}"
|
|
206
|
+
),
|
|
207
|
+
)
|
|
193
208
|
cache = transformers.cache_utils.StaticCache(
|
|
194
209
|
_config(),
|
|
195
210
|
max_batch_size=key_value_pairs[0][0].shape[0],
|
|
196
211
|
device=key_value_pairs[0][0].device,
|
|
197
212
|
dtype=key_value_pairs[0][0].dtype,
|
|
198
|
-
max_cache_len=
|
|
213
|
+
max_cache_len=max_cache_len,
|
|
199
214
|
)
|
|
200
215
|
for i in range(len(key_value_pairs)):
|
|
201
|
-
assert
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
cache.key_cache[i][:, :,
|
|
206
|
-
|
|
207
|
-
f"Shape mismatch, expected {cache.value_cache[i].shape}, "
|
|
208
|
-
f"got {key_value_pairs[i][1].shape}"
|
|
209
|
-
)
|
|
210
|
-
cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
|
|
216
|
+
assert (
|
|
217
|
+
key_value_pairs[i][0].shape == key_value_pairs[i][1].shape
|
|
218
|
+
), f"Shape mismatch {key_value_pairs[i][0].shape} != {key_value_pairs[i][1].shape}"
|
|
219
|
+
d = key_value_pairs[i][1].shape[2]
|
|
220
|
+
cache.key_cache[i][:, :, :d, :] = key_value_pairs[i][0]
|
|
221
|
+
cache.value_cache[i][:, :, :d, :] = key_value_pairs[i][1]
|
|
211
222
|
return cache
|
|
212
223
|
|
|
213
224
|
|
|
@@ -43,7 +43,10 @@ def update_config(config: Any, mkwargs: Dict[str, Any]):
|
|
|
43
43
|
else:
|
|
44
44
|
update_config(getattr(config, k), v)
|
|
45
45
|
continue
|
|
46
|
-
|
|
46
|
+
if type(config) is dict:
|
|
47
|
+
config[k] = v
|
|
48
|
+
else:
|
|
49
|
+
setattr(config, k, v)
|
|
47
50
|
|
|
48
51
|
|
|
49
52
|
def _pick(config, *atts, exceptions: Optional[Dict[str, Callable]] = None):
|
|
@@ -66,6 +69,18 @@ def _pick(config, *atts, exceptions: Optional[Dict[str, Callable]] = None):
|
|
|
66
69
|
raise AssertionError(f"Unable to find any of these {atts!r} in {config}")
|
|
67
70
|
|
|
68
71
|
|
|
72
|
+
def pick(config, name: str, default_value: Any) -> Any:
|
|
73
|
+
"""
|
|
74
|
+
Returns the value of a attribute if config has it
|
|
75
|
+
otherwise the default value.
|
|
76
|
+
"""
|
|
77
|
+
if not config:
|
|
78
|
+
return default_value
|
|
79
|
+
if type(config) is dict:
|
|
80
|
+
return config.get(name, default_value)
|
|
81
|
+
return getattr(config, name, default_value)
|
|
82
|
+
|
|
83
|
+
|
|
69
84
|
@functools.cache
|
|
70
85
|
def config_class_from_architecture(arch: str, exc: bool = False) -> Optional[type]:
|
|
71
86
|
"""
|