onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,249 @@
1
+ import multiprocessing
2
+ import os
3
+ from typing import Dict, Optional
4
+
5
+
6
+ def get_memory_rss(pid: int) -> int:
7
+ """
8
+ Returns the physical memory used by a process.
9
+
10
+ :param pid: process id, current one is `os.getpid()`
11
+ :return: physical memory
12
+
13
+ It relies on the module :epkg:`psutil`.
14
+ """
15
+ import psutil
16
+
17
+ process = psutil.Process(pid)
18
+ mem = process.memory_info().rss
19
+ return mem
20
+
21
+
22
+ class Monitor:
23
+ def __init__(self):
24
+ self.max_peak = 0
25
+ self.average = 0
26
+ self.n_measures = 0
27
+ self.begin = 0
28
+ self.end = 0
29
+
30
+ def to_dict(self, unit: int = 1):
31
+ funit = float(unit)
32
+ return dict(
33
+ peak=self.max_peak / funit,
34
+ mean=self.average * 1.0 / self.n_measures / funit,
35
+ n=self.n_measures,
36
+ begin=self.begin / funit,
37
+ end=self.end / funit,
38
+ )
39
+
40
+ @property
41
+ def delta_peak(self):
42
+ return self.max_peak - self.begin
43
+
44
+ @property
45
+ def delta_end(self):
46
+ return self.end - self.begin
47
+
48
+ @property
49
+ def delta_avg(self):
50
+ return self.average / self.n_measures - self.begin
51
+
52
+ def __repr__(self):
53
+ return (
54
+ f"{self.__class__.__name__}(begin={self.begin}, end={self.end}, "
55
+ f"peak={self.max_peak}, average={self.average}, n={self.n_measures}, "
56
+ f"d_end={self.delta_end}, d_peak={self.delta_peak}, d_avg={self.delta_avg}"
57
+ f")"
58
+ )
59
+
60
+ def update(self, mem):
61
+ if self.n_measures == 0:
62
+ self.begin = mem
63
+ self.max_peak = max(mem, self.max_peak)
64
+ self.average += mem
65
+ self.end = mem
66
+ self.n_measures += 1
67
+
68
+ def send(self, conn):
69
+ conn.send(self.max_peak)
70
+ conn.send(self.average)
71
+ conn.send(self.n_measures)
72
+ conn.send(self.begin)
73
+ conn.send(self.end)
74
+
75
+ @classmethod
76
+ def recv(cls, conn):
77
+ m = cls()
78
+ m.max_peak = conn.recv()
79
+ m.average = conn.recv()
80
+ m.n_measures = conn.recv()
81
+ m.begin = conn.recv()
82
+ m.end = conn.recv()
83
+ return m
84
+
85
+
86
+ def _process_memory_spy(conn):
87
+ # Sends the value it started.
88
+ conn.send(-2)
89
+
90
+ # process id to spy on
91
+ pid = conn.recv()
92
+
93
+ # delay between two measures
94
+ timeout = conn.recv()
95
+
96
+ # do CUDA
97
+ cuda = conn.recv()
98
+
99
+ import psutil
100
+
101
+ process = psutil.Process(pid)
102
+
103
+ if cuda:
104
+ from pynvml import (
105
+ nvmlDeviceGetCount,
106
+ nvmlDeviceGetHandleByIndex,
107
+ nvmlDeviceGetMemoryInfo,
108
+ nvmlInit,
109
+ nvmlShutdown,
110
+ )
111
+
112
+ nvmlInit()
113
+ n_gpus = nvmlDeviceGetCount()
114
+ handles = [nvmlDeviceGetHandleByIndex(i) for i in range(n_gpus)]
115
+
116
+ def gpu_used():
117
+ return [nvmlDeviceGetMemoryInfo(h).used for h in handles]
118
+
119
+ gpus = [Monitor() for i in range(n_gpus)]
120
+ else:
121
+ gpus = []
122
+
123
+ cpu = Monitor()
124
+
125
+ conn.send(-2)
126
+
127
+ # loop
128
+ while True:
129
+ mem = process.memory_info().rss
130
+ cpu.update(mem)
131
+ if cuda:
132
+ for r, g in zip(gpu_used(), gpus):
133
+ g.update(r)
134
+ if conn.poll(timeout=timeout):
135
+ code = conn.recv()
136
+ if code == -3:
137
+ break
138
+
139
+ # final iteration
140
+ end = process.memory_info().rss
141
+ cpu.update(end)
142
+ if cuda:
143
+ for r, g in zip(gpu_used(), gpus):
144
+ g.update(r)
145
+
146
+ # send
147
+ cpu.send(conn)
148
+ conn.send(len(gpus))
149
+ for g in gpus:
150
+ g.send(conn)
151
+ if cuda:
152
+ nvmlShutdown()
153
+ conn.close()
154
+
155
+
156
+ class MemorySpy:
157
+ """
158
+ Information about the spy. It class method `start`.
159
+ Method `stop` can be called to end the measure.
160
+
161
+ :param pid: process id of the process to spy on
162
+ :param delay: spy on every delay seconds
163
+ :param cuda: enable cuda monitoring
164
+ """
165
+
166
+ def __init__(self, pid: int, delay: float = 0.01, cuda: bool = False):
167
+ self.pid = pid
168
+ self.delay = delay
169
+ self.cuda = cuda
170
+ self.start()
171
+
172
+ def start(self) -> "MemorySpy":
173
+ """Starts another process and tells it to spy."""
174
+ self.parent_conn, self.child_conn = multiprocessing.Pipe()
175
+ self.child_process = multiprocessing.Process(
176
+ target=_process_memory_spy, args=(self.child_conn,)
177
+ )
178
+ self.child_process.start()
179
+ data = self.parent_conn.recv()
180
+ if data != -2:
181
+ raise RuntimeError(f"The child processing is supposed to send -2 not {data}.")
182
+ self.parent_conn.send(self.pid)
183
+ self.parent_conn.send(self.delay)
184
+ self.parent_conn.send(1 if self.cuda else 0)
185
+ data = self.parent_conn.recv()
186
+ if data != -2:
187
+ raise RuntimeError(
188
+ f"The child processing is supposed to send -2 again not {data}."
189
+ )
190
+ return self
191
+
192
+ def stop(self):
193
+ """Stops spying on."""
194
+ self.parent_conn.send(-3)
195
+
196
+ cpu = Monitor.recv(self.parent_conn)
197
+
198
+ n_gpus = self.parent_conn.recv()
199
+ gpus = []
200
+ for _i in range(n_gpus):
201
+ gpus.append(Monitor.recv(self.parent_conn))
202
+
203
+ self.parent_conn.close()
204
+ self.child_process.join()
205
+ res = dict(cpu=cpu)
206
+ if self.cuda:
207
+ res["gpus"] = gpus
208
+ return res
209
+
210
+
211
+ def start_spying_on(
212
+ pid: Optional[int] = None, delay: float = 0.01, cuda: bool = False
213
+ ) -> MemorySpy:
214
+ """
215
+ Starts the memory spy. The function starts another
216
+ process spying on the one sent as an argument.
217
+
218
+ :param pid: process id to spy or the the current one.
219
+ :param delay: delay between two measures.
220
+ :param cuda: True or False to get memory for cuda devices
221
+
222
+ Example:
223
+
224
+ .. code-block:: python
225
+
226
+ from onnx_diagnostic.helpers.memory_peak import start_spying_on, flatten
227
+
228
+ p = start_spying_on()
229
+ # ...
230
+ # code to measure
231
+ # ...
232
+ stat = p.stop()
233
+ print(stat)
234
+ print(flatten(stat))
235
+ """
236
+ if pid is None:
237
+ pid = os.getpid()
238
+ return MemorySpy(pid, delay, cuda)
239
+
240
+
241
+ def flatten(ps, prefix: str = "") -> Dict[str, float]:
242
+ obs = ps["cpu"].to_dict(unit=2**20)
243
+ if "gpus" in ps:
244
+ for i, g in enumerate(ps["gpus"]):
245
+ for k, v in g.to_dict(unit=2**20).items():
246
+ obs[f"gpu{i}_{k}"] = v
247
+ if prefix:
248
+ obs = {f"{prefix}{k}": v for k, v in obs.items()}
249
+ return obs