onnx-diagnostic 0.7.4__py3-none-any.whl → 0.7.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +66 -8
  3. onnx_diagnostic/ext_test_case.py +2 -0
  4. onnx_diagnostic/helpers/_log_helper.py +461 -0
  5. onnx_diagnostic/helpers/cache_helper.py +250 -15
  6. onnx_diagnostic/helpers/helper.py +146 -10
  7. onnx_diagnostic/helpers/log_helper.py +404 -315
  8. onnx_diagnostic/helpers/mini_onnx_builder.py +7 -2
  9. onnx_diagnostic/helpers/onnx_helper.py +13 -7
  10. onnx_diagnostic/helpers/torch_helper.py +33 -11
  11. onnx_diagnostic/tasks/__init__.py +2 -0
  12. onnx_diagnostic/tasks/feature_extraction.py +86 -5
  13. onnx_diagnostic/tasks/image_text_to_text.py +260 -56
  14. onnx_diagnostic/tasks/mask_generation.py +139 -0
  15. onnx_diagnostic/tasks/text2text_generation.py +2 -2
  16. onnx_diagnostic/tasks/text_generation.py +6 -2
  17. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +7 -1
  18. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +17 -1
  19. onnx_diagnostic/torch_export_patches/patch_inputs.py +4 -1
  20. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +397 -128
  21. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +57 -40
  22. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +288 -0
  23. onnx_diagnostic/torch_models/hghub/model_inputs.py +5 -0
  24. onnx_diagnostic/torch_models/validate.py +26 -3
  25. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/METADATA +1 -1
  26. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/RECORD +29 -27
  27. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/WHEEL +0 -0
  28. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/licenses/LICENSE.txt +0 -0
  29. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,25 @@
1
1
  from typing import Any, List, Set, Tuple
2
2
  import torch
3
- import transformers
4
3
  from transformers.cache_utils import (
5
4
  DynamicCache,
6
- MambaCache,
7
5
  EncoderDecoderCache,
6
+ HybridCache,
8
7
  SlidingWindowCache,
9
8
  StaticCache,
10
9
  )
10
+
11
+ try:
12
+ from transformers.models.mamba.modeling_mamba import MambaCache
13
+ except ImportError:
14
+ from transformers.cache_utils import MambaCache
11
15
  from transformers.modeling_outputs import BaseModelOutput
12
- from ...helpers.cache_helper import make_static_cache
16
+ from ...helpers.cache_helper import (
17
+ make_dynamic_cache,
18
+ make_hybrid_cache,
19
+ make_sliding_window_cache,
20
+ make_static_cache,
21
+ CacheKeyValue,
22
+ )
13
23
  from . import make_serialization_function_for_dataclass
14
24
 
15
25
 
@@ -29,6 +39,12 @@ def flatten_mamba_cache(
29
39
  mamba_cache: MambaCache,
30
40
  ) -> Tuple[List[Any], torch.utils._pytree.Context]:
31
41
  """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
42
+ assert isinstance(mamba_cache.conv_states, list) and isinstance(
43
+ mamba_cache.ssm_states, list
44
+ ), (
45
+ f"Unexpected types for conv_states and ssm_states {type(mamba_cache.conv_states)}, "
46
+ f"{type(mamba_cache.ssm_states)}"
47
+ )
32
48
  flat = [
33
49
  ("conv_states", mamba_cache.conv_states),
34
50
  ("ssm_states", mamba_cache.ssm_states),
@@ -85,9 +101,8 @@ def flatten_dynamic_cache(
85
101
  dynamic_cache: DynamicCache,
86
102
  ) -> Tuple[List[Any], torch.utils._pytree.Context]:
87
103
  """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
88
- if hasattr(transformers.cache_utils, "_flatten_dynamic_cache"):
89
- return transformers.cache_utils._flatten_dynamic_cache(dynamic_cache)
90
- flat = [("key_cache", dynamic_cache.key_cache), ("value_cache", dynamic_cache.value_cache)]
104
+ ca = CacheKeyValue(dynamic_cache)
105
+ flat = [("key_cache", ca.key_cache), ("value_cache", ca.value_cache)]
91
106
  return [f[1] for f in flat], [f[0] for f in flat]
