onnx-diagnostic 0.8.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.8.0"
6
+ __version__ = "0.8.1"
7
7
  __author__ = "Xavier Dupré"
@@ -265,7 +265,7 @@ def get_parser_config() -> ArgumentParser:
265
265
  "--mop",
266
266
  metavar="KEY=VALUE",
267
267
  nargs="*",
268
- help="Additional model options, use to change some parameters of the model, "
268
+ help="Additional model options, used to change some parameters of the model, "
269
269
  "example:\n --mop attn_implementation=sdpa or --mop attn_implementation=eager",
270
270
  action=_ParseDict,
271
271
  )
@@ -442,11 +442,17 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
442
442
  default=True,
443
443
  action=_BoolOrParseDictPatch,
444
444
  nargs="*",
445
- help="Applies patches before exporting, it can be a boolean "
446
- "to enable to disable the patches or be more finetuned. It is possible to "
447
- "disable patch for torch by adding "
448
- '--patch "patch_sympy=False" --patch "patch_torch=False", '
449
- "default is True.",
445
+ help=textwrap.dedent(
446
+ """
447
+ Applies patches before exporting, it can be a boolean
448
+ to enable to disable the patches or be more finetuned
449
+ (default is True). It is possible to disable patch for torch
450
+ by adding:
451
+ --patch "patch_sympy=False" --patch "patch_torch=False"
452
+ """.strip(
453
+ "\n"
454
+ )
455
+ ),
450
456
  )
451
457
  parser.add_argument(
452
458
  "--rewrite",
@@ -476,10 +482,16 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
476
482
  "--inputs2",
477
483
  default=1,
478
484
  type=int,
479
- help="Validates or exports the model on a second set of inputs\n"
480
- "to check the exported model supports dynamism. The values is used "
481
- "as an increment to the first set of inputs. A high value may trick "
482
- "a different behavior in the model and missed by the exporter.",
485
+ help=textwrap.dedent(
486
+ """
487
+ Validates or exports the model on a second set of inputs
488
+ to check the exported model supports dynamism. The values is used
489
+ as an increment to the first set of inputs. A high value may trick
490
+ a different behavior in the model and missed by the exporter.
491
+ """.strip(
492
+ "\n"
493
+ )
494
+ ),
483
495
  )
484
496
  parser.add_argument(
485
497
  "--runtime",
@@ -512,9 +524,15 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
512
524
  parser.add_argument(
513
525
  "--ortfusiontype",
514
526
  required=False,
515
- help="Applies onnxruntime fusion, this parameter should contain the\n"
516
- "model type or multiple values separated by `|`. `ALL` can be used\n"
517
- "to run them all.",
527
+ help=textwrap.dedent(
528
+ """
529
+ Applies onnxruntime fusion, this parameter should contain the
530
+ model type or multiple values separated by `|`. `ALL` can be used
531
+ to run them all.
532
+ """.strip(
533
+ "\n"
534
+ )
535
+ ),
518
536
  )
519
537
  parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
520
538
  parser.add_argument("--dtype", help="Changes dtype if necessary.")
@@ -523,18 +541,32 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
523
541
  "--iop",
524
542
  metavar="KEY=VALUE",
525
543
  nargs="*",
526
- help="Additional input options, use to change the default"
527
- "inputs use to export, example:\n --iop cls_cache=SlidingWindowCache"
528
- "\n --iop cls_cache=StaticCache",
544
+ help=textwrap.dedent(
545
+ """
546
+ Additional input options, used to change the default
547
+ inputs use to export. Examples:
548
+ --iop cls_cache=SlidingWindowCache
549
+ --iop cls_cache=StaticCache
550
+ """.strip(
551
+ "\n"
552
+ )
553
+ ),
529
554
  action=_ParseDict,
530
555
  )
