onnx-diagnostic 0.7.16__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +78 -22
  3. onnx_diagnostic/export/api.py +124 -0
  4. onnx_diagnostic/export/dynamic_shapes.py +2 -1
  5. onnx_diagnostic/export/shape_helper.py +47 -70
  6. onnx_diagnostic/ext_test_case.py +11 -0
  7. onnx_diagnostic/helpers/cache_helper.py +38 -7
  8. onnx_diagnostic/helpers/fake_tensor_helper.py +224 -104
  9. onnx_diagnostic/helpers/helper.py +27 -33
  10. onnx_diagnostic/helpers/log_helper.py +109 -5
  11. onnx_diagnostic/helpers/memory_peak.py +2 -0
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +1 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +132 -2
  14. onnx_diagnostic/helpers/onnx_helper.py +1 -1
  15. onnx_diagnostic/helpers/ort_session.py +4 -0
  16. onnx_diagnostic/helpers/rt_helper.py +393 -43
  17. onnx_diagnostic/helpers/torch_helper.py +20 -1
  18. onnx_diagnostic/tasks/__init__.py +7 -0
  19. onnx_diagnostic/tasks/automatic_speech_recognition.py +2 -8
  20. onnx_diagnostic/tasks/feature_extraction.py +2 -8
  21. onnx_diagnostic/tasks/image_text_to_text.py +10 -8
  22. onnx_diagnostic/tasks/summarization.py +2 -8
  23. onnx_diagnostic/tasks/text2text_generation.py +3 -8
  24. onnx_diagnostic/tasks/text_generation.py +86 -65
  25. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +718 -438
  26. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  27. onnx_diagnostic/torch_export_patches/patch_inputs.py +1 -1
  28. onnx_diagnostic/torch_export_patches/patch_module.py +9 -36
  29. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -6
  30. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +162 -24
  31. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +140 -104
  32. onnx_diagnostic/torch_models/untrained/llm_phi2.py +1 -4
  33. onnx_diagnostic/torch_models/validate.py +626 -228
  34. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/METADATA +1 -1
  35. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/RECORD +38 -36
  36. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/WHEEL +0 -0
  37. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/licenses/LICENSE.txt +0 -0
  38. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/top_level.txt +0 -0
@@ -42,6 +42,8 @@ class CubeViewDef:
42
42
  :param name: name of the view, used mostly to debug
43
43
  :param plots: adds plot to the Excel sheet
44
44
  :param no_index: remove the index (but keeps the columns)
45
+ :param fix_aggregation_change: a column among the keys which changes aggregation value
46
+ for different dates
45
47
 
46
48
  Some examples of views. First example is an aggregated view
47
49
  for many metrics.
@@ -106,6 +108,7 @@ class CubeViewDef:
106
108
  name: Optional[str] = None,
107
109
  no_index: bool = False,
108
110
  plots: bool = False,
111
+ fix_aggregation_change: Optional[List["str"]] = None,
109
112
  ):
110
113
  self.key_index = key_index
111
114
  self.values = values
@@ -123,6 +126,7 @@ class CubeViewDef:
123
126
  self.name = name
124
127
  self.no_index = no_index
125
128
  self.plots = plots
129
+ self.fix_aggregation_change = fix_aggregation_change
126
130
 
127
131
  def __repr__(self) -> str:
128
132
  "usual"
@@ -750,6 +754,17 @@ class CubeLogs:
750
754
  f"values={sorted(self.values)}"
751
755
  )
752
756
 
757
+ if view_def.fix_aggregation_change and (
758
+ set(view_def.fix_aggregation_change) & set(self.keys_no_time)
759
+ ):
760
+ # before aggregation, let's fix some keys whose values changed over time
761
+ data_to_process = self._fix_aggregation_change(
762
+ self.data,
763
+ list(set(view_def.fix_aggregation_change) & set(self.keys_no_time)),
764
+ )
765
+ else:
766
+ data_to_process = self.data
767
+
753
768
  # aggregation
754
769
  if key_agg:
755
770
  final_stack = True
@@ -763,7 +778,7 @@ class CubeLogs:
763
778
  print(f"[CubeLogs.view] aggregation of {set_key_agg}")
764
779
  print(f"[CubeLogs.view] groupby {keys_no_agg}")