92
107
 
93
108
 
@@ -95,8 +110,6 @@ def flatten_with_keys_dynamic_cache(
95
110
  dynamic_cache: DynamicCache,
96
111
  ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
97
112
  """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
98
- if hasattr(transformers.cache_utils, "_flatten_with_keys_dynamic_cache"):
99
- return transformers.cache_utils._flatten_with_keys_dynamic_cache(dynamic_cache)
100
113
  values, context = flatten_dynamic_cache(dynamic_cache)
101
114
  return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
102
115
 
@@ -105,15 +118,36 @@ def unflatten_dynamic_cache(
105
118
  values: List[Any], context: torch.utils._pytree.Context, output_type=None
106
119
  ) -> DynamicCache:
107
120
  """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
108
- if hasattr(transformers.cache_utils, "_unflatten_dynamic_cache"):
109
- assert output_type is None, f"output_type={output_type} not supported"
110
- return transformers.cache_utils._unflatten_dynamic_cache(values, context)
121
+ return make_dynamic_cache(list(zip(values[0], values[1])))
111
122
 
112
- cache = transformers.cache_utils.DynamicCache()
113
- values = dict(zip(context, values))
114
- for k, v in values.items():
115
- setattr(cache, k, v)
116
- return cache
123
+
124
+ #############
125
+ # HybridCache
126
+ #############
127
+
128
+
129
+ def flatten_hybrid_cache(
130
+ cache: HybridCache,
131
+ ) -> Tuple[List[Any], torch.utils._pytree.Context]:
132
+ """Serializes a :class:`transformers.cache_utils.HybridCache` with python objects."""
133
+ ca = CacheKeyValue(cache)
134
+ flat = [("key_cache", ca.key_cache), ("value_cache", ca.value_cache)]
135
+ return [f[1] for f in flat], [f[0] for f in flat]
136
+
137
+
138
+ def flatten_with_keys_hybrid_cache(
139
+ cache: HybridCache,
140
+ ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
141
+ """Serializes a :class:`transformers.cache_utils.HybridCache` with python objects."""
142
+ values, context = flatten_hybrid_cache(cache)
143
+ return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
144
+
145
+
146
+ def unflatten_hybrid_cache(
147
+ values: List[Any], context: torch.utils._pytree.Context, output_type=None
148
+ ) -> HybridCache:
149
+ """Restores a :class:`transformers.cache_utils.HybridCache` from python objects."""
150
+ return make_hybrid_cache(list(zip(values[0], values[1])))
117
151
 
118
152
 
119
153
  #############
@@ -125,12 +159,13 @@ def flatten_static_cache(
125
159
  cache: StaticCache,
126
160
  ) -> Tuple[List[Any], torch.utils._pytree.Context]:
127
161
  """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
128
- assert not cache.key_cache or cache.max_cache_len == cache.key_cache[0].shape[2], (
162
+ ca = CacheKeyValue(cache)
163
+ assert not ca.key_cache or cache.max_cache_len == ca.key_cache[0].shape[2], (
129
164
  f"Serialization doet not work when "
130
165
  f"cache.max_cache_len={cache.max_cache_len} != "
131
- f"cache.key_cache[0].shape[2]={cache.key_cache[0].shape[2]}"
166
+ f"cache.key_cache[0].shape[2]={ca.keu_cache[0].shape[2]}"
132
167
  )
133
- flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)]
168
+ flat = [("key_cache", ca.key_cache), ("value_cache", ca.value_cache)]
134
169
  return [f[1] for f in flat], [f[0] for f in flat]
135
170
 
136
171
 
@@ -163,7 +198,8 @@ def flatten_sliding_window_cache(
163
198
  Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
164
199
  with python objects.
165
200
  """
166
- flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)]
201
+ ca = CacheKeyValue(cache)
202
+ flat = [("key_cache", ca.key_cache), ("value_cache", ca.value_cache)]
167
203
  return [f[1] for f in flat], [f[0] for f in flat]
