onnx-diagnostic 0.6.0__py3-none-any.whl → 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +18 -0
  3. onnx_diagnostic/api.py +15 -0
  4. onnx_diagnostic/ext_test_case.py +3 -1
  5. onnx_diagnostic/helpers/args_helper.py +1 -1
  6. onnx_diagnostic/helpers/doc_helper.py +143 -0
  7. onnx_diagnostic/helpers/helper.py +6 -5
  8. onnx_diagnostic/helpers/model_builder_helper.py +24 -8
  9. onnx_diagnostic/helpers/rt_helper.py +5 -1
  10. onnx_diagnostic/helpers/torch_helper.py +2 -0
  11. onnx_diagnostic/reference/__init__.py +1 -0
  12. onnx_diagnostic/reference/torch_evaluator.py +648 -0
  13. onnx_diagnostic/reference/torch_ops/__init__.py +55 -0
  14. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  15. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  16. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  17. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  18. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  19. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  20. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  21. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  22. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  23. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  24. onnx_diagnostic/reference/torch_ops/unary_ops.py +86 -0
  25. onnx_diagnostic/tasks/__init__.py +22 -1
  26. onnx_diagnostic/tasks/image_classification.py +2 -2
  27. onnx_diagnostic/tasks/text_generation.py +3 -3
  28. onnx_diagnostic/torch_export_patches/eval/__init__.py +106 -37
  29. onnx_diagnostic/torch_export_patches/eval/model_cases.py +12 -25
  30. onnx_diagnostic/torch_export_patches/patch_module_helper.py +130 -16
  31. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +88 -0
  32. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +142 -0
  33. onnx_diagnostic/torch_models/test_helper.py +133 -16
  34. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  35. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/METADATA +1 -1
  36. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/RECORD +39 -23
  37. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/WHEEL +1 -1
  38. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/licenses/LICENSE.txt +0 -0
  39. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  import ast
2
- from typing import Any, List, Optional
2
+ import functools
3
+ from typing import Any, Dict, List, Optional
3
4
 
4
5
 
5
6
  class OrToBitOrTransformer(ast.NodeTransformer):
@@ -19,10 +20,129 @@ def ast_or_into_bitor(node: "ast.Node") -> "ast.Node":
19
20
  return new_node
20
21
 
21
22
 
22
- def _rewrite_bart_encoder_layer():
23
- "BartEncoderLayer, PLBartEncoderLayer"
23
+ @functools.lru_cache
24
+ def _rewrite_forward_clamp_float16() -> Dict[str, List[type]]:
25
+
24
26
  import transformers
25
27
 