765
780
 
766
- data_red = self.data[[*keys_no_agg, *values]]
781
+ data_red = data_to_process[[*keys_no_agg, *values]]
767
782
  assert set(key_index) <= set(data_red.columns), (
768
783
  f"view_def.name={view_def.name!r}, "
769
784
  f"nnable to find {set(key_index) - set(data_red.columns)}, "
@@ -792,7 +807,7 @@ class CubeLogs:
792
807
  key_index = self._filter_column(view_def.key_index, self.keys_time)
793
808
  if verbose:
794
809
  print(f"[CubeLogs.view] no aggregation, index={key_index}")
795
- data = self.data[[*self.keys_time, *values]]
810
+ data = data_to_process[[*self.keys_time, *values]]
796
811
  set_all_keys = set(self.keys_time)
797
812
  final_stack = False
798
813
 
@@ -829,7 +844,7 @@ class CubeLogs:
829
844
  key_columns = sorted(set_key_columns)
830
845
  unique = set()
831
846
 
832
- _md = lambda s: {k: v for k, v in self.values_for_key.items() if k in s} # noqa: E731
847
+ # md = lambda s: {k: v for k, v in self.values_for_key.items() if k in s} # noqa: E731
833
848
  all_cols = set(key_columns) | set(key_index) | set(key_agg) | unique
834
849
  assert all_cols == set(self.keys_time), (
835
850
  f"view_def.name={view_def.name!r}, "
@@ -892,7 +907,7 @@ class CubeLogs:
892
907
  f"key={sorted(key_columns)}, key_agg={key_agg}, values={sorted(values)}, "
893
908
  f"columns={sorted(data.columns)}, ignored={view_def.ignore_columns}, "
894
909
  f"not unique={set(data.columns) - unique}"
895
- f"\n--\n{not_unique.head()}"
910
+ f"\n--\n{not_unique.head(10)}"
896
911
  )
897
912
 
898
913
  # pivot
@@ -961,6 +976,70 @@ class CubeLogs:
961
976
  print(f"[CubeLogs.view] -- done view {view_def.name!r}")
962
977
  return (piv, view_def) if return_view_def else piv
963
978
 
979
+ def _fix_aggregation_change(
980
+ self,
981
+ data: pandas.DataFrame,
982
+ columns_to_fix: Union[str, List[str]],
983
+ overwrite_or_merge: bool = True,
984
+ ) -> pandas.DataFrame:
985
+ """
986
+ Fixes columns used to aggregate values because their meaning changed over time.
987
+
988
+ :param data: data to fix
989
+ :param columns_to_fix: list of columns to fix
990
+ :param overwrite_or_merge: if True, overwrite all values by the concatenation
991
+ of all existing values, if merge, merges existing values found
992
+ and grouped by the other keys
993
+ :return: fixed data
994
+ """
995
+ if not isinstance(columns_to_fix, str):
996
+ for c in columns_to_fix:
997
+ data = self._fix_aggregation_change(data, c)
998
+ return data
999
+ # Let's process one column.
1000
+ keys = set(self.keys_time) - {columns_to_fix}
1001
+ select = data[self.keys_time]
1002
+ select_agg = select.groupby(list(keys)).count()
1003
+ assert select_agg[columns_to_fix].max() <= 1, (
1004
+ f"Column {columns_to_fix!r} has two distinct values at least for one date\n"
1005
+ f"{select_agg[select_agg[columns_to_fix] > 1]}"
1006
+ )
1007
+
1008
+ # unique value (to fill NaN)
1009
+ unique = "-".join(sorted(set(data[columns_to_fix].dropna())))
1010
+
1011
+ keys = set(self.keys_no_time) - {columns_to_fix}
1012
+ select = data[self.keys_no_time]
1013
+ select_agg = select.groupby(list(keys), as_index=True).apply(
1014
+ lambda x: "-".join(sorted(set(x[columns_to_fix].dropna()))), include_groups=False
1015
+ )
1016
+ select_agg = select_agg.to_frame(name=columns_to_fix)
1017
+ res = pandas.merge(
1018
+ data.drop([columns_to_fix], axis=1),
1019
+ select_agg,
1020
+ how="left",
1021
+ left_on=list(keys),
1022
+ right_index=True,
1023
+ )
1024
+ val = f"?{unique}?"
1025
+ res[columns_to_fix] = res[columns_to_fix].fillna(val).replace("", val)
1026
+ assert (
1027
+ data.shape == res.shape
1028
+ and sorted(data.columns) == sorted(res.columns)
1029
+ and sorted(data.index) == sorted(res.index)
1030
+ ), (
1031
+ f"Shape should match, data.shape={data.shape}, res.shape={res.shape}, "
1032
+ f"lost={set(data.columns) - set(res.columns)}, "
1033
+ f"added={set(res.columns) - set(data.columns)}"
1034
+ )
1035
+ res = res[data.columns]
1036
+ assert data.columns.equals(res.columns) and data.index.equals(res.index), (
1037
+ f"Columns or index mismatch "
1038
+ f"data.columns.equals(res.columns)={data.columns.equals(res.columns)}, "
1039
+ f"data.index.equals(res.columns)={data.index.equals(res.columns)}, "
1040
+ )
1041
+ return res
1042
+
964
1043
  def _dropna(
965
1044
  self,
966
1045
  data: pandas.DataFrame,
@@ -1090,7 +1169,8 @@ class CubeLogs:
1090
1169
  assuming they should remain stale
1091
1170
  :param sbs: configurations to compare side-by-side, this adds two tabs,
1092
1171
  one gathering raw data about the two configurations, the other one
1093
- is aggregated by metrics
1172
+ is aggregated by metrics, example:
1173
+ ``=dict(CFA=dict(exporter="E1", opt="O"), CFB=dict(exporter="E2", opt="O"))``
1094
1174
  """
1095
1175
  if verbose:
1096
1176
  print(f"[CubeLogs.to_excel] create Excel file {output}, shape={self.shape}")
@@ -1532,6 +1612,7 @@ class CubeLogsPerformance(CubeLogs):
1532
1612
  "n_node_initializer_small",
1533
1613
  "n_node_layer_normalization",
1534
1614
  "n_node_layer_normalization23",
1615
+ "n_node_random",
1535
1616
  "n_node_reshape",
1536
1617
  "n_node_rotary_embedding",
1537
1618
  "n_node_rotary_embedding23",
@@ -1723,6 +1804,16 @@ class CubeLogsPerformance(CubeLogs):
1723
1804
  + gdf(df, "op_onnx__InstanceNormlization", 0)
1724
1805
  + gdf(df, "op_onnx__GroupNormalization", 0),
1725
1806
  ),
1807
+ n_node_random=lambda df: gpreserve(
1808
+ df,
1809
+ "time_latency_eager",
1810
+ gdf(df, "op_onnx__RandomNormal", 0)
1811
+ + gdf(df, "op_onnx__RandomNormalLike", 0)
1812
+ + gdf(df, "op_onnx__RandomUniform", 0)
1813
+ + gdf(df, "op_onnx__RandomUniformLike", 0)
1814
+ + gdf(df, "op_onnx__Multinomial", 0)
1815
+ + gdf(df, "op_onnx__Bernoulli", 0),
1816
+ ),
1726
1817
  n_node_attention=lambda df: gpreserve(
1727
1818
  df,
1728
1819
  "time_latency_eager",
@@ -1886,6 +1977,7 @@ class CubeLogsPerformance(CubeLogs):
1886
1977
  * **cmd:** command lines
1887
1978
  * **raw-short:** raw data without all the unused columns
1888
1979
  """
1980
+ fix_aggregation_change = ["model_speedup_input_set", "model_test_with"]
1889
1981
  fs = ["suite", "model_suite", "task", "model_name", "model_task"]
1890
1982
  index_cols = self._filter_column(fs, self.keys_time)
1891
1983
  assert index_cols, (
@@ -1984,6 +2076,7 @@ class CubeLogsPerformance(CubeLogs):
1984
2076
  keep_columns_in_index=["suite"],
1985
2077
  name="agg-suite",
1986
2078
  order=order,
2079
+ fix_aggregation_change=fix_aggregation_change,
1987
2080
  ),
1988
2081
  "agg-all": lambda: CubeViewDef(
1989
2082
  key_index=index_cols,
@@ -2014,6 +2107,7 @@ class CubeLogsPerformance(CubeLogs):
2014
2107
  name="agg-all",
2015
2108
  order=order,
2016
2109
  plots=True,
2110
+ fix_aggregation_change=fix_aggregation_change,
2017
2111
  ),
2018
2112
  "disc": lambda: CubeViewDef(
2019
2113
  key_index=index_cols,
@@ -2023,6 +2117,7 @@ class CubeLogsPerformance(CubeLogs):
2023
2117
  f_highlight=f_disc,
2024
2118
  name="disc",
2025
2119
  order=order,
2120
+ fix_aggregation_change=fix_aggregation_change,
2026
2121
  ),
2027
2122
  "speedup": lambda: CubeViewDef(
2028
2123
  key_index=index_cols,
@@ -2032,6 +2127,7 @@ class CubeLogsPerformance(CubeLogs):
2032
2127
  f_highlight=f_speedup,
2033
2128
  name="speedup",
2034
2129
  order=order,
2130
+ fix_aggregation_change=fix_aggregation_change,
2035
2131
  ),
2036
2132
  "counts": lambda: CubeViewDef(
2037
2133
  key_index=index_cols,
@@ -2048,6 +2144,7 @@ class CubeLogsPerformance(CubeLogs):
2048
2144
  keep_columns_in_index=["suite"],
2049
2145
  name="peak-gpu",
2050
2146
  order=order,
2147
+ fix_aggregation_change=fix_aggregation_change,
2051
2148
  ),
2052
2149
  "time": lambda: CubeViewDef(
2053
2150
  key_index=index_cols,
@@ -2058,6 +2155,7 @@ class CubeLogsPerformance(CubeLogs):
2058
2155
  keep_columns_in_index=["suite"],
2059
2156
  name="time",
2060
2157
  order=order,
2158
+ fix_aggregation_change=fix_aggregation_change,
2061
2159
  ),
2062
2160
  "time_export": lambda: CubeViewDef(
2063
2161
  key_index=index_cols,
@@ -2066,6 +2164,7 @@ class CubeLogsPerformance(CubeLogs):
2066
2164
  keep_columns_in_index=["suite"],
2067
2165
  name="time_export",
2068
2166
  order=order,
2167
+ fix_aggregation_change=fix_aggregation_change,
2069
2168
  ),
2070
2169
  "err": lambda: CubeViewDef(
2071
2170
  key_index=index_cols,
@@ -2076,6 +2175,7 @@ class CubeLogsPerformance(CubeLogs):
2076
2175
  keep_columns_in_index=["suite"],
2077
2176
  name="err",
2078
2177
  order=order,
2178
+ fix_aggregation_change=fix_aggregation_change,
2079
2179
  ),
2080
2180
  "bucket-speedup": lambda: CubeViewDef(
2081
2181
  key_index=index_cols,
@@ -2085,6 +2185,7 @@ class CubeLogsPerformance(CubeLogs):
2085
2185
  name="bucket-speedup",
2086
2186
  f_highlight=f_bucket,
2087
2187
  order=order,
2188
+ fix_aggregation_change=fix_aggregation_change,
2088
2189
  ),
2089
2190
  "onnx": lambda: CubeViewDef(
2090
2191
  key_index=index_cols,
@@ -2103,6 +2204,7 @@ class CubeLogsPerformance(CubeLogs):
2103
2204
  keep_columns_in_index=["suite"],
2104
2205
  name="onnx",
2105
2206
  order=order,
2207
+ fix_aggregation_change=fix_aggregation_change,
2106
2208
  ),
2107
2209
  "raw-short": lambda: CubeViewDef(
2108
2210
  key_index=self.keys_time,
@@ -2111,6 +2213,7 @@ class CubeLogsPerformance(CubeLogs):
2111
2213
  keep_columns_in_index=["suite"],
2112
2214
  name="raw-short",
2113
2215
  no_index=True,
2216
+ fix_aggregation_change=fix_aggregation_change,
2114
2217
  ),
2115
2218
  }
2116
2219
 
@@ -2123,6 +2226,7 @@ class CubeLogsPerformance(CubeLogs):
2123
2226
  keep_columns_in_index=["suite"],
2124
2227
  name="cmd",
2125
2228
  order=order,
2229
+ fix_aggregation_change=fix_aggregation_change,
2126
2230
  )
