onnx-diagnostic 0.8.0__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +78 -22
- onnx_diagnostic/helpers/helper.py +4 -2
- onnx_diagnostic/helpers/log_helper.py +13 -1
- onnx_diagnostic/helpers/memory_peak.py +2 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +1 -1
- onnx_diagnostic/helpers/onnx_helper.py +1 -1
- onnx_diagnostic/helpers/rt_helper.py +32 -15
- onnx_diagnostic/tasks/text2text_generation.py +1 -0
- onnx_diagnostic/tasks/text_generation.py +84 -54
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +4 -1
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2 -2
- onnx_diagnostic/torch_models/validate.py +620 -213
- {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.1.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.1.dist-info}/RECORD +18 -18
- {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.1.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.1.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.1.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -265,7 +265,7 @@ def get_parser_config() -> ArgumentParser:
|
|
|
265
265
|
"--mop",
|
|
266
266
|
metavar="KEY=VALUE",
|
|
267
267
|
nargs="*",
|
|
268
|
-
help="Additional model options,
|
|
268
|
+
help="Additional model options, used to change some parameters of the model, "
|
|
269
269
|
"example:\n --mop attn_implementation=sdpa or --mop attn_implementation=eager",
|
|
270
270
|
action=_ParseDict,
|
|
271
271
|
)
|
|
@@ -442,11 +442,17 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
|
|
|
442
442
|
default=True,
|
|
443
443
|
action=_BoolOrParseDictPatch,
|
|
444
444
|
nargs="*",
|
|
445
|
-
help=
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
445
|
+
help=textwrap.dedent(
|
|
446
|
+
"""
|
|
447
|
+
Applies patches before exporting, it can be a boolean
|
|
448
|
+
to enable to disable the patches or be more finetuned
|
|
449
|
+
(default is True). It is possible to disable patch for torch
|
|
450
|
+
by adding:
|
|
451
|
+
--patch "patch_sympy=False" --patch "patch_torch=False"
|
|
452
|
+
""".strip(
|
|
453
|
+
"\n"
|
|
454
|
+
)
|
|
455
|
+
),
|
|
450
456
|
)
|
|
451
457
|
parser.add_argument(
|
|
452
458
|
"--rewrite",
|
|
@@ -476,10 +482,16 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
|
|
|
476
482
|
"--inputs2",
|
|
477
483
|
default=1,
|
|
478
484
|
type=int,
|
|
479
|
-
help=
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
485
|
+
help=textwrap.dedent(
|
|
486
|
+
"""
|
|
487
|
+
Validates or exports the model on a second set of inputs
|
|
488
|
+
to check the exported model supports dynamism. The values is used
|
|
489
|
+
as an increment to the first set of inputs. A high value may trick
|
|
490
|
+
a different behavior in the model and missed by the exporter.
|
|
491
|
+
""".strip(
|
|
492
|
+
"\n"
|
|
493
|
+
)
|
|
494
|
+
),
|
|
483
495
|
)
|
|
484
496
|
parser.add_argument(
|
|
485
497
|
"--runtime",
|
|
@@ -512,9 +524,15 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
|
|
|
512
524
|
parser.add_argument(
|
|
513
525
|
"--ortfusiontype",
|
|
514
526
|
required=False,
|
|
515
|
-
help=
|
|
516
|
-
|
|
517
|
-
|
|
527
|
+
help=textwrap.dedent(
|
|
528
|
+
"""
|
|
529
|
+
Applies onnxruntime fusion, this parameter should contain the
|
|
530
|
+
model type or multiple values separated by `|`. `ALL` can be used
|
|
531
|
+
to run them all.
|
|
532
|
+
""".strip(
|
|
533
|
+
"\n"
|
|
534
|
+
)
|
|
535
|
+
),
|
|
518
536
|
)
|
|
519
537
|
parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
|
|
520
538
|
parser.add_argument("--dtype", help="Changes dtype if necessary.")
|
|
@@ -523,18 +541,32 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
|
|
|
523
541
|
"--iop",
|
|
524
542
|
metavar="KEY=VALUE",
|
|
525
543
|
nargs="*",
|
|
526
|
-
help=
|
|
527
|
-
|
|
528
|
-
|
|
544
|
+
help=textwrap.dedent(
|
|
545
|
+
"""
|
|
546
|
+
Additional input options, used to change the default
|
|
547
|
+
inputs use to export. Examples:
|
|
548
|
+
--iop cls_cache=SlidingWindowCache
|
|
549
|
+
--iop cls_cache=StaticCache
|
|
550
|
+
""".strip(
|
|
551
|
+
"\n"
|
|
552
|
+
)
|
|
553
|
+
),
|
|
529
554
|
action=_ParseDict,
|
|
530
555
|
)
|
|
531
556
|
parser.add_argument(
|
|
532
557
|
"--mop",
|
|
533
558
|
metavar="KEY=VALUE",
|
|
534
559
|
nargs="*",
|
|
535
|
-
help=
|
|
536
|
-
|
|
537
|
-
|
|
560
|
+
help=textwrap.dedent(
|
|
561
|
+
"""
|
|
562
|
+
Additional model options, used to change some parameters
|
|
563
|
+
of the model. Example:
|
|
564
|
+
--mop attn_implementation=sdpa --mop attn_implementation=eager"
|
|
565
|
+
--mop "rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}"
|
|
566
|
+
""".strip(
|
|
567
|
+
"\n"
|
|
568
|
+
)
|
|
569
|
+
),
|
|
538
570
|
action=_ParseDict,
|
|
539
571
|
)
|
|
540
572
|
if name == "validate":
|
|
@@ -566,9 +598,32 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
|
|
|
566
598
|
parser.add_argument(
|
|
567
599
|
"--quiet-input-sets",
|
|
568
600
|
default="",
|
|
569
|
-
help=
|
|
570
|
-
|
|
601
|
+
help=textwrap.dedent(
|
|
602
|
+
"""
|
|
603
|
+
Avoids raising an exception when an input sets does not work with
|
|
604
|
+
the exported model. Example:
|
|
605
|
+
--quiet-input-sets=inputs,inputs22
|
|
606
|
+
""".strip(
|
|
607
|
+
"\n"
|
|
608
|
+
)
|
|
609
|
+
),
|
|
571
610
|
)
|
|
611
|
+
parser.add_argument(
|
|
612
|
+
"--expop",
|
|
613
|
+
metavar="KEY=VALUE",
|
|
614
|
+
nargs="*",
|
|
615
|
+
help=textwrap.dedent(
|
|
616
|
+
"""
|
|
617
|
+
Additional exporter options, use to change some parameters
|
|
618
|
+
of the model. Examples:
|
|
619
|
+
--expop report=True
|
|
620
|
+
--expop report=True --expop verify=True
|
|
621
|
+
""".strip(
|
|
622
|
+
"\n"
|
|
623
|
+
)
|
|
624
|
+
),
|
|
625
|
+
action=_ParseDict,
|
|
626
|
+
)
|
|
572
627
|
return parser
|
|
573
628
|
|
|
574
629
|
|
|
@@ -634,6 +689,7 @@ def _cmd_validate(argv: List[Any]):
|
|
|
634
689
|
output_names=(
|
|
635
690
|
None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
|
|
636
691
|
),
|
|
692
|
+
exporter_options=args.expop,
|
|
637
693
|
)
|
|
638
694
|
print("")
|
|
639
695
|
print("-- summary --")
|
|
@@ -940,7 +996,7 @@ def get_parser_agg() -> ArgumentParser:
|
|
|
940
996
|
"n_model_faster2x,n_model_faster3x,n_model_faster4x,n_node_attention,"
|
|
941
997
|
"n_node_attention23,n_node_rotary_embedding,n_node_rotary_embedding23,"
|
|
942
998
|
"n_node_gqa,n_node_layer_normalization,n_node_layer_normalization23,"
|
|
943
|
-
"peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,"
|
|
999
|
+
"peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,n_node_random,"
|
|
944
1000
|
"n_node_constant,n_node_shape,n_node_expand,"
|
|
945
1001
|
"n_node_function,n_node_initializer,n_node_scatter,"
|
|
946
1002
|
"time_export_unbiased,onnx_n_nodes_no_cst,n_node_initializer_small",
|
|
@@ -1016,6 +1016,8 @@ def max_diff(
|
|
|
1016
1016
|
|
|
1017
1017
|
You may use :func:`string_diff` to display the discrepancies in one string.
|
|
1018
1018
|
"""
|
|
1019
|
+
if verbose >= 10:
|
|
1020
|
+
print(f"[max_diff] {type(expected)} ? {type(got)}")
|
|
1019
1021
|
if expected is None and got is None:
|
|
1020
1022
|
return dict(abs=0, rel=0, sum=0, n=0, dnan=0)
|
|
1021
1023
|
|
|
@@ -1061,8 +1063,8 @@ def max_diff(
|
|
|
1061
1063
|
if expected.__class__.__name__ == "CausalLMOutputWithPast":
|
|
1062
1064
|
if verbose >= 6:
|
|
1063
1065
|
print(
|
|
1064
|
-
f"[max_diff] CausalLMOutputWithPast: {string_type(expected)} "
|
|
1065
|
-
f"? {string_type(got)}"
|
|
1066
|
+
f"[max_diff] CausalLMOutputWithPast: {string_type(expected, with_shape=True)} "
|
|
1067
|
+
f"? {string_type(got, with_shape=True)}"
|
|
1066
1068
|
)
|
|
1067
1069
|
if got.__class__.__name__ == "CausalLMOutputWithPast":
|
|
1068
1070
|
return max_diff(
|
|
@@ -1169,7 +1169,8 @@ class CubeLogs:
|
|
|
1169
1169
|
assuming they should remain stale
|
|
1170
1170
|
:param sbs: configurations to compare side-by-side, this adds two tabs,
|
|
1171
1171
|
one gathering raw data about the two configurations, the other one
|
|
1172
|
-
is aggregated by metrics
|
|
1172
|
+
is aggregated by metrics, example:
|
|
1173
|
+
``=dict(CFA=dict(exporter="E1", opt="O"), CFB=dict(exporter="E2", opt="O"))``
|
|
1173
1174
|
"""
|
|
1174
1175
|
if verbose:
|
|
1175
1176
|
print(f"[CubeLogs.to_excel] create Excel file {output}, shape={self.shape}")
|
|
@@ -1611,6 +1612,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1611
1612
|
"n_node_initializer_small",
|
|
1612
1613
|
"n_node_layer_normalization",
|
|
1613
1614
|
"n_node_layer_normalization23",
|
|
1615
|
+
"n_node_random",
|
|
1614
1616
|
"n_node_reshape",
|
|
1615
1617
|
"n_node_rotary_embedding",
|
|
1616
1618
|
"n_node_rotary_embedding23",
|
|
@@ -1802,6 +1804,16 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1802
1804
|
+ gdf(df, "op_onnx__InstanceNormlization", 0)
|
|
1803
1805
|
+ gdf(df, "op_onnx__GroupNormalization", 0),
|
|
1804
1806
|
),
|
|
1807
|
+
n_node_random=lambda df: gpreserve(
|
|
1808
|
+
df,
|
|
1809
|
+
"time_latency_eager",
|
|
1810
|
+
gdf(df, "op_onnx__RandomNormal", 0)
|
|
1811
|
+
+ gdf(df, "op_onnx__RandomNormalLike", 0)
|
|
1812
|
+
+ gdf(df, "op_onnx__RandomUniform", 0)
|
|
1813
|
+
+ gdf(df, "op_onnx__RandomUniformLike", 0)
|
|
1814
|
+
+ gdf(df, "op_onnx__Multinomial", 0)
|
|
1815
|
+
+ gdf(df, "op_onnx__Bernoulli", 0),
|
|
1816
|
+
),
|
|
1805
1817
|
n_node_attention=lambda df: gpreserve(
|
|
1806
1818
|
df,
|
|
1807
1819
|
"time_latency_eager",
|
|
@@ -52,7 +52,7 @@ def proto_from_array(
|
|
|
52
52
|
|
|
53
53
|
tensor = TensorProto()
|
|
54
54
|
tensor.dims.extend(arr_cpu.shape)
|
|
55
|
-
tensor.name = name
|
|
55
|
+
tensor.name = name or ""
|
|
56
56
|
itype = dtype_to_tensor_dtype(arr_cpu.dtype)
|
|
57
57
|
assert not hasattr(TensorProto, "INT4") or itype not in {
|
|
58
58
|
TensorProto.INT4,
|
|
@@ -331,7 +331,7 @@ def onnx_dtype_name(itype: int, exc: bool = True) -> str:
|
|
|
331
331
|
print(onnx_dtype_name(7))
|
|
332
332
|
"""
|
|
333
333
|
for k in dir(TensorProto):
|
|
334
|
-
if
|
|
334
|
+
if k.upper() == k and k != "EXTERNAL":
|
|
335
335
|
v = getattr(TensorProto, k)
|
|
336
336
|
if v == itype:
|
|
337
337
|
return k
|
|
@@ -10,13 +10,9 @@ from .ort_session import InferenceSessionForTorch
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def name_type_to_onnx_dtype(name: str) -> int:
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
return onnx.TensorProto.FLOAT
|
|
17
|
-
if name == "tensor(float16)":
|
|
18
|
-
return onnx.TensorProto.FLOAT16
|
|
19
|
-
raise AssertionError(f"Unexpected value {name!r}")
|
|
13
|
+
assert name.startswith("tensor(") and name.endswith(")"), f"Invalid value name={name!r}"
|
|
14
|
+
look = name[7:-1]
|
|
15
|
+
return getattr(onnx.TensorProto, look.upper())
|
|
20
16
|
|
|
21
17
|
|
|
22
18
|
def make_feeds(
|
|
@@ -153,7 +149,7 @@ def make_empty_cache(
|
|
|
153
149
|
def generate_and_validate(
|
|
154
150
|
model,
|
|
155
151
|
input_ids: torch.Tensor,
|
|
156
|
-
eos_token_id: int,
|
|
152
|
+
eos_token_id: int = 2,
|
|
157
153
|
max_new_tokens: int = 100,
|
|
158
154
|
session: Optional[Union[InferenceSessionForTorch, onnx.ModelProto, str]] = None,
|
|
159
155
|
atol: float = 0.1,
|
|
@@ -262,10 +258,10 @@ def generate_and_validate(
|
|
|
262
258
|
def onnx_generate(
|
|
263
259
|
model_or_path: Union[onnx.ModelProto, str, InferenceSessionForTorch],
|
|
264
260
|
input_ids: torch.Tensor,
|
|
265
|
-
eos_token_id: int,
|
|
261
|
+
eos_token_id: int = 2,
|
|
266
262
|
max_new_tokens=100,
|
|
267
263
|
return_session: bool = False,
|
|
268
|
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, InferenceSessionForTorch]]:
|
|
264
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, InferenceSessionForTorch, Dict[str, Any]]]:
|
|
269
265
|
"""
|
|
270
266
|
Implements a simple method ``generate`` for an ONNX model.
|
|
271
267
|
The function does not expect any ``position_ids`` as input.
|
|
@@ -277,7 +273,7 @@ def onnx_generate(
|
|
|
277
273
|
:param return_session: returns the instance of class
|
|
278
274
|
:class:`InferenceSessionForTorch
|
|
279
275
|
<onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch>`
|
|
280
|
-
created if necessary
|
|
276
|
+
created if necessary, the function returns the feeds for the next iteration
|
|
281
277
|
:return: input tokens concatenated with new tokens
|
|
282
278
|
|
|
283
279
|
.. runpython::
|
|
@@ -353,12 +349,19 @@ def onnx_generate(
|
|
|
353
349
|
input_shapes = session.input_shapes
|
|
354
350
|
input_names = session.input_names
|
|
355
351
|
input_types = session.input_types
|
|
352
|
+
has_position_ids = "position_ids" in session.input_names
|
|
356
353
|
|
|
357
354
|
assert (
|
|
358
355
|
len(input_names) > 2
|
|
359
356
|
and input_names[:2] == ["input_ids", "attention_mask"]
|
|
360
|
-
and input_names[2].startswith("past_key_values")
|
|
361
|
-
),
|
|
357
|
+
and input_names[3 if has_position_ids else 2].startswith("past_key_values")
|
|
358
|
+
), (
|
|
359
|
+
f"Only text generation is supported but input_names == {input_names}, "
|
|
360
|
+
f"has_position_ids={has_position_ids}"
|
|
361
|
+
)
|
|
362
|
+
assert (
|
|
363
|
+
not has_position_ids or input_names[2] == "position_ids"
|
|
364
|
+
), f"position_ids must the third input but input_names={input_names}"
|
|
362
365
|
|
|
363
366
|
# First call: prefill
|
|
364
367
|
feeds = dict(
|
|
@@ -370,6 +373,10 @@ def onnx_generate(
|
|
|
370
373
|
input_ids.shape[0], input_names[2:], input_shapes[2:], input_types[2:]
|
|
371
374
|
),
|
|
372
375
|
)
|
|
376
|
+
if has_position_ids:
|
|
377
|
+
feeds["position_ids"] = torch.unsqueeze(
|
|
378
|
+
torch.arange(input_ids.shape[1], dtype=torch.int64, device=input_ids.device), 0
|
|
379
|
+
)
|
|
373
380
|
|
|
374
381
|
outputs = session.run(None, feeds)
|
|
375
382
|
|
|
@@ -393,11 +400,21 @@ def onnx_generate(
|
|
|
393
400
|
input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
|
|
394
401
|
),
|
|
395
402
|
)
|
|
396
|
-
|
|
403
|
+
if has_position_ids:
|
|
404
|
+
feeds["position_ids"] = torch.unsqueeze(
|
|
405
|
+
torch.arange(
|
|
406
|
+
input_ids.shape[1],
|
|
407
|
+
input_ids.shape[1] + 1,
|
|
408
|
+
dtype=torch.int64,
|
|
409
|
+
device=input_ids.device,
|
|
410
|
+
),
|
|
411
|
+
0,
|
|
412
|
+
)
|
|
413
|
+
feeds.update(dict(zip(input_names[3 if has_position_ids else 2 :], outputs[1:])))
|
|
397
414
|
outputs = session.run(None, feeds)
|
|
398
415
|
|
|
399
416
|
if return_session:
|
|
400
|
-
return input_ids, session
|
|
417
|
+
return input_ids, session, feeds
|
|
401
418
|
return input_ids
|
|
402
419
|
|
|
403
420
|
|
|
@@ -151,6 +151,7 @@ def get_inputs(
|
|
|
151
151
|
assert (
|
|
152
152
|
add_second_input > 0
|
|
153
153
|
), f"Not implemented for add_second_input={add_second_input}."
|
|
154
|
+
res["inputs_prompt"] = dict(input_ids=torch.randint(1000, 30000, (1, 11)))
|
|
154
155
|
res["inputs2"] = get_inputs(
|
|
155
156
|
model=model,
|
|
156
157
|
config=config,
|
|
@@ -56,6 +56,74 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
|
56
56
|
return kwargs
|
|
57
57
|
|
|
58
58
|
|
|
59
|
+
def _get_input_falcon_mamba(
|
|
60
|
+
model: torch.nn.Module,
|
|
61
|
+
config: Optional[Any],
|
|
62
|
+
dummy_max_token_id: int,
|
|
63
|
+
num_hidden_layers: int,
|
|
64
|
+
batch_size: int = 2,
|
|
65
|
+
sequence_length: int = 30,
|
|
66
|
+
sequence_length2: int = 3,
|
|
67
|
+
dynamic_rope: bool = False,
|
|
68
|
+
num_key_value_heads: Optional[int] = None,
|
|
69
|
+
head_dim: Optional[int] = None,
|
|
70
|
+
cls_cache: Optional[Union[type, str]] = None,
|
|
71
|
+
**kwargs, # unused
|
|
72
|
+
):
|
|
73
|
+
try:
|
|
74
|
+
from transformers.models.mamba.modeling_mamba import MambaCache
|
|
75
|
+
except ImportError:
|
|
76
|
+
from transformers.cache_utils import MambaCache
|
|
77
|
+
|
|
78
|
+
assert cls_cache in (
|
|
79
|
+
"MambaCache",
|
|
80
|
+
MambaCache,
|
|
81
|
+
), f"Unexpected value for cls_cache={cls_cache} and config={config}"
|
|
82
|
+
|
|
83
|
+
batch = "batch"
|
|
84
|
+
seq_length_multiple = 8
|
|
85
|
+
sequence_length = (
|
|
86
|
+
(sequence_length + seq_length_multiple) // seq_length_multiple * seq_length_multiple
|
|
87
|
+
)
|
|
88
|
+
# sequence_inc = seq_length_multiple
|
|
89
|
+
sequence_length2 = seq_length_multiple
|
|
90
|
+
|
|
91
|
+
shapes = {
|
|
92
|
+
"input_ids": {0: batch, 1: "sequence_length"},
|
|
93
|
+
"attention_mask": {
|
|
94
|
+
0: batch,
|
|
95
|
+
1: "cache+seq", # cache_length + seq_length
|
|
96
|
+
},
|
|
97
|
+
"cache_position": {
|
|
98
|
+
0: batch,
|
|
99
|
+
1: "cache+seq", # cache_length + seq_length
|
|
100
|
+
},
|
|
101
|
+
"cache_params": [{0: batch} for _ in range(num_hidden_layers * 2)],
|
|
102
|
+
}
|
|
103
|
+
inputs = dict(
|
|
104
|
+
input_ids=torch.randint(
|
|
105
|
+
0, dummy_max_token_id, (batch_size, sequence_length + sequence_length2)
|
|
106
|
+
).to(torch.int64),
|
|
107
|
+
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
|
|
108
|
+
torch.int64
|
|
109
|
+
),
|
|
110
|
+
cache_position=torch.arange(0, kwargs["conv_kernel"]).to(torch.int64),
|
|
111
|
+
# .expand((batch_size, -1))
|
|
112
|
+
cache_params=make_mamba_cache(
|
|
113
|
+
[
|
|
114
|
+
(
|
|
115
|
+
torch.randn(
|
|
116
|
+
batch_size, kwargs["intermediate_size"], kwargs["conv_kernel"]
|
|
117
|
+
),
|
|
118
|
+
torch.randn(batch_size, kwargs["intermediate_size"], kwargs["state_size"]),
|
|
119
|
+
)
|
|
120
|
+
for i in range(num_hidden_layers)
|
|
121
|
+
]
|
|
122
|
+
),
|
|
123
|
+
)
|
|
124
|
+
return dict(inputs=inputs, dynamic_shapes=shapes)
|
|
125
|
+
|
|
126
|
+
|
|
59
127
|
def get_inputs(
|
|
60
128
|
model: torch.nn.Module,
|
|
61
129
|
config: Optional[Any],
|
|
@@ -68,7 +136,7 @@ def get_inputs(
|
|
|
68
136
|
num_key_value_heads: Optional[int] = None,
|
|
69
137
|
head_dim: Optional[int] = None,
|
|
70
138
|
cls_cache: Optional[Union[type, str]] = None,
|
|
71
|
-
add_second_input: int =
|
|
139
|
+
add_second_input: Optional[int] = None,
|
|
72
140
|
**kwargs, # unused
|
|
73
141
|
):
|
|
74
142
|
"""
|
|
@@ -84,6 +152,7 @@ def get_inputs(
|
|
|
84
152
|
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
|
|
85
153
|
:param cls_cache: cache class, by default it is
|
|
86
154
|
:class:`transformers.cache_utils.DynamicCache`
|
|
155
|
+
:param add_second_input: adds other kinds of inputs
|
|
87
156
|
:return: dictionary
|
|
88
157
|
"""
|
|
89
158
|
batch = "batch"
|
|
@@ -91,60 +160,20 @@ def get_inputs(
|
|
|
91
160
|
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
|
|
92
161
|
|
|
93
162
|
if config is not None and config.__class__.__name__ == "FalconMambaConfig":
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
* seq_length_multiple
|
|
108
|
-
)
|
|
109
|
-
# sequence_inc = seq_length_multiple
|
|
110
|
-
sequence_length2 = seq_length_multiple
|
|
111
|
-
|
|
112
|
-
shapes = {
|
|
113
|
-
"input_ids": {0: batch, 1: "sequence_length"},
|
|
114
|
-
"attention_mask": {
|
|
115
|
-
0: batch,
|
|
116
|
-
1: "cache+seq", # cache_length + seq_length
|
|
117
|
-
},
|
|
118
|
-
"cache_position": {
|
|
119
|
-
0: batch,
|
|
120
|
-
1: "cache+seq", # cache_length + seq_length
|
|
121
|
-
},
|
|
122
|
-
"cache_params": [{0: batch} for _ in range(num_hidden_layers * 2)],
|
|
123
|
-
}
|
|
124
|
-
inputs = dict(
|
|
125
|
-
input_ids=torch.randint(
|
|
126
|
-
0, dummy_max_token_id, (batch_size, sequence_length + sequence_length2)
|
|
127
|
-
).to(torch.int64),
|
|
128
|
-
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
|
|
129
|
-
torch.int64
|
|
130
|
-
),
|
|
131
|
-
cache_position=torch.arange(0, kwargs["conv_kernel"]).to(torch.int64),
|
|
132
|
-
# .expand((batch_size, -1))
|
|
133
|
-
cache_params=make_mamba_cache(
|
|
134
|
-
[
|
|
135
|
-
(
|
|
136
|
-
torch.randn(
|
|
137
|
-
batch_size, kwargs["intermediate_size"], kwargs["conv_kernel"]
|
|
138
|
-
),
|
|
139
|
-
torch.randn(
|
|
140
|
-
batch_size, kwargs["intermediate_size"], kwargs["state_size"]
|
|
141
|
-
),
|
|
142
|
-
)
|
|
143
|
-
for i in range(num_hidden_layers)
|
|
144
|
-
]
|
|
145
|
-
),
|
|
163
|
+
res = _get_input_falcon_mamba(
|
|
164
|
+
model=model,
|
|
165
|
+
config=config,
|
|
166
|
+
dummy_max_token_id=dummy_max_token_id,
|
|
167
|
+
num_hidden_layers=num_hidden_layers,
|
|
168
|
+
batch_size=batch_size,
|
|
169
|
+
sequence_length=sequence_length,
|
|
170
|
+
sequence_length2=sequence_length2,
|
|
171
|
+
dynamic_rope=dynamic_rope,
|
|
172
|
+
num_key_value_heads=num_key_value_heads,
|
|
173
|
+
head_dim=head_dim,
|
|
174
|
+
cls_cache=cls_cache,
|
|
175
|
+
**kwargs, # unused
|
|
146
176
|
)
|
|
147
|
-
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
148
177
|
else:
|
|
149
178
|
if head_dim is None:
|
|
150
179
|
assert config, "head_dim is None, the value cannot be set without a configuration"
|
|
@@ -244,6 +273,7 @@ def get_inputs(
|
|
|
244
273
|
)
|
|
245
274
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
246
275
|
if add_second_input:
|
|
276
|
+
res["inputs_prompt"] = dict(input_ids=torch.randint(1000, 30000, (1, 11)))
|
|
247
277
|
res["inputs2"] = get_inputs(
|
|
248
278
|
model=model,
|
|
249
279
|
config=config,
|
|
@@ -195,9 +195,12 @@ class patched_ShapeEnv:
|
|
|
195
195
|
if self.frozen:
|
|
196
196
|
self.counter["ignored_backward_guard"] += 1
|
|
197
197
|
# PATCHED: raised an exception instead of logging.
|
|
198
|
+
import transformers
|
|
199
|
+
|
|
198
200
|
raise AssertionError(
|
|
199
201
|
f"[patched_ShapeEnv] Ignored guard {expr} == {concrete_val}, "
|
|
200
|
-
f"this could result in accuracy problems"
|
|
202
|
+
f"this could result in accuracy problems, transformers.__version__="
|
|
203
|
+
f"{transformers.__version__!r}"
|
|
201
204
|
)
|
|
202
205
|
|
|
203
206
|
def _set_replacement(
|
|
@@ -1452,7 +1452,7 @@ def patched_sdpa_attention_forward(
|
|
|
1452
1452
|
scale=scaling,
|
|
1453
1453
|
is_causal=True,
|
|
1454
1454
|
**sdpa_kwargs,
|
|
1455
|
-
),
|
|
1455
|
+
).contiguous(),
|
|
1456
1456
|
lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
|
|
1457
1457
|
query,
|
|
1458
1458
|
key,
|
|
@@ -1461,7 +1461,7 @@ def patched_sdpa_attention_forward(
|
|
|
1461
1461
|
scale=scaling,
|
|
1462
1462
|
is_causal=False,
|
|
1463
1463
|
**sdpa_kwargs,
|
|
1464
|
-
),
|
|
1464
|
+
).contiguous(),
|
|
1465
1465
|
[query, key, value],
|
|
1466
1466
|
)
|
|
1467
1467
|
attn_output = attn_output.transpose(1, 2).contiguous()
|