28
+ _known = {
29
+ "AutoformerEncoderLayer": [
30
+ transformers.models.autoformer.modeling_autoformer.AutoformerEncoderLayer
31
+ ],
32
+ "BartEncoderLayer": [
33
+ transformers.models.bart.modeling_bart.BartEncoderLayer,
34
+ transformers.models.plbart.modeling_plbart.PLBartEncoderLayer,
35
+ ],
36
+ "BigBirdPegasusEncoderLayer": [
37
+ transformers.models.bigbird_pegasus.modeling_bigbird_pegasus.BigBirdPegasusEncoderLayer
38
+ ],
39
+ "BlenderbotSmallEncoderLayer": [
40
+ transformers.models.blenderbot_small.modeling_blenderbot_small.BlenderbotSmallEncoderLayer
41
+ ],
42
+ "InformerEncoderLayer": [
43
+ transformers.models.informer.modeling_informer.InformerEncoderLayer
44
+ ],
45
+ "LEDEncoderLayer": [transformers.models.led.modeling_led.LEDEncoderLayer],
46
+ "MarianEncoderLayer": [transformers.models.marian.modeling_marian.MarianEncoderLayer],
47
+ "MvpEncoderLayer": [transformers.models.mvp.modeling_mvp.MvpEncoderLayer],
48
+ "NllbMoeEncoderLayer": [
49
+ transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeEncoderLayer
50
+ ],
51
+ "TimeSeriesTransformerEncoderLayer": [
52
+ transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerEncoderLayer
53
+ ],
54
+ }
55
+ return _known
56
+
57
+
58
+ @functools.lru_cache
59
+ def known_transformers_rewritings_clamp_float16() -> Dict[str, str]:
60
+ """
61
+ This functions returns the list of known classes to be rewritten.
62
+ in :epkg:`transformers`. Each class is mapped to an alias,
63
+ this alias is then given to :func:`rewritings_transformers_clamp_float16`
64
+ to rewrite the encoder layers because of a specific control flow.
65
+
66
+ .. runpython::
67
+ :showcode:
68
+
69
+ import pprint
70
+ from onnx_diagnostic.torch_export_patches.patch_module_helper import (
71
+ known_transformers_rewritings_clamp_float16,
72
+ )
73
+
74
+ pprint.pprint(known_transformers_rewritings_clamp_float16())
75
+ """
76
+ _alias = {
77
+ "AutoformerEncoder": "AutoformerEncoderLayer",
78
+ "AutoformerEncoderLayer": "AutoformerEncoderLayer",
79
+ "AutoformerForPrediction": "AutoformerEncoderLayer",
80
+ "AutoformerModel": "AutoformerEncoderLayer",
81
+ "BartEncoderLayer": "BartEncoderLayer",
82
+ "BartForConditionalGeneration": "BartEncoderLayer",
83
+ "BigBirdPegasusForConditionalGeneration": "BigBirdPegasusEncoderLayer",
84
+ "BigBirdPegasusForQuestionAnswering": "BigBirdPegasusEncoderLayer",
85
+ "BigBirdPegasusForCausalLM": "BigBirdPegasusEncoderLayer",
86
+ "BlenderbotSmallEncoderLayer": "BlenderbotSmallEncoderLayer",
87
+ "BlenderbotSmallForConditionalGeneration": "BlenderbotSmallEncoderLayer",
88
+ "BlenderbotSmallForCausalLM": "BlenderbotSmallEncoderLayer",
89
+ "InformerEncoderLayer": "InformerEncoderLayer",
90
+ "InformerForPrediction": "InformerEncoderLayer",
91
+ "LEDEncoderLayer": "LEDEncoderLayer",
92
+ "LEDClassificationHead": "LEDEncoderLayer",
93
+ "LEDForConditionalGeneration": "LEDEncoderLayer",
94
+ "MarianEncoderLayer": "MarianEncoderLayer",
95
+ "MarianEncoder": "MarianEncoderLayer",
96
+ "MarianModel": "MarianEncoderLayer",
97
+ "MarianMTModel": "MarianEncoderLayer",
98
+ "MvpEncoderLayer": "MvpEncoderLayer",
99
+ "MvpPrompt": "MvpEncoderLayer",
100
+ "MvpForConditionalGeneration": "MvpEncoderLayer",
101
+ "MvpForSequenceClassification": "MvpEncoderLayer",
102
+ "MvpForQuestionAnswering": "MvpEncoderLayer",
103
+ "MvpForCausalLM": "MvpEncoderLayer",
104
+ "NllbMoeEncoderLayer": "NllbMoeEncoderLayer",
105
+ "NllbMoeForConditionalGeneration": "NllbMoeEncoderLayer",
106
+ "PLBartEncoderLayer": "BartEncoderLayer",
107
+ "PLBartForConditionalGeneration": "BartEncoderLayer",
108
+ "TimeSeriesTransformerEncoderLayer": "TimeSeriesTransformerEncoderLayer",
109
+ "TimeSeriesTransformerForPrediction": "TimeSeriesTransformerEncoderLayer",
110
+ }
111
+ return _alias
112
+
113
+
114
+ def rewritings_transformers_clamp_float16(cls_name) -> List[type]:
115
+ """
116
+ Rewrites known control flows equal to this:
117
+
118
+ .. code-block:: python
119
+
120
+ if hidden_states.dtype == torch.float16 and (
121
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
122
+ ):
123
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
124
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
125
+
126
+ *cls_name* is the class name. It is mapped with a list of other class names
127
+ to rename. Here is the known list:
128
+
129
+ .. runpython::
130
+ :showcode:
131
+
132
+ import pprint
133
+ from onnx_diagnostic.torch_export_patches.patch_module_helper import (
134
+ _rewrite_forward_clamp_float16,
135
+ )
136
+
137
+ pprint.pprint(_rewrite_forward_clamp_float16())
138
+
139
+ Function `_rewrite_forward_clamp_float16` collects
140
+ all model classes using those layers.
141
+ """
142
+ _known = _rewrite_forward_clamp_float16()
143
+
144
+ assert cls_name in _known, f"cls_name={cls_name!r} unknown in {sorted(_known)}."
145
+
26
146
  bd = dict(
27
147
  filter_node=(
28
148
  lambda node: isinstance(node, ast.If) and not isinstance(node.test, ast.Name)
@@ -35,16 +155,13 @@ def _rewrite_bart_encoder_layer():
35
155
  g["function"] = f
36
156
  return g
37
157
 
38
- return [
39
- _add(transformers.models.bart.modeling_bart.BartEncoderLayer.forward),
40
- _add(transformers.models.plbart.modeling_plbart.PLBartEncoderLayer.forward),
41
- ]
158
+ return [_add(cls.forward) for cls in _known[cls_name]]
42
159
 
43
160
 
44
161
  def code_needing_rewriting(cls_name: str) -> Optional[List[Any]]:
45
162
  """
46
- Returns a known list of methods or functions to rewrite because of control flow
47
- for a specific model class.
163
+ Returns a known list of classes mapped to a known rewritings
164
+ because of control flow. See :func:`known_transformers_rewritings_clamp_float16`.
48
165
 
49
166
  :param cls_name: name of the class
50
167
  :return: a list of rewriting
@@ -59,11 +176,8 @@ def code_needing_rewriting(cls_name: str) -> Optional[List[Any]]:
59
176
 
60
177
  pprint.pprint(code_needing_rewriting("BartForConditionalGeneration"))
61
178
  """
62
- if cls_name in {
63
- "BartEncoderLayer",
64
- "BartForConditionalGeneration",
65
- "PLBartEncoderLayer",
66
- "PLBartForConditionalGeneration",
67
- }:
68
- return _rewrite_bart_encoder_layer()
179
+ aliases = known_transformers_rewritings_clamp_float16()
180
+ if cls_name in aliases:
181
+ alias = aliases[cls_name]
182
+ return rewritings_transformers_clamp_float16(alias)
69
183
  return None
@@ -1,5 +1,6 @@
1
1
  import inspect
2
2
  from dataclasses import dataclass
3
+ from functools import wraps
3
4
  from typing import Any, Callable, Dict, List, Optional, Tuple
4
5
  import torch
5
6
  import transformers
@@ -531,3 +532,90 @@ class patched_GenerationMixin:
531
532
  # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
532
533
  model_inputs.pop("labels", None)
533
534
  return model_inputs
535
+
536
+
537
+ def patched_dynamic_rope_update(rope_forward):
538
+ """
539
+ patch:transformers.modeling_rope_utils.dynamic_rope_update
540
+ """
541
+
542
+ def longrope_frequency_update(self, position_ids, device):
543
+ seq_len = torch.max(position_ids) + 1
544
+ if hasattr(self.config, "original_max_position_embeddings"):
545
+ original_max_position_embeddings = self.config.original_max_position_embeddings
546
+ else:
547
+ original_max_position_embeddings = self.config.max_position_embeddings
548
+ # At export time, seq_len is unknown.
549
+ long_inv_freq, _ = self.rope_init_fn(
550
+ self.config, device, seq_len=original_max_position_embeddings + 1
551
+ )
552
+ original_inv_freq = self.original_inv_freq.to(device)
553
+
554
+ cond = (seq_len > original_max_position_embeddings).item()
555
+ inv_freq = torch.cond(
556
+ cond,
557
+ (lambda x, y: x.clone()),
558
+ (lambda x, y: y.clone()),
559
+ [long_inv_freq, original_inv_freq],
560
+ )
561
+ self.inv_freq = inv_freq
562
+ # if seq_len > original_max_position_embeddings:
563
+ # self.inv_freq = self.long_inv_freq
564
+ # else:
565
+ # self.inv_freq = self.original_inv_freq
566
+
567
+ def dynamic_frequency_update(self, position_ids, device):
568
+ seq_len = torch.max(position_ids) + 1
569
+ if seq_len > self.max_seq_len_cached: # growth
570
+ inv_freq, self.attention_scaling = self.rope_init_fn(
571
+ self.config, device, seq_len=seq_len
572
+ )
573
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
574
+ self.max_seq_len_cached = seq_len
575
+
576
+ if (
577
+ seq_len < self.original_max_seq_len
578
+ and self.max_seq_len_cached > self.original_max_seq_len
579
+ ):
580
+ self.original_inv_freq = self.original_inv_freq.to(device)
581
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
582
+ self.max_seq_len_cached = self.original_max_seq_len
583
+
584
+ @wraps(rope_forward)
585
+ def wrapper(self, x, position_ids):
586
+ if "dynamic" in self.rope_type:
587
+ dynamic_frequency_update(self, position_ids, device=x.device)
588
+ elif self.rope_type == "longrope":
589
+ longrope_frequency_update(self, position_ids, device=x.device)
590
+ return rope_forward(self, x, position_ids)
591
+
592
+ return wrapper
593
+
594
+
595
+ class patched_Phi3RotaryEmbedding(torch.nn.Module):
596
+ _PATCHES_ = ["forward"]
597
+ _PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding
598
+
599
+ @torch.no_grad()
600
+ @patched_dynamic_rope_update
601
+ def forward(self, x, position_ids):
602
+ inv_freq_expanded = (
603
+ self.inv_freq[None, :, None]
604
+ .float()
605
+ .expand(position_ids.shape[0], -1, 1)
606
+ .to(x.device)
607
+ )
608
+ position_ids_expanded = position_ids[:, None, :].float()
609
+
610
+ device_type = (
611
+ x.device.type
612
+ if isinstance(x.device.type, str) and x.device.type != "mps"
613
+ else "cpu"
614
+ )
615
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
616
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
617
+ emb = torch.cat((freqs, freqs), dim=-1)
618
+ cos = emb.cos() * self.attention_scaling
619
+ sin = emb.sin() * self.attention_scaling
620
+
621
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@@ -3951,3 +3951,145 @@ def _ccached_facebook_bart_large_cnn():
3951
3951
  "vocab_size": 50264,
3952
3952
  }
3953
3953
  )
