onnxruntime-directml 1.17.0__cp38-cp38-win_amd64.whl → 1.17.3__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/__init__.py +1 -1
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +1 -1
- onnxruntime/quantization/matmul_4bits_quantizer.py +8 -1
- onnxruntime/quantization/onnx_quantizer.py +29 -6
- onnxruntime/quantization/qdq_quantizer.py +2 -0
- onnxruntime/tools/symbolic_shape_infer.py +28 -0
- onnxruntime/transformers/convert_generation.py +143 -11
- onnxruntime/transformers/fusion_bart_attention.py +229 -10
- onnxruntime/transformers/models/llama/benchmark.py +20 -18
- onnxruntime/transformers/models/llama/benchmark_all.py +22 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +581 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +5 -0
- onnxruntime/transformers/models/llama/dist_settings.py +5 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +200 -4
- onnxruntime/transformers/models/llama/llama_parity.py +8 -3
- onnxruntime/transformers/models/llama/llama_torch.py +5 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +5 -0
- onnxruntime/transformers/models/whisper/benchmark.py +19 -3
- onnxruntime/transformers/models/whisper/benchmark_all.py +6 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +165 -131
- onnxruntime/transformers/models/whisper/whisper_chain.py +166 -117
- onnxruntime/transformers/models/whisper/whisper_decoder.py +17 -3
- onnxruntime/transformers/models/whisper/whisper_encoder.py +13 -4
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +22 -9
- onnxruntime/transformers/models/whisper/whisper_helper.py +209 -59
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_model.py +48 -0
- onnxruntime/transformers/onnx_model_bart.py +1 -1
- onnxruntime/transformers/quantize_helper.py +2 -1
- onnxruntime/transformers/torch_onnx_export_helper.py +2 -1
- {onnxruntime_directml-1.17.0.dist-info → onnxruntime_directml-1.17.3.dist-info}/METADATA +16 -1
- {onnxruntime_directml-1.17.0.dist-info → onnxruntime_directml-1.17.3.dist-info}/RECORD +38 -36
- {onnxruntime_directml-1.17.0.dist-info → onnxruntime_directml-1.17.3.dist-info}/WHEEL +1 -1
- {onnxruntime_directml-1.17.0.dist-info → onnxruntime_directml-1.17.3.dist-info}/entry_points.txt +0 -0
- {onnxruntime_directml-1.17.0.dist-info → onnxruntime_directml-1.17.3.dist-info}/top_level.txt +0 -0
onnxruntime/__init__.py
CHANGED
|
@@ -7,7 +7,7 @@ ONNX Runtime is a performance-focused scoring engine for Open Neural Network Exc
|
|
|
7
7
|
For more information on ONNX Runtime, please see `aka.ms/onnxruntime <https://aka.ms/onnxruntime/>`_
|
|
8
8
|
or the `Github project <https://github.com/microsoft/onnxruntime/>`_.
|
|
9
9
|
"""
|
|
10
|
-
__version__ = "1.17.
|
|
10
|
+
__version__ = "1.17.3"
|
|
11
11
|
__author__ = "Microsoft"
|
|
12
12
|
|
|
13
13
|
# we need to do device version validation (for example to check Cuda version for an onnxruntime-training package).
|
onnxruntime/capi/DirectML.dll
CHANGED
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -22,7 +22,7 @@ def check_distro_info():
|
|
|
22
22
|
__my_distro__ = __my_system__
|
|
23
23
|
__my_distro_ver__ = platform.release().lower()
|
|
24
24
|
|
|
25
|
-
if __my_distro_ver__
|
|
25
|
+
if __my_distro_ver__ not in ["10", "11"]:
|
|
26
26
|
warnings.warn(
|
|
27
27
|
"Unsupported Windows version (%s). ONNX Runtime supports Windows 10 and above, only."
|
|
28
28
|
% __my_distro_ver__
|
|
@@ -349,6 +349,10 @@ class MatMul4BitsQuantizer:
|
|
|
349
349
|
self.int4_quant_algo()
|
|
350
350
|
|
|
351
351
|
|
|
352
|
+
def ort_convert_str_to_bool(value):
|
|
353
|
+
return value.lower() in ("true", "1")
|
|
354
|
+
|
|
355
|
+
|
|
352
356
|
def parse_args():
|
|
353
357
|
parser = argparse.ArgumentParser(
|
|
354
358
|
description="""Blockwise int4 quantization for MatMul 2D weight matrices.
|
|
@@ -366,7 +370,10 @@ set of 4b integers with a scaling factor and an optional offset.
|
|
|
366
370
|
"--symmetric",
|
|
367
371
|
required=False,
|
|
368
372
|
default=True,
|
|
369
|
-
|
|
373
|
+
const=True,
|
|
374
|
+
nargs="?",
|
|
375
|
+
type=ort_convert_str_to_bool,
|
|
376
|
+
choices=[True, False],
|
|
370
377
|
help="Indicate whether to quantize the model symmetrically",
|
|
371
378
|
)
|
|
372
379
|
parser.add_argument(
|
|
@@ -389,7 +389,7 @@ class ONNXQuantizer:
|
|
|
389
389
|
def quantize_model(self):
|
|
390
390
|
if self.has_QDQ_nodes():
|
|
391
391
|
logging.warning(
|
|
392
|
-
"Please check if the model is already quantized."
|
|
392
|
+
"Please check if the model is already quantized. "
|
|
393
393
|
"Note you don't need to quantize a QAT model. OnnxRuntime support to run QAT model directly."
|
|
394
394
|
)
|
|
395
395
|
|
|
@@ -446,6 +446,23 @@ class ONNXQuantizer:
|
|
|
446
446
|
return False
|
|
447
447
|
return self.parent.is_valid_quantize_weight(weight_name)
|
|
448
448
|
|
|
449
|
+
def _get_default_tensor_type(self, tensor_name):
|
|
450
|
+
if "DefaultTensorType" in self.extra_options:
|
|
451
|
+
logging.info(
|
|
452
|
+
"get_tensor_type returns DefaultTensorType for tensor name %r, use %d",
|
|
453
|
+
tensor_name,
|
|
454
|
+
self.extra_options["DefaultTensorType"],
|
|
455
|
+
)
|
|
456
|
+
return self.extra_options["DefaultTensorType"]
|
|
457
|
+
raise RuntimeError(
|
|
458
|
+
f"Unable to find data type for weight_name={tensor_name!r}. "
|
|
459
|
+
f"shape_inference failed to return a type probably this node is "
|
|
460
|
+
f"from a different domain or using an input produced by such an operator. "
|
|
461
|
+
f"This may happen if you quantize a model already quantized. "
|
|
462
|
+
f"You may use extra_options `DefaultTensorType` to indicate "
|
|
463
|
+
f"the default weight type, usually `onnx.TensorProto.FLOAT`."
|
|
464
|
+
)
|
|
465
|
+
|
|
449
466
|
def get_tensor_type(self, tensor_name, mandatory=False):
|
|
450
467
|
weight = find_by_name(tensor_name, self.model.initializer())
|
|
451
468
|
if weight is not None:
|
|
@@ -454,11 +471,11 @@ class ONNXQuantizer:
|
|
|
454
471
|
vi = self.value_infos[tensor_name]
|
|
455
472
|
if vi.type.HasField("tensor_type"):
|
|
456
473
|
if mandatory and vi.type.tensor_type.elem_type == 0:
|
|
457
|
-
|
|
474
|
+
return self._get_default_tensor_type(tensor_name)
|
|
458
475
|
return vi.type.tensor_type.elem_type
|
|
459
476
|
if (not self.enable_subgraph_quantization) or (self.parent is None):
|
|
460
477
|
if mandatory:
|
|
461
|
-
|
|
478
|
+
return self._get_default_tensor_type(tensor_name)
|
|
462
479
|
return None
|
|
463
480
|
otype = self.parent.is_valid_quantize_weight(tensor_name)
|
|
464
481
|
if otype is not None:
|
|
@@ -468,7 +485,7 @@ class ONNXQuantizer:
|
|
|
468
485
|
if res is not None:
|
|
469
486
|
return res
|
|
470
487
|
if mandatory:
|
|
471
|
-
|
|
488
|
+
return self._get_default_tensor_type(tensor_name)
|
|
472
489
|
return None
|
|
473
490
|
|
|
474
491
|
def is_float_tensor(self, tensor_name):
|
|
@@ -1336,9 +1353,15 @@ class ONNXQuantizer:
|
|
|
1336
1353
|
if (value_name in self.quantized_value_map) and (value_name not in self.generated_value_names):
|
|
1337
1354
|
quantized_value = self.quantized_value_map[value_name]
|
|
1338
1355
|
# Add DequantizeLinear Node for this input
|
|
1356
|
+
|
|
1339
1357
|
scale_init = find_by_name(quantized_value.scale_name, self.model.initializer())
|
|
1340
|
-
|
|
1341
|
-
|
|
1358
|
+
|
|
1359
|
+
# In case we are working with subgraphs, the graph `producer_name` is set to `"onnx-quantizer"` in the `quantize_subgraph` method. In this case, the scale initializer may be on the top level graph, so the check below can not be done.
|
|
1360
|
+
if self.model.model.producer_name != "onnx-quantizer" or (
|
|
1361
|
+
self.model.model.producer_name == "onnx-quantizer" and scale_init is not None
|
|
1362
|
+
):
|
|
1363
|
+
# axis is not specified so scale_init must be a scalar.
|
|
1364
|
+
assert onnx.numpy_helper.to_array(scale_init).size == 1
|
|
1342
1365
|
|
|
1343
1366
|
dqlinear_name = value_name + "_DequantizeLinear"
|
|
1344
1367
|
dqlinear_node = self.model.find_node_by_name(dqlinear_name, self.new_nodes, self.model.graph())
|
|
@@ -270,6 +270,8 @@ class QDQQuantizer(ONNXQuantizer):
|
|
|
270
270
|
|
|
271
271
|
self.model.model.producer_name = __producer__
|
|
272
272
|
self.model.model.producer_version = __version__
|
|
273
|
+
if self.qdq_op_domain == ms_domain:
|
|
274
|
+
self.model.set_opset_import(ms_domain, 1)
|
|
273
275
|
|
|
274
276
|
return self.model.model
|
|
275
277
|
|
|
@@ -197,6 +197,7 @@ class SymbolicShapeInference:
|
|
|
197
197
|
"BiasGelu": self._infer_BiasGelu,
|
|
198
198
|
"BiasSplitGelu": self._infer_BiasSplitGelu,
|
|
199
199
|
"DecoderMaskedMultiHeadAttention": self._infer_DecoderMaskedMultiHeadAttention,
|
|
200
|
+
"DequantizeLinear": self._infer_DequantizeLinear,
|
|
200
201
|
"EmbedLayerNormalization": self._infer_EmbedLayerNormalization,
|
|
201
202
|
"FastGelu": self._infer_FastGelu,
|
|
202
203
|
"GatedRelativePositionBias": self._infer_GatedRelativePositionBias,
|
|
@@ -212,6 +213,7 @@ class SymbolicShapeInference:
|
|
|
212
213
|
"PackedAttention": self._infer_PackedAttention,
|
|
213
214
|
"PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention,
|
|
214
215
|
"PythonOp": self._infer_PythonOp,
|
|
216
|
+
"QuantizeLinear": self._infer_QuantizeLinear,
|
|
215
217
|
"QuickGelu": self._infer_FastGelu,
|
|
216
218
|
"RelativePositionBias": self._infer_RelativePositionBias,
|
|
217
219
|
"RemovePadding": self._infer_RemovePadding,
|
|
@@ -238,6 +240,7 @@ class SymbolicShapeInference:
|
|
|
238
240
|
"upsample_nearest1d": self._infer_aten_upsample,
|
|
239
241
|
"upsample_nearest2d": self._infer_aten_upsample,
|
|
240
242
|
"upsample_nearest3d": self._infer_aten_upsample,
|
|
243
|
+
"upsample_bicubic2d": self._infer_aten_upsample,
|
|
241
244
|
}
|
|
242
245
|
self.run_ = True
|
|
243
246
|
self.suggested_merge_ = {}
|
|
@@ -457,6 +460,8 @@ class SymbolicShapeInference:
|
|
|
457
460
|
"GemmFastGelu",
|
|
458
461
|
"LayerNormalization",
|
|
459
462
|
"LongformerAttention",
|
|
463
|
+
"DequantizeLinear",
|
|
464
|
+
"QuantizeLinear",
|
|
460
465
|
"RelativePositionBias",
|
|
461
466
|
"RemovePadding",
|
|
462
467
|
"RestorePadding",
|
|
@@ -979,6 +984,29 @@ class SymbolicShapeInference:
|
|
|
979
984
|
)
|
|
980
985
|
)
|
|
981
986
|
|
|
987
|
+
def _infer_DequantizeLinear(self, node): # noqa: N802
|
|
988
|
+
# Get the output data type from the scale input (index 1, required).
|
|
989
|
+
output_dtype = self.known_vi_[node.input[1]].type.tensor_type.elem_type
|
|
990
|
+
|
|
991
|
+
# Get the output shape from the first input.
|
|
992
|
+
output_shape = self._get_shape(node, 0)
|
|
993
|
+
|
|
994
|
+
vi = self.known_vi_[node.output[0]]
|
|
995
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
|
996
|
+
|
|
997
|
+
def _infer_QuantizeLinear(self, node): # noqa: N802
|
|
998
|
+
# Get the output data type from the zero-point input (index 2, optional).
|
|
999
|
+
# Otherwise, default to uint8
|
|
1000
|
+
output_dtype = onnx.TensorProto.UINT8
|
|
1001
|
+
if len(node.input) > 2 and node.input[2]:
|
|
1002
|
+
output_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type
|
|
1003
|
+
|
|
1004
|
+
# Get the output shape from the first input.
|
|
1005
|
+
output_shape = self._get_shape(node, 0)
|
|
1006
|
+
|
|
1007
|
+
vi = self.known_vi_[node.output[0]]
|
|
1008
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
|
1009
|
+
|
|
982
1010
|
def _infer_Einsum(self, node): # noqa: N802
|
|
983
1011
|
# ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275
|
|
984
1012
|
equation = get_attribute(node, "equation")
|
|
@@ -1273,7 +1273,7 @@ def find_past_seq_len_usage(subg: GraphProto):
|
|
|
1273
1273
|
|
|
1274
1274
|
|
|
1275
1275
|
def replace_mha_with_gqa(
|
|
1276
|
-
model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1, window_size: int =
|
|
1276
|
+
model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1, window_size: int = -1
|
|
1277
1277
|
):
|
|
1278
1278
|
# Insert attention_mask subgraph to calculate shared inputs for all GroupQueryAttention nodes
|
|
1279
1279
|
#
|
|
@@ -1339,31 +1339,163 @@ def replace_mha_with_gqa(
|
|
|
1339
1339
|
)
|
|
1340
1340
|
|
|
1341
1341
|
# Replace MultiHeadAttention with GroupQueryAttention
|
|
1342
|
+
#
|
|
1343
|
+
# When replacing, fuse the following subgraph:
|
|
1344
|
+
#
|
|
1345
|
+
# root_input
|
|
1346
|
+
# / | \
|
|
1347
|
+
# MatMul MatMul MatMul
|
|
1348
|
+
# | | |
|
|
1349
|
+
# Add Add Add (optional Adds)
|
|
1350
|
+
# | | |
|
|
1351
|
+
# RotEmb RotEmb |
|
|
1352
|
+
# \ | /
|
|
1353
|
+
# MultiHeadAttention
|
|
1354
|
+
#
|
|
1355
|
+
# to this new subgraph:
|
|
1356
|
+
#
|
|
1357
|
+
# root_input
|
|
1358
|
+
# |
|
|
1359
|
+
# PackedMatMul (if possible)
|
|
1360
|
+
# |
|
|
1361
|
+
# PackedAdd (if possible)
|
|
1362
|
+
# |
|
|
1363
|
+
# GroupQueryAttention
|
|
1364
|
+
#
|
|
1365
|
+
|
|
1342
1366
|
mha_nodes = list(filter(lambda node: node.op_type == "MultiHeadAttention", model.model.graph.node))
|
|
1343
|
-
for node in mha_nodes:
|
|
1344
|
-
|
|
1367
|
+
for idx, node in enumerate(mha_nodes):
|
|
1368
|
+
# Detect Q path to MHA
|
|
1369
|
+
q_path_1 = model.match_parent_path(node, ["RotaryEmbedding", "Add", "MatMul"], [0, 0, 0])
|
|
1370
|
+
q_path_2 = model.match_parent_path(node, ["RotaryEmbedding", "MatMul"], [0, 0])
|
|
1371
|
+
|
|
1372
|
+
q_rotary, q_add, q_matmul = None, None, None
|
|
1373
|
+
if q_path_1 is not None:
|
|
1374
|
+
q_rotary, q_add, q_matmul = q_path_1
|
|
1375
|
+
elif q_path_2 is not None:
|
|
1376
|
+
q_rotary, q_matmul = q_path_2
|
|
1377
|
+
|
|
1378
|
+
# Detect K path to MHA
|
|
1379
|
+
k_path_1 = model.match_parent_path(node, ["RotaryEmbedding", "Add", "MatMul"], [1, 0, 0])
|
|
1380
|
+
k_path_2 = model.match_parent_path(node, ["RotaryEmbedding", "MatMul"], [1, 0])
|
|
1381
|
+
|
|
1382
|
+
k_rotary, k_add, k_matmul = None, None, None
|
|
1383
|
+
if k_path_1 is not None:
|
|
1384
|
+
k_rotary, k_add, k_matmul = k_path_1
|
|
1385
|
+
elif k_path_2 is not None:
|
|
1386
|
+
k_rotary, k_matmul = k_path_2
|
|
1387
|
+
|
|
1388
|
+
# Detect V path to MHA
|
|
1389
|
+
v_path_1 = model.match_parent_path(node, ["Add", "MatMul"], [2, 0])
|
|
1390
|
+
v_path_2 = model.match_parent_path(node, ["MatMul"], [2])
|
|
1391
|
+
|
|
1392
|
+
v_add, v_matmul = None, None
|
|
1393
|
+
if v_path_1 is not None:
|
|
1394
|
+
v_add, v_matmul = v_path_1
|
|
1395
|
+
elif v_path_2 is not None:
|
|
1396
|
+
v_matmul = v_path_2[0]
|
|
1397
|
+
|
|
1398
|
+
# Get `interleaved` attribute from RotaryEmbedding
|
|
1399
|
+
interleaved = 0
|
|
1400
|
+
if q_rotary is not None and k_rotary is not None:
|
|
1401
|
+
for att in q_rotary.attribute:
|
|
1402
|
+
if att.name == "interleaved":
|
|
1403
|
+
interleaved = att.i
|
|
1404
|
+
|
|
1405
|
+
# Get `num_heads` attribute from MHA
|
|
1406
|
+
num_heads = 0
|
|
1345
1407
|
for att in node.attribute:
|
|
1346
1408
|
if att.name == "num_heads":
|
|
1347
|
-
|
|
1409
|
+
num_heads = att.i
|
|
1410
|
+
|
|
1411
|
+
# Check if root_input to Q/K/V paths is the same
|
|
1412
|
+
root_input_is_same = q_matmul.input[0] == k_matmul.input[0] and k_matmul.input[0] == v_matmul.input[0]
|
|
1413
|
+
|
|
1414
|
+
# Check if Q/K/V paths all have bias or all don't have bias
|
|
1415
|
+
all_paths_have_bias = q_add is not None and k_add is not None and v_add is not None
|
|
1416
|
+
all_paths_have_no_bias = q_add is None and k_add is None and v_add is None
|
|
1417
|
+
|
|
1418
|
+
# Make PackedMatMul node if possible
|
|
1419
|
+
q_input_to_attention, k_input_to_attention, v_input_to_attention = "", "", ""
|
|
1420
|
+
if root_input_is_same and (all_paths_have_bias or all_paths_have_no_bias):
|
|
1421
|
+
qw = NumpyHelper.to_array(model.get_initializer(q_matmul.input[1]))
|
|
1422
|
+
kw = NumpyHelper.to_array(model.get_initializer(k_matmul.input[1]))
|
|
1423
|
+
vw = NumpyHelper.to_array(model.get_initializer(v_matmul.input[1]))
|
|
1424
|
+
|
|
1425
|
+
dim = qw.shape[-1]
|
|
1426
|
+
qkv_weight = np.stack((qw, kw, vw), axis=1).reshape(dim, 3 * dim)
|
|
1427
|
+
qkv_weight = onnx.numpy_helper.from_array(qkv_weight, name=f"QKV_Weight_{idx}")
|
|
1428
|
+
model.add_initializer(qkv_weight)
|
|
1429
|
+
|
|
1430
|
+
packed_matmul_node = onnx.helper.make_node(
|
|
1431
|
+
"MatMul",
|
|
1432
|
+
inputs=[q_matmul.input[0], qkv_weight.name],
|
|
1433
|
+
outputs=[f"{qkv_weight.name}_output"],
|
|
1434
|
+
name=model.create_node_name("MatMul"),
|
|
1435
|
+
)
|
|
1436
|
+
model.model.graph.node.extend([packed_matmul_node])
|
|
1437
|
+
model.model.graph.node.remove(q_matmul)
|
|
1438
|
+
model.model.graph.node.remove(k_matmul)
|
|
1439
|
+
model.model.graph.node.remove(v_matmul)
|
|
1440
|
+
q_input_to_attention = packed_matmul_node.output[0]
|
|
1441
|
+
|
|
1442
|
+
# Make PackedAdd node if possible
|
|
1443
|
+
if all_paths_have_bias:
|
|
1444
|
+
qb = NumpyHelper.to_array(model.get_initializer(q_add.input[1]))
|
|
1445
|
+
kb = NumpyHelper.to_array(model.get_initializer(k_add.input[1]))
|
|
1446
|
+
vb = NumpyHelper.to_array(model.get_initializer(v_add.input[1]))
|
|
1447
|
+
|
|
1448
|
+
dim = qb.shape[-1]
|
|
1449
|
+
qkv_bias = np.stack((qb, kb, vb), axis=0).reshape(3 * dim)
|
|
1450
|
+
qkv_bias = onnx.numpy_helper.from_array(qkv_bias, name=f"QKV_Bias_{idx}")
|
|
1451
|
+
model.add_initializer(qkv_bias)
|
|
1452
|
+
packed_add_node = onnx.helper.make_node(
|
|
1453
|
+
"Add",
|
|
1454
|
+
inputs=[packed_matmul_node.output[0], qkv_bias.name],
|
|
1455
|
+
outputs=[f"{qkv_bias.name}_output"],
|
|
1456
|
+
)
|
|
1457
|
+
model.model.graph.node.extend([packed_add_node])
|
|
1458
|
+
model.model.graph.node.remove(q_add)
|
|
1459
|
+
model.model.graph.node.remove(k_add)
|
|
1460
|
+
model.model.graph.node.remove(v_add)
|
|
1461
|
+
q_input_to_attention = packed_add_node.output[0]
|
|
1462
|
+
|
|
1463
|
+
else:
|
|
1464
|
+
q_input_to_attention = q_matmul.output[0]
|
|
1465
|
+
k_input_to_attention = k_matmul.output[0]
|
|
1466
|
+
v_input_to_attention = v_matmul.output[0]
|
|
1467
|
+
|
|
1468
|
+
# Make GQA node
|
|
1348
1469
|
gqa_node = onnx.helper.make_node(
|
|
1349
1470
|
"GroupQueryAttention",
|
|
1350
1471
|
inputs=[
|
|
1351
|
-
|
|
1352
|
-
|
|
1353
|
-
|
|
1472
|
+
q_input_to_attention, # query
|
|
1473
|
+
k_input_to_attention, # key
|
|
1474
|
+
v_input_to_attention, # value
|
|
1354
1475
|
node.input[6], # past_key
|
|
1355
1476
|
node.input[7], # past_value
|
|
1356
|
-
|
|
1357
|
-
|
|
1477
|
+
seqlen_k_cast_node.output[0], # seqlens_k (for attention mask)
|
|
1478
|
+
total_seqlen_cast_node.output[0], # total_seq_len (for attention mask)
|
|
1479
|
+
q_rotary.input[2] if q_rotary is not None else "", # cos_cache (for rotary embeddings)
|
|
1480
|
+
q_rotary.input[3] if q_rotary is not None else "", # sin_cache (for rotary embeddings)
|
|
1358
1481
|
],
|
|
1359
1482
|
outputs=node.output,
|
|
1360
1483
|
name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"),
|
|
1361
1484
|
domain="com.microsoft",
|
|
1362
|
-
num_heads=
|
|
1363
|
-
kv_num_heads=
|
|
1485
|
+
num_heads=num_heads // world_size,
|
|
1486
|
+
kv_num_heads=num_heads // world_size if kv_num_heads == 0 else kv_num_heads // world_size,
|
|
1487
|
+
local_window_size=window_size,
|
|
1488
|
+
do_rotary=int(q_rotary is not None and k_rotary is not None),
|
|
1489
|
+
rotary_interleaved=interleaved,
|
|
1364
1490
|
)
|
|
1365
1491
|
model.model.graph.node.remove(node)
|
|
1366
1492
|
model.model.graph.node.extend([gqa_node])
|
|
1493
|
+
|
|
1494
|
+
if q_rotary is not None:
|
|
1495
|
+
model.model.graph.node.remove(q_rotary)
|
|
1496
|
+
if k_rotary is not None:
|
|
1497
|
+
model.model.graph.node.remove(k_rotary)
|
|
1498
|
+
|
|
1367
1499
|
return model
|
|
1368
1500
|
|
|
1369
1501
|
|