onnx-diagnostic 0.8.10__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +136 -140
- onnx_diagnostic/ci_models/data/Blanca_Lake_Hudak.jpg +0 -0
- onnx_diagnostic/ci_models/data/Ice_worm_glacier.jpg +0 -0
- onnx_diagnostic/ci_models/data/__init__.py +0 -0
- onnx_diagnostic/ci_models/export_phi4_mm.py +10 -7
- onnx_diagnostic/export/api.py +13 -4
- onnx_diagnostic/export/dynamic_shapes.py +1 -1
- onnx_diagnostic/export/validate.py +2 -0
- onnx_diagnostic/ext_test_case.py +32 -15
- onnx_diagnostic/helpers/args_helper.py +1 -0
- onnx_diagnostic/helpers/bench_run.py +0 -1
- onnx_diagnostic/helpers/cache_helper.py +102 -36
- onnx_diagnostic/helpers/doc_helper.py +7 -4
- onnx_diagnostic/helpers/graph_helper.py +6 -6
- onnx_diagnostic/helpers/helper.py +39 -0
- onnx_diagnostic/helpers/log_helper.py +37 -14
- onnx_diagnostic/helpers/memory_peak.py +5 -1
- onnx_diagnostic/helpers/mini_onnx_builder.py +9 -14
- onnx_diagnostic/helpers/model_builder_helper.py +1 -1
- onnx_diagnostic/helpers/onnx_helper.py +283 -110
- onnx_diagnostic/helpers/ort_session.py +5 -2
- onnx_diagnostic/helpers/rt_helper.py +53 -9
- onnx_diagnostic/helpers/torch_helper.py +15 -11
- onnx_diagnostic/investigate/__init__.py +0 -0
- onnx_diagnostic/investigate/input_observer.py +970 -0
- onnx_diagnostic/reference/evaluator.py +0 -1
- onnx_diagnostic/reference/ort_evaluator.py +0 -1
- onnx_diagnostic/reference/report_results_comparison.py +9 -3
- onnx_diagnostic/reference/torch_evaluator.py +5 -1
- onnx_diagnostic/reference/torch_ops/_op_run.py +3 -5
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +1 -1
- onnx_diagnostic/tasks/feature_extraction.py +0 -1
- onnx_diagnostic/torch_export_patches/__init__.py +0 -1
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +32 -14
- onnx_diagnostic/torch_export_patches/patch_module.py +1 -1
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +107 -6
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +13 -3
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +1 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +70 -23
- onnx_diagnostic/torch_models/code_sample.py +5 -10
- onnx_diagnostic/torch_models/hghub/hub_data.py +2 -4
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +6 -12
- onnx_diagnostic/torch_models/validate.py +1 -1
- onnx_diagnostic/torch_onnx/compare.py +0 -1
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -1
- onnx_diagnostic/torch_onnx/sbs.py +1 -1
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +2 -4
- onnx_diagnostic/typing.py +15 -0
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/RECORD +55 -50
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/WHEEL +1 -1
- onnx_diagnostic/api.py +0 -15
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/top_level.txt +0 -0
|
@@ -29,10 +29,10 @@ class CubeViewDef:
|
|
|
29
29
|
:param order: to reorder key in columns index
|
|
30
30
|
:param key_agg: aggregate according to these columns before
|
|
31
31
|
creating the view
|
|
32
|
-
:param agg_args: see :meth:`pandas.
|
|
32
|
+
:param agg_args: see :meth:`pandas.api.typing.DataFrameGroupBy.agg`,
|
|
33
33
|
it can be also a callable to return a different aggregation
|
|
34
34
|
method depending on the column name
|
|
35
|
-
:param agg_kwargs: see :meth:`pandas.
|
|
35
|
+
:param agg_kwargs: see :meth:`pandas.api.typing.DataFrameGroupBy.agg`
|
|
36
36
|
:param agg_multi: aggregation over multiple columns
|
|
37
37
|
:param ignore_columns: ignore the following columns if known to overload the view
|
|
38
38
|
:param keep_columns_in_index: keeps the columns even if there is only one unique value
|
|
@@ -98,7 +98,7 @@ class CubeViewDef:
|
|
|
98
98
|
agg_args: Union[Sequence[Any], Callable[[str], Any]] = ("sum",),
|
|
99
99
|
agg_kwargs: Optional[Dict[str, Any]] = None,
|
|
100
100
|
agg_multi: Optional[
|
|
101
|
-
Dict[str, Callable[[pandas.
|
|
101
|
+
Dict[str, Callable[[pandas.api.typing.DataFrameGroupBy], pandas.Series]]
|
|
102
102
|
] = None,
|
|
103
103
|
ignore_columns: Optional[Sequence[str]] = None,
|
|
104
104
|
keep_columns_in_index: Optional[Sequence[str]] = None,
|
|
@@ -365,6 +365,7 @@ class CubePlot:
|
|
|
365
365
|
# This is very slow
|
|
366
366
|
# ddd.plot(ax=axs[row, ii],linewidth=3)
|
|
367
367
|
for jj in range(ddd.shape[1]):
|
|
368
|
+
# pyrefly: ignore[bad-index]
|
|
368
369
|
axs[row, ii].plot(x, ddd.iloc[:, jj], lw=3, label=ddd.columns[jj])
|
|
369
370
|
axs[row, ii].set_title(f"{c}{title_suffix}")
|
|
370
371
|
rotate_align(axs[row, ii])
|
|
@@ -480,7 +481,9 @@ class CubeLogs:
|
|
|
480
481
|
elif isinstance(self._data, list) and all(isinstance(r, dict) for r in self._data):
|
|
481
482
|
if verbose:
|
|
482
483
|
print(f"[CubeLogs.load] load from list of dicts, n={len(self._data)}")
|
|
483
|
-
self.data = pandas.DataFrame(
|
|
484
|
+
self.data = pandas.DataFrame(
|
|
485
|
+
self.post_load_process_piece(pandas.DataFrame(self._data), unique=True)
|
|
486
|
+
)
|
|
484
487
|
if verbose:
|
|
485
488
|
print(f"[CubeLogs.load] after postprocessing shape={self.data.shape}")
|
|
486
489
|
elif isinstance(self._data, list) and all(
|
|
@@ -614,7 +617,7 @@ class CubeLogs:
|
|
|
614
617
|
|
|
615
618
|
def _process_formula(
|
|
616
619
|
self, formula: Union[str, Callable[[pandas.DataFrame], pandas.Series]]
|
|
617
|
-
) -> Callable[[pandas.DataFrame], pandas.Series]:
|
|
620
|
+
) -> Callable[[pandas.DataFrame], Optional[pandas.Series]]:
|
|
618
621
|
assert callable(formula), f"formula={formula!r} is not supported."
|
|
619
622
|
return formula
|
|
620
623
|
|
|
@@ -625,9 +628,11 @@ class CubeLogs:
|
|
|
625
628
|
return self.data.shape
|
|
626
629
|
|
|
627
630
|
@property
|
|
628
|
-
def columns(self) -> Sequence[
|
|
631
|
+
def columns(self) -> Sequence[Any]:
|
|
629
632
|
"Returns the columns."
|
|
630
633
|
assert hasattr(self, "data"), "Method load was not called"
|
|
634
|
+
assert isinstance(self.data, pandas.DataFrame) # type checking
|
|
635
|
+
# pyrefly: ignore[bad-return]
|
|
631
636
|
return self.data.columns
|
|
632
637
|
|
|
633
638
|
def _preprocess(self):
|
|
@@ -647,7 +652,7 @@ class CubeLogs:
|
|
|
647
652
|
)
|
|
648
653
|
assert gr.shape[0] > 0, (
|
|
649
654
|
f"Something went wrong after the groupby.\n"
|
|
650
|
-
f"{cp[[*self.
|
|
655
|
+
f"{cp[[*self.keys_no_time, self.time, '__index__']].head().T}"
|
|
651
656
|
)
|
|
652
657
|
filtered = pandas.merge(cp, gr, on=["__index__", *self.keys_time])
|
|
653
658
|
assert filtered.shape[0] <= self.data.shape[0], (
|
|
@@ -797,6 +802,7 @@ class CubeLogs:
|
|
|
797
802
|
if view_def.agg_multi:
|
|
798
803
|
append = []
|
|
799
804
|
for k, f in view_def.agg_multi.items():
|
|
805
|
+
# pyrefly: ignore[no-matching-overload]
|
|
800
806
|
cv = grouped_data.apply(f, include_groups=False)
|
|
801
807
|
append.append(cv.to_frame(k))
|
|
802
808
|
data = pandas.concat([data, *append], axis=1)
|
|
@@ -1020,8 +1026,10 @@ class CubeLogs:
|
|
|
1020
1026
|
|
|
1021
1027
|
keys = set(self.keys_no_time) - {columns_to_fix}
|
|
1022
1028
|
select = data[self.keys_no_time]
|
|
1029
|
+
# pyrefly: ignore[no-matching-overload]
|
|
1023
1030
|
select_agg = select.groupby(list(keys), as_index=True).apply(
|
|
1024
|
-
lambda x: "-".join(sorted(set(x[columns_to_fix].dropna()))),
|
|
1031
|
+
lambda x: "-".join(sorted(set(x[columns_to_fix].dropna()))),
|
|
1032
|
+
include_groups=False,
|
|
1025
1033
|
)
|
|
1026
1034
|
select_agg = select_agg.to_frame(name=columns_to_fix)
|
|
1027
1035
|
res = pandas.merge(
|
|
@@ -1137,6 +1145,7 @@ class CubeLogs:
|
|
|
1137
1145
|
if len(nonan) > 0:
|
|
1138
1146
|
obs.update(dict(count=len(nonan)))
|
|
1139
1147
|
if is_numeric_dtype(nonan) and not pandas.api.types.is_object_dtype(nonan):
|
|
1148
|
+
# pyrefly: ignore[no-matching-overload]
|
|
1140
1149
|
obs.update(
|
|
1141
1150
|
dict(
|
|
1142
1151
|
min=nonan.min(),
|
|
@@ -1208,12 +1217,15 @@ class CubeLogs:
|
|
|
1208
1217
|
df.to_excel(writer, sheet_name=main, freeze_panes=(1, 1))
|
|
1209
1218
|
|
|
1210
1219
|
time_mask_view: Dict[str, pandas.DataFrame] = {}
|
|
1220
|
+
df = None
|
|
1211
1221
|
for name, view in views.items():
|
|
1212
1222
|
if view is None:
|
|
1213
1223
|
continue
|
|
1214
1224
|
df, tview = self.view(view, return_view_def=True, verbose=max(verbose - 1, 0))
|
|
1215
1225
|
if cube_time is not None:
|
|
1216
1226
|
cube_mask = cube_time.view(view)
|
|
1227
|
+
assert isinstance(cube_mask, pandas.DataFrame) # type checking
|
|
1228
|
+
assert isinstance(df, pandas.DataFrame) # type checking
|
|
1217
1229
|
aligned = align_dataframe_with(cube_mask, df)
|
|
1218
1230
|
if aligned is not None:
|
|
1219
1231
|
assert aligned.shape == df.shape, (
|
|
@@ -1228,6 +1240,7 @@ class CubeLogs:
|
|
|
1228
1240
|
)
|
|
1229
1241
|
if tview is None:
|
|
1230
1242
|
continue
|
|
1243
|
+
assert isinstance(df, pandas.DataFrame) # type checking
|
|
1231
1244
|
memory = df.memory_usage(deep=True).sum()
|
|
1232
1245
|
if verbose:
|
|
1233
1246
|
print(
|
|
@@ -1269,7 +1282,9 @@ class CubeLogs:
|
|
|
1269
1282
|
sheet_name=name,
|
|
1270
1283
|
freeze_panes=(df.columns.nlevels + 1, df.index.nlevels),
|
|
1271
1284
|
)
|
|
1285
|
+
# pyrefly: ignore[missing-attribute]
|
|
1272
1286
|
f_highlights[name] = tview.f_highlight
|
|
1287
|
+
# pyrefly: ignore[missing-attribute]
|
|
1273
1288
|
if tview.plots:
|
|
1274
1289
|
plots.append(
|
|
1275
1290
|
CubePlot(
|
|
@@ -1282,6 +1297,7 @@ class CubeLogs:
|
|
|
1282
1297
|
if self.time in df.columns.names
|
|
1283
1298
|
else CubePlot(df, kind="barh", orientation="row", split=True)
|
|
1284
1299
|
)
|
|
1300
|
+
assert isinstance(df, pandas.DataFrame) # type checking
|
|
1285
1301
|
if raw:
|
|
1286
1302
|
assert main not in views, f"{main!r} is duplicated in views {sorted(views)}"
|
|
1287
1303
|
# Too long.
|
|
@@ -1439,7 +1455,7 @@ class CubeLogs:
|
|
|
1439
1455
|
len(configs) >= 2
|
|
1440
1456
|
), f"A side by side needs at least two configs but configs={configs}"
|
|
1441
1457
|
set_keys_time = set(self.keys_time)
|
|
1442
|
-
columns_index = None
|
|
1458
|
+
columns_index: Optional[List[str]] = None
|
|
1443
1459
|
data_list = []
|
|
1444
1460
|
for name_conf, conf in configs.items():
|
|
1445
1461
|
if columns_index is None:
|
|
@@ -1478,9 +1494,11 @@ class CubeLogs:
|
|
|
1478
1494
|
|
|
1479
1495
|
# add metrics
|
|
1480
1496
|
index_column_name = list(view_res.columns.names).index(column_name)
|
|
1497
|
+
# pyrefly: ignore[missing-attribute]
|
|
1481
1498
|
index_metrics = list(view_res.columns.names).index("METRICS")
|
|
1482
1499
|
|
|
1483
1500
|
def _mkc(m, s):
|
|
1501
|
+
# pyrefly: ignore[missing-attribute]
|
|
1484
1502
|
c = ["" for c in view_res.columns.names]
|
|
1485
1503
|
c[index_column_name] = s
|
|
1486
1504
|
c[index_metrics] = m
|
|
@@ -1515,7 +1533,9 @@ class CubeLogs:
|
|
|
1515
1533
|
ci["CONF"] = iname
|
|
1516
1534
|
cj["CONF"] = jname
|
|
1517
1535
|
|
|
1536
|
+
# pyrefly: ignore[bad-index]
|
|
1518
1537
|
ci_name = tuple(ci[n] for n in view_res.columns.names)
|
|
1538
|
+
# pyrefly: ignore[bad-index]
|
|
1519
1539
|
cj_name = tuple(cj[n] for n in view_res.columns.names)
|
|
1520
1540
|
assert ci_name in view_res.columns or cj_name in view_res.columns, (
|
|
1521
1541
|
f"Unable to find column {ci_name} or {cj_name} "
|
|
@@ -1562,6 +1582,7 @@ class CubeLogs:
|
|
|
1562
1582
|
}
|
|
1563
1583
|
flat = view_res.groupby(self.time).agg(aggs)
|
|
1564
1584
|
flat = flat.stack("METRICS", future_stack=True)
|
|
1585
|
+
# pyrefly: ignore[bad-return, missing-attribute]
|
|
1565
1586
|
return res, flat, view_res.T.sort_index().T
|
|
1566
1587
|
|
|
1567
1588
|
|
|
@@ -1679,7 +1700,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1679
1700
|
|
|
1680
1701
|
def _process_formula(
|
|
1681
1702
|
self, formula: Union[str, Callable[[pandas.DataFrame], pandas.Series]]
|
|
1682
|
-
) -> Callable[[pandas.DataFrame], pandas.Series]:
|
|
1703
|
+
) -> Callable[[pandas.DataFrame], Optional[pandas.Series]]:
|
|
1683
1704
|
"""
|
|
1684
1705
|
Processes a formula, converting it into a function.
|
|
1685
1706
|
|
|
@@ -1726,6 +1747,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1726
1747
|
f"{pprint.pformat(sorted(columns))}"
|
|
1727
1748
|
)
|
|
1728
1749
|
# return lambda df: df["time_latency_eager"] / df["time_latency"]
|
|
1750
|
+
# pyrefly: ignore[no-matching-overload]
|
|
1729
1751
|
return lambda df: pandas.cut(
|
|
1730
1752
|
df["speedup"], bins=BUCKET_SCALES, right=False, duplicates="raise"
|
|
1731
1753
|
)
|
|
@@ -1733,9 +1755,9 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1733
1755
|
if formula == "ERR1":
|
|
1734
1756
|
columns = set(self._filter_column(["^ERR_.*"], self.data.columns))
|
|
1735
1757
|
if not columns:
|
|
1736
|
-
return lambda df:
|
|
1758
|
+
return lambda df: None
|
|
1737
1759
|
|
|
1738
|
-
def first_err(df: pandas.DataFrame) -> pandas.Series:
|
|
1760
|
+
def first_err(df: pandas.DataFrame) -> Optional[pandas.Series]:
|
|
1739
1761
|
ordered = [
|
|
1740
1762
|
c
|
|
1741
1763
|
for c in [
|
|
@@ -1752,7 +1774,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1752
1774
|
]
|
|
1753
1775
|
if c in df.columns
|
|
1754
1776
|
]
|
|
1755
|
-
res = None
|
|
1777
|
+
res: Optional[pandas.Series] = None
|
|
1756
1778
|
for c in ordered:
|
|
1757
1779
|
if res is None:
|
|
1758
1780
|
res = df[c].fillna("")
|
|
@@ -1949,6 +1971,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1949
1971
|
f"{pprint.pformat(sorted(self.data.columns))}"
|
|
1950
1972
|
)
|
|
1951
1973
|
|
|
1974
|
+
# pyrefly: ignore[bad-override]
|
|
1952
1975
|
def view(
|
|
1953
1976
|
self,
|
|
1954
1977
|
view_def: Optional[Union[str, CubeViewDef]],
|
|
@@ -2265,7 +2288,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
2265
2288
|
if unique:
|
|
2266
2289
|
return df
|
|
2267
2290
|
cols = self._filter_column(self._keys, df)
|
|
2268
|
-
res = None
|
|
2291
|
+
res: Optional[pandas.DataFrame] = None
|
|
2269
2292
|
for c in cols:
|
|
2270
2293
|
if df[c].isna().any():
|
|
2271
2294
|
# Missing values for keys are not supposed to happen.
|
|
@@ -103,6 +103,7 @@ def _process_memory_spy(conn):
|
|
|
103
103
|
process = psutil.Process(pid)
|
|
104
104
|
|
|
105
105
|
if cuda:
|
|
106
|
+
# pyrefly: ignore[missing-import]
|
|
106
107
|
from pynvml import (
|
|
107
108
|
nvmlDeviceGetCount,
|
|
108
109
|
nvmlDeviceGetHandleByIndex,
|
|
@@ -131,6 +132,7 @@ def _process_memory_spy(conn):
|
|
|
131
132
|
mem = process.memory_info().rss
|
|
132
133
|
cpu.update(mem)
|
|
133
134
|
if cuda:
|
|
135
|
+
# pyrefly: ignore[unbound-name]
|
|
134
136
|
for r, g in zip(gpu_used(), gpus):
|
|
135
137
|
g.update(r)
|
|
136
138
|
if conn.poll(timeout=timeout):
|
|
@@ -142,6 +144,7 @@ def _process_memory_spy(conn):
|
|
|
142
144
|
end = process.memory_info().rss
|
|
143
145
|
cpu.update(end)
|
|
144
146
|
if cuda:
|
|
147
|
+
# pyrefly: ignore[unbound-name]
|
|
145
148
|
for r, g in zip(gpu_used(), gpus):
|
|
146
149
|
g.update(r)
|
|
147
150
|
|
|
@@ -151,6 +154,7 @@ def _process_memory_spy(conn):
|
|
|
151
154
|
for g in gpus:
|
|
152
155
|
g.send(conn)
|
|
153
156
|
if cuda:
|
|
157
|
+
# pyrefly: ignore[unbound-name]
|
|
154
158
|
nvmlShutdown()
|
|
155
159
|
conn.close()
|
|
156
160
|
|
|
@@ -217,7 +221,7 @@ def start_spying_on(
|
|
|
217
221
|
Starts the memory spy. The function starts another
|
|
218
222
|
process spying on the one sent as an argument.
|
|
219
223
|
|
|
220
|
-
:param pid: process id to spy or the
|
|
224
|
+
:param pid: process id to spy or the current one.
|
|
221
225
|
:param delay: delay between two measures.
|
|
222
226
|
:param cuda: True or False to get memory for cuda devices
|
|
223
227
|
|
|
@@ -8,11 +8,6 @@ import torch
|
|
|
8
8
|
from .onnx_helper import dtype_to_tensor_dtype, tensor_dtype_to_np_dtype, from_array_extended
|
|
9
9
|
from . import string_type
|
|
10
10
|
|
|
11
|
-
STORAGE_TYPE = {
|
|
12
|
-
TensorProto.FLOAT16: np.int16,
|
|
13
|
-
TensorProto.BFLOAT16: np.int16,
|
|
14
|
-
}
|
|
15
|
-
|
|
16
11
|
|
|
17
12
|
def proto_from_array(
|
|
18
13
|
arr: torch.Tensor,
|
|
@@ -67,13 +62,13 @@ def proto_from_array(
|
|
|
67
62
|
byte_data = (ctypes.c_ubyte * numel * element_size).from_address(np_arr.data_ptr())
|
|
68
63
|
tensor.raw_data = bytes(byte_data)
|
|
69
64
|
if sys.byteorder == "big":
|
|
70
|
-
np_dtype = tensor_dtype_to_np_dtype(
|
|
71
|
-
np.
|
|
65
|
+
np_dtype = tensor_dtype_to_np_dtype(tensor.data_type)
|
|
66
|
+
np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap(inplace=True)
|
|
72
67
|
else:
|
|
73
68
|
tensor.raw_data = np_arr.tobytes()
|
|
74
69
|
if sys.byteorder == "big":
|
|
75
70
|
np_dtype = tensor_dtype_to_np_dtype(tensor.data_type)
|
|
76
|
-
np.
|
|
71
|
+
np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap(inplace=True)
|
|
77
72
|
|
|
78
73
|
return tensor
|
|
79
74
|
|
|
@@ -133,6 +128,7 @@ class MiniOnnxBuilder:
|
|
|
133
128
|
}
|
|
134
129
|
shape = tuple(map(int, tensor.shape))
|
|
135
130
|
self.nodes.append(
|
|
131
|
+
# pyrefly: ignore[bad-argument-type]
|
|
136
132
|
oh.make_node(op_type, [], [name], dtype=dtype, shape=shape, **kwargs)
|
|
137
133
|
)
|
|
138
134
|
self.outputs.append(oh.make_tensor_value_info(name, dtype, shape))
|
|
@@ -632,6 +628,7 @@ def create_input_tensors_from_onnx_model(
|
|
|
632
628
|
raise AssertionError(f"Unexpected value for engine={engine!r}")
|
|
633
629
|
|
|
634
630
|
got = sess.run(None, {})
|
|
631
|
+
assert isinstance(got, list) # type checking
|
|
635
632
|
if len(names) == 1:
|
|
636
633
|
name = names[0]
|
|
637
634
|
output = got[0]
|
|
@@ -639,12 +636,10 @@ def create_input_tensors_from_onnx_model(
|
|
|
639
636
|
return None
|
|
640
637
|
if name == "array":
|
|
641
638
|
return output
|
|
642
|
-
if name
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
return
|
|
646
|
-
if name == "float":
|
|
647
|
-
return float(output[0])
|
|
639
|
+
if name in {"bool", "int", "float"}:
|
|
640
|
+
cvt = {"bool": bool, "int": int, "float": float}[name]
|
|
641
|
+
# pyrefly: ignore[bad-index]
|
|
642
|
+
return cvt(output[0])
|
|
648
643
|
if name == "tensor":
|
|
649
644
|
return torch.from_numpy(output).to(device)
|
|
650
645
|
assert name.startswith(
|
|
@@ -14,7 +14,7 @@ CACHE_SUBDIR = "onnx-diagnostic"
|
|
|
14
14
|
|
|
15
15
|
def download_model_builder_to_cache(
|
|
16
16
|
url: str = "https://raw.githubusercontent.com/microsoft/onnxruntime-genai/refs/heads/main/src/python/py/models/builder.py",
|
|
17
|
-
):
|
|
17
|
+
) -> Path:
|
|
18
18
|
"""
|
|
19
19
|
Downloads ``builder.py`` from the
|
|
20
20
|
``https://github.com/microsoft/onnxruntime-genai/blob/main/src/python/py/models/builder.py``.
|