168
204
 
169
205
 
@@ -183,26 +219,7 @@ def unflatten_sliding_window_cache(
183
219
  ) -> SlidingWindowCache:
184
220
  """Restores a :class:`transformers.cache_utils.SlidingWindowCache` from python objects."""
185
221
  key_cache, value_cache = values
186
-
187
- class _config:
188
- def __init__(self):
189
- self.head_dim = key_cache[0].shape[-1]
190
- self.num_attention_heads = key_cache[0].shape[1]
191
- self.num_hidden_layers = len(key_cache)
192
- self.sliding_window = key_cache[0].shape[2]
193
-
194
- cache = SlidingWindowCache(
195
- _config(),
196
- max_batch_size=key_cache[0].shape[0],
197
- max_cache_len=key_cache[0].shape[2], # sligding window
198
- device=key_cache[0].device,
199
- dtype=key_cache[0].dtype,
200
- )
201
-
202
- values = dict(zip(context, values))
203
- for k, v in values.items():
204
- setattr(cache, k, v)
205
- return cache
222
+ return make_sliding_window_cache(list(zip(values[0], values[1])))
206
223
 
207
224
 
208
225
  #####################
@@ -1366,6 +1366,236 @@ def _ccached_fxmarty_tiny_random_gemmaforcausallm():
1366
1366
  )
1367
1367
 
1368
1368
 
