onnx-diagnostic 0.5.0__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. onnx_diagnostic/__init__.py +2 -2
  2. onnx_diagnostic/_command_lines_parser.py +39 -1
  3. onnx_diagnostic/api.py +15 -0
  4. onnx_diagnostic/export/dynamic_shapes.py +14 -5
  5. onnx_diagnostic/ext_test_case.py +15 -1
  6. onnx_diagnostic/helpers/args_helper.py +1 -1
  7. onnx_diagnostic/helpers/graph_helper.py +386 -0
  8. onnx_diagnostic/helpers/helper.py +30 -5
  9. onnx_diagnostic/helpers/model_builder_helper.py +349 -0
  10. onnx_diagnostic/helpers/rt_helper.py +69 -1
  11. onnx_diagnostic/helpers/torch_helper.py +2 -0
  12. onnx_diagnostic/reference/__init__.py +1 -0
  13. onnx_diagnostic/reference/torch_evaluator.py +518 -0
  14. onnx_diagnostic/reference/torch_ops/__init__.py +55 -0
  15. onnx_diagnostic/reference/torch_ops/_op_run.py +326 -0
  16. onnx_diagnostic/reference/torch_ops/access_ops.py +84 -0
  17. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  18. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +118 -0
  19. onnx_diagnostic/reference/torch_ops/generator_ops.py +35 -0
  20. onnx_diagnostic/reference/torch_ops/nn_ops.py +176 -0
  21. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  22. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  23. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  24. onnx_diagnostic/reference/torch_ops/shape_ops.py +120 -0
  25. onnx_diagnostic/reference/torch_ops/unary_ops.py +86 -0
  26. onnx_diagnostic/tasks/__init__.py +22 -1
  27. onnx_diagnostic/tasks/image_classification.py +2 -2
  28. onnx_diagnostic/tasks/text_generation.py +3 -3
  29. onnx_diagnostic/torch_export_patches/eval/__init__.py +690 -0
  30. onnx_diagnostic/torch_export_patches/eval/model_cases.py +883 -0
  31. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +34 -1
  32. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +6 -1
  33. onnx_diagnostic/torch_export_patches/patch_module_helper.py +148 -28
  34. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +91 -0
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +117 -1
  36. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +142 -0
  37. onnx_diagnostic/torch_models/test_helper.py +225 -22
  38. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  39. {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/METADATA +1 -1
  40. {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/RECORD +43 -24
  41. {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/WHEEL +1 -1
  42. {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/licenses/LICENSE.txt +0 -0
  43. {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/top_level.txt +0 -0
@@ -39,9 +39,30 @@ def supported_tasks() -> List[str]:
39
39
 
40
40
  def reduce_model_config(config: Any, task: str) -> Dict[str, Any]:
41
41
  """Reduces a model size."""
42
+ head_size0 = (
43
+ config.head_dim
44
+ if hasattr(config, "head_dim") and config.head_dim
45
+ else (
46
+ config.hidden_size // config.num_attention_heads
47
+ if hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads")
48
+ else None
49
+ )
50
+ )
42
51
  tasks = {mod.__TASK__: mod.reduce_model_config for mod in __TASKS__}
43
52
  assert task in tasks, f"Task {task!r} not found in {sorted(tasks)}"
44
- return tasks[task](config)
53
+ res = tasks[task](config)
54
+ if head_size0 and "head_dim" in res:
55
+ head_size = (
56
+ config.head_dim
57
+ if hasattr(config, "head_dim") and config.head_dim
58
+ else config.hidden_size // config.num_attention_heads
59
+ )
60
+ assert head_size0 == head_size or head_size % 16 == 0, (
61
+ f"head_size should be a multiple of 16 "
62
+ f"(head_size0={head_size0}), res={res}, "
63
+ f"config=\n{config}"
64
+ )
65
+ return res
45
66
 
46
67
 
47
68
  def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]:
@@ -58,8 +58,8 @@ def get_inputs(
58
58
  shapes = {
59
59
  "pixel_values": {
60
60
  0: torch.export.Dim("batch", min=1, max=1024),
61
- 2: torch.export.Dim("width", min=1, max=4096),
62
- 3: torch.export.Dim("height", min=1, max=4096),
61
+ 2: "width",
62
+ 3: "height",
63
63
  },
64
64
  }
65
65
  inputs = dict(
@@ -27,7 +27,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
27
27
  kwargs = dict(
28
28
  num_hidden_layers=min(config.num_hidden_layers, 2),
29
29
  intermediate_size=256 if config is None else min(512, config.intermediate_size),
30
- hidden_size=256 if config is None else min(256, config.hidden_size),
30
+ hidden_size=512 if config is None else min(512, config.hidden_size),
31
31
  cls_cache="MambaCache",
32
32
  state_size=8 if config is None else getattr(config, "state_size", None),
33
33
  conv_kernel=4 if config is None else getattr(config, "conv_kernel", None),
@@ -44,8 +44,8 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
44
44
  else config.num_attention_heads
45
45
  ),
46
46
  hidden_size=(
47
- min(config.hidden_size, 3072 // 4)
48
- if config.hidden_size % 4 == 0
47
+ min(config.hidden_size, 4096 // 4)
48
+ if config.hidden_size % 64 == 0
49
49
  else config.hidden_size
50
50
  ),
51
51
  )