531
556
  parser.add_argument(
532
557
  "--mop",
533
558
  metavar="KEY=VALUE",
534
559
  nargs="*",
535
- help="Additional model options, use to change some parameters of the model, "
536
- "example:\n --mop attn_implementation=sdpa --mop attn_implementation=eager\n "
537
- "--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\"",
560
+ help=textwrap.dedent(
561
+ """
562
+ Additional model options, used to change some parameters
563
+ of the model. Example:
564
+ --mop attn_implementation=sdpa --mop attn_implementation=eager"
565
+ --mop "rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}"
566
+ """.strip(
567
+ "\n"
568
+ )
569
+ ),
538
570
  action=_ParseDict,
539
571
  )
540
572
  if name == "validate":
@@ -566,9 +598,32 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
566
598
  parser.add_argument(
567
599
  "--quiet-input-sets",
568
600
  default="",
569
- help="Avoids raising an exception when an input sets does not work with "
570
- "the exported model.\nExample: --quiet-input-sets=inputs,inputs22",
601
+ help=textwrap.dedent(
602
+ """
603
+ Avoids raising an exception when an input sets does not work with
604
+ the exported model. Example:
605
+ --quiet-input-sets=inputs,inputs22
606
+ """.strip(
607
+ "\n"
608
+ )
609
+ ),
571
610
  )
611
+ parser.add_argument(
612
+ "--expop",
613
+ metavar="KEY=VALUE",
614
+ nargs="*",
615
+ help=textwrap.dedent(
616
+ """
617
+ Additional exporter options, use to change some parameters
618
+ of the model. Examples:
619
+ --expop report=True
620
+ --expop report=True --expop verify=True
621
+ """.strip(
622
+ "\n"
623
+ )
624
+ ),
625
+ action=_ParseDict,
626
+ )
572
627
  return parser
573
628
 
574
629
 
@@ -634,6 +689,7 @@ def _cmd_validate(argv: List[Any]):
634
689
  output_names=(
635
690
  None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
636
691
  ),
692
+ exporter_options=args.expop,
637
693
  )
638
694
  print("")
639
695
  print("-- summary --")
@@ -940,7 +996,7 @@ def get_parser_agg() -> ArgumentParser:
940
996
  "n_model_faster2x,n_model_faster3x,n_model_faster4x,n_node_attention,"
941
997
  "n_node_attention23,n_node_rotary_embedding,n_node_rotary_embedding23,"
942
998
  "n_node_gqa,n_node_layer_normalization,n_node_layer_normalization23,"
943
- "peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,"
999
+ "peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,n_node_random,"
944
1000
  "n_node_constant,n_node_shape,n_node_expand,"
945
1001
  "n_node_function,n_node_initializer,n_node_scatter,"
946
1002
  "time_export_unbiased,onnx_n_nodes_no_cst,n_node_initializer_small",
@@ -1016,6 +1016,8 @@ def max_diff(
1016
1016
 
1017
1017
  You may use :func:`string_diff` to display the discrepancies in one string.
1018
1018
  """
1019
+ if verbose >= 10:
1020
+ print(f"[max_diff] {type(expected)} ? {type(got)}")
1019
1021
  if expected is None and got is None:
1020
1022
  return dict(abs=0, rel=0, sum=0, n=0, dnan=0)
1021
1023
 
@@ -1061,8 +1063,8 @@ def max_diff(
1061
1063
  if expected.__class__.__name__ == "CausalLMOutputWithPast":
1062
1064
  if verbose >= 6:
1063
1065
  print(
1064
- f"[max_diff] CausalLMOutputWithPast: {string_type(expected)} "
1065
- f"? {string_type(got)}"
1066
+ f"[max_diff] CausalLMOutputWithPast: {string_type(expected, with_shape=True)} "
1067
+ f"? {string_type(got, with_shape=True)}"
1066
1068
  )
1067
1069
  if got.__class__.__name__ == "CausalLMOutputWithPast":
1068
1070
  return max_diff(
@@ -1169,7 +1169,8 @@ class CubeLogs:
1169
1169
  assuming they should remain stale
1170
1170
  :param sbs: configurations to compare side-by-side, this adds two tabs,
1171
1171
  one gathering raw data about the two configurations, the other one
1172
- is aggregated by metrics
1172
+ is aggregated by metrics, example:
1173
+ ``=dict(CFA=dict(exporter="E1", opt="O"), CFB=dict(exporter="E2", opt="O"))``
1173
1174
  """
1174
1175
  if verbose:
1175
1176
  print(f"[CubeLogs.to_excel] create Excel file {output}, shape={self.shape}")
@@ -1611,6 +1612,7 @@ class CubeLogsPerformance(CubeLogs):
1611
1612
  "n_node_initializer_small",
1612
1613
  "n_node_layer_normalization",
1613
1614
  "n_node_layer_normalization23",
1615
+ "n_node_random",
1614
1616
  "n_node_reshape",
1615
1617
  "n_node_rotary_embedding",
1616
1618
  "n_node_rotary_embedding23",
@@ -1802,6 +1804,16 @@ class CubeLogsPerformance(CubeLogs):
1802
1804
  + gdf(df, "op_onnx__InstanceNormlization", 0)
1803
1805
  + gdf(df, "op_onnx__GroupNormalization", 0),
1804
1806
  ),
