onnx-diagnostic 0.7.16__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +78 -22
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +2 -1
- onnx_diagnostic/export/shape_helper.py +47 -70
- onnx_diagnostic/ext_test_case.py +11 -0
- onnx_diagnostic/helpers/cache_helper.py +38 -7
- onnx_diagnostic/helpers/fake_tensor_helper.py +224 -104
- onnx_diagnostic/helpers/helper.py +27 -33
- onnx_diagnostic/helpers/log_helper.py +109 -5
- onnx_diagnostic/helpers/memory_peak.py +2 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +1 -1
- onnx_diagnostic/helpers/model_builder_helper.py +132 -2
- onnx_diagnostic/helpers/onnx_helper.py +1 -1
- onnx_diagnostic/helpers/ort_session.py +4 -0
- onnx_diagnostic/helpers/rt_helper.py +393 -43
- onnx_diagnostic/helpers/torch_helper.py +20 -1
- onnx_diagnostic/tasks/__init__.py +7 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +2 -8
- onnx_diagnostic/tasks/feature_extraction.py +2 -8
- onnx_diagnostic/tasks/image_text_to_text.py +10 -8
- onnx_diagnostic/tasks/summarization.py +2 -8
- onnx_diagnostic/tasks/text2text_generation.py +3 -8
- onnx_diagnostic/tasks/text_generation.py +86 -65
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +718 -438
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +1 -1
- onnx_diagnostic/torch_export_patches/patch_module.py +9 -36
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -6
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +162 -24
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +140 -104
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +1 -4
- onnx_diagnostic/torch_models/validate.py +626 -228
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/RECORD +38 -36
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/top_level.txt +0 -0
|
@@ -42,6 +42,8 @@ class CubeViewDef:
|
|
|
42
42
|
:param name: name of the view, used mostly to debug
|
|
43
43
|
:param plots: adds plot to the Excel sheet
|
|
44
44
|
:param no_index: remove the index (but keeps the columns)
|
|
45
|
+
:param fix_aggregation_change: a column among the keys which changes aggregation value
|
|
46
|
+
for different dates
|
|
45
47
|
|
|
46
48
|
Some examples of views. First example is an aggregated view
|
|
47
49
|
for many metrics.
|
|
@@ -106,6 +108,7 @@ class CubeViewDef:
|
|
|
106
108
|
name: Optional[str] = None,
|
|
107
109
|
no_index: bool = False,
|
|
108
110
|
plots: bool = False,
|
|
111
|
+
fix_aggregation_change: Optional[List["str"]] = None,
|
|
109
112
|
):
|
|
110
113
|
self.key_index = key_index
|
|
111
114
|
self.values = values
|
|
@@ -123,6 +126,7 @@ class CubeViewDef:
|
|
|
123
126
|
self.name = name
|
|
124
127
|
self.no_index = no_index
|
|
125
128
|
self.plots = plots
|
|
129
|
+
self.fix_aggregation_change = fix_aggregation_change
|
|
126
130
|
|
|
127
131
|
def __repr__(self) -> str:
|
|
128
132
|
"usual"
|
|
@@ -750,6 +754,17 @@ class CubeLogs:
|
|
|
750
754
|
f"values={sorted(self.values)}"
|
|
751
755
|
)
|
|
752
756
|
|
|
757
|
+
if view_def.fix_aggregation_change and (
|
|
758
|
+
set(view_def.fix_aggregation_change) & set(self.keys_no_time)
|
|
759
|
+
):
|
|
760
|
+
# before aggregation, let's fix some keys whose values changed over time
|
|
761
|
+
data_to_process = self._fix_aggregation_change(
|
|
762
|
+
self.data,
|
|
763
|
+
list(set(view_def.fix_aggregation_change) & set(self.keys_no_time)),
|
|
764
|
+
)
|
|
765
|
+
else:
|
|
766
|
+
data_to_process = self.data
|
|
767
|
+
|
|
753
768
|
# aggregation
|
|
754
769
|
if key_agg:
|
|
755
770
|
final_stack = True
|
|
@@ -763,7 +778,7 @@ class CubeLogs:
|
|
|
763
778
|
print(f"[CubeLogs.view] aggregation of {set_key_agg}")
|
|
764
779
|
print(f"[CubeLogs.view] groupby {keys_no_agg}")
|
|
765
780
|
|
|
766
|
-
data_red =
|
|
781
|
+
data_red = data_to_process[[*keys_no_agg, *values]]
|
|
767
782
|
assert set(key_index) <= set(data_red.columns), (
|
|
768
783
|
f"view_def.name={view_def.name!r}, "
|
|
769
784
|
f"nnable to find {set(key_index) - set(data_red.columns)}, "
|
|
@@ -792,7 +807,7 @@ class CubeLogs:
|
|
|
792
807
|
key_index = self._filter_column(view_def.key_index, self.keys_time)
|
|
793
808
|
if verbose:
|
|
794
809
|
print(f"[CubeLogs.view] no aggregation, index={key_index}")
|
|
795
|
-
data =
|
|
810
|
+
data = data_to_process[[*self.keys_time, *values]]
|
|
796
811
|
set_all_keys = set(self.keys_time)
|
|
797
812
|
final_stack = False
|
|
798
813
|
|
|
@@ -829,7 +844,7 @@ class CubeLogs:
|
|
|
829
844
|
key_columns = sorted(set_key_columns)
|
|
830
845
|
unique = set()
|
|
831
846
|
|
|
832
|
-
|
|
847
|
+
# md = lambda s: {k: v for k, v in self.values_for_key.items() if k in s} # noqa: E731
|
|
833
848
|
all_cols = set(key_columns) | set(key_index) | set(key_agg) | unique
|
|
834
849
|
assert all_cols == set(self.keys_time), (
|
|
835
850
|
f"view_def.name={view_def.name!r}, "
|
|
@@ -892,7 +907,7 @@ class CubeLogs:
|
|
|
892
907
|
f"key={sorted(key_columns)}, key_agg={key_agg}, values={sorted(values)}, "
|
|
893
908
|
f"columns={sorted(data.columns)}, ignored={view_def.ignore_columns}, "
|
|
894
909
|
f"not unique={set(data.columns) - unique}"
|
|
895
|
-
f"\n--\n{not_unique.head()}"
|
|
910
|
+
f"\n--\n{not_unique.head(10)}"
|
|
896
911
|
)
|
|
897
912
|
|
|
898
913
|
# pivot
|
|
@@ -961,6 +976,70 @@ class CubeLogs:
|
|
|
961
976
|
print(f"[CubeLogs.view] -- done view {view_def.name!r}")
|
|
962
977
|
return (piv, view_def) if return_view_def else piv
|
|
963
978
|
|
|
979
|
+
def _fix_aggregation_change(
|
|
980
|
+
self,
|
|
981
|
+
data: pandas.DataFrame,
|
|
982
|
+
columns_to_fix: Union[str, List[str]],
|
|
983
|
+
overwrite_or_merge: bool = True,
|
|
984
|
+
) -> pandas.DataFrame:
|
|
985
|
+
"""
|
|
986
|
+
Fixes columns used to aggregate values because their meaning changed over time.
|
|
987
|
+
|
|
988
|
+
:param data: data to fix
|
|
989
|
+
:param columns_to_fix: list of columns to fix
|
|
990
|
+
:param overwrite_or_merge: if True, overwrite all values by the concatenation
|
|
991
|
+
of all existing values, if merge, merges existing values found
|
|
992
|
+
and grouped by the other keys
|
|
993
|
+
:return: fixed data
|
|
994
|
+
"""
|
|
995
|
+
if not isinstance(columns_to_fix, str):
|
|
996
|
+
for c in columns_to_fix:
|
|
997
|
+
data = self._fix_aggregation_change(data, c)
|
|
998
|
+
return data
|
|
999
|
+
# Let's process one column.
|
|
1000
|
+
keys = set(self.keys_time) - {columns_to_fix}
|
|
1001
|
+
select = data[self.keys_time]
|
|
1002
|
+
select_agg = select.groupby(list(keys)).count()
|
|
1003
|
+
assert select_agg[columns_to_fix].max() <= 1, (
|
|
1004
|
+
f"Column {columns_to_fix!r} has two distinct values at least for one date\n"
|
|
1005
|
+
f"{select_agg[select_agg[columns_to_fix] > 1]}"
|
|
1006
|
+
)
|
|
1007
|
+
|
|
1008
|
+
# unique value (to fill NaN)
|
|
1009
|
+
unique = "-".join(sorted(set(data[columns_to_fix].dropna())))
|
|
1010
|
+
|
|
1011
|
+
keys = set(self.keys_no_time) - {columns_to_fix}
|
|
1012
|
+
select = data[self.keys_no_time]
|
|
1013
|
+
select_agg = select.groupby(list(keys), as_index=True).apply(
|
|
1014
|
+
lambda x: "-".join(sorted(set(x[columns_to_fix].dropna()))), include_groups=False
|
|
1015
|
+
)
|
|
1016
|
+
select_agg = select_agg.to_frame(name=columns_to_fix)
|
|
1017
|
+
res = pandas.merge(
|
|
1018
|
+
data.drop([columns_to_fix], axis=1),
|
|
1019
|
+
select_agg,
|
|
1020
|
+
how="left",
|
|
1021
|
+
left_on=list(keys),
|
|
1022
|
+
right_index=True,
|
|
1023
|
+
)
|
|
1024
|
+
val = f"?{unique}?"
|
|
1025
|
+
res[columns_to_fix] = res[columns_to_fix].fillna(val).replace("", val)
|
|
1026
|
+
assert (
|
|
1027
|
+
data.shape == res.shape
|
|
1028
|
+
and sorted(data.columns) == sorted(res.columns)
|
|
1029
|
+
and sorted(data.index) == sorted(res.index)
|
|
1030
|
+
), (
|
|
1031
|
+
f"Shape should match, data.shape={data.shape}, res.shape={res.shape}, "
|
|
1032
|
+
f"lost={set(data.columns) - set(res.columns)}, "
|
|
1033
|
+
f"added={set(res.columns) - set(data.columns)}"
|
|
1034
|
+
)
|
|
1035
|
+
res = res[data.columns]
|
|
1036
|
+
assert data.columns.equals(res.columns) and data.index.equals(res.index), (
|
|
1037
|
+
f"Columns or index mismatch "
|
|
1038
|
+
f"data.columns.equals(res.columns)={data.columns.equals(res.columns)}, "
|
|
1039
|
+
f"data.index.equals(res.columns)={data.index.equals(res.columns)}, "
|
|
1040
|
+
)
|
|
1041
|
+
return res
|
|
1042
|
+
|
|
964
1043
|
def _dropna(
|
|
965
1044
|
self,
|
|
966
1045
|
data: pandas.DataFrame,
|
|
@@ -1090,7 +1169,8 @@ class CubeLogs:
|
|
|
1090
1169
|
assuming they should remain stale
|
|
1091
1170
|
:param sbs: configurations to compare side-by-side, this adds two tabs,
|
|
1092
1171
|
one gathering raw data about the two configurations, the other one
|
|
1093
|
-
is aggregated by metrics
|
|
1172
|
+
is aggregated by metrics, example:
|
|
1173
|
+
``=dict(CFA=dict(exporter="E1", opt="O"), CFB=dict(exporter="E2", opt="O"))``
|
|
1094
1174
|
"""
|
|
1095
1175
|
if verbose:
|
|
1096
1176
|
print(f"[CubeLogs.to_excel] create Excel file {output}, shape={self.shape}")
|
|
@@ -1532,6 +1612,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1532
1612
|
"n_node_initializer_small",
|
|
1533
1613
|
"n_node_layer_normalization",
|
|
1534
1614
|
"n_node_layer_normalization23",
|
|
1615
|
+
"n_node_random",
|
|
1535
1616
|
"n_node_reshape",
|
|
1536
1617
|
"n_node_rotary_embedding",
|
|
1537
1618
|
"n_node_rotary_embedding23",
|
|
@@ -1723,6 +1804,16 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1723
1804
|
+ gdf(df, "op_onnx__InstanceNormlization", 0)
|
|
1724
1805
|
+ gdf(df, "op_onnx__GroupNormalization", 0),
|
|
1725
1806
|
),
|
|
1807
|
+
n_node_random=lambda df: gpreserve(
|
|
1808
|
+
df,
|
|
1809
|
+
"time_latency_eager",
|
|
1810
|
+
gdf(df, "op_onnx__RandomNormal", 0)
|
|
1811
|
+
+ gdf(df, "op_onnx__RandomNormalLike", 0)
|
|
1812
|
+
+ gdf(df, "op_onnx__RandomUniform", 0)
|
|
1813
|
+
+ gdf(df, "op_onnx__RandomUniformLike", 0)
|
|
1814
|
+
+ gdf(df, "op_onnx__Multinomial", 0)
|
|
1815
|
+
+ gdf(df, "op_onnx__Bernoulli", 0),
|
|
1816
|
+
),
|
|
1726
1817
|
n_node_attention=lambda df: gpreserve(
|
|
1727
1818
|
df,
|
|
1728
1819
|
"time_latency_eager",
|
|
@@ -1886,6 +1977,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1886
1977
|
* **cmd:** command lines
|
|
1887
1978
|
* **raw-short:** raw data without all the unused columns
|
|
1888
1979
|
"""
|
|
1980
|
+
fix_aggregation_change = ["model_speedup_input_set", "model_test_with"]
|
|
1889
1981
|
fs = ["suite", "model_suite", "task", "model_name", "model_task"]
|
|
1890
1982
|
index_cols = self._filter_column(fs, self.keys_time)
|
|
1891
1983
|
assert index_cols, (
|
|
@@ -1984,6 +2076,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1984
2076
|
keep_columns_in_index=["suite"],
|
|
1985
2077
|
name="agg-suite",
|
|
1986
2078
|
order=order,
|
|
2079
|
+
fix_aggregation_change=fix_aggregation_change,
|
|
1987
2080
|
),
|
|
1988
2081
|
"agg-all": lambda: CubeViewDef(
|
|
1989
2082
|
key_index=index_cols,
|
|
@@ -2014,6 +2107,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
2014
2107
|
name="agg-all",
|
|
2015
2108
|
order=order,
|
|
2016
2109
|
plots=True,
|
|
2110
|
+
fix_aggregation_change=fix_aggregation_change,
|
|
2017
2111
|
),
|
|
2018
2112
|
"disc": lambda: CubeViewDef(
|
|
2019
2113
|
key_index=index_cols,
|
|
@@ -2023,6 +2117,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
2023
2117
|
f_highlight=f_disc,
|
|
2024
2118
|
name="disc",
|
|
2025
2119
|
order=order,
|
|
2120
|
+
fix_aggregation_change=fix_aggregation_change,
|
|
2026
2121
|
),
|
|
2027
2122
|
"speedup": lambda: CubeViewDef(
|
|
2028
2123
|
key_index=index_cols,
|
|
@@ -2032,6 +2127,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
2032
2127
|
f_highlight=f_speedup,
|
|
2033
2128
|
name="speedup",
|
|
2034
2129
|
order=order,
|
|
2130
|
+
fix_aggregation_change=fix_aggregation_change,
|
|
2035
2131
|
),
|
|
2036
2132
|
"counts": lambda: CubeViewDef(
|
|
2037
2133
|
key_index=index_cols,
|
|
@@ -2048,6 +2144,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
2048
2144
|
keep_columns_in_index=["suite"],
|
|
2049
2145
|
name="peak-gpu",
|
|
2050
2146
|
order=order,
|
|
2147
|
+
fix_aggregation_change=fix_aggregation_change,
|
|
2051
2148
|
),
|
|
2052
2149
|
"time": lambda: CubeViewDef(
|
|
2053
2150
|
key_index=index_cols,
|
|
@@ -2058,6 +2155,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
2058
2155
|
keep_columns_in_index=["suite"],
|
|
2059
2156
|
name="time",
|
|
2060
2157
|
order=order,
|
|
2158
|
+
fix_aggregation_change=fix_aggregation_change,
|
|
2061
2159
|
),
|
|
2062
2160
|
"time_export": lambda: CubeViewDef(
|
|
2063
2161
|
key_index=index_cols,
|
|
@@ -2066,6 +2164,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
2066
2164
|
keep_columns_in_index=["suite"],
|
|
2067
2165
|
name="time_export",
|
|
2068
2166
|
order=order,
|
|
2167
|
+
fix_aggregation_change=fix_aggregation_change,
|
|
2069
2168
|
),
|
|
2070
2169
|
"err": lambda: CubeViewDef(
|
|
2071
2170
|
key_index=index_cols,
|
|
@@ -2076,6 +2175,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
2076
2175
|
keep_columns_in_index=["suite"],
|
|
2077
2176
|
name="err",
|
|
2078
2177
|
order=order,
|
|
2178
|
+
fix_aggregation_change=fix_aggregation_change,
|
|
2079
2179
|
),
|
|
2080
2180
|
"bucket-speedup": lambda: CubeViewDef(
|
|
2081
2181
|
key_index=index_cols,
|
|
@@ -2085,6 +2185,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
2085
2185
|
name="bucket-speedup",
|
|
2086
2186
|
f_highlight=f_bucket,
|
|
2087
2187
|
order=order,
|
|
2188
|
+
fix_aggregation_change=fix_aggregation_change,
|
|
2088
2189
|
),
|
|
2089
2190
|
"onnx": lambda: CubeViewDef(
|
|
2090
2191
|
key_index=index_cols,
|
|
@@ -2103,6 +2204,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
2103
2204
|
keep_columns_in_index=["suite"],
|
|
2104
2205
|
name="onnx",
|
|
2105
2206
|
order=order,
|
|
2207
|
+
fix_aggregation_change=fix_aggregation_change,
|
|
2106
2208
|
),
|
|
2107
2209
|
"raw-short": lambda: CubeViewDef(
|
|
2108
2210
|
key_index=self.keys_time,
|
|
@@ -2111,6 +2213,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
2111
2213
|
keep_columns_in_index=["suite"],
|
|
2112
2214
|
name="raw-short",
|
|
2113
2215
|
no_index=True,
|
|
2216
|
+
fix_aggregation_change=fix_aggregation_change,
|
|
2114
2217
|
),
|
|
2115
2218
|
}
|
|
2116
2219
|
|
|
@@ -2123,6 +2226,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
2123
2226
|
keep_columns_in_index=["suite"],
|
|
2124
2227
|
name="cmd",
|
|
2125
2228
|
order=order,
|
|
2229
|
+
fix_aggregation_change=fix_aggregation_change,
|
|
2126
2230
|
)
|
|
2127
2231
|
|
|
2128
2232
|
assert name in implemented_views or name in {"cmd"}, (
|
|
@@ -52,7 +52,7 @@ def proto_from_array(
|
|
|
52
52
|
|
|
53
53
|
tensor = TensorProto()
|
|
54
54
|
tensor.dims.extend(arr_cpu.shape)
|
|
55
|
-
tensor.name = name
|
|
55
|
+
tensor.name = name or ""
|
|
56
56
|
itype = dtype_to_tensor_dtype(arr_cpu.dtype)
|
|
57
57
|
assert not hasattr(TensorProto, "INT4") or itype not in {
|
|
58
58
|
TensorProto.INT4,
|
|
@@ -1,11 +1,13 @@
|
|
|
1
|
+
import copy
|
|
1
2
|
import importlib.util
|
|
2
3
|
import os
|
|
4
|
+
import re
|
|
3
5
|
import requests
|
|
4
6
|
import sys
|
|
5
7
|
from pathlib import Path
|
|
6
|
-
from typing import Any, Optional, Union
|
|
8
|
+
from typing import Any, Dict, List, Optional, Union
|
|
7
9
|
from urllib.parse import urlparse
|
|
8
|
-
from onnx import ModelProto, TensorProto
|
|
10
|
+
from onnx import ModelProto, TensorProto, load as load_model
|
|
9
11
|
|
|
10
12
|
CACHE_SUBDIR = "onnx-diagnostic"
|
|
11
13
|
|
|
@@ -337,3 +339,131 @@ def create_model_builder(
|
|
|
337
339
|
# onnx_model.make_genai_config(hf_name, extra_kwargs, output_dir)
|
|
338
340
|
# onnx_model.save_processing(hf_name, extra_kwargs, output_dir)
|
|
339
341
|
return onnx_model
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def find_names_pattern(names: List[str]) -> str:
|
|
345
|
+
"""
|
|
346
|
+
Finds a repeatable patterns in a list of names.
|
|
347
|
+
It tries to locate the figures.
|
|
348
|
+
|
|
349
|
+
.. runpython::
|
|
350
|
+
:showcode:
|
|
351
|
+
|
|
352
|
+
from onnx_diagnostic.helpers.model_builder_helper import find_names_pattern
|
|
353
|
+
pattern = find_names_pattern(["past_key_values_key_0", "past_key_values_key_1"])
|
|
354
|
+
print(pattern)
|
|
355
|
+
"""
|
|
356
|
+
patterns = [re.sub(r"(\d+)", r"%d", t) for t in names]
|
|
357
|
+
unique = set(patterns)
|
|
358
|
+
assert (
|
|
359
|
+
len(unique) == 1
|
|
360
|
+
), f"Unable to guess a pattern from {names} which led to the unique patterns {unique}"
|
|
361
|
+
return patterns[0]
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def make_genai_config(
|
|
365
|
+
config,
|
|
366
|
+
onnx_filename: str,
|
|
367
|
+
) -> Dict:
|
|
368
|
+
"""
|
|
369
|
+
Creates genai config file for a model.
|
|
370
|
+
|
|
371
|
+
:param config: configuration from transformers
|
|
372
|
+
:param onnx_filename: onnx configuration
|
|
373
|
+
:return: configuration
|
|
374
|
+
"""
|
|
375
|
+
onx = load_model(onnx_filename, load_external_data=False)
|
|
376
|
+
config = copy.deepcopy(config)
|
|
377
|
+
defaults = {
|
|
378
|
+
"bos_token_id": None,
|
|
379
|
+
"do_sample": False,
|
|
380
|
+
"eos_token_id": None,
|
|
381
|
+
"pad_token_id": None,
|
|
382
|
+
"temperature": 1.0,
|
|
383
|
+
"top_k": 50,
|
|
384
|
+
"top_p": 1.0,
|
|
385
|
+
}
|
|
386
|
+
for key, default_val in defaults.items():
|
|
387
|
+
if not hasattr(config, key):
|
|
388
|
+
setattr(config, key, default_val)
|
|
389
|
+
|
|
390
|
+
bos_token_id = (
|
|
391
|
+
config.bos_token_id
|
|
392
|
+
if hasattr(config, "bos_token_id") and config.bos_token_id is not None
|
|
393
|
+
else 1
|
|
394
|
+
)
|
|
395
|
+
eos_token_id = config.eos_token_id
|
|
396
|
+
pad_token_id = (
|
|
397
|
+
config.pad_token_id
|
|
398
|
+
if hasattr(config, "pad_token_id") and config.pad_token_id is not None
|
|
399
|
+
else (
|
|
400
|
+
config.eos_token_id[0]
|
|
401
|
+
if isinstance(config.eos_token_id, list)
|
|
402
|
+
else config.eos_token_id
|
|
403
|
+
)
|
|
404
|
+
)
|
|
405
|
+
input_names = [i.name for i in onx.graph.input]
|
|
406
|
+
output_names = [i.name for i in onx.graph.output]
|
|
407
|
+
past_key_values = [s for s in input_names if s.startswith("past_key_value")]
|
|
408
|
+
first = [i for i in onx.graph.input if i.name == past_key_values[0]][0] # noqa: RUF015
|
|
409
|
+
shape = tuple(d.dim_value or d.dim_param for d in first.type.tensor_type.shape.dim)
|
|
410
|
+
return {
|
|
411
|
+
"model": {
|
|
412
|
+
"bos_token_id": bos_token_id,
|
|
413
|
+
"context_length": config.max_position_embeddings,
|
|
414
|
+
"decoder": {
|
|
415
|
+
"session_options": {
|
|
416
|
+
"log_id": "onnxruntime-genai",
|
|
417
|
+
"provider_options": [],
|
|
418
|
+
},
|
|
419
|
+
"filename": os.path.split(onnx_filename)[-1],
|
|
420
|
+
"head_size": shape[-1],
|
|
421
|
+
"hidden_size": config.hidden_size,
|
|
422
|
+
"inputs": {
|
|
423
|
+
"input_ids": input_names[0],
|
|
424
|
+
"attention_mask": input_names[1],
|
|
425
|
+
"past_key_names": find_names_pattern(input_names[2::2]),
|
|
426
|
+
"past_value_names": find_names_pattern(input_names[3::2]),
|
|
427
|
+
},
|
|
428
|
+
"outputs": {
|
|
429
|
+
"logits": output_names[0],
|
|
430
|
+
"present_key_names": find_names_pattern(output_names[1::2]),
|
|
431
|
+
"present_value_names": find_names_pattern(output_names[2::2]),
|
|
432
|
+
},
|
|
433
|
+
"num_attention_heads": config.num_attention_heads,
|
|
434
|
+
"num_hidden_layers": len(past_key_values) // 2,
|
|
435
|
+
"num_key_value_heads": shape[1],
|
|
436
|
+
},
|
|
437
|
+
"eos_token_id": eos_token_id,
|
|
438
|
+
"pad_token_id": pad_token_id,
|
|
439
|
+
"type": config.model_type,
|
|
440
|
+
# if "For" in self.model_type else len(self.model_type)].lower(),
|
|
441
|
+
"vocab_size": config.vocab_size,
|
|
442
|
+
},
|
|
443
|
+
"search": {
|
|
444
|
+
"diversity_penalty": (
|
|
445
|
+
config.diversity_penalty if hasattr(config, "diversity_penalty") else 0.0
|
|
446
|
+
),
|
|
447
|
+
"do_sample": config.do_sample if hasattr(config, "do_sample") else False,
|
|
448
|
+
"early_stopping": True,
|
|
449
|
+
"length_penalty": (
|
|
450
|
+
config.length_penalty if hasattr(config, "length_penalty") else 1.0
|
|
451
|
+
),
|
|
452
|
+
"max_length": config.max_position_embeddings,
|
|
453
|
+
"min_length": 0,
|
|
454
|
+
"no_repeat_ngram_size": (
|
|
455
|
+
config.no_repeat_ngram_size if hasattr(config, "no_repeat_ngram_size") else 0
|
|
456
|
+
),
|
|
457
|
+
"num_beams": config.num_beams if hasattr(config, "num_beams") else 1,
|
|
458
|
+
"num_return_sequences": (
|
|
459
|
+
config.num_return_sequences if hasattr(config, "num_return_sequences") else 1
|
|
460
|
+
),
|
|
461
|
+
"past_present_share_buffer": False,
|
|
462
|
+
"repetition_penalty": (
|
|
463
|
+
config.repetition_penalty if hasattr(config, "repetition_penalty") else 1.0
|
|
464
|
+
),
|
|
465
|
+
"temperature": config.temperature if hasattr(config, "temperature") else 1.0,
|
|
466
|
+
"top_k": config.top_k if hasattr(config, "top_k") else 50,
|
|
467
|
+
"top_p": config.top_p if hasattr(config, "top_p") else 1.0,
|
|
468
|
+
},
|
|
469
|
+
}
|
|
@@ -331,7 +331,7 @@ def onnx_dtype_name(itype: int, exc: bool = True) -> str:
|
|
|
331
331
|
print(onnx_dtype_name(7))
|
|
332
332
|
"""
|
|
333
333
|
for k in dir(TensorProto):
|
|
334
|
-
if
|
|
334
|
+
if k.upper() == k and k != "EXTERNAL":
|
|
335
335
|
v = getattr(TensorProto, k)
|
|
336
336
|
if v == itype:
|
|
337
337
|
return k
|
|
@@ -135,6 +135,10 @@ class _InferenceSession:
|
|
|
135
135
|
self.sess = sess
|
|
136
136
|
self.input_names = [i.name for i in sess.get_inputs()]
|
|
137
137
|
self.output_names = [i.name for i in sess.get_outputs()]
|
|
138
|
+
self.input_shapes = [i.shape for i in sess.get_inputs()]
|
|
139
|
+
self.output_shapes = [i.shape for i in sess.get_outputs()]
|
|
140
|
+
self.input_types = [i.type for i in sess.get_inputs()]
|
|
141
|
+
self.output_types = [i.type for i in sess.get_outputs()]
|
|
138
142
|
self.torch = torch
|
|
139
143
|
self.nvtx = nvtx
|
|
140
144
|
self.run_options = onnxruntime.RunOptions()
|