onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,121 @@
1
+ from typing import Optional, Tuple
2
+ import onnx
3
+ import torch
4
+ from . import OpRunKernel, OpRunTensor
5
+
6
+
7
+ class ConstantOfShape_9(OpRunKernel):
8
+ "ConstantOfShape"
9
+
10
+ @classmethod
11
+ def device_dependent(cls) -> bool:
12
+ """
13
+ Returns True if the kernel needs a device to be efficiently initialized.
14
+ """
15
+ return True
16
+
17
+ def __init__(
18
+ self,
19
+ node: onnx.NodeProto,
20
+ version: Optional[int] = None,
21
+ device: Optional[torch.device] = None,
22
+ verbose: int = 0,
23
+ ):
24
+ super().__init__(node, version, verbose=verbose)
25
+ value = self.get_attribute_tensor(node, "value")
26
+ if value is None:
27
+ value = torch.tensor([0], dtype=torch.float32)
28
+ self.dtype = value.dtype
29
+ self.device = device
30
+ self.value = value[0]
31
+
32
+ def run(self, shape: OpRunTensor) -> OpRunTensor:
33
+ # The device is unknown as shapes usually take place on CPU.
34
+ return OpRunTensor(
35
+ torch.full(
36
+ shape.as_tuple_int, fill_value=self.value, dtype=self.dtype, device=self.device
37
+ )
38
+ )
39
+
40
+
41
+ class Expand_8(OpRunKernel):
42
+ "Expand"
43
+
44
+ def run(self, data: OpRunTensor, shape: OpRunTensor) -> OpRunTensor:
45
+ ishape = tuple(-1 if i == 1 else i for i in shape.as_tuple_int)
46
+ return OpRunTensor(data.tensor.expand(ishape))
47
+
48
+
49
+ class Reshape_14(OpRunKernel):
50
+ "Reshape"
51
+
52
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
53
+ super().__init__(node, version, verbose=verbose)
54
+ self.allowzero = self.get_attribute_int(node, "allowzero", 0)
55
+
56
+ def run(self, data: OpRunTensor, shape: OpRunTensor) -> OpRunTensor:
57
+ ishape = shape.as_tuple_int
58
+ assert ishape is not None, f"Unexpected return for shape={shape!r}"
59
+ if not self.allowzero and 0 in ishape:
60
+ xshape = data.tensor.shape
61
+ new_shape = []
62
+ for i, s in enumerate(ishape):
63
+ new_shape.append(xshape[i] if s == 0 else s)
64
+ return OpRunTensor(data.tensor.reshape(new_shape))
65
+ return OpRunTensor(data.tensor.reshape(ishape))
66
+
67
+
68
+ class Shape_15(OpRunKernel):
69
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
70
+ super().__init__(node, version, verbose=verbose)
71
+ self.start = self.get_attribute_int(node, "start", 0)
72
+ self.end = self.get_attribute_int(node, "end", None)
73
+
74
+ def run(self, data: OpRunTensor) -> OpRunTensor:
75
+ shape = data.shape
76
+ sh = shape[self.start :] if self.end is None else shape[self.start : self.end]
77
+ return OpRunTensor(torch.tensor(sh, dtype=torch.int64), is_constant=True)
78
+
79
+
80
+ class Split_18(OpRunKernel):
81
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
82
+ super().__init__(node, version, verbose=verbose)
83
+ self.axis = self.get_attribute_int(node, "axis", 0)
84
+ self.num_outputs = self.get_attribute_int(node, "num_outputs", None)
85
+
86
+ def run(
87
+ self, data: OpRunTensor, split: Optional[OpRunTensor] = None
88
+ ) -> Tuple[OpRunTensor, ...]:
89
+ if split is None:
90
+ assert isinstance(
91
+ self.num_outputs, int
92
+ ), f"Incompatibilities: split is None and num_outputs={self.num_outputs}"
93
+ size = data.tensor.shape[self.axis]
94
+ split_size = (
95
+ size // self.num_outputs
96
+ if size % self.num_outputs == 0
97
+ else size // self.num_outputs + 1
98
+ )
99
+ spl = torch.split(data.tensor, split_size, dim=self.axis)
100
+ else:
101
+ spl = torch.split(data.tensor, split.as_tuple_int, dim=self.axis)
102
+ return tuple(OpRunTensor(t) for t in spl)
103
+
104
+
105
+ class Squeeze_13(OpRunKernel):
106
+ "Squeeze"
107
+
108
+ def run(self, data: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
109
+ if axes is None:
110
+ return OpRunTensor(data.tensor.squeeze())
111
+ return OpRunTensor(data.tensor.squeeze(axes.as_tuple_int))
112
+
113
+
114
+ class Unsqueeze_13(OpRunKernel):
115
+ "Unsqueeze"
116
+
117
+ def run(self, data: OpRunTensor, axes: OpRunTensor) -> OpRunTensor:
118
+ t = data.tensor
119
+ for i in axes.as_tuple_int:
120
+ t = t.unsqueeze(i)
121
+ return OpRunTensor(t)
@@ -0,0 +1,93 @@
1
+ import torch
2
+ from . import OpRunKernel, OpRunTensor
3
+
4
+
5
+ class Abs_1(OpRunKernel):
6
+ """Abs"""
7
+
8
+ def run(self, x: OpRunTensor) -> OpRunTensor:
9
+ return OpRunTensor(torch.abs(x.tensor))
10
+
11
+
12
+ class Cos_1(OpRunKernel):
13
+ """Cos"""
14
+
15
+ def run(self, x: OpRunTensor) -> OpRunTensor:
16
+ return OpRunTensor(x.tensor.cos())
17
+
18
+
19
+ class Erf_9(OpRunKernel):
20
+ """Erf"""
21
+
22
+ def run(self, x: OpRunTensor) -> OpRunTensor:
23
+ return OpRunTensor(x.tensor.erf())
24
+
25
+
26
+ class Exp_1(OpRunKernel):
27
+ """Exp"""
28
+
29
+ def run(self, x: OpRunTensor) -> OpRunTensor:
30
+ return OpRunTensor(x.tensor.exp())
31
+
32
+
33
+ class Identity_1(OpRunKernel):
34
+ "Identity"
35
+
36
+ def run(self, x: OpRunTensor) -> OpRunTensor:
37
+ return OpRunTensor(x.tensor)
38
+
39
+
40
+ class IsNaN_9(OpRunKernel):
41
+ """IsNaN"""
42
+
43
+ def run(self, x: OpRunTensor) -> OpRunTensor:
44
+ return OpRunTensor(x.tensor.isnan())
45
+
46
+
47
+ class Log_1(OpRunKernel):
48
+ """Log"""
49
+
50
+ def run(self, x: OpRunTensor) -> OpRunTensor:
51
+ return OpRunTensor(x.tensor.log())
52
+
53
+
54
+ class Neg_1(OpRunKernel):
55
+ """Neg"""
56
+
57
+ def run(self, x: OpRunTensor) -> OpRunTensor:
58
+ return OpRunTensor(-x.tensor)
59
+
60
+
61
+ class Not_1(OpRunKernel):
62
+ """Not"""
63
+
64
+ def run(self, x: OpRunTensor) -> OpRunTensor:
65
+ return OpRunTensor(~x.tensor)
66
+
67
+
68
+ class Reciprocal_1(OpRunKernel):
69
+ """REciprocal"""
70
+
71
+ def run(self, x: OpRunTensor) -> OpRunTensor:
72
+ return OpRunTensor(1 / x.tensor)
73
+
74
+
75
+ class Sigmoid_6(OpRunKernel):
76
+ """Sqrt"""
77
+
78
+ def run(self, x: OpRunTensor) -> OpRunTensor:
79
+ return OpRunTensor(torch.sigmoid(x.tensor))
80
+
81
+
82
+ class Sin_1(OpRunKernel):
83
+ """Sin"""
84
+
85
+ def run(self, x: OpRunTensor) -> OpRunTensor:
86
+ return OpRunTensor(x.tensor.sin())
87
+
88
+
89
+ class Sqrt_1(OpRunKernel):
90
+ """Sqrt"""
91
+
92
+ def run(self, x: OpRunTensor) -> OpRunTensor:
93
+ return OpRunTensor(x.tensor.sqrt())
@@ -0,0 +1,90 @@
1
+ from typing import Any, Callable, Dict, List, Tuple
2
+ from . import (
3
+ automatic_speech_recognition,
4
+ feature_extraction,
5
+ fill_mask,
6
+ image_classification,
7
+ image_text_to_text,
8
+ image_to_video,
9
+ mask_generation,
10
+ mixture_of_expert,
11
+ object_detection,
12
+ sentence_similarity,
13
+ summarization,
14
+ text_classification,
15
+ text_generation,
16
+ text_to_image,
17
+ text2text_generation,
18
+ zero_shot_image_classification,
19
+ )
20
+
21
+ __TASKS__ = [
22
+ automatic_speech_recognition,
23
+ feature_extraction,
24
+ fill_mask,
25
+ image_classification,
26
+ image_text_to_text,
27
+ image_to_video,
28
+ mask_generation,
29
+ mixture_of_expert,
30
+ object_detection,
31
+ sentence_similarity,
32
+ summarization,
33
+ text_classification,
34
+ text_generation,
35
+ text_to_image,
36
+ text2text_generation,
37
+ zero_shot_image_classification,
38
+ ]
39
+
40
+
41
+ def supported_tasks() -> List[str]:
42
+ "Returns the list of supported tasks."
43
+ return sorted(mod.__TASK__ for mod in __TASKS__)
44
+
45
+
46
+ def reduce_model_config(config: Any, task: str) -> Dict[str, Any]:
47
+ """Reduces a model size."""
48
+ head_size0 = (
49
+ config.head_dim
50
+ if hasattr(config, "head_dim") and config.head_dim
51
+ else (
52
+ config.hidden_size // config.num_attention_heads
53
+ if hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads")
54
+ else None
55
+ )
56
+ )
57
+ tasks = {mod.__TASK__: mod.reduce_model_config for mod in __TASKS__}
58
+ assert task in tasks, f"Task {task!r} not found in {sorted(tasks)}"
59
+ res = tasks[task](config)
60
+ if head_size0 and "head_dim" in res:
61
+ head_size = (
62
+ config.head_dim
63
+ if hasattr(config, "head_dim") and config.head_dim
64
+ else config.hidden_size // config.num_attention_heads
65
+ )
66
+ assert head_size0 == head_size or head_size % 16 == 0, (
67
+ f"head_size should be a multiple of 16 "
68
+ f"(head_size0={head_size0}), res={res}, "
69
+ f"config=\n{config}"
70
+ )
71
+ return res
72
+
73
+
74
+ def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]:
75
+ """
76
+ Inputs kwargs.
77
+ If the configuration is None, the function selects typical dimensions.
78
+ It returns parameters and a function. The function creates dummy inputs
79
+ if it receives the parameters returned as a first result.
80
+
81
+ .. code-block:: python
82
+
83
+ config = get_pretrained_config(model_id)
84
+ task = task = task_from_id(name)
85
+ kwargs, fct = random_input_kwargs(config, task)
86
+ res = fct(model, config, add_second_input=False, **kwargs)
87
+ """
88
+ tasks = {mod.__TASK__: mod.random_input_kwargs for mod in __TASKS__}
89
+ assert task in tasks, f"Task {task!r} not found in {sorted(tasks)}"
90
+ return tasks[task](config)
@@ -0,0 +1,188 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+ import transformers
4
+ from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
5
+ from ..helpers.config_helper import (
6
+ update_config,
7
+ check_hasattr,
8
+ default_num_hidden_layers as nhl,
9
+ )
10
+
11
+ __TASK__ = "automatic-speech-recognition"
12
+
13
+
14
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
15
+ """Reduces a model size."""
16
+ kwargs: Dict[str, Any] = {}
17
+ if hasattr(config, "num_decoder_layers"):
18
+ config.num_decoder_layers = min(config.num_decoder_layers, 2)
19
+ if hasattr(config, "decoder_layers"):
20
+ config.decoder_layers = min(config.decoder_layers, 2)
21
+ if hasattr(config, "num_hidden_layers"):
22
+ config.num_hidden_layers = min(config.num_hidden_layers, nhl())
23
+ update_config(config, kwargs)
24
+ return kwargs
25
+
26
+
27
+ def get_inputs(
28
+ model: torch.nn.Module,
29
+ config: Optional[Any],
30
+ dummy_max_token_id: int,
31
+ max_source_positions: int,
32
+ d_model: int,
33
+ num_hidden_layers: int,
34
+ encoder_attention_heads: int,
35
+ encoder_layers: int,
36
+ decoder_layers: int,
37
+ head_dim: int,
38
+ batch_size: int = 2,
39
+ sequence_length: int = 30,
40
+ add_second_input: int = 1,
41
+ **kwargs, # unused
42
+ ):
43
+ """
44
+ Generates inputs for task ``automatic-speech-recognition``.
45
+ Example:
46
+
47
+ ::
48
+
49
+ dict(
50
+ cache_position:T7s4,
51
+ past_key_values:EncoderDecoderCache(
52
+ self_attention_cache=DynamicCache[serialized](#2[#0[],#0[]]),
53
+ cross_attention_cache=DynamicCache[serialized](#2[#0[],#0[]])
54
+ ),
55
+ decoder_input_ids:T7s1x4,
56
+ encoder_outputs:BaseModelOutput(last_hidden_state:T1s1x1500x384),
57
+ use_cache:bool,return_dict:bool
58
+ )
59
+ dict(
60
+ cache_position:T7s1,
61
+ past_key_values:EncoderDecoderCache(
62
+ self_attention_cache=DynamicCache[serialized](#2[
63
+ #4[T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64],
64
+ #4[T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64]
65
+ ]),
66
+ cross_attention_cache=DynamicCache[serialized](#2[
67
+ #4[T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64],
68
+ #4[T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64]
69
+ ]),
70
+ ),
71
+ decoder_input_ids:T7s1x1,
72
+ encoder_outputs:BaseModelOutput(last_hidden_state:T1s1x1500x384),
73
+ use_cache:bool,return_dict:bool
74
+ )
75
+ """
76
+ assert (
77
+ "cls_cache" not in kwargs
78
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
79
+ batch = "batch"
80
+ seq_length = "seq_length"
81
+
82
+ shapes = {
83
+ "decoder_input_ids": {0: batch, 1: seq_length},
84
+ "cache_position": {0: seq_length},
85
+ "encoder_outputs": [{0: batch}], # last_hidden_state
86
+ "past_key_values": [
87
+ [{0: batch} for _ in range(num_hidden_layers * 2)],
88
+ [{0: batch} for _ in range(num_hidden_layers * 2)],
89
+ ],
90
+ }
91
+ inputs = dict(
92
+ decoder_input_ids=torch.randint(
93
+ 0, dummy_max_token_id, (batch_size, sequence_length)
94
+ ).to(torch.int64),
95
+ cache_position=(torch.arange(sequence_length) + 5).to(torch.int64),
96
+ encoder_outputs=transformers.modeling_outputs.BaseModelOutput(
97
+ last_hidden_state=torch.randn(batch_size, max_source_positions, d_model)
98
+ ),
99
+ past_key_values=make_encoder_decoder_cache(
100
+ make_dynamic_cache(
101
+ [
102
+ (
103
+ torch.randn(
104
+ batch_size, encoder_attention_heads, encoder_layers, head_dim
105
+ ),
106
+ torch.randn(
107
+ batch_size, encoder_attention_heads, encoder_layers, head_dim
108
+ ),
109
+ )
110
+ for i in range(num_hidden_layers)
111
+ ]
112
+ ),
113
+ make_dynamic_cache(
114
+ [
115
+ (
116
+ torch.randn(
117
+ batch_size, encoder_attention_heads, max_source_positions, head_dim
118
+ ),
119
+ torch.randn(
120
+ batch_size, encoder_attention_heads, max_source_positions, head_dim
121
+ ),
122
+ )
123
+ for i in range(num_hidden_layers)
124
+ ]
125
+ ),
126
+ ),
127
+ # one these is selected based on the forward method signature
128
+ # encoder_last_hidden_state=torch.randn(batch_size, sequence_length2, encoder_dim),
129
+ # encoder_outputs=torch.randn(batch_size, sequence_length2, encoder_dim),
130
+ )
131
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
132
+ if add_second_input:
133
+ assert (
134
+ add_second_input > 0
135
+ ), f"Not implemented for add_second_input={add_second_input}."
136
+ res["inputs2"] = get_inputs(
137
+ model=model,
138
+ config=config,
139
+ dummy_max_token_id=dummy_max_token_id,
140
+ max_source_positions=max_source_positions,
141
+ d_model=d_model,
142
+ num_hidden_layers=num_hidden_layers,
143
+ encoder_attention_heads=encoder_attention_heads,
144
+ encoder_layers=encoder_layers,
145
+ decoder_layers=decoder_layers,
146
+ head_dim=head_dim,
147
+ batch_size=batch_size + 1,
148
+ sequence_length=sequence_length + add_second_input,
149
+ add_second_input=0,
150
+ **kwargs,
151
+ )["inputs"]
152
+ return res
153
+
154
+
155
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
156
+ """
157
+ Inputs kwargs.
158
+
159
+ If the configuration is None, the function selects typical dimensions.
160
+ """
161
+ if config is not None:
162
+ check_hasattr(
163
+ config,
164
+ "d_model",
165
+ "decoder_attention_heads",
166
+ "decoder_layers",
167
+ "encoder_attention_heads",
168
+ "encoder_layers",
169
+ "max_source_positions",
170
+ "num_hidden_layers",
171
+ "vocab_size",
172
+ )
173
+ kwargs = dict(
174
+ batch_size=2,
175
+ sequence_length=30,
176
+ dummy_max_token_id=31000 if config is None else config.vocab_size,
177
+ max_source_positions=1500 if config is None else config.max_source_positions,
178
+ d_model=384 if config is None else config.d_model,
179
+ num_hidden_layers=4 if config is None else config.num_hidden_layers,
180
+ encoder_attention_heads=6 if config is None else config.encoder_attention_heads,
181
+ encoder_layers=4 if config is None else config.encoder_layers,
182
+ decoder_attention_heads=6 if config is None else config.decoder_attention_heads,
183
+ decoder_layers=4 if config is None else config.decoder_layers,
184
+ head_dim=(
185
+ 64 if config is None else (config.d_model // config.encoder_attention_heads)
186
+ ),
187
+ )
188
+ return kwargs, get_inputs
@@ -0,0 +1,13 @@
1
+ import os
2
+
3
+
4
+ def get_data(name: str):
5
+ """Returns data stored in this folder."""
6
+ filename = os.path.join(os.path.dirname(__file__), name)
7
+ assert os.path.exists(
8
+ filename
9
+ ), f"Unable to find a file with {name!r}, looked for {filename!r}"
10
+
11
+ from ...helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
12
+
13
+ return create_input_tensors_from_onnx_model(filename)
@@ -0,0 +1,162 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+ from ..helpers.config_helper import (
4
+ update_config,
5
+ check_hasattr,
6
+ default_num_hidden_layers as nhl,
7
+ )
8
+ from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
9
+
10
+
11
+ __TASK__ = "feature-extraction"
12
+
13
+
14
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
15
+ """Reduces a model size."""
16
+ check_hasattr(config, "num_hidden_layers")
17
+ kwargs = dict(num_hidden_layers=min(config.num_hidden_layers, nhl()))
18
+ update_config(config, kwargs)
19
+ return kwargs
20
+
21
+
22
+ def get_inputs(
23
+ model: torch.nn.Module,
24
+ config: Optional[Any],
25
+ batch_size: int,
26
+ sequence_length: int,
27
+ dummy_max_token_id: int,
28
+ sequence_length2: int = 3,
29
+ decoder_attention_heads: Optional[int] = None,
30
+ encoder_attention_heads: Optional[int] = None,
31
+ encoder_ffn_dim: Optional[int] = None,
32
+ decoder_ffn_dim: Optional[int] = None,
33
+ num_hidden_layers: Optional[int] = None,
34
+ add_second_input: int = 1,
35
+ **kwargs, # unused
36
+ ):
37
+ """
38
+ Generates inputs for task ``feature-extraction``.
39
+ Example:
40
+
41
+ ::
42
+
43
+ input_ids:T7s1x13[101,72654:A16789.23076923077],
44
+ token_type_ids:T7s1x13[0,0:A0.0],
45
+ attention_mask:T7s1x13[1,1:A1.0])
46
+ """
47
+ assert (
48
+ "cls_cache" not in kwargs
49
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
50
+ batch = "batch"
51
+ seq_length = "sequence_length"
52
+ shapes = {
53
+ "input_ids": {0: batch, 1: seq_length},
54
+ "attention_mask": {0: batch, 1: seq_length},
55
+ }
56
+ inputs = dict(
57
+ input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length)).to(
58
+ torch.int64
59
+ ),
60
+ attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
61
+ )
62
+ if (
63
+ encoder_attention_heads
64
+ and decoder_attention_heads
65
+ and encoder_ffn_dim
66
+ and decoder_ffn_dim
67
+ and num_hidden_layers
68
+ ):
69
+ inputs["past_key_values"] = make_encoder_decoder_cache(
70
+ make_dynamic_cache(
71
+ [
72
+ (
73
+ torch.randn(
74
+ batch_size,
75
+ encoder_attention_heads,
76
+ sequence_length,
77
+ encoder_ffn_dim,
78
+ ),
79
+ torch.randn(
80
+ batch_size,
81
+ encoder_attention_heads,
82
+ sequence_length,
83
+ encoder_ffn_dim,
84
+ ),
85
+ )
86
+ for i in range(num_hidden_layers)
87
+ ]
88
+ ),
89
+ make_dynamic_cache(
90
+ [
91
+ (
92
+ torch.randn(
93
+ batch_size,
94
+ decoder_attention_heads,
95
+ sequence_length2,
96
+ decoder_ffn_dim,
97
+ ),
98
+ torch.randn(
99
+ batch_size,
100
+ decoder_attention_heads,
101
+ sequence_length2,
102
+ decoder_ffn_dim,
103
+ ),
104
+ )
105
+ for i in range(num_hidden_layers)
106
+ ]
107
+ ),
108
+ )
109
+ cache_length = "cache_length_key"
110
+ cache_length2 = "cache_length_val"
111
+ shapes["past_key_values"] = [ # type: ignore[assignment]
112
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2)],
113
+ [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers * 2)],
114
+ ]
115
+
116
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
117
+ if add_second_input:
118
+ assert (
119
+ add_second_input > 0
120
+ ), f"Not implemented for add_second_input={add_second_input}."
121
+ res["inputs2"] = get_inputs(
122
+ model=model,
123
+ config=config,
124
+ batch_size=batch_size + 1,
125
+ sequence_length=sequence_length + add_second_input,
126
+ dummy_max_token_id=dummy_max_token_id,
127
+ sequence_length2=sequence_length2,
128
+ decoder_attention_heads=decoder_attention_heads,
129
+ encoder_attention_heads=encoder_attention_heads,
130
+ encoder_ffn_dim=encoder_ffn_dim,
131
+ decoder_ffn_dim=decoder_ffn_dim,
132
+ num_hidden_layers=num_hidden_layers,
133
+ add_second_input=0,
134
+ **kwargs,
135
+ )["inputs"]
136
+ return res
137
+
138
+
139
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
140
+ """
141
+ Inputs kwargs.
142
+
143
+ If the configuration is None, the function selects typical dimensions.
144
+ """
145
+ if config is not None:
146
+ check_hasattr(config, "vocab_size")
147
+ kwargs = dict(
148
+ batch_size=2,
149
+ sequence_length=30,
150
+ dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
151
+ )
152
+ for att in [
153
+ "decoder_attention_heads",
154
+ "encoder_attention_heads",
155
+ "encoder_ffn_dim",
156
+ "decoder_ffn_dim",
157
+ "num_hidden_layers",
158
+ ]:
159
+ if hasattr(config, att):
160
+ kwargs[att] = getattr(config, att)
161
+ kwargs["decoder_ffn_dim"] = kwargs["encoder_ffn_dim"] = 64
162
+ return kwargs, get_inputs