1807
+ n_node_random=lambda df: gpreserve(
1808
+ df,
1809
+ "time_latency_eager",
1810
+ gdf(df, "op_onnx__RandomNormal", 0)
1811
+ + gdf(df, "op_onnx__RandomNormalLike", 0)
1812
+ + gdf(df, "op_onnx__RandomUniform", 0)
1813
+ + gdf(df, "op_onnx__RandomUniformLike", 0)
1814
+ + gdf(df, "op_onnx__Multinomial", 0)
1815
+ + gdf(df, "op_onnx__Bernoulli", 0),
1816
+ ),
1805
1817
  n_node_attention=lambda df: gpreserve(
1806
1818
  df,
1807
1819
  "time_latency_eager",
@@ -47,6 +47,8 @@ class Monitor:
47
47
 
48
48
  @property
49
49
  def delta_avg(self):
50
+ if self.n_measures == 0:
51
+ return 0
50
52
  return self.average / self.n_measures - self.begin
51
53
 
52
54
  def __repr__(self):
@@ -52,7 +52,7 @@ def proto_from_array(
52
52
 
53
53
  tensor = TensorProto()
54
54
  tensor.dims.extend(arr_cpu.shape)
55
- tensor.name = name
55
+ tensor.name = name or ""
56
56
  itype = dtype_to_tensor_dtype(arr_cpu.dtype)
57
57
  assert not hasattr(TensorProto, "INT4") or itype not in {
58
58
  TensorProto.INT4,
@@ -331,7 +331,7 @@ def onnx_dtype_name(itype: int, exc: bool = True) -> str:
331
331
  print(onnx_dtype_name(7))
332
332
  """
333
333
  for k in dir(TensorProto):
334
- if "FLOAT" in k or "INT" in k or "TEXT" in k or "BOOL" in k:
334
+ if k.upper() == k and k != "EXTERNAL":
335
335
  v = getattr(TensorProto, k)
336
336
  if v == itype:
337
337
  return k
@@ -10,13 +10,9 @@ from .ort_session import InferenceSessionForTorch
10
10
 
11
11
 
12
12
  def name_type_to_onnx_dtype(name: str) -> int:
13
- if name == "tensor(int64)":
14
- return onnx.TensorProto.INT64
15
- if name == "tensor(float)":
16
- return onnx.TensorProto.FLOAT
17
- if name == "tensor(float16)":
18
- return onnx.TensorProto.FLOAT16
19
- raise AssertionError(f"Unexpected value {name!r}")
13
+ assert name.startswith("tensor(") and name.endswith(")"), f"Invalid value name={name!r}"
14
+ look = name[7:-1]
15
+ return getattr(onnx.TensorProto, look.upper())
20
16
 
21
17
 
22
18
  def make_feeds(
@@ -153,7 +149,7 @@ def make_empty_cache(
153
149
  def generate_and_validate(
154
150
  model,
155
151
  input_ids: torch.Tensor,
156
- eos_token_id: int,
152
+ eos_token_id: int = 2,
157
153
  max_new_tokens: int = 100,
158
154
  session: Optional[Union[InferenceSessionForTorch, onnx.ModelProto, str]] = None,
159
155
  atol: float = 0.1,
@@ -262,10 +258,10 @@ def generate_and_validate(
262
258
  def onnx_generate(
263
259
  model_or_path: Union[onnx.ModelProto, str, InferenceSessionForTorch],
264
260
  input_ids: torch.Tensor,
265
- eos_token_id: int,
261
+ eos_token_id: int = 2,
266
262
  max_new_tokens=100,
267
263
  return_session: bool = False,
268
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, InferenceSessionForTorch]]:
264
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, InferenceSessionForTorch, Dict[str, Any]]]:
269
265
  """
270
266
  Implements a simple method ``generate`` for an ONNX model.
271
267
  The function does not expect any ``position_ids`` as input.
@@ -277,7 +273,7 @@ def onnx_generate(
277
273
  :param return_session: returns the instance of class
278
274
  :class:`InferenceSessionForTorch
279
275
  <onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch>`
280
- created if necessary
276
+ created if necessary, the function returns the feeds for the next iteration
281
277
  :return: input tokens concatenated with new tokens
282
278
 
283
279
  .. runpython::
@@ -353,12 +349,19 @@ def onnx_generate(
353
349
  input_shapes = session.input_shapes
354
350
  input_names = session.input_names
355
351
  input_types = session.input_types
352
+ has_position_ids = "position_ids" in session.input_names
356
353
 
357
354
  assert (
358
355
  len(input_names) > 2
359
356
  and input_names[:2] == ["input_ids", "attention_mask"]
360
- and input_names[2].startswith("past_key_values")
361
- ), f"Only text generation is supported but input_names == {input_names}"
357
+ and input_names[3 if has_position_ids else 2].startswith("past_key_values")
358
+ ), (
359
+ f"Only text generation is supported but input_names == {input_names}, "
360
+ f"has_position_ids={has_position_ids}"
361
+ )
362
+ assert (
363
+ not has_position_ids or input_names[2] == "position_ids"
364
+ ), f"position_ids must the third input but input_names={input_names}"
362
365
 
363
366
  # First call: prefill
364
367
  feeds = dict(
@@ -370,6 +373,10 @@ def onnx_generate(
370
373
  input_ids.shape[0], input_names[2:], input_shapes[2:], input_types[2:]
371
374
  ),
372
375
  )
376
+ if has_position_ids:
377
+ feeds["position_ids"] = torch.unsqueeze(
378
+ torch.arange(input_ids.shape[1], dtype=torch.int64, device=input_ids.device), 0
379
+ )
373
380
 
374
381
  outputs = session.run(None, feeds)
375
382
 
@@ -393,11 +400,21 @@ def onnx_generate(
393
400
  input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
394
401
  ),
395
402
  )
396
- feeds.update(dict(zip(input_names[2:], outputs[1:])))
403
+ if has_position_ids:
404
+ feeds["position_ids"] = torch.unsqueeze(
405
+ torch.arange(
406
+ input_ids.shape[1],
407
+ input_ids.shape[1] + 1,
408
+ dtype=torch.int64,
409
+ device=input_ids.device,
410
+ ),
411
+ 0,
412
+ )
413
+ feeds.update(dict(zip(input_names[3 if has_position_ids else 2 :], outputs[1:])))
397
414
  outputs = session.run(None, feeds)
398
415
 
399
416
  if return_session:
400
- return input_ids, session
417
+ return input_ids, session, feeds
401
418
  return input_ids
402
419
 
403
420
 
@@ -151,6 +151,7 @@ def get_inputs(
151
151
  assert (
152
152
  add_second_input > 0
153
153
  ), f"Not implemented for add_second_input={add_second_input}."
154
+ res["inputs_prompt"] = dict(input_ids=torch.randint(1000, 30000, (1, 11)))
154
155
  res["inputs2"] = get_inputs(
155
156
  model=model,
156
157
  config=config,
@@ -56,6 +56,74 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
56
56
  return kwargs
57
57
 
58
58
 
59
+ def _get_input_falcon_mamba(
60
+ model: torch.nn.Module,
61
+ config: Optional[Any],
62
+ dummy_max_token_id: int,
63
+ num_hidden_layers: int,
64
+ batch_size: int = 2,
65
+ sequence_length: int = 30,
66
+ sequence_length2: int = 3,
67
+ dynamic_rope: bool = False,
68
+ num_key_value_heads: Optional[int] = None,
69
+ head_dim: Optional[int] = None,
70
+ cls_cache: Optional[Union[type, str]] = None,
71
+ **kwargs, # unused
72
+ ):
73
+ try:
74
+ from transformers.models.mamba.modeling_mamba import MambaCache
75
+ except ImportError:
76
+ from transformers.cache_utils import MambaCache
77
+
78
+ assert cls_cache in (
79
+ "MambaCache",
80
+ MambaCache,
81
+ ), f"Unexpected value for cls_cache={cls_cache} and config={config}"
82
+
83
+ batch = "batch"
84
+ seq_length_multiple = 8
85
+ sequence_length = (
86
+ (sequence_length + seq_length_multiple) // seq_length_multiple * seq_length_multiple
87
+ )
88
+ # sequence_inc = seq_length_multiple
89
+ sequence_length2 = seq_length_multiple
90
+
91
+ shapes = {
92
+ "input_ids": {0: batch, 1: "sequence_length"},
93
+ "attention_mask": {
94
+ 0: batch,
95
+ 1: "cache+seq", # cache_length + seq_length
96
+ },
97
+ "cache_position": {
98
+ 0: batch,
99
+ 1: "cache+seq", # cache_length + seq_length
100
+ },
101
+ "cache_params": [{0: batch} for _ in range(num_hidden_layers * 2)],
102
+ }
103
+ inputs = dict(
104
+ input_ids=torch.randint(
105
+ 0, dummy_max_token_id, (batch_size, sequence_length + sequence_length2)
106
+ ).to(torch.int64),
107
+ attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
108
+ torch.int64
109
+ ),
110
+ cache_position=torch.arange(0, kwargs["conv_kernel"]).to(torch.int64),
111
+ # .expand((batch_size, -1))
112
+ cache_params=make_mamba_cache(
113
+ [
114
+ (
115
+ torch.randn(
116
+ batch_size, kwargs["intermediate_size"], kwargs["conv_kernel"]
117
+ ),
118
+ torch.randn(batch_size, kwargs["intermediate_size"], kwargs["state_size"]),
119
+ )
120
+ for i in range(num_hidden_layers)
121
+ ]
122
+ ),
123
+ )
124
+ return dict(inputs=inputs, dynamic_shapes=shapes)
125
+
126
+
59
127
  def get_inputs(
60
128
  model: torch.nn.Module,
61
129
  config: Optional[Any],
@@ -68,7 +136,7 @@ def get_inputs(
68
136
  num_key_value_heads: Optional[int] = None,
69
137
  head_dim: Optional[int] = None,
70
138
  cls_cache: Optional[Union[type, str]] = None,
71
- add_second_input: int = 1,
139
+ add_second_input: Optional[int] = None,
72
140
  **kwargs, # unused
73
141
  ):
74
142
  """
@@ -84,6 +152,7 @@ def get_inputs(
84
152
  :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
85
153
  :param cls_cache: cache class, by default it is
86
154
  :class:`transformers.cache_utils.DynamicCache`
155
+ :param add_second_input: adds other kinds of inputs
87
156
  :return: dictionary
88
157
  """
89
158
  batch = "batch"
@@ -91,60 +160,20 @@ def get_inputs(
91
160
  cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
92
161
 
93
162
  if config is not None and config.__class__.__name__ == "FalconMambaConfig":
94
- try:
95
- from transformers.models.mamba.modeling_mamba import MambaCache
96
- except ImportError:
97
- from transformers.cache_utils import MambaCache
98
-
99
- assert cls_cache in (
100
- "MambaCache",
101
- MambaCache,
102
- ), f"Unexpected value for cls_cache={cls_cache} and config={config}"
103
- seq_length_multiple = 8
104
- sequence_length = (
105
- (sequence_length + seq_length_multiple)
106
- // seq_length_multiple
107
- * seq_length_multiple
108
- )
109
- # sequence_inc = seq_length_multiple
110
- sequence_length2 = seq_length_multiple
111
-
112
- shapes = {
113
- "input_ids": {0: batch, 1: "sequence_length"},
114
- "attention_mask": {
115
- 0: batch,
116
- 1: "cache+seq", # cache_length + seq_length
117
- },
118
- "cache_position": {
119
- 0: batch,
120
- 1: "cache+seq", # cache_length + seq_length
121
- },
122
- "cache_params": [{0: batch} for _ in range(num_hidden_layers * 2)],
123
- }
124
- inputs = dict(
125
- input_ids=torch.randint(
126
- 0, dummy_max_token_id, (batch_size, sequence_length + sequence_length2)
127
- ).to(torch.int64),
128
- attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
129
- torch.int64
130
- ),
131
- cache_position=torch.arange(0, kwargs["conv_kernel"]).to(torch.int64),
132
- # .expand((batch_size, -1))
133
- cache_params=make_mamba_cache(
134
- [
135
- (
136
- torch.randn(
137
- batch_size, kwargs["intermediate_size"], kwargs["conv_kernel"]
138
- ),
139
- torch.randn(
140
- batch_size, kwargs["intermediate_size"], kwargs["state_size"]
141
- ),
142
- )
143
- for i in range(num_hidden_layers)
144
- ]
145
- ),
163
+ res = _get_input_falcon_mamba(
164
+ model=model,
165
+ config=config,
166
+ dummy_max_token_id=dummy_max_token_id,
167
+ num_hidden_layers=num_hidden_layers,
168
+ batch_size=batch_size,
169
+ sequence_length=sequence_length,
170
+ sequence_length2=sequence_length2,
171
+ dynamic_rope=dynamic_rope,
172
+ num_key_value_heads=num_key_value_heads,
173
+ head_dim=head_dim,
174
+ cls_cache=cls_cache,
175
+ **kwargs, # unused
146
176
  )
147
- res = dict(inputs=inputs, dynamic_shapes=shapes)
148
177
  else:
149
178
  if head_dim is None:
150
179
  assert config, "head_dim is None, the value cannot be set without a configuration"
@@ -244,6 +273,7 @@ def get_inputs(
244
273
  )
245
274
  res = dict(inputs=inputs, dynamic_shapes=shapes)
246
275
  if add_second_input:
276
+ res["inputs_prompt"] = dict(input_ids=torch.randint(1000, 30000, (1, 11)))
247
277
  res["inputs2"] = get_inputs(
248
278
  model=model,
249
279
  config=config,
@@ -195,9 +195,12 @@ class patched_ShapeEnv:
195
195
  if self.frozen:
196
196
  self.counter["ignored_backward_guard"] += 1
197
197
  # PATCHED: raised an exception instead of logging.
198
+ import transformers
199
+
198
200
  raise AssertionError(
199
201
  f"[patched_ShapeEnv] Ignored guard {expr} == {concrete_val}, "
200
- f"this could result in accuracy problems"
202
+ f"this could result in accuracy problems, transformers.__version__="
203
+ f"{transformers.__version__!r}"
201
204
  )
202
205
 
203
206
  def _set_replacement(
@@ -1452,7 +1452,7 @@ def patched_sdpa_attention_forward(
1452
1452
  scale=scaling,
1453
1453
  is_causal=True,
1454
1454
  **sdpa_kwargs,
1455
- ),
1455
+ ).contiguous(),
1456
1456
  lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
1457
1457
  query,
1458
1458
  key,
@@ -1461,7 +1461,7 @@ def patched_sdpa_attention_forward(
1461
1461
  scale=scaling,
1462
1462
  is_causal=False,
1463
1463
  **sdpa_kwargs,
1464
- ),
1464
+ ).contiguous(),
1465
1465
  [query, key, value],
1466
1466
  )
1467
1467
  attn_output = attn_output.transpose(1, 2).contiguous()