onnx-diagnostic 0.5.0__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. onnx_diagnostic/__init__.py +2 -2
  2. onnx_diagnostic/_command_lines_parser.py +39 -1
  3. onnx_diagnostic/api.py +15 -0
  4. onnx_diagnostic/export/dynamic_shapes.py +14 -5
  5. onnx_diagnostic/ext_test_case.py +15 -1
  6. onnx_diagnostic/helpers/args_helper.py +1 -1
  7. onnx_diagnostic/helpers/graph_helper.py +386 -0
  8. onnx_diagnostic/helpers/helper.py +30 -5
  9. onnx_diagnostic/helpers/model_builder_helper.py +349 -0
  10. onnx_diagnostic/helpers/rt_helper.py +69 -1
  11. onnx_diagnostic/helpers/torch_helper.py +2 -0
  12. onnx_diagnostic/reference/__init__.py +1 -0
  13. onnx_diagnostic/reference/torch_evaluator.py +518 -0
  14. onnx_diagnostic/reference/torch_ops/__init__.py +55 -0
  15. onnx_diagnostic/reference/torch_ops/_op_run.py +326 -0
  16. onnx_diagnostic/reference/torch_ops/access_ops.py +84 -0
  17. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  18. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +118 -0
  19. onnx_diagnostic/reference/torch_ops/generator_ops.py +35 -0
  20. onnx_diagnostic/reference/torch_ops/nn_ops.py +176 -0
  21. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  22. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  23. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  24. onnx_diagnostic/reference/torch_ops/shape_ops.py +120 -0
  25. onnx_diagnostic/reference/torch_ops/unary_ops.py +86 -0
  26. onnx_diagnostic/tasks/__init__.py +22 -1
  27. onnx_diagnostic/tasks/image_classification.py +2 -2
  28. onnx_diagnostic/tasks/text_generation.py +3 -3
  29. onnx_diagnostic/torch_export_patches/eval/__init__.py +690 -0
  30. onnx_diagnostic/torch_export_patches/eval/model_cases.py +883 -0
  31. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +34 -1
  32. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +6 -1
  33. onnx_diagnostic/torch_export_patches/patch_module_helper.py +148 -28
  34. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +91 -0
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +117 -1
  36. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +142 -0
  37. onnx_diagnostic/torch_models/test_helper.py +225 -22
  38. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  39. {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/METADATA +1 -1
  40. {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/RECORD +43 -24
  41. {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/WHEEL +1 -1
  42. {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/licenses/LICENSE.txt +0 -0
  43. {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/top_level.txt +0 -0
@@ -107,7 +107,7 @@ def torch_export_patches(
107
107
  ) -> Callable:
108
108
  """
109
109
  Tries to bypass some situations :func:`torch.export.export` does not support.
110
- See also :ref:`l-patches-explained`.
110
+ See also :ref:`l-patches-explained` and :ref:`l-patch-coverage`.
111
111
 
112
112
  :param patch_sympy: fix missing method ``name`` for IntegerConstant
113
113
  :param patch_torch: patches :epkg:`torch` with supported implementation
@@ -140,6 +140,7 @@ def torch_export_patches(
140
140
  * ``torch.jit.isinstance``
141
141
  * ``torch._dynamo.mark_static_address``
142
142
  * ``torch._subclasses.fake_impls.infer_size``
143
+ * ``torch.vmap``
143
144
  * fix missing method ``name`` for ``sympy.S.IntegerConstant``
144
145
  * ``AttentionMaskConverter._make_causal_mask``
145
146
  * Serialization of ``MambaCache`` (in :epkg:`transformers`)
@@ -251,6 +252,7 @@ def torch_export_patches(
251
252
  if patch_torch:
252
253
  from .patches.patch_torch import (
253
254
  patched_infer_size,
255
+ patched_vmap,
254
256
  patched__broadcast_shapes,
255
257
  _catch_produce_guards_and_solve_constraints,
256
258
  patch__check_input_constraints_for_graph,
@@ -261,6 +263,10 @@ def torch_export_patches(
261
263
  print(f"[torch_export_patches] stop_if_static={stop_if_static!r}")
262
264
  print("[torch_export_patches] patch pytorch")
263
265
 
266
+ # torch.vmap
267
+ f_vmap = torch.vmap
268
+ torch.vmap = patched_vmap
269
+
264
270
  # torch.jit.isinstance
265
271
  f_jit_isinstance = torch.jit.isinstance
266
272
  torch.jit.isinstance = isinstance
@@ -328,6 +334,11 @@ def torch_export_patches(
328
334
  ####################
329
335
 
330
336
  if patch_transformers:
337
+ try:
338
+ import transformers.masking_utils as masking_utils
339
+ except ImportError:
340
+ masking_utils = None
341
+
331
342
  if verbose:
332
343
  import transformers
333
344
 
@@ -339,6 +350,15 @@ def torch_export_patches(
339
350
  patch_transformers_list, verbose=verbose
340
351
  )
341
352
 
353
+ if masking_utils and hasattr(masking_utils, "_vmap_for_bhqkv"):
354
+ if verbose:
355
+ print(
356
+ "[torch_export_patches] patches "
357
+ "transformers.masking_utils._vmap_for_bhqkv"
358
+ )
359
+ f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv
360
+ masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv
361
+
342
362
  if custom_patches:
343
363
  if verbose:
344
364
  print("[torch_export_patches] applies custom patches")
@@ -381,6 +401,7 @@ def torch_export_patches(
381
401
 
382
402
  if patch_torch:
383
403
  # this should disappear when torch.jit is removed
404
+ torch.vmap = f_vmap
384
405
  torch.jit.isinstance = f_jit_isinstance
385
406
  torch._dynamo.mark_static_address = f_mark_static_address
386
407
  # tracked by https://github.com/pytorch/pytorch/issues/143495
@@ -430,12 +451,24 @@ def torch_export_patches(
430
451
  ##############
431
452
 
432
453
  if patch_transformers:
454
+ try:
455
+ import transformers.masking_utils as masking_utils
456
+ except ImportError:
457
+ masking_utils = None
433
458
  if verbose:
434
459
  print("[torch_export_patches] unpatch transformers")
435
460
  unpatch_module_or_classes(
436
461
  patch_transformers_list, revert_patches_info, verbose=verbose
437
462
  )
438
463
 
464
+ if masking_utils and hasattr(masking_utils, "_vmap_for_bhqkv"):
465
+ if verbose:
466
+ print(
467
+ "[torch_export_patches] unpatch "
468
+ "transformers.masking_utils._vmap_for_bhqkv"
469
+ )
470
+ masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv
471
+
439
472
  ########
440
473
  # caches
441
474
  ########
@@ -1,5 +1,5 @@
1
1
  import pprint
2
- from typing import Any, Callable, Dict, List, Optional, Set, Tuple
2
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
3
3
  import packaging.version as pv
4
4
  import optree
5
5
  import torch
@@ -133,6 +133,11 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
133
133
  # To avoid doing it multiple times.
134
134
  PATCH_OF_PATCHES.add(BaseModelOutput)
135
135
 
136
+ return serialization_functions(verbose=verbose)
137
+
138
+
139
+ def serialization_functions(verbose: int = 0) -> Dict[str, Union[Callable, int]]:
140
+ """Returns the list of serialization functions."""
136
141
  return dict(
137
142
  DynamicCache=register_class_serialization(
138
143
  DynamicCache,
@@ -1,5 +1,6 @@
1
1
  import ast
2
- from typing import Any, List, Optional
2
+ import functools
3
+ from typing import Any, Dict, List, Optional
3
4
 
4
5
 
5
6
  class OrToBitOrTransformer(ast.NodeTransformer):
@@ -19,10 +20,148 @@ def ast_or_into_bitor(node: "ast.Node") -> "ast.Node":
19
20
  return new_node
20
21
 
21
22
 
23
+ @functools.lru_cache
24
+ def _rewrite_forward_clamp_float16() -> Dict[str, List[type]]:
25
+
26
+ import transformers
27
+
28
+ _known = {
29
+ "AutoformerEncoderLayer": [
30
+ transformers.models.autoformer.modeling_autoformer.AutoformerEncoderLayer
31
+ ],
32
+ "BartEncoderLayer": [
33
+ transformers.models.bart.modeling_bart.BartEncoderLayer,
34
+ transformers.models.plbart.modeling_plbart.PLBartEncoderLayer,
35
+ ],
36
+ "BigBirdPegasusEncoderLayer": [
37
+ transformers.models.bigbird_pegasus.modeling_bigbird_pegasus.BigBirdPegasusEncoderLayer
38
+ ],
39
+ "BlenderbotSmallEncoderLayer": [
40
+ transformers.models.blenderbot_small.modeling_blenderbot_small.BlenderbotSmallEncoderLayer
41
+ ],
42
+ "InformerEncoderLayer": [
43
+ transformers.models.informer.modeling_informer.InformerEncoderLayer
44
+ ],
45
+ "LEDEncoderLayer": [transformers.models.led.modeling_led.LEDEncoderLayer],
46
+ "MarianEncoderLayer": [transformers.models.marian.modeling_marian.MarianEncoderLayer],
47
+ "MvpEncoderLayer": [transformers.models.mvp.modeling_mvp.MvpEncoderLayer],
48
+ "NllbMoeEncoderLayer": [
49
+ transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeEncoderLayer
50
+ ],
51
+ "TimeSeriesTransformerEncoderLayer": [
52
+ transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerEncoderLayer
53
+ ],
54
+ }
55
+ return _known
56
+
57
+
58
+ @functools.lru_cache
59
+ def known_transformers_rewritings_clamp_float16() -> Dict[str, str]:
60
+ """
61
+ This functions returns the list of known classes to be rewritten.
62
+ in :epkg:`transformers`. Each class is mapped to an alias,
63
+ this alias is then given to :func:`rewritings_transformers_clamp_float16`
64
+ to rewrite the encoder layers because of a specific control flow.
65
+
66
+ .. runpython::
67
+ :showcode:
68
+
69
+ import pprint
70
+ from onnx_diagnostic.torch_export_patches.patch_module_helper import (
71
+ known_transformers_rewritings_clamp_float16,
72
+ )
73
+
74
+ pprint.pprint(known_transformers_rewritings_clamp_float16())
75
+ """
76
+ _alias = {
77
+ "AutoformerEncoder": "AutoformerEncoderLayer",
78
+ "AutoformerEncoderLayer": "AutoformerEncoderLayer",
79
+ "AutoformerForPrediction": "AutoformerEncoderLayer",
80
+ "AutoformerModel": "AutoformerEncoderLayer",
81
+ "BartEncoderLayer": "BartEncoderLayer",
82
+ "BartForConditionalGeneration": "BartEncoderLayer",
83
+ "BigBirdPegasusForConditionalGeneration": "BigBirdPegasusEncoderLayer",
84
+ "BigBirdPegasusForQuestionAnswering": "BigBirdPegasusEncoderLayer",
85
+ "BigBirdPegasusForCausalLM": "BigBirdPegasusEncoderLayer",
86
+ "BlenderbotSmallEncoderLayer": "BlenderbotSmallEncoderLayer",
87
+ "BlenderbotSmallForConditionalGeneration": "BlenderbotSmallEncoderLayer",
88
+ "BlenderbotSmallForCausalLM": "BlenderbotSmallEncoderLayer",
89
+ "InformerEncoderLayer": "InformerEncoderLayer",
90
+ "InformerForPrediction": "InformerEncoderLayer",
91
+ "LEDEncoderLayer": "LEDEncoderLayer",
92
+ "LEDClassificationHead": "LEDEncoderLayer",
93
+ "LEDForConditionalGeneration": "LEDEncoderLayer",
94
+ "MarianEncoderLayer": "MarianEncoderLayer",
95
+ "MarianEncoder": "MarianEncoderLayer",
96
+ "MarianModel": "MarianEncoderLayer",
97
+ "MarianMTModel": "MarianEncoderLayer",
98
+ "MvpEncoderLayer": "MvpEncoderLayer",
99
+ "MvpPrompt": "MvpEncoderLayer",
100
+ "MvpForConditionalGeneration": "MvpEncoderLayer",
101
+ "MvpForSequenceClassification": "MvpEncoderLayer",
102
+ "MvpForQuestionAnswering": "MvpEncoderLayer",
103
+ "MvpForCausalLM": "MvpEncoderLayer",
104
+ "NllbMoeEncoderLayer": "NllbMoeEncoderLayer",
105
+ "NllbMoeForConditionalGeneration": "NllbMoeEncoderLayer",
106
+ "PLBartEncoderLayer": "BartEncoderLayer",
107
+ "PLBartForConditionalGeneration": "BartEncoderLayer",
108
+ "TimeSeriesTransformerEncoderLayer": "TimeSeriesTransformerEncoderLayer",
109
+ "TimeSeriesTransformerForPrediction": "TimeSeriesTransformerEncoderLayer",
110
+ }
111
+ return _alias
112
+
113
+
114
+ def rewritings_transformers_clamp_float16(cls_name) -> List[type]:
115
+ """
116
+ Rewrites known control flows equal to this:
117
+
118
+ .. code-block:: python
119
+
120
+ if hidden_states.dtype == torch.float16 and (
121
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
122
+ ):
123
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
124
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
125
+
126
+ *cls_name* is the class name. It is mapped with a list of other class names
127
+ to rename. Here is the known list:
128
+
129
+ .. runpython::
130
+ :showcode:
131
+
132
+ import pprint
133
+ from onnx_diagnostic.torch_export_patches.patch_module_helper import (
134
+ _rewrite_forward_clamp_float16,
135
+ )
136
+
137
+ pprint.pprint(_rewrite_forward_clamp_float16())
138
+
139
+ Function `_rewrite_forward_clamp_float16` collects
140
+ all model classes using those layers.
141
+ """
142
+ _known = _rewrite_forward_clamp_float16()
143
+
144
+ assert cls_name in _known, f"cls_name={cls_name!r} unknown in {sorted(_known)}."
145
+
146
+ bd = dict(
147
+ filter_node=(
148
+ lambda node: isinstance(node, ast.If) and not isinstance(node.test, ast.Name)
149
+ ),
150
+ pre_rewriter=ast_or_into_bitor,
151
+ )
152
+
153
+ def _add(f):
154
+ g = bd.copy()
155
+ g["function"] = f
156
+ return g
157
+
158
+ return [_add(cls.forward) for cls in _known[cls_name]]
159
+
160
+
22
161
  def code_needing_rewriting(cls_name: str) -> Optional[List[Any]]:
23
162
  """
24
- Returns a known list of methods or functions to rewrite because of control flow
25
- for a specific model class.
163
+ Returns a known list of classes mapped to a known rewritings
164
+ because of control flow. See :func:`known_transformers_rewritings_clamp_float16`.
26
165
 
27
166
  :param cls_name: name of the class
28
167
  :return: a list of rewriting
@@ -30,34 +169,15 @@ def code_needing_rewriting(cls_name: str) -> Optional[List[Any]]:
30
169
  .. runpython::
31
170
  :showcode:
32
171
 
172
+ import pprint
33
173
  from onnx_diagnostic.torch_export_patches.patch_module_helper import (
34
174
  code_needing_rewriting,
35
175
  )
36
176
 
37
- print(code_needing_rewriting("BartForConditionalGeneration"))
177
+ pprint.pprint(code_needing_rewriting("BartForConditionalGeneration"))
38
178
  """
39
- if cls_name in {
40
- "BartEncoderLayer",
41
- "BartForConditionalGeneration",
42
- "PLBartEncoderLayer",
43
- "PLBartForConditionalGeneration",
44
- }:
45
- import transformers
46
-
47
- bd = dict(
48
- filter_node=(
49
- lambda node: isinstance(node, ast.If) and not isinstance(node.test, ast.Name)
50
- ),
51
- pre_rewriter=ast_or_into_bitor,
52
- )
53
-
54
- def _add(f):
55
- g = bd.copy()
56
- g["function"] = f
57
- return g
58
-
59
- return [
60
- _add(transformers.models.bart.modeling_bart.BartEncoderLayer.forward),
61
- _add(transformers.models.plbart.modeling_plbart.PLBartEncoderLayer.forward),
62
- ]
179
+ aliases = known_transformers_rewritings_clamp_float16()
180
+ if cls_name in aliases:
181
+ alias = aliases[cls_name]
182
+ return rewritings_transformers_clamp_float16(alias)
63
183
  return None
@@ -370,3 +370,94 @@ class patched_ShapeEnv:
370
370
  # RuntimeWarning,
371
371
  # stacklevel=0,
372
372
  # )
373
+
374
+
375
+ def patched_vmap(func, in_dims=0, out_dims=0):
376
+ """
377
+ Python implementation of :func:`torch.vmap`.
378
+ The implementation raises an issue when it is being exported with
379
+ :func:`torch.export.export` when the function is called with
380
+ non tensors arguments and the batch size is dynamic.
381
+ """
382
+ from ...helpers import string_type
383
+
384
+ def wrapped(*args):
385
+ assert all(not isinstance(a, dict) for a in args), (
386
+ f"dictionaries are not implemented in "
387
+ f"args={string_type(args, with_shape=True)}"
388
+ )
389
+
390
+ in_dims_ = (
391
+ ([in_dims] * len(args))
392
+ if not isinstance(in_dims, (list, tuple))
393
+ else list(in_dims)
394
+ )
395
+ assert len(in_dims_) == len(args), (
396
+ f"Mismtch between in_dims={in_dims_} and "
397
+ f"args={string_type(args, with_shape=True)}"
398
+ )
399
+
400
+ batch_size = None
401
+ batched_args = []
402
+ for arg, in_dim in zip(args, in_dims_):
403
+ if in_dim is None:
404
+ batched_args.append(arg)
405
+ continue
406
+
407
+ assert batch_size is None or batch_size == arg.size(in_dim), (
408
+ f"Unable to continue, batch_size={batch_size}, in_dim={in_dim}, "
409
+ f"arg.size(in_dim)={arg.size(in_dim)}"
410
+ )
411
+ if batch_size is None:
412
+ batch_size = arg.size(in_dim)
413
+ arg = arg.movedim(in_dim, 0)
414
+ batched_args.append(arg)
415
+
416
+ if all(isinstance(a, torch.Tensor) for a in args) and isinstance(
417
+ batch_size, torch.SymInt
418
+ ):
419
+ batched_tensors = [
420
+ (
421
+ arg
422
+ if (isinstance(arg, torch.Tensor) and in_dim is not None)
423
+ else arg.unsqueeze(0).expand((batch_size, *arg.shape))
424
+ )
425
+ for arg, in_dim in zip(batched_args, in_dims_)
426
+ ]
427
+ results = torch.ops.higher_order.scan(func, [], batched_tensors, [])
428
+ stacked = results[0]
429
+ if out_dims != 0:
430
+ return stacked.movedim(0, out_dims)
431
+ return stacked
432
+
433
+ else:
434
+ torch._check(
435
+ not isinstance(batch_size, torch.SymInt),
436
+ lambda: (
437
+ f"patched_vmap supports dynamic batch_size only if all argument "
438
+ f"are tensors but types are {[type(a) for a in args]}"
439
+ ),
440
+ )
441
+ batched_tensors = [
442
+ (
443
+ (None, arg)
444
+ if (isinstance(arg, torch.Tensor) and in_dim is not None)
445
+ else (arg, arg)
446
+ )
447
+ for arg, in_dim in zip(batched_args, in_dims_)
448
+ ]
449
+
450
+ results = []
451
+ for i in range(batch_size):
452
+ input_slice = [v if v is not None else arg[i] for v, arg in batched_tensors]
453
+ result = func(*input_slice)
454
+ results.append(result)
455
+
456
+ if isinstance(results[0], torch.Tensor):
457
+ stacked = torch.stack(results)
458
+ if out_dims != 0:
459
+ return stacked.movedim(0, out_dims)
460
+ return stacked
461
+ return results
462
+
463
+ return wrapped
@@ -1,6 +1,7 @@
1
1
  import inspect
2
2
  from dataclasses import dataclass
3
- from typing import Any, Dict, List, Optional, Tuple
3
+ from functools import wraps
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple
4
5
  import torch
5
6
  import transformers
6
7
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
@@ -9,6 +10,34 @@ from ...ext_test_case import has_transformers
9
10
  from ...helpers.torch_helper import is_torchdynamo_exporting
10
11
 
11
12
 
13
+ def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
14
+ """Patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
15
+ from ...helpers import string_type
16
+
17
+ dimensions: List[Tuple[Optional[int], ...]] = [
18
+ (None, None, None, 0),
19
+ (None, None, 0, None),
20
+ ]
21
+ if bh_indices:
22
+ dimensions.extend([(None, 0, None, None), (0, None, None, None)])
23
+ dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions]
24
+ dimensions = tuple(reversed(dimensions))
25
+ indices = tuple(shape.index(-1) for shape in dimensions)
26
+
27
+ def vector_mask_function(
28
+ *args, mask_function=mask_function, dimensions=dimensions, indices=indices
29
+ ):
30
+ assert len(args) == len(
31
+ dimensions
32
+ ), f"Mismatch between args={string_type(args)} and dimensions={dimensions}"
33
+ new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
34
+ max_shape = tuple(args[i].shape[0] for i in indices)
35
+ expanded_args = [a.expand(max_shape) for a in new_args]
36
+ return mask_function(*expanded_args)
37
+
38
+ return vector_mask_function
39
+
40
+
12
41
  def _patch_make_causal_mask(
13
42
  input_ids_shape: torch.Size,
14
43
  dtype: torch.dtype,
@@ -503,3 +532,90 @@ class patched_GenerationMixin:
503
532
  # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
504
533
  model_inputs.pop("labels", None)
505
534
  return model_inputs
535
+
536
+
537
+ def patched_dynamic_rope_update(rope_forward):
538
+ """
539
+ patch:transformers.modeling_rope_utils.dynamic_rope_update
540
+ """
541
+
542
+ def longrope_frequency_update(self, position_ids, device):
543
+ seq_len = torch.max(position_ids) + 1
544
+ if hasattr(self.config, "original_max_position_embeddings"):
545
+ original_max_position_embeddings = self.config.original_max_position_embeddings
546
+ else:
547
+ original_max_position_embeddings = self.config.max_position_embeddings
548
+ # At export time, seq_len is unknown.
549
+ long_inv_freq, _ = self.rope_init_fn(
550
+ self.config, device, seq_len=original_max_position_embeddings + 1
551
+ )
552
+ original_inv_freq = self.original_inv_freq.to(device)
553
+
554
+ cond = (seq_len > original_max_position_embeddings).item()
555
+ inv_freq = torch.cond(
556
+ cond,
557
+ (lambda x, y: x.clone()),
558
+ (lambda x, y: y.clone()),
559
+ [long_inv_freq, original_inv_freq],
560
+ )
561
+ self.inv_freq = inv_freq
562
+ # if seq_len > original_max_position_embeddings:
563
+ # self.inv_freq = self.long_inv_freq
564
+ # else:
565
+ # self.inv_freq = self.original_inv_freq
566
+
567
+ def dynamic_frequency_update(self, position_ids, device):
568
+ seq_len = torch.max(position_ids) + 1
569
+ if seq_len > self.max_seq_len_cached: # growth
570
+ inv_freq, self.attention_scaling = self.rope_init_fn(
571
+ self.config, device, seq_len=seq_len
572
+ )
573
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
574
+ self.max_seq_len_cached = seq_len
575
+
576
+ if (
577
+ seq_len < self.original_max_seq_len
578
+ and self.max_seq_len_cached > self.original_max_seq_len
579
+ ):
580
+ self.original_inv_freq = self.original_inv_freq.to(device)
581
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
582
+ self.max_seq_len_cached = self.original_max_seq_len
583
+
584
+ @wraps(rope_forward)
585
+ def wrapper(self, x, position_ids):
586
+ if "dynamic" in self.rope_type:
587
+ dynamic_frequency_update(self, position_ids, device=x.device)
588
+ elif self.rope_type == "longrope":
589
+ longrope_frequency_update(self, position_ids, device=x.device)
590
+ return rope_forward(self, x, position_ids)
591
+
592
+ return wrapper
593
+
594
+
595
+ class patched_Phi3RotaryEmbedding(torch.nn.Module):
596
+ _PATCHES_ = ["forward"]
597
+ _PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding
598
+
599
+ @torch.no_grad()
600
+ @patched_dynamic_rope_update
601
+ def forward(self, x, position_ids):
602
+ inv_freq_expanded = (
603
+ self.inv_freq[None, :, None]
604
+ .float()
605
+ .expand(position_ids.shape[0], -1, 1)
606
+ .to(x.device)
607
+ )
608
+ position_ids_expanded = position_ids[:, None, :].float()
609
+
610
+ device_type = (
611
+ x.device.type
612
+ if isinstance(x.device.type, str) and x.device.type != "mps"
613
+ else "cpu"
614
+ )
615
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
616
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
617
+ emb = torch.cat((freqs, freqs), dim=-1)
618
+ cos = emb.cos() * self.attention_scaling
619
+ sin = emb.sin() * self.attention_scaling
620
+
621
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@@ -3951,3 +3951,145 @@ def _ccached_facebook_bart_large_cnn():
3951
3951
  "vocab_size": 50264,
3952
3952
  }
3953
3953
  )
3954
+
3955
+
3956
+ def _ccached_microsoft_phi4_reasoning():
3957
+ "microsoft/Phi-4-mini-reasoning"
3958
+ return transformers.Phi3Config(
3959
+ **{
3960
+ "architectures": ["Phi3ForCausalLM"],
3961
+ "attention_bias": false,
3962
+ "attention_dropout": 0.0,
3963
+ "bos_token_id": 199999,
3964
+ "embd_pdrop": 0.0,
3965
+ "eos_token_id": 199999,
3966
+ "full_attn_mod": 1,
3967
+ "hidden_act": "silu",
3968
+ "hidden_size": 3072,
3969
+ "initializer_range": 0.02,
3970
+ "intermediate_size": 8192,
3971
+ "interpolate_factor": 1,
3972
+ "lm_head_bias": false,
3973
+ "max_position_embeddings": 131072,
3974
+ "mlp_bias": false,
3975
+ "model_type": "phi3",
3976
+ "num_attention_heads": 24,
3977
+ "num_hidden_layers": 32,
3978
+ "num_key_value_heads": 8,
3979
+ "original_max_position_embeddings": 4096,
3980
+ "pad_token_id": 199999,
3981
+ "partial_rotary_factor": 0.75,
3982
+ "resid_pdrop": 0.0,
3983
+ "rms_norm_eps": 1e-05,
3984
+ "rope_scaling": {
3985
+ "long_factor": [
3986
+ 1,
3987
+ 1.118320672,
3988
+ 1.250641126,
3989
+ 1.398617824,
3990
+ 1.564103225,
3991
+ 1.74916897,
3992
+ 1.956131817,
3993
+ 2.187582649,
3994
+ 2.446418898,
3995
+ 2.735880826,
3996
+ 3.059592084,
3997
+ 3.421605075,
3998
+ 3.826451687,
3999
+ 4.279200023,
4000
+ 4.785517845,
4001
+ 5.351743533,
4002
+ 5.984965424,
4003
+ 6.693110555,
4004
+ 7.485043894,
4005
+ 8.370679318,
4006
+ 9.36110372,
4007
+ 10.4687158,
4008
+ 11.70738129,
4009
+ 13.09260651,
4010
+ 14.64173252,
4011
+ 16.37415215,
4012
+ 18.31155283,
4013
+ 20.47818807,
4014
+ 22.90118105,
4015
+ 25.61086418,
4016
+ 28.64115884,
4017
+ 32.03,
4018
+ 32.1,
4019
+ 32.13,
4020
+ 32.23,
4021
+ 32.6,
4022
+ 32.61,
4023
+ 32.64,
4024
+ 32.66,
4025
+ 32.7,
4026
+ 32.71,
4027
+ 32.93,
4028
+ 32.97,
4029
+ 33.28,
4030
+ 33.49,
4031
+ 33.5,
4032
+ 44.16,
4033
+ 47.77,
4034
+ ],
4035
+ "short_factor": [
4036
+ 1,
4037
+ 1.118320672,
4038
+ 1.250641126,
4039
+ 1.398617824,
4040
+ 1.564103225,
4041
+ 1.74916897,
4042
+ 1.956131817,
4043
+ 2.187582649,
4044
+ 2.446418898,
4045
+ 2.735880826,
4046
+ 3.059592084,
4047
+ 3.421605075,
4048
+ 3.826451687,
4049
+ 4.279200023,
4050
+ 4.785517845,
4051
+ 5.351743533,
4052
+ 5.984965424,
4053
+ 6.693110555,
4054
+ 7.485043894,
4055
+ 8.370679318,
4056
+ 9.36110372,
4057
+ 10.4687158,
4058
+ 11.70738129,
4059
+ 13.09260651,
4060
+ 14.64173252,
4061
+ 16.37415215,
4062
+ 18.31155283,
4063
+ 20.47818807,
4064
+ 22.90118105,
4065
+ 25.61086418,
4066
+ 28.64115884,
4067
+ 32.03,
4068
+ 32.1,
4069
+ 32.13,
4070
+ 32.23,
4071
+ 32.6,
4072
+ 32.61,
4073
+ 32.64,
4074
+ 32.66,
4075
+ 32.7,
4076
+ 32.71,
4077
+ 32.93,
4078
+ 32.97,
4079
+ 33.28,
4080
+ 33.49,
4081
+ 33.5,
4082
+ 44.16,
4083
+ 47.77,
4084
+ ],
4085
+ "type": "longrope",
4086
+ },
4087
+ "rope_theta": 10000.0,
4088
+ "sliding_window": 262144,
4089
+ "tie_word_embeddings": true,
4090
+ "torch_dtype": "bfloat16",
4091
+ "transformers_version": "4.50.0",
4092
+ "use_cache": true,
4093
+ "vocab_size": 200064,
4094
+ }
4095
+ )