1369
+ def _ccached_fxmarty_sam_vit_tiny_random():
1370
+ "fxmarty/sam-vit-tiny-random"
1371
+ return transformers.SamConfig(
1372
+ **{
1373
+ "_commit_hash": "a7c34ea5d2b33a3bc34d34dc9a7b2417c0eaa809",
1374
+ "_name_or_path": "facebook/sam-vit-base",
1375
+ "architectures": ["SamModel"],
1376
+ "initializer_range": 0.02,
1377
+ "mask_decoder_config": {
1378
+ "_name_or_path": "",
1379
+ "add_cross_attention": false,
1380
+ "architectures": null,
1381
+ "attention_downsample_rate": 2,
1382
+ "bad_words_ids": null,
1383
+ "begin_suppress_tokens": null,
1384
+ "bos_token_id": null,
1385
+ "chunk_size_feed_forward": 0,
1386
+ "cross_attention_hidden_size": null,
1387
+ "decoder_start_token_id": null,
1388
+ "diversity_penalty": 0.0,
1389
+ "do_sample": false,
1390
+ "early_stopping": false,
1391
+ "encoder_no_repeat_ngram_size": 0,
1392
+ "eos_token_id": null,
1393
+ "exponential_decay_length_penalty": null,
1394
+ "finetuning_task": null,
1395
+ "forced_bos_token_id": null,
1396
+ "forced_eos_token_id": null,
1397
+ "hidden_act": "relu",
1398
+ "hidden_size": 32,
1399
+ "id2label": {"0": "LABEL_0", "1": "LABEL_1"},
1400
+ "iou_head_depth": 3,
1401
+ "iou_head_hidden_dim": 256,
1402
+ "is_decoder": false,
1403
+ "is_encoder_decoder": false,
1404
+ "label2id": {"LABEL_0": 0, "LABEL_1": 1},
1405
+ "layer_norm_eps": 1e-06,
1406
+ "length_penalty": 1.0,
1407
+ "max_length": 20,
1408
+ "min_length": 0,
1409
+ "mlp_dim": 2048,
1410
+ "model_type": "",
1411
+ "no_repeat_ngram_size": 0,
1412
+ "num_attention_heads": 8,
1413
+ "num_beam_groups": 1,
1414
+ "num_beams": 1,
1415
+ "num_hidden_layers": 2,
1416
+ "num_multimask_outputs": 3,
1417
+ "num_return_sequences": 1,
1418
+ "output_attentions": false,
1419
+ "output_hidden_states": false,
1420
+ "output_scores": false,
1421
+ "pad_token_id": null,
1422
+ "prefix": null,
1423
+ "problem_type": null,
1424
+ "pruned_heads": {},
1425
+ "remove_invalid_values": false,
1426
+ "repetition_penalty": 1.0,
1427
+ "return_dict": true,
1428
+ "return_dict_in_generate": false,
1429
+ "sep_token_id": null,
1430
+ "suppress_tokens": null,
1431
+ "task_specific_params": null,
1432
+ "temperature": 1.0,
1433
+ "tf_legacy_loss": false,
1434
+ "tie_encoder_decoder": false,
1435
+ "tie_word_embeddings": true,
1436
+ "tokenizer_class": null,
1437
+ "top_k": 50,
1438
+ "top_p": 1.0,
1439
+ "torch_dtype": null,
1440
+ "torchscript": false,
1441
+ "transformers_version": "4.29.0.dev0",
1442
+ "typical_p": 1.0,
1443
+ "use_bfloat16": false,
1444
+ },
1445
+ "model_type": "sam",
1446
+ "prompt_encoder_config": {
1447
+ "_name_or_path": "",
1448
+ "add_cross_attention": false,
1449
+ "architectures": null,
1450
+ "bad_words_ids": null,
1451
+ "begin_suppress_tokens": null,
1452
+ "bos_token_id": null,
1453
+ "chunk_size_feed_forward": 0,
1454
+ "cross_attention_hidden_size": null,
1455
+ "decoder_start_token_id": null,
1456
+ "diversity_penalty": 0.0,
1457
+ "do_sample": false,
1458
+ "early_stopping": false,
1459
+ "encoder_no_repeat_ngram_size": 0,
1460
+ "eos_token_id": null,
1461
+ "exponential_decay_length_penalty": null,
1462
+ "finetuning_task": null,
1463
+ "forced_bos_token_id": null,
1464
+ "forced_eos_token_id": null,
1465
+ "hidden_act": "gelu",
1466
+ "hidden_size": 32,
1467
+ "id2label": {"0": "LABEL_0", "1": "LABEL_1"},
1468
+ "image_embedding_size": 64,
1469
+ "image_size": 1024,
1470
+ "is_decoder": false,
1471
+ "is_encoder_decoder": false,
1472
+ "label2id": {"LABEL_0": 0, "LABEL_1": 1},
1473
+ "layer_norm_eps": 1e-06,
1474
+ "length_penalty": 1.0,
1475
+ "mask_input_channels": 16,
1476
+ "max_length": 20,
1477
+ "min_length": 0,
1478
+ "model_type": "",
1479
+ "no_repeat_ngram_size": 0,
1480
+ "num_beam_groups": 1,
1481
+ "num_beams": 1,
1482
+ "num_point_embeddings": 4,
1483
+ "num_return_sequences": 1,
1484
+ "output_attentions": false,
1485
+ "output_hidden_states": false,
1486
+ "output_scores": false,
1487
+ "pad_token_id": null,
1488
+ "patch_size": 16,
1489
+ "prefix": null,
1490
+ "problem_type": null,
1491
+ "pruned_heads": {},
1492
+ "remove_invalid_values": false,
1493
+ "repetition_penalty": 1.0,
1494
+ "return_dict": true,
1495
+ "return_dict_in_generate": false,
1496
+ "sep_token_id": null,
1497
+ "suppress_tokens": null,
1498
+ "task_specific_params": null,
1499
+ "temperature": 1.0,
1500
+ "tf_legacy_loss": false,
1501
+ "tie_encoder_decoder": false,
1502
+ "tie_word_embeddings": true,
1503
+ "tokenizer_class": null,
1504
+ "top_k": 50,
1505
+ "top_p": 1.0,
1506
+ "torch_dtype": null,
1507
+ "torchscript": false,
1508
+ "transformers_version": "4.29.0.dev0",
1509
+ "typical_p": 1.0,
1510
+ "use_bfloat16": false,
1511
+ },
1512
+ "torch_dtype": "float32",
1513
+ "transformers_version": null,
1514
+ "vision_config": {
1515
+ "_name_or_path": "",
1516
+ "add_cross_attention": false,
1517
+ "architectures": null,
1518
+ "attention_dropout": 0.0,
1519
+ "bad_words_ids": null,
1520
+ "begin_suppress_tokens": null,
1521
+ "bos_token_id": null,
1522
+ "chunk_size_feed_forward": 0,
1523
+ "cross_attention_hidden_size": null,
1524
+ "decoder_start_token_id": null,
1525
+ "diversity_penalty": 0.0,
1526
+ "do_sample": false,
1527
+ "dropout": 0.0,
1528
+ "early_stopping": false,
1529
+ "encoder_no_repeat_ngram_size": 0,
1530
+ "eos_token_id": null,
1531
+ "exponential_decay_length_penalty": null,
1532
+ "finetuning_task": null,
1533
+ "forced_bos_token_id": null,
1534
+ "forced_eos_token_id": null,
1535
+ "global_attn_indexes": [2, 5, 8, 11],
1536
+ "hidden_act": "gelu",
1537
+ "hidden_size": 96,
1538
+ "id2label": {"0": "LABEL_0", "1": "LABEL_1"},
1539
+ "image_size": 1024,
1540
+ "initializer_factor": 1.0,
1541
+ "initializer_range": 1e-10,
1542
+ "intermediate_size": 768,
1543
+ "is_decoder": false,
1544
+ "is_encoder_decoder": false,
1545
+ "label2id": {"LABEL_0": 0, "LABEL_1": 1},
1546
+ "layer_norm_eps": 1e-06,
1547
+ "length_penalty": 1.0,
1548
+ "max_length": 20,
1549
+ "min_length": 0,
1550
+ "mlp_dim": 384,
1551
+ "mlp_ratio": 4.0,
1552
+ "model_type": "",
1553
+ "no_repeat_ngram_size": 0,
1554
+ "num_attention_heads": 1,
1555
+ "num_beam_groups": 1,
1556
+ "num_beams": 1,
1557
+ "num_channels": 3,
1558
+ "num_hidden_layers": 12,
1559
+ "num_pos_feats": 16,
1560
+ "num_return_sequences": 1,
1561
+ "output_attentions": false,
1562
+ "output_channels": 32,
1563
+ "output_hidden_states": false,
1564
+ "output_scores": false,
1565
+ "pad_token_id": null,
1566
+ "patch_size": 16,
1567
+ "prefix": null,
1568
+ "problem_type": null,
1569
+ "projection_dim": 64,
1570
+ "pruned_heads": {},
1571
+ "qkv_bias": true,
1572
+ "remove_invalid_values": false,
1573
+ "repetition_penalty": 1.0,
1574
+ "return_dict": true,
1575
+ "return_dict_in_generate": false,
1576
+ "sep_token_id": null,
1577
+ "suppress_tokens": null,
1578
+ "task_specific_params": null,
1579
+ "temperature": 1.0,
1580
+ "tf_legacy_loss": false,
1581
+ "tie_encoder_decoder": false,
1582
+ "tie_word_embeddings": true,
1583
+ "tokenizer_class": null,
1584
+ "top_k": 50,
1585
+ "top_p": 1.0,
1586
+ "torch_dtype": null,
1587
+ "torchscript": false,
1588
+ "transformers_version": "4.29.0.dev0",
1589
+ "typical_p": 1.0,
1590
+ "use_abs_pos": true,
1591
+ "use_bfloat16": false,
1592
+ "use_rel_pos": true,
1593
+ "window_size": 14,
1594
+ },
1595
+ }
1596
+ )
1597
+
1598
+
1369
1599
  def _ccached_hf_internal_testing_tiny_random_gptneoxforcausallm():
