onnx-diagnostic 0.8.10__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +136 -140
  3. onnx_diagnostic/ci_models/data/Blanca_Lake_Hudak.jpg +0 -0
  4. onnx_diagnostic/ci_models/data/Ice_worm_glacier.jpg +0 -0
  5. onnx_diagnostic/ci_models/data/__init__.py +0 -0
  6. onnx_diagnostic/ci_models/export_phi4_mm.py +10 -7
  7. onnx_diagnostic/export/api.py +13 -4
  8. onnx_diagnostic/export/dynamic_shapes.py +1 -1
  9. onnx_diagnostic/export/validate.py +2 -0
  10. onnx_diagnostic/ext_test_case.py +32 -15
  11. onnx_diagnostic/helpers/args_helper.py +1 -0
  12. onnx_diagnostic/helpers/bench_run.py +0 -1
  13. onnx_diagnostic/helpers/cache_helper.py +102 -36
  14. onnx_diagnostic/helpers/doc_helper.py +7 -4
  15. onnx_diagnostic/helpers/graph_helper.py +6 -6
  16. onnx_diagnostic/helpers/helper.py +39 -0
  17. onnx_diagnostic/helpers/log_helper.py +37 -14
  18. onnx_diagnostic/helpers/memory_peak.py +5 -1
  19. onnx_diagnostic/helpers/mini_onnx_builder.py +9 -14
  20. onnx_diagnostic/helpers/model_builder_helper.py +1 -1
  21. onnx_diagnostic/helpers/onnx_helper.py +283 -110
  22. onnx_diagnostic/helpers/ort_session.py +5 -2
  23. onnx_diagnostic/helpers/rt_helper.py +53 -9
  24. onnx_diagnostic/helpers/torch_helper.py +15 -11
  25. onnx_diagnostic/investigate/__init__.py +0 -0
  26. onnx_diagnostic/investigate/input_observer.py +970 -0
  27. onnx_diagnostic/reference/evaluator.py +0 -1
  28. onnx_diagnostic/reference/ort_evaluator.py +0 -1
  29. onnx_diagnostic/reference/report_results_comparison.py +9 -3
  30. onnx_diagnostic/reference/torch_evaluator.py +5 -1
  31. onnx_diagnostic/reference/torch_ops/_op_run.py +3 -5
  32. onnx_diagnostic/reference/torch_ops/sequence_ops.py +1 -1
  33. onnx_diagnostic/tasks/feature_extraction.py +0 -1
  34. onnx_diagnostic/torch_export_patches/__init__.py +0 -1
  35. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +32 -14
  36. onnx_diagnostic/torch_export_patches/patch_module.py +1 -1
  37. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +107 -6
  38. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
  39. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +13 -3
  40. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +1 -0
  41. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +70 -23
  42. onnx_diagnostic/torch_models/code_sample.py +5 -10
  43. onnx_diagnostic/torch_models/hghub/hub_data.py +2 -4
  44. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +6 -12
  45. onnx_diagnostic/torch_models/validate.py +1 -1
  46. onnx_diagnostic/torch_onnx/compare.py +0 -1
  47. onnx_diagnostic/torch_onnx/runtime_info.py +1 -1
  48. onnx_diagnostic/torch_onnx/sbs.py +1 -1
  49. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +2 -4
  50. onnx_diagnostic/typing.py +15 -0
  51. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/METADATA +2 -2
  52. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/RECORD +55 -50
  53. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/WHEEL +1 -1
  54. onnx_diagnostic/api.py +0 -15
  55. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/licenses/LICENSE.txt +0 -0
  56. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/top_level.txt +0 -0
