onnx-diagnostic 0.7.13__py3-none-any.whl → 0.7.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +15 -3
- onnx_diagnostic/helpers/cache_helper.py +1 -1
- onnx_diagnostic/helpers/config_helper.py +2 -1
- onnx_diagnostic/helpers/log_helper.py +53 -17
- onnx_diagnostic/helpers/rt_helper.py +3 -3
- onnx_diagnostic/tasks/image_text_to_text.py +6 -5
- onnx_diagnostic/tasks/text_generation.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +7 -1
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +1 -4
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +24 -7
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +31 -13
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +445 -9
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +79 -28
- onnx_diagnostic/torch_models/hghub/model_inputs.py +31 -5
- onnx_diagnostic/torch_models/validate.py +41 -28
- {onnx_diagnostic-0.7.13.dist-info → onnx_diagnostic-0.7.15.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.7.13.dist-info → onnx_diagnostic-0.7.15.dist-info}/RECORD +21 -21
- {onnx_diagnostic-0.7.13.dist-info → onnx_diagnostic-0.7.15.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.13.dist-info → onnx_diagnostic-0.7.15.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.13.dist-info → onnx_diagnostic-0.7.15.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -400,12 +400,17 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
400
400
|
|
|
401
401
|
position_ids is usually not needed, they can be removed by adding:
|
|
402
402
|
|
|
403
|
-
|
|
403
|
+
--drop position_ids
|
|
404
404
|
|
|
405
405
|
The behaviour may be modified compare the original configuration,
|
|
406
406
|
the following argument can be rope_scaling to dynamic:
|
|
407
407
|
|
|
408
|
-
|
|
408
|
+
--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\""
|
|
409
|
+
|
|
410
|
+
You can profile the command line by running:
|
|
411
|
+
|
|
412
|
+
pyinstrument -m onnx_diagnostic validate ...
|
|
413
|
+
pyinstrument -r html -o profile.html -m onnx_diagnostic validate ...
|
|
409
414
|
"""
|
|
410
415
|
),
|
|
411
416
|
formatter_class=RawTextHelpFormatter,
|
|
@@ -548,6 +553,12 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
548
553
|
action=BooleanOptionalAction,
|
|
549
554
|
help="Enables onnxruntime logging when the session is created",
|
|
550
555
|
)
|
|
556
|
+
parser.add_argument(
|
|
557
|
+
"--quiet-input-sets",
|
|
558
|
+
default="",
|
|
559
|
+
help="Avoids raising an exception when an input sets does not work with "
|
|
560
|
+
"the exported model, example: --quiet-input-sets=inputs,inputs22",
|
|
561
|
+
)
|
|
551
562
|
return parser
|
|
552
563
|
|
|
553
564
|
|
|
@@ -609,6 +620,7 @@ def _cmd_validate(argv: List[Any]):
|
|
|
609
620
|
warmup=args.warmup,
|
|
610
621
|
inputs2=args.inputs2,
|
|
611
622
|
ort_logs=args.ort_logs,
|
|
623
|
+
quiet_input_sets=set(args.quiet_input_sets.split(",")),
|
|
612
624
|
output_names=(
|
|
613
625
|
None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
|
|
614
626
|
),
|
|
@@ -829,7 +841,7 @@ def get_parser_agg() -> ArgumentParser:
|
|
|
829
841
|
"n_model_pass,n_model_faster,"
|
|
830
842
|
"n_model_faster2x,n_model_faster3x,n_model_faster4x,n_node_attention,"
|
|
831
843
|
"n_node_attention23,n_node_rotary_embedding,n_node_rotary_embedding23,"
|
|
832
|
-
"n_node_layer_normalization,n_node_layer_normalization23,"
|
|
844
|
+
"n_node_gqa,n_node_layer_normalization,n_node_layer_normalization23,"
|
|
833
845
|
"peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,"
|
|
834
846
|
"n_node_constant,n_node_shape,n_node_expand,"
|
|
835
847
|
"n_node_function,n_node_initializer,n_node_scatter,"
|
|
@@ -108,7 +108,7 @@ def flatten_unflatten_for_dynamic_shapes(
|
|
|
108
108
|
|
|
109
109
|
def is_cache_dynamic_registered(fast: bool = False) -> bool:
|
|
110
110
|
"""
|
|
111
|
-
Tells class :class:`transformers.cache_utils.DynamicCache` can be
|
|
111
|
+
Tells if class :class:`transformers.cache_utils.DynamicCache` can be
|
|
112
112
|
serialized and deserialized. Only then, :func:`torch.export.export`
|
|
113
113
|
can export a model.
|
|
114
114
|
|
|
@@ -95,7 +95,8 @@ def config_class_from_architecture(arch: str, exc: bool = False) -> Optional[typ
|
|
|
95
95
|
mod_name = cls.__module__
|
|
96
96
|
mod = importlib.import_module(mod_name)
|
|
97
97
|
source = inspect.getsource(mod)
|
|
98
|
-
|
|
98
|
+
# [^O] avoids capturing Optional[Something]
|
|
99
|
+
reg = re.compile("config: ([^O][A-Za-z0-9]+)")
|
|
99
100
|
fall = reg.findall(source)
|
|
100
101
|
if len(fall) == 0:
|
|
101
102
|
assert not exc, (
|
|
@@ -1167,7 +1167,7 @@ class CubeLogs:
|
|
|
1167
1167
|
df.to_excel(
|
|
1168
1168
|
writer,
|
|
1169
1169
|
sheet_name=name,
|
|
1170
|
-
freeze_panes=(df.columns.nlevels +
|
|
1170
|
+
freeze_panes=(df.columns.nlevels + 1, df.index.nlevels),
|
|
1171
1171
|
)
|
|
1172
1172
|
f_highlights[name] = tview.f_highlight
|
|
1173
1173
|
if tview.plots:
|
|
@@ -1210,7 +1210,7 @@ class CubeLogs:
|
|
|
1210
1210
|
for k, v in sbs.items():
|
|
1211
1211
|
print(f"[CubeLogs.to_excel] sbs {k}: {v}")
|
|
1212
1212
|
name = "∧".join(sbs)
|
|
1213
|
-
sbs_raw, sbs_agg = self.sbs(sbs)
|
|
1213
|
+
sbs_raw, sbs_agg, sbs_col = self.sbs(sbs)
|
|
1214
1214
|
if verbose:
|
|
1215
1215
|
print(f"[CubeLogs.to_excel] add sheet {name!r} with shape {sbs_raw.shape}")
|
|
1216
1216
|
print(
|
|
@@ -1222,7 +1222,7 @@ class CubeLogs:
|
|
|
1222
1222
|
writer,
|
|
1223
1223
|
sheet_name=name,
|
|
1224
1224
|
freeze_panes=(
|
|
1225
|
-
sbs_raw.columns.nlevels +
|
|
1225
|
+
sbs_raw.columns.nlevels + 1,
|
|
1226
1226
|
sbs_raw.index.nlevels,
|
|
1227
1227
|
),
|
|
1228
1228
|
)
|
|
@@ -1230,10 +1230,18 @@ class CubeLogs:
|
|
|
1230
1230
|
writer,
|
|
1231
1231
|
sheet_name=f"{name}-AGG",
|
|
1232
1232
|
freeze_panes=(
|
|
1233
|
-
sbs_agg.columns.nlevels +
|
|
1233
|
+
sbs_agg.columns.nlevels + 1,
|
|
1234
1234
|
sbs_agg.index.nlevels,
|
|
1235
1235
|
),
|
|
1236
1236
|
)
|
|
1237
|
+
sbs_col.to_excel(
|
|
1238
|
+
writer,
|
|
1239
|
+
sheet_name=f"{name}-COL",
|
|
1240
|
+
freeze_panes=(
|
|
1241
|
+
sbs_col.columns.nlevels + 1,
|
|
1242
|
+
sbs_col.index.nlevels,
|
|
1243
|
+
),
|
|
1244
|
+
)
|
|
1237
1245
|
|
|
1238
1246
|
if plots:
|
|
1239
1247
|
from openpyxl.drawing.image import Image
|
|
@@ -1314,7 +1322,7 @@ class CubeLogs:
|
|
|
1314
1322
|
|
|
1315
1323
|
def sbs(
|
|
1316
1324
|
self, configs: Dict[str, Dict[str, Any]], column_name: str = "CONF"
|
|
1317
|
-
) -> Tuple[pandas.DataFrame, pandas.DataFrame]:
|
|
1325
|
+
) -> Tuple[pandas.DataFrame, pandas.DataFrame, pandas.DataFrame]:
|
|
1318
1326
|
"""
|
|
1319
1327
|
Creates a side-by-side for two configurations.
|
|
1320
1328
|
Every configuration a dictionary column:value which filters in
|
|
@@ -1325,7 +1333,7 @@ class CubeLogs:
|
|
|
1325
1333
|
:param configs: example
|
|
1326
1334
|
``dict(CFA=dict(exporter="E1", opt="O"), CFB=dict(exporter="E2", opt="O"))``
|
|
1327
1335
|
:param column_name: column to add with the name of the configuration
|
|
1328
|
-
:return: data
|
|
1336
|
+
:return: data, aggregated date, data with a row per model
|
|
1329
1337
|
"""
|
|
1330
1338
|
assert (
|
|
1331
1339
|
len(configs) >= 2
|
|
@@ -1433,6 +1441,8 @@ class CubeLogs:
|
|
|
1433
1441
|
_mkc(m, f"{n1}<{n2}"): (si < sj).astype(int),
|
|
1434
1442
|
_mkc(m, f"{n1}=={n2}"): (si == sj).astype(int),
|
|
1435
1443
|
_mkc(m, f"{n1}>{n2}"): (si > sj).astype(int),
|
|
1444
|
+
_mkc(m, f"{n1}*({n1}∧{n2})"): si * (~sinan & ~sjnan).astype(float),
|
|
1445
|
+
_mkc(m, f"{n2}*({n1}∧{n2})"): sj * (~sinan & ~sjnan).astype(float),
|
|
1436
1446
|
}
|
|
1437
1447
|
)
|
|
1438
1448
|
nas.columns.names = view_res.columns.names
|
|
@@ -1452,13 +1462,11 @@ class CubeLogs:
|
|
|
1452
1462
|
}
|
|
1453
1463
|
flat = view_res.groupby(self.time).agg(aggs)
|
|
1454
1464
|
flat = flat.stack("METRICS", future_stack=True)
|
|
1455
|
-
return res, flat
|
|
1465
|
+
return res, flat, view_res.T.sort_index().T
|
|
1456
1466
|
|
|
1457
1467
|
|
|
1458
1468
|
class CubeLogsPerformance(CubeLogs):
|
|
1459
|
-
"""
|
|
1460
|
-
Processes logs coming from experiments.
|
|
1461
|
-
"""
|
|
1469
|
+
"""Processes logs coming from experiments."""
|
|
1462
1470
|
|
|
1463
1471
|
def __init__(
|
|
1464
1472
|
self,
|
|
@@ -1511,20 +1519,25 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1511
1519
|
"n_model_faster2x",
|
|
1512
1520
|
"n_model_faster3x",
|
|
1513
1521
|
"n_model_faster4x",
|
|
1522
|
+
"n_model_faster5x",
|
|
1514
1523
|
"n_node_attention",
|
|
1515
1524
|
"n_node_attention23",
|
|
1516
|
-
"
|
|
1517
|
-
"
|
|
1518
|
-
"n_node_layer_normalization",
|
|
1519
|
-
"n_node_layer_normalization23",
|
|
1525
|
+
"n_node_causal_mask",
|
|
1526
|
+
"n_node_constant",
|
|
1520
1527
|
"n_node_control_flow",
|
|
1521
|
-
"
|
|
1528
|
+
"n_node_expand",
|
|
1522
1529
|
"n_node_function",
|
|
1530
|
+
"n_node_gqa",
|
|
1523
1531
|
"n_node_initializer",
|
|
1524
1532
|
"n_node_initializer_small",
|
|
1525
|
-
"
|
|
1533
|
+
"n_node_layer_normalization",
|
|
1534
|
+
"n_node_layer_normalization23",
|
|
1535
|
+
"n_node_reshape",
|
|
1536
|
+
"n_node_rotary_embedding",
|
|
1537
|
+
"n_node_rotary_embedding23",
|
|
1538
|
+
"n_node_scatter",
|
|
1539
|
+
"n_node_sequence",
|
|
1526
1540
|
"n_node_shape",
|
|
1527
|
-
"n_node_expand",
|
|
1528
1541
|
"onnx_n_nodes_no_cst",
|
|
1529
1542
|
"peak_gpu_torch",
|
|
1530
1543
|
"peak_gpu_nvidia",
|
|
@@ -1690,6 +1703,11 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1690
1703
|
"time_latency",
|
|
1691
1704
|
gdf(df, "time_latency_eager") > gdf(df, "time_latency", np.inf) * 3.98,
|
|
1692
1705
|
),
|
|
1706
|
+
n_model_faster5x=lambda df: gpreserve(
|
|
1707
|
+
df,
|
|
1708
|
+
"time_latency",
|
|
1709
|
+
gdf(df, "time_latency_eager") > gdf(df, "time_latency", np.inf) * 4.98,
|
|
1710
|
+
),
|
|
1693
1711
|
n_node_attention23=lambda df: gpreserve(
|
|
1694
1712
|
df, "time_latency_eager", gdf(df, "op_onnx__Attention")
|
|
1695
1713
|
),
|
|
@@ -1720,6 +1738,11 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1720
1738
|
+ gdf(df, "op_onnx_com.microsoft_DecoderMaskedMultiHeadAttention", 0)
|
|
1721
1739
|
+ gdf(df, "op_onnx_com.microsoft_SparseAttention", 0),
|
|
1722
1740
|
),
|
|
1741
|
+
n_node_gqa=lambda df: gpreserve(
|
|
1742
|
+
df,
|
|
1743
|
+
"time_latency_eager",
|
|
1744
|
+
gdf(df, "op_onnx_com.microsoft_GroupQueryAttention", 0),
|
|
1745
|
+
),
|
|
1723
1746
|
n_node_layer_normalization=lambda df: gpreserve(
|
|
1724
1747
|
df,
|
|
1725
1748
|
"time_latency_eager",
|
|
@@ -1764,9 +1787,22 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1764
1787
|
n_node_shape=lambda df: gpreserve(
|
|
1765
1788
|
df, "time_latency_eager", gdf(df, "op_onnx__Shape")
|
|
1766
1789
|
),
|
|
1790
|
+
n_node_reshape=lambda df: gpreserve(
|
|
1791
|
+
df, "time_latency_eager", gdf(df, "op_onnx__Reshape")
|
|
1792
|
+
),
|
|
1767
1793
|
n_node_expand=lambda df: gpreserve(
|
|
1768
1794
|
df, "time_latency_eager", gdf(df, "op_onnx__Expand")
|
|
1769
1795
|
),
|
|
1796
|
+
n_node_causal_mask=lambda df: gpreserve(
|
|
1797
|
+
df,
|
|
1798
|
+
"time_latency_eager",
|
|
1799
|
+
gdf(df, "op_onnx__CausalMask", 0),
|
|
1800
|
+
),
|
|
1801
|
+
n_node_sequence=lambda df: gpreserve(
|
|
1802
|
+
df,
|
|
1803
|
+
"time_latency_eager",
|
|
1804
|
+
gdf(df, "op_onnx__SequenceAt", 0) + gdf(df, "op_onnx__SplitToSequence", 0),
|
|
1805
|
+
),
|
|
1770
1806
|
)
|
|
1771
1807
|
assert (
|
|
1772
1808
|
formula in lambdas
|
|
@@ -3,8 +3,6 @@ import numpy as np
|
|
|
3
3
|
import onnx
|
|
4
4
|
import torch
|
|
5
5
|
from .helper import string_type, flatten_object
|
|
6
|
-
from .torch_helper import to_numpy
|
|
7
|
-
from .cache_helper import is_cache_dynamic_registered
|
|
8
6
|
|
|
9
7
|
|
|
10
8
|
def name_type_to_onnx_dtype(name: str) -> int:
|
|
@@ -49,7 +47,7 @@ def make_feeds(
|
|
|
49
47
|
assert (
|
|
50
48
|
not check_flatten
|
|
51
49
|
or not all(isinstance(obj, torch.Tensor) for obj in flat)
|
|
52
|
-
or not is_cache_dynamic_registered(fast=True)
|
|
50
|
+
# or not is_cache_dynamic_registered(fast=True)
|
|
53
51
|
or len(flat) == len(torch.utils._pytree.tree_flatten(inputs)[0])
|
|
54
52
|
), (
|
|
55
53
|
f"Unexpected number of flattened objects, "
|
|
@@ -57,6 +55,8 @@ def make_feeds(
|
|
|
57
55
|
f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True)}"
|
|
58
56
|
)
|
|
59
57
|
if use_numpy:
|
|
58
|
+
from .torch_helper import to_numpy
|
|
59
|
+
|
|
60
60
|
flat = [to_numpy(t) if isinstance(t, torch.Tensor) else t for t in flat]
|
|
61
61
|
names = (
|
|
62
62
|
[i.name for i in proto.graph.input]
|
|
@@ -186,12 +186,13 @@ def _get_inputs_gemma3(
|
|
|
186
186
|
f"total_sequence_length={total_sequence_length} != 860 "
|
|
187
187
|
f"for model {model.__class__.__name__}"
|
|
188
188
|
)
|
|
189
|
-
assert (
|
|
190
|
-
|
|
191
|
-
|
|
189
|
+
assert head_dim in (
|
|
190
|
+
256,
|
|
191
|
+
32,
|
|
192
|
+
), f"head_dim={head_dim} not in (32, 256) for model {model.__class__.__name__}"
|
|
192
193
|
assert n_images == 1, f"n_images={n_images} != 1 for model {model.__class__.__name__}"
|
|
193
|
-
assert num_key_value_heads
|
|
194
|
-
f"num_key_value_heads={num_key_value_heads}
|
|
194
|
+
assert num_key_value_heads in (1, 4), (
|
|
195
|
+
f"num_key_value_heads={num_key_value_heads} not in (1, 4) "
|
|
195
196
|
f"for this model {model.__class__.__name__}"
|
|
196
197
|
)
|
|
197
198
|
|
|
@@ -19,6 +19,9 @@ __TASK__ = "text-generation"
|
|
|
19
19
|
def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
20
20
|
"""Reduces a model size."""
|
|
21
21
|
# FalconMambaConfig: use_mambapy
|
|
22
|
+
if hasattr(config, "text_config"):
|
|
23
|
+
# The model is probably of mixture of models used only for text.
|
|
24
|
+
config = config.text_config
|
|
22
25
|
check_hasattr(
|
|
23
26
|
config,
|
|
24
27
|
("head_dim", ("hidden_size", "num_attention_heads"), "use_mambapy"),
|
|
@@ -284,6 +287,21 @@ def get_inputs(
|
|
|
284
287
|
add_second_input=0,
|
|
285
288
|
**kwargs,
|
|
286
289
|
)["inputs"]
|
|
290
|
+
res["inputs_batch1"] = get_inputs(
|
|
291
|
+
model=model,
|
|
292
|
+
config=config,
|
|
293
|
+
dummy_max_token_id=dummy_max_token_id,
|
|
294
|
+
num_hidden_layers=num_hidden_layers,
|
|
295
|
+
batch_size=1,
|
|
296
|
+
sequence_length=sequence_length,
|
|
297
|
+
sequence_length2=sequence_length2,
|
|
298
|
+
dynamic_rope=dynamic_rope,
|
|
299
|
+
num_key_value_heads=num_key_value_heads,
|
|
300
|
+
head_dim=head_dim,
|
|
301
|
+
cls_cache=cls_cache,
|
|
302
|
+
add_second_input=0,
|
|
303
|
+
**kwargs,
|
|
304
|
+
)["inputs"]
|
|
287
305
|
return res
|
|
288
306
|
|
|
289
307
|
|
|
@@ -293,6 +311,9 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
293
311
|
|
|
294
312
|
If the configuration is None, the function selects typical dimensions.
|
|
295
313
|
"""
|
|
314
|
+
if hasattr(config, "text_config"):
|
|
315
|
+
# The model is probably of mixture of models used only for text.
|
|
316
|
+
config = config.text_config
|
|
296
317
|
if config is not None:
|
|
297
318
|
check_hasattr(
|
|
298
319
|
config,
|
|
@@ -676,7 +676,13 @@ def run_exporter(
|
|
|
676
676
|
|
|
677
677
|
if dynamic and len(inputs) > 1:
|
|
678
678
|
for index, i in enumerate(inputs):
|
|
679
|
-
|
|
679
|
+
if quiet:
|
|
680
|
+
try:
|
|
681
|
+
expected = model(*_clone(i))
|
|
682
|
+
except Exception as e:
|
|
683
|
+
return dict(error=str(e), success=0, error_step=f"run0.{index}")
|
|
684
|
+
else:
|
|
685
|
+
expected = model(*_clone(i))
|
|
680
686
|
try:
|
|
681
687
|
got = mod(*i)
|
|
682
688
|
except Exception as e:
|
|
@@ -353,12 +353,9 @@ class ControlFlowCondNonZero(torch.nn.Module):
|
|
|
353
353
|
|
|
354
354
|
|
|
355
355
|
class ControlFlowCondIdentity_153832(torch.nn.Module):
|
|
356
|
-
"""
|
|
357
|
-
`#153832 <https://github.com/pytorch/pytorch/issues/153832>`_
|
|
358
|
-
"""
|
|
356
|
+
"""`#153832 <https://github.com/pytorch/pytorch/issues/153832>`_"""
|
|
359
357
|
|
|
360
358
|
def forward(self, x, y):
|
|
361
|
-
|
|
362
359
|
def branch_cond_then_1(x):
|
|
363
360
|
x = torch.abs(x) + 1
|
|
364
361
|
return x
|
|
@@ -340,6 +340,7 @@ def torch_export_patches(
|
|
|
340
340
|
###############
|
|
341
341
|
|
|
342
342
|
if patch_torch:
|
|
343
|
+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
343
344
|
from .patches.patch_torch import (
|
|
344
345
|
patched_infer_size,
|
|
345
346
|
patched_vmap,
|
|
@@ -347,6 +348,9 @@ def torch_export_patches(
|
|
|
347
348
|
patched__constrain_user_specified_dimhint_range,
|
|
348
349
|
_catch_produce_guards_and_solve_constraints,
|
|
349
350
|
patch__check_input_constraints_for_graph,
|
|
351
|
+
patched__broadcast_in_dim_meta,
|
|
352
|
+
patched__maybe_broadcast,
|
|
353
|
+
patched_ShapeEnv,
|
|
350
354
|
)
|
|
351
355
|
|
|
352
356
|
if verbose:
|
|
@@ -383,6 +387,20 @@ def torch_export_patches(
|
|
|
383
387
|
patched__constrain_user_specified_dimhint_range
|
|
384
388
|
)
|
|
385
389
|
|
|
390
|
+
# torch._prims._broadcast_in_dim_meta
|
|
391
|
+
f_broadcast_in_dim = torch._prims.broadcast_in_dim
|
|
392
|
+
f__broadcast_in_dim_meta = torch._prims._broadcast_in_dim_meta
|
|
393
|
+
torch._prims._broadcast_in_dim_meta = patched__broadcast_in_dim_meta
|
|
394
|
+
torch._prims.broadcast_in_dim = patched__broadcast_in_dim_meta
|
|
395
|
+
|
|
396
|
+
# torch._refs._maybe_broadcast
|
|
397
|
+
f__maybe_broadcast = torch._refs._maybe_broadcast
|
|
398
|
+
torch._refs._maybe_broadcast = patched__maybe_broadcast
|
|
399
|
+
|
|
400
|
+
# ShapeEnv
|
|
401
|
+
f_shape_env__evaluate_expr = ShapeEnv._evaluate_expr
|
|
402
|
+
ShapeEnv._evaluate_expr = patched_ShapeEnv._evaluate_expr
|
|
403
|
+
|
|
386
404
|
# torch._export.non_strict_utils.produce_guards_and_solve_constraints
|
|
387
405
|
if patch_torch and catch_constraints:
|
|
388
406
|
if verbose:
|
|
@@ -404,10 +422,7 @@ def torch_export_patches(
|
|
|
404
422
|
)
|
|
405
423
|
)
|
|
406
424
|
|
|
407
|
-
if stop_if_static:
|
|
408
|
-
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
409
|
-
from .patches.patch_torch import patched_ShapeEnv
|
|
410
|
-
|
|
425
|
+
if patch_torch and stop_if_static:
|
|
411
426
|
ShapeEnv._log_guard_remember = ShapeEnv._log_guard
|
|
412
427
|
|
|
413
428
|
if verbose:
|
|
@@ -584,6 +599,10 @@ def torch_export_patches(
|
|
|
584
599
|
torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
|
|
585
600
|
f___constrain_user_specified_dimhint_range
|
|
586
601
|
)
|
|
602
|
+
torch._prims._broadcast_in_dim_meta = f__broadcast_in_dim_meta
|
|
603
|
+
torch._prims.broadcast_in_dim = f_broadcast_in_dim
|
|
604
|
+
torch._refs._maybe_broadcast = f__maybe_broadcast
|
|
605
|
+
ShapeEnv._evaluate_expr = f_shape_env__evaluate_expr
|
|
587
606
|
|
|
588
607
|
if verbose:
|
|
589
608
|
print("[torch_export_patches] restored pytorch functions")
|
|
@@ -723,9 +742,7 @@ def torch_export_patches(
|
|
|
723
742
|
|
|
724
743
|
|
|
725
744
|
def replacement_before_exporting(args: Any) -> Any:
|
|
726
|
-
"""
|
|
727
|
-
Does replacements on the given inputs if needed.
|
|
728
|
-
"""
|
|
745
|
+
"""Does replacements on the given inputs if needed."""
|
|
729
746
|
if args is None:
|
|
730
747
|
return None
|
|
731
748
|
if isinstance(args, (int, float)):
|
|
@@ -12,17 +12,26 @@ from transformers.cache_utils import (
|
|
|
12
12
|
StaticCache,
|
|
13
13
|
)
|
|
14
14
|
|
|
15
|
-
try:
|
|
16
|
-
from transformers.models.mamba.modeling_mamba import MambaCache
|
|
17
|
-
except ImportError:
|
|
18
|
-
from transformers.cache_utils import MambaCache
|
|
19
|
-
|
|
20
15
|
from ..helpers import string_type
|
|
21
16
|
from .serialization import _lower_name_with_
|
|
22
17
|
|
|
23
18
|
PATCH_OF_PATCHES: Set[Any] = set()
|
|
24
19
|
|
|
25
20
|
|
|
21
|
+
def get_mamba_cache_cls() -> type:
|
|
22
|
+
try:
|
|
23
|
+
from transformers.models.mamba.modeling_mamba import MambaCache
|
|
24
|
+
|
|
25
|
+
return MambaCache
|
|
26
|
+
except ImportError:
|
|
27
|
+
try:
|
|
28
|
+
from transformers.cache_utils import MambaCache
|
|
29
|
+
|
|
30
|
+
return MambaCache
|
|
31
|
+
except ImportError:
|
|
32
|
+
return None
|
|
33
|
+
|
|
34
|
+
|
|
26
35
|
def register_class_serialization(
|
|
27
36
|
cls,
|
|
28
37
|
f_flatten: Callable,
|
|
@@ -203,13 +212,6 @@ def serialization_functions(
|
|
|
203
212
|
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
|
|
204
213
|
verbose=verbose,
|
|
205
214
|
),
|
|
206
|
-
MambaCache: lambda verbose=verbose: register_class_serialization(
|
|
207
|
-
MambaCache,
|
|
208
|
-
flatten_mamba_cache,
|
|
209
|
-
unflatten_mamba_cache,
|
|
210
|
-
flatten_with_keys_mamba_cache,
|
|
211
|
-
verbose=verbose,
|
|
212
|
-
),
|
|
213
215
|
EncoderDecoderCache: lambda verbose=verbose: register_class_serialization(
|
|
214
216
|
EncoderDecoderCache,
|
|
215
217
|
flatten_encoder_decoder_cache,
|
|
@@ -232,6 +234,17 @@ def serialization_functions(
|
|
|
232
234
|
verbose=verbose,
|
|
233
235
|
),
|
|
234
236
|
}
|
|
237
|
+
MambaCache = get_mamba_cache_cls()
|
|
238
|
+
if MambaCache:
|
|
239
|
+
transformers_classes[MambaCache] = (
|
|
240
|
+
lambda verbose=verbose: register_class_serialization(
|
|
241
|
+
MambaCache,
|
|
242
|
+
flatten_mamba_cache,
|
|
243
|
+
unflatten_mamba_cache,
|
|
244
|
+
flatten_with_keys_mamba_cache,
|
|
245
|
+
verbose=verbose,
|
|
246
|
+
)
|
|
247
|
+
)
|
|
235
248
|
classes.update(transformers_classes)
|
|
236
249
|
|
|
237
250
|
if patch_diffusers:
|
|
@@ -287,7 +300,12 @@ def unregister_class_serialization(cls: type, verbose: int = 0):
|
|
|
287
300
|
|
|
288
301
|
def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
|
|
289
302
|
"""Undo all registrations."""
|
|
290
|
-
|
|
303
|
+
MambaCache = get_mamba_cache_cls()
|
|
304
|
+
cls_ensemble = (
|
|
305
|
+
{DynamicCache, EncoderDecoderCache}
|
|
306
|
+
| set(undo)
|
|
307
|
+
| ({MambaCache} if MambaCache else set())
|
|
308
|
+
)
|
|
291
309
|
for cls in cls_ensemble:
|
|
292
310
|
if undo.get(cls.__name__, False):
|
|
293
311
|
unregister_class_serialization(cls, verbose)
|