onnx-diagnostic 0.7.11__py3-none-any.whl → 0.7.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +5 -2
- onnx_diagnostic/export/dynamic_shapes.py +11 -2
- onnx_diagnostic/helpers/helper.py +11 -5
- onnx_diagnostic/helpers/log_helper.py +65 -12
- onnx_diagnostic/helpers/mini_onnx_builder.py +17 -0
- onnx_diagnostic/helpers/model_builder_helper.py +1 -0
- onnx_diagnostic/helpers/rt_helper.py +55 -37
- onnx_diagnostic/helpers/torch_helper.py +31 -7
- onnx_diagnostic/reference/torch_evaluator.py +2 -2
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/image_text_to_text.py +256 -141
- onnx_diagnostic/tasks/text_generation.py +15 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +177 -150
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +19 -1
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +40 -14
- onnx_diagnostic/torch_export_patches/patch_inputs.py +10 -6
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +116 -10
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +269 -4
- onnx_diagnostic/torch_models/hghub/hub_api.py +4 -10
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +36 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +32 -4
- onnx_diagnostic/torch_models/validate.py +337 -113
- onnx_diagnostic/torch_onnx/sbs.py +2 -1
- {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/METADATA +11 -31
- {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/RECORD +30 -28
- {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -581,6 +581,7 @@ def _cmd_validate(argv: List[Any]):
|
|
|
581
581
|
):
|
|
582
582
|
print(f"validate - unsupported args: export={args.export!r}, opt={args.opt!r}")
|
|
583
583
|
return
|
|
584
|
+
patch_dict = args.patch if isinstance(args.patch, dict) else {"patch": args.patch}
|
|
584
585
|
summary, _data = validate_model(
|
|
585
586
|
model_id=args.mid,
|
|
586
587
|
task=args.task,
|
|
@@ -591,8 +592,8 @@ def _cmd_validate(argv: List[Any]):
|
|
|
591
592
|
use_pretrained=args.trained,
|
|
592
593
|
dtype=args.dtype,
|
|
593
594
|
device=args.device,
|
|
594
|
-
patch=
|
|
595
|
-
rewrite=args.rewrite,
|
|
595
|
+
patch=patch_dict,
|
|
596
|
+
rewrite=args.rewrite and patch_dict.get("patch", True),
|
|
596
597
|
stop_if_static=args.stop_if_static,
|
|
597
598
|
optimization=args.opt,
|
|
598
599
|
exporter=args.export,
|
|
@@ -827,6 +828,8 @@ def get_parser_agg() -> ArgumentParser:
|
|
|
827
828
|
"n_model_running,n_model_acc01,n_model_acc001,n_model_dynamic,"
|
|
828
829
|
"n_model_pass,n_model_faster,"
|
|
829
830
|
"n_model_faster2x,n_model_faster3x,n_model_faster4x,n_node_attention,"
|
|
831
|
+
"n_node_attention23,n_node_rotary_embedding,n_node_rotary_embedding23,"
|
|
832
|
+
"n_node_layer_normalization,n_node_layer_normalization23,"
|
|
830
833
|
"peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,"
|
|
831
834
|
"n_node_constant,n_node_shape,n_node_expand,"
|
|
832
835
|
"n_node_function,n_node_initializer,n_node_scatter,"
|
|
@@ -56,6 +56,14 @@ class CoupleInputsDynamicShapes:
|
|
|
56
56
|
self.kwargs = kwargs
|
|
57
57
|
self.dynamic_shapes = dynamic_shapes
|
|
58
58
|
self.args_names = args_names
|
|
59
|
+
if not self.kwargs and isinstance(self.dynamic_shapes, dict):
|
|
60
|
+
# This assumes the dictionary for the dynamic shapes is ordered
|
|
61
|
+
# the same way the args are. The input names are not known.
|
|
62
|
+
assert len(self.dynamic_shapes) == len(self.args), (
|
|
63
|
+
f"Length mismatch, kwargs is empty, len(dynamic_shapes)="
|
|
64
|
+
f"{len(self.dynamic_shapes)}, len(args)={len(self.args)}"
|
|
65
|
+
)
|
|
66
|
+
self.dynamic_shapes = tuple(self.dynamic_shapes.values())
|
|
59
67
|
|
|
60
68
|
def __str__(self) -> str:
|
|
61
69
|
return "\n".join(
|
|
@@ -232,8 +240,9 @@ class CoupleInputsDynamicShapes:
|
|
|
232
240
|
"""
|
|
233
241
|
if not self.args:
|
|
234
242
|
assert isinstance(self.kwargs, dict) and isinstance(self.dynamic_shapes, dict), (
|
|
235
|
-
f"Type mismatch, args={string_type(self.args)}
|
|
236
|
-
f"
|
|
243
|
+
f"Type mismatch, args={string_type(self.args)}, "
|
|
244
|
+
f"kwargs={string_type(self.kwargs)} and dynamic_shapes="
|
|
245
|
+
f"{string_type(self.dynamic_shapes)} should have the same type."
|
|
237
246
|
)
|
|
238
247
|
res = self._generic_walker_step(
|
|
239
248
|
processor,
|
|
@@ -397,7 +397,7 @@ def string_type(
|
|
|
397
397
|
return "AUTO"
|
|
398
398
|
if verbose:
|
|
399
399
|
print(f"[string_type] Y7:{type(obj)}")
|
|
400
|
-
return str(obj)
|
|
400
|
+
return str(obj).replace("DimHint(DYNAMIC)", "DYNAMIC").replace("DimHint(AUTO)", "AUTO")
|
|
401
401
|
|
|
402
402
|
if isinstance(obj, bool):
|
|
403
403
|
if with_min_max:
|
|
@@ -516,8 +516,10 @@ def string_type(
|
|
|
516
516
|
print(f"[string_type] V2:{type(obj)}")
|
|
517
517
|
return "OV(NOTENSOR)"
|
|
518
518
|
if with_min_max:
|
|
519
|
+
from .torch_helper import to_numpy
|
|
520
|
+
|
|
519
521
|
try:
|
|
520
|
-
t = obj
|
|
522
|
+
t = to_numpy(obj)
|
|
521
523
|
except Exception:
|
|
522
524
|
# pass unable to convert into numpy (bfloat16, ...)
|
|
523
525
|
if verbose:
|
|
@@ -939,7 +941,7 @@ def flatten_object(x: Any, drop_keys: bool = False) -> Any:
|
|
|
939
941
|
return flatten_object(list(x.values()), drop_keys=drop_keys)
|
|
940
942
|
return flatten_object(list(x.items()), drop_keys=drop_keys)
|
|
941
943
|
|
|
942
|
-
if x.__class__.__name__ in {"DynamicCache", "StaticCache"}:
|
|
944
|
+
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
|
|
943
945
|
from .cache_helper import CacheKeyValue
|
|
944
946
|
|
|
945
947
|
kc = CacheKeyValue(x)
|
|
@@ -1233,9 +1235,13 @@ def max_diff(
|
|
|
1233
1235
|
|
|
1234
1236
|
if isinstance(expected, np.ndarray) or isinstance(got, np.ndarray):
|
|
1235
1237
|
if isinstance(expected, torch.Tensor):
|
|
1236
|
-
|
|
1238
|
+
from .torch_helper import to_numpy
|
|
1239
|
+
|
|
1240
|
+
expected = to_numpy(expected)
|
|
1237
1241
|
if isinstance(got, torch.Tensor):
|
|
1238
|
-
|
|
1242
|
+
from .torch_helper import to_numpy
|
|
1243
|
+
|
|
1244
|
+
got = to_numpy(got)
|
|
1239
1245
|
if verbose >= 6:
|
|
1240
1246
|
print(f"[max_diff] tensor: {string_type(expected)} ? {string_type(got)}")
|
|
1241
1247
|
|
|
@@ -285,7 +285,8 @@ class CubePlot:
|
|
|
285
285
|
nn = df.shape[1] // n_cols
|
|
286
286
|
nn += int(df.shape[1] % n_cols != 0)
|
|
287
287
|
ratio = float(os.environ.get("FIGSIZEH", "1"))
|
|
288
|
-
|
|
288
|
+
figsize = (6 * n_cols, nn * (2.5 + df.shape[0] / 15) * ratio)
|
|
289
|
+
fig, axs = plt.subplots(nn, n_cols, figsize=figsize)
|
|
289
290
|
pos = 0
|
|
290
291
|
imgs = []
|
|
291
292
|
for c in self._make_loop(df.columns, verbose):
|
|
@@ -332,10 +333,12 @@ class CubePlot:
|
|
|
332
333
|
n_cols = len(groups)
|
|
333
334
|
|
|
334
335
|
title_suffix = f"\n{title_suffix}" if title_suffix else ""
|
|
336
|
+
ratio = float(os.environ.get("FIGSIZEH", "1"))
|
|
337
|
+
figsize = (5 * n_cols, max(len(g) for g in groups) * (2 + df.shape[1] / 2) * ratio)
|
|
335
338
|
fig, axs = plt.subplots(
|
|
336
339
|
df.shape[1],
|
|
337
340
|
n_cols,
|
|
338
|
-
figsize=
|
|
341
|
+
figsize=figsize,
|
|
339
342
|
sharex=True,
|
|
340
343
|
sharey="row" if n_cols > 1 else False,
|
|
341
344
|
)
|
|
@@ -877,7 +880,11 @@ class CubeLogs:
|
|
|
877
880
|
print(f"[CubeLogs.view] key_columns={key_columns}")
|
|
878
881
|
g = data[[*key_index, *key_columns]].copy()
|
|
879
882
|
g["count"] = 1
|
|
880
|
-
r =
|
|
883
|
+
r = (
|
|
884
|
+
g.copy()
|
|
885
|
+
if not key_index and not key_columns
|
|
886
|
+
else g.groupby([*key_index, *key_columns], dropna=False).sum()
|
|
887
|
+
)
|
|
881
888
|
not_unique = r[r["count"] > 1]
|
|
882
889
|
assert not_unique.shape[0] == 0, (
|
|
883
890
|
f"view_def.name={view_def.name!r}, "
|
|
@@ -1505,6 +1512,11 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1505
1512
|
"n_model_faster3x",
|
|
1506
1513
|
"n_model_faster4x",
|
|
1507
1514
|
"n_node_attention",
|
|
1515
|
+
"n_node_attention23",
|
|
1516
|
+
"n_node_rotary_embedding",
|
|
1517
|
+
"n_node_rotary_embedding23",
|
|
1518
|
+
"n_node_layer_normalization",
|
|
1519
|
+
"n_node_layer_normalization23",
|
|
1508
1520
|
"n_node_control_flow",
|
|
1509
1521
|
"n_node_scatter",
|
|
1510
1522
|
"n_node_function",
|
|
@@ -1568,7 +1580,9 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1568
1580
|
|
|
1569
1581
|
def gdf(df, cname, default_value=np.nan):
|
|
1570
1582
|
if cname in df.columns:
|
|
1571
|
-
|
|
1583
|
+
if np.isnan(default_value):
|
|
1584
|
+
return df[cname]
|
|
1585
|
+
return df[cname].fillna(default_value)
|
|
1572
1586
|
return pandas.Series(default_value, index=df.index)
|
|
1573
1587
|
|
|
1574
1588
|
def ghas_value(df, cname):
|
|
@@ -1676,15 +1690,54 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1676
1690
|
"time_latency",
|
|
1677
1691
|
gdf(df, "time_latency_eager") > gdf(df, "time_latency", np.inf) * 3.98,
|
|
1678
1692
|
),
|
|
1693
|
+
n_node_attention23=lambda df: gpreserve(
|
|
1694
|
+
df, "time_latency_eager", gdf(df, "op_onnx__Attention")
|
|
1695
|
+
),
|
|
1696
|
+
n_node_rotary_embedding23=lambda df: gpreserve(
|
|
1697
|
+
df, "time_latency_eager", gdf(df, "op_onnx__RotaryEmbedding")
|
|
1698
|
+
),
|
|
1699
|
+
n_node_layer_normalization23=lambda df: gpreserve(
|
|
1700
|
+
df,
|
|
1701
|
+
"time_latency_eager",
|
|
1702
|
+
gdf(df, "op_onnx__LayerNormalization", 0)
|
|
1703
|
+
+ gdf(df, "op_onnx__RMSNormalization", 0)
|
|
1704
|
+
+ gdf(df, "op_onnx__BatchNormlization", 0)
|
|
1705
|
+
+ gdf(df, "op_onnx__InstanceNormlization", 0)
|
|
1706
|
+
+ gdf(df, "op_onnx__GroupNormalization", 0),
|
|
1707
|
+
),
|
|
1679
1708
|
n_node_attention=lambda df: gpreserve(
|
|
1680
1709
|
df,
|
|
1681
|
-
"
|
|
1682
|
-
gdf(df, "op_onnx_com.microsoft_Attention")
|
|
1683
|
-
+ gdf(df, "op_onnx_com.microsoft_MultiHeadAttention")
|
|
1710
|
+
"time_latency_eager",
|
|
1711
|
+
gdf(df, "op_onnx_com.microsoft_Attention", 0)
|
|
1712
|
+
+ gdf(df, "op_onnx_com.microsoft_MultiHeadAttention", 0)
|
|
1713
|
+
+ gdf(df, "op_onnx_com.microsoft_PackedAttention", 0)
|
|
1714
|
+
+ gdf(df, "op_onnx_com.microsoft_PackedMultiHeadAttention", 0)
|
|
1715
|
+
+ gdf(df, "op_onnx_com.microsoft_GroupQueryAttention", 0)
|
|
1716
|
+
+ gdf(df, "op_onnx_com.microsoft_PagedAttention", 0)
|
|
1717
|
+
+ gdf(df, "op_onnx_com.microsoft_DecoderAttention", 0)
|
|
1718
|
+
+ gdf(df, "op_onnx_com.microsoft_LongformerAttention", 0)
|
|
1719
|
+
+ gdf(df, "op_onnx_com.microsoft_DecoderMaskedSelfAttention", 0)
|
|
1720
|
+
+ gdf(df, "op_onnx_com.microsoft_DecoderMaskedMultiHeadAttention", 0)
|
|
1721
|
+
+ gdf(df, "op_onnx_com.microsoft_SparseAttention", 0),
|
|
1722
|
+
),
|
|
1723
|
+
n_node_layer_normalization=lambda df: gpreserve(
|
|
1724
|
+
df,
|
|
1725
|
+
"time_latency_eager",
|
|
1726
|
+
gdf(df, "op_onnx_com.microsoft_EmbedLayerNormalization", 0)
|
|
1727
|
+
+ gdf(df, "op_onnx_com.microsoft_SkipLayerNormalization", 0)
|
|
1728
|
+
+ gdf(df, "op_onnx_com.microsoft_LayerNormalization", 0)
|
|
1729
|
+
+ gdf(df, "op_onnx_com.microsoft_SkipSimplifiedLayerNormalization", 0)
|
|
1730
|
+
+ gdf(df, "op_onnx_com.microsoft_SimplifiedLayerNormalization", 0),
|
|
1731
|
+
),
|
|
1732
|
+
n_node_rotary_embedding=lambda df: gpreserve(
|
|
1733
|
+
df,
|
|
1734
|
+
"time_latency_eager",
|
|
1735
|
+
gdf(df, "op_onnx_com.microsoft_GemmaRotaryEmbedding", 0)
|
|
1736
|
+
+ gdf(df, "op_onnx_com.microsoft_RotaryEmbedding", 0),
|
|
1684
1737
|
),
|
|
1685
1738
|
n_node_control_flow=lambda df: gpreserve(
|
|
1686
1739
|
df,
|
|
1687
|
-
"
|
|
1740
|
+
"time_latency_eager",
|
|
1688
1741
|
(
|
|
1689
1742
|
gdf(df, "op_onnx__If", 0)
|
|
1690
1743
|
+ gdf(df, "op_onnx__Scan", 0)
|
|
@@ -1693,7 +1746,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1693
1746
|
),
|
|
1694
1747
|
n_node_scatter=lambda df: gpreserve(
|
|
1695
1748
|
df,
|
|
1696
|
-
"
|
|
1749
|
+
"time_latency_eager",
|
|
1697
1750
|
gdf(df, "op_onnx__ScatterND", 0) + gdf(df, "op_onnx__ScatterElements", 0),
|
|
1698
1751
|
),
|
|
1699
1752
|
n_node_function=lambda df: gpreserve(
|
|
@@ -1706,13 +1759,13 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1706
1759
|
df, "onnx_n_initializer", gdf(df, "onnx_n_initializer")
|
|
1707
1760
|
),
|
|
1708
1761
|
n_node_constant=lambda df: gpreserve(
|
|
1709
|
-
df, "
|
|
1762
|
+
df, "time_latency_eager", gdf(df, "op_onnx__Constant")
|
|
1710
1763
|
),
|
|
1711
1764
|
n_node_shape=lambda df: gpreserve(
|
|
1712
|
-
df, "
|
|
1765
|
+
df, "time_latency_eager", gdf(df, "op_onnx__Shape")
|
|
1713
1766
|
),
|
|
1714
1767
|
n_node_expand=lambda df: gpreserve(
|
|
1715
|
-
df, "
|
|
1768
|
+
df, "time_latency_eager", gdf(df, "op_onnx__Expand")
|
|
1716
1769
|
),
|
|
1717
1770
|
)
|
|
1718
1771
|
assert (
|
|
@@ -381,6 +381,23 @@ def _flatten_iterator(obj: Any, sep: str) -> Iterator:
|
|
|
381
381
|
else:
|
|
382
382
|
for p, o in _flatten_iterator(getattr(obj, att), sep):
|
|
383
383
|
yield f"DynamicCache_{att}{sep}{p}", o
|
|
384
|
+
elif obj.__class__.__name__ == "StaticCache":
|
|
385
|
+
# transformers
|
|
386
|
+
import transformers
|
|
387
|
+
from .cache_helper import CacheKeyValue
|
|
388
|
+
|
|
389
|
+
assert isinstance(
|
|
390
|
+
obj, transformers.cache_utils.StaticCache
|
|
391
|
+
), f"Unexpected type {type(obj)}"
|
|
392
|
+
obj = CacheKeyValue(obj)
|
|
393
|
+
atts = ["key_cache", "value_cache"]
|
|
394
|
+
for i, att in enumerate(atts):
|
|
395
|
+
if i == len(atts) - 1:
|
|
396
|
+
for p, o in _flatten_iterator(getattr(obj, att), sep):
|
|
397
|
+
yield f"StaticCache._{att}{sep}{p}", o
|
|
398
|
+
else:
|
|
399
|
+
for p, o in _flatten_iterator(getattr(obj, att), sep):
|
|
400
|
+
yield f"StaticCache_{att}{sep}{p}", o
|
|
384
401
|
else:
|
|
385
402
|
raise NotImplementedError(f"Unexpected type {type(obj)}")
|
|
386
403
|
|
|
@@ -203,6 +203,7 @@ def create_model_builder(
|
|
|
203
203
|
"ChatGLMModel": builder.ChatGLMModel,
|
|
204
204
|
"Ernie4_5_ForCausalLM": builder.ErnieModel,
|
|
205
205
|
"GemmaForCausalLM": builder.Gemma2Model,
|
|
206
|
+
"Gemma2ForCausalLM": builder.Gemma2Model,
|
|
206
207
|
"Gemma3ForCausalLM": builder.Gemma3Model,
|
|
207
208
|
"Gemma3ForConditionalGeneration": builder.Gemma3Model,
|
|
208
209
|
"GraniteForCausalLM": builder.GraniteModel,
|
|
@@ -3,7 +3,7 @@ import numpy as np
|
|
|
3
3
|
import onnx
|
|
4
4
|
import torch
|
|
5
5
|
from .helper import string_type, flatten_object
|
|
6
|
-
from .
|
|
6
|
+
from .torch_helper import to_numpy
|
|
7
7
|
from .cache_helper import is_cache_dynamic_registered
|
|
8
8
|
|
|
9
9
|
|
|
@@ -23,6 +23,7 @@ def make_feeds(
|
|
|
23
23
|
use_numpy: bool = False,
|
|
24
24
|
copy: bool = False,
|
|
25
25
|
check_flatten: bool = True,
|
|
26
|
+
is_modelbuilder: bool = False,
|
|
26
27
|
) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
|
|
27
28
|
"""
|
|
28
29
|
Serializes the inputs to produce feeds expected
|
|
@@ -35,10 +36,15 @@ def make_feeds(
|
|
|
35
36
|
by ``OrtValue``
|
|
36
37
|
:param check_flatten: if True, checks the ``torch.utils._pytree.tree_flatten``
|
|
37
38
|
returns the same number of outputs
|
|
39
|
+
:param is_modelbuilder: if True, the exporter is ModelBuilder, and we need to reorder
|
|
40
|
+
the past_key_values inputs to match the expected order, and get rid of position_ids.
|
|
38
41
|
:return: feeds dictionary
|
|
39
42
|
"""
|
|
40
|
-
# position_ids is a special case because ModelBuilder does not usually use it
|
|
41
|
-
#
|
|
43
|
+
# NOTE: position_ids is a special case because ModelBuilder does not usually use it,
|
|
44
|
+
# because it's fued into rotary embedding in GQA.
|
|
45
|
+
if is_modelbuilder and isinstance(inputs, dict):
|
|
46
|
+
inputs.pop("position_ids", None) # Ensure 'position_ids' absent before removing.
|
|
47
|
+
|
|
42
48
|
flat = flatten_object(inputs, drop_keys=True)
|
|
43
49
|
assert (
|
|
44
50
|
not check_flatten
|
|
@@ -51,7 +57,7 @@ def make_feeds(
|
|
|
51
57
|
f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True)}"
|
|
52
58
|
)
|
|
53
59
|
if use_numpy:
|
|
54
|
-
flat = [t
|
|
60
|
+
flat = [to_numpy(t) if isinstance(t, torch.Tensor) else t for t in flat]
|
|
55
61
|
names = (
|
|
56
62
|
[i.name for i in proto.graph.input]
|
|
57
63
|
if isinstance(proto, onnx.ModelProto)
|
|
@@ -76,39 +82,6 @@ def make_feeds(
|
|
|
76
82
|
f"\n-- inputs={string_type(inputs, with_shape=True)}"
|
|
77
83
|
f"\n-- names={names}"
|
|
78
84
|
)
|
|
79
|
-
if len(names) < len(flat) and (
|
|
80
|
-
isinstance(proto, onnx.ModelProto) or hasattr(proto, "get_inputs")
|
|
81
|
-
):
|
|
82
|
-
|
|
83
|
-
typed_names = (
|
|
84
|
-
[(i.name, i.type.tensor_type.elem_type) for i in proto.graph.input]
|
|
85
|
-
if isinstance(proto, onnx.ModelProto)
|
|
86
|
-
else [(i.name, name_type_to_onnx_dtype(i.type)) for i in proto.get_inputs()]
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
new_flat = []
|
|
90
|
-
pos = 0
|
|
91
|
-
for _name, dtype in typed_names:
|
|
92
|
-
assert isinstance(
|
|
93
|
-
dtype, int
|
|
94
|
-
), f"Unexpected value for dtype={dtype!r}, type(proto)={type(proto)}"
|
|
95
|
-
itype = dtype_to_tensor_dtype(flat[pos].dtype)
|
|
96
|
-
while dtype != itype:
|
|
97
|
-
pos += 1
|
|
98
|
-
if pos >= len(flat):
|
|
99
|
-
break
|
|
100
|
-
itype = dtype_to_tensor_dtype(flat[pos].dtype)
|
|
101
|
-
if pos >= len(flat):
|
|
102
|
-
break
|
|
103
|
-
new_flat.append(flat[pos])
|
|
104
|
-
pos += 1
|
|
105
|
-
assert len(new_flat) == len(names), (
|
|
106
|
-
f"Unable to align expected input {names} with the given input, "
|
|
107
|
-
f"type(proto)={type(proto)}"
|
|
108
|
-
f"\n-- inputs: {string_type(inputs, with_shape=True)}"
|
|
109
|
-
f"\n-- typed_names: {typed_names}"
|
|
110
|
-
)
|
|
111
|
-
flat = new_flat
|
|
112
85
|
|
|
113
86
|
if copy:
|
|
114
87
|
flat = [t.copy() if hasattr(t, "copy") else t.clone() for t in flat]
|
|
@@ -122,4 +95,49 @@ def make_feeds(
|
|
|
122
95
|
elif isinstance(i, float):
|
|
123
96
|
i = np.array(i, dtype=np.float32)
|
|
124
97
|
new_flat.append(i)
|
|
98
|
+
|
|
99
|
+
# NOTE: model builder has a different order for past_key_values
|
|
100
|
+
# we need to reorder them to match the expected order
|
|
101
|
+
if is_modelbuilder:
|
|
102
|
+
# We assume that if "past_key_values" is in the names when it's
|
|
103
|
+
# modelbuilder
|
|
104
|
+
non_past_kv_input_names = [n for n in names if "past_key_values" not in n]
|
|
105
|
+
past_kv_names = [n for n in names if "past_key_values" in n]
|
|
106
|
+
reorder_past_kv_names = reorder_modelbuilder_cache_to_torch(past_kv_names)
|
|
107
|
+
names = non_past_kv_input_names + reorder_past_kv_names
|
|
125
108
|
return dict(zip(names, new_flat))
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def reorder_modelbuilder_cache_to_torch(past_kv: List[Any]) -> List[Any]:
|
|
112
|
+
"""
|
|
113
|
+
Reorders the past_kvs for ModelBuilder to match the expected order
|
|
114
|
+
by PyTorch exported models.
|
|
115
|
+
|
|
116
|
+
.. note::
|
|
117
|
+
This function can take either the names or the actual tensors
|
|
118
|
+
as long as they are in a list.
|
|
119
|
+
|
|
120
|
+
Conceptually,
|
|
121
|
+
|
|
122
|
+
From::
|
|
123
|
+
|
|
124
|
+
[past_key_values.0.key, past_key_values.0.value,
|
|
125
|
+
past_key_values.1.key, past_key_values.1.value, ...]
|
|
126
|
+
|
|
127
|
+
To::
|
|
128
|
+
|
|
129
|
+
[past_key_values.0.key, past_key_values.1.key,
|
|
130
|
+
..., past_key_values.0.value, past_key_values.1.value, ...]
|
|
131
|
+
|
|
132
|
+
:param past_kv: list of flattened inputs
|
|
133
|
+
:return: reordered list of flattened inputs
|
|
134
|
+
"""
|
|
135
|
+
total_len = len(past_kv)
|
|
136
|
+
if total_len % 2 != 0:
|
|
137
|
+
raise ValueError("The length of past_key_values should be even.")
|
|
138
|
+
keys = []
|
|
139
|
+
values = []
|
|
140
|
+
for i in range(0, total_len, 2):
|
|
141
|
+
keys.append(past_kv[i])
|
|
142
|
+
values.append(past_kv[i + 1])
|
|
143
|
+
return keys + values
|
|
@@ -5,7 +5,7 @@ import os
|
|
|
5
5
|
import sys
|
|
6
6
|
import warnings
|
|
7
7
|
from collections.abc import Iterable
|
|
8
|
-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
8
|
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
|
9
9
|
import numpy as np
|
|
10
10
|
import onnx
|
|
11
11
|
from onnx.external_data_helper import load_external_data_for_tensor, uses_external_data
|
|
@@ -283,9 +283,11 @@ def steal_forward(
|
|
|
283
283
|
],
|
|
284
284
|
fprint: Callable = string_type,
|
|
285
285
|
dump_file: Optional[str] = None,
|
|
286
|
+
dump_drop: Optional[Set[str]] = None,
|
|
286
287
|
submodules: bool = False,
|
|
287
288
|
verbose: int = 0,
|
|
288
289
|
storage_limit: int = 2**27,
|
|
290
|
+
save_as_external_data: bool = True,
|
|
289
291
|
**kwargs,
|
|
290
292
|
):
|
|
291
293
|
"""
|
|
@@ -303,6 +305,9 @@ def steal_forward(
|
|
|
303
305
|
:param dump_file: dumps stolen inputs and outputs in an onnx model,
|
|
304
306
|
they can be restored with :func:`create_input_tensors_from_onnx_model
|
|
305
307
|
<onnx_diagnostic.helpers.mini_onnx_builder.create_input_tensors_from_onnx_model>`
|
|
308
|
+
:param dump_drop: to drop some inputs too big (only if dump_file is specified)
|
|
309
|
+
:param save_as_external_data: True by default, but maybe better to have everything
|
|
310
|
+
in a single file if possible
|
|
306
311
|
:param submodules: if True and model is a module, the list extended with all the submodules
|
|
307
312
|
the module contains
|
|
308
313
|
:param verbose: verbosity
|
|
@@ -411,6 +416,15 @@ def steal_forward(
|
|
|
411
416
|
if verbose:
|
|
412
417
|
size = torch_tensor_size(storage)
|
|
413
418
|
print(f"-- gather stored {len(storage)} objects, size={size // 2 ** 20} Mb")
|
|
419
|
+
if dump_drop:
|
|
420
|
+
for k, v in storage.items():
|
|
421
|
+
if k[-1] == "I":
|
|
422
|
+
_args, kwargs = v
|
|
423
|
+
ii = set(kwargs) & dump_drop
|
|
424
|
+
if ii:
|
|
425
|
+
for i in ii:
|
|
426
|
+
print("---", i)
|
|
427
|
+
del kwargs[i]
|
|
414
428
|
proto = create_onnx_model_from_input_tensors(storage)
|
|
415
429
|
if verbose:
|
|
416
430
|
print("-- dumps stored objects")
|
|
@@ -420,7 +434,7 @@ def steal_forward(
|
|
|
420
434
|
onnx.save(
|
|
421
435
|
proto,
|
|
422
436
|
dump_file,
|
|
423
|
-
save_as_external_data=
|
|
437
|
+
save_as_external_data=save_as_external_data,
|
|
424
438
|
all_tensors_to_one_file=True,
|
|
425
439
|
location=location,
|
|
426
440
|
)
|
|
@@ -464,10 +478,10 @@ def is_torchdynamo_exporting() -> bool:
|
|
|
464
478
|
return False
|
|
465
479
|
|
|
466
480
|
|
|
467
|
-
def to_numpy(tensor: "torch.Tensor"): # noqa: F821
|
|
481
|
+
def to_numpy(tensor: "torch.Tensor") -> np.ndarray: # noqa: F821
|
|
468
482
|
"""Converts a :class:`torch.Tensor` to :class:`numpy.ndarray`."""
|
|
469
483
|
try:
|
|
470
|
-
return tensor.numpy()
|
|
484
|
+
return tensor.detach().cpu().numpy()
|
|
471
485
|
except TypeError:
|
|
472
486
|
# We try with ml_dtypes
|
|
473
487
|
pass
|
|
@@ -476,7 +490,7 @@ def to_numpy(tensor: "torch.Tensor"): # noqa: F821
|
|
|
476
490
|
|
|
477
491
|
conv = {torch.bfloat16: ml_dtypes.bfloat16}
|
|
478
492
|
assert tensor.dtype in conv, f"Unsupported type {tensor.dtype}, not in {conv}"
|
|
479
|
-
return tensor.to(torch.float32).numpy().astype(conv[tensor.dtype])
|
|
493
|
+
return tensor.detach().to(torch.float32).cpu().numpy().astype(conv[tensor.dtype])
|
|
480
494
|
|
|
481
495
|
|
|
482
496
|
def replace_string_by_dynamic(dynamic_shapes: Any) -> Any:
|
|
@@ -765,7 +779,12 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
|
|
|
765
779
|
|
|
766
780
|
|
|
767
781
|
def torch_deepcopy(value: Any) -> Any:
|
|
768
|
-
"""
|
|
782
|
+
"""
|
|
783
|
+
Makes a deep copy.
|
|
784
|
+
|
|
785
|
+
:param value: any value
|
|
786
|
+
:return: a deep copy
|
|
787
|
+
"""
|
|
769
788
|
if value is None:
|
|
770
789
|
return None
|
|
771
790
|
if isinstance(value, (int, float, str)):
|
|
@@ -794,9 +813,14 @@ def torch_deepcopy(value: Any) -> Any:
|
|
|
794
813
|
from .cache_helper import CacheKeyValue
|
|
795
814
|
|
|
796
815
|
ca = CacheKeyValue(value)
|
|
816
|
+
if len(ca.key_cache) == 0:
|
|
817
|
+
# Use of deepcopy.
|
|
818
|
+
import copy
|
|
819
|
+
|
|
820
|
+
return copy.deepcopy(value)
|
|
797
821
|
return make_static_cache(
|
|
798
822
|
torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))),
|
|
799
|
-
max_cache_len=value.max_cache_len,
|
|
823
|
+
max_cache_len=max([value.max_cache_len, *[t.shape[2] for t in ca.key_cache]]),
|
|
800
824
|
)
|
|
801
825
|
if value.__class__.__name__ == "HybridCache":
|
|
802
826
|
from .cache_helper import CacheKeyValue
|
|
@@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import onnx
|
|
5
5
|
import torch
|
|
6
|
-
from ..helpers.torch_helper import to_tensor
|
|
6
|
+
from ..helpers.torch_helper import to_tensor, to_numpy
|
|
7
7
|
from ..torch_onnx.runtime_info import first_used_last_used, RuntimeValue
|
|
8
8
|
from .report_results_comparison import ReportResultComparison
|
|
9
9
|
from . import torch_ops
|
|
@@ -578,7 +578,7 @@ class TorchOnnxEvaluator:
|
|
|
578
578
|
print(f"- clean {o}")
|
|
579
579
|
|
|
580
580
|
if use_numpy:
|
|
581
|
-
return [None if a is None else a
|
|
581
|
+
return [None if a is None else to_numpy(a) for a in fres]
|
|
582
582
|
return fres
|
|
583
583
|
|
|
584
584
|
def run_with_values(
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_data(name: str):
|
|
5
|
+
"""Returns data stored in this folder."""
|
|
6
|
+
filename = os.path.join(os.path.dirname(__file__), name)
|
|
7
|
+
assert os.path.exists(
|
|
8
|
+
filename
|
|
9
|
+
), f"Unable to find a file with {name!r}, looked for {filename!r}"
|
|
10
|
+
|
|
11
|
+
from ...helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
|
|
12
|
+
|
|
13
|
+
return create_input_tensors_from_onnx_model(filename)
|
|
Binary file
|