@@ -115,7 +115,7 @@ def make_feeds(
115
115
  def _get_dim(i: int, s: Union[str, int], batch: int = 1) -> int:
116
116
  if isinstance(s, int):
117
117
  return s
118
- if s == "batch":
118
+ if s == "batch" or i == 0:
119
119
  return batch
120
120
  # Everything else is cache length or sequence length.
121
121
  return 0
@@ -153,9 +153,13 @@ def make_empty_cache(
153
153
  [i.type for i in sess.get_inputs()[2:]],
154
154
  )
155
155
  """
156
+ assert batch > 0, f"batch size = {batch} must be positive"
156
157
  feeds = {}
157
158
  for name, shape, dtype in zip(onnx_input_names, onnx_input_shapes, onnx_input_types):
158
159
  new_shape = tuple(_get_dim(i, s, batch=batch) for i, s in enumerate(shape))
160
+ assert (
161
+ new_shape and new_shape[0] > 0
162
+ ), f"new_shape={new_shape} cannot have a null batch size, name={name!r}, shape={shape}"
159
163
  feeds[name] = torch.empty(new_shape, dtype=rt_type_to_torch_dtype(dtype))
160
164
  return feeds
161
165
 
@@ -272,6 +276,7 @@ def generate_and_validate(
272
276
  def onnx_generate(
273
277
  model_or_path: Union[onnx.ModelProto, str, InferenceSessionForTorch],
274
278
  input_ids: torch.Tensor,
279
+ attention_mask: Optional[torch.Tensor] = None,
275
280
  eos_token_id: int = 2,
276
281
  max_new_tokens=100,
277
282
  return_session: bool = False,
@@ -330,7 +335,9 @@ def onnx_generate(
330
335
  )
331
336
 
332
337
  print("-- generate with onnx")
333
- onnx_outputs = onnx_generate(model_name, input_ids[:1], 2, max_new_tokens=10)
338
+ onnx_outputs = onnx_generate(
339
+ model_name, input_ids[:1], eos_token_id=2, max_new_tokens=10
340
+ )
334
341
  print("-- onnx output", onnx_outputs)
335
342
 
336
343
  # The example continues with other functions doing the same.
@@ -364,6 +371,7 @@ def onnx_generate(
364
371
  input_names = session.input_names
365
372
  input_types = session.input_types
366
373
  has_position_ids = "position_ids" in session.input_names
374
+ has_cache_position = "cache_position" in session.input_names
367
375
 
368
376
  assert (
369
377
  len(input_names) > 2
@@ -377,21 +385,46 @@ def onnx_generate(
377
385
  not has_position_ids or input_names[2] == "position_ids"
378
386
  ), f"position_ids must the third input but input_names={input_names}"
379
387
 
388
+ cache_names, cache_shapes, cache_types = [], [], []
389
+ for name, shape, dt in zip(input_names, input_shapes, input_types):
390
+ if name.startswith("past_key_values"):
391
+ cache_names.append(name)
392
+ cache_shapes.append(shape)
393
+ cache_types.append(dt)
394
+
380
395
  # First call: prefill
396
+ empty_cache = make_empty_cache(input_ids.shape[0], cache_names, cache_shapes, cache_types)
381
397
  feeds = dict(
382
398
  input_ids=input_ids,
383
- attention_mask=torch.ones(
384
- input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
385
- ),
386
- **make_empty_cache(
387
- input_ids.shape[0], input_names[2:], input_shapes[2:], input_types[2:]
399
+ attention_mask=(
400
+ attention_mask
401
+ if attention_mask is not None
402
+ else torch.ones(input_ids.shape, dtype=input_ids.dtype, device=input_ids.device)
388
403
  ),
404
+ **empty_cache,
389
405
  )
406
+
390
407
  if has_position_ids:
391
- feeds["position_ids"] = torch.unsqueeze(
408
+ assert (
409
+ input_ids.shape[1] > 0
410
+ ), f"unexpected value for input_ids shape={input_ids.shape}"
411
+ position_ids = torch.unsqueeze(
392
412
  torch.arange(input_ids.shape[1], dtype=torch.int64, device=input_ids.device), 0
393
413
  )
414
+ feeds["position_ids"] = position_ids
415
+
416
+ if has_cache_position:
417
+ assert empty_cache, "no cache means no cache_position"
418
+ first_tensor = next(iter(empty_cache.values()))
419
+ cache_position = torch.arange(
420
+ first_tensor.shape[2],
421
+ input_ids.shape[1] + first_tensor.shape[2],
422
+ dtype=torch.int64,
423
+ device=input_ids.device,
424
+ )
425
+ feeds["cache_position"] = cache_position
394
426
 
427
+ # prefill step
395
428
  outputs = session.run(None, feeds)
396
429
 
397
430
  # Next calls: decode
@@ -424,7 +457,18 @@ def onnx_generate(
424
457
  ),
425
458
  0,
426
459
  )
427
- feeds.update(dict(zip(input_names[3 if has_position_ids else 2 :], outputs[1:])))
460
+ if has_cache_position:
461
+ feeds["cache_position"] = torch.arange(
462
+ input_ids.shape[1],
463
+ input_ids.shape[1] + 1,
464
+ dtype=torch.int64,
465
+ device=input_ids.device,
466
+ )
467
+
468
+ feeds.update(
469
+ dict(zip([n for n in input_names if n.startswith("past_key_values")], outputs[1:]))
470
+ )
471
+ # generate/decoding step
428
472
  outputs = session.run(None, feeds)
429
473
 
430
474
  if return_session:
@@ -19,12 +19,7 @@ from .cache_helper import (
19
19
  CacheKeyValue,
20
20
  )
21
21
  from .mini_onnx_builder import create_onnx_model_from_input_tensors
22
- from .onnx_helper import (
23
- to_array_extended,
24
- tensor_dtype_to_np_dtype,
25
- _STORAGE_TYPE,
26
- onnx_dtype_name,
27
- )
22
+ from .onnx_helper import to_array_extended, tensor_dtype_to_np_dtype, onnx_dtype_name
28
23
 
29
24
 
30
25
  def proto_from_tensor(
@@ -84,13 +79,17 @@ def proto_from_tensor(
84
79
  byte_data = (ctypes.c_ubyte * numel * element_size).from_address(np_arr.data_ptr())
85
80
  tensor.raw_data = bytes(byte_data)
86
81
  if sys.byteorder == "big":
87
- np_dtype = _STORAGE_TYPE[tensor.data_type] # type: ignore
88
- np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True) # type: ignore
82
+ storage_type = {
83
+ onnx.TensorProto.FLOAT16: np.int16,
84
+ onnx.TensorProto.BFLOAT16: np.int16,
85
+ }
86
+ np_dtype = storage_type[tensor.data_type] # type: ignore
87
+ np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap(inplace=True) # type: ignore
89
88
  else:
90
89
  tensor.raw_data = np_arr.tobytes()
91
90
  if sys.byteorder == "big":
92
91
  np_dtype = tensor_dtype_to_np_dtype(tensor.data_type)
93
- np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True)
92
+ np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap(inplace=True)
94
93
  return tensor
95
94
 
96
95
 
@@ -852,9 +851,14 @@ def torch_deepcopy(value: Any) -> Any:
852
851
  from .cache_helper import CacheKeyValue
853
852
 
854
853
  ca = CacheKeyValue(value)
855
- return make_dynamic_cache(
856
- torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))), cls_layers=ca.cls_layers
854
+ pairs = list(zip(ca.key_cache, ca.value_cache))
855
+ assert not hasattr(value, "layers") or len(value.layers) == len(pairs), (
856
+ f"Size mismatch between {len(value.layers)=} and {len(pairs)=}. "
857
+ f"value={string_type(value, with_shape=True)}, "
858
+ f"first key={value.layers[0].keys}, "
859
+ f"first value={value.layers[0].values}"
857
860
  )
861
+ return make_dynamic_cache(torch_deepcopy(pairs), cls_layers=ca.cls_layers)
858
862
  if value.__class__.__name__ == "StaticCache":
859
863
  from .cache_helper import CacheKeyValue
860
864
 
File without changes