onnxruntime-directml 1.17.1__cp39-cp39-win_amd64.whl → 1.17.3__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. onnxruntime/__init__.py +1 -1
  2. onnxruntime/capi/DirectML.dll +0 -0
  3. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  4. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  5. onnxruntime/capi/onnxruntime_validation.py +1 -1
  6. onnxruntime/transformers/convert_generation.py +143 -11
  7. onnxruntime/transformers/models/llama/benchmark.py +20 -18
  8. onnxruntime/transformers/models/llama/benchmark_all.py +22 -0
  9. onnxruntime/transformers/models/llama/benchmark_e2e.py +581 -0
  10. onnxruntime/transformers/models/llama/convert_to_onnx.py +5 -0
  11. onnxruntime/transformers/models/llama/dist_settings.py +5 -0
  12. onnxruntime/transformers/models/llama/llama_inputs.py +200 -4
  13. onnxruntime/transformers/models/llama/llama_parity.py +8 -3
  14. onnxruntime/transformers/models/llama/llama_torch.py +5 -0
  15. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +5 -0
  16. onnxruntime/transformers/models/whisper/convert_to_onnx.py +7 -1
  17. onnxruntime/transformers/models/whisper/whisper_chain.py +2 -2
  18. onnxruntime/transformers/models/whisper/whisper_decoder.py +3 -1
  19. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +5 -5
  20. onnxruntime/transformers/models/whisper/whisper_helper.py +140 -52
  21. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +9 -1
  22. {onnxruntime_directml-1.17.1.dist-info → onnxruntime_directml-1.17.3.dist-info}/METADATA +11 -1
  23. {onnxruntime_directml-1.17.1.dist-info → onnxruntime_directml-1.17.3.dist-info}/RECORD +26 -25
  24. {onnxruntime_directml-1.17.1.dist-info → onnxruntime_directml-1.17.3.dist-info}/WHEEL +1 -1
  25. {onnxruntime_directml-1.17.1.dist-info → onnxruntime_directml-1.17.3.dist-info}/entry_points.txt +0 -0
  26. {onnxruntime_directml-1.17.1.dist-info → onnxruntime_directml-1.17.3.dist-info}/top_level.txt +0 -0
onnxruntime/__init__.py CHANGED
@@ -7,7 +7,7 @@ ONNX Runtime is a performance-focused scoring engine for Open Neural Network Exc
7
7
  For more information on ONNX Runtime, please see `aka.ms/onnxruntime <https://aka.ms/onnxruntime/>`_
8
8
  or the `Github project <https://github.com/microsoft/onnxruntime/>`_.