1370
1600
  "hf-internal-testing/tiny-random-GPTNeoXForCausalLM"
1371
1601
  return transformers.GPTNeoXConfig(
@@ -4330,3 +4560,61 @@ def _ccached_diffusers_tiny_torch_full_checker_unet():
4330
4560
  "up_block_types": ["CrossAttnUpBlock2D", "UpBlock2D"],
4331
4561
  "use_linear_projection": false,
4332
4562
  }
4563
+
4564
+
4565
+ def _ccached_riny_random_gemma_3():
4566
+ "tiny-random/gemma-3"
4567
+ return transformers.Gemma3Config(
4568
+ **{
4569
+ "architectures": ["Gemma3ForConditionalGeneration"],
4570
+ "boi_token_index": 255999,
4571
+ "eoi_token_index": 256000,
4572
+ "eos_token_id": [1, 106],
4573
+ "image_token_index": 262144,
4574
+ "initializer_range": 0.02,
4575
+ "mm_tokens_per_image": 256,
4576
+ "model_type": "gemma3",
4577
+ "text_config": {
4578
+ "attention_bias": false,
4579
+ "attention_dropout": 0.0,
4580
+ "attn_logit_softcapping": null,
4581
+ "cache_implementation": "hybrid",
4582
+ "final_logit_softcapping": null,
4583
+ "head_dim": 32,
4584
+ "hidden_activation": "gelu_pytorch_tanh",
4585
+ "hidden_size": 32,
4586
+ "initializer_range": 0.02,
4587
+ "intermediate_size": 128,
4588
+ "max_position_embeddings": 131072,
4589
+ "model_type": "gemma3_text",
4590
+ "num_attention_heads": 1,
4591
+ "num_hidden_layers": 2,
4592
+ "num_key_value_heads": 1,
4593
+ "query_pre_attn_scalar": 168,
4594
+ "rms_norm_eps": 1e-06,
4595
+ "rope_local_base_freq": 10000.0,
4596
+ "rope_scaling": {"factor": 8.0, "rope_type": "linear"},
4597
+ "rope_theta": 1000000.0,
4598
+ "sliding_window": 1024,
4599
+ "sliding_window_pattern": 2,
4600
+ "use_cache": true,
4601
+ "vocab_size": 262208,
4602
+ },
4603
+ "torch_dtype": "bfloat16",
4604
+ "transformers_version": "4.50.0.dev0",
4605
+ "vision_config": {
4606
+ "attention_dropout": 0.0,
4607
+ "hidden_act": "gelu_pytorch_tanh",
4608
+ "hidden_size": 32,
4609
+ "image_size": 896,
4610
+ "intermediate_size": 128,
4611
+ "layer_norm_eps": 1e-06,
4612
+ "model_type": "siglip_vision_model",
4613
+ "num_attention_heads": 1,
4614
+ "num_channels": 3,
4615
+ "num_hidden_layers": 2,
4616
+ "patch_size": 14,
4617
+ "vision_use_head": false,
4618
+ },
4619
+ }
4620
+ )
@@ -144,6 +144,11 @@ def get_untrained_model_with_inputs(
144
144
  f"[get_untrained_model_with_inputs] config._attn_implementation="
145
145
  f"{config._attn_implementation!r}" # type: ignore[union-attr]
146
146
  )
