onnx-diagnostic 0.7.2__py3-none-any.whl → 0.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/helpers/rt_helper.py +11 -1
- onnx_diagnostic/helpers/torch_helper.py +1 -1
- onnx_diagnostic/torch_export_patches/eval/__init__.py +3 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +4 -2
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2 -3
- {onnx_diagnostic-0.7.2.dist-info → onnx_diagnostic-0.7.3.dist-info}/METADATA +26 -1
- {onnx_diagnostic-0.7.2.dist-info → onnx_diagnostic-0.7.3.dist-info}/RECORD +11 -11
- {onnx_diagnostic-0.7.2.dist-info → onnx_diagnostic-0.7.3.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.2.dist-info → onnx_diagnostic-0.7.3.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.2.dist-info → onnx_diagnostic-0.7.3.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -112,4 +112,14 @@ def make_feeds(
|
|
|
112
112
|
|
|
113
113
|
if copy:
|
|
114
114
|
flat = [t.copy() if hasattr(t, "copy") else t.clone() for t in flat]
|
|
115
|
-
|
|
115
|
+
# bool, int, float, onnxruntime does not support float, bool, int
|
|
116
|
+
new_flat = []
|
|
117
|
+
for i in flat:
|
|
118
|
+
if isinstance(i, bool):
|
|
119
|
+
i = np.array(i, dtype=np.bool_)
|
|
120
|
+
elif isinstance(i, int):
|
|
121
|
+
i = np.array(i, dtype=np.int64)
|
|
122
|
+
elif isinstance(i, float):
|
|
123
|
+
i = np.array(i, dtype=np.float32)
|
|
124
|
+
new_flat.append(i)
|
|
125
|
+
return dict(zip(names, new_flat))
|
|
@@ -717,7 +717,7 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
|
|
|
717
717
|
return tuple(to_any(t, to_value) for t in value)
|
|
718
718
|
if isinstance(value, set):
|
|
719
719
|
return {to_any(t, to_value) for t in value}
|
|
720
|
-
if
|
|
720
|
+
if type(value) is dict:
|
|
721
721
|
return {k: to_any(t, to_value) for k, t in value.items()}
|
|
722
722
|
if value.__class__.__name__ == "DynamicCache":
|
|
723
723
|
return make_dynamic_cache(
|
|
@@ -337,7 +337,7 @@ def _make_exporter_onnx(
|
|
|
337
337
|
from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
|
|
338
338
|
|
|
339
339
|
opts = {}
|
|
340
|
-
opts["strict"] = "-
|
|
340
|
+
opts["strict"] = "-strict" in exporter
|
|
341
341
|
opts["fallback"] = "-fallback" in exporter
|
|
342
342
|
opts["tracing"] = "-tracing" in exporter
|
|
343
343
|
opts["jit"] = "-jit" in exporter
|
|
@@ -520,6 +520,8 @@ def run_exporter(
|
|
|
520
520
|
return res
|
|
521
521
|
|
|
522
522
|
onx, builder = res
|
|
523
|
+
base["onx"] = onx
|
|
524
|
+
base["builder"] = builder
|
|
523
525
|
if verbose >= 9:
|
|
524
526
|
print("[run_exporter] onnx model")
|
|
525
527
|
print(
|
|
@@ -28,7 +28,8 @@ def register_class_serialization(
|
|
|
28
28
|
) -> bool:
|
|
29
29
|
"""
|
|
30
30
|
Registers a class.
|
|
31
|
-
It can be undone with
|
|
31
|
+
It can be undone with
|
|
32
|
+
:func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_class_serialization`.
|
|
32
33
|
|
|
33
34
|
:param cls: class to register
|
|
34
35
|
:param f_flatten: see ``torch.utils._pytree.register_pytree_node``
|
|
@@ -77,7 +78,8 @@ def register_cache_serialization(
|
|
|
77
78
|
patch_transformers: bool = False, patch_diffusers: bool = True, verbose: int = 0
|
|
78
79
|
) -> Dict[str, bool]:
|
|
79
80
|
"""
|
|
80
|
-
Registers many classes with
|
|
81
|
+
Registers many classes with
|
|
82
|
+
:func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_class_serialization`.
|
|
81
83
|
Returns information needed to undo the registration.
|
|
82
84
|
|
|
83
85
|
:param patch_transformers: add serialization function for
|
|
@@ -214,8 +214,8 @@ class patched_DynamicCache:
|
|
|
214
214
|
if len(self.key_cache) <= layer_idx:
|
|
215
215
|
# There may be skipped layers, fill them with empty lists
|
|
216
216
|
for _ in range(len(self.key_cache), layer_idx):
|
|
217
|
-
self.key_cache.append(torch.tensor([]))
|
|
218
|
-
self.value_cache.append(torch.tensor([]))
|
|
217
|
+
self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
|
|
218
|
+
self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
|
|
219
219
|
self.key_cache.append(key_states)
|
|
220
220
|
self.value_cache.append(value_states)
|
|
221
221
|
elif not self.key_cache[
|
|
@@ -231,7 +231,6 @@ class patched_DynamicCache:
|
|
|
231
231
|
self.value_cache[layer_idx] = torch.cat(
|
|
232
232
|
[self.value_cache[layer_idx], value_states], dim=-2
|
|
233
233
|
)
|
|
234
|
-
|
|
235
234
|
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
|
236
235
|
|
|
237
236
|
def crop(self, max_length: int):
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: onnx-diagnostic
|
|
3
|
-
Version: 0.7.
|
|
3
|
+
Version: 0.7.3
|
|
4
4
|
Summary: Investigate ONNX models
|
|
5
5
|
Home-page: https://github.com/sdpython/onnx-diagnostic
|
|
6
6
|
Author: Xavier Dupré
|
|
@@ -64,13 +64,26 @@ onnx-diagnostic: investigate onnx models
|
|
|
64
64
|
|
|
65
65
|
The main feature is about `patches <https://github.com/sdpython/onnx-diagnostic/tree/main/onnx_diagnostic/torch_export_patches>`_:
|
|
66
66
|
it helps exporting **pytorch models into ONNX**, mostly designed for LLMs using dynamic caches.
|
|
67
|
+
Patches can be enabled as follows:
|
|
67
68
|
|
|
68
69
|
.. code-block:: python
|
|
69
70
|
|
|
71
|
+
from onnx_diagnostic.torch_export_patches import torch_export_patches
|
|
72
|
+
|
|
70
73
|
with torch_export_patches(patch_transformers=True) as f:
|
|
71
74
|
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
|
|
72
75
|
# ...
|
|
73
76
|
|
|
77
|
+
Dynamic shapes are difficult to guess for caches, one function
|
|
78
|
+
returns a structure defining all dimensions as dynamic.
|
|
79
|
+
You need then to remove those which are not dynamic in your model.
|
|
80
|
+
|
|
81
|
+
.. code-block:: python
|
|
82
|
+
|
|
83
|
+
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
|
|
84
|
+
|
|
85
|
+
dynamic_shapes = all_dynamic_shape_from_inputs(cache)
|
|
86
|
+
|
|
74
87
|
It also implements tools to investigate, validate exported models (ExportedProgramm, ONNXProgram, ...).
|
|
75
88
|
See `documentation of onnx-diagnostic <https://sdpython.github.io/doc/onnx-diagnostic/dev/>`_ and
|
|
76
89
|
`torch_export_patches <https://sdpython.github.io/doc/onnx-diagnostic/dev/api/torch_export_patches/index.html#onnx_diagnostic.torch_export_patches.torch_export_patches>`_.
|
|
@@ -127,14 +140,26 @@ Snapshot of usefuls tools
|
|
|
127
140
|
|
|
128
141
|
.. code-block:: python
|
|
129
142
|
|
|
143
|
+
from onnx_diagnostic.torch_export_patches import torch_export_patches
|
|
144
|
+
|
|
130
145
|
with torch_export_patches(patch_transformers=True) as f:
|
|
131
146
|
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
|
|
132
147
|
# ...
|
|
133
148
|
|
|
149
|
+
**all_dynamic_shape_from_inputs**
|
|
150
|
+
|
|
151
|
+
.. code-block:: python
|
|
152
|
+
|
|
153
|
+
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
|
|
154
|
+
|
|
155
|
+
dynamic_shapes = all_dynamic_shape_from_inputs(cache)
|
|
156
|
+
|
|
134
157
|
**torch_export_rewrite**
|
|
135
158
|
|
|
136
159
|
.. code-block:: python
|
|
137
160
|
|
|
161
|
+
from onnx_diagnostic.torch_export_patches import torch_export_rewrite
|
|
162
|
+
|
|
138
163
|
with torch_export_rewrite(rewrite=[Model.forward]) as f:
|
|
139
164
|
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
|
|
140
165
|
# ...
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
onnx_diagnostic/__init__.py,sha256=
|
|
1
|
+
onnx_diagnostic/__init__.py,sha256=N1lf8_afRytDUnulPCeDVDPA-M4k7y9x7LbWwX0USZs,173
|
|
2
2
|
onnx_diagnostic/__main__.py,sha256=YmyV_Aq_ianDlHyKLHMa6h8YK3ZmFPpLVHLKjM91aCk,79
|
|
3
3
|
onnx_diagnostic/_command_lines_parser.py,sha256=TliHXedXFerv-zO6cBigxKbuKHE0-6TUhhsk1pdkz9M,28072
|
|
4
4
|
onnx_diagnostic/api.py,sha256=BhCl_yCd78N7TlVtPOHjeYv1QBEy39TjZ647rcHqLh0,345
|
|
@@ -22,8 +22,8 @@ onnx_diagnostic/helpers/mini_onnx_builder.py,sha256=p0Xh2Br38xAqUjB2214GiNOIbCgi
|
|
|
22
22
|
onnx_diagnostic/helpers/model_builder_helper.py,sha256=RvDyPFqRboEU3HsQV_xi9oy-o3_4KuGFVzs5MhksduY,12552
|
|
23
23
|
onnx_diagnostic/helpers/onnx_helper.py,sha256=pXXQjfyNTSUF-Kt72U4fnBDkYAnWYMxdSw8m0qk3xmE,39670
|
|
24
24
|
onnx_diagnostic/helpers/ort_session.py,sha256=UgUUeUslDxEFBc6w6f3HMq_a7bn4TBlItmojqWquSj4,29281
|
|
25
|
-
onnx_diagnostic/helpers/rt_helper.py,sha256=
|
|
26
|
-
onnx_diagnostic/helpers/torch_helper.py,sha256=
|
|
25
|
+
onnx_diagnostic/helpers/rt_helper.py,sha256=qbV6zyMs-iH6H65WHC2tu4h0psnHg0TX5fwfO_k-glg,4623
|
|
26
|
+
onnx_diagnostic/helpers/torch_helper.py,sha256=QfUXUPx0lZEqJBgyA97daPRDlT9duTM5Jq5Yjq1jJd8,32358
|
|
27
27
|
onnx_diagnostic/reference/__init__.py,sha256=rLZsxOlnb7-81F2CzepGnZLejaROg4JvgFaGR9FwVQA,208
|
|
28
28
|
onnx_diagnostic/reference/evaluator.py,sha256=RzNzjFDeMe-4X51Tb22N6aagazY5ktNq-mRmPcfY5EU,8848
|
|
29
29
|
onnx_diagnostic/reference/ort_evaluator.py,sha256=1O7dHj8Aspolidg6rB2Nm7hT3HaGb4TxAgjCCD0XVcQ,26159
|
|
@@ -88,16 +88,16 @@ onnx_diagnostic/tasks/text_to_image.py,sha256=6z-rFG6MX9aBi8YoYtYI_8OV3M3Tfoi45V
|
|
|
88
88
|
onnx_diagnostic/tasks/zero_shot_image_classification.py,sha256=GKaXm8g7cK23h3wJEUc6Q-6mpmLAzQ4YkJbd-eGP7Y4,4496
|
|
89
89
|
onnx_diagnostic/torch_export_patches/__init__.py,sha256=0SaZedwznm1hQUCvXZsGZORV5vby954wEExr5faepGg,720
|
|
90
90
|
onnx_diagnostic/torch_export_patches/onnx_export_errors.py,sha256=h_txSp30QmF1R_Q2wL4qpPqY59Dund2P9nAAsvucS8A,21245
|
|
91
|
-
onnx_diagnostic/torch_export_patches/onnx_export_serialization.py,sha256=
|
|
91
|
+
onnx_diagnostic/torch_export_patches/onnx_export_serialization.py,sha256=wFE2fNxihAA3iua79AEB97_RBVv4wvGxwS9g4RJaSIc,10715
|
|
92
92
|
onnx_diagnostic/torch_export_patches/patch_expressions.py,sha256=vr4tt61cbDnaaaduzMj4UBZ8OUtr6GfDpIWwOYqjWzs,3213
|
|
93
93
|
onnx_diagnostic/torch_export_patches/patch_inputs.py,sha256=9b4pmyT00BwLqi7WG-gliep1RUy3gXEgW6BDnlSSA-M,7689
|
|
94
94
|
onnx_diagnostic/torch_export_patches/patch_module.py,sha256=R2d9IHM-RwsBKDsxuBIJnEqMoxbS9gd4YWFGG2wwV5A,39881
|
|
95
95
|
onnx_diagnostic/torch_export_patches/patch_module_helper.py,sha256=2U0AdyZuU0W54QTdE7tY7imVzMnpQ5091ADNtTCkT8Y,6967
|
|
96
|
-
onnx_diagnostic/torch_export_patches/eval/__init__.py,sha256=
|
|
96
|
+
onnx_diagnostic/torch_export_patches/eval/__init__.py,sha256=57x62uZNA80XiWgkG8Fe0_8YJcIVrvKLPqvwLDPJwgc,24008
|
|
97
97
|
onnx_diagnostic/torch_export_patches/eval/model_cases.py,sha256=DTvdHPtNQh25Akv5o3D4Jxf1L1-SJ7w14tgvj8AAns8,26577
|
|
98
98
|
onnx_diagnostic/torch_export_patches/patches/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
99
99
|
onnx_diagnostic/torch_export_patches/patches/patch_torch.py,sha256=KaZ8TjDa9ATgT4HllYzzoNf_51q_yOj_GuF5NYjPCrU,18913
|
|
100
|
-
onnx_diagnostic/torch_export_patches/patches/patch_transformers.py,sha256=
|
|
100
|
+
onnx_diagnostic/torch_export_patches/patches/patch_transformers.py,sha256=QGfXf1IlsOToQKF9NOHau_l62u1alQbJ6KeiSKFb980,44061
|
|
101
101
|
onnx_diagnostic/torch_export_patches/serialization/__init__.py,sha256=BHLdRPtNAtNPAS-bPKEj3-foGSPvwAbZXrHzGGPDLEw,1876
|
|
102
102
|
onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py,sha256=drq3EH_yjcSuIWYsVeUWm8Cx6YCZFU6bP_1PLtPfY5I,945
|
|
103
103
|
onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py,sha256=9u2jkqnuyBkIF3R2sDEO0Jlkedl-cQhBNXxXXDLSEwE,8885
|
|
@@ -115,8 +115,8 @@ onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py,sha256=QXw_Bs2SzfeiQMf-tm
|
|
|
115
115
|
onnx_diagnostic/torch_onnx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
116
116
|
onnx_diagnostic/torch_onnx/runtime_info.py,sha256=1g9F_Jf9AAgYQU4stbsrFXwQl-30mWlQrFbQ7val8Ps,9268
|
|
117
117
|
onnx_diagnostic/torch_onnx/sbs.py,sha256=1EL25DeYFzlBSiFG_XjePBLvsiItRXbdDrr5-QZW2mA,16878
|
|
118
|
-
onnx_diagnostic-0.7.
|
|
119
|
-
onnx_diagnostic-0.7.
|
|
120
|
-
onnx_diagnostic-0.7.
|
|
121
|
-
onnx_diagnostic-0.7.
|
|
122
|
-
onnx_diagnostic-0.7.
|
|
118
|
+
onnx_diagnostic-0.7.3.dist-info/licenses/LICENSE.txt,sha256=Vv6TXglX6Rc0d-f8aREhayhT-6PMQXEyOmI2NKlUCMc,1045
|
|
119
|
+
onnx_diagnostic-0.7.3.dist-info/METADATA,sha256=uMYdGJbm6K04yoi5UidH3uA0nK_PXrnVXe83s9v6yPE,7431
|
|
120
|
+
onnx_diagnostic-0.7.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
121
|
+
onnx_diagnostic-0.7.3.dist-info/top_level.txt,sha256=KwNkXewmcobM3ZT1DJLVWH6ebJzA5qKg7cWqKfpGNT4,16
|
|
122
|
+
onnx_diagnostic-0.7.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|