onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,89 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+ from ..helpers.config_helper import (
4
+ update_config,
5
+ check_hasattr,
6
+ default_num_hidden_layers as nhl,
7
+ )
8
+
9
+ __TASK__ = "fill-mask"
10
+
11
+
12
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
13
+ """Reduces a model size."""
14
+ check_hasattr(config, "num_attention_heads", "num_hidden_layers")
15
+ kwargs = dict(
16
+ num_hidden_layers=min(config.num_hidden_layers, nhl()),
17
+ num_attention_heads=min(config.num_attention_heads, 4),
18
+ )
19
+ update_config(config, kwargs)
20
+ return kwargs
21
+
22
+
23
+ def get_inputs(
24
+ model: torch.nn.Module,
25
+ config: Optional[Any],
26
+ batch_size: int,
27
+ sequence_length: int,
28
+ dummy_max_token_id: int,
29
+ add_second_input: int = 1,
30
+ **kwargs, # unused
31
+ ):
32
+ """
33
+ Generates inputs for task ``fill-mask``.
34
+ Example:
35
+
36
+ ::
37
+
38
+ input_ids:T7s1x13[101,72654:A16789.23076923077],
39
+ token_type_ids:T7s1x13[0,0:A0.0],
40
+ attention_mask:T7s1x13[1,1:A1.0])
41
+ """
42
+ assert (
43
+ "cls_cache" not in kwargs
44
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
45
+ batch = "batch"
46
+ seq_length = "sequence_length"
47
+ shapes = {
48
+ "input_ids": {0: batch, 1: seq_length},
49
+ "token_type_ids": {0: batch, 1: seq_length},
50
+ "attention_mask": {0: batch, 1: seq_length},
51
+ }
52
+ inputs = dict(
53
+ input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length)).to(
54
+ torch.int64
55
+ ),
56
+ token_type_ids=torch.zeros((batch_size, sequence_length)).to(torch.int64),
57
+ attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
58
+ )
59
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
60
+ if add_second_input:
61
+ assert (
62
+ add_second_input > 0
63
+ ), f"Not implemented for add_second_input={add_second_input}."
64
+ res["inputs2"] = get_inputs(
65
+ model=model,
66
+ config=config,
67
+ batch_size=batch_size + 1,
68
+ sequence_length=sequence_length + add_second_input,
69
+ dummy_max_token_id=dummy_max_token_id,
70
+ add_second_input=0,
71
+ **kwargs,
72
+ )["inputs"]
73
+ return res
74
+
75
+
76
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
77
+ """
78
+ Inputs kwargs.
79
+
80
+ If the configuration is None, the function selects typical dimensions.
81
+ """
82
+ if config is not None:
83
+ check_hasattr(config, "vocab_size")
84
+ kwargs = dict(
85
+ batch_size=2,
86
+ sequence_length=30,
87
+ dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
88
+ )
89
+ return kwargs, get_inputs
@@ -0,0 +1,144 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+ from ..helpers.config_helper import (
4
+ update_config,
5
+ check_hasattr,
6
+ default_num_hidden_layers as nhl,
7
+ )
8
+
9
+ __TASK__ = "image-classification"
10
+
11
+
12
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
13
+ """Reduces a model size."""
14
+ if (
15
+ hasattr(config, "model_type")
16
+ and config.model_type == "timm_wrapper"
17
+ and not hasattr(config, "num_hidden_layers")
18
+ ):
19
+ # We cannot reduce.
20
+ return {}
21
+ check_hasattr(config, ("num_hidden_layers", "hidden_sizes"))
22
+ kwargs = dict(
23
+ num_hidden_layers=(
24
+ min(config.num_hidden_layers, nhl())
25
+ if hasattr(config, "num_hidden_layers")
26
+ else len(config.hidden_sizes)
27
+ )
28
+ )
29
+ update_config(config, kwargs)
30
+ return kwargs
31
+
32
+
33
+ def get_inputs(
34
+ model: torch.nn.Module,
35
+ config: Optional[Any],
36
+ input_width: int,
37
+ input_height: int,
38
+ input_channels: int,
39
+ batch_size: int = 2,
40
+ dynamic_rope: bool = False,
41
+ add_second_input: int = 1,
42
+ **kwargs, # unused
43
+ ):
44
+ """
45
+ Generates inputs for task ``image-classification``.
46
+
47
+ :param model: model to get the missing information
48
+ :param config: configuration used to generate the model
49
+ :param batch_size: batch size
50
+ :param input_channels: input channel
51
+ :param input_width: input width
52
+ :param input_height: input height
53
+ :return: dictionary
54
+ """
55
+ assert (
56
+ "cls_cache" not in kwargs
57
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
58
+ assert isinstance(
59
+ input_width, int
60
+ ), f"Unexpected type for input_width {type(input_width)}{config}"
61
+ assert isinstance(
62
+ input_height, int
63
+ ), f"Unexpected type for input_height {type(input_height)}{config}"
64
+
65
+ shapes = {
66
+ "pixel_values": {
67
+ 0: torch.export.Dim("batch", min=1, max=1024),
68
+ 2: "width",
69
+ 3: "height",
70
+ },
71
+ }
72
+ inputs = dict(
73
+ pixel_values=torch.randn(batch_size, input_channels, input_width, input_height).clamp(
74
+ -1, 1
75
+ ),
76
+ )
77
+ if model.__class__.__name__ == "ViTForImageClassification":
78
+ inputs["interpolate_pos_encoding"] = True
79
+ shapes["interpolate_pos_encoding"] = None # type: ignore[assignment]
80
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
81
+ if add_second_input:
82
+ assert (
83
+ add_second_input > 0
84
+ ), f"Not implemented for add_second_input={add_second_input}."
85
+ res["inputs2"] = get_inputs(
86
+ model=model,
87
+ config=config,
88
+ input_width=input_width + add_second_input,
89
+ input_height=input_height + add_second_input,
90
+ input_channels=input_channels,
91
+ batch_size=batch_size + 1,
92
+ dynamic_rope=dynamic_rope,
93
+ add_second_input=0,
94
+ **kwargs,
95
+ )["inputs"]
96
+ return res
97
+
98
+
99
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
100
+ """
101
+ Inputs kwargs.
102
+
103
+ If the configuration is None, the function selects typical dimensions.
104
+ """
105
+ if config is not None:
106
+ if (
107
+ hasattr(config, "model_type")
108
+ and config.model_type == "timm_wrapper"
109
+ and not hasattr(config, "num_hidden_layers")
110
+ ):
111
+ input_size = config.pretrained_cfg["input_size"]
112
+ kwargs = dict(
113
+ batch_size=2,
114
+ input_width=input_size[-2],
115
+ input_height=input_size[-1],
116
+ input_channels=input_size[-3],
117
+ )
118
+ return kwargs, get_inputs
119
+
120
+ check_hasattr(config, ("image_size", "architectures"), "num_channels")
121
+ if config is not None:
122
+ if hasattr(config, "image_size"):
123
+ image_size = config.image_size
124
+ else:
125
+ assert config.architectures, f"empty architecture in {config}"
126
+ from ..torch_models.hghub.hub_api import get_architecture_default_values
127
+
128
+ default_values = get_architecture_default_values(config.architectures[0])
129
+ image_size = default_values["image_size"]
130
+ if config is None or isinstance(image_size, int):
131
+ kwargs = dict(
132
+ batch_size=2,
133
+ input_width=224 if config is None else image_size,
134
+ input_height=224 if config is None else image_size,
135
+ input_channels=3 if config is None else config.num_channels,
136
+ )
137
+ else:
138
+ kwargs = dict(
139
+ batch_size=2,
140
+ input_width=config.image_size[0],
141
+ input_height=config.image_size[1],
142
+ input_channels=config.num_channels,
143
+ )
144
+ return kwargs, get_inputs