onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,143 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+ from ..helpers.config_helper import (
4
+ update_config,
5
+ check_hasattr,
6
+ default_num_hidden_layers as nhl,
7
+ )
8
+
9
+ __TASK__ = "mask-generation"
10
+
11
+
12
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
13
+ """Reduces a model size."""
14
+ kwargs: Dict[str, Any] = {}
15
+ if hasattr(config, "num_hidden_layers"):
16
+ config.num_hidden_layers = min(config.num_hidden_layers, nhl())
17
+ if hasattr(config, "vision_config") and hasattr(config.vision_config, "num_hidden_layers"):
18
+ config.vision_config.num_hidden_layers = min(config.vision_config.num_hidden_layers, 2)
19
+ update_config(config, kwargs)
20
+ return kwargs
21
+
22
+
23
+ def get_inputs(
24
+ model: torch.nn.Module,
25
+ config: Optional[Any],
26
+ batch_size: int,
27
+ width: int,
28
+ height: int,
29
+ num_channels: int,
30
+ output_channels: int,
31
+ window_size: int,
32
+ add_second_input: bool = True,
33
+ **kwargs, # unused
34
+ ):
35
+ """
36
+ Generates input for task ``mask-generation``.
37
+
38
+ :param model: model to get the missing information
39
+ :param config: configuration used to generate the model
40
+ :param batch_size: batch size
41
+ :param width: width of the image
42
+ :param height: height of the image
43
+ :param num_channels: number of channels in the image
44
+ :param output_channels: number of output channels
45
+ :param window_size: size of the window for the vision model
46
+ :return: dictionary with inputs and dynamic shapes
47
+
48
+ """
49
+ assert (
50
+ "cls_cache" not in kwargs
51
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
52
+
53
+ # TODO(anyone): input_masks is weirdly failing all the time with mismatch channels
54
+ # with Conv or embedding_size. I guess maybe the model is too implicit on the
55
+ # input_masks shape.
56
+
57
+ # TODO(titaiwang): modeling code specifically requires the height and width of inputs
58
+ # should be the same as the config.vision_config.image_size. Does that make sense?
59
+
60
+ shapes = {
61
+ "pixel_values": {0: "batch"}, # 1: num_channels is static
62
+ "input_points": {0: "batch", 1: "point_batch_size", 2: "nb_points_per_image"},
63
+ "input_boxes": {0: "batch", 1: "point_batch_size"},
64
+ # "input_masks": {0: "batch", 2: "height", 3: "width"},
65
+ }
66
+ inputs = dict(
67
+ pixel_values=torch.randn(
68
+ (batch_size, num_channels, height, width), dtype=torch.float32
69
+ ).clamp(-1, 1),
70
+ input_points=torch.randn(
71
+ (batch_size, 2, 10, 2), dtype=torch.float32
72
+ ), # 10 points per image
73
+ input_boxes=torch.randn((batch_size, 2, 4), dtype=torch.float32), # 1 box per image
74
+ # input_masks=torch.randn(
75
+ # (batch_size, 1, height, width), dtype=torch.float32
76
+ # ), # mask for the image
77
+ )
78
+
79
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
80
+ if add_second_input:
81
+ assert (
82
+ add_second_input > 0
83
+ ), f"Not implemented for add_second_input={add_second_input}."
84
+ res["inputs2"] = get_inputs(
85
+ model=model,
86
+ config=config,
87
+ batch_size=batch_size + 1,
88
+ width=width,
89
+ height=height,
90
+ num_channels=num_channels,
91
+ output_channels=output_channels,
92
+ window_size=window_size,
93
+ add_second_input=False,
94
+ **kwargs,
95
+ )["inputs"]
96
+ return res
97
+
98
+
99
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
100
+ """
101
+ Inputs kwargs.
102
+
103
+ If the configuration is None, the function selects typical dimensions.
104
+ """
105
+ if config is not None:
106
+ # generates mask as outputs
107
+ if hasattr(config, "mask_decoder_config"):
108
+ check_hasattr(
109
+ config.mask_decoder_config,
110
+ "hidden_size",
111
+ "iou_head_hidden_dim",
112
+ "iou_head_depth",
113
+ "num_hidden_layers",
114
+ "num_multimask_outputs",
115
+ )
116
+ if hasattr(config, "prompt_encoder_config"):
117
+ check_hasattr(
118
+ config.prompt_encoder_config,
119
+ "hidden_size",
120
+ "image_embedding_size",
121
+ "image_size",
122
+ "mask_input_channels",
123
+ )
124
+ if hasattr(config, "vision_config"):
125
+ check_hasattr(
126
+ config.vision_config,
127
+ "image_size",
128
+ "hidden_size",
129
+ "intermediate_size",
130
+ "num_hidden_layers",
131
+ "output_channels",
132
+ "num_channels",
133
+ "window_size",
134
+ )
135
+ kwargs = dict(
136
+ batch_size=2,
137
+ width=1024 if config is None else config.vision_config.image_size,
138
+ height=1024 if config is None else config.vision_config.image_size,
139
+ num_channels=3 if config is None else config.vision_config.num_channels,
140
+ output_channels=256 if config is None else config.vision_config.output_channels,
141
+ window_size=14 if config is None else config.vision_config.window_size,
142
+ )
143
+ return kwargs, get_inputs
@@ -0,0 +1,79 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+
4
+ # from ..helpers.cache_helper import make_dynamic_cache
5
+ from ..helpers.config_helper import update_config, default_num_hidden_layers as nhl
6
+
7
+ __TASK__ = "MoE"
8
+
9
+
10
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
11
+ """Reduces a model size."""
12
+ kwargs: Dict[str, Any] = {}
13
+ if hasattr(config, "num_hidden_layers"):
14
+ config.num_hidden_layers = min(config.num_hidden_layers, nhl())
15
+ if hasattr(config, "vision_config") and hasattr(config.vision_config, "num_hidden_layers"):
16
+ config.vision_config.num_hidden_layers = min(config.vision_config.num_hidden_layers, 2)
17
+ if hasattr(config, "audio_processor") and hasattr(
18
+ config.audio_processor, "num_hidden_layers"
19
+ ):
20
+ config.audio_processor.num_hidden_layers = min(
21
+ config.audio_processor.num_hidden_layers, 2
22
+ )
23
+ if hasattr(config, "audio_processor") and hasattr(config.audio_processor, "attention_dim"):
24
+ config.audio_processor.attention_dim = min(config.audio_processor.attention_dim, 2)
25
+ update_config(config, kwargs)
26
+ return kwargs
27
+
28
+
29
+ def get_inputs(
30
+ model: torch.nn.Module,
31
+ config: Optional[Any],
32
+ dummy_max_token_id: int,
33
+ num_key_value_heads: int,
34
+ num_hidden_layers: int,
35
+ head_dim: int,
36
+ width: int,
37
+ height: int,
38
+ num_channels: int,
39
+ batch_size: int = 2,
40
+ sequence_length: int = 30,
41
+ sequence_length2: int = 3,
42
+ n_images: int = 2,
43
+ dynamic_rope: bool = False,
44
+ add_second_input: int = 1,
45
+ **kwargs, # unused
46
+ ):
47
+ """
48
+ Generates input for task ``MoE``.
49
+
50
+ :param model: model to get the missing information
51
+ :param config: configuration used to generate the model
52
+ :param head_dim: last dimension of the cache
53
+ :param dummy_max_token_id: dummy max token id
54
+ :param batch_size: batch size
55
+ :param sequence_length: sequence length
56
+ :param sequence_length2: new sequence length
57
+ :param n_images: number of images
58
+ :param width: width of the image
59
+ :param height: height of the image
60
+ :param num_channels: number of channels
61
+ :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
62
+ :return: dictionary
63
+ """
64
+ assert (
65
+ "cls_cache" not in kwargs
66
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
67
+ assert not add_second_input, "add_second_input=True not yet implemented"
68
+ raise NotImplementedError(f"get_inputs not yet implemented for task {__TASK__!r}.")
69
+
70
+
71
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
72
+ """
73
+ Inputs kwargs.
74
+
75
+ If the configuration is None, the function selects typical dimensions.
76
+ """
77
+ raise NotImplementedError(
78
+ f"random_input_kwargs not yet implemented for task {__TASK__!r}."
79
+ )
@@ -0,0 +1,134 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+ from ..helpers.config_helper import (
4
+ update_config,
5
+ check_hasattr,
6
+ default_num_hidden_layers as nhl,
7
+ )
8
+
9
+ __TASK__ = "object-detection"
10
+
11
+
12
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
13
+ """Reduces a model size."""
14
+ check_hasattr(config, ("num_hidden_layers", "hidden_sizes"))
15
+ kwargs = dict(
16
+ num_hidden_layers=(
17
+ min(config.num_hidden_layers, nhl())
18
+ if hasattr(config, "num_hidden_layers")
19
+ else len(config.hidden_sizes)
20
+ )
21
+ )
22
+ update_config(config, kwargs)
23
+ return kwargs
24
+
25
+
26
+ def get_inputs(
27
+ model: torch.nn.Module,
28
+ config: Optional[Any],
29
+ input_width: int,
30
+ input_height: int,
31
+ input_channels: int,
32
+ batch_size: int = 2,
33
+ dynamic_rope: bool = False,
34
+ add_second_input: int = 1,
35
+ **kwargs, # unused
36
+ ):
37
+ """
38
+ Generates inputs for task ``object-detection``.
39
+
40
+ :param model: model to get the missing information
41
+ :param config: configuration used to generate the model
42
+ :param batch_size: batch size
43
+ :param input_channels: input channel
44
+ :param input_width: input width
45
+ :param input_height: input height
46
+ :return: dictionary
47
+ """
48
+ assert (
49
+ "cls_cache" not in kwargs
50
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
51
+ assert isinstance(
52
+ input_width, int
53
+ ), f"Unexpected type for input_width {type(input_width)}{config}"
54
+ assert isinstance(
55
+ input_width, int
56
+ ), f"Unexpected type for input_height {type(input_height)}{config}"
57
+
58
+ shapes = {
59
+ "pixel_values": {
60
+ 0: torch.export.Dim("batch", min=1, max=1024),
61
+ 2: "width",
62
+ 3: "height",
63
+ }
64
+ }
65
+ inputs = dict(
66
+ pixel_values=torch.randn(batch_size, input_channels, input_width, input_height).clamp(
67
+ -1, 1
68
+ ),
69
+ )
70
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
71
+ if add_second_input:
72
+ assert (
73
+ add_second_input > 0
74
+ ), f"Not implemented for add_second_input={add_second_input}."
75
+ res["inputs2"] = get_inputs(
76
+ model=model,
77
+ config=config,
78
+ input_width=input_width + add_second_input,
79
+ input_height=input_height + add_second_input,
80
+ input_channels=input_channels,
81
+ batch_size=batch_size + 1,
82
+ dynamic_rope=dynamic_rope,
83
+ add_second_input=0,
84
+ **kwargs,
85
+ )["inputs"]
86
+ return res
87
+
88
+
89
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
90
+ """
91
+ Inputs kwargs.
92
+
93
+ If the configuration is None, the function selects typical dimensions.
94
+ """
95
+ if config is not None:
96
+ if (
97
+ hasattr(config, "model_type")
98
+ and config.model_type == "timm_wrapper"
99
+ and not hasattr(config, "num_hidden_layers")
100
+ ):
101
+ input_size = config.pretrained_cfg["input_size"]
102
+ kwargs = dict(
103
+ batch_size=2,
104
+ input_width=input_size[-2],
105
+ input_height=input_size[-1],
106
+ input_channels=input_size[-3],
107
+ )
108
+ return kwargs, get_inputs
109
+
110
+ check_hasattr(config, ("image_size", "architectures"), "num_channels")
111
+ if config is not None:
112
+ if hasattr(config, "image_size"):
113
+ image_size = config.image_size
114
+ else:
115
+ assert config.architectures, f"empty architecture in {config}"
116
+ from ..torch_models.hghub.hub_api import get_architecture_default_values
117
+
118
+ default_values = get_architecture_default_values(config.architectures[0])
119
+ image_size = default_values["image_size"]
120
+ if config is None or isinstance(image_size, int):
121
+ kwargs = dict(
122
+ batch_size=2,
123
+ input_width=224 if config is None else image_size,
124
+ input_height=224 if config is None else image_size,
125
+ input_channels=3 if config is None else config.num_channels,
126
+ )
127
+ else:
128
+ kwargs = dict(
129
+ batch_size=2,
130
+ input_width=config.image_size[0],
131
+ input_height=config.image_size[1],
132
+ input_channels=config.num_channels,
133
+ )
134
+ return kwargs, get_inputs
@@ -0,0 +1,89 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+ from ..helpers.config_helper import (
4
+ update_config,
5
+ check_hasattr,
6
+ default_num_hidden_layers as nhl,
7
+ )
8
+
9
+ __TASK__ = "sentence-similarity"
10
+
11
+
12
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
13
+ """Reduces a model size."""
14
+ check_hasattr(config, "num_attention_heads", "num_hidden_layers")
15
+ kwargs = dict(
16
+ num_hidden_layers=min(config.num_hidden_layers, nhl()),
17
+ num_attention_heads=min(config.num_attention_heads, 4),
18
+ )
19
+ update_config(config, kwargs)
20
+ return kwargs
21
+
22
+
23
+ def get_inputs(
24
+ model: torch.nn.Module,
25
+ config: Optional[Any],
26
+ batch_size: int,
27
+ sequence_length: int,
28
+ dummy_max_token_id: int,
29
+ add_second_input: int = 1,
30
+ **kwargs, # unused
31
+ ):
32
+ """
33
+ Generates inputs for task ``sentence-similarity``.
34
+ Example:
35
+
36
+ ::
37
+
38
+ input_ids:T7s1x13[101,72654:A16789.23076923077],
39
+ token_type_ids:T7s1x13[0,0:A0.0],
40
+ attention_mask:T7s1x13[1,1:A1.0])
41
+ """
42
+ assert (
43
+ "cls_cache" not in kwargs
44
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
45
+ batch = "batch"
46
+ seq_length = "seq_length"
47
+ shapes = {
48
+ "input_ids": {0: batch, 1: seq_length},
49
+ "token_type_ids": {0: batch, 1: seq_length},
50
+ "attention_mask": {0: batch, 1: seq_length},
51
+ }
52
+ inputs = dict(
53
+ input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length)).to(
54
+ torch.int64
55
+ ),
56
+ token_type_ids=torch.zeros((batch_size, sequence_length)).to(torch.int64),
57
+ attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
58
+ )
59
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
60
+ if add_second_input:
61
+ assert (
62
+ add_second_input > 0
63
+ ), f"Not implemented for add_second_input={add_second_input}."
64
+ res["inputs2"] = get_inputs(
65
+ model=model,
66
+ config=config,
67
+ batch_size=batch_size + 1,
68
+ sequence_length=sequence_length + add_second_input,
69
+ dummy_max_token_id=dummy_max_token_id,
70
+ add_second_input=0,
71
+ **kwargs,
72
+ )["inputs"]
73
+ return res
74
+
75
+
76
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
77
+ """
78
+ Inputs kwargs.
79
+
80
+ If the configuration is None, the function selects typical dimensions.
81
+ """
82
+ if config is not None:
83
+ check_hasattr(config, "vocab_size")
84
+ kwargs = dict(
85
+ batch_size=2,
86
+ sequence_length=30,
87
+ dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
88
+ )
89
+ return kwargs, get_inputs
@@ -0,0 +1,227 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+ from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
4
+ from ..helpers.config_helper import (
5
+ update_config,
6
+ check_hasattr,
7
+ _pick,
8
+ default_num_hidden_layers as nhl,
9
+ )
10
+
11
+ __TASK__ = "summarization"
12
+
13
+
14
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
15
+ """Reduces a model size."""
16
+ kwargs: Dict[str, Any] = {}
17
+ if hasattr(config, "num_decoder_layers"):
18
+ config.num_decoder_layers = min(config.num_decoder_layers, 2)
19
+ if hasattr(config, "num_hidden_layers"):
20
+ config.num_hidden_layers = min(config.num_hidden_layers, nhl())
21
+ update_config(config, kwargs)
22
+ return kwargs
23
+
24
+
25
+ def get_inputs(
26
+ model: torch.nn.Module,
27
+ config: Optional[Any],
28
+ dummy_max_token_id: int,
29
+ num_key_value_heads_encoder: int,
30
+ num_key_value_heads_decoder: int,
31
+ num_hidden_layers: int,
32
+ head_dim_encoder: int,
33
+ head_dim_decoder: int,
34
+ batch_size: int = 2,
35
+ sequence_length: int = 30,
36
+ sequence_length2: int = 3,
37
+ add_second_input: int = 1,
38
+ **kwargs, # unused
39
+ ):
40
+ """
41
+ Generates input for task ``summarization``.
42
+
43
+ :param model: model to get the missing information
44
+ :param config: configuration used to generate the model
45
+ :param head_dim_encoder: last dimension of the cache for the encoder
46
+ :param head_dim_decoder: last dimension of the cache for the decoder
47
+ :param num_key_value_heads_encoder: number of heads for the encoder
48
+ :param num_key_value_heads_decoder: number of heads for the decoder
49
+ :param dummy_max_token_id: dummy max token id
50
+ :param batch_size: batch size
51
+ :param sequence_length: sequence length
52
+ :param sequence_length2: new sequence length
53
+ :return: dictionary
54
+
55
+ Stolen inputs for one model.
56
+
57
+ ::
58
+
59
+ cache_position:T7s1
60
+ past_key_values:EncoderDecoderCache(
61
+ self_attention_cache=DynamicCache(
62
+ key_cache=#6[T1s1x8x1x64,...],
63
+ value_cache=#6[T1s1x8x1x64,...]),
64
+ cross_attention_cache=DynamicCache(
65
+ key_cache=#6[T1s1x8x16x64,...],
66
+ value_cache=#6[T1s1x8x16x64,...])),
67
+ decoder_input_ids:T7s1x1,
68
+ encoder_outputs:dict(last_hidden_state:T1s1x16x512)
69
+ """
70
+ assert (
71
+ "cls_cache" not in kwargs
72
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
73
+ batch = "batch"
74
+ seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
75
+ cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096)
76
+ cache_length2 = "cache_length_val" # torch.export.Dim("cache_length2", min=1, max=4096)
77
+
78
+ shapes = {
79
+ "input_ids": {0: batch, 1: seq_length},
80
+ "decoder_input_ids": {0: batch, 1: "seq_ids"},
81
+ "attention_mask": {0: batch, 1: "seq_mask"},
82
+ # "cache_position": {0: batch, 1: torch.export.Dim.DYNAMIC},
83
+ "past_key_values": [
84
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2)],
85
+ [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers * 2)],
86
+ ],
87
+ # one these is selected based on the forward method signature
88
+ # "encoder_last_hidden_state": {0: batch, 1: torch.export.Dim.DYNAMIC},
89
+ # "encoder_outputs": {0: batch, 1: torch.export.Dim.DYNAMIC},
90
+ }
91
+
92
+ inputs = dict(
93
+ input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length)).to(
94
+ torch.int64
95
+ ),
96
+ decoder_input_ids=torch.randint(
97
+ 0, dummy_max_token_id, (batch_size, sequence_length2)
98
+ ).to(torch.int64),
99
+ attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
100
+ # cache_position=torch.arange(sequence_length, sequence_length + sequence_length2)
101
+ # .to(torch.int64)
102
+ # .expand((batch_size, -1)),
103
+ past_key_values=make_encoder_decoder_cache(
104
+ make_dynamic_cache(
105
+ [
106
+ (
107
+ torch.randn(
108
+ batch_size,
109
+ num_key_value_heads_encoder,
110
+ sequence_length,
111
+ head_dim_encoder,
112
+ ),
113
+ torch.randn(
114
+ batch_size,
115
+ num_key_value_heads_encoder,
116
+ sequence_length,
117
+ head_dim_encoder,
118
+ ),
119
+ )
120
+ for i in range(num_hidden_layers)
121
+ ]
122
+ ),
123
+ make_dynamic_cache(
124
+ [
125
+ (
126
+ torch.randn(
127
+ batch_size,
128
+ num_key_value_heads_decoder,
129
+ sequence_length2,
130
+ head_dim_decoder,
131
+ ),
132
+ torch.randn(
133
+ batch_size,
134
+ num_key_value_heads_decoder,
135
+ sequence_length2,
136
+ head_dim_decoder,
137
+ ),
138
+ )
139
+ for i in range(num_hidden_layers)
140
+ ]
141
+ ),
142
+ ),
143
+ )
144
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
145
+ if add_second_input:
146
+ assert (
147
+ add_second_input > 0
148
+ ), f"Not implemented for add_second_input={add_second_input}."
149
+ res["inputs2"] = get_inputs(
150
+ model=model,
151
+ config=config,
152
+ dummy_max_token_id=dummy_max_token_id,
153
+ num_key_value_heads_encoder=num_key_value_heads_encoder,
154
+ num_key_value_heads_decoder=num_key_value_heads_decoder,
155
+ num_hidden_layers=num_hidden_layers,
156
+ head_dim_encoder=head_dim_encoder,
157
+ head_dim_decoder=head_dim_decoder,
158
+ batch_size=batch_size + 1,
159
+ sequence_length=sequence_length + add_second_input,
160
+ sequence_length2=sequence_length2 + 1,
161
+ add_second_input=0,
162
+ **kwargs,
163
+ )["inputs"]
164
+ return res
165
+
166
+
167
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
168
+ """
169
+ Inputs kwargs.
170
+
171
+ If the configuration is None, the function selects typical dimensions.
172
+ """
173
+ if config is not None:
174
+ check_hasattr(
175
+ config,
176
+ "vocab_size",
177
+ "hidden_size",
178
+ "num_attention_heads",
179
+ ("num_hidden_layers", "num_layers"),
180
+ ("n_positions", "d_model"),
181
+ (
182
+ "num_key_value_heads",
183
+ "num_heads",
184
+ ("decoder_attention_heads", "encoder_attention_heads"),
185
+ ),
186
+ )
187
+ # exceptions = {
188
+ # "PLBartForConditionalGeneration": (
189
+ # lambda c: c.encoder_attention_heads + c.decoder_attention_heads
190
+ # )
191
+ # }
192
+ kwargs = dict(
193
+ batch_size=2,
194
+ sequence_length=30,
195
+ sequence_length2=3,
196
+ head_dim_encoder=(
197
+ 16 if config is None else int(_pick(config, "encoder_ffn_dim") ** 0.5)
198
+ ),
199
+ head_dim_decoder=(
200
+ 16 if config is None else int(_pick(config, "decoder_ffn_dim") ** 0.5)
201
+ ),
202
+ dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
203
+ num_hidden_layers=(
204
+ 8 if config is None else _pick(config, "num_hidden_layers", "num_layers")
205
+ ),
206
+ num_key_value_heads_encoder=(
207
+ 16
208
+ if config is None
209
+ else _pick(
210
+ config,
211
+ "encoder_attention_heads",
212
+ "num_key_value_heads",
213
+ "num_heads",
214
+ )
215
+ ),
216
+ num_key_value_heads_decoder=(
217
+ 16
218
+ if config is None
219
+ else _pick(
220
+ config,
221
+ "decoder_attention_heads",
222
+ "num_key_value_heads",
223
+ "num_heads",
224
+ )
225
+ ),
226
+ )
227
+ return kwargs, get_inputs