onnx-diagnostic 0.7.2__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.7.2"
6
+ __version__ = "0.7.3"
7
7
  __author__ = "Xavier Dupré"
@@ -112,4 +112,14 @@ def make_feeds(
112
112
 
113
113
  if copy:
114
114
  flat = [t.copy() if hasattr(t, "copy") else t.clone() for t in flat]
115
- return dict(zip(names, flat))
115
+ # bool, int, float, onnxruntime does not support float, bool, int
116
+ new_flat = []
117
+ for i in flat:
118
+ if isinstance(i, bool):
119
+ i = np.array(i, dtype=np.bool_)
120
+ elif isinstance(i, int):
121
+ i = np.array(i, dtype=np.int64)
122
+ elif isinstance(i, float):
123
+ i = np.array(i, dtype=np.float32)
124
+ new_flat.append(i)
125
+ return dict(zip(names, new_flat))
@@ -717,7 +717,7 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
717
717
  return tuple(to_any(t, to_value) for t in value)
718
718
  if isinstance(value, set):
719
719
  return {to_any(t, to_value) for t in value}
720
- if isinstance(value, dict):
720
+ if type(value) is dict:
721
721
  return {k: to_any(t, to_value) for k, t in value.items()}
722
722
  if value.__class__.__name__ == "DynamicCache":
723
723
  return make_dynamic_cache(
@@ -337,7 +337,7 @@ def _make_exporter_onnx(
337
337
  from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
338
338
 
339
339
  opts = {}
340
- opts["strict"] = "-nostrict" not in exporter
340
+ opts["strict"] = "-strict" in exporter
341
341
  opts["fallback"] = "-fallback" in exporter
342
342
  opts["tracing"] = "-tracing" in exporter
343
343
  opts["jit"] = "-jit" in exporter
@@ -520,6 +520,8 @@ def run_exporter(
520
520
  return res
521
521
 
522
522
  onx, builder = res
523
+ base["onx"] = onx
524
+ base["builder"] = builder
523
525
  if verbose >= 9:
524
526
  print("[run_exporter] onnx model")
525
527
  print(
@@ -28,7 +28,8 @@ def register_class_serialization(
28
28
  ) -> bool:
29
29
  """
30
30
  Registers a class.
31
- It can be undone with :func:`unregister`.
31
+ It can be undone with
32
+ :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_class_serialization`.
32
33
 
33
34
  :param cls: class to register
34
35
  :param f_flatten: see ``torch.utils._pytree.register_pytree_node``
@@ -77,7 +78,8 @@ def register_cache_serialization(
77
78
  patch_transformers: bool = False, patch_diffusers: bool = True, verbose: int = 0
78
79
  ) -> Dict[str, bool]:
79
80
  """
80
- Registers many classes with :func:`register_class_serialization`.
81
+ Registers many classes with
82
+ :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_class_serialization`.
81
83
  Returns information needed to undo the registration.
82
84
 
83
85
  :param patch_transformers: add serialization function for
@@ -214,8 +214,8 @@ class patched_DynamicCache:
214
214
  if len(self.key_cache) <= layer_idx:
215
215
  # There may be skipped layers, fill them with empty lists
216
216
  for _ in range(len(self.key_cache), layer_idx):
217
- self.key_cache.append(torch.tensor([]))
218
- self.value_cache.append(torch.tensor([]))
217
+ self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
218
+ self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
219
219
  self.key_cache.append(key_states)
220
220
  self.value_cache.append(value_states)
221
221
  elif not self.key_cache[
@@ -231,7 +231,6 @@ class patched_DynamicCache:
231
231
  self.value_cache[layer_idx] = torch.cat(
232
232
  [self.value_cache[layer_idx], value_states], dim=-2
233
233
  )
234
-
235
234
  return self.key_cache[layer_idx], self.value_cache[layer_idx]
236
235
 
237
236
  def crop(self, max_length: int):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-diagnostic
3
- Version: 0.7.2
3
+ Version: 0.7.3
4
4
  Summary: Investigate ONNX models
5
5
  Home-page: https://github.com/sdpython/onnx-diagnostic
6
6
  Author: Xavier Dupré
@@ -64,13 +64,26 @@ onnx-diagnostic: investigate onnx models
64
64
 
65
65
  The main feature is about `patches <https://github.com/sdpython/onnx-diagnostic/tree/main/onnx_diagnostic/torch_export_patches>`_:
66
66
  it helps exporting **pytorch models into ONNX**, mostly designed for LLMs using dynamic caches.
67
+ Patches can be enabled as follows:
67
68
 
68
69
  .. code-block:: python
69
70
 
71
+ from onnx_diagnostic.torch_export_patches import torch_export_patches
72
+
70
73
  with torch_export_patches(patch_transformers=True) as f:
71
74
  ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
72
75
  # ...
73
76
 
77
+ Dynamic shapes are difficult to guess for caches, one function
78
+ returns a structure defining all dimensions as dynamic.
79
+ You need then to remove those which are not dynamic in your model.
80
+
81
+ .. code-block:: python
82
+
83
+ from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
84
+
85
+ dynamic_shapes = all_dynamic_shape_from_inputs(cache)
86
+
74
87
  It also implements tools to investigate, validate exported models (ExportedProgramm, ONNXProgram, ...).
75
88
  See `documentation of onnx-diagnostic <https://sdpython.github.io/doc/onnx-diagnostic/dev/>`_ and
76
89
  `torch_export_patches <https://sdpython.github.io/doc/onnx-diagnostic/dev/api/torch_export_patches/index.html#onnx_diagnostic.torch_export_patches.torch_export_patches>`_.
@@ -127,14 +140,26 @@ Snapshot of usefuls tools
127
140
 
128
141
  .. code-block:: python
129
142
 
143
+ from onnx_diagnostic.torch_export_patches import torch_export_patches
144
+
130
145
  with torch_export_patches(patch_transformers=True) as f:
131
146
  ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
132
147
  # ...
133
148
 
149
+ **all_dynamic_shape_from_inputs**
150
+
151
+ .. code-block:: python
152
+
153
+ from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
154
+
155
+ dynamic_shapes = all_dynamic_shape_from_inputs(cache)
156
+
134
157
  **torch_export_rewrite**
135
158
 
136
159
  .. code-block:: python
137
160
 
161
+ from onnx_diagnostic.torch_export_patches import torch_export_rewrite
162
+
138
163
  with torch_export_rewrite(rewrite=[Model.forward]) as f:
139
164
  ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
140
165
  # ...
@@ -1,4 +1,4 @@
1
- onnx_diagnostic/__init__.py,sha256=Lt-QBr--poshkZCAn2mvNtBcQfTKfBUI7__zuZCXklo,173
1
+ onnx_diagnostic/__init__.py,sha256=N1lf8_afRytDUnulPCeDVDPA-M4k7y9x7LbWwX0USZs,173
2
2
  onnx_diagnostic/__main__.py,sha256=YmyV_Aq_ianDlHyKLHMa6h8YK3ZmFPpLVHLKjM91aCk,79
3
3
  onnx_diagnostic/_command_lines_parser.py,sha256=TliHXedXFerv-zO6cBigxKbuKHE0-6TUhhsk1pdkz9M,28072
4
4
  onnx_diagnostic/api.py,sha256=BhCl_yCd78N7TlVtPOHjeYv1QBEy39TjZ647rcHqLh0,345
@@ -22,8 +22,8 @@ onnx_diagnostic/helpers/mini_onnx_builder.py,sha256=p0Xh2Br38xAqUjB2214GiNOIbCgi
22
22
  onnx_diagnostic/helpers/model_builder_helper.py,sha256=RvDyPFqRboEU3HsQV_xi9oy-o3_4KuGFVzs5MhksduY,12552
23
23
  onnx_diagnostic/helpers/onnx_helper.py,sha256=pXXQjfyNTSUF-Kt72U4fnBDkYAnWYMxdSw8m0qk3xmE,39670
24
24
  onnx_diagnostic/helpers/ort_session.py,sha256=UgUUeUslDxEFBc6w6f3HMq_a7bn4TBlItmojqWquSj4,29281
25
- onnx_diagnostic/helpers/rt_helper.py,sha256=BXU_u1syk2RyM0HTFHKEiO6rHHhZW2UFPyUTVdeq8BU,4251
26
- onnx_diagnostic/helpers/torch_helper.py,sha256=MJpoiKZoKzp_ed5LK_2ssIMPo0eohn9WrVAcgPvT2Gk,32362
25
+ onnx_diagnostic/helpers/rt_helper.py,sha256=qbV6zyMs-iH6H65WHC2tu4h0psnHg0TX5fwfO_k-glg,4623
26
+ onnx_diagnostic/helpers/torch_helper.py,sha256=QfUXUPx0lZEqJBgyA97daPRDlT9duTM5Jq5Yjq1jJd8,32358
27
27
  onnx_diagnostic/reference/__init__.py,sha256=rLZsxOlnb7-81F2CzepGnZLejaROg4JvgFaGR9FwVQA,208
28
28
  onnx_diagnostic/reference/evaluator.py,sha256=RzNzjFDeMe-4X51Tb22N6aagazY5ktNq-mRmPcfY5EU,8848
29
29
  onnx_diagnostic/reference/ort_evaluator.py,sha256=1O7dHj8Aspolidg6rB2Nm7hT3HaGb4TxAgjCCD0XVcQ,26159
@@ -88,16 +88,16 @@ onnx_diagnostic/tasks/text_to_image.py,sha256=6z-rFG6MX9aBi8YoYtYI_8OV3M3Tfoi45V
88
88
  onnx_diagnostic/tasks/zero_shot_image_classification.py,sha256=GKaXm8g7cK23h3wJEUc6Q-6mpmLAzQ4YkJbd-eGP7Y4,4496
89
89
  onnx_diagnostic/torch_export_patches/__init__.py,sha256=0SaZedwznm1hQUCvXZsGZORV5vby954wEExr5faepGg,720
90
90
  onnx_diagnostic/torch_export_patches/onnx_export_errors.py,sha256=h_txSp30QmF1R_Q2wL4qpPqY59Dund2P9nAAsvucS8A,21245
91
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py,sha256=4fXsuQjJq_Ko_EehiVZYypdWTBgFgaaK8ryhAFaR0yo,10561
91
+ onnx_diagnostic/torch_export_patches/onnx_export_serialization.py,sha256=wFE2fNxihAA3iua79AEB97_RBVv4wvGxwS9g4RJaSIc,10715
92
92
  onnx_diagnostic/torch_export_patches/patch_expressions.py,sha256=vr4tt61cbDnaaaduzMj4UBZ8OUtr6GfDpIWwOYqjWzs,3213
93
93
  onnx_diagnostic/torch_export_patches/patch_inputs.py,sha256=9b4pmyT00BwLqi7WG-gliep1RUy3gXEgW6BDnlSSA-M,7689
94
94
  onnx_diagnostic/torch_export_patches/patch_module.py,sha256=R2d9IHM-RwsBKDsxuBIJnEqMoxbS9gd4YWFGG2wwV5A,39881
95
95
  onnx_diagnostic/torch_export_patches/patch_module_helper.py,sha256=2U0AdyZuU0W54QTdE7tY7imVzMnpQ5091ADNtTCkT8Y,6967
96
- onnx_diagnostic/torch_export_patches/eval/__init__.py,sha256=VtkQB1o3Q2Fh99OOF6vQ2dynkhwzx2Wx6oB-rRbvTI0,23954
96
+ onnx_diagnostic/torch_export_patches/eval/__init__.py,sha256=57x62uZNA80XiWgkG8Fe0_8YJcIVrvKLPqvwLDPJwgc,24008
97
97
  onnx_diagnostic/torch_export_patches/eval/model_cases.py,sha256=DTvdHPtNQh25Akv5o3D4Jxf1L1-SJ7w14tgvj8AAns8,26577
98
98
  onnx_diagnostic/torch_export_patches/patches/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
99
99
  onnx_diagnostic/torch_export_patches/patches/patch_torch.py,sha256=KaZ8TjDa9ATgT4HllYzzoNf_51q_yOj_GuF5NYjPCrU,18913
100
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py,sha256=XyLqG4w4ALCVF8Dc8_Meu903saFYGBEBG0utziw9i3Q,44014
100
+ onnx_diagnostic/torch_export_patches/patches/patch_transformers.py,sha256=QGfXf1IlsOToQKF9NOHau_l62u1alQbJ6KeiSKFb980,44061
101
101
  onnx_diagnostic/torch_export_patches/serialization/__init__.py,sha256=BHLdRPtNAtNPAS-bPKEj3-foGSPvwAbZXrHzGGPDLEw,1876
102
102
  onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py,sha256=drq3EH_yjcSuIWYsVeUWm8Cx6YCZFU6bP_1PLtPfY5I,945
103
103
  onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py,sha256=9u2jkqnuyBkIF3R2sDEO0Jlkedl-cQhBNXxXXDLSEwE,8885
@@ -115,8 +115,8 @@ onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py,sha256=QXw_Bs2SzfeiQMf-tm
115
115
  onnx_diagnostic/torch_onnx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
116
116
  onnx_diagnostic/torch_onnx/runtime_info.py,sha256=1g9F_Jf9AAgYQU4stbsrFXwQl-30mWlQrFbQ7val8Ps,9268
117
117
  onnx_diagnostic/torch_onnx/sbs.py,sha256=1EL25DeYFzlBSiFG_XjePBLvsiItRXbdDrr5-QZW2mA,16878
118
- onnx_diagnostic-0.7.2.dist-info/licenses/LICENSE.txt,sha256=Vv6TXglX6Rc0d-f8aREhayhT-6PMQXEyOmI2NKlUCMc,1045
119
- onnx_diagnostic-0.7.2.dist-info/METADATA,sha256=2jkNpfMIypu51qway6NIH1olWbeF_soM-e8rbwc3jVc,6631
120
- onnx_diagnostic-0.7.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
121
- onnx_diagnostic-0.7.2.dist-info/top_level.txt,sha256=KwNkXewmcobM3ZT1DJLVWH6ebJzA5qKg7cWqKfpGNT4,16
122
- onnx_diagnostic-0.7.2.dist-info/RECORD,,
118
+ onnx_diagnostic-0.7.3.dist-info/licenses/LICENSE.txt,sha256=Vv6TXglX6Rc0d-f8aREhayhT-6PMQXEyOmI2NKlUCMc,1045
119
+ onnx_diagnostic-0.7.3.dist-info/METADATA,sha256=uMYdGJbm6K04yoi5UidH3uA0nK_PXrnVXe83s9v6yPE,7431
120
+ onnx_diagnostic-0.7.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
121
+ onnx_diagnostic-0.7.3.dist-info/top_level.txt,sha256=KwNkXewmcobM3ZT1DJLVWH6ebJzA5qKg7cWqKfpGNT4,16
122
+ onnx_diagnostic-0.7.3.dist-info/RECORD,,