147
+ elif verbose:
148
+ print(
149
+ f"[get_untrained_model_with_inputs] default config._attn_implementation="
150
+ f"{getattr(config, '_attn_implementation', '?')!r}" # type: ignore[union-attr]
151
+ )
147
152
 
148
153
  if type(config) is dict and "_diffusers_version" in config:
149
154
  import diffusers
@@ -288,6 +288,7 @@ def validate_model(
288
288
  repeat: int = 1,
289
289
  warmup: int = 0,
290
290
  inputs2: int = 1,
291
+ output_names: Optional[List[str]] = None,
291
292
  ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
292
293
  """
293
294
  Validates a model.
@@ -338,6 +339,7 @@ def validate_model(
338
339
  :param inputs2: checks that the second set of inputs is reunning as well,
339
340
  this ensures that the model does support dynamism, the value is used
340
341
  as an increment to the first set of values (added to dimensions)
342
+ :param output_names: output names the onnx exporter should use
341
343
  :return: two dictionaries, one with some metrics,
342
344
  another one with whatever the function produces
343
345
 
@@ -433,6 +435,7 @@ def validate_model(
433
435
  )
434
436
  print(f"[validate_model] exporter={exporter!r}, optimization={optimization!r}")
435
437
  print(f"[validate_model] dump_folder={dump_folder!r}")
438
+ print(f"[validate_model] output_names={output_names}")
436
439
  summary["model_id"] = model_id
437
440
  summary["model_subfolder"] = subfolder or ""
438
441
 
@@ -631,6 +634,7 @@ def validate_model(
631
634
  optimization=optimization,
632
635
  do_run=do_run,
633
636
  dump_folder=dump_folder,
637
+ output_names=output_names,
634
638
  )
635
639
  else:
636
640
  data["inputs_export"] = data["inputs"]
@@ -643,6 +647,7 @@ def validate_model(
643
647
  optimization=optimization,
644
648
  do_run=do_run,
645
649
  dump_folder=dump_folder,
650
+ output_names=output_names,
646
651
  )
647
652
  summary.update(summary_export)
648
653
 
@@ -868,6 +873,7 @@ def call_exporter(
868
873
  optimization: Optional[str] = None,
869
874
  do_run: bool = False,
870
875
  dump_folder: Optional[str] = None,
876
+ output_names: Optional[List[str]] = None,
871
877
  ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
872
878
  """
873
879
  Calls an exporter on a model;
@@ -880,6 +886,7 @@ def call_exporter(
880
886
  :param optimization: optimization to do
881
887
  :param do_run: runs and compute discrepancies
882
888
  :param dump_folder: to dump additional information
889
+ :param output_names: list of output names to use with the onnx exporter
883
890
  :return: two dictionaries, one with some metrics,
884
891
  another one with whatever the function produces
885
892
  """
@@ -902,6 +909,7 @@ def call_exporter(
902
909
  quiet=quiet,
903
910
  verbose=verbose,
904
911
  optimization=optimization,
912
+ output_names=output_names,
905
913
  )
906
914
  return summary, data
907
915
  if exporter == "custom" or exporter.startswith("custom"):
@@ -913,6 +921,7 @@ def call_exporter(
913
921
  verbose=verbose,
914
922
  optimization=optimization,
915
923
  dump_folder=dump_folder,
924
+ output_names=output_names,
916
925
  )
917
926
  return summary, data
918
927
  if exporter == "modelbuilder":
@@ -923,6 +932,7 @@ def call_exporter(
923
932
  quiet=quiet,
924
933
  verbose=verbose,
925
934
  optimization=optimization,
935
+ output_names=output_names,
926
936
  )
927
937
  return summary, data
928
938
  raise NotImplementedError(
@@ -1090,7 +1100,7 @@ def validate_onnx_model(
1090
1100
  """
1091
1101
  import onnxruntime
1092
1102
 
1093
- def _mk(key):
1103
+ def _mk(key, flavour=flavour):
1094
1104
  return f"{key}_{flavour}" if flavour else key
1095
1105
 
1096
1106
  summary: Dict[str, Any] = {}
@@ -1145,7 +1155,7 @@ def validate_onnx_model(
1145
1155
  )
1146
1156
  sess = _quiet_or_not_quiet(
1147
1157
  quiet,
1148
- _mk("onnx_ort_create"),
1158
+ _mk("create_onnx_ort"),
1149
1159
  summary,
1150
1160
  data,
1151
1161
  (lambda source=source, providers=providers: cls_runtime(source, providers)),
@@ -1180,7 +1190,7 @@ def validate_onnx_model(
1180
1190
 
1181
1191
  got = _quiet_or_not_quiet(
1182
1192
  quiet,
1183
- _mk(f"time_onnx_ort_run{suffix}"),
1193
+ _mk(f"run_onnx_ort{suffix}"),
1184
1194
  summary,
1185
1195
  data,
1186
1196
  (lambda sess=sess, feeds=feeds: sess.run(None, feeds)),
@@ -1211,6 +1221,7 @@ def call_torch_export_onnx(
1211
1221
  quiet: bool = False,
1212
1222
  verbose: int = 0,
1213
1223
  optimization: Optional[str] = None,
1224
+ output_names: Optional[List[str]] = None,
1214
1225
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1215
1226
  """
1216
1227
  Exports a model into onnx.
@@ -1222,6 +1233,7 @@ def call_torch_export_onnx(
1222
1233
  :param quiet: catch exception or not
1223
1234
  :param verbose: verbosity
1224
1235
  :param optimization: optimization to do
1236
+ :param output_names: output names to use
1225
1237
  :return: two dictionaries, one with some metrics,
1226
1238
  another one with whatever the function produces
1227
1239
  """
@@ -1276,6 +1288,8 @@ def call_torch_export_onnx(
1276
1288
  print("[call_torch_export_onnx] dynamo=False so...")
1277
1289
  print(f"[call_torch_export_onnx] args={string_type(args, with_shape=True)}")
1278
1290
  print(f"[call_torch_export_onnx] kwargs={string_type(kwargs, with_shape=True)}")
1291
+ if output_names:
1292
+ export_export_kwargs["output_names"] = output_names
1279
1293
  if opset:
1280
1294
  export_export_kwargs["opset_version"] = opset
1281
1295
  if verbose:
@@ -1346,6 +1360,7 @@ def call_torch_export_model_builder(
1346
1360
  quiet: bool = False,
1347
1361
  verbose: int = 0,
1348
1362
  optimization: Optional[str] = None,
1363
+ output_names: Optional[List[str]] = None,
1349
1364
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1350
1365
  """
1351
1366
  Exports a model into onnx with :epkg:`ModelBuilder`.
@@ -1356,6 +1371,7 @@ def call_torch_export_model_builder(
1356
1371
  :param quiet: catch exception or not
1357
1372
  :param verbose: verbosity
1358
1373
  :param optimization: optimization to do
1374
+ :param output_names: list of output names to use
1359
1375
  :return: two dictionaries, one with some metrics,
1360
1376
  another one with whatever the function produces
1361
1377
  """
@@ -1369,6 +1385,9 @@ def call_torch_export_model_builder(
1369
1385
  provider = data.get("model_device", "cpu")
1370
1386
  dump_folder = data.get("model_dump_folder", "")
1371
1387
  assert dump_folder, "dump_folder cannot be empty with ModelBuilder"
1388
+ assert (
1389
+ not output_names
1390
+ ), f"output_names not empty, not supported yet, output_names={output_names}"
1372
1391
  cache_dir = os.path.join(dump_folder, "cache_mb")
1373
1392
  if not os.path.exists(cache_dir):
1374
1393
  os.makedirs(cache_dir)
@@ -1408,6 +1427,7 @@ def call_torch_export_custom(
1408
1427
  verbose: int = 0,
1409
1428
  optimization: Optional[str] = None,
1410
1429
  dump_folder: Optional[str] = None,
1430
+ output_names: Optional[List[str]] = None,
1411
1431
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1412
1432
  """
1413
1433
  Exports a model into onnx.
@@ -1420,6 +1440,7 @@ def call_torch_export_custom(
1420
1440
  :param verbose: verbosity
1421
1441
  :param optimization: optimization to do
1422
1442
  :param dump_folder: to store additional information
1443
+ :param output_names: list of output names to use
1423
1444
  :return: two dictionaries, one with some metrics,
1424
1445
  another one with whatever the function produces
1425
1446
  """
@@ -1504,6 +1525,8 @@ def call_torch_export_custom(
1504
1525
  )
1505
1526
  if opset:
1506
1527
  kws["target_opset"] = opset
1528
+ if output_names:
1529
+ kws["output_names"] = output_names
1507
1530
 
1508
1531
  epo, opt_stats = _quiet_or_not_quiet(
1509
1532
  quiet,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-diagnostic
3
- Version: 0.7.4
3
+ Version: 0.7.6
4
4
  Summary: Investigate ONNX models
5
5
  Home-page: https://github.com/sdpython/onnx-diagnostic
6
6
  Author: Xavier Dupré