3954
+
3955
+
3956
+ def _ccached_microsoft_phi4_reasoning():
3957
+ "microsoft/Phi-4-mini-reasoning"
3958
+ return transformers.Phi3Config(
3959
+ **{
3960
+ "architectures": ["Phi3ForCausalLM"],
3961
+ "attention_bias": false,
3962
+ "attention_dropout": 0.0,
3963
+ "bos_token_id": 199999,
3964
+ "embd_pdrop": 0.0,
3965
+ "eos_token_id": 199999,
3966
+ "full_attn_mod": 1,
3967
+ "hidden_act": "silu",
3968
+ "hidden_size": 3072,
3969
+ "initializer_range": 0.02,
3970
+ "intermediate_size": 8192,
3971
+ "interpolate_factor": 1,
3972
+ "lm_head_bias": false,
3973
+ "max_position_embeddings": 131072,
3974
+ "mlp_bias": false,
3975
+ "model_type": "phi3",
3976
+ "num_attention_heads": 24,
3977
+ "num_hidden_layers": 32,
3978
+ "num_key_value_heads": 8,
3979
+ "original_max_position_embeddings": 4096,
3980
+ "pad_token_id": 199999,
3981
+ "partial_rotary_factor": 0.75,
3982
+ "resid_pdrop": 0.0,
3983
+ "rms_norm_eps": 1e-05,
3984
+ "rope_scaling": {
3985
+ "long_factor": [
3986
+ 1,
3987
+ 1.118320672,
3988
+ 1.250641126,
3989
+ 1.398617824,
3990
+ 1.564103225,
3991
+ 1.74916897,
3992
+ 1.956131817,
3993
+ 2.187582649,
3994
+ 2.446418898,
3995
+ 2.735880826,
3996
+ 3.059592084,
3997
+ 3.421605075,
3998
+ 3.826451687,
3999
+ 4.279200023,
4000
+ 4.785517845,
4001
+ 5.351743533,
4002
+ 5.984965424,
4003
+ 6.693110555,
4004
+ 7.485043894,
4005
+ 8.370679318,
4006
+ 9.36110372,
4007
+ 10.4687158,
4008
+ 11.70738129,
4009
+ 13.09260651,
4010
+ 14.64173252,
4011
+ 16.37415215,
4012
+ 18.31155283,
4013
+ 20.47818807,
4014
+ 22.90118105,
4015
+ 25.61086418,
4016
+ 28.64115884,
4017
+ 32.03,
4018
+ 32.1,
4019
+ 32.13,
4020
+ 32.23,
4021
+ 32.6,
4022
+ 32.61,
4023
+ 32.64,
4024
+ 32.66,
4025
+ 32.7,
4026
+ 32.71,
4027
+ 32.93,
4028
+ 32.97,
4029
+ 33.28,
4030
+ 33.49,
4031
+ 33.5,
4032
+ 44.16,
4033
+ 47.77,
4034
+ ],
4035
+ "short_factor": [
4036
+ 1,
4037
+ 1.118320672,
4038
+ 1.250641126,
4039
+ 1.398617824,
4040
+ 1.564103225,
4041
+ 1.74916897,
4042
+ 1.956131817,
4043
+ 2.187582649,
4044
+ 2.446418898,
4045
+ 2.735880826,
4046
+ 3.059592084,
4047
+ 3.421605075,
4048
+ 3.826451687,
4049
+ 4.279200023,
4050
+ 4.785517845,
4051
+ 5.351743533,
4052
+ 5.984965424,
4053
+ 6.693110555,
4054
+ 7.485043894,
4055
+ 8.370679318,
4056
+ 9.36110372,
4057
+ 10.4687158,
4058
+ 11.70738129,
4059
+ 13.09260651,
4060
+ 14.64173252,
4061
+ 16.37415215,
4062
+ 18.31155283,
4063
+ 20.47818807,
4064
+ 22.90118105,
4065
+ 25.61086418,
4066
+ 28.64115884,
4067
+ 32.03,
4068
+ 32.1,
4069
+ 32.13,
4070
+ 32.23,
4071
+ 32.6,
4072
+ 32.61,
4073
+ 32.64,
4074
+ 32.66,
4075
+ 32.7,
4076
+ 32.71,
4077
+ 32.93,
4078
+ 32.97,
4079
+ 33.28,
4080
+ 33.49,
4081
+ 33.5,
4082
+ 44.16,
4083
+ 47.77,
4084
+ ],
4085
+ "type": "longrope",
4086
+ },
4087
+ "rope_theta": 10000.0,
4088
+ "sliding_window": 262144,
4089
+ "tie_word_embeddings": true,
4090
+ "torch_dtype": "bfloat16",
4091
+ "transformers_version": "4.50.0",
4092
+ "use_cache": true,
4093
+ "vocab_size": 200064,
4094
+ }
4095
+ )