2127
2231
 
2128
2232
  assert name in implemented_views or name in {"cmd"}, (
@@ -47,6 +47,8 @@ class Monitor:
47
47
 
48
48
  @property
49
49
  def delta_avg(self):
50
+ if self.n_measures == 0:
51
+ return 0
50
52
  return self.average / self.n_measures - self.begin
51
53
 
52
54
  def __repr__(self):
@@ -52,7 +52,7 @@ def proto_from_array(
52
52
 
53
53
  tensor = TensorProto()
54
54
  tensor.dims.extend(arr_cpu.shape)
55
- tensor.name = name
55
+ tensor.name = name or ""
56
56
  itype = dtype_to_tensor_dtype(arr_cpu.dtype)
57
57
  assert not hasattr(TensorProto, "INT4") or itype not in {
58
58
  TensorProto.INT4,
@@ -1,11 +1,13 @@
1
+ import copy
1
2
  import importlib.util
2
3
  import os
4
+ import re
3
5
  import requests
4
6
  import sys
5
7
  from pathlib import Path
6
- from typing import Any, Optional, Union
8
+ from typing import Any, Dict, List, Optional, Union
7
9
  from urllib.parse import urlparse
8
- from onnx import ModelProto, TensorProto
10
+ from onnx import ModelProto, TensorProto, load as load_model
9
11
 
10
12
  CACHE_SUBDIR = "onnx-diagnostic"
11
13
 
@@ -337,3 +339,131 @@ def create_model_builder(
337
339
  # onnx_model.make_genai_config(hf_name, extra_kwargs, output_dir)
338
340
  # onnx_model.save_processing(hf_name, extra_kwargs, output_dir)
339
341
  return onnx_model
342
+
343
+
344
+ def find_names_pattern(names: List[str]) -> str:
345
+ """
346
+ Finds a repeatable patterns in a list of names.
347
+ It tries to locate the figures.
348
+
349
+ .. runpython::
350
+ :showcode:
351
+
352
+ from onnx_diagnostic.helpers.model_builder_helper import find_names_pattern
353
+ pattern = find_names_pattern(["past_key_values_key_0", "past_key_values_key_1"])
354
+ print(pattern)
355
+ """
356
+ patterns = [re.sub(r"(\d+)", r"%d", t) for t in names]
357
+ unique = set(patterns)
358
+ assert (
359
+ len(unique) == 1
360
+ ), f"Unable to guess a pattern from {names} which led to the unique patterns {unique}"
361
+ return patterns[0]
362
+
363
+
364
+ def make_genai_config(
365
+ config,
366
+ onnx_filename: str,
367
+ ) -> Dict:
368
+ """
369
+ Creates genai config file for a model.
370
+
371
+ :param config: configuration from transformers
372
+ :param onnx_filename: onnx configuration
373
+ :return: configuration
374
+ """
375
+ onx = load_model(onnx_filename, load_external_data=False)
376
+ config = copy.deepcopy(config)
377
+ defaults = {
378
+ "bos_token_id": None,
379
+ "do_sample": False,
380
+ "eos_token_id": None,
381
+ "pad_token_id": None,
382
+ "temperature": 1.0,
383
+ "top_k": 50,
384
+ "top_p": 1.0,
385
+ }
386
+ for key, default_val in defaults.items():
387
+ if not hasattr(config, key):
388
+ setattr(config, key, default_val)
389
+
390
+ bos_token_id = (
391
+ config.bos_token_id
392
+ if hasattr(config, "bos_token_id") and config.bos_token_id is not None
393
+ else 1
394
+ )
395
+ eos_token_id = config.eos_token_id
396
+ pad_token_id = (
397
+ config.pad_token_id
398
+ if hasattr(config, "pad_token_id") and config.pad_token_id is not None
399
+ else (
400
+ config.eos_token_id[0]
401
+ if isinstance(config.eos_token_id, list)
402
+ else config.eos_token_id
403
+ )
404
+ )
405
+ input_names = [i.name for i in onx.graph.input]
406
+ output_names = [i.name for i in onx.graph.output]
407
+ past_key_values = [s for s in input_names if s.startswith("past_key_value")]
408
+ first = [i for i in onx.graph.input if i.name == past_key_values[0]][0] # noqa: RUF015
409
+ shape = tuple(d.dim_value or d.dim_param for d in first.type.tensor_type.shape.dim)
410
+ return {
411
+ "model": {
412
+ "bos_token_id": bos_token_id,
413
+ "context_length": config.max_position_embeddings,
414
+ "decoder": {
415
+ "session_options": {
416
+ "log_id": "onnxruntime-genai",
417
+ "provider_options": [],
418
+ },
419
+ "filename": os.path.split(onnx_filename)[-1],
420
+ "head_size": shape[-1],
421
+ "hidden_size": config.hidden_size,
422
+ "inputs": {
423
+ "input_ids": input_names[0],
424
+ "attention_mask": input_names[1],
425
+ "past_key_names": find_names_pattern(input_names[2::2]),
426
+ "past_value_names": find_names_pattern(input_names[3::2]),
427
+ },
428
+ "outputs": {
429
+ "logits": output_names[0],
430
+ "present_key_names": find_names_pattern(output_names[1::2]),
431
+ "present_value_names": find_names_pattern(output_names[2::2]),
432
+ },
433
+ "num_attention_heads": config.num_attention_heads,
434
+ "num_hidden_layers": len(past_key_values) // 2,
435
+ "num_key_value_heads": shape[1],
436
+ },
437
+ "eos_token_id": eos_token_id,
438
+ "pad_token_id": pad_token_id,
439
+ "type": config.model_type,
440
+ # if "For" in self.model_type else len(self.model_type)].lower(),
441
+ "vocab_size": config.vocab_size,
442
+ },
443
+ "search": {
444
+ "diversity_penalty": (
445
+ config.diversity_penalty if hasattr(config, "diversity_penalty") else 0.0
446
+ ),
447
+ "do_sample": config.do_sample if hasattr(config, "do_sample") else False,
448
+ "early_stopping": True,
449
+ "length_penalty": (
450
+ config.length_penalty if hasattr(config, "length_penalty") else 1.0
451
+ ),
452
+ "max_length": config.max_position_embeddings,
453
+ "min_length": 0,
454
+ "no_repeat_ngram_size": (
455
+ config.no_repeat_ngram_size if hasattr(config, "no_repeat_ngram_size") else 0
456
+ ),
457
+ "num_beams": config.num_beams if hasattr(config, "num_beams") else 1,
458
+ "num_return_sequences": (
459
+ config.num_return_sequences if hasattr(config, "num_return_sequences") else 1
460
+ ),
461
+ "past_present_share_buffer": False,
462
+ "repetition_penalty": (
463
+ config.repetition_penalty if hasattr(config, "repetition_penalty") else 1.0
464
+ ),
465
+ "temperature": config.temperature if hasattr(config, "temperature") else 1.0,
466
+ "top_k": config.top_k if hasattr(config, "top_k") else 50,
467
+ "top_p": config.top_p if hasattr(config, "top_p") else 1.0,
468
+ },
469
+ }
@@ -331,7 +331,7 @@ def onnx_dtype_name(itype: int, exc: bool = True) -> str:
331
331
  print(onnx_dtype_name(7))
332
332
  """
333
333
  for k in dir(TensorProto):
334
- if "FLOAT" in k or "INT" in k or "TEXT" in k or "BOOL" in k:
334
+ if k.upper() == k and k != "EXTERNAL":
335
335
  v = getattr(TensorProto, k)
336
336
  if v == itype:
337
337
  return k
@@ -135,6 +135,10 @@ class _InferenceSession:
135
135
  self.sess = sess
136
136
  self.input_names = [i.name for i in sess.get_inputs()]
137
137
  self.output_names = [i.name for i in sess.get_outputs()]
138
+ self.input_shapes = [i.shape for i in sess.get_inputs()]
139
+ self.output_shapes = [i.shape for i in sess.get_outputs()]
140
+ self.input_types = [i.type for i in sess.get_inputs()]
141
+ self.output_types = [i.type for i in sess.get_outputs()]
138
142
  self.torch = torch
139
143
  self.nvtx = nvtx
140
144
  self.run_options = onnxruntime.RunOptions()