onnx-diagnostic 0.7.10__py3-none-any.whl → 0.7.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +13 -3
- onnx_diagnostic/helpers/cache_helper.py +8 -6
- onnx_diagnostic/helpers/log_helper.py +65 -12
- onnx_diagnostic/helpers/rt_helper.py +53 -36
- onnx_diagnostic/tasks/__init__.py +4 -2
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +11 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +5 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +73 -32
- onnx_diagnostic/torch_models/hghub/hub_data.py +3 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +70 -38
- onnx_diagnostic/torch_models/hghub/model_specific.py +27 -0
- onnx_diagnostic/torch_models/validate.py +329 -88
- {onnx_diagnostic-0.7.10.dist-info → onnx_diagnostic-0.7.12.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.7.10.dist-info → onnx_diagnostic-0.7.12.dist-info}/RECORD +19 -18
- {onnx_diagnostic-0.7.10.dist-info → onnx_diagnostic-0.7.12.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.10.dist-info → onnx_diagnostic-0.7.12.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.10.dist-info → onnx_diagnostic-0.7.12.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -474,7 +474,7 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
474
474
|
)
|
|
475
475
|
parser.add_argument(
|
|
476
476
|
"--runtime",
|
|
477
|
-
choices=["onnxruntime", "torch", "ref"],
|
|
477
|
+
choices=["onnxruntime", "torch", "ref", "orteval", "orteval10"],
|
|
478
478
|
default="onnxruntime",
|
|
479
479
|
help="onnx runtime to use, `onnxruntime` by default",
|
|
480
480
|
)
|
|
@@ -542,6 +542,12 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
542
542
|
"the onnx exporter should use.",
|
|
543
543
|
default="",
|
|
544
544
|
)
|
|
545
|
+
parser.add_argument(
|
|
546
|
+
"--ort-logs",
|
|
547
|
+
default=False,
|
|
548
|
+
action=BooleanOptionalAction,
|
|
549
|
+
help="Enables onnxruntime logging when the session is created",
|
|
550
|
+
)
|
|
545
551
|
return parser
|
|
546
552
|
|
|
547
553
|
|
|
@@ -575,6 +581,7 @@ def _cmd_validate(argv: List[Any]):
|
|
|
575
581
|
):
|
|
576
582
|
print(f"validate - unsupported args: export={args.export!r}, opt={args.opt!r}")
|
|
577
583
|
return
|
|
584
|
+
patch_dict = args.patch if isinstance(args.patch, dict) else {"patch": args.patch}
|
|
578
585
|
summary, _data = validate_model(
|
|
579
586
|
model_id=args.mid,
|
|
580
587
|
task=args.task,
|
|
@@ -585,8 +592,8 @@ def _cmd_validate(argv: List[Any]):
|
|
|
585
592
|
use_pretrained=args.trained,
|
|
586
593
|
dtype=args.dtype,
|
|
587
594
|
device=args.device,
|
|
588
|
-
patch=
|
|
589
|
-
rewrite=args.rewrite,
|
|
595
|
+
patch=patch_dict,
|
|
596
|
+
rewrite=args.rewrite and patch_dict.get("patch", True),
|
|
590
597
|
stop_if_static=args.stop_if_static,
|
|
591
598
|
optimization=args.opt,
|
|
592
599
|
exporter=args.export,
|
|
@@ -601,6 +608,7 @@ def _cmd_validate(argv: List[Any]):
|
|
|
601
608
|
repeat=args.repeat,
|
|
602
609
|
warmup=args.warmup,
|
|
603
610
|
inputs2=args.inputs2,
|
|
611
|
+
ort_logs=args.ort_logs,
|
|
604
612
|
output_names=(
|
|
605
613
|
None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
|
|
606
614
|
),
|
|
@@ -820,6 +828,8 @@ def get_parser_agg() -> ArgumentParser:
|
|
|
820
828
|
"n_model_running,n_model_acc01,n_model_acc001,n_model_dynamic,"
|
|
821
829
|
"n_model_pass,n_model_faster,"
|
|
822
830
|
"n_model_faster2x,n_model_faster3x,n_model_faster4x,n_node_attention,"
|
|
831
|
+
"n_node_attention23,n_node_rotary_embedding,n_node_rotary_embedding23,"
|
|
832
|
+
"n_node_layer_normalization,n_node_layer_normalization23,"
|
|
823
833
|
"peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,"
|
|
824
834
|
"n_node_constant,n_node_shape,n_node_expand,"
|
|
825
835
|
"n_node_function,n_node_initializer,n_node_scatter,"
|
|
@@ -4,11 +4,6 @@ import torch
|
|
|
4
4
|
import transformers
|
|
5
5
|
import transformers.cache_utils
|
|
6
6
|
|
|
7
|
-
try:
|
|
8
|
-
from transformers.models.mamba.modeling_mamba import MambaCache
|
|
9
|
-
except ImportError:
|
|
10
|
-
from transformers.cache_utils import MambaCache
|
|
11
|
-
|
|
12
7
|
|
|
13
8
|
class CacheKeyValue:
|
|
14
9
|
"""
|
|
@@ -354,8 +349,15 @@ def make_encoder_decoder_cache(
|
|
|
354
349
|
)
|
|
355
350
|
|
|
356
351
|
|
|
357
|
-
def make_mamba_cache(
|
|
352
|
+
def make_mamba_cache(
|
|
353
|
+
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
354
|
+
) -> "MambaCache": # noqa: F821
|
|
358
355
|
"Creates a ``MambaCache``."
|
|
356
|
+
# import is moved here because this part is slow.
|
|
357
|
+
try:
|
|
358
|
+
from transformers.models.mamba.modeling_mamba import MambaCache
|
|
359
|
+
except ImportError:
|
|
360
|
+
from transformers.cache_utils import MambaCache
|
|
359
361
|
dtype = key_value_pairs[0][0].dtype
|
|
360
362
|
|
|
361
363
|
class _config:
|
|
@@ -285,7 +285,8 @@ class CubePlot:
|
|
|
285
285
|
nn = df.shape[1] // n_cols
|
|
286
286
|
nn += int(df.shape[1] % n_cols != 0)
|
|
287
287
|
ratio = float(os.environ.get("FIGSIZEH", "1"))
|
|
288
|
-
|
|
288
|
+
figsize = (6 * n_cols, nn * (2.5 + df.shape[0] / 15) * ratio)
|
|
289
|
+
fig, axs = plt.subplots(nn, n_cols, figsize=figsize)
|
|
289
290
|
pos = 0
|
|
290
291
|
imgs = []
|
|
291
292
|
for c in self._make_loop(df.columns, verbose):
|
|
@@ -332,10 +333,12 @@ class CubePlot:
|
|
|
332
333
|
n_cols = len(groups)
|
|
333
334
|
|
|
334
335
|
title_suffix = f"\n{title_suffix}" if title_suffix else ""
|
|
336
|
+
ratio = float(os.environ.get("FIGSIZEH", "1"))
|
|
337
|
+
figsize = (5 * n_cols, max(len(g) for g in groups) * (2 + df.shape[1] / 2) * ratio)
|
|
335
338
|
fig, axs = plt.subplots(
|
|
336
339
|
df.shape[1],
|
|
337
340
|
n_cols,
|
|
338
|
-
figsize=
|
|
341
|
+
figsize=figsize,
|
|
339
342
|
sharex=True,
|
|
340
343
|
sharey="row" if n_cols > 1 else False,
|
|
341
344
|
)
|
|
@@ -877,7 +880,11 @@ class CubeLogs:
|
|
|
877
880
|
print(f"[CubeLogs.view] key_columns={key_columns}")
|
|
878
881
|
g = data[[*key_index, *key_columns]].copy()
|
|
879
882
|
g["count"] = 1
|
|
880
|
-
r =
|
|
883
|
+
r = (
|
|
884
|
+
g.copy()
|
|
885
|
+
if not key_index and not key_columns
|
|
886
|
+
else g.groupby([*key_index, *key_columns], dropna=False).sum()
|
|
887
|
+
)
|
|
881
888
|
not_unique = r[r["count"] > 1]
|
|
882
889
|
assert not_unique.shape[0] == 0, (
|
|
883
890
|
f"view_def.name={view_def.name!r}, "
|
|
@@ -1505,6 +1512,11 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1505
1512
|
"n_model_faster3x",
|
|
1506
1513
|
"n_model_faster4x",
|
|
1507
1514
|
"n_node_attention",
|
|
1515
|
+
"n_node_attention23",
|
|
1516
|
+
"n_node_rotary_embedding",
|
|
1517
|
+
"n_node_rotary_embedding23",
|
|
1518
|
+
"n_node_layer_normalization",
|
|
1519
|
+
"n_node_layer_normalization23",
|
|
1508
1520
|
"n_node_control_flow",
|
|
1509
1521
|
"n_node_scatter",
|
|
1510
1522
|
"n_node_function",
|
|
@@ -1568,7 +1580,9 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1568
1580
|
|
|
1569
1581
|
def gdf(df, cname, default_value=np.nan):
|
|
1570
1582
|
if cname in df.columns:
|
|
1571
|
-
|
|
1583
|
+
if np.isnan(default_value):
|
|
1584
|
+
return df[cname]
|
|
1585
|
+
return df[cname].fillna(default_value)
|
|
1572
1586
|
return pandas.Series(default_value, index=df.index)
|
|
1573
1587
|
|
|
1574
1588
|
def ghas_value(df, cname):
|
|
@@ -1676,15 +1690,54 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1676
1690
|
"time_latency",
|
|
1677
1691
|
gdf(df, "time_latency_eager") > gdf(df, "time_latency", np.inf) * 3.98,
|
|
1678
1692
|
),
|
|
1693
|
+
n_node_attention23=lambda df: gpreserve(
|
|
1694
|
+
df, "time_latency_eager", gdf(df, "op_onnx__Attention")
|
|
1695
|
+
),
|
|
1696
|
+
n_node_rotary_embedding23=lambda df: gpreserve(
|
|
1697
|
+
df, "time_latency_eager", gdf(df, "op_onnx__RotaryEmbedding")
|
|
1698
|
+
),
|
|
1699
|
+
n_node_layer_normalization23=lambda df: gpreserve(
|
|
1700
|
+
df,
|
|
1701
|
+
"time_latency_eager",
|
|
1702
|
+
gdf(df, "op_onnx__LayerNormalization", 0)
|
|
1703
|
+
+ gdf(df, "op_onnx__RMSNormalization", 0)
|
|
1704
|
+
+ gdf(df, "op_onnx__BatchNormlization", 0)
|
|
1705
|
+
+ gdf(df, "op_onnx__InstanceNormlization", 0)
|
|
1706
|
+
+ gdf(df, "op_onnx__GroupNormalization", 0),
|
|
1707
|
+
),
|
|
1679
1708
|
n_node_attention=lambda df: gpreserve(
|
|
1680
1709
|
df,
|
|
1681
|
-
"
|
|
1682
|
-
gdf(df, "op_onnx_com.microsoft_Attention")
|
|
1683
|
-
+ gdf(df, "op_onnx_com.microsoft_MultiHeadAttention")
|
|
1710
|
+
"time_latency_eager",
|
|
1711
|
+
gdf(df, "op_onnx_com.microsoft_Attention", 0)
|
|
1712
|
+
+ gdf(df, "op_onnx_com.microsoft_MultiHeadAttention", 0)
|
|
1713
|
+
+ gdf(df, "op_onnx_com.microsoft_PackedAttention", 0)
|
|
1714
|
+
+ gdf(df, "op_onnx_com.microsoft_PackedMultiHeadAttention", 0)
|
|
1715
|
+
+ gdf(df, "op_onnx_com.microsoft_GroupQueryAttention", 0)
|
|
1716
|
+
+ gdf(df, "op_onnx_com.microsoft_PagedAttention", 0)
|
|
1717
|
+
+ gdf(df, "op_onnx_com.microsoft_DecoderAttention", 0)
|
|
1718
|
+
+ gdf(df, "op_onnx_com.microsoft_LongformerAttention", 0)
|
|
1719
|
+
+ gdf(df, "op_onnx_com.microsoft_DecoderMaskedSelfAttention", 0)
|
|
1720
|
+
+ gdf(df, "op_onnx_com.microsoft_DecoderMaskedMultiHeadAttention", 0)
|
|
1721
|
+
+ gdf(df, "op_onnx_com.microsoft_SparseAttention", 0),
|
|
1722
|
+
),
|
|
1723
|
+
n_node_layer_normalization=lambda df: gpreserve(
|
|
1724
|
+
df,
|
|
1725
|
+
"time_latency_eager",
|
|
1726
|
+
gdf(df, "op_onnx_com.microsoft_EmbedLayerNormalization", 0)
|
|
1727
|
+
+ gdf(df, "op_onnx_com.microsoft_SkipLayerNormalization", 0)
|
|
1728
|
+
+ gdf(df, "op_onnx_com.microsoft_LayerNormalization", 0)
|
|
1729
|
+
+ gdf(df, "op_onnx_com.microsoft_SkipSimplifiedLayerNormalization", 0)
|
|
1730
|
+
+ gdf(df, "op_onnx_com.microsoft_SimplifiedLayerNormalization", 0),
|
|
1731
|
+
),
|
|
1732
|
+
n_node_rotary_embedding=lambda df: gpreserve(
|
|
1733
|
+
df,
|
|
1734
|
+
"time_latency_eager",
|
|
1735
|
+
gdf(df, "op_onnx_com.microsoft_GemmaRotaryEmbedding", 0)
|
|
1736
|
+
+ gdf(df, "op_onnx_com.microsoft_RotaryEmbedding", 0),
|
|
1684
1737
|
),
|
|
1685
1738
|
n_node_control_flow=lambda df: gpreserve(
|
|
1686
1739
|
df,
|
|
1687
|
-
"
|
|
1740
|
+
"time_latency_eager",
|
|
1688
1741
|
(
|
|
1689
1742
|
gdf(df, "op_onnx__If", 0)
|
|
1690
1743
|
+ gdf(df, "op_onnx__Scan", 0)
|
|
@@ -1693,7 +1746,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1693
1746
|
),
|
|
1694
1747
|
n_node_scatter=lambda df: gpreserve(
|
|
1695
1748
|
df,
|
|
1696
|
-
"
|
|
1749
|
+
"time_latency_eager",
|
|
1697
1750
|
gdf(df, "op_onnx__ScatterND", 0) + gdf(df, "op_onnx__ScatterElements", 0),
|
|
1698
1751
|
),
|
|
1699
1752
|
n_node_function=lambda df: gpreserve(
|
|
@@ -1706,13 +1759,13 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1706
1759
|
df, "onnx_n_initializer", gdf(df, "onnx_n_initializer")
|
|
1707
1760
|
),
|
|
1708
1761
|
n_node_constant=lambda df: gpreserve(
|
|
1709
|
-
df, "
|
|
1762
|
+
df, "time_latency_eager", gdf(df, "op_onnx__Constant")
|
|
1710
1763
|
),
|
|
1711
1764
|
n_node_shape=lambda df: gpreserve(
|
|
1712
|
-
df, "
|
|
1765
|
+
df, "time_latency_eager", gdf(df, "op_onnx__Shape")
|
|
1713
1766
|
),
|
|
1714
1767
|
n_node_expand=lambda df: gpreserve(
|
|
1715
|
-
df, "
|
|
1768
|
+
df, "time_latency_eager", gdf(df, "op_onnx__Expand")
|
|
1716
1769
|
),
|
|
1717
1770
|
)
|
|
1718
1771
|
assert (
|
|
@@ -3,7 +3,6 @@ import numpy as np
|
|
|
3
3
|
import onnx
|
|
4
4
|
import torch
|
|
5
5
|
from .helper import string_type, flatten_object
|
|
6
|
-
from .onnx_helper import dtype_to_tensor_dtype
|
|
7
6
|
from .cache_helper import is_cache_dynamic_registered
|
|
8
7
|
|
|
9
8
|
|
|
@@ -23,6 +22,7 @@ def make_feeds(
|
|
|
23
22
|
use_numpy: bool = False,
|
|
24
23
|
copy: bool = False,
|
|
25
24
|
check_flatten: bool = True,
|
|
25
|
+
is_modelbuilder: bool = False,
|
|
26
26
|
) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
|
|
27
27
|
"""
|
|
28
28
|
Serializes the inputs to produce feeds expected
|
|
@@ -35,10 +35,15 @@ def make_feeds(
|
|
|
35
35
|
by ``OrtValue``
|
|
36
36
|
:param check_flatten: if True, checks the ``torch.utils._pytree.tree_flatten``
|
|
37
37
|
returns the same number of outputs
|
|
38
|
+
:param is_modelbuilder: if True, the exporter is ModelBuilder, and we need to reorder
|
|
39
|
+
the past_key_values inputs to match the expected order, and get rid of position_ids.
|
|
38
40
|
:return: feeds dictionary
|
|
39
41
|
"""
|
|
40
|
-
# position_ids is a special case because ModelBuilder does not usually use it
|
|
41
|
-
#
|
|
42
|
+
# NOTE: position_ids is a special case because ModelBuilder does not usually use it,
|
|
43
|
+
# because it's fued into rotary embedding in GQA.
|
|
44
|
+
if is_modelbuilder and isinstance(inputs, dict):
|
|
45
|
+
inputs.pop("position_ids", None) # Ensure 'position_ids' absent before removing.
|
|
46
|
+
|
|
42
47
|
flat = flatten_object(inputs, drop_keys=True)
|
|
43
48
|
assert (
|
|
44
49
|
not check_flatten
|
|
@@ -76,39 +81,6 @@ def make_feeds(
|
|
|
76
81
|
f"\n-- inputs={string_type(inputs, with_shape=True)}"
|
|
77
82
|
f"\n-- names={names}"
|
|
78
83
|
)
|
|
79
|
-
if len(names) < len(flat) and (
|
|
80
|
-
isinstance(proto, onnx.ModelProto) or hasattr(proto, "get_inputs")
|
|
81
|
-
):
|
|
82
|
-
|
|
83
|
-
typed_names = (
|
|
84
|
-
[(i.name, i.type.tensor_type.elem_type) for i in proto.graph.input]
|
|
85
|
-
if isinstance(proto, onnx.ModelProto)
|
|
86
|
-
else [(i.name, name_type_to_onnx_dtype(i.type)) for i in proto.get_inputs()]
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
new_flat = []
|
|
90
|
-
pos = 0
|
|
91
|
-
for _name, dtype in typed_names:
|
|
92
|
-
assert isinstance(
|
|
93
|
-
dtype, int
|
|
94
|
-
), f"Unexpected value for dtype={dtype!r}, type(proto)={type(proto)}"
|
|
95
|
-
itype = dtype_to_tensor_dtype(flat[pos].dtype)
|
|
96
|
-
while dtype != itype:
|
|
97
|
-
pos += 1
|
|
98
|
-
if pos >= len(flat):
|
|
99
|
-
break
|
|
100
|
-
itype = dtype_to_tensor_dtype(flat[pos].dtype)
|
|
101
|
-
if pos >= len(flat):
|
|
102
|
-
break
|
|
103
|
-
new_flat.append(flat[pos])
|
|
104
|
-
pos += 1
|
|
105
|
-
assert len(new_flat) == len(names), (
|
|
106
|
-
f"Unable to align expected input {names} with the given input, "
|
|
107
|
-
f"type(proto)={type(proto)}"
|
|
108
|
-
f"\n-- inputs: {string_type(inputs, with_shape=True)}"
|
|
109
|
-
f"\n-- typed_names: {typed_names}"
|
|
110
|
-
)
|
|
111
|
-
flat = new_flat
|
|
112
84
|
|
|
113
85
|
if copy:
|
|
114
86
|
flat = [t.copy() if hasattr(t, "copy") else t.clone() for t in flat]
|
|
@@ -122,4 +94,49 @@ def make_feeds(
|
|
|
122
94
|
elif isinstance(i, float):
|
|
123
95
|
i = np.array(i, dtype=np.float32)
|
|
124
96
|
new_flat.append(i)
|
|
97
|
+
|
|
98
|
+
# NOTE: model builder has a different order for past_key_values
|
|
99
|
+
# we need to reorder them to match the expected order
|
|
100
|
+
if is_modelbuilder:
|
|
101
|
+
# We assume that if "past_key_values" is in the names when it's
|
|
102
|
+
# modelbuilder
|
|
103
|
+
non_past_kv_input_names = [n for n in names if "past_key_values" not in n]
|
|
104
|
+
past_kv_names = [n for n in names if "past_key_values" in n]
|
|
105
|
+
reorder_past_kv_names = reorder_modelbuilder_cache_to_torch(past_kv_names)
|
|
106
|
+
names = non_past_kv_input_names + reorder_past_kv_names
|
|
125
107
|
return dict(zip(names, new_flat))
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def reorder_modelbuilder_cache_to_torch(past_kv: List[Any]) -> List[Any]:
|
|
111
|
+
"""
|
|
112
|
+
Reorders the past_kvs for ModelBuilder to match the expected order
|
|
113
|
+
by PyTorch exported models.
|
|
114
|
+
|
|
115
|
+
.. note::
|
|
116
|
+
This function can take either the names or the actual tensors
|
|
117
|
+
as long as they are in a list.
|
|
118
|
+
|
|
119
|
+
Conceptually,
|
|
120
|
+
|
|
121
|
+
From::
|
|
122
|
+
|
|
123
|
+
[past_key_values.0.key, past_key_values.0.value,
|
|
124
|
+
past_key_values.1.key, past_key_values.1.value, ...]
|
|
125
|
+
|
|
126
|
+
To::
|
|
127
|
+
|
|
128
|
+
[past_key_values.0.key, past_key_values.1.key,
|
|
129
|
+
..., past_key_values.0.value, past_key_values.1.value, ...]
|
|
130
|
+
|
|
131
|
+
:param past_kv: list of flattened inputs
|
|
132
|
+
:return: reordered list of flattened inputs
|
|
133
|
+
"""
|
|
134
|
+
total_len = len(past_kv)
|
|
135
|
+
if total_len % 2 != 0:
|
|
136
|
+
raise ValueError("The length of past_key_values should be even.")
|
|
137
|
+
keys = []
|
|
138
|
+
values = []
|
|
139
|
+
for i in range(0, total_len, 2):
|
|
140
|
+
keys.append(past_kv[i])
|
|
141
|
+
values.append(past_kv[i + 1])
|
|
142
|
+
return keys + values
|
|
@@ -5,6 +5,8 @@ from . import (
|
|
|
5
5
|
fill_mask,
|
|
6
6
|
image_classification,
|
|
7
7
|
image_text_to_text,
|
|
8
|
+
image_to_video,
|
|
9
|
+
mask_generation,
|
|
8
10
|
mixture_of_expert,
|
|
9
11
|
object_detection,
|
|
10
12
|
sentence_similarity,
|
|
@@ -14,7 +16,6 @@ from . import (
|
|
|
14
16
|
text_to_image,
|
|
15
17
|
text2text_generation,
|
|
16
18
|
zero_shot_image_classification,
|
|
17
|
-
mask_generation,
|
|
18
19
|
)
|
|
19
20
|
|
|
20
21
|
__TASKS__ = [
|
|
@@ -23,6 +24,8 @@ __TASKS__ = [
|
|
|
23
24
|
fill_mask,
|
|
24
25
|
image_classification,
|
|
25
26
|
image_text_to_text,
|
|
27
|
+
image_to_video,
|
|
28
|
+
mask_generation,
|
|
26
29
|
mixture_of_expert,
|
|
27
30
|
object_detection,
|
|
28
31
|
sentence_similarity,
|
|
@@ -32,7 +35,6 @@ __TASKS__ = [
|
|
|
32
35
|
text_to_image,
|
|
33
36
|
text2text_generation,
|
|
34
37
|
zero_shot_image_classification,
|
|
35
|
-
mask_generation,
|
|
36
38
|
]
|
|
37
39
|
|
|
38
40
|
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
from typing import Any, Callable, Dict, Optional, Tuple
|
|
2
|
+
import torch
|
|
3
|
+
from ..helpers.config_helper import (
|
|
4
|
+
update_config,
|
|
5
|
+
check_hasattr,
|
|
6
|
+
default_num_hidden_layers as nhl,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
__TASK__ = "image-to-video"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
13
|
+
"""Reduces a model size."""
|
|
14
|
+
if not hasattr(config, "num_hidden_layers") and not hasattr(config, "num_layers"):
|
|
15
|
+
# We cannot reduce.
|
|
16
|
+
return {}
|
|
17
|
+
check_hasattr(config, ("num_hidden_layers", "num_layers"))
|
|
18
|
+
kwargs = {}
|
|
19
|
+
if hasattr(config, "num_layers"):
|
|
20
|
+
kwargs["num_layers"] = min(config.num_layers, nhl())
|
|
21
|
+
if hasattr(config, "num_hidden_layers"):
|
|
22
|
+
kwargs["num_hidden_layers"] = min(config.num_hidden_layers, nhl())
|
|
23
|
+
|
|
24
|
+
update_config(config, kwargs)
|
|
25
|
+
return kwargs
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_inputs(
|
|
29
|
+
model: torch.nn.Module,
|
|
30
|
+
config: Optional[Any],
|
|
31
|
+
text_embed_dim: int,
|
|
32
|
+
latent_channels: int,
|
|
33
|
+
batch_size: int = 2,
|
|
34
|
+
image_height: int = 704,
|
|
35
|
+
image_width: int = 1280,
|
|
36
|
+
latent_frames: int = 1,
|
|
37
|
+
text_maxlen: int = 512,
|
|
38
|
+
add_second_input: int = 1,
|
|
39
|
+
**kwargs, # unused
|
|
40
|
+
):
|
|
41
|
+
"""
|
|
42
|
+
Generates inputs for task ``image-to-video``.
|
|
43
|
+
"""
|
|
44
|
+
assert (
|
|
45
|
+
"cls_cache" not in kwargs
|
|
46
|
+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
47
|
+
latent_height = image_height // 8
|
|
48
|
+
latent_width = image_width // 8
|
|
49
|
+
dtype = torch.float32
|
|
50
|
+
|
|
51
|
+
inputs = dict(
|
|
52
|
+
hidden_states=torch.randn(
|
|
53
|
+
batch_size,
|
|
54
|
+
latent_channels,
|
|
55
|
+
latent_frames,
|
|
56
|
+
latent_height,
|
|
57
|
+
latent_width,
|
|
58
|
+
dtype=dtype,
|
|
59
|
+
),
|
|
60
|
+
timestep=torch.tensor([1.0] * batch_size, dtype=dtype),
|
|
61
|
+
encoder_hidden_states=torch.randn(
|
|
62
|
+
batch_size, text_maxlen, text_embed_dim, dtype=dtype
|
|
63
|
+
),
|
|
64
|
+
padding_mask=torch.ones(1, 1, image_height, image_width, dtype=dtype),
|
|
65
|
+
fps=torch.tensor([16] * batch_size, dtype=dtype),
|
|
66
|
+
condition_mask=torch.randn(
|
|
67
|
+
batch_size, 1, latent_frames, latent_height, latent_width, dtype=dtype
|
|
68
|
+
),
|
|
69
|
+
)
|
|
70
|
+
shapes = dict(
|
|
71
|
+
hidden_states={
|
|
72
|
+
0: "batch_size",
|
|
73
|
+
2: "latent_frames",
|
|
74
|
+
3: "latent_height",
|
|
75
|
+
4: "latent_width",
|
|
76
|
+
},
|
|
77
|
+
timestep={0: "batch_size"},
|
|
78
|
+
encoder_hidden_states={0: "batch_size"},
|
|
79
|
+
padding_mask={0: "batch_size", 2: "height", 3: "width"},
|
|
80
|
+
fps={0: "batch_size"},
|
|
81
|
+
condition_mask={
|
|
82
|
+
0: "batch_size",
|
|
83
|
+
2: "latent_frames",
|
|
84
|
+
3: "latent_height",
|
|
85
|
+
4: "latent_width",
|
|
86
|
+
},
|
|
87
|
+
)
|
|
88
|
+
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
89
|
+
|
|
90
|
+
if add_second_input:
|
|
91
|
+
assert (
|
|
92
|
+
add_second_input > 0
|
|
93
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
94
|
+
res["inputs2"] = get_inputs(
|
|
95
|
+
model=model,
|
|
96
|
+
config=config,
|
|
97
|
+
text_embed_dim=text_embed_dim,
|
|
98
|
+
latent_channels=latent_channels,
|
|
99
|
+
batch_size=batch_size,
|
|
100
|
+
image_height=image_height,
|
|
101
|
+
image_width=image_width,
|
|
102
|
+
latent_frames=latent_frames,
|
|
103
|
+
text_maxlen=text_maxlen,
|
|
104
|
+
add_second_input=0,
|
|
105
|
+
**kwargs,
|
|
106
|
+
)["inputs"]
|
|
107
|
+
return res
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
111
|
+
"""
|
|
112
|
+
Inputs kwargs.
|
|
113
|
+
|
|
114
|
+
If the configuration is None, the function selects typical dimensions.
|
|
115
|
+
"""
|
|
116
|
+
if config is not None:
|
|
117
|
+
check_hasattr(config, "in_channels", "text_embed_dim"),
|
|
118
|
+
kwargs = dict(
|
|
119
|
+
text_embed_dim=1024 if config is None else config.text_embed_dim,
|
|
120
|
+
latent_channels=16 if config is None else config.in_channels - 1,
|
|
121
|
+
batch_size=1,
|
|
122
|
+
image_height=8 * 50,
|
|
123
|
+
image_width=8 * 80,
|
|
124
|
+
latent_frames=1,
|
|
125
|
+
text_maxlen=512,
|
|
126
|
+
)
|
|
127
|
+
return kwargs, get_inputs
|
|
@@ -254,6 +254,17 @@ def torch_export_patches(
|
|
|
254
254
|
may appear ``AssertionError: Mutating module attribute _seen_tokens during export.``.
|
|
255
255
|
It can be avoided by setting ``strict=False`` when call :func:`torch.export.export`.
|
|
256
256
|
"""
|
|
257
|
+
if verbose:
|
|
258
|
+
print(f"[torch_export_patches] patch_sympy={patch_sympy!r}")
|
|
259
|
+
print(f" . patch_torch={patch_torch!r}")
|
|
260
|
+
print(f" . patch_transformers={patch_transformers!r}")
|
|
261
|
+
print(f" . patch_diffusers={patch_diffusers!r}")
|
|
262
|
+
print(f" . catch_constraints={catch_constraints!r}")
|
|
263
|
+
print(f" . stop_if_static={stop_if_static!r}")
|
|
264
|
+
print(f" . patch={patch!r}")
|
|
265
|
+
print(f" . custom_patches={custom_patches!r}")
|
|
266
|
+
print(f"[torch_export_patches] dump_rewriting={dump_rewriting!r}")
|
|
267
|
+
|
|
257
268
|
if rewrite:
|
|
258
269
|
from .patch_module import torch_export_rewrite
|
|
259
270
|
|
|
@@ -35,6 +35,9 @@ except ImportError:
|
|
|
35
35
|
from ...ext_test_case import has_transformers
|
|
36
36
|
from ...helpers.torch_helper import is_torchdynamo_exporting
|
|
37
37
|
|
|
38
|
+
patch_is_initialized = pv.Version(transformers.__version__) > pv.Version("4.56.99")
|
|
39
|
+
|
|
40
|
+
|
|
38
41
|
if patch_masking_utils:
|
|
39
42
|
# Introduced in 4.52
|
|
40
43
|
from transformers.masking_utils import (
|
|
@@ -213,6 +216,8 @@ if patch_DynamicLayer:
|
|
|
213
216
|
new_shape[-2] = 0
|
|
214
217
|
self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
|
|
215
218
|
self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
|
|
219
|
+
if patch_is_initialized:
|
|
220
|
+
self.is_initialized = True
|
|
216
221
|
|
|
217
222
|
|
|
218
223
|
def _patch_make_causal_mask(
|