9
9
  """
10
- __version__ = "1.17.1"
10
+ __version__ = "1.17.3"
11
11
  __author__ = "Microsoft"
12
12
 
13
13
  # we need to do device version validation (for example to check Cuda version for an onnxruntime-training package).
Binary file
@@ -22,7 +22,7 @@ def check_distro_info():
22
22
  __my_distro__ = __my_system__
23
23
  __my_distro_ver__ = platform.release().lower()
24
24
 
25
- if __my_distro_ver__ != "10":
25
+ if __my_distro_ver__ not in ["10", "11"]:
26
26
  warnings.warn(
27
27
  "Unsupported Windows version (%s). ONNX Runtime supports Windows 10 and above, only."
28
28
  % __my_distro_ver__
@@ -1273,7 +1273,7 @@ def find_past_seq_len_usage(subg: GraphProto):
1273
1273
 
1274
1274
 
1275
1275
  def replace_mha_with_gqa(
1276
- model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1, window_size: int = 0
1276
+ model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1, window_size: int = -1
1277
1277
  ):
1278
1278
  # Insert attention_mask subgraph to calculate shared inputs for all GroupQueryAttention nodes
1279
1279
  #
@@ -1339,31 +1339,163 @@ def replace_mha_with_gqa(
1339
1339
  )
1340
1340
 
1341
1341
  # Replace MultiHeadAttention with GroupQueryAttention
1342
+ #
1343
+ # When replacing, fuse the following subgraph:
1344
+ #
1345
+ # root_input
1346
+ # / | \
1347
+ # MatMul MatMul MatMul
1348
+ # | | |
1349
+ # Add Add Add (optional Adds)
1350
+ # | | |
1351
+ # RotEmb RotEmb |
1352
+ # \ | /
1353
+ # MultiHeadAttention
1354
+ #
1355
+ # to this new subgraph:
1356
+ #
1357
+ # root_input
1358
+ # |
1359
+ # PackedMatMul (if possible)
1360
+ # |
1361
+ # PackedAdd (if possible)
1362
+ # |
1363
+ # GroupQueryAttention
1364
+ #
1365
+
1342
1366
  mha_nodes = list(filter(lambda node: node.op_type == "MultiHeadAttention", model.model.graph.node))
1343
- for node in mha_nodes:
1344
- num_heads_mha = 0
1367
+ for idx, node in enumerate(mha_nodes):
1368
+ # Detect Q path to MHA
1369
+ q_path_1 = model.match_parent_path(node, ["RotaryEmbedding", "Add", "MatMul"], [0, 0, 0])
1370
+ q_path_2 = model.match_parent_path(node, ["RotaryEmbedding", "MatMul"], [0, 0])
1371
+
1372
+ q_rotary, q_add, q_matmul = None, None, None
1373
+ if q_path_1 is not None:
1374
+ q_rotary, q_add, q_matmul = q_path_1
1375
+ elif q_path_2 is not None:
1376
+ q_rotary, q_matmul = q_path_2
1377
+
1378
+ # Detect K path to MHA
1379
+ k_path_1 = model.match_parent_path(node, ["RotaryEmbedding", "Add", "MatMul"], [1, 0, 0])
1380
+ k_path_2 = model.match_parent_path(node, ["RotaryEmbedding", "MatMul"], [1, 0])
1381
+
1382
+ k_rotary, k_add, k_matmul = None, None, None
1383
+ if k_path_1 is not None:
1384
+ k_rotary, k_add, k_matmul = k_path_1
1385
+ elif k_path_2 is not None:
1386
+ k_rotary, k_matmul = k_path_2
1387
+
1388
+ # Detect V path to MHA
1389
+ v_path_1 = model.match_parent_path(node, ["Add", "MatMul"], [2, 0])
1390
+ v_path_2 = model.match_parent_path(node, ["MatMul"], [2])
1391
+
1392
+ v_add, v_matmul = None, None
1393
+ if v_path_1 is not None:
1394
+ v_add, v_matmul = v_path_1
1395
+ elif v_path_2 is not None:
1396
+ v_matmul = v_path_2[0]
1397
+
1398
+ # Get `interleaved` attribute from RotaryEmbedding
1399
+ interleaved = 0
1400
+ if q_rotary is not None and k_rotary is not None:
1401
+ for att in q_rotary.attribute:
1402
+ if att.name == "interleaved":
1403
+ interleaved = att.i
1404
+
1405
+ # Get `num_heads` attribute from MHA
1406
+ num_heads = 0
1345
1407
  for att in node.attribute:
1346
1408
  if att.name == "num_heads":
1347
- num_heads_mha = att.i
1409
+ num_heads = att.i
1410
+
1411
+ # Check if root_input to Q/K/V paths is the same
1412
+ root_input_is_same = q_matmul.input[0] == k_matmul.input[0] and k_matmul.input[0] == v_matmul.input[0]
1413
+
1414
+ # Check if Q/K/V paths all have bias or all don't have bias
1415
+ all_paths_have_bias = q_add is not None and k_add is not None and v_add is not None
1416
+ all_paths_have_no_bias = q_add is None and k_add is None and v_add is None
1417
+
1418
+ # Make PackedMatMul node if possible
1419
+ q_input_to_attention, k_input_to_attention, v_input_to_attention = "", "", ""
1420
+ if root_input_is_same and (all_paths_have_bias or all_paths_have_no_bias):
1421
+ qw = NumpyHelper.to_array(model.get_initializer(q_matmul.input[1]))
1422
+ kw = NumpyHelper.to_array(model.get_initializer(k_matmul.input[1]))
1423
+ vw = NumpyHelper.to_array(model.get_initializer(v_matmul.input[1]))
1424
+
1425
+ dim = qw.shape[-1]
1426
+ qkv_weight = np.stack((qw, kw, vw), axis=1).reshape(dim, 3 * dim)
1427
+ qkv_weight = onnx.numpy_helper.from_array(qkv_weight, name=f"QKV_Weight_{idx}")
1428
+ model.add_initializer(qkv_weight)
1429
+
1430
+ packed_matmul_node = onnx.helper.make_node(
1431
+ "MatMul",
1432
+ inputs=[q_matmul.input[0], qkv_weight.name],
1433
+ outputs=[f"{qkv_weight.name}_output"],
1434
+ name=model.create_node_name("MatMul"),
1435
+ )
1436
+ model.model.graph.node.extend([packed_matmul_node])
1437
+ model.model.graph.node.remove(q_matmul)
1438
+ model.model.graph.node.remove(k_matmul)
1439
+ model.model.graph.node.remove(v_matmul)
1440
+ q_input_to_attention = packed_matmul_node.output[0]
1441
+
1442
+ # Make PackedAdd node if possible
1443
+ if all_paths_have_bias:
1444
+ qb = NumpyHelper.to_array(model.get_initializer(q_add.input[1]))
1445
+ kb = NumpyHelper.to_array(model.get_initializer(k_add.input[1]))
1446
+ vb = NumpyHelper.to_array(model.get_initializer(v_add.input[1]))
1447
+
1448
+ dim = qb.shape[-1]
1449
+ qkv_bias = np.stack((qb, kb, vb), axis=0).reshape(3 * dim)
1450
+ qkv_bias = onnx.numpy_helper.from_array(qkv_bias, name=f"QKV_Bias_{idx}")
1451
+ model.add_initializer(qkv_bias)
1452
+ packed_add_node = onnx.helper.make_node(
1453
+ "Add",
1454
+ inputs=[packed_matmul_node.output[0], qkv_bias.name],
1455
+ outputs=[f"{qkv_bias.name}_output"],
1456
+ )
1457
+ model.model.graph.node.extend([packed_add_node])
1458
+ model.model.graph.node.remove(q_add)
1459
+ model.model.graph.node.remove(k_add)
1460
+ model.model.graph.node.remove(v_add)
1461
+ q_input_to_attention = packed_add_node.output[0]
1462
+
1463
+ else:
1464
+ q_input_to_attention = q_matmul.output[0]
1465
+ k_input_to_attention = k_matmul.output[0]
1466
+ v_input_to_attention = v_matmul.output[0]
1467
+
1468
+ # Make GQA node
1348
1469
  gqa_node = onnx.helper.make_node(
1349
1470
  "GroupQueryAttention",
1350
1471
  inputs=[
1351
- node.input[0], # query
1352
- node.input[1], # key
1353
- node.input[2], # value
1472
+ q_input_to_attention, # query
1473
+ k_input_to_attention, # key
1474
+ v_input_to_attention, # value
1354
1475
  node.input[6], # past_key
1355
1476
  node.input[7], # past_value
1356
- "seqlens_k", # seqlens_k (for attention_mask)
1357
- "total_seq_len", # total_seq_len (for attention_mask)
1477
+ seqlen_k_cast_node.output[0], # seqlens_k (for attention mask)
1478
+ total_seqlen_cast_node.output[0], # total_seq_len (for attention mask)
1479
+ q_rotary.input[2] if q_rotary is not None else "", # cos_cache (for rotary embeddings)
1480
+ q_rotary.input[3] if q_rotary is not None else "", # sin_cache (for rotary embeddings)
1358
1481
  ],
1359
1482
  outputs=node.output,
1360
1483
  name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"),
1361
1484
  domain="com.microsoft",
1362
- num_heads=num_heads_mha // world_size,
1363
- kv_num_heads=num_heads_mha // world_size if kv_num_heads == 0 else kv_num_heads // world_size,
1485
+ num_heads=num_heads // world_size,
1486
+ kv_num_heads=num_heads // world_size if kv_num_heads == 0 else kv_num_heads // world_size,
1487
+ local_window_size=window_size,
1488
+ do_rotary=int(q_rotary is not None and k_rotary is not None),
1489
+ rotary_interleaved=interleaved,
1364
1490
  )
1365
1491
  model.model.graph.node.remove(node)
1366
1492
  model.model.graph.node.extend([gqa_node])
1493
+
1494
+ if q_rotary is not None:
1495
+ model.model.graph.node.remove(q_rotary)
1496
+ if k_rotary is not None:
1497
+ model.model.graph.node.remove(k_rotary)
1498
+
1367
1499
  return model
1368
1500
 
1369
1501
 
@@ -1,3 +1,8 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
1
6
  import argparse
2
7
  import datetime
3
8
  import gc
@@ -14,11 +19,12 @@ import torch
14
19
  from benchmark_helper import measure_memory, setup_logger
15
20
  from dist_settings import get_rank, get_size
16
21
  from llama_inputs import (
17
- add_io_bindings,
22
+ add_io_bindings_as_ortvalues,
18
23
  get_merged_sample_with_past_kv_inputs,
19
24
  get_msft_sample_inputs,
20
25
  get_sample_inputs,
21
26
  get_sample_with_past_kv_inputs,
27
+ verify_ort_inputs,
22
28
  )
23
29
  from optimum.onnxruntime import ORTModelForCausalLM
24
30
  from torch.profiler import ProfilerActivity, profile, record_function
@@ -203,6 +209,7 @@ def get_model(args: argparse.Namespace):
203
209
  torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
204
210
  use_auth_token=args.auth,
205
211
  use_cache=True,
212
+ cache_dir=args.cache_dir,
206
213
  ).to(args.target_device)
207
214
  end_time = time.time()
208
215
 
@@ -444,24 +451,12 @@ def run_hf_inference(args, init_inputs, iter_inputs, model):
444
451
 
445
452
  def run_ort_inference(args, init_inputs, iter_inputs, model):
446
453
  def prepare_ort_inputs(inputs, kv_cache_ortvalues):
447
- # Check that all model inputs will be provided
448
- model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
449
- user_inputs = set(inputs.keys())
450
- missing_inputs = model_inputs - user_inputs
451
- if len(missing_inputs):
452
- logger.error(f"The following model inputs are missing: {missing_inputs}")
453
- raise Exception("There are missing inputs to the model. Please add them and try again.")
454
-
455
- # Remove unnecessary inputs from model inputs
456
- unnecessary_inputs = user_inputs - model_inputs
457
- if len(unnecessary_inputs):
458
- for unnecessary_input in unnecessary_inputs:
459
- logger.info(f"Removing unnecessary input '{unnecessary_input}' from user provided inputs")
460
- del inputs[unnecessary_input]
454
+ # Verify model inputs
455
+ inputs = verify_ort_inputs(model, inputs)
461
456
 
462
457
  # Add IO bindings for non-CPU execution providers
463
458
  if args.device != "cpu":
464
- io_binding, kv_cache_ortvalues = add_io_bindings(
459
+ io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues(
465
460
  model, inputs, args.device, int(args.rank), args.use_gqa, kv_cache_ortvalues
466
461
  )
467
462
  setattr(args, "io_binding", io_binding) # noqa: B010
@@ -612,6 +607,13 @@ def get_args(rank=0):
612
607
  parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
613
608
  parser.add_argument("--verbose", default=False, action="store_true")
614
609
  parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
610
+ parser.add_argument(
611
+ "--cache-dir",
612
+ type=str,
613
+ required=True,
614
+ default="./model_cache",
615
+ help="Cache dir where Hugging Face files are stored",
616
+ )
615
617
 
616
618
  args = parser.parse_args()
617
619
 
@@ -662,8 +664,8 @@ def main():
662
664
 
663
665
  args.rank = rank
664
666
  args.world_size = world_size
665
- tokenizer = AutoTokenizer.from_pretrained(args.model_name)
666
- config = AutoConfig.from_pretrained(args.model_name)
667
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=args.cache_dir)
668
+ config = AutoConfig.from_pretrained(args.model_name, cache_dir=args.cache_dir)
667
669
  target_device = f"cuda:{args.rank}" if args.device != "cpu" else args.device
668
670
  use_fp16 = args.precision == "fp16"
669
671
 
@@ -1,3 +1,8 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
1
6
  import argparse
2
7
  import datetime
3
8
  import json
@@ -78,6 +83,13 @@ def get_args():
78
83
  help="Path to ONNX model from convert_to_onnx",
79
84
  )
80
85
 
86
+ parser.add_argument(
87
+ "--cache-dir",
88
+ type=str,
89
+ default="./model_cache",
90
+ help="Cache dir where Hugging Face files are stored",
91
+ )
92
+
81
93
  parser.add_argument(
82
94
  "--model-name",
83
95
  type=str,
@@ -332,6 +344,8 @@ def main():
332
344
  str(args.num_runs),
333
345
  "--log-folder",
334
346
  args.log_folder,
347
+ "--cache-dir",
348
+ args.cache_dir,
335
349
  "--auth",
336
350
  ]
337
351
  logger.info("Benchmark PyTorch without torch.compile")
@@ -362,6 +376,8 @@ def main():
362
376
  str(args.num_runs),
363
377
  "--log-folder",
364
378
  args.log_folder,
379
+ "--cache-dir",
380
+ args.cache_dir,
365
381
  "--auth",
366
382
  ]
367
383
  logger.info("Benchmark PyTorch with torch.compile")
@@ -394,6 +410,8 @@ def main():
394
410
  str(args.num_runs),
395
411
  "--log-folder",
396
412
  args.log_folder,
413
+ "--cache-dir",
414
+ args.cache_dir,
397
415
  "--auth",
398
416
  ]
399
417
  logger.info("Benchmark Optimum + ONNX Runtime")
@@ -426,6 +444,8 @@ def main():
426
444
  str(args.num_runs),
427
445
  "--log-folder",
428
446
  args.log_folder,
447
+ "--cache-dir",
448
+ args.cache_dir,
429
449
  ]
430
450
  logger.info("Benchmark Microsoft model in ONNX Runtime")
431
451
  results = benchmark(args, benchmark_cmd, "ort-msft")
@@ -457,6 +477,8 @@ def main():
457
477
  str(args.num_runs),
458
478
  "--log-folder",
459
479
  args.log_folder,
480
+ "--cache-dir",
481
+ args.cache_dir,
460
482
  ]
461
483
  logger.info("Benchmark convert_to_onnx model in ONNX Runtime")
462
484
  results = benchmark(args, benchmark_cmd, "onnxruntime")