onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,95 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+ from ..helpers.config_helper import update_config, check_hasattr, pick
4
+
5
+ __TASK__ = "text-to-image"
6
+
7
+
8
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
9
+ """Reduces a model size."""
10
+ check_hasattr(config, "sample_size", "cross_attention_dim")
11
+ kwargs = dict(
12
+ sample_size=min(config["sample_size"], 32),
13
+ cross_attention_dim=min(config["cross_attention_dim"], 64),
14
+ )
15
+ update_config(config, kwargs)
16
+ return kwargs
17
+
18
+
19
+ def get_inputs(
20
+ model: torch.nn.Module,
21
+ config: Optional[Any],
22
+ batch_size: int,
23
+ sequence_length: int,
24
+ cache_length: int,
25
+ in_channels: int,
26
+ sample_size: int,
27
+ cross_attention_dim: int,
28
+ add_second_input: int = 1,
29
+ **kwargs, # unused
30
+ ):
31
+ """
32
+ Generates inputs for task ``text-to-image``.
33
+ Example:
34
+
35
+ ::
36
+
37
+ sample:T10s2x4x96x96[-3.7734375,4.359375:A-0.043463995395642184]
38
+ timestep:T7s=101
39
+ encoder_hidden_states:T10s2x77x1024[-6.58203125,13.0234375:A-0.16780663634440257]
40
+ """
41
+ assert (
42
+ "cls_cache" not in kwargs
43
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
44
+ batch = "batch"
45
+ shapes = {
46
+ "sample": {0: batch},
47
+ "timestep": {},
48
+ "encoder_hidden_states": {0: batch, 1: "encoder_length"},
49
+ }
50
+ inputs = dict(
51
+ sample=torch.randn((batch_size, sequence_length, sample_size, sample_size)).to(
52
+ torch.float32
53
+ ),
54
+ timestep=torch.tensor([101], dtype=torch.int64),
55
+ encoder_hidden_states=torch.randn(
56
+ (batch_size, sequence_length, cross_attention_dim)
57
+ ).to(torch.float32),
58
+ )
59
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
60
+ if add_second_input:
61
+ assert (
62
+ add_second_input > 0
63
+ ), f"Not implemented for add_second_input={add_second_input}."
64
+ res["inputs2"] = get_inputs(
65
+ model=model,
66
+ config=config,
67
+ batch_size=batch_size + 1,
68
+ sequence_length=sequence_length,
69
+ cache_length=cache_length + add_second_input,
70
+ in_channels=in_channels,
71
+ sample_size=sample_size,
72
+ cross_attention_dim=cross_attention_dim,
73
+ add_second_input=0,
74
+ **kwargs,
75
+ )["inputs"]
76
+ return res
77
+
78
+
79
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
80
+ """
81
+ Inputs kwargs.
82
+
83
+ If the configuration is None, the function selects typical dimensions.
84
+ """
85
+ if config is not None:
86
+ check_hasattr(config, "sample_size", "cross_attention_dim", "in_channels")
87
+ kwargs = dict(
88
+ batch_size=2,
89
+ sequence_length=pick(config, "in_channels", 4),
90
+ cache_length=77,
91
+ in_channels=pick(config, "in_channels", 4),
92
+ sample_size=pick(config, "sample_size", 32),
93
+ cross_attention_dim=pick(config, "cross_attention_dim", 64),
94
+ )
95
+ return kwargs, get_inputs
@@ -0,0 +1,128 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+ from ..helpers.config_helper import update_config, check_hasattr
4
+
5
+ __TASK__ = "zero-shot-image-classification"
6
+
7
+
8
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
9
+ """Reduces a model size."""
10
+ check_hasattr(config, "vision_config", "text_config")
11
+ check_hasattr(config.vision_config, "num_hidden_layers", "num_attention_heads")
12
+ check_hasattr(config.text_config, "num_hidden_layers", "num_attention_heads")
13
+ kwargs = dict(
14
+ vision_config=dict(
15
+ num_hidden_layers=min(2, config.vision_config.num_hidden_layers),
16
+ num_attention_heads=min(2, config.vision_config.num_attention_heads),
17
+ ),
18
+ text_config=dict(
19
+ num_hidden_layers=min(2, config.text_config.num_hidden_layers),
20
+ num_attention_heads=min(2, config.text_config.num_attention_heads),
21
+ ),
22
+ )
23
+ update_config(config, kwargs)
24
+ return kwargs
25
+
26
+
27
+ def get_inputs(
28
+ model: torch.nn.Module,
29
+ config: Optional[Any],
30
+ dummy_max_token_id: int,
31
+ batch_size: int = 2,
32
+ sequence_length: int = 30,
33
+ input_width: int = 224,
34
+ input_height: int = 224,
35
+ input_channels: int = 3,
36
+ batch_size_image=3,
37
+ add_second_input: int = 1,
38
+ **kwargs, # unused
39
+ ):
40
+ """
41
+ Generates inputs for task ``zero-short-image-classification``.
42
+
43
+ :param model: model to get the missing information
44
+ :param config: configuration used to generate the model
45
+ :param dummy_max_token_id: vocabulary size
46
+ :param batch_size: batch size
47
+ :param sequence_length: sequence length
48
+ :param batch_size_image: number of images
49
+ :param input_channels: input channel
50
+ :param input_width: input width
51
+ :param input_height: input height
52
+ :return: dictionary
53
+
54
+ # input_ids:T7s2x7
55
+ # attention_mask:T7s2x7
56
+ # pixel_values:T1s2x3x224x224
57
+ """
58
+ assert (
59
+ "cls_cache" not in kwargs
60
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
61
+ assert isinstance(
62
+ input_width, int
63
+ ), f"Unexpected type for input_width {type(input_width)}{config}"
64
+ assert isinstance(
65
+ input_width, int
66
+ ), f"Unexpected type for input_height {type(input_height)}{config}"
67
+
68
+ batch = "batch"
69
+ seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
70
+ shapes = {
71
+ "input_ids": {0: batch, 1: seq_length},
72
+ "attention_mask": {0: batch, 1: seq_length},
73
+ "pixel_values": {
74
+ 0: torch.export.Dim("batch_img", min=1, max=1024),
75
+ # 2: torch.export.Dim("width", min=1, max=4096),
76
+ # 3: torch.export.Dim("height", min=1, max=4096),
77
+ },
78
+ }
79
+ inputs = dict(
80
+ input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length)).to(
81
+ torch.int64
82
+ ),
83
+ attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
84
+ pixel_values=torch.randn(
85
+ batch_size_image, input_channels, input_width, input_height
86
+ ).clamp(-1, 1),
87
+ )
88
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
89
+ if add_second_input:
90
+ assert (
91
+ add_second_input > 0
92
+ ), f"Not implemented for add_second_input={add_second_input}."
93
+ res["inputs2"] = get_inputs(
94
+ model=model,
95
+ config=config,
96
+ dummy_max_token_id=dummy_max_token_id,
97
+ batch_size=batch_size + 1,
98
+ sequence_length=sequence_length + add_second_input,
99
+ input_width=input_width,
100
+ input_height=input_height,
101
+ input_channels=input_channels,
102
+ batch_size_image=batch_size_image + 1,
103
+ add_second_input=0,
104
+ **kwargs,
105
+ )["inputs"]
106
+ return res
107
+
108
+
109
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
110
+ """
111
+ Inputs kwargs.
112
+
113
+ If the configuration is None, the function selects typical dimensions.
114
+ """
115
+ if config is not None:
116
+ check_hasattr(config, "vision_config", "text_config")
117
+ check_hasattr(config.vision_config, "image_size", "num_channels")
118
+ check_hasattr(config.text_config, "vocab_size")
119
+ kwargs = dict(
120
+ batch_size=2,
121
+ batch_size_image=3,
122
+ sequence_length=30,
123
+ dummy_max_token_id=(49408 if config is None else (config.text_config.vocab_size - 1)),
124
+ input_width=224 if config is None else config.vision_config.image_size,
125
+ input_height=224 if config is None else config.vision_config.image_size,
126
+ input_channels=3 if config is None else config.vision_config.num_channels,
127
+ )
128
+ return kwargs, get_inputs
@@ -0,0 +1,21 @@
1
+ from .onnx_export_errors import (
2
+ torch_export_patches,
3
+ register_additional_serialization_functions,
4
+ )
5
+ from .patch_module import torch_export_rewrite
6
+
7
+
8
+ # bypass_export_some_errors is the first name given to the patches.
9
+ bypass_export_some_errors = torch_export_patches # type: ignore
10
+
11
+
12
+ def register_flattening_functions(verbose: int = 0):
13
+ """
14
+ Registers functions to serialize deserialize cache or other classes
15
+ implemented in :epkg:`transformers` and used as inputs.
16
+ This is needed whenever a model must be exported through
17
+ :func:`torch.export.export`.
18
+ """
19
+ from .onnx_export_serialization import register_cache_serialization
20
+
21
+ return register_cache_serialization(verbose=verbose)